### Import Libraries

In [None]:
!pip install torch_geometric
!pip install networkit
!apt install libgraphviz-dev
!pip install pygraphviz

Collecting torch_geometric
  Downloading torch_geometric-2.5.3-py3-none-any.whl (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch_geometric
Successfully installed torch_geometric-2.5.3
Collecting networkit
  Downloading networkit-11.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.7/13.7 MB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: networkit
Successfully installed networkit-11.0
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
  libgail-common libgail18 libgtk2.0-0 libgtk2.0-bin libgtk2.0-common libgvc6-plugins-gtk
  librsvg2-common libxdot4
Suggested packages:
  gvfs
The following NEW packages will be installed:
  libgail-common libgail18 libgraphviz-dev 

In [None]:
import numpy as np
import pandas as pd
import torch
from torch import nn

import json
import os
import networkx as nx
import networkit as nk
from networkit import vizbridges

from tqdm import tqdm
import ast
import matplotlib.pyplot as plt
import re
from sklearn.metrics import top_k_accuracy_score
from sklearn.metrics import ndcg_score

from torch_geometric.data import Data
from torch_geometric.utils import subgraph
from torch_geometric.nn import GATv2Conv, GCNConv

import copy
from typing import Callable, Tuple

import torch
from torch import Tensor
from torch.nn import Module, Parameter

from torch_geometric.nn.inits import reset, uniform
import random

In [None]:
torch.manual_seed(200)
np.random.seed(200)

### Util functions

In [None]:
def load_dataset_timestamp( n_users, n_context, seq_len):
    act_list = []
    time_list = []
    user_list = []

    max_timestamp = -1.0
    min_timestamp = float('inf')

    with open('gowalla_user_activity.txt', 'r') as raw_file:
        for line in raw_file:
            t_item_list = []
            t_time_list = []
            user = int(line.split(':')[0])
            entries = line.split()[1:]
            for a_entry in entries:
                item, time_stamp = a_entry.split(':')
                t_item_list.append(int(item.strip()))
                t_time_list.append(int(time_stamp.strip()))

                if min_timestamp > int(time_stamp.strip()):
                    min_timestamp = int(time_stamp.strip())
                if max_timestamp < int(time_stamp.strip()):
                    max_timestamp = int(time_stamp.strip())

            act_list.append(t_item_list[0: seq_len])
            time_list.append(t_time_list[0: seq_len])
            user_list.append(user)

    new_time_list = []
    num_bins = 0

    times_bins = np.linspace(min_timestamp, max_timestamp + 1, num=num_bins, dtype=np.int32)
    for a_time_list in time_list:
        temp_time_list = (np.digitize(np.asarray(a_time_list), times_bins) - 1).tolist()
        new_time_list.append(temp_time_list)

    all_examples = {}
    for i in range(0, len(act_list)):

        train_act_seq = act_list[i][:-2]

        train_time_seq = new_time_list[i][:-2]

        train_act_label = act_list[i][-2]
        train_time_label = new_time_list[i][-2]

        test_act_seq = act_list[i][1:-1]
        test_time_seq = new_time_list[i][1:-1]

        test_act_label = act_list[i][-1]
        test_time_label = new_time_list[i][-1]

        entry = {
            'train_act_seq': train_act_seq,
            'train_time_seq': train_time_seq,
            'train_act_label': train_act_label,
            'train_time_label': train_time_label,
            'test_act_seq': test_act_seq,
            'test_time_seq': test_time_seq,
            'test_act_label': test_act_label,
            'test_time_label': test_time_label,
            'seq_len': len(train_act_seq),
            'user': user_list[i]
        }

        all_examples[user_list[i]] = entry

    return all_examples

### Metrics

In [None]:
def print_metrics(hits1, hits5, hits10, hits20, hits50, hits100, map1, map5, map10, map20, map50, map100, ndcg1, ndcg5, ndcg10, ndcg20, ndcg50, ndcg100):
    print(f'hits@1: {hits1:.6f}, hits@5: {hits5:.6f}, hits@10: {hits10:.6f}, hits@20: {hits20:.6f}')
    print(f'hits@50: {hits50:.6f}, hits@100: {hits100:.6f}')
    print(f'map@1: {map1:.6f}, map@5: {map5:.6f}, map@10: {map10:.6f}, map@20: {map20:.6f}')
    print(f'map@50: {map50:.6f}, map@100: {map100:.6f}')
    print(f'ndcg@1: {ndcg1:.6f}, ndcg@5: {ndcg5:.6f}, ndcg@10: {ndcg10:.6f}, ndcg@20: {ndcg20:.6f}')
    print(f'ndcg@50: {ndcg50:.6f}, ndcg@100: {ndcg100:.6f}')

In [None]:
def apk(actual, predicted, k=10):

    if len(predicted) > k:
        predicted = predicted[:k]

    score = 0.0
    num_hits = 0.0

    for i, p in enumerate(predicted):
        if p in actual and p not in predicted[:i]:
            num_hits += 1.0
            score += num_hits / (i + 1.0)

    if not actual:
        return 0.0

    return score / min(len(actual), k)


def mapk(y_prob, y, k=10):

    predicted = [np.argsort(p_)[-k:][::-1] for p_ in y_prob]
    actual = [[y_] for y_ in y]
    return np.mean([apk(a, p, k) for a, p in zip(actual, predicted)])


def hits_k(y_prob, y, k=10):
    acc = []
    for p_, y_ in zip(y_prob, y):
        top_k = p_.argsort()[-k:][::-1]
        acc += [1. if y_ in top_k else 0.]
    return sum(acc) / len(acc)

In [None]:
def get_metrics_(probs, labels_batch, test_one_hot):
    hits1 = top_k_accuracy_score(labels_batch, probs.cpu().detach().numpy(), k=1, labels = classes)
    hits5 = top_k_accuracy_score(labels_batch, probs.cpu().detach().numpy(), k=5, labels = classes)
    hits10 = top_k_accuracy_score(labels_batch, probs.cpu().detach().numpy(), k=10, labels = classes)
    hits20 = top_k_accuracy_score(labels_batch, probs.cpu().detach().numpy(), k=20, labels = classes)
    hits50= top_k_accuracy_score(labels_batch, probs.cpu().detach().numpy(), k=50, labels = classes)
    hits100 = top_k_accuracy_score(labels_batch, probs.cpu().detach().numpy(), k=100, labels = classes)

    map1 = mapk(y_prob=probs.cpu().detach().numpy(), y = labels_batch, k=1)
    map5 = mapk(y_prob=probs.cpu().detach().numpy(), y = labels_batch, k=5)
    map10 = mapk(y_prob=probs.cpu().detach().numpy(), y = labels_batch, k=10)
    map20 = mapk(y_prob=probs.cpu().detach().numpy(), y = labels_batch, k=20)
    map50 = mapk(y_prob=probs.cpu().detach().numpy(), y = labels_batch, k=50)
    map100 = mapk(y_prob=probs.cpu().detach().numpy(), y = labels_batch, k=100)

    ndcg1 = ndcg_score(test_one_hot, probs.cpu().detach().numpy(), k=1)
    ndcg5 = ndcg_score(test_one_hot, probs.cpu().detach().numpy(), k=5)
    ndcg10 = ndcg_score(test_one_hot, probs.cpu().detach().numpy(), k=10)
    ndcg20 = ndcg_score(test_one_hot, probs.cpu().detach().numpy(), k=20)
    ndcg50 = ndcg_score(test_one_hot, probs.cpu().detach().numpy(), k=50)
    ndcg100 = ndcg_score(test_one_hot, probs.cpu().detach().numpy(), k=100)
    return hits1, hits5, hits10, hits20, hits50, hits100, map1, map5, map10, map20, map50, map100, ndcg1, ndcg5, ndcg10, ndcg20, ndcg50, ndcg100

### Graph construction

In [None]:
users = load_dataset_timestamp(20001, 128, 30)

In [None]:
friends = pd.read_csv('gowalla_edges.csv')

In [None]:
Gcl = nx.Graph()
for i in range(len(users)):
  if users[i]['user'] not in Gcl.nodes():
      Gcl.add_node(users[i]['user'], weight=10, color='seagreen')
for i in range(len(friends)):
  Gcl.add_edge(friends.loc[i, '1st friend'], friends.loc[i, '2nd friend'], color = 'blue')



nodes = list(Gcl.nodes())
node_colors = [Gcl.nodes[node]['color'] for node in Gcl.nodes]
node_weights = [Gcl.nodes[node]['weight'] for node in Gcl.nodes]
edge_colors = [Gcl.edges[edge]['color'] for edge in Gcl.edges]

plt.figure(figsize=(12, 8))
pos = nx.drawing.nx_agraph.graphviz_layout(Gcl, prog='neato')
nx.draw_networkx(Gcl, pos=pos, with_labels=False, node_size=node_weights, node_color=node_colors, edge_color=edge_colors, width=1)

In [None]:
Gcl = nx.Graph()
for i in range(len(users)):
  if users[i]['user'] not in Gcl.nodes():
      Gcl.add_node(users[i]['user'], weight=10, color='seagreen')
for i in range(len(friends)):
  Gcl.add_edge(friends.loc[i, '1st friend'], friends.loc[i, '2nd friend'], color = 'blue')


nodes = list(Gcl.nodes())
node_colors = [Gcl.nodes[node]['color'] for node in Gcl.nodes]
node_weights = [Gcl.nodes[node]['weight'] for node in Gcl.nodes]
edge_colors = [Gcl.edges[edge]['color'] for edge in Gcl.edges]

In [None]:
graphid2pid = {}
pid2graphid = {}

s = 0
for node_id in Gcl.nodes():
    pid2graphid[node_id] = s
    graphid2pid[s] = node_id
    s += 1

graph_edges = []
for (u, v) in Gcl.edges():
    graph_edges.append((pid2graphid[u], pid2graphid[v]))

### Graph characteristics

In [None]:
nkg = nk.Graph(n=len(pid2graphid), weighted=True, directed=False, edgesIndexed=True)
for (u, v) in graph_edges:
    nkg.addEdge(u, v)

In [None]:
nk.overview(nkg)

Network Properties:
nodes, edges			20001, 167752
directed?			False
weighted?			True
isolated nodes			275
self-loops			0
density				0.000839
clustering coefficient		0.294269
min/max/avg degree		0, 6171, 16.774361
degree assortativity		-0.064750
number of connected components	290
size of largest component	19698 (98.49 %)


### User dataset construction

In [None]:
class UserInfoDataset():
  def __init__(self, data, max_len):
    self.data = data
    self.max_len = max_len

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):

    user = self.data[idx]
    seq_len = user['seq_len']

    tr_act_seq = np.zeros((self.max_len,)).astype('int32')
    tr_act_seq[:seq_len] = np.array(user['train_act_seq'])
    tr_act_seq = np.transpose(tr_act_seq)

    tr_time_seq = np.zeros((self.max_len,)).astype('int32')
    tr_time_seq[:seq_len] = user['train_time_seq']
    tr_time_seq = np.transpose(tr_time_seq)

    t_act_seq = np.zeros((self.max_len, )).astype('int32')
    t_act_seq[:seq_len] = user['test_act_seq']
    t_act_seq = np.transpose(t_act_seq)

    t_time_seq = np.zeros((self.max_len, )).astype('int32')
    t_time_seq[:seq_len] = user['test_time_seq']
    t_time_seq = np.transpose(t_time_seq)


    return user['user'], tr_act_seq, \
    tr_time_seq, user['train_act_label'], \
    user['train_time_label'], t_act_seq, \
    t_time_seq, user['test_act_label'], \
    user['test_time_label'], user['seq_len']

In [None]:
from torch.utils.data import Subset
from torch.utils.data import DataLoader

user_dataset = UserInfoDataset(users, 30)

n = len(user_dataset)

indices = np.arange(n)
indices = np.random.permutation(indices)

train_indices = indices [:int(0.8*n)]
test_indices = indices[int(0.8*n):]

user_train_dataset = Subset(user_dataset, train_indices)
user_test_dataset = Subset(user_dataset, test_indices)

user_dataloader = DataLoader(user_dataset, batch_size = 64, shuffle = True)
user_train_dataloader = DataLoader(user_train_dataset, batch_size=64, shuffle=True)
user_test_dataloader = DataLoader(user_test_dataset, batch_size=64, shuffle=True)

item_emb  = nn.init.xavier_uniform_(torch.empty(186, 128))

## Pre-training experiments

In [None]:
rnn = nn.RNN(128, 128, batch_first = True)
norm = nn.BatchNorm1d(128)

### Getting initial user representations

In [None]:
user_representations = {}
for users, train_input, train_time, train_label, train_time_label, test_input, test_time, test_label, test_time_label, seq_len in user_dataloader:
      comb_input = np.concatenate([np.expand_dims(train_input, axis=-1),
                                                np.expand_dims(train_time, axis=-1)], axis=2)
      model_input = comb_input
      model_output = train_label
      rnn_input_emb = item_emb[model_input[:, :, 0]]
      test_comb_input = np.concatenate([np.expand_dims(test_input, axis=-1),
                                                np.expand_dims(test_time, axis=-1)], axis=2)
      test_model_input = test_comb_input
      test_rnn_input_emb = item_emb[test_model_input[:, :, 0]]
      x, h = rnn(rnn_input_emb)
      test_x, test_h = rnn(test_rnn_input_emb)
      hx = torch.zeros(x.shape[0], x.shape[2])#.to(device)
      for i in range(hx.shape[0]):
        hx[i] = x[i][seq_len[i] - 1]
      user_representation = norm(hx)
      test_hx = torch.zeros(test_x.shape[0], test_x.shape[2])
      for i in range(test_hx.shape[0]):
        test_hx[i] = test_x[i][seq_len[i] - 1]
      test_user_representation = norm(test_hx)
      for i in range(len(users)):
        user_representations[users[i].item()] = (users[i].item(), user_representation[i], train_label[i].item(), seq_len[i].item(), test_user_representation[i], test_label[i].item())

In [None]:
user_ids = []
features = []
train_labels = []
seq_lens = []
test_features = []
test_labels = []
for graphid in Gcl.nodes():
    pid = graphid2pid[graphid]
    if graphid in user_representations.keys():
      user_ids.append(user_representations[graphid][0])
      features.append(user_representations[graphid][1].detach().numpy())
      train_labels.append(user_representations[graphid][2])
      test_features.append(user_representations[graphid][4].detach().numpy())
      seq_lens.append(user_representations[graphid][3])
      test_labels.append(user_representations[graphid][5])
user_ids = np.array(user_ids)
features = np.array(features)
train_labels = np.array(train_labels)
test_features = np.array(test_features)
seq_lens = np.array(seq_lens)
test_labels = np.array(test_labels)

In [None]:
dataset = Data(x=torch.tensor(features, dtype=torch.float), edge_index=torch.tensor(np.array(graph_edges).T, dtype=torch.int64))

In [None]:
friends

Unnamed: 0.1,Unnamed: 0,1st friend,2nd friend
0,0,0,1
1,1,0,2
2,2,0,3
3,3,0,4
4,4,0,5
...,...,...,...
335499,448763,20000,2892
335500,448764,20000,2937
335501,448765,20000,2940
335502,448766,20000,4375


In [None]:
dataset.user_ids = torch.tensor(user_ids, dtype=torch.int64)
dataset.train_labels = torch.tensor(train_labels, dtype=torch.int64)
dataset.seq_lens = torch.tensor(seq_lens, dtype=torch.int64)
dataset.test_labels = torch.tensor(test_labels, dtype=torch.int64)
dataset.test_features = torch.tensor(test_features, dtype=torch.float)

In [None]:
dataset.edge_index.shape

torch.Size([2, 167752])

In [None]:
print(len(features))

20001


### Deep Graph Infomax model

In [None]:
eps = 1e-15
class DeepGraphInfomax(torch.nn.Module):
    def __init__(
        self,
        hidden_channels,
        encoder,
        summary,
        corruption,
    ):
        super().__init__()
        self.hidden_channels = hidden_channels
        self.encoder = encoder
        self.summary = summary
        self.corruption = corruption

        self.weight = Parameter(torch.empty(hidden_channels, hidden_channels))

        self.reset_parameters()

    def reset_parameters(self):
        reset(self.encoder)
        reset(self.summary)
        uniform(self.hidden_channels, self.weight)

    def forward(self, *args, **kwargs) -> Tuple[Tensor, Tensor, Tensor]:
        pos_z = self.encoder(*args, **kwargs)

        cor = self.corruption(*args, **kwargs)
        cor = cor if isinstance(cor, tuple) else (cor, )
        cor_args = cor[:len(args)]
        cor_kwargs = copy.copy(kwargs)
        for key, value in zip(kwargs.keys(), cor[len(args):]):
            cor_kwargs[key] = value

        neg_z = self.encoder(*cor_args, **cor_kwargs)

        summary = self.summary(pos_z, *args, **kwargs)

        return pos_z, neg_z, summary

    def discriminate(self, z: Tensor, summary: Tensor,
                     sigmoid: bool = True) -> Tensor:

        summary = summary.t() if summary.dim() > 1 else summary
        value = torch.matmul(z, torch.matmul(self.weight, summary))
        return torch.sigmoid(value) if sigmoid else value

    def loss(self, pos_z: Tensor, neg_z: Tensor, summary: Tensor) -> Tensor:
        pos_loss = -torch.log(
            self.discriminate(pos_z, summary, sigmoid=True) + eps).mean()
        neg_loss = -torch.log(1 -
                              self.discriminate(neg_z, summary, sigmoid=True) +
                              eps).mean()

        return pos_loss + neg_loss


    def __repr__(self) -> str:
        return f'{self.__class__.__name__}({self.hidden_channels})'

In [None]:
class Encoder(nn.Module):
    def __init__(self, interm_channels, hidden_channels, num_features):
        super().__init__()

        self.num_features = num_features
        self.rnn = nn.RNN(128, 128, batch_first = True)
        self.norm = nn.BatchNorm1d(128)
        self.conv1 = GATv2Conv(self.num_features, interm_channels)
        self.prelu1 = nn.PReLU(interm_channels)
        self.dropout = nn.Dropout(0.2)
        self.conv2 = GATv2Conv(interm_channels, hidden_channels)
        self.prelu2 = nn.PReLU(hidden_channels)
        self.conv3 = GATv2Conv(hidden_channels, hidden_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = self.prelu1(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        x = self.prelu2(x)
        x = self.dropout(x)
        x = self.conv3(x, edge_index)
        return x

### Augmentation construction helping functions

In [None]:
def index_to_mask(index):

    index = index.view(-1)
    size = int(index.max()) + 1

    mask = index.new_zeros(20001, dtype=torch.bool)
    mask[index] = True
    return mask

In [None]:
def map_index(indices, edge_index):
  dictionary = {}
  label = 0
  for ind in indices:
    dictionary[ind] = label
    label += 1
  for i in range(edge_index.shape[1]):
    # if edge_index[0][i] in dictionary.keys():
    edge_index[0][i] = dictionary[edge_index[0][i].item()]
    # if edge_index[1][i] in dictionary.keys():
    edge_index[1][i] = dictionary[edge_index[1][i].item()]
  return edge_index

In [None]:
def construct_neighbors_dictionary(edge_index):
  dictionary = {}
  for i in range(edge_index.shape[1]):
    if edge_index[0][i].item() not in dictionary.keys():
      dictionary[edge_index[0][i].item()] = []
    dictionary[edge_index[0][i].item()].append(edge_index[1][i].item())
    if edge_index[1][i].item() not in dictionary.keys():
      dictionary[edge_index[1][i].item()] = []
    dictionary[edge_index[1][i].item()].append(edge_index[0][i].item())
  return dictionary

In [None]:
neighbors = construct_neighbors_dictionary(dataset.edge_index)

### Augmentations

In [None]:
def node_mix(x, edge_index):
    return x[torch.randperm(x.size(0))], edge_index

In [None]:
def node_dropout(x, edge_index):
  indices = []
  deleted_indices = []
  for i in range(x.shape[0]):
    drop = random.random()
    if drop >= 0.3:
      indices.append(i)
    else:
      deleted_indices.append(i)
  x = torch.index_select(x, 0, torch.tensor(indices, dtype = torch.int32))
  node_mask = index_to_mask(torch.tensor(indices, dtype = torch.int32))
  edge_mask = node_mask[edge_index[0]] & node_mask[edge_index[1]]
  edge_index = edge_index[:, edge_mask]
  edge_index = map_index(indices, edge_index)

  return x, edge_index

In [None]:
def random_walk(x, edge_index):
  node = 0
  rwalk = [node]
  walk_length = 10000
  for i in range(walk_length - 1):
    temp = neighbors[node]
    temp = list(set(temp) - set(rwalk))
    if len(temp) == 0:
      break
    new_node = random.choice(temp)
    rwalk.append(new_node)
    node = new_node
  x = torch.index_select(x, 0, torch.tensor(rwalk, dtype = torch.int32))
  node_mask = index_to_mask(torch.tensor(rwalk, dtype = torch.int32))
  edge_mask = node_mask[edge_index[0]] & node_mask[edge_index[1]]
  edge_index = edge_index[:, edge_mask]
  edge_index = map_index(rwalk, edge_index)

  return x, edge_index

In [None]:
def edge_dropout(x, edge_index):
  indices = []
  for i in range(edge_index.shape[1]):
    drop = random.random()
    if drop >= 0.3:
      indices.append(i)
  edge_index = torch.index_select(edge_index, 1, torch.tensor(indices, dtype = torch.int32))
  return x, edge_index

### Link prediction head

In [None]:
class LinkPredHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc_1 = nn.Linear(256, 1)

    def forward(self, x, edge_index):
        x_src, x_dst = x[edge_index[0]], x[edge_index[1]]
        # x = torch.cat((x_src, x_dst), dim = 1)
        # x = self.fc_1(x)
        return torch.sum(x_src * x_dst, dim=1)

### Next action prediction head

In [None]:
class NextActionPredHead(nn.Module):
  def __init__(self):
    super(NextActionPredHead, self).__init__()
    self.fc1 = nn.Linear(128, num_classes)
    self.conv1 = GATv2Conv(128, 128)
    self.rnn = nn.RNN(128, 128, batch_first = True)

  def forward(self, x):
    x = self.fc1(x)
    return x

In [None]:
def negative_sampling(edge_index, num_nodes):
    mask_1 = torch.rand(edge_index.size(1)) < 0.5
    mask_2 = ~mask_1

    neg_edge_index = edge_index.clone()
    neg_edge_index[0, mask_1] = torch.randint(num_nodes, (mask_1.sum(),))
    neg_edge_index[1, mask_2] = torch.randint(num_nodes, (mask_2.sum(),))

    return neg_edge_index

### User representations dataset

In [None]:
class UserDataset():
    def __init__(self, user_representations):
        self.user_representations = user_representations

    def __len__(self):
        return len(self.user_representations)

    def __getitem__(self, idx):

        return self.user_representations[idx][0], self.user_representations[idx][1], self.user_representations[idx][2]

### Training functions

In [None]:
def train_only_infomax(epoch):
    model.train()
    optimizer.zero_grad()
    pos_z, neg_z, summary = model(dataset.x, dataset.edge_index)
    loss_infomax = model.loss(pos_z, neg_z, summary)

    pos_out = linkmodel(pos_z, dataset.edge_index)
    neg_edge_index = negative_sampling(dataset.edge_index, dataset.num_nodes)
    neg_out = linkmodel(pos_z, neg_edge_index)

    out = torch.cat([pos_out, neg_out])
    gt = torch.cat([torch.ones_like(pos_out), torch.zeros_like(neg_out)])
    loss_link_pred = loss_bce(out, gt)

    loss =  loss_infomax
    loss.backward()
    optimizer.step()
    print(f'Epoch: {epoch}, Loss: {loss}')
    if epoch == num_epoch:
      return pos_z, neg_z, summary, loss.item()
    else:
      return loss.item()

In [None]:
def train_infomax_linkpred(epoch):
    model.train()
    optimizer.zero_grad()
    pos_z, neg_z, summary = model(dataset.x, dataset.edge_index)
    loss_infomax = model.loss(pos_z, neg_z, summary)

    pos_out = linkmodel(pos_z, dataset.edge_index)
    neg_edge_index = negative_sampling(dataset.edge_index, dataset.num_nodes)
    neg_out = linkmodel(pos_z, neg_edge_index)

    out = torch.cat([pos_out, neg_out])
    gt = torch.cat([torch.ones_like(pos_out), torch.zeros_like(neg_out)])
    loss_link_pred = loss_bce(out, gt)
    loss =  2*loss_infomax + 0.5*loss_link_pred
    loss.backward()
    optimizer.step()
    print(f'Epoch: {epoch}, Loss: {loss}')
    if epoch == num_epoch:
      return pos_z, neg_z, summary, loss.item()
    else:
      return loss.item()

In [None]:
def test(head_model):
  metrics_val = []
  head_model.eval()
  index = 0
  for users, vectors, labels in user_test_dataloader:
      test_probs = head_model(vectors)
      test_pred = torch.argmax(test_probs, axis = 1)#torch.Tensor([0]*20001)

      test_one_hot = torch.zeros(len(test_probs), num_classes)
      test_one_hot[torch.arange(len(test_one_hot)), labels] = 1


      hits1, hits5, hits10, hits20, hits50, hits100, map1, map5, map10, map20, map50, map100, ndcg1, ndcg5, ndcg10, ndcg20, ndcg50, ndcg100 \
       = get_metrics_(test_probs, labels, test_one_hot)

      metrics_val.append([hits1, hits5, hits10, hits20, hits50, hits100, map1, map5, map10, map20, map50, map100, ndcg1, ndcg5, ndcg10, ndcg20, ndcg50, ndcg100])#, hits5, hits10, hits20, hits50, hits100, map1, map5, map10, map20, map50, map100])

  mean = torch.Tensor(metrics_val).mean(axis=0)
  test_hits1, test_hits5, test_hits10, test_hits20, test_hits50, test_hits100, test_map1, test_map5, test_map10, test_map20, test_map50, test_map100, \
  test_ndcg1, test_ndcg5, test_ndcg10, test_ndcg20, test_ndcg50, test_ndcg100 = mean

  return test_hits1, test_hits5, test_hits10, test_hits20, test_hits50, test_hits100, test_map1, test_map5, test_map10, test_map20, test_map50, test_map100, \
  test_ndcg1, test_ndcg5, test_ndcg10, test_ndcg20, test_ndcg50, test_ndcg100

In [None]:
def middle_test(model):
  hits_1 = []
  for users, vectors, labels in user_test_dataloader:

    test_probs = model(vectors)

    test_pred = torch.argmax(test_probs, axis = 1)

    test_one_hot = torch.zeros(len(test_probs), num_classes)
    test_one_hot[torch.arange(len(test_one_hot)), labels] = 1
    hits_1.append(top_k_accuracy_score(labels, test_probs.cpu().detach().numpy(), k=1, labels = classes))
  print(f'Hits@1: {torch.mean(torch.Tensor(hits_1))}')


In [None]:
def train2task():
  loss_fn = nn.CrossEntropyLoss()
  head_model  = NextActionPredHead()

  head_model.train()
  optimizer = torch.optim.Adam(head_model.parameters(), lr = 0.0001)
  lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.97, last_epoch=-1)
  losses = []

  for i in range(200):
    for users, vectors, labels in user_train_dataloader:

        optimizer.zero_grad()

        probs = head_model(vectors)
        pred = torch.argmax(probs, axis = 1)
        one_hot = torch.zeros(len(probs), num_classes)
        one_hot[torch.arange(len(one_hot)), labels] = 1
        # print(one_hot.shape)
        # print(probs.shape)
        loss = loss_fn(probs, one_hot)

        losses.append(loss)
        loss.backward()
        optimizer.step()
    mean_loss = torch.Tensor(losses).mean(axis=0).item()
    print(f'Epoch: {i}, Loss: {mean_loss}')
    middle_test(head_model)

  hits1, hits5, hits10, hits20, hits50, hits100, map1, map5, map10, map20, map50, map100, ndcg1, ndcg5, ndcg10, ndcg20, ndcg50, ndcg100 = test(head_model)
  print_metrics(hits1, hits5, hits10, hits20, hits50, hits100, map1, map5, map10, map20, map50, map100, ndcg1, ndcg5, ndcg10, ndcg20, ndcg50, ndcg100)
  return head_model

### Node mix augmentation, only infomax

In [None]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
model = DeepGraphInfomax(hidden_channels=128,
                         encoder=Encoder(interm_channels=74, hidden_channels=128,
                                         num_features=128),
                         summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)),
                         corruption=node_mix).to(device)
linkmodel = LinkPredHead()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)
loss_bce = nn.BCEWithLogitsLoss()

In [None]:
num_epoch = 100
loss_epoch = []
for epoch in tqdm(range(1, num_epoch+1)):
  if epoch == num_epoch:
    pos_z, neg_z, summary, loss = train_only_infomax(epoch)
  else:
    loss = train_only_infomax(epoch)
  loss_epoch.append(loss)

test_pos_z, test_neg_z, test_summary = model(dataset.test_features, dataset.edge_index)

  1%|          | 1/100 [00:07<12:41,  7.69s/it]

Epoch: 1, Loss: 2.0441513061523438


  2%|▏         | 2/100 [00:15<12:31,  7.67s/it]

Epoch: 2, Loss: 1.3706047534942627


  3%|▎         | 3/100 [00:21<11:28,  7.10s/it]

Epoch: 3, Loss: 1.5238572359085083


  4%|▍         | 4/100 [00:29<11:41,  7.31s/it]

Epoch: 4, Loss: 1.4519790410995483


  5%|▌         | 5/100 [00:35<11:03,  6.99s/it]

Epoch: 5, Loss: 1.0741887092590332


  6%|▌         | 6/100 [00:43<11:13,  7.16s/it]

Epoch: 6, Loss: 1.1899919509887695


  7%|▋         | 7/100 [00:49<10:40,  6.88s/it]

Epoch: 7, Loss: 0.8261829614639282


  8%|▊         | 8/100 [00:57<10:55,  7.12s/it]

Epoch: 8, Loss: 0.6506307721138


  9%|▉         | 9/100 [01:03<10:28,  6.90s/it]

Epoch: 9, Loss: 0.4852405786514282


 10%|█         | 10/100 [01:11<10:41,  7.13s/it]

Epoch: 10, Loss: 0.8739855289459229


 11%|█         | 11/100 [01:17<10:14,  6.91s/it]

Epoch: 11, Loss: 0.5356216430664062


 12%|█▏        | 12/100 [01:25<10:27,  7.14s/it]

Epoch: 12, Loss: 0.5005763173103333


 13%|█▎        | 13/100 [01:33<10:45,  7.42s/it]

Epoch: 13, Loss: 0.3467347025871277


 14%|█▍        | 14/100 [01:40<10:34,  7.38s/it]

Epoch: 14, Loss: 0.26977667212486267


 15%|█▌        | 15/100 [01:47<10:15,  7.24s/it]

Epoch: 15, Loss: 0.5275331139564514


 16%|█▌        | 16/100 [01:54<10:06,  7.22s/it]

Epoch: 16, Loss: 0.465753972530365


 17%|█▋        | 17/100 [02:01<09:47,  7.08s/it]

Epoch: 17, Loss: 0.3546474575996399


 18%|█▊        | 18/100 [02:09<09:51,  7.22s/it]

Epoch: 18, Loss: 0.5445301532745361


 19%|█▉        | 19/100 [02:15<09:34,  7.09s/it]

Epoch: 19, Loss: 0.3433173894882202


 20%|██        | 20/100 [02:23<09:30,  7.14s/it]

Epoch: 20, Loss: 0.43800729513168335


 21%|██        | 21/100 [02:29<09:16,  7.04s/it]

Epoch: 21, Loss: 0.3413238525390625


 22%|██▏       | 22/100 [02:37<09:12,  7.09s/it]

Epoch: 22, Loss: 0.36065128445625305


 23%|██▎       | 23/100 [02:43<08:55,  6.96s/it]

Epoch: 23, Loss: 0.33490121364593506


 24%|██▍       | 24/100 [02:51<08:56,  7.06s/it]

Epoch: 24, Loss: 0.3043622374534607


 25%|██▌       | 25/100 [02:57<08:37,  6.90s/it]

Epoch: 25, Loss: 0.374714195728302


 26%|██▌       | 26/100 [03:05<08:41,  7.04s/it]

Epoch: 26, Loss: 0.45965757966041565


 27%|██▋       | 27/100 [03:11<08:23,  6.90s/it]

Epoch: 27, Loss: 0.28340017795562744


 28%|██▊       | 28/100 [03:19<08:27,  7.05s/it]

Epoch: 28, Loss: 0.5449036955833435


 29%|██▉       | 29/100 [03:25<08:08,  6.88s/it]

Epoch: 29, Loss: 0.3212209641933441


 30%|███       | 30/100 [03:33<08:14,  7.07s/it]

Epoch: 30, Loss: 0.3507351279258728


 31%|███       | 31/100 [03:39<07:55,  6.89s/it]

Epoch: 31, Loss: 0.27348196506500244


 32%|███▏      | 32/100 [03:47<08:03,  7.11s/it]

Epoch: 32, Loss: 0.28317779302597046


 33%|███▎      | 33/100 [03:54<07:59,  7.16s/it]

Epoch: 33, Loss: 0.30881232023239136


 34%|███▍      | 34/100 [04:02<08:19,  7.56s/it]

Epoch: 34, Loss: 0.2708754539489746


 35%|███▌      | 35/100 [04:09<07:49,  7.22s/it]

Epoch: 35, Loss: 0.26705610752105713


 36%|███▌      | 36/100 [04:16<07:50,  7.36s/it]

Epoch: 36, Loss: 0.22535641491413116


 37%|███▋      | 37/100 [04:23<07:28,  7.12s/it]

Epoch: 37, Loss: 0.23792573809623718


 38%|███▊      | 38/100 [04:31<07:32,  7.30s/it]

Epoch: 38, Loss: 0.2754320502281189


 39%|███▉      | 39/100 [04:37<07:08,  7.03s/it]

Epoch: 39, Loss: 0.25759458541870117


 40%|████      | 40/100 [04:46<07:39,  7.65s/it]

Epoch: 40, Loss: 0.25509074330329895


 41%|████      | 41/100 [04:53<07:09,  7.28s/it]

Epoch: 41, Loss: 0.2309490442276001


 42%|████▏     | 42/100 [05:00<07:07,  7.37s/it]

Epoch: 42, Loss: 0.20666518807411194


 43%|████▎     | 43/100 [05:07<06:43,  7.08s/it]

Epoch: 43, Loss: 0.18510323762893677


 44%|████▍     | 44/100 [05:14<06:46,  7.25s/it]

Epoch: 44, Loss: 0.16460363566875458


 45%|████▌     | 45/100 [05:21<06:24,  6.98s/it]

Epoch: 45, Loss: 0.2396889477968216


 46%|████▌     | 46/100 [05:28<06:27,  7.18s/it]

Epoch: 46, Loss: 0.1893056035041809


 47%|████▋     | 47/100 [05:35<06:09,  6.97s/it]

Epoch: 47, Loss: 0.1790321171283722


 48%|████▊     | 48/100 [05:42<06:12,  7.17s/it]

Epoch: 48, Loss: 0.1982746422290802


 49%|████▉     | 49/100 [05:49<05:52,  6.91s/it]

Epoch: 49, Loss: 3.456930637359619


 50%|█████     | 50/100 [05:56<05:57,  7.15s/it]

Epoch: 50, Loss: 0.22017526626586914


 51%|█████     | 51/100 [06:03<05:39,  6.93s/it]

Epoch: 51, Loss: 0.3071228265762329


 52%|█████▏    | 52/100 [06:10<05:38,  7.06s/it]

Epoch: 52, Loss: 0.3499699831008911


 53%|█████▎    | 53/100 [06:17<05:32,  7.09s/it]

Epoch: 53, Loss: 0.5960848927497864


 54%|█████▍    | 54/100 [06:25<05:40,  7.40s/it]

Epoch: 54, Loss: 0.5207613706588745


 55%|█████▌    | 55/100 [06:32<05:18,  7.08s/it]

Epoch: 55, Loss: 0.5178110003471375


 56%|█████▌    | 56/100 [06:39<05:18,  7.24s/it]

Epoch: 56, Loss: 0.3621097803115845


 57%|█████▋    | 57/100 [06:46<05:00,  6.98s/it]

Epoch: 57, Loss: 1.0506043434143066


 58%|█████▊    | 58/100 [06:53<05:02,  7.19s/it]

Epoch: 58, Loss: 0.256161630153656


 59%|█████▉    | 59/100 [07:00<04:44,  6.93s/it]

Epoch: 59, Loss: 0.3388850688934326


 60%|██████    | 60/100 [07:07<04:45,  7.13s/it]

Epoch: 60, Loss: 0.1978786587715149


 61%|██████    | 61/100 [07:14<04:28,  6.89s/it]

Epoch: 61, Loss: 0.2442983239889145


 62%|██████▏   | 62/100 [07:21<04:30,  7.13s/it]

Epoch: 62, Loss: 0.26686060428619385


 63%|██████▎   | 63/100 [07:28<04:15,  6.90s/it]

Epoch: 63, Loss: 0.36671730875968933


 64%|██████▍   | 64/100 [07:35<04:16,  7.14s/it]

Epoch: 64, Loss: 0.20375382900238037


 65%|██████▌   | 65/100 [07:42<04:01,  6.90s/it]

Epoch: 65, Loss: 0.16889207065105438


 66%|██████▌   | 66/100 [07:50<04:02,  7.13s/it]

Epoch: 66, Loss: 0.22132810950279236


 67%|██████▋   | 67/100 [07:56<03:47,  6.89s/it]

Epoch: 67, Loss: 0.27139562368392944


 68%|██████▊   | 68/100 [08:04<03:48,  7.15s/it]

Epoch: 68, Loss: 0.172154501080513


 69%|██████▉   | 69/100 [08:10<03:34,  6.91s/it]

Epoch: 69, Loss: 0.20965620875358582


 70%|███████   | 70/100 [08:18<03:33,  7.13s/it]

Epoch: 70, Loss: 0.1667444407939911


 71%|███████   | 71/100 [08:24<03:20,  6.90s/it]

Epoch: 71, Loss: 0.1637231707572937


 72%|███████▏  | 72/100 [08:32<03:19,  7.13s/it]

Epoch: 72, Loss: 0.18219301104545593


 73%|███████▎  | 73/100 [08:38<03:06,  6.91s/it]

Epoch: 73, Loss: 0.2087913304567337


 74%|███████▍  | 74/100 [08:47<03:15,  7.50s/it]

Epoch: 74, Loss: 0.21877218782901764


 75%|███████▌  | 75/100 [08:53<02:59,  7.17s/it]

Epoch: 75, Loss: 0.23566022515296936


 76%|███████▌  | 76/100 [09:01<02:55,  7.32s/it]

Epoch: 76, Loss: 0.20990264415740967


 77%|███████▋  | 77/100 [09:07<02:41,  7.04s/it]

Epoch: 77, Loss: 0.19380119442939758


 78%|███████▊  | 78/100 [09:15<02:39,  7.24s/it]

Epoch: 78, Loss: 0.25988447666168213


 79%|███████▉  | 79/100 [09:21<02:26,  6.99s/it]

Epoch: 79, Loss: 0.17591333389282227


 80%|████████  | 80/100 [09:29<02:23,  7.19s/it]

Epoch: 80, Loss: 0.1559648960828781


 81%|████████  | 81/100 [09:35<02:11,  6.94s/it]

Epoch: 81, Loss: 0.20770439505577087


 82%|████████▏ | 82/100 [09:43<02:08,  7.16s/it]

Epoch: 82, Loss: 0.19117924571037292


 83%|████████▎ | 83/100 [09:49<01:57,  6.90s/it]

Epoch: 83, Loss: 0.1795436143875122


 84%|████████▍ | 84/100 [09:57<01:54,  7.14s/it]

Epoch: 84, Loss: 0.1867300271987915


 85%|████████▌ | 85/100 [10:04<01:43,  6.93s/it]

Epoch: 85, Loss: 0.15875087678432465


 86%|████████▌ | 86/100 [10:11<01:40,  7.15s/it]

Epoch: 86, Loss: 0.14846453070640564


 87%|████████▋ | 87/100 [10:18<01:29,  6.92s/it]

Epoch: 87, Loss: 0.1509300172328949


 88%|████████▊ | 88/100 [10:25<01:25,  7.12s/it]

Epoch: 88, Loss: 0.19810311496257782


 89%|████████▉ | 89/100 [10:32<01:15,  6.90s/it]

Epoch: 89, Loss: 0.1745585948228836


 90%|█████████ | 90/100 [10:39<01:10,  7.09s/it]

Epoch: 90, Loss: 0.2344718873500824


 91%|█████████ | 91/100 [10:46<01:02,  6.91s/it]

Epoch: 91, Loss: 0.17622226476669312


 92%|█████████▏| 92/100 [10:53<00:56,  7.07s/it]

Epoch: 92, Loss: 0.13568104803562164


 93%|█████████▎| 93/100 [11:00<00:48,  6.89s/it]

Epoch: 93, Loss: 0.13894811272621155


 94%|█████████▍| 94/100 [11:08<00:44,  7.39s/it]

Epoch: 94, Loss: 0.1394711285829544


 95%|█████████▌| 95/100 [11:15<00:36,  7.25s/it]

Epoch: 95, Loss: 1.7329628467559814


 96%|█████████▌| 96/100 [11:22<00:28,  7.20s/it]

Epoch: 96, Loss: 0.16524195671081543


 97%|█████████▋| 97/100 [11:29<00:21,  7.13s/it]

Epoch: 97, Loss: 0.1923018991947174


 98%|█████████▊| 98/100 [11:36<00:14,  7.07s/it]

Epoch: 98, Loss: 0.20094075798988342


 99%|█████████▉| 99/100 [11:43<00:07,  7.05s/it]

Epoch: 99, Loss: 0.26364076137542725


100%|██████████| 100/100 [11:50<00:00,  7.10s/it]

Epoch: 100, Loss: 0.25947657227516174





In [None]:
train_user_representations = {}
test_user_representations = {}
for i in range(pos_z.shape[0]):
  train_user_representations[i] = (dataset.user_ids[i], pos_z[i].detach(), dataset.train_labels[i])
for i in range(test_pos_z.shape[0]):
  test_user_representations[i] = (dataset.user_ids[i], test_pos_z[i].detach(), dataset.test_labels[i])

In [None]:
user_dataset = UserDataset(train_user_representations)

n = len(user_dataset)

indices = np.arange(n)
indices = np.random.permutation(indices)

train_indices = indices [:int(0.8*n)]
test_indices = indices[int(0.8*n):]

user_train_dataset = Subset(user_dataset, train_indices)
user_test_dataset = Subset(user_dataset, test_indices)

user_train_dataloader = DataLoader(user_train_dataset, batch_size=64, shuffle=True)
user_test_dataloader = DataLoader(user_test_dataset, batch_size=64, shuffle=True)

In [None]:
num_classes = 186
classes = np.arange(0, num_classes)

In [None]:
head_model = train2task()
torch.save(head_model.state_dict(), 'node_mix_only_infomax_model.pth')

Epoch: 0, Loss: 4.948060035705566
Hits@1: 0.04782196879386902
Epoch: 1, Loss: 4.755845546722412
Hits@1: 0.054548460990190506
Epoch: 2, Loss: 4.627336502075195
Hits@1: 0.07984607666730881
Epoch: 3, Loss: 4.5391411781311035
Hits@1: 0.07826779037714005
Epoch: 4, Loss: 4.475954532623291
Hits@1: 0.07955297082662582
Epoch: 5, Loss: 4.428366661071777
Hits@1: 0.08254419267177582
Epoch: 6, Loss: 4.390832424163818
Hits@1: 0.08524230122566223
Epoch: 7, Loss: 4.360129356384277
Hits@1: 0.08921055495738983
Epoch: 8, Loss: 4.334252834320068
Hits@1: 0.09741011261940002
Epoch: 9, Loss: 4.311925888061523
Hits@1: 0.10312950611114502
Epoch: 10, Loss: 4.292300224304199
Hits@1: 0.10609066486358643
Epoch: 11, Loss: 4.274774074554443
Hits@1: 0.10858585685491562
Epoch: 12, Loss: 4.2589192390441895
Hits@1: 0.11032196879386902
Epoch: 13, Loss: 4.244439125061035
Hits@1: 0.1118401288986206
Epoch: 14, Loss: 4.231086254119873
Hits@1: 0.11354617774486542
Epoch: 15, Loss: 4.218681812286377
Hits@1: 0.11503426730632782


### Node dropout augmentation, only infomax

In [None]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
model = DeepGraphInfomax(hidden_channels=128,
                         encoder=Encoder(interm_channels=74, hidden_channels=128,
                                         num_features=128),
                         summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)),
                         corruption=node_dropout).to(device)
linkmodel = LinkPredHead()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)
loss_bce = nn.BCEWithLogitsLoss()

In [None]:
num_epoch = 100
loss_epoch = []
for epoch in tqdm(range(1, num_epoch+1)):
  if epoch == num_epoch:
    pos_z, neg_z, summary, loss = train_only_infomax(epoch)
  else:
    loss = train_only_infomax(epoch)
  loss_epoch.append(loss)
test_pos_z, test_neg_z, test_summary = model(dataset.test_features, dataset.edge_index)

  1%|          | 1/100 [00:10<16:43, 10.14s/it]

Epoch: 1, Loss: 1.434377908706665


  2%|▏         | 2/100 [00:18<14:57,  9.16s/it]

Epoch: 2, Loss: 1.6858642101287842


  3%|▎         | 3/100 [00:25<13:03,  8.07s/it]

Epoch: 3, Loss: 1.4676151275634766


  4%|▍         | 4/100 [00:33<13:02,  8.15s/it]

Epoch: 4, Loss: 1.4774839878082275


  5%|▌         | 5/100 [00:40<12:10,  7.69s/it]

Epoch: 5, Loss: 1.471682071685791


  6%|▌         | 6/100 [00:48<12:27,  7.95s/it]

Epoch: 6, Loss: 1.5278609991073608


  7%|▋         | 7/100 [00:55<11:48,  7.61s/it]

Epoch: 7, Loss: 1.493563175201416


  8%|▊         | 8/100 [01:04<12:00,  7.84s/it]

Epoch: 8, Loss: 1.413658857345581


  9%|▉         | 9/100 [01:11<11:37,  7.66s/it]

Epoch: 9, Loss: 1.4670861959457397


 10%|█         | 10/100 [01:18<11:22,  7.58s/it]

Epoch: 10, Loss: 1.2441600561141968


 11%|█         | 11/100 [01:26<11:23,  7.68s/it]

Epoch: 11, Loss: 1.3666484355926514


 12%|█▏        | 12/100 [01:34<11:03,  7.54s/it]

Epoch: 12, Loss: 1.3139638900756836


 13%|█▎        | 13/100 [01:42<11:24,  7.87s/it]

Epoch: 13, Loss: 1.436647891998291


 14%|█▍        | 14/100 [01:49<10:44,  7.50s/it]

Epoch: 14, Loss: 1.2914801836013794


 15%|█▌        | 15/100 [01:57<10:47,  7.62s/it]

Epoch: 15, Loss: 1.1389440298080444


 16%|█▌        | 16/100 [02:04<10:22,  7.41s/it]

Epoch: 16, Loss: 1.4007046222686768


 17%|█▋        | 17/100 [02:12<10:36,  7.67s/it]

Epoch: 17, Loss: 1.3011038303375244


 18%|█▊        | 18/100 [02:20<10:43,  7.85s/it]

Epoch: 18, Loss: 1.1491774320602417


 19%|█▉        | 19/100 [02:29<10:47,  8.00s/it]

Epoch: 19, Loss: 1.1295394897460938


 20%|██        | 20/100 [02:36<10:28,  7.86s/it]

Epoch: 20, Loss: 1.3284473419189453


 21%|██        | 21/100 [02:43<10:09,  7.72s/it]

Epoch: 21, Loss: 1.4089276790618896


 22%|██▏       | 22/100 [02:51<10:04,  7.75s/it]

Epoch: 22, Loss: 1.2242851257324219


 23%|██▎       | 23/100 [02:58<09:42,  7.57s/it]

Epoch: 23, Loss: 1.2039520740509033


 24%|██▍       | 24/100 [03:06<09:41,  7.65s/it]

Epoch: 24, Loss: 1.1842586994171143


 25%|██▌       | 25/100 [03:13<09:18,  7.45s/it]

Epoch: 25, Loss: 1.3701759576797485


 26%|██▌       | 26/100 [03:22<09:29,  7.70s/it]

Epoch: 26, Loss: 1.493953824043274


 27%|██▋       | 27/100 [03:28<09:01,  7.41s/it]

Epoch: 27, Loss: 1.0008609294891357


 28%|██▊       | 28/100 [03:36<09:08,  7.62s/it]

Epoch: 28, Loss: 1.1428148746490479


 29%|██▉       | 29/100 [03:43<08:45,  7.40s/it]

Epoch: 29, Loss: 0.9720543026924133


 30%|███       | 30/100 [03:51<08:50,  7.58s/it]

Epoch: 30, Loss: 1.5224685668945312


 31%|███       | 31/100 [03:59<08:36,  7.49s/it]

Epoch: 31, Loss: 1.3161308765411377


 32%|███▏      | 32/100 [04:07<08:40,  7.65s/it]

Epoch: 32, Loss: 1.5469311475753784


 33%|███▎      | 33/100 [04:14<08:35,  7.69s/it]

Epoch: 33, Loss: 0.9444793462753296


 34%|███▍      | 34/100 [04:22<08:18,  7.56s/it]

Epoch: 34, Loss: 1.0261913537979126


 35%|███▌      | 35/100 [04:30<08:25,  7.78s/it]

Epoch: 35, Loss: 1.585784673690796


 36%|███▌      | 36/100 [04:38<08:29,  7.95s/it]

Epoch: 36, Loss: 1.1696017980575562


 37%|███▋      | 37/100 [04:46<08:20,  7.95s/it]

Epoch: 37, Loss: 1.2164859771728516


 38%|███▊      | 38/100 [04:53<07:47,  7.54s/it]

Epoch: 38, Loss: 0.8097628355026245


 39%|███▉      | 39/100 [05:01<07:47,  7.66s/it]

Epoch: 39, Loss: 1.3112215995788574


 40%|████      | 40/100 [05:08<07:37,  7.62s/it]

Epoch: 40, Loss: 1.5048935413360596


 41%|████      | 41/100 [05:16<07:25,  7.55s/it]

Epoch: 41, Loss: 1.320711374282837


 42%|████▏     | 42/100 [05:23<07:20,  7.60s/it]

Epoch: 42, Loss: 0.9727059602737427


 43%|████▎     | 43/100 [05:30<07:02,  7.41s/it]

Epoch: 43, Loss: 1.4551103115081787


 44%|████▍     | 44/100 [05:39<07:08,  7.65s/it]

Epoch: 44, Loss: 1.502182960510254


 45%|████▌     | 45/100 [05:45<06:46,  7.40s/it]

Epoch: 45, Loss: 1.4356248378753662


 46%|████▌     | 46/100 [05:53<06:50,  7.61s/it]

Epoch: 46, Loss: 1.3414232730865479


 47%|████▋     | 47/100 [06:00<06:32,  7.40s/it]

Epoch: 47, Loss: 1.060674786567688


 48%|████▊     | 48/100 [06:09<06:40,  7.70s/it]

Epoch: 48, Loss: 1.32865309715271


 49%|████▉     | 49/100 [06:16<06:25,  7.57s/it]

Epoch: 49, Loss: 1.1351902484893799


 50%|█████     | 50/100 [06:24<06:27,  7.75s/it]

Epoch: 50, Loss: 1.5129246711730957


 51%|█████     | 51/100 [06:32<06:20,  7.76s/it]

Epoch: 51, Loss: 1.4626754522323608


 52%|█████▏    | 52/100 [06:39<06:02,  7.55s/it]

Epoch: 52, Loss: 1.3119215965270996


 53%|█████▎    | 53/100 [06:47<06:01,  7.68s/it]

Epoch: 53, Loss: 1.433196783065796


 54%|█████▍    | 54/100 [06:54<05:41,  7.42s/it]

Epoch: 54, Loss: 1.0706886053085327


 55%|█████▌    | 55/100 [07:04<06:05,  8.13s/it]

Epoch: 55, Loss: 1.1254193782806396


 56%|█████▌    | 56/100 [07:11<05:45,  7.86s/it]

Epoch: 56, Loss: 1.29121732711792


 57%|█████▋    | 57/100 [07:19<05:44,  8.01s/it]

Epoch: 57, Loss: 1.2475066184997559


 58%|█████▊    | 58/100 [07:26<05:23,  7.70s/it]

Epoch: 58, Loss: 1.3682318925857544


 59%|█████▉    | 59/100 [07:34<05:17,  7.74s/it]

Epoch: 59, Loss: 1.3778488636016846


 60%|██████    | 60/100 [07:41<05:02,  7.57s/it]

Epoch: 60, Loss: 0.9925915598869324


 61%|██████    | 61/100 [07:49<04:56,  7.60s/it]

Epoch: 61, Loss: 1.3898911476135254


 62%|██████▏   | 62/100 [07:57<04:54,  7.76s/it]

Epoch: 62, Loss: 1.3822014331817627


 63%|██████▎   | 63/100 [08:04<04:39,  7.56s/it]

Epoch: 63, Loss: 1.3768699169158936


 64%|██████▍   | 64/100 [08:13<04:43,  7.88s/it]

Epoch: 64, Loss: 1.2479016780853271


 65%|██████▌   | 65/100 [08:20<04:28,  7.68s/it]

Epoch: 65, Loss: 0.9863511323928833


 66%|██████▌   | 66/100 [08:28<04:27,  7.86s/it]

Epoch: 66, Loss: 1.253706693649292


 67%|██████▋   | 67/100 [08:35<04:07,  7.51s/it]

Epoch: 67, Loss: 1.4212249517440796


 68%|██████▊   | 68/100 [08:43<04:07,  7.72s/it]

Epoch: 68, Loss: 1.2988011837005615


 69%|██████▉   | 69/100 [08:50<03:51,  7.45s/it]

Epoch: 69, Loss: 1.0215299129486084


 70%|███████   | 70/100 [08:58<03:46,  7.56s/it]

Epoch: 70, Loss: 1.2296409606933594


 71%|███████   | 71/100 [09:05<03:35,  7.45s/it]

Epoch: 71, Loss: 1.2914327383041382


 72%|███████▏  | 72/100 [09:13<03:29,  7.49s/it]

Epoch: 72, Loss: 1.092300534248352


 73%|███████▎  | 73/100 [09:21<03:28,  7.74s/it]

Epoch: 73, Loss: 0.8509871959686279


 74%|███████▍  | 74/100 [09:28<03:20,  7.70s/it]

Epoch: 74, Loss: 0.8271988034248352


 75%|███████▌  | 75/100 [09:36<03:08,  7.56s/it]

Epoch: 75, Loss: 1.2355655431747437


 76%|███████▌  | 76/100 [09:43<03:03,  7.63s/it]

Epoch: 76, Loss: 1.5652377605438232


 77%|███████▋  | 77/100 [09:51<02:55,  7.64s/it]

Epoch: 77, Loss: 0.6490504145622253


 78%|███████▊  | 78/100 [09:58<02:45,  7.50s/it]

Epoch: 78, Loss: 1.328251600265503


 79%|███████▉  | 79/100 [10:06<02:39,  7.61s/it]

Epoch: 79, Loss: 1.3262487649917603


 80%|████████  | 80/100 [10:13<02:28,  7.44s/it]

Epoch: 80, Loss: 1.65399968624115


 81%|████████  | 81/100 [10:23<02:34,  8.12s/it]

Epoch: 81, Loss: 1.4593112468719482


 82%|████████▏ | 82/100 [10:30<02:22,  7.91s/it]

Epoch: 82, Loss: 1.4507437944412231


 83%|████████▎ | 83/100 [10:39<02:17,  8.07s/it]

Epoch: 83, Loss: 1.4126683473587036


 84%|████████▍ | 84/100 [10:47<02:09,  8.08s/it]

Epoch: 84, Loss: 1.4047579765319824


 85%|████████▌ | 85/100 [10:54<01:57,  7.84s/it]

Epoch: 85, Loss: 1.2953596115112305


 86%|████████▌ | 86/100 [11:02<01:50,  7.93s/it]

Epoch: 86, Loss: 1.3627493381500244


 87%|████████▋ | 87/100 [11:09<01:38,  7.55s/it]

Epoch: 87, Loss: 1.1717016696929932


 88%|████████▊ | 88/100 [11:17<01:32,  7.72s/it]

Epoch: 88, Loss: 1.198279857635498


 89%|████████▉ | 89/100 [11:24<01:22,  7.46s/it]

Epoch: 89, Loss: 1.4529569149017334


 90%|█████████ | 90/100 [11:32<01:16,  7.63s/it]

Epoch: 90, Loss: 0.8439153432846069


 91%|█████████ | 91/100 [11:40<01:10,  7.84s/it]

Epoch: 91, Loss: 1.5079786777496338


 92%|█████████▏| 92/100 [11:48<01:03,  7.88s/it]

Epoch: 92, Loss: 1.0753843784332275


 93%|█████████▎| 93/100 [11:56<00:54,  7.80s/it]

Epoch: 93, Loss: 1.3580098152160645


 94%|█████████▍| 94/100 [12:03<00:45,  7.59s/it]

Epoch: 94, Loss: 1.3142826557159424


 95%|█████████▌| 95/100 [12:11<00:37,  7.58s/it]

Epoch: 95, Loss: 0.8830126523971558


 96%|█████████▌| 96/100 [12:18<00:29,  7.42s/it]

Epoch: 96, Loss: 1.176457405090332


 97%|█████████▋| 97/100 [12:26<00:22,  7.66s/it]

Epoch: 97, Loss: 0.6481517553329468


 98%|█████████▊| 98/100 [12:33<00:15,  7.53s/it]

Epoch: 98, Loss: 0.5377374291419983


 99%|█████████▉| 99/100 [12:41<00:07,  7.80s/it]

Epoch: 99, Loss: 0.6868425607681274


100%|██████████| 100/100 [12:48<00:00,  7.69s/it]

Epoch: 100, Loss: 0.4218140244483948





In [None]:
train_user_representations = {}
test_user_representations = {}
for i in range(pos_z.shape[0]):
  train_user_representations[i] = (dataset.user_ids[i], pos_z[i].detach(), dataset.train_labels[i])
for i in range(test_pos_z.shape[0]):
  test_user_representations[i] = (dataset.user_ids[i], test_pos_z[i].detach(), dataset.test_labels[i])

In [None]:
user_dataset = UserDataset(train_user_representations)

n = len(user_dataset)

indices = np.arange(n)
indices = np.random.permutation(indices)

train_indices = indices [:int(0.8*n)]
test_indices = indices[int(0.8*n):]

user_train_dataset = Subset(user_dataset, train_indices)
user_test_dataset = Subset(user_dataset, test_indices)

user_train_dataloader = DataLoader(user_train_dataset, batch_size=64, shuffle=True)
user_test_dataloader = DataLoader(user_test_dataset, batch_size=64, shuffle=True)

In [None]:
num_classes = 186
classes = np.arange(0, num_classes)

In [None]:
head_model = train2task()
torch.save(head_model.state_dict(), 'node_dropout_only_infomax_model.pth')

Epoch: 0, Loss: 5.0717082023620605
Hits@1: 0.04041155427694321
Epoch: 1, Loss: 4.94442892074585
Hits@1: 0.05825366824865341
Epoch: 2, Loss: 4.8375325202941895
Hits@1: 0.08203312754631042
Epoch: 3, Loss: 4.747663974761963
Hits@1: 0.08848153799772263
Epoch: 4, Loss: 4.671762943267822
Hits@1: 0.09320887178182602
Epoch: 5, Loss: 4.607346534729004
Hits@1: 0.09345688670873642
Epoch: 6, Loss: 4.552304744720459
Hits@1: 0.09494498372077942
Epoch: 7, Loss: 4.5048627853393555
Hits@1: 0.09818422049283981
Epoch: 8, Loss: 4.463573455810547
Hits@1: 0.10140842944383621
Epoch: 9, Loss: 4.427262306213379
Hits@1: 0.1046476662158966
Epoch: 10, Loss: 4.394991874694824
Hits@1: 0.10905182361602783
Epoch: 11, Loss: 4.36603307723999
Hits@1: 0.11108104884624481
Epoch: 12, Loss: 4.339807510375977
Hits@1: 0.11474116146564484
Epoch: 13, Loss: 4.3158698081970215
Hits@1: 0.11943843215703964
Epoch: 14, Loss: 4.293863773345947
Hits@1: 0.12219666689634323
Epoch: 15, Loss: 4.273504734039307
Hits@1: 0.12543590366840363
E

### Edge dropout, only infomax

In [None]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
model = DeepGraphInfomax(hidden_channels=128,
                         encoder=Encoder(interm_channels=74, hidden_channels=128,
                                         num_features=128),
                         summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)),
                         corruption=edge_dropout).to(device)
linkmodel = LinkPredHead()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)
loss_bce = nn.BCEWithLogitsLoss()

In [None]:
num_epoch = 100
loss_epoch = []
for epoch in tqdm(range(1, num_epoch+1)):
  if epoch == num_epoch:
    pos_z, neg_z, summary, loss = train_only_infomax(epoch)
  else:
    loss = train_only_infomax(epoch)
  loss_epoch.append(loss)
test_pos_z, test_neg_z, test_summary = model(dataset.test_features, dataset.edge_index)

  1%|          | 1/100 [00:08<14:41,  8.90s/it]

Epoch: 1, Loss: 1.5315818786621094


  2%|▏         | 2/100 [00:14<11:28,  7.03s/it]

Epoch: 2, Loss: 1.4307504892349243


  3%|▎         | 3/100 [00:21<11:01,  6.82s/it]

Epoch: 3, Loss: 1.450163722038269


  4%|▍         | 4/100 [00:27<10:23,  6.49s/it]

Epoch: 4, Loss: 1.3130885362625122


  5%|▌         | 5/100 [00:33<09:57,  6.29s/it]

Epoch: 5, Loss: 1.4555678367614746


  6%|▌         | 6/100 [00:39<09:54,  6.32s/it]

Epoch: 6, Loss: 1.3451616764068604


  7%|▋         | 7/100 [00:45<09:24,  6.07s/it]

Epoch: 7, Loss: 1.2964013814926147


  8%|▊         | 8/100 [00:51<09:42,  6.33s/it]

Epoch: 8, Loss: 1.3289048671722412


  9%|▉         | 9/100 [00:57<09:17,  6.12s/it]

Epoch: 9, Loss: 1.3488438129425049


 10%|█         | 10/100 [01:04<09:29,  6.33s/it]

Epoch: 10, Loss: 1.4355223178863525


 11%|█         | 11/100 [01:10<09:05,  6.13s/it]

Epoch: 11, Loss: 1.3512799739837646


 12%|█▏        | 12/100 [01:16<08:59,  6.14s/it]

Epoch: 12, Loss: 1.3199876546859741


 13%|█▎        | 13/100 [01:22<08:58,  6.19s/it]

Epoch: 13, Loss: 1.2901902198791504


 14%|█▍        | 14/100 [01:28<08:38,  6.03s/it]

Epoch: 14, Loss: 1.3807094097137451


 15%|█▌        | 15/100 [01:36<09:31,  6.73s/it]

Epoch: 15, Loss: 1.323075771331787


 16%|█▌        | 16/100 [01:42<08:55,  6.37s/it]

Epoch: 16, Loss: 1.3552299737930298


 17%|█▋        | 17/100 [01:49<09:02,  6.54s/it]

Epoch: 17, Loss: 1.2928330898284912


 18%|█▊        | 18/100 [01:54<08:32,  6.25s/it]

Epoch: 18, Loss: 1.188621997833252


 19%|█▉        | 19/100 [02:00<08:22,  6.20s/it]

Epoch: 19, Loss: 1.235265851020813


 20%|██        | 20/100 [02:07<08:20,  6.25s/it]

Epoch: 20, Loss: 1.3095481395721436


 21%|██        | 21/100 [02:12<08:00,  6.08s/it]

Epoch: 21, Loss: 1.3086782693862915


 22%|██▏       | 22/100 [02:19<08:11,  6.30s/it]

Epoch: 22, Loss: 1.2790441513061523


 23%|██▎       | 23/100 [02:25<07:50,  6.11s/it]

Epoch: 23, Loss: 1.1856164932250977


 24%|██▍       | 24/100 [02:32<08:03,  6.36s/it]

Epoch: 24, Loss: 1.227426290512085


 25%|██▌       | 25/100 [02:37<07:40,  6.15s/it]

Epoch: 25, Loss: 1.2669401168823242


 26%|██▌       | 26/100 [02:44<07:43,  6.26s/it]

Epoch: 26, Loss: 1.163386583328247


 27%|██▋       | 27/100 [02:50<07:26,  6.12s/it]

Epoch: 27, Loss: 1.1860742568969727


 28%|██▊       | 28/100 [02:56<07:15,  6.05s/it]

Epoch: 28, Loss: 1.1444859504699707


 29%|██▉       | 29/100 [03:02<07:21,  6.21s/it]

Epoch: 29, Loss: 1.0480680465698242


 30%|███       | 30/100 [03:08<07:01,  6.03s/it]

Epoch: 30, Loss: 1.0864105224609375


 31%|███       | 31/100 [03:15<07:13,  6.28s/it]

Epoch: 31, Loss: 1.1932272911071777


 32%|███▏      | 32/100 [03:20<06:51,  6.05s/it]

Epoch: 32, Loss: 1.0998951196670532


 33%|███▎      | 33/100 [03:27<07:03,  6.32s/it]

Epoch: 33, Loss: 0.9680413603782654


 34%|███▍      | 34/100 [03:33<06:42,  6.10s/it]

Epoch: 34, Loss: 1.1599860191345215


 35%|███▌      | 35/100 [03:39<06:36,  6.10s/it]

Epoch: 35, Loss: 0.8700605630874634


 36%|███▌      | 36/100 [03:45<06:38,  6.23s/it]

Epoch: 36, Loss: 0.9803987145423889


 37%|███▋      | 37/100 [03:52<06:47,  6.47s/it]

Epoch: 37, Loss: 0.9227088689804077


 38%|███▊      | 38/100 [03:59<06:40,  6.46s/it]

Epoch: 38, Loss: 0.983038604259491


 39%|███▉      | 39/100 [04:04<06:21,  6.25s/it]

Epoch: 39, Loss: 1.0280659198760986


 40%|████      | 40/100 [04:11<06:28,  6.47s/it]

Epoch: 40, Loss: 0.8215065598487854


 41%|████      | 41/100 [04:17<06:08,  6.25s/it]

Epoch: 41, Loss: 0.827405571937561


 42%|████▏     | 42/100 [04:24<06:14,  6.45s/it]

Epoch: 42, Loss: 0.8928031325340271


 43%|████▎     | 43/100 [04:30<05:53,  6.20s/it]

Epoch: 43, Loss: 0.807645320892334


 44%|████▍     | 44/100 [04:36<05:54,  6.34s/it]

Epoch: 44, Loss: 0.7820622324943542


 45%|████▌     | 45/100 [04:42<05:40,  6.19s/it]

Epoch: 45, Loss: 0.6725160479545593


 46%|████▌     | 46/100 [04:48<05:29,  6.10s/it]

Epoch: 46, Loss: 0.7326535582542419


 47%|████▋     | 47/100 [04:55<05:31,  6.25s/it]

Epoch: 47, Loss: 0.7734333276748657


 48%|████▊     | 48/100 [05:00<05:15,  6.07s/it]

Epoch: 48, Loss: 0.9913886785507202


 49%|████▉     | 49/100 [05:07<05:22,  6.33s/it]

Epoch: 49, Loss: 0.8644290566444397


 50%|█████     | 50/100 [05:13<05:06,  6.12s/it]

Epoch: 50, Loss: 0.7474942207336426


 51%|█████     | 51/100 [05:20<05:11,  6.36s/it]

Epoch: 51, Loss: 0.7413913607597351


 52%|█████▏    | 52/100 [05:26<04:55,  6.15s/it]

Epoch: 52, Loss: 0.7447165846824646


 53%|█████▎    | 53/100 [05:32<04:50,  6.19s/it]

Epoch: 53, Loss: 0.7733690738677979


 54%|█████▍    | 54/100 [05:38<04:45,  6.20s/it]

Epoch: 54, Loss: 0.7040135264396667


 55%|█████▌    | 55/100 [05:44<04:31,  6.03s/it]

Epoch: 55, Loss: 0.6966257095336914


 56%|█████▌    | 56/100 [05:51<04:37,  6.31s/it]

Epoch: 56, Loss: 0.7589700222015381


 57%|█████▋    | 57/100 [05:56<04:21,  6.09s/it]

Epoch: 57, Loss: 0.6942002773284912


 58%|█████▊    | 58/100 [06:04<04:34,  6.54s/it]

Epoch: 58, Loss: 0.6957862377166748


 59%|█████▉    | 59/100 [06:10<04:26,  6.49s/it]

Epoch: 59, Loss: 0.6682083606719971


 60%|██████    | 60/100 [06:16<04:08,  6.20s/it]

Epoch: 60, Loss: 0.7383553981781006


 61%|██████    | 61/100 [06:23<04:11,  6.46s/it]

Epoch: 61, Loss: 0.6952181458473206


 62%|██████▏   | 62/100 [06:28<03:55,  6.20s/it]

Epoch: 62, Loss: 0.679351270198822


 63%|██████▎   | 63/100 [06:35<03:57,  6.42s/it]

Epoch: 63, Loss: 0.6721837520599365


 64%|██████▍   | 64/100 [06:41<03:42,  6.17s/it]

Epoch: 64, Loss: 0.6837379932403564


 65%|██████▌   | 65/100 [06:47<03:40,  6.29s/it]

Epoch: 65, Loss: 0.7463153600692749


 66%|██████▌   | 66/100 [06:53<03:30,  6.18s/it]

Epoch: 66, Loss: 0.8935039043426514


 67%|██████▋   | 67/100 [06:59<03:20,  6.06s/it]

Epoch: 67, Loss: 0.5790199041366577


 68%|██████▊   | 68/100 [07:06<03:17,  6.18s/it]

Epoch: 68, Loss: 0.6081664562225342


 69%|██████▉   | 69/100 [07:11<03:06,  6.02s/it]

Epoch: 69, Loss: 0.6213963627815247


 70%|███████   | 70/100 [07:18<03:07,  6.26s/it]

Epoch: 70, Loss: 0.8244941234588623


 71%|███████   | 71/100 [07:24<02:56,  6.07s/it]

Epoch: 71, Loss: 0.6187937259674072


 72%|███████▏  | 72/100 [07:31<02:57,  6.32s/it]

Epoch: 72, Loss: 0.601722002029419


 73%|███████▎  | 73/100 [07:36<02:45,  6.12s/it]

Epoch: 73, Loss: 0.5183278322219849


 74%|███████▍  | 74/100 [07:42<02:39,  6.15s/it]

Epoch: 74, Loss: 0.6255543231964111


 75%|███████▌  | 75/100 [07:49<02:34,  6.16s/it]

Epoch: 75, Loss: 0.7441462278366089


 76%|███████▌  | 76/100 [07:54<02:23,  5.99s/it]

Epoch: 76, Loss: 0.592313289642334


 77%|███████▋  | 77/100 [08:01<02:23,  6.25s/it]

Epoch: 77, Loss: 0.6758454442024231


 78%|███████▊  | 78/100 [08:07<02:14,  6.09s/it]

Epoch: 78, Loss: 0.6430345773696899


 79%|███████▉  | 79/100 [08:14<02:17,  6.54s/it]

Epoch: 79, Loss: 0.5575236082077026


 80%|████████  | 80/100 [08:21<02:09,  6.50s/it]

Epoch: 80, Loss: 0.6221998929977417


 81%|████████  | 81/100 [08:28<02:05,  6.62s/it]

Epoch: 81, Loss: 0.6545928716659546


 82%|████████▏ | 82/100 [08:33<01:53,  6.31s/it]

Epoch: 82, Loss: 0.662026047706604


 83%|████████▎ | 83/100 [08:40<01:47,  6.33s/it]

Epoch: 83, Loss: 0.6462627053260803


 84%|████████▍ | 84/100 [08:46<01:39,  6.24s/it]

Epoch: 84, Loss: 0.5960139632225037


 85%|████████▌ | 85/100 [08:51<01:30,  6.06s/it]

Epoch: 85, Loss: 0.6004881858825684


 86%|████████▌ | 86/100 [08:58<01:28,  6.30s/it]

Epoch: 86, Loss: 0.5678758025169373


 87%|████████▋ | 87/100 [09:04<01:19,  6.09s/it]

Epoch: 87, Loss: 0.5468395948410034


 88%|████████▊ | 88/100 [09:11<01:15,  6.32s/it]

Epoch: 88, Loss: 0.608731746673584


 89%|████████▉ | 89/100 [09:16<01:07,  6.12s/it]

Epoch: 89, Loss: 0.5690065026283264


 90%|█████████ | 90/100 [09:23<01:02,  6.21s/it]

Epoch: 90, Loss: 0.6468214392662048


 91%|█████████ | 91/100 [09:29<00:55,  6.12s/it]

Epoch: 91, Loss: 0.677240252494812


 92%|█████████▏| 92/100 [09:34<00:48,  6.02s/it]

Epoch: 92, Loss: 0.5484325885772705


 93%|█████████▎| 93/100 [09:41<00:43,  6.20s/it]

Epoch: 93, Loss: 0.552706241607666


 94%|█████████▍| 94/100 [09:47<00:36,  6.01s/it]

Epoch: 94, Loss: 0.5386564135551453


 95%|█████████▌| 95/100 [09:53<00:31,  6.27s/it]

Epoch: 95, Loss: 0.4837915599346161


 96%|█████████▌| 96/100 [09:59<00:24,  6.08s/it]

Epoch: 96, Loss: 0.6473405957221985


 97%|█████████▋| 97/100 [10:06<00:18,  6.33s/it]

Epoch: 97, Loss: 0.6031897068023682


 98%|█████████▊| 98/100 [10:12<00:12,  6.13s/it]

Epoch: 98, Loss: 0.5901374220848083


 99%|█████████▉| 99/100 [10:18<00:06,  6.15s/it]

Epoch: 99, Loss: 0.5753174424171448


100%|██████████| 100/100 [10:24<00:00,  6.25s/it]

Epoch: 100, Loss: 0.5405212640762329





In [None]:
train_user_representations = {}
test_user_representations = {}
for i in range(pos_z.shape[0]):
  train_user_representations[i] = (dataset.user_ids[i], pos_z[i].detach(), dataset.train_labels[i])
for i in range(test_pos_z.shape[0]):
  test_user_representations[i] = (dataset.user_ids[i], test_pos_z[i].detach(), dataset.test_labels[i])

In [None]:
user_dataset = UserDataset(train_user_representations)

n = len(user_dataset)

indices = np.arange(n)
indices = np.random.permutation(indices)

train_indices = indices [:int(0.8*n)]
test_indices = indices[int(0.8*n):]

user_train_dataset = Subset(user_dataset, train_indices)
user_test_dataset = Subset(user_dataset, test_indices)

user_train_dataloader = DataLoader(user_train_dataset, batch_size=64, shuffle=True)
user_test_dataloader = DataLoader(user_test_dataset, batch_size=64, shuffle=True)

In [None]:
num_classes = 186
classes = np.arange(0, num_classes)

In [None]:
head_model = train2task()
torch.save(head_model.state_dict(), 'edge_dropout_only_infomax_model.pth')

Epoch: 0, Loss: 4.957077503204346
Hits@1: 0.09370490163564682
Epoch: 1, Loss: 4.774438858032227
Hits@1: 0.09492995589971542
Epoch: 2, Loss: 4.651824951171875
Hits@1: 0.09392286092042923
Epoch: 3, Loss: 4.566333770751953
Hits@1: 0.09395292401313782
Epoch: 4, Loss: 4.503297328948975
Hits@1: 0.09492995589971542
Epoch: 5, Loss: 4.454209804534912
Hits@1: 0.09891323745250702
Epoch: 6, Loss: 4.414301872253418
Hits@1: 0.100679412484169
Epoch: 7, Loss: 4.380742073059082
Hits@1: 0.10337752103805542
Epoch: 8, Loss: 4.351809978485107
Hits@1: 0.10560967028141022
Epoch: 9, Loss: 4.326364517211914
Hits@1: 0.1086159199476242
Epoch: 10, Loss: 4.303642272949219
Hits@1: 0.10935996472835541
Epoch: 11, Loss: 4.283111572265625
Hits@1: 0.11449314653873444
Epoch: 12, Loss: 4.264362812042236
Hits@1: 0.11751443147659302
Epoch: 13, Loss: 4.247109889984131
Hits@1: 0.11849146336317062
Epoch: 14, Loss: 4.231112957000732
Hits@1: 0.12124969810247421
Epoch: 15, Loss: 4.216197490692139
Hits@1: 0.12369979172945023
Epoch

### Random walk, only infomax

In [None]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
model = DeepGraphInfomax(hidden_channels=128,
                         encoder=Encoder(interm_channels=74, hidden_channels=128,
                                         num_features=128),
                         summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)),
                         corruption=random_walk).to(device)
linkmodel = LinkPredHead()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)
loss_bce = nn.BCEWithLogitsLoss()

In [None]:
num_epoch = 100
loss_epoch = []
for epoch in tqdm(range(1, num_epoch+1)):
  if epoch == num_epoch:
    pos_z, neg_z, summary, loss = train_only_infomax(epoch)
  else:
    loss = train_only_infomax(epoch)
  loss_epoch.append(loss)
test_pos_z, test_neg_z, test_summary = model(dataset.test_features, dataset.edge_index)

  1%|          | 1/100 [00:05<08:20,  5.06s/it]

Epoch: 1, Loss: 1.3601748943328857


  2%|▏         | 2/100 [00:08<06:45,  4.14s/it]

Epoch: 2, Loss: 1.3325636386871338


  3%|▎         | 3/100 [00:13<07:06,  4.40s/it]

Epoch: 3, Loss: 1.3798646926879883


  4%|▍         | 4/100 [00:16<06:33,  4.10s/it]

Epoch: 4, Loss: 1.1963316202163696


  5%|▌         | 5/100 [00:20<06:07,  3.87s/it]

Epoch: 5, Loss: 1.0853121280670166


  6%|▌         | 6/100 [00:24<05:58,  3.81s/it]

Epoch: 6, Loss: 0.6716153025627136


  7%|▋         | 7/100 [00:28<06:15,  4.03s/it]

Epoch: 7, Loss: 1.2048636674880981


  8%|▊         | 8/100 [00:32<05:55,  3.86s/it]

Epoch: 8, Loss: 1.5446449518203735


  9%|▉         | 9/100 [00:35<05:41,  3.75s/it]

Epoch: 9, Loss: 1.4881179332733154


 10%|█         | 10/100 [00:39<05:52,  3.91s/it]

Epoch: 10, Loss: 1.5803169012069702


 11%|█         | 11/100 [00:43<05:49,  3.92s/it]

Epoch: 11, Loss: 0.9594547748565674


 12%|█▏        | 12/100 [00:47<05:34,  3.80s/it]

Epoch: 12, Loss: 0.7972549796104431


 13%|█▎        | 13/100 [00:50<05:20,  3.68s/it]

Epoch: 13, Loss: 0.6490609049797058


 14%|█▍        | 14/100 [00:55<05:40,  3.96s/it]

Epoch: 14, Loss: 0.6712852120399475


 15%|█▌        | 15/100 [00:58<05:23,  3.81s/it]

Epoch: 15, Loss: 0.8232851028442383


 16%|█▌        | 16/100 [01:02<05:11,  3.71s/it]

Epoch: 16, Loss: 1.639331579208374


 17%|█▋        | 17/100 [01:05<05:08,  3.72s/it]

Epoch: 17, Loss: 1.0431206226348877


 18%|█▊        | 18/100 [01:10<05:22,  3.93s/it]

Epoch: 18, Loss: 0.6099056601524353


 19%|█▉        | 19/100 [01:13<05:07,  3.80s/it]

Epoch: 19, Loss: 0.7573067545890808


 20%|██        | 20/100 [01:17<04:56,  3.70s/it]

Epoch: 20, Loss: 1.1397576332092285


 21%|██        | 21/100 [01:21<05:04,  3.85s/it]

Epoch: 21, Loss: 0.5661287307739258


 22%|██▏       | 22/100 [01:25<05:04,  3.90s/it]

Epoch: 22, Loss: 1.0171226263046265


 23%|██▎       | 23/100 [01:29<04:49,  3.76s/it]

Epoch: 23, Loss: 0.45251381397247314


 24%|██▍       | 24/100 [01:32<04:39,  3.68s/it]

Epoch: 24, Loss: 0.7087335586547852


 25%|██▌       | 25/100 [01:37<04:58,  3.98s/it]

Epoch: 25, Loss: 0.5964272618293762


 26%|██▌       | 26/100 [01:40<04:48,  3.90s/it]

Epoch: 26, Loss: 1.8610165119171143


 27%|██▋       | 27/100 [01:44<04:36,  3.79s/it]

Epoch: 27, Loss: 0.8373644351959229


 28%|██▊       | 28/100 [01:48<04:29,  3.75s/it]

Epoch: 28, Loss: 0.7456758618354797


 29%|██▉       | 29/100 [01:52<04:46,  4.04s/it]

Epoch: 29, Loss: 0.9508333802223206


 30%|███       | 30/100 [01:56<04:31,  3.88s/it]

Epoch: 30, Loss: 0.6101541519165039


 31%|███       | 31/100 [01:59<04:19,  3.76s/it]

Epoch: 31, Loss: 1.345369815826416


 32%|███▏      | 32/100 [02:04<04:24,  3.90s/it]

Epoch: 32, Loss: 1.0017595291137695


 33%|███▎      | 33/100 [02:08<04:24,  3.94s/it]

Epoch: 33, Loss: 0.7214342355728149


 34%|███▍      | 34/100 [02:11<04:10,  3.80s/it]

Epoch: 34, Loss: 0.6556901335716248


 35%|███▌      | 35/100 [02:15<04:15,  3.93s/it]

Epoch: 35, Loss: 0.7617374658584595


 36%|███▌      | 36/100 [02:20<04:35,  4.31s/it]

Epoch: 36, Loss: 0.6858187913894653


 37%|███▋      | 37/100 [02:24<04:19,  4.12s/it]

Epoch: 37, Loss: 0.6721980571746826


 38%|███▊      | 38/100 [02:28<04:04,  3.95s/it]

Epoch: 38, Loss: 0.8292343020439148


 39%|███▉      | 39/100 [02:31<03:54,  3.84s/it]

Epoch: 39, Loss: 1.040019154548645


 40%|████      | 40/100 [02:36<04:05,  4.10s/it]

Epoch: 40, Loss: 0.5001113414764404


 41%|████      | 41/100 [02:40<03:51,  3.93s/it]

Epoch: 41, Loss: 0.4677208960056305


 42%|████▏     | 42/100 [02:43<03:46,  3.91s/it]

Epoch: 42, Loss: 0.6839302778244019


 43%|████▎     | 43/100 [02:49<04:13,  4.45s/it]

Epoch: 43, Loss: 0.45360392332077026


 44%|████▍     | 44/100 [02:53<03:53,  4.18s/it]

Epoch: 44, Loss: 2.7614216804504395


 45%|████▌     | 45/100 [02:56<03:38,  3.97s/it]

Epoch: 45, Loss: 0.6343388557434082


 46%|████▌     | 46/100 [03:00<03:29,  3.87s/it]

Epoch: 46, Loss: 0.5497608780860901


 47%|████▋     | 47/100 [03:04<03:36,  4.08s/it]

Epoch: 47, Loss: 0.5573652982711792


 48%|████▊     | 48/100 [03:08<03:22,  3.89s/it]

Epoch: 48, Loss: 0.45027604699134827


 49%|████▉     | 49/100 [03:11<03:13,  3.79s/it]

Epoch: 49, Loss: 0.9700979590415955


 50%|█████     | 50/100 [03:16<03:15,  3.91s/it]

Epoch: 50, Loss: 0.5275843143463135


 51%|█████     | 51/100 [03:20<03:14,  3.96s/it]

Epoch: 51, Loss: 0.8243167400360107


 52%|█████▏    | 52/100 [03:23<03:03,  3.83s/it]

Epoch: 52, Loss: 0.5209327936172485


 53%|█████▎    | 53/100 [03:27<02:54,  3.72s/it]

Epoch: 53, Loss: 1.1006125211715698


 54%|█████▍    | 54/100 [03:31<03:03,  4.00s/it]

Epoch: 54, Loss: 0.4618203341960907


 55%|█████▌    | 55/100 [03:35<02:54,  3.89s/it]

Epoch: 55, Loss: 0.4196470379829407


 56%|█████▌    | 56/100 [03:38<02:45,  3.77s/it]

Epoch: 56, Loss: 0.6552998423576355


 57%|█████▋    | 57/100 [03:42<02:41,  3.76s/it]

Epoch: 57, Loss: 0.6829873323440552


 58%|█████▊    | 58/100 [03:47<02:48,  4.01s/it]

Epoch: 58, Loss: 0.36889898777008057


 59%|█████▉    | 59/100 [03:50<02:38,  3.86s/it]

Epoch: 59, Loss: 1.632730484008789


 60%|██████    | 60/100 [03:54<02:30,  3.76s/it]

Epoch: 60, Loss: 0.683708906173706


 61%|██████    | 61/100 [03:58<02:31,  3.89s/it]

Epoch: 61, Loss: 1.361624002456665


 62%|██████▏   | 62/100 [04:02<02:29,  3.95s/it]

Epoch: 62, Loss: 0.422578901052475


 63%|██████▎   | 63/100 [04:05<02:21,  3.81s/it]

Epoch: 63, Loss: 0.6483668684959412


 64%|██████▍   | 64/100 [04:09<02:13,  3.72s/it]

Epoch: 64, Loss: 1.6068445444107056


 65%|██████▌   | 65/100 [04:14<02:20,  4.00s/it]

Epoch: 65, Loss: 0.6636335849761963


 66%|██████▌   | 66/100 [04:17<02:12,  3.89s/it]

Epoch: 66, Loss: 0.6579346656799316


 67%|██████▋   | 67/100 [04:21<02:04,  3.77s/it]

Epoch: 67, Loss: 0.5483803749084473


 68%|██████▊   | 68/100 [04:24<01:59,  3.74s/it]

Epoch: 68, Loss: 0.76332688331604


 69%|██████▉   | 69/100 [04:30<02:11,  4.23s/it]

Epoch: 69, Loss: 2.23990797996521


 70%|███████   | 70/100 [04:34<02:06,  4.22s/it]

Epoch: 70, Loss: 0.6463057398796082


 71%|███████   | 71/100 [04:38<01:56,  4.00s/it]

Epoch: 71, Loss: 0.7627145648002625


 72%|███████▏  | 72/100 [04:41<01:47,  3.85s/it]

Epoch: 72, Loss: 0.7993032932281494


 73%|███████▎  | 73/100 [04:46<01:51,  4.11s/it]

Epoch: 73, Loss: 0.641755998134613


 74%|███████▍  | 74/100 [04:49<01:42,  3.93s/it]

Epoch: 74, Loss: 0.6518844366073608


 75%|███████▌  | 75/100 [04:53<01:34,  3.79s/it]

Epoch: 75, Loss: 0.4565773606300354


 76%|███████▌  | 76/100 [04:56<01:29,  3.74s/it]

Epoch: 76, Loss: 0.3337230682373047


 77%|███████▋  | 77/100 [05:01<01:31,  3.99s/it]

Epoch: 77, Loss: 1.8073581457138062


 78%|███████▊  | 78/100 [05:04<01:24,  3.84s/it]

Epoch: 78, Loss: 0.6676205396652222


 79%|███████▉  | 79/100 [05:08<01:18,  3.74s/it]

Epoch: 79, Loss: 0.5103175640106201


 80%|████████  | 80/100 [05:12<01:17,  3.86s/it]

Epoch: 80, Loss: 0.3151206374168396


 81%|████████  | 81/100 [05:16<01:13,  3.89s/it]

Epoch: 81, Loss: 0.42546531558036804


 82%|████████▏ | 82/100 [05:19<01:07,  3.77s/it]

Epoch: 82, Loss: 0.49405449628829956


 83%|████████▎ | 83/100 [05:23<01:02,  3.68s/it]

Epoch: 83, Loss: 0.5833823680877686


 84%|████████▍ | 84/100 [05:28<01:03,  3.96s/it]

Epoch: 84, Loss: 0.535919189453125


 85%|████████▌ | 85/100 [05:31<00:57,  3.84s/it]

Epoch: 85, Loss: 0.6563652753829956


 86%|████████▌ | 86/100 [05:35<00:52,  3.76s/it]

Epoch: 86, Loss: 1.0327975749969482


 87%|████████▋ | 87/100 [05:38<00:48,  3.73s/it]

Epoch: 87, Loss: 0.379071444272995


 88%|████████▊ | 88/100 [05:43<00:47,  3.97s/it]

Epoch: 88, Loss: 0.38967403769493103


 89%|████████▉ | 89/100 [05:46<00:42,  3.84s/it]

Epoch: 89, Loss: 0.8332091569900513


 90%|█████████ | 90/100 [05:50<00:37,  3.75s/it]

Epoch: 90, Loss: 0.43500861525535583


 91%|█████████ | 91/100 [05:54<00:34,  3.88s/it]

Epoch: 91, Loss: 0.5338840484619141


 92%|█████████▏| 92/100 [05:58<00:31,  3.92s/it]

Epoch: 92, Loss: 0.6744825839996338


 93%|█████████▎| 93/100 [06:02<00:26,  3.79s/it]

Epoch: 93, Loss: 0.3683546185493469


 94%|█████████▍| 94/100 [06:05<00:22,  3.72s/it]

Epoch: 94, Loss: 0.38512611389160156


 95%|█████████▌| 95/100 [06:10<00:20,  4.03s/it]

Epoch: 95, Loss: 0.38876864314079285


 96%|█████████▌| 96/100 [06:14<00:15,  3.95s/it]

Epoch: 96, Loss: 1.7054141759872437


 97%|█████████▋| 97/100 [06:17<00:11,  3.81s/it]

Epoch: 97, Loss: 0.5347980260848999


 98%|█████████▊| 98/100 [06:21<00:07,  3.78s/it]

Epoch: 98, Loss: 0.664185643196106


 99%|█████████▉| 99/100 [06:25<00:03,  3.96s/it]

Epoch: 99, Loss: 0.4624948799610138


100%|██████████| 100/100 [06:29<00:00,  3.89s/it]

Epoch: 100, Loss: 0.7446961402893066





In [None]:
train_user_representations = {}
test_user_representations = {}
for i in range(pos_z.shape[0]):
  train_user_representations[i] = (dataset.user_ids[i], pos_z[i].detach(), dataset.train_labels[i])
for i in range(test_pos_z.shape[0]):
  test_user_representations[i] = (dataset.user_ids[i], test_pos_z[i].detach(), dataset.test_labels[i])

In [None]:
user_dataset = UserDataset(train_user_representations)

n = len(user_dataset)

indices = np.arange(n)
indices = np.random.permutation(indices)

train_indices = indices [:int(0.8*n)]
test_indices = indices[int(0.8*n):]

user_train_dataset = Subset(user_dataset, train_indices)
user_test_dataset = Subset(user_dataset, test_indices)

user_train_dataloader = DataLoader(user_train_dataset, batch_size=64, shuffle=True)
user_test_dataloader = DataLoader(user_test_dataset, batch_size=64, shuffle=True)

In [None]:
num_classes = 186
classes = np.arange(0, num_classes)

In [None]:
head_model = train2task()
torch.save(head_model.state_dict(), 'random_walk_only_infomax_model.pth')

Epoch: 0, Loss: 5.0268049240112305
Hits@1: 0.038923460990190506
Epoch: 1, Loss: 4.866142272949219
Hits@1: 0.05279731750488281
Epoch: 2, Loss: 4.744835376739502
Hits@1: 0.08801557123661041
Epoch: 3, Loss: 4.6520304679870605
Hits@1: 0.10407647490501404
Epoch: 4, Loss: 4.579610347747803
Hits@1: 0.10286646336317062
Epoch: 5, Loss: 4.521663188934326
Hits@1: 0.10514369606971741
Epoch: 6, Loss: 4.474079608917236
Hits@1: 0.10661676526069641
Epoch: 7, Loss: 4.434088230133057
Hits@1: 0.10835286974906921
Epoch: 8, Loss: 4.399752140045166
Hits@1: 0.11255411058664322
Epoch: 9, Loss: 4.369768142700195
Hits@1: 0.11624428629875183
Epoch: 10, Loss: 4.343201160430908
Hits@1: 0.12146764993667603
Epoch: 11, Loss: 4.319370269775391
Hits@1: 0.12447390705347061
Epoch: 12, Loss: 4.297772407531738
Hits@1: 0.126473069190979
Epoch: 13, Loss: 4.278028964996338
Hits@1: 0.13014820218086243
Epoch: 14, Loss: 4.2598443031311035
Hits@1: 0.13210226595401764
Epoch: 15, Loss: 4.242992401123047
Hits@1: 0.13332732021808624


### Node mix, Infomax + Link pred

In [None]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
model = DeepGraphInfomax(hidden_channels=128,
                         encoder=Encoder(interm_channels=74, hidden_channels=128,
                                         num_features=128),
                         summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)),
                         corruption=node_mix).to(device)
linkmodel = LinkPredHead()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)
loss_bce = nn.BCEWithLogitsLoss()

In [None]:
num_epoch = 100
loss_epoch = []
for epoch in tqdm(range(1, num_epoch+1)):
  if epoch == num_epoch:
    pos_z, neg_z, summary, loss = train_infomax_linkpred(epoch)
  else:
    loss = train_infomax_linkpred(epoch)
  loss_epoch.append(loss)
test_pos_z, test_neg_z, test_summary = model(dataset.test_features, dataset.edge_index)

  1%|          | 1/100 [00:10<16:33, 10.03s/it]

Epoch: 1, Loss: 4.500741958618164


  2%|▏         | 2/100 [00:17<13:39,  8.36s/it]

Epoch: 2, Loss: 3.718627452850342


  3%|▎         | 3/100 [00:25<13:23,  8.29s/it]

Epoch: 3, Loss: 3.4254417419433594


  4%|▍         | 4/100 [00:32<12:28,  7.80s/it]

Epoch: 4, Loss: 3.2371156215667725


  5%|▌         | 5/100 [00:40<12:36,  7.97s/it]

Epoch: 5, Loss: 3.07939076423645


  6%|▌         | 6/100 [00:48<12:10,  7.77s/it]

Epoch: 6, Loss: 2.6666293144226074


  7%|▋         | 7/100 [00:55<12:05,  7.80s/it]

Epoch: 7, Loss: 2.402069091796875


  8%|▊         | 8/100 [01:03<11:48,  7.70s/it]

Epoch: 8, Loss: 2.415858030319214


  9%|▉         | 9/100 [01:11<11:37,  7.67s/it]

Epoch: 9, Loss: 2.1719188690185547


 10%|█         | 10/100 [01:19<11:44,  7.82s/it]

Epoch: 10, Loss: 2.281268358230591


 11%|█         | 11/100 [01:26<11:21,  7.66s/it]

Epoch: 11, Loss: 3.3060061931610107


 12%|█▏        | 12/100 [01:34<11:27,  7.82s/it]

Epoch: 12, Loss: 1.639325499534607


 13%|█▎        | 13/100 [01:41<10:52,  7.50s/it]

Epoch: 13, Loss: 2.8956797122955322


 14%|█▍        | 14/100 [01:49<10:58,  7.66s/it]

Epoch: 14, Loss: 1.3943301439285278


 15%|█▌        | 15/100 [01:56<10:26,  7.37s/it]

Epoch: 15, Loss: 1.8662102222442627


 16%|█▌        | 16/100 [02:04<10:50,  7.75s/it]

Epoch: 16, Loss: 1.7523062229156494


 17%|█▋        | 17/100 [02:13<11:11,  8.09s/it]

Epoch: 17, Loss: 1.4031293392181396


 18%|█▊        | 18/100 [02:21<11:08,  8.15s/it]

Epoch: 18, Loss: 1.2916085720062256


 19%|█▉        | 19/100 [02:29<10:42,  7.94s/it]

Epoch: 19, Loss: 2.2038042545318604


 20%|██        | 20/100 [02:37<10:37,  7.96s/it]

Epoch: 20, Loss: 2.70727276802063


 21%|██        | 21/100 [02:44<10:16,  7.80s/it]

Epoch: 21, Loss: 1.91615629196167


 22%|██▏       | 22/100 [02:52<09:55,  7.63s/it]

Epoch: 22, Loss: 1.3311233520507812


 23%|██▎       | 23/100 [02:59<09:47,  7.63s/it]

Epoch: 23, Loss: 1.4793239831924438


 24%|██▍       | 24/100 [03:06<09:24,  7.43s/it]

Epoch: 24, Loss: 1.6419076919555664


 25%|██▌       | 25/100 [03:14<09:27,  7.56s/it]

Epoch: 25, Loss: 1.2551809549331665


 26%|██▌       | 26/100 [03:21<09:01,  7.32s/it]

Epoch: 26, Loss: 1.319373369216919


 27%|██▋       | 27/100 [03:29<09:09,  7.53s/it]

Epoch: 27, Loss: 1.2911067008972168


 28%|██▊       | 28/100 [03:36<08:45,  7.29s/it]

Epoch: 28, Loss: 1.3610830307006836


 29%|██▉       | 29/100 [03:44<08:54,  7.53s/it]

Epoch: 29, Loss: 1.0908353328704834


 30%|███       | 30/100 [03:50<08:31,  7.30s/it]

Epoch: 30, Loss: 1.2094981670379639


 31%|███       | 31/100 [03:59<08:40,  7.55s/it]

Epoch: 31, Loss: 1.606271743774414


 32%|███▏      | 32/100 [04:05<08:17,  7.32s/it]

Epoch: 32, Loss: 1.4398698806762695


 33%|███▎      | 33/100 [04:13<08:25,  7.55s/it]

Epoch: 33, Loss: 1.496410846710205


 34%|███▍      | 34/100 [04:20<08:03,  7.32s/it]

Epoch: 34, Loss: 1.11375892162323


 35%|███▌      | 35/100 [04:28<08:10,  7.54s/it]

Epoch: 35, Loss: 1.4157794713974


 36%|███▌      | 36/100 [04:35<07:46,  7.29s/it]

Epoch: 36, Loss: 1.0820322036743164


 37%|███▋      | 37/100 [04:43<07:53,  7.52s/it]

Epoch: 37, Loss: 1.1804450750350952


 38%|███▊      | 38/100 [04:52<08:05,  7.82s/it]

Epoch: 38, Loss: 1.381648302078247


 39%|███▉      | 39/100 [04:59<07:48,  7.67s/it]

Epoch: 39, Loss: 1.1143170595169067


 40%|████      | 40/100 [05:07<07:40,  7.67s/it]

Epoch: 40, Loss: 1.8352549076080322


 41%|████      | 41/100 [05:14<07:21,  7.49s/it]

Epoch: 41, Loss: 1.1110121011734009


 42%|████▏     | 42/100 [05:22<07:23,  7.64s/it]

Epoch: 42, Loss: 1.2400658130645752


 43%|████▎     | 43/100 [05:28<07:00,  7.37s/it]

Epoch: 43, Loss: 1.080901861190796


 44%|████▍     | 44/100 [05:36<07:03,  7.57s/it]

Epoch: 44, Loss: 1.1956465244293213


 45%|████▌     | 45/100 [05:43<06:42,  7.32s/it]

Epoch: 45, Loss: 1.021637201309204


 46%|████▌     | 46/100 [05:51<06:47,  7.55s/it]

Epoch: 46, Loss: 0.9964213967323303


 47%|████▋     | 47/100 [05:58<06:27,  7.31s/it]

Epoch: 47, Loss: 1.1061662435531616


 48%|████▊     | 48/100 [06:06<06:31,  7.53s/it]

Epoch: 48, Loss: 1.0074001550674438


 49%|████▉     | 49/100 [06:13<06:12,  7.30s/it]

Epoch: 49, Loss: 0.994920015335083


 50%|█████     | 50/100 [06:21<06:16,  7.54s/it]

Epoch: 50, Loss: 0.9733976125717163


 51%|█████     | 51/100 [06:28<05:58,  7.31s/it]

Epoch: 51, Loss: 0.9484869241714478


 52%|█████▏    | 52/100 [06:36<06:01,  7.52s/it]

Epoch: 52, Loss: 1.7245705127716064


 53%|█████▎    | 53/100 [06:42<05:42,  7.28s/it]

Epoch: 53, Loss: 0.8939281702041626


 54%|█████▍    | 54/100 [06:50<05:43,  7.47s/it]

Epoch: 54, Loss: 0.9867926239967346


 55%|█████▌    | 55/100 [06:57<05:27,  7.27s/it]

Epoch: 55, Loss: 0.9388020038604736


 56%|█████▌    | 56/100 [07:05<05:27,  7.44s/it]

Epoch: 56, Loss: 1.1811331510543823


 57%|█████▋    | 57/100 [07:12<05:14,  7.32s/it]

Epoch: 57, Loss: 0.9871504306793213


 58%|█████▊    | 58/100 [07:21<05:27,  7.80s/it]

Epoch: 58, Loss: 0.9777753353118896


 59%|█████▉    | 59/100 [07:29<05:20,  7.83s/it]

Epoch: 59, Loss: 0.8910067081451416


 60%|██████    | 60/100 [07:36<05:00,  7.52s/it]

Epoch: 60, Loss: 0.8786947727203369


 61%|██████    | 61/100 [07:44<04:59,  7.67s/it]

Epoch: 61, Loss: 0.9064375162124634


 62%|██████▏   | 62/100 [07:50<04:40,  7.38s/it]

Epoch: 62, Loss: 1.149276614189148


 63%|██████▎   | 63/100 [07:58<04:39,  7.56s/it]

Epoch: 63, Loss: 1.3048304319381714


 64%|██████▍   | 64/100 [08:05<04:22,  7.30s/it]

Epoch: 64, Loss: 0.8333754539489746


 65%|██████▌   | 65/100 [08:13<04:23,  7.53s/it]

Epoch: 65, Loss: 0.8392172455787659


 66%|██████▌   | 66/100 [08:20<04:08,  7.30s/it]

Epoch: 66, Loss: 0.795709490776062


 67%|██████▋   | 67/100 [08:28<04:08,  7.53s/it]

Epoch: 67, Loss: 1.0410243272781372


 68%|██████▊   | 68/100 [08:35<03:52,  7.27s/it]

Epoch: 68, Loss: 0.8275045156478882


 69%|██████▉   | 69/100 [08:42<03:51,  7.47s/it]

Epoch: 69, Loss: 1.0379537343978882


 70%|███████   | 70/100 [08:49<03:37,  7.24s/it]

Epoch: 70, Loss: 0.9351856708526611


 71%|███████   | 71/100 [08:57<03:36,  7.45s/it]

Epoch: 71, Loss: 0.902019739151001


 72%|███████▏  | 72/100 [09:04<03:22,  7.23s/it]

Epoch: 72, Loss: 0.8208370208740234


 73%|███████▎  | 73/100 [09:12<03:20,  7.44s/it]

Epoch: 73, Loss: 0.8728042840957642


 74%|███████▍  | 74/100 [09:19<03:08,  7.25s/it]

Epoch: 74, Loss: 1.1103196144104004


 75%|███████▌  | 75/100 [09:26<03:05,  7.41s/it]

Epoch: 75, Loss: 0.804144561290741


 76%|███████▌  | 76/100 [09:33<02:54,  7.29s/it]

Epoch: 76, Loss: 0.7732282876968384


 77%|███████▋  | 77/100 [09:41<02:49,  7.38s/it]

Epoch: 77, Loss: 0.8808126449584961


 78%|███████▊  | 78/100 [09:48<02:41,  7.34s/it]

Epoch: 78, Loss: 0.7988002300262451


 79%|███████▉  | 79/100 [09:57<02:43,  7.77s/it]

Epoch: 79, Loss: 0.7806093692779541


 80%|████████  | 80/100 [10:04<02:32,  7.64s/it]

Epoch: 80, Loss: 0.7347771525382996


 81%|████████  | 81/100 [10:12<02:23,  7.54s/it]

Epoch: 81, Loss: 0.9318368434906006


 82%|████████▏ | 82/100 [10:19<02:16,  7.57s/it]

Epoch: 82, Loss: 0.8014329671859741


 83%|████████▎ | 83/100 [10:26<02:05,  7.40s/it]

Epoch: 83, Loss: 0.7344963550567627


 84%|████████▍ | 84/100 [10:34<02:01,  7.57s/it]

Epoch: 84, Loss: 0.9185954928398132


 85%|████████▌ | 85/100 [10:41<01:49,  7.30s/it]

Epoch: 85, Loss: 0.848392903804779


 86%|████████▌ | 86/100 [10:49<01:45,  7.50s/it]

Epoch: 86, Loss: 0.7166116237640381


 87%|████████▋ | 87/100 [10:56<01:34,  7.26s/it]

Epoch: 87, Loss: 0.8188190460205078


 88%|████████▊ | 88/100 [11:04<01:29,  7.48s/it]

Epoch: 88, Loss: 0.8207554817199707


 89%|████████▉ | 89/100 [11:10<01:19,  7.23s/it]

Epoch: 89, Loss: 0.773224949836731


 90%|█████████ | 90/100 [11:18<01:14,  7.46s/it]

Epoch: 90, Loss: 0.7276590466499329


 91%|█████████ | 91/100 [11:25<01:05,  7.23s/it]

Epoch: 91, Loss: 0.8430317044258118


 92%|█████████▏| 92/100 [11:33<00:59,  7.45s/it]

Epoch: 92, Loss: 0.7088104486465454


 93%|█████████▎| 93/100 [11:40<00:50,  7.22s/it]

Epoch: 93, Loss: 0.7430026531219482


 94%|█████████▍| 94/100 [11:47<00:44,  7.44s/it]

Epoch: 94, Loss: 0.7606393694877625


 95%|█████████▌| 95/100 [11:54<00:36,  7.21s/it]

Epoch: 95, Loss: 0.6833232641220093


 96%|█████████▌| 96/100 [12:02<00:29,  7.44s/it]

Epoch: 96, Loss: 0.8263025283813477


 97%|█████████▋| 97/100 [12:09<00:21,  7.23s/it]

Epoch: 97, Loss: 0.6642577648162842


 98%|█████████▊| 98/100 [12:17<00:14,  7.48s/it]

Epoch: 98, Loss: 0.9257211685180664


 99%|█████████▉| 99/100 [12:26<00:07,  7.87s/it]

Epoch: 99, Loss: 0.6998384594917297


100%|██████████| 100/100 [12:33<00:00,  7.53s/it]

Epoch: 100, Loss: 0.7707847356796265





In [None]:
train_user_representations = {}
test_user_representations = {}
for i in range(pos_z.shape[0]):
  train_user_representations[i] = (dataset.user_ids[i], pos_z[i].detach(), dataset.train_labels[i])
for i in range(test_pos_z.shape[0]):
  test_user_representations[i] = (dataset.user_ids[i], test_pos_z[i].detach(), dataset.test_labels[i])

In [None]:
user_dataset = UserDataset(train_user_representations)

n = len(user_dataset)

indices = np.arange(n)
indices = np.random.permutation(indices)

train_indices = indices [:int(0.8*n)]
test_indices = indices[int(0.8*n):]

user_train_dataset = Subset(user_dataset, train_indices)
user_test_dataset = Subset(user_dataset, test_indices)

user_train_dataloader = DataLoader(user_train_dataset, batch_size=64, shuffle=True)
user_test_dataloader = DataLoader(user_test_dataset, batch_size=64, shuffle=True)

In [None]:
num_classes = 186
classes = np.arange(0, num_classes)

In [None]:
head_model = train2task()
torch.save(head_model.state_dict(), 'node_mix_infomax_linkpred_model.pth')

Epoch: 0, Loss: 5.155427932739258
Hits@1: 0.06047077849507332
Epoch: 1, Loss: 5.084137916564941
Hits@1: 0.08645232021808624
Epoch: 2, Loss: 5.018618583679199
Hits@1: 0.10460256785154343
Epoch: 3, Loss: 4.958208084106445
Hits@1: 0.12050565332174301
Epoch: 4, Loss: 4.902405261993408
Hits@1: 0.12737493216991425
Epoch: 5, Loss: 4.850795745849609
Hits@1: 0.12934403121471405
Epoch: 6, Loss: 4.803010940551758
Hits@1: 0.13211730122566223
Epoch: 7, Loss: 4.75872278213501
Hits@1: 0.13189934194087982
Epoch: 8, Loss: 4.717621326446533
Hits@1: 0.13388347625732422
Epoch: 9, Loss: 4.679427146911621
Hits@1: 0.13461248576641083
Epoch: 10, Loss: 4.643889427185059
Hits@1: 0.13459746539592743
Epoch: 11, Loss: 4.610775947570801
Hits@1: 0.13484548032283783
Epoch: 12, Loss: 4.579861640930176
Hits@1: 0.13509349524974823
Epoch: 13, Loss: 4.550949573516846
Hits@1: 0.13658158481121063
Epoch: 14, Loss: 4.523855209350586
Hits@1: 0.13707761466503143
Epoch: 15, Loss: 4.4984130859375
Hits@1: 0.13709264993667603
Epoch

### Node dropout, Infomax + Link pred

In [None]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
model = DeepGraphInfomax(hidden_channels=128,
                         encoder=Encoder(interm_channels=74, hidden_channels=128,
                                         num_features=128),
                         summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)),
                         corruption=node_dropout).to(device)
linkmodel = LinkPredHead()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)
loss_bce = nn.BCEWithLogitsLoss()

In [None]:
num_epoch = 100
loss_epoch = []
for epoch in tqdm(range(1, num_epoch+1)):
  if epoch == num_epoch:
    pos_z, neg_z, summary, loss = train_infomax_linkpred(epoch)
  else:
    loss = train_infomax_linkpred(epoch)
  loss_epoch.append(loss)
test_pos_z, test_neg_z, test_summary = model(dataset.test_features, dataset.edge_index)

  1%|          | 1/100 [00:10<16:50, 10.20s/it]

Epoch: 1, Loss: 3.8799002170562744


  2%|▏         | 2/100 [00:17<14:05,  8.63s/it]

Epoch: 2, Loss: 3.5561585426330566


  3%|▎         | 3/100 [00:27<14:44,  9.12s/it]

Epoch: 3, Loss: 3.5263118743896484


  4%|▍         | 4/100 [00:41<17:57, 11.23s/it]

Epoch: 4, Loss: 3.481123208999634


  5%|▌         | 5/100 [00:51<16:47, 10.61s/it]

Epoch: 5, Loss: 3.0215132236480713


  6%|▌         | 6/100 [00:59<15:09,  9.68s/it]

Epoch: 6, Loss: 3.2386960983276367


  7%|▋         | 7/100 [01:07<14:28,  9.33s/it]

Epoch: 7, Loss: 3.191718101501465


  8%|▊         | 8/100 [01:15<13:35,  8.86s/it]

Epoch: 8, Loss: 3.200770378112793


  9%|▉         | 9/100 [01:24<13:14,  8.73s/it]

Epoch: 9, Loss: 3.2664012908935547


 10%|█         | 10/100 [01:32<13:06,  8.74s/it]

Epoch: 10, Loss: 3.320202112197876


 11%|█         | 11/100 [01:40<12:25,  8.38s/it]

Epoch: 11, Loss: 3.019913673400879


 12%|█▏        | 12/100 [01:49<12:24,  8.46s/it]

Epoch: 12, Loss: 3.1790578365325928


 13%|█▎        | 13/100 [01:56<11:45,  8.11s/it]

Epoch: 13, Loss: 3.0948691368103027


 14%|█▍        | 14/100 [02:05<12:04,  8.43s/it]

Epoch: 14, Loss: 3.026639699935913


 15%|█▌        | 15/100 [02:13<11:52,  8.38s/it]

Epoch: 15, Loss: 3.095552921295166


 16%|█▌        | 16/100 [02:21<11:24,  8.15s/it]

Epoch: 16, Loss: 3.1355197429656982


 17%|█▋        | 17/100 [02:30<11:35,  8.38s/it]

Epoch: 17, Loss: 2.8428714275360107


 18%|█▊        | 18/100 [02:37<11:03,  8.09s/it]

Epoch: 18, Loss: 3.0672106742858887


 19%|█▉        | 19/100 [02:46<11:13,  8.32s/it]

Epoch: 19, Loss: 3.1016013622283936


 20%|██        | 20/100 [02:54<10:43,  8.04s/it]

Epoch: 20, Loss: 2.810856819152832


 21%|██        | 21/100 [03:02<10:37,  8.07s/it]

Epoch: 21, Loss: 2.846606492996216


 22%|██▏       | 22/100 [03:10<10:38,  8.19s/it]

Epoch: 22, Loss: 2.6299173831939697


 23%|██▎       | 23/100 [03:18<10:14,  7.98s/it]

Epoch: 23, Loss: 3.155036687850952


 24%|██▍       | 24/100 [03:26<10:21,  8.17s/it]

Epoch: 24, Loss: 3.1829307079315186


 25%|██▌       | 25/100 [03:35<10:24,  8.32s/it]

Epoch: 25, Loss: 2.9033031463623047


 26%|██▌       | 26/100 [03:44<10:22,  8.41s/it]

Epoch: 26, Loss: 2.6108057498931885


 27%|██▋       | 27/100 [03:52<10:20,  8.50s/it]

Epoch: 27, Loss: 3.133852005004883


 28%|██▊       | 28/100 [04:00<09:48,  8.17s/it]

Epoch: 28, Loss: 3.0625596046447754


 29%|██▉       | 29/100 [04:09<09:53,  8.36s/it]

Epoch: 29, Loss: 3.2569961547851562


 30%|███       | 30/100 [04:16<09:25,  8.07s/it]

Epoch: 30, Loss: 3.0811307430267334


 31%|███       | 31/100 [04:25<09:32,  8.30s/it]

Epoch: 31, Loss: 2.526677370071411


 32%|███▏      | 32/100 [04:33<09:30,  8.39s/it]

Epoch: 32, Loss: 3.2386057376861572


 33%|███▎      | 33/100 [04:41<09:06,  8.16s/it]

Epoch: 33, Loss: 3.127969980239868


 34%|███▍      | 34/100 [04:50<09:12,  8.36s/it]

Epoch: 34, Loss: 3.025425434112549


 35%|███▌      | 35/100 [04:57<08:40,  8.01s/it]

Epoch: 35, Loss: 3.091251850128174


 36%|███▌      | 36/100 [05:05<08:35,  8.06s/it]

Epoch: 36, Loss: 2.834530830383301


 37%|███▋      | 37/100 [05:12<08:11,  7.80s/it]

Epoch: 37, Loss: 2.9161033630371094


 38%|███▊      | 38/100 [05:21<08:10,  7.92s/it]

Epoch: 38, Loss: 2.5751819610595703


 39%|███▉      | 39/100 [05:29<08:15,  8.12s/it]

Epoch: 39, Loss: 3.0209598541259766


 40%|████      | 40/100 [05:37<07:54,  7.92s/it]

Epoch: 40, Loss: 2.6153688430786133


 41%|████      | 41/100 [05:45<07:59,  8.13s/it]

Epoch: 41, Loss: 3.047680377960205


 42%|████▏     | 42/100 [05:53<07:38,  7.91s/it]

Epoch: 42, Loss: 2.840416669845581


 43%|████▎     | 43/100 [06:01<07:46,  8.18s/it]

Epoch: 43, Loss: 2.499664783477783


 44%|████▍     | 44/100 [06:09<07:33,  8.09s/it]

Epoch: 44, Loss: 3.1664838790893555


 45%|████▌     | 45/100 [06:17<07:25,  8.11s/it]

Epoch: 45, Loss: 3.2506115436553955


 46%|████▌     | 46/100 [06:27<07:35,  8.44s/it]

Epoch: 46, Loss: 3.2048707008361816


 47%|████▋     | 47/100 [06:34<07:17,  8.26s/it]

Epoch: 47, Loss: 3.108206272125244


 48%|████▊     | 48/100 [06:43<07:11,  8.29s/it]

Epoch: 48, Loss: 2.3807852268218994


 49%|████▉     | 49/100 [06:50<06:47,  7.99s/it]

Epoch: 49, Loss: 3.074702501296997


 50%|█████     | 50/100 [06:59<06:47,  8.15s/it]

Epoch: 50, Loss: 2.4425833225250244


 51%|█████     | 51/100 [07:06<06:23,  7.83s/it]

Epoch: 51, Loss: 2.8964760303497314


 52%|█████▏    | 52/100 [07:14<06:27,  8.08s/it]

Epoch: 52, Loss: 2.2449018955230713


 53%|█████▎    | 53/100 [07:22<06:13,  7.95s/it]

Epoch: 53, Loss: 2.881822109222412


 54%|█████▍    | 54/100 [07:31<06:13,  8.13s/it]

Epoch: 54, Loss: 3.2685694694519043


 55%|█████▌    | 55/100 [07:39<06:13,  8.31s/it]

Epoch: 55, Loss: 3.2967002391815186


 56%|█████▌    | 56/100 [07:47<05:53,  8.03s/it]

Epoch: 56, Loss: 2.2019448280334473


 57%|█████▋    | 57/100 [07:56<05:59,  8.37s/it]

Epoch: 57, Loss: 2.061732530593872


 58%|█████▊    | 58/100 [08:03<05:38,  8.06s/it]

Epoch: 58, Loss: 2.0578713417053223


 59%|█████▉    | 59/100 [08:12<05:38,  8.26s/it]

Epoch: 59, Loss: 2.1723439693450928


 60%|██████    | 60/100 [08:20<05:24,  8.10s/it]

Epoch: 60, Loss: 1.81612229347229


 61%|██████    | 61/100 [08:27<05:11,  7.99s/it]

Epoch: 61, Loss: 3.583026885986328


 62%|██████▏   | 62/100 [08:36<05:07,  8.10s/it]

Epoch: 62, Loss: 3.879030704498291


 63%|██████▎   | 63/100 [08:43<04:53,  7.94s/it]

Epoch: 63, Loss: 3.8042047023773193


 64%|██████▍   | 64/100 [08:52<04:54,  8.19s/it]

Epoch: 64, Loss: 2.893742084503174


 65%|██████▌   | 65/100 [09:00<04:38,  7.97s/it]

Epoch: 65, Loss: 3.2911911010742188


 66%|██████▌   | 66/100 [09:08<04:39,  8.21s/it]

Epoch: 66, Loss: 3.205655336380005


 67%|██████▋   | 67/100 [09:17<04:33,  8.30s/it]

Epoch: 67, Loss: 3.2659752368927


 68%|██████▊   | 68/100 [09:26<04:30,  8.45s/it]

Epoch: 68, Loss: 3.3171753883361816


 69%|██████▉   | 69/100 [09:34<04:19,  8.38s/it]

Epoch: 69, Loss: 3.0113513469696045


 70%|███████   | 70/100 [09:42<04:05,  8.17s/it]

Epoch: 70, Loss: 3.2539029121398926


 71%|███████   | 71/100 [09:51<04:06,  8.49s/it]

Epoch: 71, Loss: 3.077662944793701


 72%|███████▏  | 72/100 [09:59<03:53,  8.35s/it]

Epoch: 72, Loss: 2.435187816619873


 73%|███████▎  | 73/100 [10:07<03:46,  8.38s/it]

Epoch: 73, Loss: 3.0788934230804443


 74%|███████▍  | 74/100 [10:16<03:40,  8.48s/it]

Epoch: 74, Loss: 2.8194491863250732


 75%|███████▌  | 75/100 [10:23<03:22,  8.12s/it]

Epoch: 75, Loss: 2.9657645225524902


 76%|███████▌  | 76/100 [10:32<03:19,  8.30s/it]

Epoch: 76, Loss: 2.997180938720703


 77%|███████▋  | 77/100 [10:45<03:45,  9.80s/it]

Epoch: 77, Loss: 2.9431307315826416


 78%|███████▊  | 78/100 [10:53<03:23,  9.25s/it]

Epoch: 78, Loss: 2.7808141708374023


 79%|███████▉  | 79/100 [11:02<03:11,  9.14s/it]

Epoch: 79, Loss: 3.0553770065307617


 80%|████████  | 80/100 [11:10<02:52,  8.64s/it]

Epoch: 80, Loss: 2.965224266052246


 81%|████████  | 81/100 [11:19<02:48,  8.88s/it]

Epoch: 81, Loss: 2.807588577270508


 82%|████████▏ | 82/100 [11:28<02:39,  8.85s/it]

Epoch: 82, Loss: 3.222757339477539


 83%|████████▎ | 83/100 [11:35<02:20,  8.28s/it]

Epoch: 83, Loss: 2.694793462753296


 84%|████████▍ | 84/100 [11:43<02:13,  8.35s/it]

Epoch: 84, Loss: 2.968139410018921


 85%|████████▌ | 85/100 [11:51<02:01,  8.11s/it]

Epoch: 85, Loss: 2.848836660385132


 86%|████████▌ | 86/100 [11:59<01:52,  8.02s/it]

Epoch: 86, Loss: 2.946662425994873


 87%|████████▋ | 87/100 [12:08<01:48,  8.33s/it]

Epoch: 87, Loss: 2.98275089263916


 88%|████████▊ | 88/100 [12:17<01:43,  8.58s/it]

Epoch: 88, Loss: 2.1955714225769043


 89%|████████▉ | 89/100 [12:26<01:35,  8.73s/it]

Epoch: 89, Loss: 2.313645839691162


 90%|█████████ | 90/100 [12:34<01:26,  8.65s/it]

Epoch: 90, Loss: 3.2426090240478516


 91%|█████████ | 91/100 [12:42<01:14,  8.30s/it]

Epoch: 91, Loss: 2.0695273876190186


 92%|█████████▏| 92/100 [12:51<01:07,  8.45s/it]

Epoch: 92, Loss: 3.1278977394104004


 93%|█████████▎| 93/100 [12:58<00:56,  8.11s/it]

Epoch: 93, Loss: 2.605182647705078


 94%|█████████▍| 94/100 [13:07<00:49,  8.28s/it]

Epoch: 94, Loss: 1.8050992488861084


 95%|█████████▌| 95/100 [13:15<00:41,  8.21s/it]

Epoch: 95, Loss: 3.070251226425171


 96%|█████████▌| 96/100 [13:22<00:32,  8.03s/it]

Epoch: 96, Loss: 3.2370011806488037


 97%|█████████▋| 97/100 [13:31<00:24,  8.17s/it]

Epoch: 97, Loss: 1.834581732749939


 98%|█████████▊| 98/100 [13:38<00:15,  7.95s/it]

Epoch: 98, Loss: 1.912804365158081


 99%|█████████▉| 99/100 [13:47<00:08,  8.13s/it]

Epoch: 99, Loss: 2.6845946311950684


100%|██████████| 100/100 [13:54<00:00,  8.35s/it]

Epoch: 100, Loss: 1.856714129447937





In [None]:
train_user_representations = {}
test_user_representations = {}
for i in range(pos_z.shape[0]):
  train_user_representations[i] = (dataset.user_ids[i], pos_z[i].detach(), dataset.train_labels[i])
for i in range(test_pos_z.shape[0]):
  test_user_representations[i] = (dataset.user_ids[i], test_pos_z[i].detach(), dataset.test_labels[i])

In [None]:
user_dataset = UserDataset(train_user_representations)

n = len(user_dataset)

indices = np.arange(n)
indices = np.random.permutation(indices)

train_indices = indices [:int(0.8*n)]
test_indices = indices[int(0.8*n):]

user_train_dataset = Subset(user_dataset, train_indices)
user_test_dataset = Subset(user_dataset, test_indices)

user_train_dataloader = DataLoader(user_train_dataset, batch_size=64, shuffle=True)
user_test_dataloader = DataLoader(user_test_dataset, batch_size=64, shuffle=True)

In [None]:
num_classes = 186
classes = np.arange(0, num_classes)

In [None]:
head_model = train2task()
torch.save(head_model.state_dict(), 'node_dropout_infomax_linkpred_model.pth')

Epoch: 0, Loss: 5.17767858505249
Hits@1: 0.057772666215896606
Epoch: 1, Loss: 5.126504898071289
Hits@1: 0.09021764993667603
Epoch: 2, Loss: 5.078027248382568
Hits@1: 0.10703763365745544
Epoch: 3, Loss: 5.031926155090332
Hits@1: 0.11799542605876923
Epoch: 4, Loss: 4.9880218505859375
Hits@1: 0.13461248576641083
Epoch: 5, Loss: 4.946174621582031
Hits@1: 0.144578218460083
Epoch: 6, Loss: 4.906243801116943
Hits@1: 0.15295062959194183
Epoch: 7, Loss: 4.868104934692383
Hits@1: 0.15837691724300385
Epoch: 8, Loss: 4.831657886505127
Hits@1: 0.15811386704444885
Epoch: 9, Loss: 4.796802520751953
Hits@1: 0.15986502170562744
Epoch: 10, Loss: 4.7634501457214355
Hits@1: 0.15988004207611084
Epoch: 11, Loss: 4.731516361236572
Hits@1: 0.16087211668491364
Epoch: 12, Loss: 4.700922966003418
Hits@1: 0.16333724558353424
Epoch: 13, Loss: 4.671606540679932
Hits@1: 0.1624203324317932
Epoch: 14, Loss: 4.643500804901123
Hits@1: 0.16487042605876923
Epoch: 15, Loss: 4.61654806137085
Hits@1: 0.16435936093330383
Epoc

### Edge dropout, Infomax + Link pred

In [None]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
model = DeepGraphInfomax(hidden_channels=128,
                         encoder=Encoder(interm_channels=74, hidden_channels=128,
                                         num_features=128),
                         summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)),
                         corruption=edge_dropout).to(device)
linkmodel = LinkPredHead()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)
loss_bce = nn.BCEWithLogitsLoss()

In [None]:
num_epoch = 100
loss_epoch = []
for epoch in tqdm(range(1, num_epoch+1)):
  if epoch == num_epoch:
    pos_z, neg_z, summary, loss = train_infomax_linkpred(epoch)
  else:
    loss = train_infomax_linkpred(epoch)
  loss_epoch.append(loss)
test_pos_z, test_neg_z, test_summary = model(dataset.test_features, dataset.edge_index)

  1%|          | 1/100 [00:08<13:23,  8.11s/it]

Epoch: 1, Loss: 7.972873687744141


  2%|▏         | 2/100 [00:15<12:37,  7.73s/it]

Epoch: 2, Loss: 6.6683197021484375


  3%|▎         | 3/100 [00:21<11:04,  6.86s/it]

Epoch: 3, Loss: 5.006956100463867


  4%|▍         | 4/100 [00:28<11:08,  6.97s/it]

Epoch: 4, Loss: 4.139936923980713


  5%|▌         | 5/100 [00:34<10:18,  6.51s/it]

Epoch: 5, Loss: 4.0134687423706055


  6%|▌         | 6/100 [00:41<10:22,  6.62s/it]

Epoch: 6, Loss: 3.50099778175354


  7%|▋         | 7/100 [00:47<09:56,  6.42s/it]

Epoch: 7, Loss: 3.5405797958374023


  8%|▊         | 8/100 [00:53<09:48,  6.40s/it]

Epoch: 8, Loss: 3.4753026962280273


  9%|▉         | 9/100 [01:00<09:51,  6.50s/it]

Epoch: 9, Loss: 3.3867475986480713


 10%|█         | 10/100 [01:05<09:26,  6.29s/it]

Epoch: 10, Loss: 3.2560577392578125


 11%|█         | 11/100 [01:13<09:41,  6.53s/it]

Epoch: 11, Loss: 3.376514196395874


 12%|█▏        | 12/100 [01:18<09:13,  6.29s/it]

Epoch: 12, Loss: 3.251241683959961


 13%|█▎        | 13/100 [01:26<09:45,  6.73s/it]

Epoch: 13, Loss: 3.2346372604370117


 14%|█▍        | 14/100 [01:32<09:29,  6.62s/it]

Epoch: 14, Loss: 3.2260708808898926


 15%|█▌        | 15/100 [01:40<09:38,  6.81s/it]

Epoch: 15, Loss: 3.153749465942383


 16%|█▌        | 16/100 [01:46<09:09,  6.55s/it]

Epoch: 16, Loss: 3.130607843399048


 17%|█▋        | 17/100 [01:53<09:18,  6.73s/it]

Epoch: 17, Loss: 3.065579652786255


 18%|█▊        | 18/100 [01:59<08:52,  6.50s/it]

Epoch: 18, Loss: 3.0902693271636963


 19%|█▉        | 19/100 [02:06<09:00,  6.68s/it]

Epoch: 19, Loss: 3.0478515625


 20%|██        | 20/100 [02:12<08:40,  6.51s/it]

Epoch: 20, Loss: 3.02191162109375


 21%|██        | 21/100 [02:18<08:34,  6.51s/it]

Epoch: 21, Loss: 2.9809556007385254


 22%|██▏       | 22/100 [02:25<08:26,  6.50s/it]

Epoch: 22, Loss: 2.9802358150482178


 23%|██▎       | 23/100 [02:31<08:12,  6.40s/it]

Epoch: 23, Loss: 3.042116403579712


 24%|██▍       | 24/100 [02:38<08:17,  6.54s/it]

Epoch: 24, Loss: 2.9695005416870117


 25%|██▌       | 25/100 [02:44<07:55,  6.34s/it]

Epoch: 25, Loss: 2.9641294479370117


 26%|██▌       | 26/100 [02:51<08:06,  6.58s/it]

Epoch: 26, Loss: 3.031702995300293


 27%|██▋       | 27/100 [02:57<07:42,  6.34s/it]

Epoch: 27, Loss: 3.035066604614258


 28%|██▊       | 28/100 [03:04<07:55,  6.60s/it]

Epoch: 28, Loss: 3.0239429473876953


 29%|██▉       | 29/100 [03:10<07:32,  6.37s/it]

Epoch: 29, Loss: 2.9365108013153076


 30%|███       | 30/100 [03:17<07:41,  6.60s/it]

Epoch: 30, Loss: 3.011253833770752


 31%|███       | 31/100 [03:23<07:20,  6.38s/it]

Epoch: 31, Loss: 3.009251832962036


 32%|███▏      | 32/100 [03:29<07:17,  6.44s/it]

Epoch: 32, Loss: 2.988787889480591


 33%|███▎      | 33/100 [03:36<07:07,  6.39s/it]

Epoch: 33, Loss: 3.0037593841552734


 34%|███▍      | 34/100 [03:41<06:48,  6.18s/it]

Epoch: 34, Loss: 2.9603185653686523


 35%|███▌      | 35/100 [03:48<07:00,  6.47s/it]

Epoch: 35, Loss: 3.044111728668213


 36%|███▌      | 36/100 [03:54<06:40,  6.26s/it]

Epoch: 36, Loss: 3.0083017349243164


 37%|███▋      | 37/100 [04:01<06:51,  6.54s/it]

Epoch: 37, Loss: 2.9340362548828125


 38%|███▊      | 38/100 [04:07<06:34,  6.36s/it]

Epoch: 38, Loss: 2.8995468616485596


 39%|███▉      | 39/100 [04:15<06:56,  6.83s/it]

Epoch: 39, Loss: 2.894562244415283


 40%|████      | 40/100 [04:22<06:41,  6.70s/it]

Epoch: 40, Loss: 2.925473690032959


 41%|████      | 41/100 [04:28<06:35,  6.70s/it]

Epoch: 41, Loss: 2.9335944652557373


 42%|████▏     | 42/100 [04:34<06:12,  6.42s/it]

Epoch: 42, Loss: 2.8973708152770996


 43%|████▎     | 43/100 [04:40<05:53,  6.20s/it]

Epoch: 43, Loss: 2.9480836391448975


 44%|████▍     | 44/100 [04:46<05:52,  6.30s/it]

Epoch: 44, Loss: 2.9056589603424072


 45%|████▌     | 45/100 [04:53<05:44,  6.26s/it]

Epoch: 45, Loss: 2.865370035171509


 46%|████▌     | 46/100 [05:00<05:53,  6.54s/it]

Epoch: 46, Loss: 2.9345810413360596


 47%|████▋     | 47/100 [05:06<05:36,  6.34s/it]

Epoch: 47, Loss: 2.8942737579345703


 48%|████▊     | 48/100 [05:13<05:41,  6.57s/it]

Epoch: 48, Loss: 2.9147911071777344


 49%|████▉     | 49/100 [05:19<05:24,  6.36s/it]

Epoch: 49, Loss: 2.893551826477051


 50%|█████     | 50/100 [05:26<05:28,  6.56s/it]

Epoch: 50, Loss: 2.9695916175842285


 51%|█████     | 51/100 [05:31<05:08,  6.30s/it]

Epoch: 51, Loss: 2.8322227001190186


 52%|█████▏    | 52/100 [05:38<05:06,  6.39s/it]

Epoch: 52, Loss: 2.871096134185791


 53%|█████▎    | 53/100 [05:44<04:54,  6.27s/it]

Epoch: 53, Loss: 2.8242156505584717


 54%|█████▍    | 54/100 [05:50<04:45,  6.21s/it]

Epoch: 54, Loss: 2.925766944885254


 55%|█████▌    | 55/100 [05:57<04:44,  6.32s/it]

Epoch: 55, Loss: 2.8468832969665527


 56%|█████▌    | 56/100 [06:02<04:30,  6.14s/it]

Epoch: 56, Loss: 2.8794198036193848


 57%|█████▋    | 57/100 [06:09<04:36,  6.43s/it]

Epoch: 57, Loss: 2.9931304454803467


 58%|█████▊    | 58/100 [06:15<04:23,  6.28s/it]

Epoch: 58, Loss: 2.9667513370513916


 59%|█████▉    | 59/100 [06:22<04:26,  6.50s/it]

Epoch: 59, Loss: 2.8371102809906006


 60%|██████    | 60/100 [06:28<04:09,  6.24s/it]

Epoch: 60, Loss: 2.8573575019836426


 61%|██████    | 61/100 [06:34<04:07,  6.34s/it]

Epoch: 61, Loss: 2.7421131134033203


 62%|██████▏   | 62/100 [06:41<03:57,  6.26s/it]

Epoch: 62, Loss: 2.89747953414917


 63%|██████▎   | 63/100 [06:46<03:47,  6.15s/it]

Epoch: 63, Loss: 2.819314956665039


 64%|██████▍   | 64/100 [06:53<03:46,  6.28s/it]

Epoch: 64, Loss: 2.8107454776763916


 65%|██████▌   | 65/100 [07:00<03:47,  6.49s/it]

Epoch: 65, Loss: 2.8509974479675293


 66%|██████▌   | 66/100 [07:08<03:51,  6.80s/it]

Epoch: 66, Loss: 2.8306057453155518


 67%|██████▋   | 67/100 [07:13<03:34,  6.50s/it]

Epoch: 67, Loss: 2.8052237033843994


 68%|██████▊   | 68/100 [07:21<03:36,  6.78s/it]

Epoch: 68, Loss: 2.7014517784118652


 69%|██████▉   | 69/100 [07:27<03:23,  6.56s/it]

Epoch: 69, Loss: 2.76137113571167


 70%|███████   | 70/100 [07:34<03:21,  6.72s/it]

Epoch: 70, Loss: 2.8849003314971924


 71%|███████   | 71/100 [07:40<03:07,  6.46s/it]

Epoch: 71, Loss: 2.6901602745056152


 72%|███████▏  | 72/100 [07:47<03:08,  6.75s/it]

Epoch: 72, Loss: 2.814882278442383


 73%|███████▎  | 73/100 [07:53<02:54,  6.47s/it]

Epoch: 73, Loss: 2.7026445865631104


 74%|███████▍  | 74/100 [08:00<02:50,  6.55s/it]

Epoch: 74, Loss: 2.6862411499023438


 75%|███████▌  | 75/100 [08:06<02:41,  6.44s/it]

Epoch: 75, Loss: 2.661720037460327


 76%|███████▌  | 76/100 [08:12<02:33,  6.38s/it]

Epoch: 76, Loss: 2.6175200939178467


 77%|███████▋  | 77/100 [08:19<02:28,  6.46s/it]

Epoch: 77, Loss: 2.5456700325012207


 78%|███████▊  | 78/100 [08:25<02:17,  6.26s/it]

Epoch: 78, Loss: 2.5091590881347656


 79%|███████▉  | 79/100 [08:32<02:18,  6.61s/it]

Epoch: 79, Loss: 2.8703455924987793


 80%|████████  | 80/100 [08:38<02:08,  6.44s/it]

Epoch: 80, Loss: 2.416902780532837


 81%|████████  | 81/100 [08:45<02:06,  6.68s/it]

Epoch: 81, Loss: 2.6998395919799805


 82%|████████▏ | 82/100 [08:51<01:55,  6.44s/it]

Epoch: 82, Loss: 2.416534900665283


 83%|████████▎ | 83/100 [08:58<01:52,  6.59s/it]

Epoch: 83, Loss: 2.6298694610595703


 84%|████████▍ | 84/100 [09:04<01:40,  6.30s/it]

Epoch: 84, Loss: 2.4910731315612793


 85%|████████▌ | 85/100 [09:10<01:33,  6.21s/it]

Epoch: 85, Loss: 2.415893793106079


 86%|████████▌ | 86/100 [09:15<01:24,  6.04s/it]

Epoch: 86, Loss: 2.355123996734619


 87%|████████▋ | 87/100 [09:21<01:15,  5.79s/it]

Epoch: 87, Loss: 2.3531363010406494


 88%|████████▊ | 88/100 [09:27<01:12,  6.04s/it]

Epoch: 88, Loss: 2.125384569168091


 89%|████████▉ | 89/100 [09:32<01:03,  5.80s/it]

Epoch: 89, Loss: 2.5547902584075928


 90%|█████████ | 90/100 [09:40<01:02,  6.27s/it]

Epoch: 90, Loss: 2.0218300819396973


 91%|█████████ | 91/100 [09:47<00:58,  6.46s/it]

Epoch: 91, Loss: 2.0381736755371094


 92%|█████████▏| 92/100 [09:54<00:53,  6.74s/it]

Epoch: 92, Loss: 1.927133560180664


 93%|█████████▎| 93/100 [10:00<00:45,  6.47s/it]

Epoch: 93, Loss: 2.1542325019836426


 94%|█████████▍| 94/100 [10:07<00:39,  6.51s/it]

Epoch: 94, Loss: 2.1112804412841797


 95%|█████████▌| 95/100 [10:13<00:32,  6.44s/it]

Epoch: 95, Loss: 1.9988771677017212


 96%|█████████▌| 96/100 [10:19<00:25,  6.30s/it]

Epoch: 96, Loss: 1.867392897605896


 97%|█████████▋| 97/100 [10:26<00:19,  6.44s/it]

Epoch: 97, Loss: 2.060044527053833


 98%|█████████▊| 98/100 [10:31<00:12,  6.23s/it]

Epoch: 98, Loss: 1.9780542850494385


 99%|█████████▉| 99/100 [10:39<00:06,  6.52s/it]

Epoch: 99, Loss: 1.9169942140579224


100%|██████████| 100/100 [10:45<00:00,  6.45s/it]

Epoch: 100, Loss: 2.0916476249694824





In [None]:
train_user_representations = {}
test_user_representations = {}
for i in range(pos_z.shape[0]):
  train_user_representations[i] = (dataset.user_ids[i], pos_z[i].detach(), dataset.train_labels[i])
for i in range(test_pos_z.shape[0]):
  test_user_representations[i] = (dataset.user_ids[i], test_pos_z[i].detach(), dataset.test_labels[i])

In [None]:
user_dataset = UserDataset(train_user_representations)

n = len(user_dataset)

indices = np.arange(n)
indices = np.random.permutation(indices)

train_indices = indices [:int(0.8*n)]
test_indices = indices[int(0.8*n):]

user_train_dataset = Subset(user_dataset, train_indices)
user_test_dataset = Subset(user_dataset, test_indices)

user_train_dataloader = DataLoader(user_train_dataset, batch_size=64, shuffle=True)
user_test_dataloader = DataLoader(user_test_dataset, batch_size=64, shuffle=True)

In [None]:
num_classes = 186
classes = np.arange(0, num_classes)

In [None]:
head_model = train2task()
torch.save(head_model.state_dict(), 'edge_dropout_infomax_linkpred_model.pth')

Epoch: 0, Loss: 5.160297393798828
Hits@1: 0.05130922049283981
Epoch: 1, Loss: 5.097674369812012
Hits@1: 0.072420634329319
Epoch: 2, Loss: 5.039208889007568
Hits@1: 0.10237042605876923
Epoch: 3, Loss: 4.984537601470947
Hits@1: 0.11653739959001541
Epoch: 4, Loss: 4.9333577156066895
Hits@1: 0.12617994844913483
Epoch: 5, Loss: 4.885403156280518
Hits@1: 0.13459746539592743
Epoch: 6, Loss: 4.840405464172363
Hits@1: 0.14278198778629303
Epoch: 7, Loss: 4.798115253448486
Hits@1: 0.14825336635112762
Epoch: 8, Loss: 4.758296966552734
Hits@1: 0.15389759838581085
Epoch: 9, Loss: 4.720736026763916
Hits@1: 0.16037608683109283
Epoch: 10, Loss: 4.685244083404541
Hits@1: 0.16354015469551086
Epoch: 11, Loss: 4.651653289794922
Hits@1: 0.16536645591259003
Epoch: 12, Loss: 4.619807720184326
Hits@1: 0.17178480327129364
Epoch: 13, Loss: 4.589570999145508
Hits@1: 0.17503908276557922
Epoch: 14, Loss: 4.560819625854492
Hits@1: 0.17753426730632782
Epoch: 15, Loss: 4.533444881439209
Hits@1: 0.18020232021808624
Epo

### Random walk, Infomax + Link pred

In [None]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
model = DeepGraphInfomax(hidden_channels=128,
                         encoder=Encoder(interm_channels=74, hidden_channels=128,
                                         num_features=128),
                         summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)),
                         corruption=random_walk).to(device)
linkmodel = LinkPredHead()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)
loss_bce = nn.BCEWithLogitsLoss()

In [None]:
num_epoch = 100
loss_epoch = []
for epoch in tqdm(range(1, num_epoch+1)):
  if epoch == num_epoch:
    pos_z, neg_z, summary, loss = train_infomax_linkpred(epoch)
  else:
    loss = train_infomax_linkpred(epoch)
  loss_epoch.append(loss)
test_pos_z, test_neg_z, test_summary = model(dataset.test_features, dataset.edge_index)

  1%|          | 1/100 [00:05<09:24,  5.70s/it]

Epoch: 1, Loss: 4.511188983917236


  2%|▏         | 2/100 [00:10<08:10,  5.01s/it]

Epoch: 2, Loss: 3.2485270500183105


  3%|▎         | 3/100 [00:14<07:17,  4.51s/it]

Epoch: 3, Loss: 3.6351232528686523


  4%|▍         | 4/100 [00:17<06:47,  4.24s/it]

Epoch: 4, Loss: 3.651085376739502


  5%|▌         | 5/100 [00:23<07:11,  4.54s/it]

Epoch: 5, Loss: 2.938366651535034


  6%|▌         | 6/100 [00:26<06:44,  4.31s/it]

Epoch: 6, Loss: 2.899475574493408


  7%|▋         | 7/100 [00:30<06:27,  4.16s/it]

Epoch: 7, Loss: 2.7207932472229004


  8%|▊         | 8/100 [00:35<06:47,  4.43s/it]

Epoch: 8, Loss: 2.125828504562378


  9%|▉         | 9/100 [00:41<07:06,  4.69s/it]

Epoch: 9, Loss: 1.3321194648742676


 10%|█         | 10/100 [00:44<06:38,  4.43s/it]

Epoch: 10, Loss: 2.705251455307007


 11%|█         | 11/100 [00:49<06:46,  4.57s/it]

Epoch: 11, Loss: 2.5453333854675293


 12%|█▏        | 12/100 [00:53<06:28,  4.42s/it]

Epoch: 12, Loss: 3.61480975151062


 13%|█▎        | 13/100 [00:57<06:10,  4.26s/it]

Epoch: 13, Loss: 2.415259599685669


 14%|█▍        | 14/100 [01:02<06:07,  4.28s/it]

Epoch: 14, Loss: 1.7380506992340088


 15%|█▌        | 15/100 [01:06<06:10,  4.36s/it]

Epoch: 15, Loss: 2.565542697906494


 16%|█▌        | 16/100 [01:10<05:51,  4.19s/it]

Epoch: 16, Loss: 1.8678243160247803


 17%|█▋        | 17/100 [01:14<05:39,  4.09s/it]

Epoch: 17, Loss: 1.9113147258758545


 18%|█▊        | 18/100 [01:19<05:59,  4.39s/it]

Epoch: 18, Loss: 1.5109338760375977


 19%|█▉        | 19/100 [01:23<05:41,  4.21s/it]

Epoch: 19, Loss: 1.3621124029159546


 20%|██        | 20/100 [01:26<05:27,  4.09s/it]

Epoch: 20, Loss: 3.75809383392334


 21%|██        | 21/100 [01:31<05:38,  4.29s/it]

Epoch: 21, Loss: 1.9234745502471924


 22%|██▏       | 22/100 [01:35<05:30,  4.24s/it]

Epoch: 22, Loss: 1.9660276174545288


 23%|██▎       | 23/100 [01:39<05:16,  4.11s/it]

Epoch: 23, Loss: 1.4544951915740967


 24%|██▍       | 24/100 [01:43<05:10,  4.09s/it]

Epoch: 24, Loss: 3.1708004474639893


 25%|██▌       | 25/100 [01:48<05:22,  4.30s/it]

Epoch: 25, Loss: 2.0654702186584473


 26%|██▌       | 26/100 [01:52<05:08,  4.16s/it]

Epoch: 26, Loss: 1.7574079036712646


 27%|██▋       | 27/100 [01:56<04:55,  4.05s/it]

Epoch: 27, Loss: 2.4111247062683105


 28%|██▊       | 28/100 [02:01<05:13,  4.36s/it]

Epoch: 28, Loss: 1.1909468173980713


 29%|██▉       | 29/100 [02:04<04:58,  4.20s/it]

Epoch: 29, Loss: 1.956261396408081


 30%|███       | 30/100 [02:08<04:45,  4.08s/it]

Epoch: 30, Loss: 1.6135680675506592


 31%|███       | 31/100 [02:13<04:50,  4.22s/it]

Epoch: 31, Loss: 1.576305866241455


 32%|███▏      | 32/100 [02:17<04:48,  4.25s/it]

Epoch: 32, Loss: 1.5809212923049927


 33%|███▎      | 33/100 [02:21<04:36,  4.12s/it]

Epoch: 33, Loss: 1.2449508905410767


 34%|███▍      | 34/100 [02:25<04:29,  4.08s/it]

Epoch: 34, Loss: 2.092135429382324


 35%|███▌      | 35/100 [02:30<04:42,  4.35s/it]

Epoch: 35, Loss: 1.574366569519043


 36%|███▌      | 36/100 [02:34<04:29,  4.21s/it]

Epoch: 36, Loss: 1.21950364112854


 37%|███▋      | 37/100 [02:38<04:17,  4.09s/it]

Epoch: 37, Loss: 1.5938754081726074


 38%|███▊      | 38/100 [02:43<04:31,  4.37s/it]

Epoch: 38, Loss: 1.2114832401275635


 39%|███▉      | 39/100 [02:47<04:18,  4.24s/it]

Epoch: 39, Loss: 1.649965763092041


 40%|████      | 40/100 [02:50<04:07,  4.12s/it]

Epoch: 40, Loss: 1.0123355388641357


 41%|████      | 41/100 [02:55<04:09,  4.22s/it]

Epoch: 41, Loss: 1.2586123943328857


 42%|████▏     | 42/100 [02:59<04:08,  4.29s/it]

Epoch: 42, Loss: 5.310296535491943


 43%|████▎     | 43/100 [03:03<03:56,  4.15s/it]

Epoch: 43, Loss: 1.129935622215271


 44%|████▍     | 44/100 [03:07<03:47,  4.07s/it]

Epoch: 44, Loss: 1.0830315351486206


 45%|████▌     | 45/100 [03:12<03:59,  4.36s/it]

Epoch: 45, Loss: 3.2153217792510986


 46%|████▌     | 46/100 [03:17<03:59,  4.44s/it]

Epoch: 46, Loss: 1.1599193811416626


 47%|████▋     | 47/100 [03:21<03:54,  4.42s/it]

Epoch: 47, Loss: 1.6275129318237305


 48%|████▊     | 48/100 [03:26<03:58,  4.59s/it]

Epoch: 48, Loss: 1.283963680267334


 49%|████▉     | 49/100 [03:30<03:42,  4.36s/it]

Epoch: 49, Loss: 1.5670708417892456


 50%|█████     | 50/100 [03:34<03:32,  4.25s/it]

Epoch: 50, Loss: 2.7998526096343994


 51%|█████     | 51/100 [03:39<03:39,  4.48s/it]

Epoch: 51, Loss: 1.316085934638977


 52%|█████▏    | 52/100 [03:43<03:26,  4.31s/it]

Epoch: 52, Loss: 1.312340497970581


 53%|█████▎    | 53/100 [03:47<03:15,  4.17s/it]

Epoch: 53, Loss: 1.3568636178970337


 54%|█████▍    | 54/100 [03:51<03:13,  4.21s/it]

Epoch: 54, Loss: 2.2088911533355713


 55%|█████▌    | 55/100 [03:55<03:13,  4.30s/it]

Epoch: 55, Loss: 1.930277943611145


 56%|█████▌    | 56/100 [03:59<03:03,  4.17s/it]

Epoch: 56, Loss: 2.0289790630340576


 57%|█████▋    | 57/100 [04:03<02:54,  4.06s/it]

Epoch: 57, Loss: 1.2134690284729004


 58%|█████▊    | 58/100 [04:08<03:02,  4.35s/it]

Epoch: 58, Loss: 1.2428863048553467


 59%|█████▉    | 59/100 [04:12<02:52,  4.20s/it]

Epoch: 59, Loss: 3.5572304725646973


 60%|██████    | 60/100 [04:16<02:43,  4.08s/it]

Epoch: 60, Loss: 1.4177736043930054


 61%|██████    | 61/100 [04:21<02:47,  4.31s/it]

Epoch: 61, Loss: 1.7786927223205566


 62%|██████▏   | 62/100 [04:25<02:40,  4.23s/it]

Epoch: 62, Loss: 1.389376163482666


 63%|██████▎   | 63/100 [04:28<02:31,  4.10s/it]

Epoch: 63, Loss: 1.0315724611282349


 64%|██████▍   | 64/100 [04:33<02:28,  4.13s/it]

Epoch: 64, Loss: 2.045011520385742


 65%|██████▌   | 65/100 [04:37<02:30,  4.29s/it]

Epoch: 65, Loss: 1.5561038255691528


 66%|██████▌   | 66/100 [04:41<02:21,  4.15s/it]

Epoch: 66, Loss: 1.20207941532135


 67%|██████▋   | 67/100 [04:45<02:13,  4.06s/it]

Epoch: 67, Loss: 1.6902382373809814


 68%|██████▊   | 68/100 [04:50<02:20,  4.38s/it]

Epoch: 68, Loss: 1.739739179611206


 69%|██████▉   | 69/100 [04:54<02:11,  4.23s/it]

Epoch: 69, Loss: 1.7542521953582764


 70%|███████   | 70/100 [04:58<02:03,  4.10s/it]

Epoch: 70, Loss: 1.1743898391723633


 71%|███████   | 71/100 [05:03<02:04,  4.28s/it]

Epoch: 71, Loss: 1.2128801345825195


 72%|███████▏  | 72/100 [05:07<01:58,  4.22s/it]

Epoch: 72, Loss: 1.32389235496521


 73%|███████▎  | 73/100 [05:10<01:51,  4.11s/it]

Epoch: 73, Loss: 1.4665169715881348


 74%|███████▍  | 74/100 [05:15<01:47,  4.13s/it]

Epoch: 74, Loss: 1.1121770143508911


 75%|███████▌  | 75/100 [05:19<01:47,  4.31s/it]

Epoch: 75, Loss: 1.7261290550231934


 76%|███████▌  | 76/100 [05:23<01:39,  4.17s/it]

Epoch: 76, Loss: 1.5960215330123901


 77%|███████▋  | 77/100 [05:27<01:33,  4.07s/it]

Epoch: 77, Loss: 3.560223340988159


 78%|███████▊  | 78/100 [05:32<01:36,  4.37s/it]

Epoch: 78, Loss: 2.0181429386138916


 79%|███████▉  | 79/100 [05:36<01:28,  4.21s/it]

Epoch: 79, Loss: 1.878627061843872


 80%|████████  | 80/100 [05:40<01:21,  4.10s/it]

Epoch: 80, Loss: 1.3989757299423218


 81%|████████  | 81/100 [05:45<01:22,  4.32s/it]

Epoch: 81, Loss: 1.4625917673110962


 82%|████████▏ | 82/100 [05:49<01:16,  4.25s/it]

Epoch: 82, Loss: 1.5742082595825195


 83%|████████▎ | 83/100 [05:53<01:10,  4.16s/it]

Epoch: 83, Loss: 1.2698966264724731


 84%|████████▍ | 84/100 [05:57<01:08,  4.29s/it]

Epoch: 84, Loss: 1.6807007789611816


 85%|████████▌ | 85/100 [06:03<01:11,  4.78s/it]

Epoch: 85, Loss: 1.3083661794662476


 86%|████████▌ | 86/100 [06:07<01:03,  4.51s/it]

Epoch: 86, Loss: 1.7620327472686768


 87%|████████▋ | 87/100 [06:11<00:55,  4.31s/it]

Epoch: 87, Loss: 1.4557660818099976


 88%|████████▊ | 88/100 [06:15<00:52,  4.39s/it]

Epoch: 88, Loss: 1.4443100690841675


 89%|████████▉ | 89/100 [06:20<00:47,  4.29s/it]

Epoch: 89, Loss: 0.9974012970924377


 90%|█████████ | 90/100 [06:23<00:41,  4.18s/it]

Epoch: 90, Loss: 4.479115009307861


 91%|█████████ | 91/100 [06:28<00:37,  4.19s/it]

Epoch: 91, Loss: 1.3958474397659302


 92%|█████████▏| 92/100 [06:32<00:34,  4.35s/it]

Epoch: 92, Loss: 1.9911344051361084


 93%|█████████▎| 93/100 [06:36<00:29,  4.22s/it]

Epoch: 93, Loss: 1.173257827758789


 94%|█████████▍| 94/100 [06:40<00:24,  4.10s/it]

Epoch: 94, Loss: 1.5896925926208496


 95%|█████████▌| 95/100 [06:45<00:22,  4.41s/it]

Epoch: 95, Loss: 1.0592446327209473


 96%|█████████▌| 96/100 [06:49<00:17,  4.26s/it]

Epoch: 96, Loss: 1.3572649955749512


 97%|█████████▋| 97/100 [06:53<00:12,  4.15s/it]

Epoch: 97, Loss: 1.2067577838897705


 98%|█████████▊| 98/100 [06:58<00:08,  4.37s/it]

Epoch: 98, Loss: 1.7433596849441528


 99%|█████████▉| 99/100 [07:02<00:04,  4.29s/it]

Epoch: 99, Loss: 1.2765605449676514


100%|██████████| 100/100 [07:06<00:00,  4.26s/it]

Epoch: 100, Loss: 1.296288013458252





In [None]:
train_user_representations = {}
test_user_representations = {}
for i in range(pos_z.shape[0]):
  train_user_representations[i] = (dataset.user_ids[i], pos_z[i].detach(), dataset.train_labels[i])
for i in range(test_pos_z.shape[0]):
  test_user_representations[i] = (dataset.user_ids[i], test_pos_z[i].detach(), dataset.test_labels[i])

In [None]:
user_dataset = UserDataset(train_user_representations)

n = len(user_dataset)

indices = np.arange(n)
indices = np.random.permutation(indices)

train_indices = indices [:int(0.8*n)]
test_indices = indices[int(0.8*n):]

user_train_dataset = Subset(user_dataset, train_indices)
user_test_dataset = Subset(user_dataset, test_indices)

user_train_dataloader = DataLoader(user_train_dataset, batch_size=64, shuffle=True)
user_test_dataloader = DataLoader(user_test_dataset, batch_size=64, shuffle=True)

In [None]:
num_classes = 186
classes = np.arange(0, num_classes)

In [None]:
head_model = train2task()
torch.save(head_model.state_dict(), 'random_walk_infomax_linkpred_model.pth')

Epoch: 0, Loss: 5.178340911865234
Hits@1: 0.04955808073282242
Epoch: 1, Loss: 5.1121368408203125
Hits@1: 0.07634379714727402
Epoch: 2, Loss: 5.050725936889648
Hits@1: 0.09246482700109482
Epoch: 3, Loss: 4.993563175201416
Hits@1: 0.10117544233798981
Epoch: 4, Loss: 4.940242767333984
Hits@1: 0.11373406648635864
Epoch: 5, Loss: 4.890415191650391
Hits@1: 0.12101671099662781
Epoch: 6, Loss: 4.843774318695068
Hits@1: 0.12841209769248962
Epoch: 7, Loss: 4.800050735473633
Hits@1: 0.13436447083950043
Epoch: 8, Loss: 4.759006500244141
Hits@1: 0.1408730149269104
Epoch: 9, Loss: 4.720427989959717
Hits@1: 0.14826840162277222
Epoch: 10, Loss: 4.684117794036865
Hits@1: 0.15346169471740723
Epoch: 11, Loss: 4.649898052215576
Hits@1: 0.15693391859531403
Epoch: 12, Loss: 4.61760950088501
Hits@1: 0.15960197150707245
Epoch: 13, Loss: 4.587100505828857
Hits@1: 0.16236020624637604
Epoch: 14, Loss: 4.558238506317139
Hits@1: 0.16190926730632782
Epoch: 15, Loss: 4.530893325805664
Hits@1: 0.16310425102710724
Epo

## Training on target experiments

### Concatenation

In [None]:
users_list = load_dataset_timestamp(20001, 128, 30)

In [None]:
class UserInfoDataset():
  def __init__(self, data, max_len, train_user_representations, test_user_representations):
    self.data = data
    self.max_len = max_len
    self.train_user_representations = train_user_representations
    self.test_user_representations = test_user_representations

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):

    user = self.data[idx]
    seq_len = user['seq_len']

    tr_act_seq = np.zeros((self.max_len,)).astype('int32')
    tr_act_seq[:seq_len] = np.array(user['train_act_seq'])
    tr_act_seq = np.transpose(tr_act_seq)

    tr_time_seq = np.zeros((self.max_len,)).astype('int32')
    tr_time_seq[:seq_len] = user['train_time_seq']
    tr_time_seq = np.transpose(tr_time_seq)

    t_act_seq = np.zeros((self.max_len, )).astype('int32')
    t_act_seq[:seq_len] = user['test_act_seq']
    t_act_seq = np.transpose(t_act_seq)

    t_time_seq = np.zeros((self.max_len, )).astype('int32')
    t_time_seq[:seq_len] = user['test_time_seq']
    t_time_seq = np.transpose(t_time_seq)

    if idx not in self.train_user_representations.keys():
      self.train_user_representations[idx] = (torch.zeros(size=(128,)), torch.tensor(0))
      self.test_user_representations[idx] = (torch.zeros(size=(128,)), torch.tensor(0))


    return user['user'], tr_act_seq, \
    tr_time_seq, user['train_act_label'], \
    user['train_time_label'], t_act_seq, \
    t_time_seq, user['test_act_label'], \
    user['test_time_label'], user['seq_len'], self.train_user_representations[idx][1], self.test_user_representations[idx][1]

In [None]:
user_dataset = UserInfoDataset(users_list, 30, train_user_representations, test_user_representations)
user_dataloader = DataLoader(user_dataset, batch_size = 64, shuffle=True)

Concatenation with pre-trained embedding

In [None]:
class RNNTargetModel(nn.Module):
  def __init__(self, num_classes):
    super(RNNTargetModel, self).__init__()
    self.fc1 = nn.Linear(256, num_classes)
    self.rnn = nn.RNN(128, 128, batch_first = True)
    self.norm = nn.BatchNorm1d(128)
    self.reduce_dim = nn.Linear(128, 128)

  def forward(self, x, repr, seq_len):
    x, h = self.rnn(x)
    hx = torch.zeros(x.shape[0], x.shape[2])
    for i in range(hx.shape[0]):
      hx[i] = x[i][seq_len[i] - 1]

    hx = torch.cat([hx, repr], dim= 1)

    x = self.fc1(hx)
    return x

Concatenation with random embedding

In [None]:
class RNNTargetModel(nn.Module):
  def __init__(self, num_classes):
    super(RNNTargetModel, self).__init__()
    self.fc1 = nn.Linear(256, num_classes)
    self.rnn = nn.RNN(128, 128, batch_first = True)
    self.norm = nn.BatchNorm1d(128)
    self.reduce_dim = nn.Linear(128, 128)

  def forward(self, x, repr, seq_len):
    x, h = self.rnn(x)
    hx = torch.zeros(x.shape[0], x.shape[2])
    for i in range(hx.shape[0]):
      hx[i] = x[i][seq_len[i] - 1]

    repr = torch.rand((len(x), 128))
    hx = torch.cat([hx, repr], dim= 1)
    x = self.fc1(hx)
    return x

Without embedding add

In [None]:
class RNNTargetModel(nn.Module):
  def __init__(self, num_classes):
    super(RNNTargetModel, self).__init__()
    self.fc1 = nn.Linear(256, num_classes)
    self.rnn = nn.RNN(128, 128, batch_first = True)
    self.norm = nn.BatchNorm1d(128)
    self.reduce_dim = nn.Linear(128, 128)

  def forward(self, x, repr, seq_len):
    x, h = self.rnn(x)
    hx = torch.zeros(x.shape[0], x.shape[2])
    for i in range(hx.shape[0]):
      hx[i] = x[i][seq_len[i] - 1]
    x = self.fc1(hx)
    return x

In [None]:
def rnn_target_test(model):
  metrics_val = []
  model.eval()
  index = 0

  for user, train_input, train_time, train_label, train_time_label, test_input, test_time, test_label, test_time_label, seq_len, train_repr, test_repr in user_dataloader:

      test_comb_input = np.concatenate([np.expand_dims(test_input, axis=-1),
                                                np.expand_dims(test_time, axis=-1)], axis=2)
      model_input = test_comb_input
      model_output = test_label
      test_rnn_input_emb = item_emb[model_input[:, :, 0]]
      test_probs = model(test_rnn_input_emb, test_repr, seq_len)

      test_pred = torch.argmax(test_probs, axis = 1)

      test_one_hot = torch.zeros(len(test_probs), num_classes)
      test_one_hot[torch.arange(len(test_one_hot)), test_label] = 1
      loss = loss_fn(test_probs, test_one_hot)

      hits1, hits5, hits10, hits20, hits50, hits100, map1, map5, map10, map20, map50, map100, \
      ndcg1, ndcg5, ndcg10, ndcg20, ndcg50, ndcg100 = get_metrics_(test_probs, test_label, test_one_hot)

      metrics_val.append([hits1, hits5, hits10, hits20, hits50, hits100, map1, map5, map10, map20, map50, map100, ndcg1, ndcg5, ndcg10, ndcg20, ndcg50, ndcg100])


  mean = torch.Tensor(metrics_val).mean(axis=0)
  test_hits1, test_hits5, test_hits10, test_hits20, test_hits50, test_hits100, \
  test_map1, test_map5, test_map10, test_map20, test_map50, test_map100, test_ndcg1, test_ndcg5, test_ndcg10, test_ndcg20, test_ndcg50, test_ndcg100 = mean
  return test_hits1, test_hits5, test_hits10, test_hits20, test_hits50, test_hits100, test_map1, test_map5, test_map10, test_map20, test_map50, test_map100,\
  test_ndcg1, test_ndcg5, test_ndcg10, test_ndcg20, test_ndcg50, test_ndcg100


In [None]:
model = RNNTargetModel(num_classes)
loss_fn = nn.CrossEntropyLoss()
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.97, last_epoch=-1)


for i in range(20):
  losses = []
  val_losses = []
  hits_1_scores = []

  for user, train_input, train_time, train_label, train_time_label, test_input, test_time, test_label, \
  test_time_label, seq_len, train_repr, test_repr in user_dataloader:
      optimizer.zero_grad()
      comb_input = np.concatenate([np.expand_dims(train_input, axis=-1),
                                                np.expand_dims(train_time, axis=-1)], axis=2)
      model_input = comb_input
      model_output = train_label
      rnn_input_emb = item_emb[model_input[:, :, 0]]
      test_comb_input = np.concatenate([np.expand_dims(test_input, axis=-1),
                                                np.expand_dims(test_time, axis=-1)], axis=2)
      test_model_input = test_comb_input
      test_rnn_input_emb = item_emb[test_model_input[:, :, 0]]
      probs = model(rnn_input_emb, train_repr, seq_len)
      test_probs = model(test_rnn_input_emb, test_repr, seq_len)
      pred = torch.argmax(probs, axis = 1)

      one_hot = torch.zeros(len(probs), num_classes)
      one_hot[torch.arange(len(one_hot)), model_output] = 1

      test_one_hot = torch.zeros(len(test_probs), num_classes)
      test_one_hot[torch.arange(len(test_one_hot)), test_label] = 1

      loss = loss_fn(probs, one_hot)
      val_loss = loss_fn(test_probs, test_one_hot)
      losses.append(loss)
      val_losses.append(val_loss)
      hits_1_scores.append(top_k_accuracy_score(test_label, test_probs.cpu().detach().numpy(), k=1, labels = classes))
      loss.backward()
      optimizer.step()
      losses.append(loss)


  mean_loss = torch.Tensor(losses).mean(axis=0)
  mean_val_loss = torch.Tensor(val_losses).mean(axis = 0).item()
  mean_hits = torch.Tensor(hits_1_scores).mean(axis=0).item()
  print(f'Epoch: {i} Loss: {mean_loss.item()}, Val loss: {mean_val_loss} Hits@1: {mean_hits}')
hits1, hits5, hits10, hits20, hits50, hits100, map1, map5, map10, map20, map50, map100, ndcg1, ndcg5, ndcg10, ndcg20, ndcg50, ndcg100 = rnn_target_test(model)
print_metrics(hits1, hits5, hits10, hits20, hits50, hits100, map1, map5, map10, map20, map50, map100, ndcg1, ndcg5, ndcg10, ndcg20, ndcg50, ndcg100)

Epoch: 0 Loss: 4.462273597717285, Val loss: 4.467819690704346 Hits@1: 0.06983523815870285
Epoch: 1 Loss: 4.01273775100708, Val loss: 4.0267839431762695 Hits@1: 0.13592343032360077
Epoch: 2 Loss: 3.4965574741363525, Val loss: 3.5176172256469727 Hits@1: 0.25534600019454956
Epoch: 3 Loss: 2.9872641563415527, Val loss: 3.0122857093811035 Hits@1: 0.3300824761390686
Epoch: 4 Loss: 2.592750310897827, Val loss: 2.621905565261841 Hits@1: 0.4544183909893036
Epoch: 5 Loss: 2.2767646312713623, Val loss: 2.3081257343292236 Hits@1: 0.5460777878761292
Epoch: 6 Loss: 2.023632049560547, Val loss: 2.0590434074401855 Hits@1: 0.6054298281669617
Epoch: 7 Loss: 1.8238966464996338, Val loss: 1.8630046844482422 Hits@1: 0.6465548872947693
Epoch: 8 Loss: 1.6682329177856445, Val loss: 1.712514042854309 Hits@1: 0.677602231502533
Epoch: 9 Loss: 1.5491427183151245, Val loss: 1.5955711603164673 Hits@1: 0.6961286067962646
Epoch: 10 Loss: 1.4537062644958496, Val loss: 1.5024302005767822 Hits@1: 0.7095329165458679
Epoc