### Import Libraries

In [1]:
!pip install torch_geometric
!pip install networkit
!apt install libgraphviz-dev
!pip install pygraphviz

Collecting torch_geometric
  Downloading torch_geometric-2.5.3-py3-none-any.whl (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch_geometric
Successfully installed torch_geometric-2.5.3
Collecting networkit
  Downloading networkit-11.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.7/13.7 MB[0m [31m50.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: networkit
Successfully installed networkit-11.0
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
  libgail-common libgail18 libgtk2.0-0 libgtk2.0-bin libgtk2.0-common libgvc6-plugins-gtk
  librsvg2-common libxdot4
Suggested packages:
  gvfs
The following NEW packages will be installed:
  libgail-common libgail18 libgraphviz-dev

In [26]:
import numpy as np
import pandas as pd
import torch
from torch import nn

import json
import os
import networkx as nx
import networkit as nk
from networkit import vizbridges

from tqdm import tqdm
import ast
import matplotlib.pyplot as plt
import re
from sklearn.metrics import top_k_accuracy_score
from sklearn.metrics import ndcg_score

from torch_geometric.data import Data
from torch_geometric.nn import GATv2Conv

import copy
from typing import Callable, Tuple

import torch
from torch import Tensor
from torch.nn import Module, Parameter

from torch_geometric.nn.inits import reset, uniform
import random

In [3]:
torch.manual_seed(42)
np.random.seed(42)

### Util functions

In [4]:
def load_dataset_timestamp( n_users, n_context, seq_len):
    act_list = []
    time_list = []
    user_list = []

    max_timestamp = -1.0
    min_timestamp = float('inf')

    with open('gowalla_user_activity.txt', 'r') as raw_file:
        for line in raw_file:
            t_item_list = []
            t_time_list = []
            user = int(line.split(':')[0])
            entries = line.split()[1:]
            for a_entry in entries:
                item, time_stamp = a_entry.split(':')
                t_item_list.append(int(item.strip()))
                t_time_list.append(int(time_stamp.strip()))

                if min_timestamp > int(time_stamp.strip()):
                    min_timestamp = int(time_stamp.strip())
                if max_timestamp < int(time_stamp.strip()):
                    max_timestamp = int(time_stamp.strip())

            act_list.append(t_item_list[0: seq_len])
            time_list.append(t_time_list[0: seq_len])
            user_list.append(user)

    new_time_list = []
    num_bins = 0

    times_bins = np.linspace(min_timestamp, max_timestamp + 1, num=num_bins, dtype=np.int32)
    for a_time_list in time_list:
        temp_time_list = (np.digitize(np.asarray(a_time_list), times_bins) - 1).tolist()
        new_time_list.append(temp_time_list)

    all_examples = {}
    for i in range(0, len(act_list)):

        train_act_seq = act_list[i][:-2]

        train_time_seq = new_time_list[i][:-2]

        train_act_label = act_list[i][-2]
        train_time_label = new_time_list[i][-2]

        test_act_seq = act_list[i][1:-1]
        test_time_seq = new_time_list[i][1:-1]

        test_act_label = act_list[i][-1]
        test_time_label = new_time_list[i][-1]

        entry = {
            'train_act_seq': train_act_seq,
            'train_time_seq': train_time_seq,
            'train_act_label': train_act_label,
            'train_time_label': train_time_label,
            'test_act_seq': test_act_seq,
            'test_time_seq': test_time_seq,
            'test_act_label': test_act_label,
            'test_time_label': test_time_label,
            'seq_len': len(train_act_seq),
            'user': user_list[i]
        }

        all_examples[user_list[i]] = entry

    return all_examples

### Metrics

In [5]:
def print_metrics(hits1, hits5, hits10, hits20, hits50, hits100, map1, map5, map10, map20, map50, map100, ndcg1, ndcg5, ndcg10, ndcg20, ndcg50, ndcg100):
    print(f'hits@1: {hits1:.6f}, hits@5: {hits5:.6f}, hits@10: {hits10:.6f}, hits@20: {hits20:.6f}')
    print(f'hits@50: {hits50:.6f}, hits@100: {hits100:.6f}')
    print(f'map@1: {map1:.6f}, map@5: {map5:.6f}, map@10: {map10:.6f}, map@20: {map20:.6f}')
    print(f'map@50: {map50:.6f}, map@100: {map100:.6f}')
    print(f'ndcg@1: {ndcg1:.6f}, ndcg@5: {ndcg5:.6f}, ndcg@10: {ndcg10:.6f}, ndcg@20: {ndcg20:.6f}')
    print(f'ndcg@50: {ndcg50:.6f}, ndcg@100: {ndcg100:.6f}')

In [6]:
def apk(actual, predicted, k=10):

    if len(predicted) > k:
        predicted = predicted[:k]

    score = 0.0
    num_hits = 0.0

    for i, p in enumerate(predicted):
        if p in actual and p not in predicted[:i]:
            num_hits += 1.0
            score += num_hits / (i + 1.0)

    if not actual:
        return 0.0

    return score / min(len(actual), k)


def mapk(y_prob, y, k=10):

    predicted = [np.argsort(p_)[-k:][::-1] for p_ in y_prob]
    actual = [[y_] for y_ in y]
    return np.mean([apk(a, p, k) for a, p in zip(actual, predicted)])


def hits_k(y_prob, y, k=10):
    acc = []
    for p_, y_ in zip(y_prob, y):
        top_k = p_.argsort()[-k:][::-1]
        acc += [1. if y_ in top_k else 0.]
    return sum(acc) / len(acc)

In [7]:
def get_metrics_(probs, labels_batch, test_one_hot):
    hits1 = top_k_accuracy_score(labels_batch, probs.cpu().detach().numpy(), k=1, labels = classes)
    hits5 = top_k_accuracy_score(labels_batch, probs.cpu().detach().numpy(), k=5, labels = classes)
    hits10 = top_k_accuracy_score(labels_batch, probs.cpu().detach().numpy(), k=10, labels = classes)
    hits20 = top_k_accuracy_score(labels_batch, probs.cpu().detach().numpy(), k=20, labels = classes)
    hits50= top_k_accuracy_score(labels_batch, probs.cpu().detach().numpy(), k=50, labels = classes)
    hits100 = top_k_accuracy_score(labels_batch, probs.cpu().detach().numpy(), k=100, labels = classes)

    map1 = mapk(y_prob=probs.cpu().detach().numpy(), y = labels_batch, k=1)
    map5 = mapk(y_prob=probs.cpu().detach().numpy(), y = labels_batch, k=5)
    map10 = mapk(y_prob=probs.cpu().detach().numpy(), y = labels_batch, k=10)
    map20 = mapk(y_prob=probs.cpu().detach().numpy(), y = labels_batch, k=20)
    map50 = mapk(y_prob=probs.cpu().detach().numpy(), y = labels_batch, k=50)
    map100 = mapk(y_prob=probs.cpu().detach().numpy(), y = labels_batch, k=100)

    ndcg1 = ndcg_score(test_one_hot, probs.cpu().detach().numpy(), k=1)
    ndcg5 = ndcg_score(test_one_hot, probs.cpu().detach().numpy(), k=5)
    ndcg10 = ndcg_score(test_one_hot, probs.cpu().detach().numpy(), k=10)
    ndcg20 = ndcg_score(test_one_hot, probs.cpu().detach().numpy(), k=20)
    ndcg50 = ndcg_score(test_one_hot, probs.cpu().detach().numpy(), k=50)
    ndcg100 = ndcg_score(test_one_hot, probs.cpu().detach().numpy(), k=100)
    return hits1, hits5, hits10, hits20, hits50, hits100, map1, map5, map10, map20, map50, map100, ndcg1, ndcg5, ndcg10, ndcg20, ndcg50, ndcg100

### Graph construction

In [8]:
users = load_dataset_timestamp(20001, 128, 30)

In [9]:
friends = pd.read_csv('gowalla_edges.csv')

In [None]:
Gcl = nx.Graph()
for i in range(len(users)):
  if users[i]['user'] not in Gcl.nodes():
      Gcl.add_node(users[i]['user'], weight=10, color='seagreen')
for i in range(len(friends)):
  Gcl.add_edge(friends.loc[i, '1st friend'], friends.loc[i, '2nd friend'], color = 'blue')



nodes = list(Gcl.nodes())
node_colors = [Gcl.nodes[node]['color'] for node in Gcl.nodes]
node_weights = [Gcl.nodes[node]['weight'] for node in Gcl.nodes]
edge_colors = [Gcl.edges[edge]['color'] for edge in Gcl.edges]

plt.figure(figsize=(12, 8))
pos = nx.drawing.nx_agraph.graphviz_layout(Gcl, prog='neato')
nx.draw_networkx(Gcl, pos=pos, with_labels=False, node_size=node_weights, node_color=node_colors, edge_color=edge_colors, width=1)

In [10]:
Gcl = nx.Graph()
for i in range(len(users)):
  if users[i]['user'] not in Gcl.nodes():
      Gcl.add_node(users[i]['user'], weight=10, color='seagreen')
for i in range(len(friends)):
  Gcl.add_edge(friends.loc[i, '1st friend'], friends.loc[i, '2nd friend'], color = 'blue')


nodes = list(Gcl.nodes())
node_colors = [Gcl.nodes[node]['color'] for node in Gcl.nodes]
node_weights = [Gcl.nodes[node]['weight'] for node in Gcl.nodes]
edge_colors = [Gcl.edges[edge]['color'] for edge in Gcl.edges]

In [11]:
graphid2pid = {}
pid2graphid = {}

s = 0
for node_id in Gcl.nodes():
    pid2graphid[node_id] = s
    graphid2pid[s] = node_id
    s += 1

graph_edges = []
for (u, v) in Gcl.edges():
    graph_edges.append((pid2graphid[u], pid2graphid[v]))

### Graph characteristics

In [12]:
nkg = nk.Graph(n=len(pid2graphid), weighted=True, directed=False, edgesIndexed=True)
for (u, v) in graph_edges:
    nkg.addEdge(u, v)

In [13]:
nk.overview(nkg)

Network Properties:
nodes, edges			20001, 167752
directed?			False
weighted?			True
isolated nodes			275
self-loops			0
density				0.000839
clustering coefficient		0.294269
min/max/avg degree		0, 6171, 16.774361
degree assortativity		-0.064750
number of connected components	290
size of largest component	19698 (98.49 %)


### User dataset construction

In [14]:
class UserInfoDataset():
  def __init__(self, data, max_len):
    self.data = data
    self.max_len = max_len

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):

    user = self.data[idx]
    seq_len = user['seq_len']

    tr_act_seq = np.zeros((self.max_len,)).astype('int32')
    tr_act_seq[:seq_len] = np.array(user['train_act_seq'])
    tr_act_seq = np.transpose(tr_act_seq)

    tr_time_seq = np.zeros((self.max_len,)).astype('int32')
    tr_time_seq[:seq_len] = user['train_time_seq']
    tr_time_seq = np.transpose(tr_time_seq)

    t_act_seq = np.zeros((self.max_len, )).astype('int32')
    t_act_seq[:seq_len] = user['test_act_seq']
    t_act_seq = np.transpose(t_act_seq)

    t_time_seq = np.zeros((self.max_len, )).astype('int32')
    t_time_seq[:seq_len] = user['test_time_seq']
    t_time_seq = np.transpose(t_time_seq)


    return user['user'], tr_act_seq, \
    tr_time_seq, user['train_act_label'], \
    user['train_time_label'], t_act_seq, \
    t_time_seq, user['test_act_label'], \
    user['test_time_label'], user['seq_len']

In [15]:
from torch.utils.data import Subset
from torch.utils.data import DataLoader

user_dataset = UserInfoDataset(users, 30)

n = len(user_dataset)

indices = np.arange(n)
indices = np.random.permutation(indices)

train_indices = indices [:int(0.8*n)]
test_indices = indices[int(0.8*n):]

user_train_dataset = Subset(user_dataset, train_indices)
user_test_dataset = Subset(user_dataset, test_indices)

user_dataloader = DataLoader(user_dataset, batch_size = 64, shuffle = True)
user_train_dataloader = DataLoader(user_train_dataset, batch_size=64, shuffle=True)
user_test_dataloader = DataLoader(user_test_dataset, batch_size=64, shuffle=True)

item_emb  = nn.init.xavier_uniform_(torch.empty(186, 128))

## Pre-training

In [16]:
rnn = nn.RNN(128, 128, batch_first = True)
norm = nn.BatchNorm1d(128)

### Getting initial user representations

In [17]:
user_representations = {}
for users, train_input, train_time, train_label, train_time_label, test_input, test_time, test_label, test_time_label, seq_len in user_dataloader:
      comb_input = np.concatenate([np.expand_dims(train_input, axis=-1),
                                                np.expand_dims(train_time, axis=-1)], axis=2)
      model_input = comb_input
      model_output = train_label
      rnn_input_emb = item_emb[model_input[:, :, 0]]
      x, h = rnn(rnn_input_emb)
      hx = torch.zeros(x.shape[0], x.shape[2])#.to(device)
      for i in range(hx.shape[0]):
        hx[i] = x[i][seq_len[i] - 1]
      user_representation = norm(hx)
      for i in range(len(users)):
        user_representations[users[i].item()] = (user_representation[i], train_label[i].item(), seq_len[i].item())

In [18]:
features = []
seq_lens = []
for graphid in Gcl.nodes():
    pid = graphid2pid[graphid]
    features.append(user_representations[graphid][0].detach().numpy())
    seq_lens.append(user_representations[graphid][2])
features = np.array(features)

In [19]:
dataset = Data(x=torch.tensor(features, dtype=torch.float), edge_index=torch.tensor(np.array(graph_edges).T, dtype=torch.int64))

### Deep Graph Infomax model

In [20]:
eps = 1e-15
class DeepGraphInfomax(torch.nn.Module):
    def __init__(
        self,
        hidden_channels,
        encoder,
        summary,
        corruption,
    ):
        super().__init__()
        self.hidden_channels = hidden_channels
        self.encoder = encoder
        self.summary = summary
        self.corruption = corruption

        self.weight = Parameter(torch.empty(hidden_channels, hidden_channels))

        self.reset_parameters()

    def reset_parameters(self):
        reset(self.encoder)
        reset(self.summary)
        uniform(self.hidden_channels, self.weight)

    def forward(self, *args, **kwargs) -> Tuple[Tensor, Tensor, Tensor]:
        pos_z = self.encoder(*args, **kwargs)

        cor = self.corruption(*args, **kwargs)
        cor = cor if isinstance(cor, tuple) else (cor, )
        cor_args = cor[:len(args)]
        cor_kwargs = copy.copy(kwargs)
        for key, value in zip(kwargs.keys(), cor[len(args):]):
            cor_kwargs[key] = value

        neg_z = self.encoder(*cor_args, **cor_kwargs)

        summary = self.summary(pos_z, *args, **kwargs)

        return pos_z, neg_z, summary

    def discriminate(self, z: Tensor, summary: Tensor,
                     sigmoid: bool = True) -> Tensor:

        summary = summary.t() if summary.dim() > 1 else summary
        value = torch.matmul(z, torch.matmul(self.weight, summary))
        return torch.sigmoid(value) if sigmoid else value

    def loss(self, pos_z: Tensor, neg_z: Tensor, summary: Tensor) -> Tensor:
        pos_loss = -torch.log(
            self.discriminate(pos_z, summary, sigmoid=True) + eps).mean()
        neg_loss = -torch.log(1 -
                              self.discriminate(neg_z, summary, sigmoid=True) +
                              eps).mean()

        return pos_loss + neg_loss


    def __repr__(self) -> str:
        return f'{self.__class__.__name__}({self.hidden_channels})'

In [21]:
class Encoder(nn.Module):
    def __init__(self, interm_channels, hidden_channels, num_features):
        super().__init__()

        self.num_features = num_features
        self.rnn = nn.RNN(128, 128, batch_first = True)
        self.norm = nn.BatchNorm1d(128)
        self.conv1 = GATv2Conv(self.num_features, interm_channels)
        self.prelu1 = nn.PReLU(interm_channels)
        self.dropout = nn.Dropout(0.2)
        self.conv2 = GATv2Conv(interm_channels, hidden_channels)
        self.prelu2 = nn.PReLU(hidden_channels)
        self.conv3 = GATv2Conv(hidden_channels, hidden_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = self.prelu1(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        x = self.prelu2(x)
        x = self.dropout(x)
        x = self.conv3(x, edge_index)
        return x

### Augmentation construction helping functions

In [22]:
def index_to_mask(index):

    index = index.view(-1)
    size = int(index.max()) + 1

    mask = index.new_zeros(20001, dtype=torch.bool)
    mask[index] = True
    return mask

In [23]:
def map_index(indices, edge_index):
  dictionary = {}
  label = 0
  for ind in indices:
    dictionary[ind] = label
    label += 1
  for i in range(edge_index.shape[1]):
    edge_index[0][i] = dictionary[edge_index[0][i].item()]
    edge_index[1][i] = dictionary[edge_index[1][i].item()]
  return edge_index

In [24]:
def construct_neighbors_dictionary(edge_index):
  dictionary = {}
  for i in range(edge_index.shape[1]):
    if edge_index[0][i].item() not in dictionary.keys():
      dictionary[edge_index[0][i].item()] = []
    dictionary[edge_index[0][i].item()].append(edge_index[1][i].item())
    if edge_index[1][i].item() not in dictionary.keys():
      dictionary[edge_index[1][i].item()] = []
    dictionary[edge_index[1][i].item()].append(edge_index[0][i].item())
  return dictionary

In [25]:
neighbors = construct_neighbors_dictionary(dataset.edge_index)

### Augmentations

In [27]:
def node_mix(x, edge_index):
    return x[torch.randperm(x.size(0))], edge_index

In [28]:
def node_dropout(x, edge_index):
  indices = []
  deleted_indices = []
  for i in range(x.shape[0]):
    drop = random.random()
    if drop >= 0.3:
      indices.append(i)
    else:
      deleted_indices.append(i)
  x = torch.index_select(x, 0, torch.tensor(indices, dtype = torch.int32))
  node_mask = index_to_mask(torch.tensor(indices, dtype = torch.int32))
  edge_mask = node_mask[edge_index[0]] & node_mask[edge_index[1]]
  edge_index = edge_index[:, edge_mask]
  edge_index = map_index(indices, edge_index)

  return x, edge_index

In [29]:
def random_walk(x, edge_index):
  node = 0
  rwalk = [node]
  walk_length = 10000
  for i in range(walk_length - 1):
    temp = neighbors[node]
    temp = list(set(temp) - set(rwalk))
    if len(temp) == 0:
      break
    new_node = random.choice(temp)
    rwalk.append(new_node)
    node = new_node
  x = torch.index_select(x, 0, torch.tensor(rwalk, dtype = torch.int32))
  node_mask = index_to_mask(torch.tensor(rwalk, dtype = torch.int32))
  edge_mask = node_mask[edge_index[0]] & node_mask[edge_index[1]]
  edge_index = edge_index[:, edge_mask]
  edge_index = map_index(rwalk, edge_index)

  return x, edge_index

In [30]:
def edge_dropout(x, edge_index):
  indices = []
  for i in range(edge_index.shape[1]):
    drop = random.random()
    if drop >= 0.3:
      indices.append(i)
  edge_index = torch.index_select(edge_index, 1, torch.tensor(indices, dtype = torch.int32))
  return x, edge_index

### Link prediction head

In [31]:
class LinkPredHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc_1 = nn.Linear(256, 1)

    def forward(self, x, edge_index):
        x_src, x_dst = x[edge_index[0]], x[edge_index[1]]
        x = torch.cat((x_src, x_dst), dim = 1)
        x = self.fc_1(x)
        return x#torch.sum(x_src * x_dst, dim=1)

### Next action prediction head

In [32]:
class NextActionPredHead(nn.Module):
  def __init__(self):
    super(NextActionPredHead, self).__init__()
    self.fc1 = nn.Linear(128, 186)
    self.rnn = nn.RNN(128, 128, batch_first = True)

  def forward(self, x):
    x = self.fc1(x)
    return x

In [33]:
def negative_sampling(edge_index, num_nodes):
    mask_1 = torch.rand(edge_index.size(1)) < 0.5
    mask_2 = ~mask_1

    neg_edge_index = edge_index.clone()
    neg_edge_index[0, mask_1] = torch.randint(num_nodes, (mask_1.sum(),))
    neg_edge_index[1, mask_2] = torch.randint(num_nodes, (mask_2.sum(),))

    return neg_edge_index

### User representations dataset

In [34]:
class UserDataset():
    def __init__(self, user_representations):
        self.user_representations = user_representations

    def __len__(self):
        return len(self.user_representations)

    def __getitem__(self, idx):

        return self.user_representations[idx][0], self.user_representations[idx][1], self.user_representations[idx][2]

### Training functions

In [37]:
def train_only_infomax(epoch):
    model.train()
    optimizer.zero_grad()
    pos_z, neg_z, summary = model(dataset.x, dataset.edge_index)

    loss_infomax = model.loss(pos_z, neg_z, summary)

    pos_out = linkmodel(pos_z, dataset.edge_index)
    neg_edge_index = negative_sampling(dataset.edge_index, dataset.num_nodes)
    neg_out = linkmodel(pos_z, neg_edge_index)

    out = torch.cat([pos_out, neg_out])
    gt = torch.cat([torch.ones_like(pos_out), torch.zeros_like(neg_out)])
    loss_link_pred = loss_bce(out, gt)

    loss =  loss_infomax
    loss.backward()
    optimizer.step()
    print(f'Epoch: {epoch}, Loss: {loss}')
    if epoch == num_epoch:
      return pos_z, neg_z, summary, loss.item()
    else:
      return loss.item()

In [38]:
def train_infomax_linkpred(epoch):
    model.train()
    optimizer.zero_grad()
    pos_z, neg_z, summary = model(dataset.x, dataset.edge_index)

    loss_infomax = model.loss(pos_z, neg_z, summary)

    pos_out = linkmodel(pos_z, dataset.edge_index)
    neg_edge_index = negative_sampling(dataset.edge_index, dataset.num_nodes)
    neg_out = linkmodel(pos_z, neg_edge_index)

    out = torch.cat([pos_out, neg_out])
    gt = torch.cat([torch.ones_like(pos_out), torch.zeros_like(neg_out)])
    loss_link_pred = loss_bce(out, gt)
    loss =  2*loss_infomax + 0.5*loss_link_pred
    loss.backward()
    optimizer.step()
    print(f'Epoch: {epoch}, Loss: {loss}')
    if epoch == num_epoch:
      return pos_z, neg_z, summary, loss.item()
    else:
      return loss.item()

In [53]:
def test(head_model):
  metrics_val = []
  head_model.eval()
  index = 0
  for users, vectors, labels in user_test_dataloader:
      test_probs = head_model(vectors)
      test_pred = torch.argmax(test_probs, axis = 1)#torch.Tensor([0]*20001)

      test_one_hot = torch.zeros(len(test_probs), num_classes)
      test_one_hot[torch.arange(len(test_one_hot)), labels] = 1


      hits1, hits5, hits10, hits20, hits50, hits100, map1, map5, map10, map20, map50, map100, ndcg1, ndcg5, ndcg10, ndcg20, ndcg50, ndcg100 \
       = get_metrics_(test_probs, labels, test_one_hot)

      metrics_val.append([hits1, hits5, hits10, hits20, hits50, hits100, map1, map5, map10, map20, map50, map100, ndcg1, ndcg5, ndcg10, ndcg20, ndcg50, ndcg100])#, hits5, hits10, hits20, hits50, hits100, map1, map5, map10, map20, map50, map100])

  mean = torch.Tensor(metrics_val).mean(axis=0)
  test_hits1, test_hits5, test_hits10, test_hits20, test_hits50, test_hits100, test_map1, test_map5, test_map10, test_map20, test_map50, test_map100, \
  test_ndcg1, test_ndcg5, test_ndcg10, test_ndcg20, test_ndcg50, test_ndcg100 = mean

  return test_hits1, test_hits5, test_hits10, test_hits20, test_hits50, test_hits100, test_map1, test_map5, test_map10, test_map20, test_map50, test_map100, \
  test_ndcg1, test_ndcg5, test_ndcg10, test_ndcg20, test_ndcg50, test_ndcg100

In [50]:
def train2task():
  loss_fn = nn.CrossEntropyLoss()
  head_model  = NextActionPredHead()

  head_model.train()
  optimizer = torch.optim.Adam(head_model.parameters(), lr = 0.001)
  lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.97, last_epoch=-1)
  losses = []

  for i in range(200):
    for users, vectors, labels in user_train_dataloader:

        optimizer.zero_grad()

        probs = head_model(vectors)
        pred = torch.argmax(probs, axis = 1)
        one_hot = torch.zeros(len(probs), num_classes)
        one_hot[torch.arange(len(one_hot)), labels] = 1
        loss = loss_fn(probs, one_hot)

        losses.append(loss)
        loss.backward()
        optimizer.step()
    mean_loss = torch.Tensor(losses).mean(axis=0).item()
    print(f'Epoch: {i}, Loss: {mean_loss}')
  hits1, hits5, hits10, hits20, hits50, hits100, map1, map5, map10, map20, map50, map100, ndcg1, ndcg5, ndcg10, ndcg20, ndcg50, ndcg100 = test(head_model)
  print_metrics(hits1, hits5, hits10, hits20, hits50, hits100, map1, map5, map10, map20, map50, map100, ndcg1, ndcg5, ndcg10, ndcg20, ndcg50, ndcg100)
  return head_model

### Node mix augmentation, only infomax

In [42]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
model = DeepGraphInfomax(hidden_channels=128,
                         encoder=Encoder(interm_channels=74, hidden_channels=128,
                                         num_features=128),
                         summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)),
                         corruption=node_mix).to(device)
linkmodel = LinkPredHead()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)
loss_bce = nn.BCEWithLogitsLoss()

In [54]:
num_epoch = 100
loss_epoch = []
for epoch in tqdm(range(1, num_epoch+1)):
  if epoch == num_epoch:
    pos_z, neg_z, summary, loss = train_only_infomax(epoch)
  else:
    loss = train_only_infomax(epoch)
  loss_epoch.append(loss)

  1%|          | 1/100 [00:09<15:58,  9.68s/it]

Epoch: 1, Loss: 0.3002675175666809


  2%|▏         | 2/100 [00:16<13:11,  8.07s/it]

Epoch: 2, Loss: 0.30694717168807983


  3%|▎         | 3/100 [00:22<11:45,  7.27s/it]

Epoch: 3, Loss: 0.26518505811691284


  4%|▍         | 4/100 [00:30<11:35,  7.25s/it]

Epoch: 4, Loss: 0.37640535831451416


  5%|▌         | 5/100 [00:37<11:20,  7.16s/it]

Epoch: 5, Loss: 0.23803697526454926


  6%|▌         | 6/100 [00:43<10:57,  7.00s/it]

Epoch: 6, Loss: 0.3896709084510803


  7%|▋         | 7/100 [00:49<10:17,  6.64s/it]

Epoch: 7, Loss: 0.22360339760780334


  8%|▊         | 8/100 [00:56<10:09,  6.63s/it]

Epoch: 8, Loss: 0.2510923445224762


  9%|▉         | 9/100 [01:02<09:38,  6.36s/it]

Epoch: 9, Loss: 0.16009752452373505


 10%|█         | 10/100 [01:08<09:41,  6.46s/it]

Epoch: 10, Loss: 0.15542955696582794


 11%|█         | 11/100 [01:14<09:19,  6.29s/it]

Epoch: 11, Loss: 0.1545710563659668


 12%|█▏        | 12/100 [01:21<09:20,  6.37s/it]

Epoch: 12, Loss: 0.4257308840751648


 13%|█▎        | 13/100 [01:27<09:07,  6.29s/it]

Epoch: 13, Loss: 0.2021554410457611


 14%|█▍        | 14/100 [01:34<09:20,  6.52s/it]

Epoch: 14, Loss: 0.1513550579547882


 15%|█▌        | 15/100 [01:40<09:01,  6.37s/it]

Epoch: 15, Loss: 0.13279816508293152


 16%|█▌        | 16/100 [01:47<09:11,  6.56s/it]

Epoch: 16, Loss: 0.2059984803199768


 17%|█▋        | 17/100 [01:53<08:51,  6.41s/it]

Epoch: 17, Loss: 0.2528940737247467


 18%|█▊        | 18/100 [02:00<08:51,  6.48s/it]

Epoch: 18, Loss: 0.2083817422389984


 19%|█▉        | 19/100 [02:05<08:25,  6.25s/it]

Epoch: 19, Loss: 0.17879986763000488


 20%|██        | 20/100 [02:13<08:43,  6.54s/it]

Epoch: 20, Loss: 0.1757945716381073


 21%|██        | 21/100 [02:19<08:33,  6.50s/it]

Epoch: 21, Loss: 0.23683615028858185


 22%|██▏       | 22/100 [02:26<08:45,  6.73s/it]

Epoch: 22, Loss: 0.15496590733528137


 23%|██▎       | 23/100 [02:33<08:27,  6.59s/it]

Epoch: 23, Loss: 0.14650748670101166


 24%|██▍       | 24/100 [02:39<08:29,  6.70s/it]

Epoch: 24, Loss: 0.1578107476234436


 25%|██▌       | 25/100 [02:46<08:10,  6.54s/it]

Epoch: 25, Loss: 0.2268022894859314


 26%|██▌       | 26/100 [02:53<08:13,  6.67s/it]

Epoch: 26, Loss: 0.14022402465343475


 27%|██▋       | 27/100 [02:59<07:53,  6.49s/it]

Epoch: 27, Loss: 0.16834306716918945


 28%|██▊       | 28/100 [03:05<07:53,  6.57s/it]

Epoch: 28, Loss: 0.2711249589920044


 29%|██▉       | 29/100 [03:12<07:35,  6.41s/it]

Epoch: 29, Loss: 0.170964777469635


 30%|███       | 30/100 [03:18<07:35,  6.51s/it]

Epoch: 30, Loss: 0.20412488281726837


 31%|███       | 31/100 [03:24<07:16,  6.32s/it]

Epoch: 31, Loss: 0.15708477795124054


 32%|███▏      | 32/100 [03:31<07:16,  6.42s/it]

Epoch: 32, Loss: 0.15989485383033752


 33%|███▎      | 33/100 [03:36<06:55,  6.20s/it]

Epoch: 33, Loss: 0.15788307785987854


 34%|███▍      | 34/100 [03:43<06:58,  6.34s/it]

Epoch: 34, Loss: 0.19117417931556702


 35%|███▌      | 35/100 [03:49<06:38,  6.14s/it]

Epoch: 35, Loss: 0.15301983058452606


 36%|███▌      | 36/100 [03:56<06:45,  6.33s/it]

Epoch: 36, Loss: 0.16762995719909668


 37%|███▋      | 37/100 [04:01<06:30,  6.19s/it]

Epoch: 37, Loss: 0.20510390400886536


 38%|███▊      | 38/100 [04:08<06:30,  6.29s/it]

Epoch: 38, Loss: 0.22311821579933167


 39%|███▉      | 39/100 [04:14<06:13,  6.13s/it]

Epoch: 39, Loss: 0.16185416281223297


 40%|████      | 40/100 [04:21<06:19,  6.33s/it]

Epoch: 40, Loss: 0.1469678282737732


 41%|████      | 41/100 [04:26<06:03,  6.16s/it]

Epoch: 41, Loss: 0.13009658455848694


 42%|████▏     | 42/100 [04:33<06:05,  6.29s/it]

Epoch: 42, Loss: 0.205097496509552


 43%|████▎     | 43/100 [04:39<05:58,  6.29s/it]

Epoch: 43, Loss: 0.16404470801353455


 44%|████▍     | 44/100 [04:46<05:59,  6.42s/it]

Epoch: 44, Loss: 0.13248814642429352


 45%|████▌     | 45/100 [04:52<05:47,  6.31s/it]

Epoch: 45, Loss: 0.22499555349349976


 46%|████▌     | 46/100 [04:59<05:51,  6.52s/it]

Epoch: 46, Loss: 0.15422075986862183


 47%|████▋     | 47/100 [05:05<05:38,  6.39s/it]

Epoch: 47, Loss: 0.15052026510238647


 48%|████▊     | 48/100 [05:12<05:41,  6.56s/it]

Epoch: 48, Loss: 0.14086061716079712


 49%|████▉     | 49/100 [05:18<05:33,  6.54s/it]

Epoch: 49, Loss: 0.24439141154289246


 50%|█████     | 50/100 [05:26<05:34,  6.70s/it]

Epoch: 50, Loss: 0.1890474557876587


 51%|█████     | 51/100 [05:32<05:17,  6.47s/it]

Epoch: 51, Loss: 0.15591935813426971


 52%|█████▏    | 52/100 [05:38<05:11,  6.48s/it]

Epoch: 52, Loss: 0.13176171481609344


 53%|█████▎    | 53/100 [05:44<04:54,  6.26s/it]

Epoch: 53, Loss: 0.1342114806175232


 54%|█████▍    | 54/100 [05:50<04:52,  6.36s/it]

Epoch: 54, Loss: 0.15334489941596985


 55%|█████▌    | 55/100 [05:56<04:40,  6.24s/it]

Epoch: 55, Loss: 0.1330406367778778


 56%|█████▌    | 56/100 [06:06<05:16,  7.18s/it]

Epoch: 56, Loss: 0.13244040310382843


 57%|█████▋    | 57/100 [06:17<06:08,  8.57s/it]

Epoch: 57, Loss: 0.14043393731117249


 58%|█████▊    | 58/100 [06:24<05:37,  8.04s/it]

Epoch: 58, Loss: 0.21874941885471344


 59%|█████▉    | 59/100 [06:31<05:18,  7.77s/it]

Epoch: 59, Loss: 0.14789608120918274


 60%|██████    | 60/100 [06:38<04:51,  7.28s/it]

Epoch: 60, Loss: 0.14554563164710999


 61%|██████    | 61/100 [06:44<04:38,  7.14s/it]

Epoch: 61, Loss: 0.14839476346969604


 62%|██████▏   | 62/100 [06:50<04:16,  6.76s/it]

Epoch: 62, Loss: 0.1151055246591568


 63%|██████▎   | 63/100 [06:57<04:09,  6.73s/it]

Epoch: 63, Loss: 0.1547539234161377


 64%|██████▍   | 64/100 [07:03<03:51,  6.42s/it]

Epoch: 64, Loss: 0.14166928827762604


 65%|██████▌   | 65/100 [07:09<03:45,  6.45s/it]

Epoch: 65, Loss: 0.14612534642219543


 66%|██████▌   | 66/100 [07:15<03:32,  6.25s/it]

Epoch: 66, Loss: 0.12506693601608276


 67%|██████▋   | 67/100 [07:21<03:29,  6.34s/it]

Epoch: 67, Loss: 0.19205403327941895


 68%|██████▊   | 68/100 [07:27<03:18,  6.21s/it]

Epoch: 68, Loss: 0.126337468624115


 69%|██████▉   | 69/100 [07:34<03:16,  6.35s/it]

Epoch: 69, Loss: 0.181796133518219


 70%|███████   | 70/100 [07:40<03:05,  6.17s/it]

Epoch: 70, Loss: 0.1407671570777893


 71%|███████   | 71/100 [07:46<03:03,  6.33s/it]

Epoch: 71, Loss: 0.19136682152748108


 72%|███████▏  | 72/100 [07:52<02:51,  6.14s/it]

Epoch: 72, Loss: 0.1800568848848343


 73%|███████▎  | 73/100 [07:59<02:48,  6.26s/it]

Epoch: 73, Loss: 0.14817996323108673


 74%|███████▍  | 74/100 [08:04<02:37,  6.07s/it]

Epoch: 74, Loss: 0.17347538471221924


 75%|███████▌  | 75/100 [08:11<02:35,  6.22s/it]

Epoch: 75, Loss: 0.1272200644016266


 76%|███████▌  | 76/100 [08:17<02:25,  6.08s/it]

Epoch: 76, Loss: 0.1569189876317978


 77%|███████▋  | 77/100 [08:23<02:24,  6.30s/it]

Epoch: 77, Loss: 0.12976409494876862


 78%|███████▊  | 78/100 [08:30<02:20,  6.41s/it]

Epoch: 78, Loss: 0.14737485349178314


 79%|███████▉  | 79/100 [08:38<02:20,  6.71s/it]

Epoch: 79, Loss: 0.15660804510116577


 80%|████████  | 80/100 [08:43<02:08,  6.43s/it]

Epoch: 80, Loss: 0.13563519716262817


 81%|████████  | 81/100 [08:50<02:02,  6.47s/it]

Epoch: 81, Loss: 0.1281663030385971


 82%|████████▏ | 82/100 [08:56<01:52,  6.23s/it]

Epoch: 82, Loss: 0.19340670108795166


 83%|████████▎ | 83/100 [09:02<01:47,  6.32s/it]

Epoch: 83, Loss: 0.1464885026216507


 84%|████████▍ | 84/100 [09:08<01:40,  6.27s/it]

Epoch: 84, Loss: 0.15055608749389648


 85%|████████▌ | 85/100 [09:15<01:36,  6.44s/it]

Epoch: 85, Loss: 0.14342889189720154


 86%|████████▌ | 86/100 [09:21<01:27,  6.26s/it]

Epoch: 86, Loss: 0.13952520489692688


 87%|████████▋ | 87/100 [09:27<01:22,  6.35s/it]

Epoch: 87, Loss: 0.12110862135887146


 88%|████████▊ | 88/100 [09:33<01:14,  6.18s/it]

Epoch: 88, Loss: 0.15002602338790894


 89%|████████▉ | 89/100 [09:40<01:09,  6.32s/it]

Epoch: 89, Loss: 0.11163340508937836


 90%|█████████ | 90/100 [09:46<01:01,  6.17s/it]

Epoch: 90, Loss: 0.10398396849632263


 91%|█████████ | 91/100 [09:52<00:56,  6.31s/it]

Epoch: 91, Loss: 0.13373365998268127


 92%|█████████▏| 92/100 [09:59<00:51,  6.43s/it]

Epoch: 92, Loss: 0.14581242203712463


 93%|█████████▎| 93/100 [10:05<00:44,  6.34s/it]

Epoch: 93, Loss: 0.12000899016857147


 94%|█████████▍| 94/100 [10:12<00:38,  6.48s/it]

Epoch: 94, Loss: 0.13284677267074585


 95%|█████████▌| 95/100 [10:18<00:32,  6.42s/it]

Epoch: 95, Loss: 0.15432047843933105


 96%|█████████▌| 96/100 [10:25<00:25,  6.41s/it]

Epoch: 96, Loss: 0.1384410858154297


 97%|█████████▋| 97/100 [10:31<00:18,  6.30s/it]

Epoch: 97, Loss: 0.19565926492214203


 98%|█████████▊| 98/100 [10:37<00:12,  6.31s/it]

Epoch: 98, Loss: 0.10686999559402466


 99%|█████████▉| 99/100 [10:43<00:06,  6.23s/it]

Epoch: 99, Loss: 0.1630438268184662


100%|██████████| 100/100 [10:50<00:00,  6.50s/it]

Epoch: 100, Loss: 0.13529092073440552





In [55]:
final_user_representations = {}
repr = pos_z
for i in range(pos_z.shape[0]):
  final_user_representations[i] = (i, repr[i].detach(), user_representations[i][1])

In [56]:
user_dataset = UserDataset(final_user_representations)

n = len(user_dataset)

indices = np.arange(n)
indices = np.random.permutation(indices)

train_indices = indices [:int(0.8*n)]
test_indices = indices[int(0.8*n):]

user_train_dataset = Subset(user_dataset, train_indices)
user_test_dataset = Subset(user_dataset, test_indices)

user_train_dataloader = DataLoader(user_train_dataset, batch_size=64, shuffle=True)
user_test_dataloader = DataLoader(user_test_dataset, batch_size=64, shuffle=True)

In [57]:
num_classes = 186
classes = np.arange(0, num_classes)

In [58]:
head_model = train2task()
torch.save(head_model.state_dict(), 'node_mix_only_infomax_model.pth')

Epoch: 0, Loss: 4.285384654998779
Epoch: 1, Loss: 4.190764904022217
Epoch: 2, Loss: 4.137697696685791
Epoch: 3, Loss: 4.099329948425293
Epoch: 4, Loss: 4.069395542144775
Epoch: 5, Loss: 4.044259071350098
Epoch: 6, Loss: 4.022629261016846
Epoch: 7, Loss: 4.003567695617676
Epoch: 8, Loss: 3.9867825508117676
Epoch: 9, Loss: 3.971529960632324
Epoch: 10, Loss: 3.957684278488159
Epoch: 11, Loss: 3.9449241161346436
Epoch: 12, Loss: 3.933150291442871
Epoch: 13, Loss: 3.9221835136413574
Epoch: 14, Loss: 3.911931276321411
Epoch: 15, Loss: 3.902322292327881
Epoch: 16, Loss: 3.8933143615722656
Epoch: 17, Loss: 3.8847427368164062
Epoch: 18, Loss: 3.8766443729400635
Epoch: 19, Loss: 3.868969440460205
Epoch: 20, Loss: 3.861726999282837
Epoch: 21, Loss: 3.8548171520233154
Epoch: 22, Loss: 3.848130464553833
Epoch: 23, Loss: 3.841722249984741
Epoch: 24, Loss: 3.8356709480285645
Epoch: 25, Loss: 3.829836130142212
Epoch: 26, Loss: 3.8241353034973145
Epoch: 27, Loss: 3.818697214126587
Epoch: 28, Loss: 3.81

### Node dropout augmentation, only infomax

In [59]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
model = DeepGraphInfomax(hidden_channels=128,
                         encoder=Encoder(interm_channels=74, hidden_channels=128,
                                         num_features=128),
                         summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)),
                         corruption=node_dropout).to(device)
linkmodel = LinkPredHead()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)
loss_bce = nn.BCEWithLogitsLoss()

In [60]:
num_epoch = 100
loss_epoch = []
for epoch in tqdm(range(1, num_epoch+1)):
  if epoch == num_epoch:
    pos_z, neg_z, summary, loss = train_only_infomax(epoch)
  else:
    loss = train_only_infomax(epoch)
  loss_epoch.append(loss)

  1%|          | 1/100 [00:08<13:26,  8.15s/it]

Epoch: 1, Loss: 1.4168205261230469


  2%|▏         | 2/100 [00:15<13:00,  7.96s/it]

Epoch: 2, Loss: 1.591853141784668


  3%|▎         | 3/100 [00:23<12:13,  7.56s/it]

Epoch: 3, Loss: 1.4212048053741455


  4%|▍         | 4/100 [00:30<12:17,  7.68s/it]

Epoch: 4, Loss: 1.3830645084381104


  5%|▌         | 5/100 [00:38<12:17,  7.77s/it]

Epoch: 5, Loss: 1.475018858909607


  6%|▌         | 6/100 [00:46<12:12,  7.79s/it]

Epoch: 6, Loss: 1.4030492305755615


  7%|▋         | 7/100 [00:54<12:05,  7.80s/it]

Epoch: 7, Loss: 1.4170567989349365


  8%|▊         | 8/100 [01:01<11:30,  7.50s/it]

Epoch: 8, Loss: 1.1938517093658447


  9%|▉         | 9/100 [01:09<11:33,  7.62s/it]

Epoch: 9, Loss: 1.4324705600738525


 10%|█         | 10/100 [01:16<11:08,  7.42s/it]

Epoch: 10, Loss: 1.4478225708007812


 11%|█         | 11/100 [01:23<11:08,  7.51s/it]

Epoch: 11, Loss: 1.456478238105774


 12%|█▏        | 12/100 [01:32<11:34,  7.89s/it]

Epoch: 12, Loss: 1.1826938390731812


 13%|█▎        | 13/100 [01:40<11:16,  7.77s/it]

Epoch: 13, Loss: 1.4611215591430664


 14%|█▍        | 14/100 [01:48<11:11,  7.80s/it]

Epoch: 14, Loss: 1.396592617034912


 15%|█▌        | 15/100 [01:55<10:48,  7.63s/it]

Epoch: 15, Loss: 1.278359293937683


 16%|█▌        | 16/100 [02:03<10:49,  7.73s/it]

Epoch: 16, Loss: 1.3996843099594116


 17%|█▋        | 17/100 [02:10<10:26,  7.54s/it]

Epoch: 17, Loss: 1.1736801862716675


 18%|█▊        | 18/100 [02:17<10:13,  7.48s/it]

Epoch: 18, Loss: 1.280418872833252


 19%|█▉        | 19/100 [02:25<10:19,  7.64s/it]

Epoch: 19, Loss: 1.1688823699951172


 20%|██        | 20/100 [02:32<09:52,  7.41s/it]

Epoch: 20, Loss: 1.533452033996582


 21%|██        | 21/100 [02:40<09:57,  7.56s/it]

Epoch: 21, Loss: 1.376896619796753


 22%|██▏       | 22/100 [02:47<09:40,  7.45s/it]

Epoch: 22, Loss: 1.4878584146499634


 23%|██▎       | 23/100 [02:55<09:42,  7.57s/it]

Epoch: 23, Loss: 1.4129183292388916


 24%|██▍       | 24/100 [03:03<09:33,  7.55s/it]

Epoch: 24, Loss: 1.3518874645233154


 25%|██▌       | 25/100 [03:10<09:18,  7.45s/it]

Epoch: 25, Loss: 1.4134693145751953


 26%|██▌       | 26/100 [03:25<11:53,  9.65s/it]

Epoch: 26, Loss: 1.5204503536224365


 27%|██▋       | 27/100 [03:33<11:09,  9.17s/it]

Epoch: 27, Loss: 1.3951189517974854


 28%|██▊       | 28/100 [03:40<10:17,  8.58s/it]

Epoch: 28, Loss: 1.4244699478149414


 29%|██▉       | 29/100 [03:48<09:54,  8.37s/it]

Epoch: 29, Loss: 1.353857159614563


 30%|███       | 30/100 [03:55<09:18,  7.98s/it]

Epoch: 30, Loss: 1.3371782302856445


 31%|███       | 31/100 [04:03<09:10,  7.97s/it]

Epoch: 31, Loss: 1.284060001373291


 32%|███▏      | 32/100 [04:10<08:55,  7.87s/it]

Epoch: 32, Loss: 1.2138322591781616


 33%|███▎      | 33/100 [04:18<08:33,  7.67s/it]

Epoch: 33, Loss: 1.3548414707183838


 34%|███▍      | 34/100 [04:26<08:34,  7.80s/it]

Epoch: 34, Loss: 1.4006657600402832


 35%|███▌      | 35/100 [04:33<08:15,  7.62s/it]

Epoch: 35, Loss: 1.3051148653030396


 36%|███▌      | 36/100 [04:41<08:18,  7.78s/it]

Epoch: 36, Loss: 1.3959543704986572


 37%|███▋      | 37/100 [04:49<08:11,  7.80s/it]

Epoch: 37, Loss: 1.2533702850341797


 38%|███▊      | 38/100 [04:56<07:50,  7.59s/it]

Epoch: 38, Loss: 1.3324816226959229


 39%|███▉      | 39/100 [05:04<07:51,  7.74s/it]

Epoch: 39, Loss: 1.3885287046432495


 40%|████      | 40/100 [05:11<07:34,  7.58s/it]

Epoch: 40, Loss: 1.2417242527008057


 41%|████      | 41/100 [05:19<07:35,  7.72s/it]

Epoch: 41, Loss: 1.3939311504364014


 42%|████▏     | 42/100 [05:27<07:34,  7.83s/it]

Epoch: 42, Loss: 1.258837103843689


 43%|████▎     | 43/100 [05:35<07:17,  7.68s/it]

Epoch: 43, Loss: 1.405261516571045


 44%|████▍     | 44/100 [05:42<07:11,  7.70s/it]

Epoch: 44, Loss: 1.263129711151123


 45%|████▌     | 45/100 [05:50<06:53,  7.52s/it]

Epoch: 45, Loss: 1.3547449111938477


 46%|████▌     | 46/100 [05:58<06:53,  7.66s/it]

Epoch: 46, Loss: 1.2820305824279785


 47%|████▋     | 47/100 [06:06<06:52,  7.78s/it]

Epoch: 47, Loss: 1.2751684188842773


 48%|████▊     | 48/100 [06:14<06:53,  7.96s/it]

Epoch: 48, Loss: 1.4534101486206055


 49%|████▉     | 49/100 [06:22<06:46,  7.96s/it]

Epoch: 49, Loss: 1.3253811597824097


 50%|█████     | 50/100 [06:29<06:27,  7.75s/it]

Epoch: 50, Loss: 1.442108154296875


 51%|█████     | 51/100 [06:37<06:16,  7.68s/it]

Epoch: 51, Loss: 1.2980668544769287


 52%|█████▏    | 52/100 [06:45<06:15,  7.83s/it]

Epoch: 52, Loss: 1.3662607669830322


 53%|█████▎    | 53/100 [06:52<05:59,  7.66s/it]

Epoch: 53, Loss: 1.4389551877975464


 54%|█████▍    | 54/100 [07:00<05:59,  7.82s/it]

Epoch: 54, Loss: 1.3411061763763428


 55%|█████▌    | 55/100 [07:08<05:44,  7.67s/it]

Epoch: 55, Loss: 1.3281176090240479


 56%|█████▌    | 56/100 [07:15<05:36,  7.66s/it]

Epoch: 56, Loss: 1.2961795330047607


 57%|█████▋    | 57/100 [07:23<05:33,  7.75s/it]

Epoch: 57, Loss: 1.381026268005371


 58%|█████▊    | 58/100 [07:30<05:17,  7.55s/it]

Epoch: 58, Loss: 1.3331272602081299


 59%|█████▉    | 59/100 [07:39<05:18,  7.78s/it]

Epoch: 59, Loss: 1.3327281475067139


 60%|██████    | 60/100 [07:46<05:01,  7.54s/it]

Epoch: 60, Loss: 1.180821180343628


 61%|██████    | 61/100 [07:53<04:57,  7.63s/it]

Epoch: 61, Loss: 1.2973132133483887


 62%|██████▏   | 62/100 [08:01<04:53,  7.73s/it]

Epoch: 62, Loss: 1.268776774406433


 63%|██████▎   | 63/100 [08:08<04:37,  7.50s/it]

Epoch: 63, Loss: 1.1503398418426514


 64%|██████▍   | 64/100 [08:16<04:36,  7.67s/it]

Epoch: 64, Loss: 1.3991646766662598


 65%|██████▌   | 65/100 [08:23<04:21,  7.47s/it]

Epoch: 65, Loss: 1.3248542547225952


 66%|██████▌   | 66/100 [08:31<04:18,  7.60s/it]

Epoch: 66, Loss: 1.1819119453430176


 67%|██████▋   | 67/100 [08:39<04:11,  7.62s/it]

Epoch: 67, Loss: 1.076625943183899


 68%|██████▊   | 68/100 [08:46<03:58,  7.46s/it]

Epoch: 68, Loss: 1.4199209213256836


 69%|██████▉   | 69/100 [08:54<03:55,  7.58s/it]

Epoch: 69, Loss: 1.0919817686080933


 70%|███████   | 70/100 [09:01<03:42,  7.41s/it]

Epoch: 70, Loss: 1.2831720113754272


 71%|███████   | 71/100 [09:09<03:41,  7.63s/it]

Epoch: 71, Loss: 1.0863678455352783


 72%|███████▏  | 72/100 [09:17<03:32,  7.59s/it]

Epoch: 72, Loss: 0.9936298131942749


 73%|███████▎  | 73/100 [09:24<03:22,  7.49s/it]

Epoch: 73, Loss: 1.3257120847702026


 74%|███████▍  | 74/100 [09:32<03:19,  7.65s/it]

Epoch: 74, Loss: 1.3909547328948975


 75%|███████▌  | 75/100 [09:39<03:07,  7.49s/it]

Epoch: 75, Loss: 1.4824336767196655


 76%|███████▌  | 76/100 [09:47<03:04,  7.67s/it]

Epoch: 76, Loss: 0.9029867649078369


 77%|███████▋  | 77/100 [09:55<02:57,  7.71s/it]

Epoch: 77, Loss: 1.423929214477539


 78%|███████▊  | 78/100 [10:02<02:47,  7.60s/it]

Epoch: 78, Loss: 1.482262134552002


 79%|███████▉  | 79/100 [10:10<02:42,  7.72s/it]

Epoch: 79, Loss: 0.8917845487594604


 80%|████████  | 80/100 [10:17<02:31,  7.57s/it]

Epoch: 80, Loss: 0.9987632632255554


 81%|████████  | 81/100 [10:26<02:27,  7.76s/it]

Epoch: 81, Loss: 0.9751673340797424


 82%|████████▏ | 82/100 [10:33<02:18,  7.68s/it]

Epoch: 82, Loss: 1.3815313577651978


 83%|████████▎ | 83/100 [10:41<02:09,  7.64s/it]

Epoch: 83, Loss: 1.6007914543151855


 84%|████████▍ | 84/100 [10:50<02:10,  8.13s/it]

Epoch: 84, Loss: 1.5345025062561035


 85%|████████▌ | 85/100 [10:57<01:58,  7.92s/it]

Epoch: 85, Loss: 0.8411794900894165


 86%|████████▌ | 86/100 [11:06<01:53,  8.09s/it]

Epoch: 86, Loss: 1.505981683731079


 87%|████████▋ | 87/100 [11:13<01:43,  7.93s/it]

Epoch: 87, Loss: 0.9792666435241699


 88%|████████▊ | 88/100 [11:21<01:33,  7.79s/it]

Epoch: 88, Loss: 0.8973676562309265


 89%|████████▉ | 89/100 [11:29<01:26,  7.88s/it]

Epoch: 89, Loss: 1.3571677207946777


 90%|█████████ | 90/100 [11:36<01:17,  7.74s/it]

Epoch: 90, Loss: 0.8249561190605164


 91%|█████████ | 91/100 [11:45<01:10,  7.86s/it]

Epoch: 91, Loss: 0.753054678440094


 92%|█████████▏| 92/100 [11:52<01:02,  7.87s/it]

Epoch: 92, Loss: 1.5296446084976196


 93%|█████████▎| 93/100 [12:00<00:53,  7.71s/it]

Epoch: 93, Loss: 0.7504339218139648


 94%|█████████▍| 94/100 [12:08<00:47,  7.90s/it]

Epoch: 94, Loss: 1.7225679159164429


 95%|█████████▌| 95/100 [12:15<00:38,  7.64s/it]

Epoch: 95, Loss: 1.279322624206543


 96%|█████████▌| 96/100 [12:23<00:31,  7.83s/it]

Epoch: 96, Loss: 1.4029417037963867


 97%|█████████▋| 97/100 [12:31<00:23,  7.83s/it]

Epoch: 97, Loss: 0.7540469169616699


 98%|█████████▊| 98/100 [12:38<00:15,  7.60s/it]

Epoch: 98, Loss: 1.1193156242370605


 99%|█████████▉| 99/100 [12:46<00:07,  7.72s/it]

Epoch: 99, Loss: 0.758968710899353


100%|██████████| 100/100 [12:54<00:00,  7.74s/it]

Epoch: 100, Loss: 0.7786322832107544





In [61]:
final_user_representations = {}
repr = pos_z
for i in range(pos_z.shape[0]):
  final_user_representations[i] = (i, repr[i].detach(), user_representations[i][1])

In [62]:
user_dataset = UserDataset(final_user_representations)

n = len(user_dataset)

indices = np.arange(n)
indices = np.random.permutation(indices)

train_indices = indices [:int(0.8*n)]
test_indices = indices[int(0.8*n):]

user_train_dataset = Subset(user_dataset, train_indices)
user_test_dataset = Subset(user_dataset, test_indices)

user_train_dataloader = DataLoader(user_train_dataset, batch_size=64, shuffle=True)
user_test_dataloader = DataLoader(user_test_dataset, batch_size=64, shuffle=True)

In [63]:
num_classes = 186
classes = np.arange(0, num_classes)

In [64]:
head_model = train2task()
torch.save(head_model.state_dict(), 'node_dropout_only_infomax_model.pth')

Epoch: 0, Loss: 4.443563461303711
Epoch: 1, Loss: 4.228217601776123
Epoch: 2, Loss: 4.109987735748291
Epoch: 3, Loss: 4.02674674987793
Epoch: 4, Loss: 3.9619386196136475
Epoch: 5, Loss: 3.9084553718566895
Epoch: 6, Loss: 3.8629555702209473
Epoch: 7, Loss: 3.8232662677764893
Epoch: 8, Loss: 3.7881197929382324
Epoch: 9, Loss: 3.756572961807251
Epoch: 10, Loss: 3.7279157638549805
Epoch: 11, Loss: 3.7017226219177246
Epoch: 12, Loss: 3.677577257156372
Epoch: 13, Loss: 3.6552114486694336
Epoch: 14, Loss: 3.634401798248291
Epoch: 15, Loss: 3.6149425506591797
Epoch: 16, Loss: 3.596691131591797
Epoch: 17, Loss: 3.579517364501953
Epoch: 18, Loss: 3.563297748565674
Epoch: 19, Loss: 3.5479421615600586
Epoch: 20, Loss: 3.5333645343780518
Epoch: 21, Loss: 3.5194921493530273
Epoch: 22, Loss: 3.506268262863159
Epoch: 23, Loss: 3.4936490058898926
Epoch: 24, Loss: 3.4815561771392822
Epoch: 25, Loss: 3.4699771404266357
Epoch: 26, Loss: 3.458878517150879
Epoch: 27, Loss: 3.4482204914093018
Epoch: 28, Loss

### Edge dropout, only infomax

In [65]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
model = DeepGraphInfomax(hidden_channels=128,
                         encoder=Encoder(interm_channels=74, hidden_channels=128,
                                         num_features=128),
                         summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)),
                         corruption=edge_dropout).to(device)
linkmodel = LinkPredHead()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)
loss_bce = nn.BCEWithLogitsLoss()

In [66]:
num_epoch = 100
loss_epoch = []
for epoch in tqdm(range(1, num_epoch+1)):
  if epoch == num_epoch:
    pos_z, neg_z, summary, loss = train_only_infomax(epoch)
  else:
    loss = train_only_infomax(epoch)
  loss_epoch.append(loss)

  1%|          | 1/100 [00:05<09:25,  5.71s/it]

Epoch: 1, Loss: 1.5530227422714233


  2%|▏         | 2/100 [00:12<09:55,  6.08s/it]

Epoch: 2, Loss: 1.5028414726257324


  3%|▎         | 3/100 [00:17<09:24,  5.82s/it]

Epoch: 3, Loss: 1.4372378587722778


  4%|▍         | 4/100 [00:24<09:51,  6.16s/it]

Epoch: 4, Loss: 1.3760030269622803


  5%|▌         | 5/100 [00:30<09:45,  6.17s/it]

Epoch: 5, Loss: 1.371445655822754


  6%|▌         | 6/100 [00:36<09:25,  6.02s/it]

Epoch: 6, Loss: 1.3410354852676392


  7%|▋         | 7/100 [00:42<09:21,  6.04s/it]

Epoch: 7, Loss: 1.4272584915161133


  8%|▊         | 8/100 [00:47<08:58,  5.85s/it]

Epoch: 8, Loss: 1.3407437801361084


  9%|▉         | 9/100 [00:54<09:07,  6.01s/it]

Epoch: 9, Loss: 1.2757248878479004


 10%|█         | 10/100 [00:59<08:46,  5.85s/it]

Epoch: 10, Loss: 1.285747766494751


 11%|█         | 11/100 [01:05<08:52,  5.98s/it]

Epoch: 11, Loss: 1.3431694507598877


 12%|█▏        | 12/100 [01:11<08:33,  5.84s/it]

Epoch: 12, Loss: 1.4050923585891724


 13%|█▎        | 13/100 [01:17<08:39,  5.98s/it]

Epoch: 13, Loss: 1.3715261220932007


 14%|█▍        | 14/100 [01:23<08:21,  5.83s/it]

Epoch: 14, Loss: 1.280372142791748


 15%|█▌        | 15/100 [01:29<08:27,  5.97s/it]

Epoch: 15, Loss: 1.233243703842163


 16%|█▌        | 16/100 [01:34<08:09,  5.83s/it]

Epoch: 16, Loss: 1.2865524291992188


 17%|█▋        | 17/100 [01:40<08:07,  5.88s/it]

Epoch: 17, Loss: 1.2917814254760742


 18%|█▊        | 18/100 [01:46<07:59,  5.85s/it]

Epoch: 18, Loss: 1.3032691478729248


 19%|█▉        | 19/100 [01:52<07:48,  5.79s/it]

Epoch: 19, Loss: 1.2451093196868896


 20%|██        | 20/100 [01:58<07:53,  5.92s/it]

Epoch: 20, Loss: 1.2962908744812012


 21%|██        | 21/100 [02:04<07:39,  5.81s/it]

Epoch: 21, Loss: 1.2655482292175293


 22%|██▏       | 22/100 [02:10<07:45,  5.97s/it]

Epoch: 22, Loss: 1.0936450958251953


 23%|██▎       | 23/100 [02:16<07:29,  5.84s/it]

Epoch: 23, Loss: 1.1857309341430664


 24%|██▍       | 24/100 [02:22<07:37,  6.01s/it]

Epoch: 24, Loss: 1.265112280845642


 25%|██▌       | 25/100 [02:27<07:19,  5.86s/it]

Epoch: 25, Loss: 1.3547712564468384


 26%|██▌       | 26/100 [02:34<07:22,  5.98s/it]

Epoch: 26, Loss: 1.0556302070617676


 27%|██▋       | 27/100 [02:39<07:05,  5.83s/it]

Epoch: 27, Loss: 0.9517903327941895


 28%|██▊       | 28/100 [02:45<07:10,  5.97s/it]

Epoch: 28, Loss: 0.9913949966430664


 29%|██▉       | 29/100 [02:51<06:54,  5.84s/it]

Epoch: 29, Loss: 1.0773794651031494


 30%|███       | 30/100 [02:57<06:50,  5.87s/it]

Epoch: 30, Loss: 1.1767914295196533


 31%|███       | 31/100 [03:03<06:46,  5.89s/it]

Epoch: 31, Loss: 1.0668206214904785


 32%|███▏      | 32/100 [03:08<06:34,  5.81s/it]

Epoch: 32, Loss: 0.937820315361023


 33%|███▎      | 33/100 [03:15<06:38,  5.95s/it]

Epoch: 33, Loss: 0.902592658996582


 34%|███▍      | 34/100 [03:20<06:24,  5.82s/it]

Epoch: 34, Loss: 0.9042137861251831


 35%|███▌      | 35/100 [03:27<06:29,  5.99s/it]

Epoch: 35, Loss: 0.7982853651046753


 36%|███▌      | 36/100 [03:32<06:14,  5.85s/it]

Epoch: 36, Loss: 0.8056641817092896


 37%|███▋      | 37/100 [03:39<06:18,  6.01s/it]

Epoch: 37, Loss: 0.7302840352058411


 38%|███▊      | 38/100 [03:44<06:03,  5.87s/it]

Epoch: 38, Loss: 0.7521911859512329


 39%|███▉      | 39/100 [03:51<06:07,  6.03s/it]

Epoch: 39, Loss: 0.861760139465332


 40%|████      | 40/100 [03:56<05:53,  5.89s/it]

Epoch: 40, Loss: 0.8185726404190063


 41%|████      | 41/100 [04:02<05:55,  6.02s/it]

Epoch: 41, Loss: 0.7894019484519958


 42%|████▏     | 42/100 [04:08<05:40,  5.87s/it]

Epoch: 42, Loss: 0.6845830082893372


 43%|████▎     | 43/100 [04:14<05:37,  5.92s/it]

Epoch: 43, Loss: 0.7483940124511719


 44%|████▍     | 44/100 [04:20<05:30,  5.91s/it]

Epoch: 44, Loss: 0.6706234216690063


 45%|████▌     | 45/100 [04:26<05:21,  5.84s/it]

Epoch: 45, Loss: 0.5252074599266052


 46%|████▌     | 46/100 [04:32<05:20,  5.94s/it]

Epoch: 46, Loss: 0.7710164189338684


 47%|████▋     | 47/100 [04:37<05:08,  5.82s/it]

Epoch: 47, Loss: 0.6717877388000488


 48%|████▊     | 48/100 [04:44<05:10,  5.98s/it]

Epoch: 48, Loss: 0.9058713912963867


 49%|████▉     | 49/100 [04:49<04:56,  5.82s/it]

Epoch: 49, Loss: 0.8275600671768188


 50%|█████     | 50/100 [04:56<05:07,  6.15s/it]

Epoch: 50, Loss: 0.5637804865837097


 51%|█████     | 51/100 [05:02<04:56,  6.05s/it]

Epoch: 51, Loss: 0.6523633003234863


 52%|█████▏    | 52/100 [05:08<04:55,  6.15s/it]

Epoch: 52, Loss: 0.6091826558113098


 53%|█████▎    | 53/100 [05:14<04:39,  5.95s/it]

Epoch: 53, Loss: 0.6664692163467407


 54%|█████▍    | 54/100 [05:20<04:39,  6.07s/it]

Epoch: 54, Loss: 0.6547764539718628


 55%|█████▌    | 55/100 [05:25<04:24,  5.89s/it]

Epoch: 55, Loss: 0.9247150421142578


 56%|█████▌    | 56/100 [05:32<04:24,  6.01s/it]

Epoch: 56, Loss: 0.6564657688140869


 57%|█████▋    | 57/100 [05:37<04:12,  5.88s/it]

Epoch: 57, Loss: 0.6348287463188171


 58%|█████▊    | 58/100 [05:44<04:15,  6.08s/it]

Epoch: 58, Loss: 0.6362752318382263


 59%|█████▉    | 59/100 [05:50<04:07,  6.03s/it]

Epoch: 59, Loss: 0.651766836643219


 60%|██████    | 60/100 [05:55<03:56,  5.91s/it]

Epoch: 60, Loss: 0.5797518491744995


 61%|██████    | 61/100 [06:02<03:54,  6.01s/it]

Epoch: 61, Loss: 0.7314668893814087


 62%|██████▏   | 62/100 [06:07<03:43,  5.87s/it]

Epoch: 62, Loss: 0.5151911973953247


 63%|██████▎   | 63/100 [06:14<03:42,  6.02s/it]

Epoch: 63, Loss: 0.5990029573440552


 64%|██████▍   | 64/100 [06:19<03:31,  5.87s/it]

Epoch: 64, Loss: 0.6065275073051453


 65%|██████▌   | 65/100 [06:26<03:31,  6.03s/it]

Epoch: 65, Loss: 0.5718080997467041


 66%|██████▌   | 66/100 [06:31<03:18,  5.83s/it]

Epoch: 66, Loss: 0.4960091710090637


 67%|██████▋   | 67/100 [06:37<03:16,  5.95s/it]

Epoch: 67, Loss: 0.5619643926620483


 68%|██████▊   | 68/100 [06:43<03:06,  5.82s/it]

Epoch: 68, Loss: 0.64657062292099


 69%|██████▉   | 69/100 [06:49<03:05,  5.97s/it]

Epoch: 69, Loss: 0.5532312393188477


 70%|███████   | 70/100 [06:55<02:55,  5.85s/it]

Epoch: 70, Loss: 0.5995537042617798


 71%|███████   | 71/100 [07:01<02:51,  5.90s/it]

Epoch: 71, Loss: 0.7022536396980286


 72%|███████▏  | 72/100 [07:06<02:44,  5.89s/it]

Epoch: 72, Loss: 0.6778134107589722


 73%|███████▎  | 73/100 [07:12<02:37,  5.82s/it]

Epoch: 73, Loss: 0.4786166250705719


 74%|███████▍  | 74/100 [07:18<02:34,  5.95s/it]

Epoch: 74, Loss: 0.6167012453079224


 75%|███████▌  | 75/100 [07:24<02:25,  5.83s/it]

Epoch: 75, Loss: 0.4903317093849182


 76%|███████▌  | 76/100 [07:30<02:22,  5.93s/it]

Epoch: 76, Loss: 0.49548137187957764


 77%|███████▋  | 77/100 [07:36<02:13,  5.80s/it]

Epoch: 77, Loss: 0.6782726645469666


 78%|███████▊  | 78/100 [07:42<02:11,  5.98s/it]

Epoch: 78, Loss: 0.6162128448486328


 79%|███████▉  | 79/100 [07:47<02:02,  5.84s/it]

Epoch: 79, Loss: 0.5221194624900818


 80%|████████  | 80/100 [07:54<02:00,  6.04s/it]

Epoch: 80, Loss: 0.6164288520812988


 81%|████████  | 81/100 [07:59<01:51,  5.89s/it]

Epoch: 81, Loss: 0.4847797751426697


 82%|████████▏ | 82/100 [08:06<01:48,  6.04s/it]

Epoch: 82, Loss: 0.6596328020095825


 83%|████████▎ | 83/100 [08:11<01:40,  5.90s/it]

Epoch: 83, Loss: 0.6651067733764648


 84%|████████▍ | 84/100 [08:18<01:35,  5.96s/it]

Epoch: 84, Loss: 0.5531089901924133


 85%|████████▌ | 85/100 [08:23<01:28,  5.91s/it]

Epoch: 85, Loss: 0.4573189914226532


 86%|████████▌ | 86/100 [08:29<01:21,  5.85s/it]

Epoch: 86, Loss: 0.4885534346103668


 87%|████████▋ | 87/100 [08:35<01:17,  5.97s/it]

Epoch: 87, Loss: 0.451158732175827


 88%|████████▊ | 88/100 [08:41<01:10,  5.85s/it]

Epoch: 88, Loss: 0.5822319984436035


 89%|████████▉ | 89/100 [08:47<01:06,  6.01s/it]

Epoch: 89, Loss: 0.6069812178611755


 90%|█████████ | 90/100 [08:53<00:58,  5.88s/it]

Epoch: 90, Loss: 0.4350907802581787


 91%|█████████ | 91/100 [08:59<00:54,  6.03s/it]

Epoch: 91, Loss: 0.400992214679718


 92%|█████████▏| 92/100 [09:05<00:47,  5.89s/it]

Epoch: 92, Loss: 0.3757736384868622


 93%|█████████▎| 93/100 [09:11<00:42,  6.04s/it]

Epoch: 93, Loss: 0.4884915351867676


 94%|█████████▍| 94/100 [09:17<00:35,  5.89s/it]

Epoch: 94, Loss: 0.40156930685043335


 95%|█████████▌| 95/100 [09:23<00:30,  6.04s/it]

Epoch: 95, Loss: 0.5747473239898682


 96%|█████████▌| 96/100 [09:29<00:24,  6.06s/it]

Epoch: 96, Loss: 0.5181127786636353


 97%|█████████▋| 97/100 [09:36<00:18,  6.24s/it]

Epoch: 97, Loss: 0.4606102705001831


 98%|█████████▊| 98/100 [09:41<00:12,  6.04s/it]

Epoch: 98, Loss: 0.45930391550064087


 99%|█████████▉| 99/100 [09:48<00:06,  6.10s/it]

Epoch: 99, Loss: 0.42324739694595337


100%|██████████| 100/100 [09:53<00:00,  5.94s/it]

Epoch: 100, Loss: 0.5164247155189514





In [67]:
final_user_representations = {}
repr = pos_z
for i in range(pos_z.shape[0]):
  final_user_representations[i] = (i, repr[i].detach(), user_representations[i][1])

In [68]:
user_dataset = UserDataset(final_user_representations)

n = len(user_dataset)

indices = np.arange(n)
indices = np.random.permutation(indices)

train_indices = indices [:int(0.8*n)]
test_indices = indices[int(0.8*n):]

user_train_dataset = Subset(user_dataset, train_indices)
user_test_dataset = Subset(user_dataset, test_indices)

user_train_dataloader = DataLoader(user_train_dataset, batch_size=64, shuffle=True)
user_test_dataloader = DataLoader(user_test_dataset, batch_size=64, shuffle=True)

In [69]:
num_classes = 186
classes = np.arange(0, num_classes)

In [70]:
head_model = train2task()
torch.save(head_model.state_dict(), 'edge_dropout_only_infomax_model.pth')

Epoch: 0, Loss: 4.3176398277282715
Epoch: 1, Loss: 4.181057929992676
Epoch: 2, Loss: 4.10624361038208
Epoch: 3, Loss: 4.054159641265869
Epoch: 4, Loss: 4.013622760772705
Epoch: 5, Loss: 3.980647563934326
Epoch: 6, Loss: 3.952723979949951
Epoch: 7, Loss: 3.928549289703369
Epoch: 8, Loss: 3.907130718231201
Epoch: 9, Loss: 3.887968063354492
Epoch: 10, Loss: 3.870573043823242
Epoch: 11, Loss: 3.8547472953796387
Epoch: 12, Loss: 3.8401265144348145
Epoch: 13, Loss: 3.8266243934631348
Epoch: 14, Loss: 3.814082622528076
Epoch: 15, Loss: 3.8023440837860107
Epoch: 16, Loss: 3.791351318359375
Epoch: 17, Loss: 3.780987024307251
Epoch: 18, Loss: 3.7712109088897705
Epoch: 19, Loss: 3.7618789672851562
Epoch: 20, Loss: 3.753053903579712
Epoch: 21, Loss: 3.7446439266204834
Epoch: 22, Loss: 3.7366271018981934
Epoch: 23, Loss: 3.7289466857910156
Epoch: 24, Loss: 3.721579074859619
Epoch: 25, Loss: 3.7145001888275146
Epoch: 26, Loss: 3.707683563232422
Epoch: 27, Loss: 3.701129674911499
Epoch: 28, Loss: 3.6

### Random walk, only infomax

In [71]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
model = DeepGraphInfomax(hidden_channels=128,
                         encoder=Encoder(interm_channels=74, hidden_channels=128,
                                         num_features=128),
                         summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)),
                         corruption=random_walk).to(device)
linkmodel = LinkPredHead()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)
loss_bce = nn.BCEWithLogitsLoss()

In [72]:
num_epoch = 100
loss_epoch = []
for epoch in tqdm(range(1, num_epoch+1)):
  if epoch == num_epoch:
    pos_z, neg_z, summary, loss = train_only_infomax(epoch)
  else:
    loss = train_only_infomax(epoch)
  loss_epoch.append(loss)

  1%|          | 1/100 [00:03<05:57,  3.61s/it]

Epoch: 1, Loss: 1.711686372756958


  2%|▏         | 2/100 [00:07<06:35,  4.03s/it]

Epoch: 2, Loss: 1.4071383476257324


  3%|▎         | 3/100 [00:11<06:07,  3.79s/it]

Epoch: 3, Loss: 1.3003087043762207


  4%|▍         | 4/100 [00:14<05:51,  3.66s/it]

Epoch: 4, Loss: 0.9399532675743103


  5%|▌         | 5/100 [00:18<05:49,  3.68s/it]

Epoch: 5, Loss: 1.3750450611114502


  6%|▌         | 6/100 [00:22<06:00,  3.84s/it]

Epoch: 6, Loss: 0.9901720285415649


  7%|▋         | 7/100 [00:26<05:47,  3.73s/it]

Epoch: 7, Loss: 1.1778173446655273


  8%|▊         | 8/100 [00:29<05:35,  3.65s/it]

Epoch: 8, Loss: 0.5564008951187134


  9%|▉         | 9/100 [00:34<05:50,  3.85s/it]

Epoch: 9, Loss: 0.7993941307067871


 10%|█         | 10/100 [00:37<05:36,  3.74s/it]

Epoch: 10, Loss: 0.9172267913818359


 11%|█         | 11/100 [00:41<05:26,  3.66s/it]

Epoch: 11, Loss: 1.1020342111587524


 12%|█▏        | 12/100 [00:44<05:26,  3.71s/it]

Epoch: 12, Loss: 0.43399906158447266


 13%|█▎        | 13/100 [00:48<05:29,  3.78s/it]

Epoch: 13, Loss: 1.9519416093826294


 14%|█▍        | 14/100 [00:52<05:17,  3.69s/it]

Epoch: 14, Loss: 0.9218191504478455


 15%|█▌        | 15/100 [00:55<05:08,  3.62s/it]

Epoch: 15, Loss: 0.7861045598983765


 16%|█▌        | 16/100 [01:00<05:22,  3.84s/it]

Epoch: 16, Loss: 0.5165513157844543


 17%|█▋        | 17/100 [01:03<05:10,  3.74s/it]

Epoch: 17, Loss: 0.9779471755027771


 18%|█▊        | 18/100 [01:07<04:59,  3.66s/it]

Epoch: 18, Loss: 1.0520133972167969


 19%|█▉        | 19/100 [01:11<05:04,  3.76s/it]

Epoch: 19, Loss: 0.9728301763534546


 20%|██        | 20/100 [01:14<05:01,  3.76s/it]

Epoch: 20, Loss: 0.6974212527275085


 21%|██        | 21/100 [01:18<04:50,  3.67s/it]

Epoch: 21, Loss: 1.4848932027816772


 22%|██▏       | 22/100 [01:21<04:40,  3.60s/it]

Epoch: 22, Loss: 0.8766263127326965


 23%|██▎       | 23/100 [01:25<04:53,  3.81s/it]

Epoch: 23, Loss: 0.6860183477401733


 24%|██▍       | 24/100 [01:29<04:41,  3.70s/it]

Epoch: 24, Loss: 0.49025699496269226


 25%|██▌       | 25/100 [01:32<04:32,  3.64s/it]

Epoch: 25, Loss: 0.7988259792327881


 26%|██▌       | 26/100 [01:37<04:40,  3.79s/it]

Epoch: 26, Loss: 1.4876835346221924


 27%|██▋       | 27/100 [01:40<04:35,  3.78s/it]

Epoch: 27, Loss: 1.200552225112915


 28%|██▊       | 28/100 [01:44<04:26,  3.70s/it]

Epoch: 28, Loss: 1.8663630485534668


 29%|██▉       | 29/100 [01:48<04:23,  3.71s/it]

Epoch: 29, Loss: 0.7294443845748901


 30%|███       | 30/100 [01:53<04:51,  4.17s/it]

Epoch: 30, Loss: 0.747558057308197


 31%|███       | 31/100 [01:56<04:33,  3.96s/it]

Epoch: 31, Loss: 0.7424470782279968


 32%|███▏      | 32/100 [02:00<04:19,  3.81s/it]

Epoch: 32, Loss: 1.1767586469650269


 33%|███▎      | 33/100 [02:04<04:15,  3.82s/it]

Epoch: 33, Loss: 0.7068695425987244


 34%|███▍      | 34/100 [02:08<04:13,  3.85s/it]

Epoch: 34, Loss: 0.8451048731803894


 35%|███▌      | 35/100 [02:11<04:05,  3.77s/it]

Epoch: 35, Loss: 1.5653131008148193


 36%|███▌      | 36/100 [02:15<03:56,  3.70s/it]

Epoch: 36, Loss: 0.8931060433387756


 37%|███▋      | 37/100 [02:19<04:04,  3.88s/it]

Epoch: 37, Loss: 0.7661505937576294


 38%|███▊      | 38/100 [02:22<03:53,  3.76s/it]

Epoch: 38, Loss: 0.5772753953933716


 39%|███▉      | 39/100 [02:26<03:44,  3.68s/it]

Epoch: 39, Loss: 1.033058524131775


 40%|████      | 40/100 [02:30<03:48,  3.81s/it]

Epoch: 40, Loss: 0.7141474485397339


 41%|████      | 41/100 [02:34<03:43,  3.79s/it]

Epoch: 41, Loss: 1.3491281270980835


 42%|████▏     | 42/100 [02:37<03:34,  3.70s/it]

Epoch: 42, Loss: 0.8802411556243896


 43%|████▎     | 43/100 [02:41<03:27,  3.64s/it]

Epoch: 43, Loss: 1.038277506828308


 44%|████▍     | 44/100 [02:45<03:35,  3.85s/it]

Epoch: 44, Loss: 0.9672227501869202


 45%|████▌     | 45/100 [02:49<03:26,  3.75s/it]

Epoch: 45, Loss: 0.5891546010971069


 46%|████▌     | 46/100 [02:52<03:18,  3.67s/it]

Epoch: 46, Loss: 0.8722288012504578


 47%|████▋     | 47/100 [02:56<03:24,  3.87s/it]

Epoch: 47, Loss: 0.730302631855011


 48%|████▊     | 48/100 [03:00<03:16,  3.78s/it]

Epoch: 48, Loss: 1.3399732112884521


 49%|████▉     | 49/100 [03:04<03:08,  3.70s/it]

Epoch: 49, Loss: 0.8549857139587402


 50%|█████     | 50/100 [03:07<03:06,  3.73s/it]

Epoch: 50, Loss: 0.7024387717247009


 51%|█████     | 51/100 [03:11<03:07,  3.82s/it]

Epoch: 51, Loss: 0.5977055430412292


 52%|█████▏    | 52/100 [03:15<02:58,  3.73s/it]

Epoch: 52, Loss: 0.7001810669898987


 53%|█████▎    | 53/100 [03:18<02:52,  3.67s/it]

Epoch: 53, Loss: 0.5755147337913513


 54%|█████▍    | 54/100 [03:23<02:57,  3.87s/it]

Epoch: 54, Loss: 0.8487597703933716


 55%|█████▌    | 55/100 [03:26<02:49,  3.76s/it]

Epoch: 55, Loss: 0.458515465259552


 56%|█████▌    | 56/100 [03:30<02:42,  3.68s/it]

Epoch: 56, Loss: 0.6594005823135376


 57%|█████▋    | 57/100 [03:34<02:42,  3.77s/it]

Epoch: 57, Loss: 0.4925558269023895


 58%|█████▊    | 58/100 [03:38<02:39,  3.79s/it]

Epoch: 58, Loss: 0.5689266920089722


 59%|█████▉    | 59/100 [03:41<02:32,  3.71s/it]

Epoch: 59, Loss: 1.3245470523834229


 60%|██████    | 60/100 [03:45<02:26,  3.67s/it]

Epoch: 60, Loss: 1.1076594591140747


 61%|██████    | 61/100 [03:49<02:31,  3.89s/it]

Epoch: 61, Loss: 0.7005027532577515


 62%|██████▏   | 62/100 [03:53<02:23,  3.77s/it]

Epoch: 62, Loss: 0.6459134817123413


 63%|██████▎   | 63/100 [03:56<02:17,  3.72s/it]

Epoch: 63, Loss: 0.5450297594070435


 64%|██████▍   | 64/100 [04:01<02:21,  3.92s/it]

Epoch: 64, Loss: 1.4361319541931152


 65%|██████▌   | 65/100 [04:04<02:12,  3.80s/it]

Epoch: 65, Loss: 0.6036865711212158


 66%|██████▌   | 66/100 [04:08<02:06,  3.71s/it]

Epoch: 66, Loss: 0.6196240782737732


 67%|██████▋   | 67/100 [04:11<02:02,  3.73s/it]

Epoch: 67, Loss: 0.6672444343566895


 68%|██████▊   | 68/100 [04:15<02:02,  3.84s/it]

Epoch: 68, Loss: 1.1080158948898315


 69%|██████▉   | 69/100 [04:19<01:56,  3.75s/it]

Epoch: 69, Loss: 0.8614253401756287


 70%|███████   | 70/100 [04:22<01:50,  3.67s/it]

Epoch: 70, Loss: 0.6354131102561951


 71%|███████   | 71/100 [04:27<01:52,  3.87s/it]

Epoch: 71, Loss: 0.7012133002281189


 72%|███████▏  | 72/100 [04:30<01:45,  3.76s/it]

Epoch: 72, Loss: 0.8832562565803528


 73%|███████▎  | 73/100 [04:34<01:39,  3.69s/it]

Epoch: 73, Loss: 0.5837719440460205


 74%|███████▍  | 74/100 [04:38<01:37,  3.76s/it]

Epoch: 74, Loss: 0.750355064868927


 75%|███████▌  | 75/100 [04:42<01:34,  3.80s/it]

Epoch: 75, Loss: 0.5192269086837769


 76%|███████▌  | 76/100 [04:45<01:28,  3.70s/it]

Epoch: 76, Loss: 0.38322216272354126


 77%|███████▋  | 77/100 [04:49<01:23,  3.64s/it]

Epoch: 77, Loss: 0.573243260383606


 78%|███████▊  | 78/100 [04:53<01:24,  3.84s/it]

Epoch: 78, Loss: 0.5429826378822327


 79%|███████▉  | 79/100 [04:56<01:18,  3.73s/it]

Epoch: 79, Loss: 0.3419991731643677


 80%|████████  | 80/100 [05:00<01:13,  3.65s/it]

Epoch: 80, Loss: 0.4553370475769043


 81%|████████  | 81/100 [05:04<01:12,  3.79s/it]

Epoch: 81, Loss: 0.7708407044410706


 82%|████████▏ | 82/100 [05:08<01:07,  3.76s/it]

Epoch: 82, Loss: 0.9793410301208496


 83%|████████▎ | 83/100 [05:11<01:02,  3.67s/it]

Epoch: 83, Loss: 0.8054823875427246


 84%|████████▍ | 84/100 [05:15<00:58,  3.63s/it]

Epoch: 84, Loss: 0.4554477632045746


 85%|████████▌ | 85/100 [05:19<00:57,  3.82s/it]

Epoch: 85, Loss: 0.8651556372642517


 86%|████████▌ | 86/100 [05:22<00:52,  3.72s/it]

Epoch: 86, Loss: 0.49185678362846375


 87%|████████▋ | 87/100 [05:26<00:47,  3.66s/it]

Epoch: 87, Loss: 0.596562922000885


 88%|████████▊ | 88/100 [05:30<00:46,  3.88s/it]

Epoch: 88, Loss: 0.5497109889984131


 89%|████████▉ | 89/100 [05:34<00:41,  3.77s/it]

Epoch: 89, Loss: 0.5123849511146545


 90%|█████████ | 90/100 [05:37<00:37,  3.71s/it]

Epoch: 90, Loss: 0.5087294578552246


 91%|█████████ | 91/100 [05:41<00:33,  3.76s/it]

Epoch: 91, Loss: 0.5047659277915955


 92%|█████████▏| 92/100 [05:45<00:30,  3.85s/it]

Epoch: 92, Loss: 0.5625943541526794


 93%|█████████▎| 93/100 [05:49<00:26,  3.76s/it]

Epoch: 93, Loss: 1.2534276247024536


 94%|█████████▍| 94/100 [05:52<00:22,  3.69s/it]

Epoch: 94, Loss: 0.3239452540874481


 95%|█████████▌| 95/100 [05:57<00:19,  3.90s/it]

Epoch: 95, Loss: 0.39297789335250854


 96%|█████████▌| 96/100 [06:00<00:15,  3.80s/it]

Epoch: 96, Loss: 0.2911743223667145


 97%|█████████▋| 97/100 [06:04<00:11,  3.71s/it]

Epoch: 97, Loss: 0.4084412157535553


 98%|█████████▊| 98/100 [06:08<00:07,  3.83s/it]

Epoch: 98, Loss: 0.7245776653289795


 99%|█████████▉| 99/100 [06:12<00:03,  3.82s/it]

Epoch: 99, Loss: 1.1137895584106445


100%|██████████| 100/100 [06:15<00:00,  3.76s/it]

Epoch: 100, Loss: 0.4775034785270691





In [73]:
final_user_representations = {}
repr = pos_z
for i in range(pos_z.shape[0]):
  final_user_representations[i] = (i, repr[i].detach(), user_representations[i][1])

In [74]:
user_dataset = UserDataset(final_user_representations)

n = len(user_dataset)

indices = np.arange(n)
indices = np.random.permutation(indices)

train_indices = indices [:int(0.8*n)]
test_indices = indices[int(0.8*n):]

user_train_dataset = Subset(user_dataset, train_indices)
user_test_dataset = Subset(user_dataset, test_indices)

user_train_dataloader = DataLoader(user_train_dataset, batch_size=64, shuffle=True)
user_test_dataloader = DataLoader(user_test_dataset, batch_size=64, shuffle=True)

In [75]:
num_classes = 186
classes = np.arange(0, num_classes)

In [76]:
head_model = train2task()
torch.save(head_model.state_dict(), 'random_walk_only_infomax_model.pth')

Epoch: 0, Loss: 4.408990383148193
Epoch: 1, Loss: 4.213274955749512
Epoch: 2, Loss: 4.10835075378418
Epoch: 3, Loss: 4.036770820617676
Epoch: 4, Loss: 3.982470750808716
Epoch: 5, Loss: 3.9388766288757324
Epoch: 6, Loss: 3.9026286602020264
Epoch: 7, Loss: 3.8715105056762695
Epoch: 8, Loss: 3.8444066047668457
Epoch: 9, Loss: 3.820385456085205
Epoch: 10, Loss: 3.7988123893737793
Epoch: 11, Loss: 3.7792716026306152
Epoch: 12, Loss: 3.7614333629608154
Epoch: 13, Loss: 3.7449991703033447
Epoch: 14, Loss: 3.729807138442993
Epoch: 15, Loss: 3.7157046794891357
Epoch: 16, Loss: 3.7025179862976074
Epoch: 17, Loss: 3.690117120742798
Epoch: 18, Loss: 3.678478717803955
Epoch: 19, Loss: 3.6674821376800537
Epoch: 20, Loss: 3.6570746898651123
Epoch: 21, Loss: 3.6471683979034424
Epoch: 22, Loss: 3.637773036956787
Epoch: 23, Loss: 3.628807306289673
Epoch: 24, Loss: 3.620236873626709
Epoch: 25, Loss: 3.6120357513427734
Epoch: 26, Loss: 3.6041738986968994
Epoch: 27, Loss: 3.5966131687164307
Epoch: 28, Loss

### Node mix, Infomax + Link pred

In [77]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
model = DeepGraphInfomax(hidden_channels=128,
                         encoder=Encoder(interm_channels=74, hidden_channels=128,
                                         num_features=128),
                         summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)),
                         corruption=node_mix).to(device)
linkmodel = LinkPredHead()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)
loss_bce = nn.BCEWithLogitsLoss()

In [78]:
num_epoch = 100
loss_epoch = []
for epoch in tqdm(range(1, num_epoch+1)):
  if epoch == num_epoch:
    pos_z, neg_z, summary, loss = train_infomax_linkpred(epoch)
  else:
    loss = train_infomax_linkpred(epoch)
  loss_epoch.append(loss)

  1%|          | 1/100 [00:09<15:38,  9.48s/it]

Epoch: 1, Loss: 3.34798264503479


  2%|▏         | 2/100 [00:16<13:18,  8.15s/it]

Epoch: 2, Loss: 2.848660707473755


  3%|▎         | 3/100 [00:23<12:09,  7.52s/it]

Epoch: 3, Loss: 2.7590811252593994


  4%|▍         | 4/100 [00:30<12:01,  7.51s/it]

Epoch: 4, Loss: 2.729125738143921


  5%|▌         | 5/100 [00:37<11:22,  7.19s/it]

Epoch: 5, Loss: 2.609304189682007


  6%|▌         | 6/100 [00:44<11:21,  7.25s/it]

Epoch: 6, Loss: 2.4844982624053955


  7%|▋         | 7/100 [00:51<10:53,  7.03s/it]

Epoch: 7, Loss: 1.8426642417907715


  8%|▊         | 8/100 [00:58<10:57,  7.14s/it]

Epoch: 8, Loss: 1.4918996095657349


  9%|▉         | 9/100 [01:05<10:32,  6.95s/it]

Epoch: 9, Loss: 2.163602828979492


 10%|█         | 10/100 [01:12<10:36,  7.07s/it]

Epoch: 10, Loss: 2.3957228660583496


 11%|█         | 11/100 [01:19<10:17,  6.93s/it]

Epoch: 11, Loss: 1.6912658214569092


 12%|█▏        | 12/100 [01:26<10:15,  7.00s/it]

Epoch: 12, Loss: 1.7507094144821167


 13%|█▎        | 13/100 [01:33<10:15,  7.07s/it]

Epoch: 13, Loss: 1.3219034671783447


 14%|█▍        | 14/100 [01:40<10:00,  6.99s/it]

Epoch: 14, Loss: 1.3419702053070068


 15%|█▌        | 15/100 [01:48<10:08,  7.16s/it]

Epoch: 15, Loss: 1.545944094657898


 16%|█▌        | 16/100 [01:54<09:47,  6.99s/it]

Epoch: 16, Loss: 1.4216115474700928


 17%|█▋        | 17/100 [02:02<09:52,  7.14s/it]

Epoch: 17, Loss: 1.8323876857757568


 18%|█▊        | 18/100 [02:09<09:56,  7.28s/it]

Epoch: 18, Loss: 1.4754638671875


 19%|█▉        | 19/100 [02:17<09:55,  7.35s/it]

Epoch: 19, Loss: 1.173267126083374


 20%|██        | 20/100 [02:24<09:45,  7.32s/it]

Epoch: 20, Loss: 1.121749758720398


 21%|██        | 21/100 [02:31<09:26,  7.17s/it]

Epoch: 21, Loss: 1.6118253469467163


 22%|██▏       | 22/100 [02:38<09:24,  7.23s/it]

Epoch: 22, Loss: 1.0595853328704834


 23%|██▎       | 23/100 [02:45<09:01,  7.03s/it]

Epoch: 23, Loss: 1.4063026905059814


 24%|██▍       | 24/100 [02:52<09:03,  7.15s/it]

Epoch: 24, Loss: 1.0364270210266113


 25%|██▌       | 25/100 [02:59<08:43,  6.98s/it]

Epoch: 25, Loss: 1.0539063215255737


 26%|██▌       | 26/100 [03:06<08:46,  7.12s/it]

Epoch: 26, Loss: 2.460965633392334


 27%|██▋       | 27/100 [03:13<08:28,  6.97s/it]

Epoch: 27, Loss: 1.0610300302505493


 28%|██▊       | 28/100 [03:20<08:32,  7.12s/it]

Epoch: 28, Loss: 0.9587936997413635


 29%|██▉       | 29/100 [03:27<08:22,  7.08s/it]

Epoch: 29, Loss: 1.0158443450927734


 30%|███       | 30/100 [03:34<08:16,  7.09s/it]

Epoch: 30, Loss: 0.9861739873886108


 31%|███       | 31/100 [03:42<08:15,  7.18s/it]

Epoch: 31, Loss: 1.0190565586090088


 32%|███▏      | 32/100 [03:49<07:59,  7.06s/it]

Epoch: 32, Loss: 1.0783450603485107


 33%|███▎      | 33/100 [03:56<08:00,  7.17s/it]

Epoch: 33, Loss: 0.8828232288360596


 34%|███▍      | 34/100 [04:03<07:41,  6.99s/it]

Epoch: 34, Loss: 0.9615055322647095


 35%|███▌      | 35/100 [04:10<07:42,  7.12s/it]

Epoch: 35, Loss: 0.887057363986969


 36%|███▌      | 36/100 [04:17<07:25,  6.95s/it]

Epoch: 36, Loss: 0.9485597610473633


 37%|███▋      | 37/100 [04:24<07:25,  7.08s/it]

Epoch: 37, Loss: 1.7526440620422363


 38%|███▊      | 38/100 [04:31<07:10,  6.94s/it]

Epoch: 38, Loss: 1.3078804016113281


 39%|███▉      | 39/100 [04:38<07:10,  7.05s/it]

Epoch: 39, Loss: 0.8526381254196167


 40%|████      | 40/100 [04:45<07:10,  7.17s/it]

Epoch: 40, Loss: 0.9802340865135193


 41%|████      | 41/100 [04:52<06:56,  7.07s/it]

Epoch: 41, Loss: 0.8796459436416626


 42%|████▏     | 42/100 [05:00<06:57,  7.21s/it]

Epoch: 42, Loss: 0.8720883131027222


 43%|████▎     | 43/100 [05:06<06:39,  7.01s/it]

Epoch: 43, Loss: 0.8418129086494446


 44%|████▍     | 44/100 [05:14<06:40,  7.16s/it]

Epoch: 44, Loss: 1.1977577209472656


 45%|████▌     | 45/100 [05:20<06:22,  6.96s/it]

Epoch: 45, Loss: 0.7963316440582275


 46%|████▌     | 46/100 [05:28<06:24,  7.13s/it]

Epoch: 46, Loss: 0.8710719347000122


 47%|████▋     | 47/100 [05:35<06:13,  7.05s/it]

Epoch: 47, Loss: 1.1282212734222412


 48%|████▊     | 48/100 [05:42<06:12,  7.16s/it]

Epoch: 48, Loss: 0.7836456298828125


 49%|████▉     | 49/100 [05:50<06:10,  7.26s/it]

Epoch: 49, Loss: 0.8171826601028442


 50%|█████     | 50/100 [05:56<05:52,  7.04s/it]

Epoch: 50, Loss: 0.8397074937820435


 51%|█████     | 51/100 [06:03<05:48,  7.12s/it]

Epoch: 51, Loss: 0.802282452583313


 52%|█████▏    | 52/100 [06:10<05:32,  6.92s/it]

Epoch: 52, Loss: 0.7683838605880737


 53%|█████▎    | 53/100 [06:17<05:30,  7.03s/it]

Epoch: 53, Loss: 0.7830791473388672


 54%|█████▍    | 54/100 [06:24<05:16,  6.88s/it]

Epoch: 54, Loss: 0.8506889343261719


 55%|█████▌    | 55/100 [06:31<05:16,  7.04s/it]

Epoch: 55, Loss: 1.1276350021362305


 56%|█████▌    | 56/100 [06:39<05:15,  7.17s/it]

Epoch: 56, Loss: 0.8493479490280151


 57%|█████▋    | 57/100 [06:46<05:08,  7.17s/it]

Epoch: 57, Loss: 0.7071700096130371


 58%|█████▊    | 58/100 [06:53<05:01,  7.18s/it]

Epoch: 58, Loss: 0.7808656692504883


 59%|█████▉    | 59/100 [07:00<04:49,  7.05s/it]

Epoch: 59, Loss: 0.7279914617538452


 60%|██████    | 60/100 [07:07<04:45,  7.15s/it]

Epoch: 60, Loss: 0.7450906038284302


 61%|██████    | 61/100 [07:14<04:32,  6.98s/it]

Epoch: 61, Loss: 0.7874808311462402


 62%|██████▏   | 62/100 [07:21<04:30,  7.12s/it]

Epoch: 62, Loss: 0.8226426839828491


 63%|██████▎   | 63/100 [07:28<04:16,  6.95s/it]

Epoch: 63, Loss: 1.7490012645721436


 64%|██████▍   | 64/100 [07:35<04:16,  7.12s/it]

Epoch: 64, Loss: 0.7931612133979797


 65%|██████▌   | 65/100 [07:42<04:03,  6.95s/it]

Epoch: 65, Loss: 0.7659151554107666


 66%|██████▌   | 66/100 [07:49<04:00,  7.07s/it]

Epoch: 66, Loss: 0.7396543025970459


 67%|██████▋   | 67/100 [07:56<03:49,  6.97s/it]

Epoch: 67, Loss: 0.9455547332763672


 68%|██████▊   | 68/100 [08:03<03:44,  7.03s/it]

Epoch: 68, Loss: 0.8690348863601685


 69%|██████▉   | 69/100 [08:10<03:39,  7.07s/it]

Epoch: 69, Loss: 1.8928561210632324


 70%|███████   | 70/100 [08:17<03:29,  6.98s/it]

Epoch: 70, Loss: 0.8313658237457275


 71%|███████   | 71/100 [08:24<03:26,  7.12s/it]

Epoch: 71, Loss: 0.8779860138893127


 72%|███████▏  | 72/100 [08:31<03:14,  6.95s/it]

Epoch: 72, Loss: 1.0625269412994385


 73%|███████▎  | 73/100 [08:38<03:10,  7.07s/it]

Epoch: 73, Loss: 0.8247459530830383


 74%|███████▍  | 74/100 [08:45<02:59,  6.92s/it]

Epoch: 74, Loss: 0.9459308981895447


 75%|███████▌  | 75/100 [08:52<02:56,  7.05s/it]

Epoch: 75, Loss: 0.8236207962036133


 76%|███████▌  | 76/100 [08:59<02:45,  6.88s/it]

Epoch: 76, Loss: 0.9261663556098938


 77%|███████▋  | 77/100 [09:06<02:41,  7.02s/it]

Epoch: 77, Loss: 1.1616692543029785


 78%|███████▊  | 78/100 [09:13<02:32,  6.91s/it]

Epoch: 78, Loss: 0.7148345708847046


 79%|███████▉  | 79/100 [09:20<02:26,  6.99s/it]

Epoch: 79, Loss: 0.6924749612808228


 80%|████████  | 80/100 [09:27<02:20,  7.05s/it]

Epoch: 80, Loss: 0.7595512866973877


 81%|████████  | 81/100 [09:34<02:12,  6.98s/it]

Epoch: 81, Loss: 0.6818481683731079


 82%|████████▏ | 82/100 [09:41<02:07,  7.08s/it]

Epoch: 82, Loss: 0.7423734068870544


 83%|████████▎ | 83/100 [09:48<01:57,  6.89s/it]

Epoch: 83, Loss: 1.1222681999206543


 84%|████████▍ | 84/100 [09:55<01:52,  7.05s/it]

Epoch: 84, Loss: 0.7354104518890381


 85%|████████▌ | 85/100 [10:01<01:42,  6.86s/it]

Epoch: 85, Loss: 1.0236214399337769


 86%|████████▌ | 86/100 [10:09<01:38,  7.02s/it]

Epoch: 86, Loss: 0.6785696744918823


 87%|████████▋ | 87/100 [10:15<01:29,  6.86s/it]

Epoch: 87, Loss: 0.6574486494064331


 88%|████████▊ | 88/100 [10:23<01:24,  7.04s/it]

Epoch: 88, Loss: 0.7964504957199097


 89%|████████▉ | 89/100 [10:29<01:16,  6.93s/it]

Epoch: 89, Loss: 0.6888188123703003


 90%|█████████ | 90/100 [10:37<01:09,  6.96s/it]

Epoch: 90, Loss: 0.9549484252929688


 91%|█████████ | 91/100 [10:44<01:03,  7.05s/it]

Epoch: 91, Loss: 0.6809626221656799


 92%|█████████▏| 92/100 [10:51<00:55,  6.96s/it]

Epoch: 92, Loss: 0.681032121181488


 93%|█████████▎| 93/100 [10:58<00:49,  7.10s/it]

Epoch: 93, Loss: 0.8444888591766357


 94%|█████████▍| 94/100 [11:05<00:42,  7.16s/it]

Epoch: 94, Loss: 0.8858747482299805


 95%|█████████▌| 95/100 [11:13<00:36,  7.29s/it]

Epoch: 95, Loss: 0.7405177354812622


 96%|█████████▌| 96/100 [11:19<00:28,  7.08s/it]

Epoch: 96, Loss: 0.6692734360694885


 97%|█████████▋| 97/100 [11:27<00:21,  7.20s/it]

Epoch: 97, Loss: 0.7152574062347412


 98%|█████████▊| 98/100 [11:34<00:14,  7.09s/it]

Epoch: 98, Loss: 0.7284792065620422


 99%|█████████▉| 99/100 [11:41<00:07,  7.10s/it]

Epoch: 99, Loss: 0.8222225904464722


100%|██████████| 100/100 [11:48<00:00,  7.09s/it]

Epoch: 100, Loss: 0.9420307874679565





In [79]:
final_user_representations = {}
repr = pos_z
for i in range(pos_z.shape[0]):
  final_user_representations[i] = (i, repr[i].detach(), user_representations[i][1])

In [80]:
user_dataset = UserDataset(final_user_representations)

n = len(user_dataset)

indices = np.arange(n)
indices = np.random.permutation(indices)

train_indices = indices [:int(0.8*n)]
test_indices = indices[int(0.8*n):]

user_train_dataset = Subset(user_dataset, train_indices)
user_test_dataset = Subset(user_dataset, test_indices)

user_train_dataloader = DataLoader(user_train_dataset, batch_size=64, shuffle=True)
user_test_dataloader = DataLoader(user_test_dataset, batch_size=64, shuffle=True)

In [81]:
num_classes = 186
classes = np.arange(0, num_classes)

In [82]:
head_model = train2task()
torch.save(head_model.state_dict(), 'node_mix_infomax_linkpred_model.pth')

Epoch: 0, Loss: 4.320979595184326
Epoch: 1, Loss: 4.219214916229248
Epoch: 2, Loss: 4.162557125091553
Epoch: 3, Loss: 4.1225810050964355
Epoch: 4, Loss: 4.091488838195801
Epoch: 5, Loss: 4.066256046295166
Epoch: 6, Loss: 4.044521331787109
Epoch: 7, Loss: 4.025753021240234
Epoch: 8, Loss: 4.0090861320495605
Epoch: 9, Loss: 3.9941675662994385
Epoch: 10, Loss: 3.980618953704834
Epoch: 11, Loss: 3.9680333137512207
Epoch: 12, Loss: 3.9565818309783936
Epoch: 13, Loss: 3.945878267288208
Epoch: 14, Loss: 3.9358906745910645
Epoch: 15, Loss: 3.9264256954193115
Epoch: 16, Loss: 3.917603015899658
Epoch: 17, Loss: 3.909235715866089
Epoch: 18, Loss: 3.901371717453003
Epoch: 19, Loss: 3.8938417434692383
Epoch: 20, Loss: 3.8866257667541504
Epoch: 21, Loss: 3.879770517349243
Epoch: 22, Loss: 3.873267412185669
Epoch: 23, Loss: 3.8669345378875732
Epoch: 24, Loss: 3.8609118461608887
Epoch: 25, Loss: 3.8550972938537598
Epoch: 26, Loss: 3.849501848220825
Epoch: 27, Loss: 3.844071626663208
Epoch: 28, Loss: 3

### Node dropout, Infomax + Link pred

In [83]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
model = DeepGraphInfomax(hidden_channels=128,
                         encoder=Encoder(interm_channels=74, hidden_channels=128,
                                         num_features=128),
                         summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)),
                         corruption=node_dropout).to(device)
linkmodel = LinkPredHead()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)
loss_bce = nn.BCEWithLogitsLoss()

In [84]:
num_epoch = 100
loss_epoch = []
for epoch in tqdm(range(1, num_epoch+1)):
  if epoch == num_epoch:
    pos_z, neg_z, summary, loss = train_infomax_linkpred(epoch)
  else:
    loss = train_infomax_linkpred(epoch)
  loss_epoch.append(loss)

  1%|          | 1/100 [00:07<13:10,  7.99s/it]

Epoch: 1, Loss: 3.664515256881714


  2%|▏         | 2/100 [00:16<13:31,  8.28s/it]

Epoch: 2, Loss: 3.490450859069824


  3%|▎         | 3/100 [00:25<13:42,  8.47s/it]

Epoch: 3, Loss: 3.3777847290039062


  4%|▍         | 4/100 [00:32<13:01,  8.14s/it]

Epoch: 4, Loss: 3.4287264347076416


  5%|▌         | 5/100 [00:41<13:08,  8.30s/it]

Epoch: 5, Loss: 2.96075439453125


  6%|▌         | 6/100 [00:49<13:03,  8.33s/it]

Epoch: 6, Loss: 3.1812496185302734


  7%|▋         | 7/100 [00:57<12:36,  8.14s/it]

Epoch: 7, Loss: 3.216418743133545


  8%|▊         | 8/100 [01:06<12:46,  8.33s/it]

Epoch: 8, Loss: 3.1292529106140137


  9%|▉         | 9/100 [01:14<12:33,  8.28s/it]

Epoch: 9, Loss: 2.971038818359375


 10%|█         | 10/100 [01:22<12:29,  8.33s/it]

Epoch: 10, Loss: 3.2290589809417725


 11%|█         | 11/100 [01:31<12:24,  8.36s/it]

Epoch: 11, Loss: 2.716290235519409


 12%|█▏        | 12/100 [01:39<12:11,  8.31s/it]

Epoch: 12, Loss: 3.33065128326416


 13%|█▎        | 13/100 [01:46<11:40,  8.05s/it]

Epoch: 13, Loss: 2.9940178394317627


 14%|█▍        | 14/100 [01:55<11:39,  8.13s/it]

Epoch: 14, Loss: 3.175687313079834


 15%|█▌        | 15/100 [02:02<11:15,  7.95s/it]

Epoch: 15, Loss: 2.8403100967407227


 16%|█▌        | 16/100 [02:11<11:18,  8.08s/it]

Epoch: 16, Loss: 2.9970505237579346


 17%|█▋        | 17/100 [02:19<11:21,  8.21s/it]

Epoch: 17, Loss: 3.1157748699188232


 18%|█▊        | 18/100 [02:27<11:00,  8.06s/it]

Epoch: 18, Loss: 3.0964436531066895


 19%|█▉        | 19/100 [02:36<11:11,  8.29s/it]

Epoch: 19, Loss: 3.1601898670196533


 20%|██        | 20/100 [02:44<11:08,  8.36s/it]

Epoch: 20, Loss: 3.246617555618286


 21%|██        | 21/100 [02:52<10:44,  8.16s/it]

Epoch: 21, Loss: 3.091555595397949


 22%|██▏       | 22/100 [03:01<10:53,  8.38s/it]

Epoch: 22, Loss: 3.139610767364502


 23%|██▎       | 23/100 [03:09<10:47,  8.41s/it]

Epoch: 23, Loss: 2.939831495285034


 24%|██▍       | 24/100 [03:17<10:21,  8.18s/it]

Epoch: 24, Loss: 3.301394462585449


 25%|██▌       | 25/100 [03:26<10:22,  8.30s/it]

Epoch: 25, Loss: 2.9104530811309814


 26%|██▌       | 26/100 [03:34<10:09,  8.23s/it]

Epoch: 26, Loss: 3.052274227142334


 27%|██▋       | 27/100 [03:42<10:00,  8.23s/it]

Epoch: 27, Loss: 3.107346773147583


 28%|██▊       | 28/100 [03:50<09:54,  8.26s/it]

Epoch: 28, Loss: 2.7719452381134033


 29%|██▉       | 29/100 [03:58<09:33,  8.07s/it]

Epoch: 29, Loss: 2.760962963104248


 30%|███       | 30/100 [04:06<09:34,  8.21s/it]

Epoch: 30, Loss: 2.9412007331848145


 31%|███       | 31/100 [04:15<09:26,  8.21s/it]

Epoch: 31, Loss: 2.756744146347046


 32%|███▏      | 32/100 [04:22<09:04,  8.01s/it]

Epoch: 32, Loss: 2.7730839252471924


 33%|███▎      | 33/100 [04:31<09:04,  8.13s/it]

Epoch: 33, Loss: 2.6887807846069336


 34%|███▍      | 34/100 [04:39<09:03,  8.24s/it]

Epoch: 34, Loss: 3.0596060752868652


 35%|███▌      | 35/100 [04:47<08:49,  8.15s/it]

Epoch: 35, Loss: 3.465054512023926


 36%|███▌      | 36/100 [04:56<08:52,  8.32s/it]

Epoch: 36, Loss: 3.296570062637329


 37%|███▋      | 37/100 [05:04<08:43,  8.32s/it]

Epoch: 37, Loss: 3.194810390472412


 38%|███▊      | 38/100 [05:12<08:21,  8.09s/it]

Epoch: 38, Loss: 2.9805750846862793


 39%|███▉      | 39/100 [05:20<08:26,  8.30s/it]

Epoch: 39, Loss: 3.2198398113250732


 40%|████      | 40/100 [05:28<08:05,  8.08s/it]

Epoch: 40, Loss: 2.779831647872925


 41%|████      | 41/100 [05:36<07:59,  8.13s/it]

Epoch: 41, Loss: 3.1248583793640137


 42%|████▏     | 42/100 [05:45<07:57,  8.23s/it]

Epoch: 42, Loss: 3.1689541339874268


 43%|████▎     | 43/100 [05:53<07:46,  8.18s/it]

Epoch: 43, Loss: 3.440685987472534


 44%|████▍     | 44/100 [06:02<07:55,  8.50s/it]

Epoch: 44, Loss: 2.7286717891693115


 45%|████▌     | 45/100 [06:10<07:48,  8.51s/it]

Epoch: 45, Loss: 3.393552303314209


 46%|████▌     | 46/100 [06:18<07:26,  8.26s/it]

Epoch: 46, Loss: 2.8403677940368652


 47%|████▋     | 47/100 [06:27<07:22,  8.35s/it]

Epoch: 47, Loss: 2.5722086429595947


 48%|████▊     | 48/100 [06:35<07:14,  8.35s/it]

Epoch: 48, Loss: 3.2635674476623535


 49%|████▉     | 49/100 [06:44<07:12,  8.48s/it]

Epoch: 49, Loss: 2.5551443099975586


 50%|█████     | 50/100 [06:52<07:05,  8.51s/it]

Epoch: 50, Loss: 2.966099500656128


 51%|█████     | 51/100 [07:01<06:57,  8.52s/it]

Epoch: 51, Loss: 3.1227238178253174


 52%|█████▏    | 52/100 [07:08<06:34,  8.23s/it]

Epoch: 52, Loss: 3.0967533588409424


 53%|█████▎    | 53/100 [07:17<06:32,  8.35s/it]

Epoch: 53, Loss: 2.682342529296875


 54%|█████▍    | 54/100 [07:32<07:56, 10.37s/it]

Epoch: 54, Loss: 2.93943190574646


 55%|█████▌    | 55/100 [07:41<07:25,  9.90s/it]

Epoch: 55, Loss: 3.1342790126800537


 56%|█████▌    | 56/100 [07:49<06:47,  9.26s/it]

Epoch: 56, Loss: 2.9758522510528564


 57%|█████▋    | 57/100 [07:58<06:35,  9.19s/it]

Epoch: 57, Loss: 3.128324508666992


 58%|█████▊    | 58/100 [08:06<06:14,  8.93s/it]

Epoch: 58, Loss: 2.503632068634033


 59%|█████▉    | 59/100 [08:14<05:51,  8.57s/it]

Epoch: 59, Loss: 3.087893009185791


 60%|██████    | 60/100 [08:22<05:42,  8.56s/it]

Epoch: 60, Loss: 3.125723123550415


 61%|██████    | 61/100 [08:30<05:20,  8.23s/it]

Epoch: 61, Loss: 3.017693042755127


 62%|██████▏   | 62/100 [08:39<05:18,  8.39s/it]

Epoch: 62, Loss: 2.5308191776275635


 63%|██████▎   | 63/100 [08:47<05:10,  8.40s/it]

Epoch: 63, Loss: 2.618776798248291


 64%|██████▍   | 64/100 [08:55<04:53,  8.15s/it]

Epoch: 64, Loss: 3.122264862060547


 65%|██████▌   | 65/100 [09:03<04:49,  8.27s/it]

Epoch: 65, Loss: 3.1275148391723633


 66%|██████▌   | 66/100 [09:12<04:43,  8.34s/it]

Epoch: 66, Loss: 3.250262975692749


 67%|██████▋   | 67/100 [09:19<04:29,  8.15s/it]

Epoch: 67, Loss: 3.113621711730957


 68%|██████▊   | 68/100 [09:28<04:22,  8.21s/it]

Epoch: 68, Loss: 2.950868606567383


 69%|██████▉   | 69/100 [09:36<04:19,  8.37s/it]

Epoch: 69, Loss: 2.507024049758911


 70%|███████   | 70/100 [09:44<04:03,  8.13s/it]

Epoch: 70, Loss: 2.676743745803833


 71%|███████   | 71/100 [09:53<03:58,  8.24s/it]

Epoch: 71, Loss: 3.271341323852539


 72%|███████▏  | 72/100 [10:00<03:46,  8.07s/it]

Epoch: 72, Loss: 2.7084808349609375


 73%|███████▎  | 73/100 [10:09<03:39,  8.15s/it]

Epoch: 73, Loss: 2.712580919265747


 74%|███████▍  | 74/100 [10:17<03:33,  8.22s/it]

Epoch: 74, Loss: 3.005690574645996


 75%|███████▌  | 75/100 [10:26<03:30,  8.43s/it]

Epoch: 75, Loss: 3.4528884887695312


 76%|███████▌  | 76/100 [10:34<03:21,  8.38s/it]

Epoch: 76, Loss: 3.4657819271087646


 77%|███████▋  | 77/100 [10:43<03:13,  8.41s/it]

Epoch: 77, Loss: 2.8269662857055664


 78%|███████▊  | 78/100 [10:50<03:01,  8.23s/it]

Epoch: 78, Loss: 3.1453449726104736


 79%|███████▉  | 79/100 [10:59<02:55,  8.36s/it]

Epoch: 79, Loss: 2.5275485515594482


 80%|████████  | 80/100 [11:08<02:48,  8.44s/it]

Epoch: 80, Loss: 3.3626534938812256


 81%|████████  | 81/100 [11:15<02:34,  8.14s/it]

Epoch: 81, Loss: 2.363163471221924


 82%|████████▏ | 82/100 [11:24<02:30,  8.36s/it]

Epoch: 82, Loss: 2.3826136589050293


 83%|████████▎ | 83/100 [11:32<02:21,  8.32s/it]

Epoch: 83, Loss: 2.6772749423980713


 84%|████████▍ | 84/100 [11:40<02:10,  8.15s/it]

Epoch: 84, Loss: 2.8820641040802


 85%|████████▌ | 85/100 [11:49<02:05,  8.39s/it]

Epoch: 85, Loss: 3.4995357990264893


 86%|████████▌ | 86/100 [11:57<01:57,  8.39s/it]

Epoch: 86, Loss: 2.201569080352783


 87%|████████▋ | 87/100 [12:05<01:45,  8.12s/it]

Epoch: 87, Loss: 2.206315279006958


 88%|████████▊ | 88/100 [12:13<01:38,  8.21s/it]

Epoch: 88, Loss: 2.7482175827026367


 89%|████████▉ | 89/100 [12:21<01:29,  8.15s/it]

Epoch: 89, Loss: 2.8475990295410156


 90%|█████████ | 90/100 [12:29<01:21,  8.14s/it]

Epoch: 90, Loss: 3.0092356204986572


 91%|█████████ | 91/100 [12:38<01:14,  8.29s/it]

Epoch: 91, Loss: 3.4208266735076904


 92%|█████████▏| 92/100 [12:46<01:04,  8.12s/it]

Epoch: 92, Loss: 2.9296605587005615


 93%|█████████▎| 93/100 [12:54<00:58,  8.32s/it]

Epoch: 93, Loss: 3.459144353866577


 94%|█████████▍| 94/100 [13:03<00:50,  8.46s/it]

Epoch: 94, Loss: 3.1057488918304443


 95%|█████████▌| 95/100 [13:11<00:40,  8.20s/it]

Epoch: 95, Loss: 3.009251832962036


 96%|█████████▌| 96/100 [13:20<00:33,  8.34s/it]

Epoch: 96, Loss: 2.9634342193603516


 97%|█████████▋| 97/100 [13:28<00:25,  8.41s/it]

Epoch: 97, Loss: 2.752070188522339


 98%|█████████▊| 98/100 [13:36<00:16,  8.23s/it]

Epoch: 98, Loss: 3.2625858783721924


 99%|█████████▉| 99/100 [13:45<00:08,  8.40s/it]

Epoch: 99, Loss: 3.368649482727051


100%|██████████| 100/100 [13:53<00:00,  8.34s/it]

Epoch: 100, Loss: 2.653125047683716





In [85]:
final_user_representations = {}
repr = pos_z
for i in range(pos_z.shape[0]):
  final_user_representations[i] = (i, repr[i].detach(), user_representations[i][1])

In [86]:
user_dataset = UserDataset(final_user_representations)

n = len(user_dataset)

indices = np.arange(n)
indices = np.random.permutation(indices)

train_indices = indices [:int(0.8*n)]
test_indices = indices[int(0.8*n):]

user_train_dataset = Subset(user_dataset, train_indices)
user_test_dataset = Subset(user_dataset, test_indices)

user_train_dataloader = DataLoader(user_train_dataset, batch_size=64, shuffle=True)
user_test_dataloader = DataLoader(user_test_dataset, batch_size=64, shuffle=True)

In [87]:
num_classes = 186
classes = np.arange(0, num_classes)

In [88]:
head_model = train2task()
torch.save(head_model.state_dict(), 'node_dropout_infomax_linkpred_model.pth')

Epoch: 0, Loss: 4.470142841339111
Epoch: 1, Loss: 4.250609874725342
Epoch: 2, Loss: 4.13272762298584
Epoch: 3, Loss: 4.051883697509766
Epoch: 4, Loss: 3.9903104305267334
Epoch: 5, Loss: 3.9406611919403076
Epoch: 6, Loss: 3.899014472961426
Epoch: 7, Loss: 3.863306999206543
Epoch: 8, Loss: 3.832075595855713
Epoch: 9, Loss: 3.804293632507324
Epoch: 10, Loss: 3.7793848514556885
Epoch: 11, Loss: 3.756803274154663
Epoch: 12, Loss: 3.736203908920288
Epoch: 13, Loss: 3.7172863483428955
Epoch: 14, Loss: 3.699819326400757
Epoch: 15, Loss: 3.6835925579071045
Epoch: 16, Loss: 3.6684701442718506
Epoch: 17, Loss: 3.6543116569519043
Epoch: 18, Loss: 3.6410152912139893
Epoch: 19, Loss: 3.6284961700439453
Epoch: 20, Loss: 3.6166610717773438
Epoch: 21, Loss: 3.605457067489624
Epoch: 22, Loss: 3.5947988033294678
Epoch: 23, Loss: 3.5846564769744873
Epoch: 24, Loss: 3.5749940872192383
Epoch: 25, Loss: 3.565760850906372
Epoch: 26, Loss: 3.5569303035736084
Epoch: 27, Loss: 3.548447608947754
Epoch: 28, Loss: 

### Edge dropout, Infomax + Link pred

In [89]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
model = DeepGraphInfomax(hidden_channels=128,
                         encoder=Encoder(interm_channels=74, hidden_channels=128,
                                         num_features=128),
                         summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)),
                         corruption=edge_dropout).to(device)
linkmodel = LinkPredHead()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)
loss_bce = nn.BCEWithLogitsLoss()

In [90]:
num_epoch = 100
loss_epoch = []
for epoch in tqdm(range(1, num_epoch+1)):
  if epoch == num_epoch:
    pos_z, neg_z, summary, loss = train_infomax_linkpred(epoch)
  else:
    loss = train_infomax_linkpred(epoch)
  loss_epoch.append(loss)

  1%|          | 1/100 [00:06<10:03,  6.10s/it]

Epoch: 1, Loss: 3.43694806098938


  2%|▏         | 2/100 [00:12<10:36,  6.49s/it]

Epoch: 2, Loss: 3.374835968017578


  3%|▎         | 3/100 [00:18<10:03,  6.23s/it]

Epoch: 3, Loss: 3.1819753646850586


  4%|▍         | 4/100 [00:25<10:18,  6.45s/it]

Epoch: 4, Loss: 3.0721793174743652


  5%|▌         | 5/100 [00:31<09:56,  6.28s/it]

Epoch: 5, Loss: 3.2621967792510986


  6%|▌         | 6/100 [00:38<10:07,  6.46s/it]

Epoch: 6, Loss: 3.0787882804870605


  7%|▋         | 7/100 [00:44<09:45,  6.29s/it]

Epoch: 7, Loss: 3.058798313140869


  8%|▊         | 8/100 [00:51<09:54,  6.46s/it]

Epoch: 8, Loss: 3.0738697052001953


  9%|▉         | 9/100 [00:57<09:34,  6.31s/it]

Epoch: 9, Loss: 3.0536210536956787


 10%|█         | 10/100 [01:03<09:40,  6.45s/it]

Epoch: 10, Loss: 3.0486812591552734


 11%|█         | 11/100 [01:09<09:18,  6.27s/it]

Epoch: 11, Loss: 3.115920066833496


 12%|█▏        | 12/100 [01:22<12:13,  8.34s/it]

Epoch: 12, Loss: 3.001203775405884


 13%|█▎        | 13/100 [01:30<11:42,  8.07s/it]

Epoch: 13, Loss: 3.002033233642578


 14%|█▍        | 14/100 [01:36<10:39,  7.44s/it]

Epoch: 14, Loss: 2.9755520820617676


 15%|█▌        | 15/100 [01:43<10:17,  7.27s/it]

Epoch: 15, Loss: 2.983144760131836


 16%|█▌        | 16/100 [01:49<09:39,  6.89s/it]

Epoch: 16, Loss: 2.9323089122772217


 17%|█▋        | 17/100 [01:55<09:31,  6.88s/it]

Epoch: 17, Loss: 2.992875337600708


 18%|█▊        | 18/100 [02:01<08:59,  6.58s/it]

Epoch: 18, Loss: 3.0085365772247314


 19%|█▉        | 19/100 [02:08<09:00,  6.67s/it]

Epoch: 19, Loss: 2.893618106842041


 20%|██        | 20/100 [02:14<08:37,  6.47s/it]

Epoch: 20, Loss: 2.829777240753174


 21%|██        | 21/100 [02:21<08:40,  6.58s/it]

Epoch: 21, Loss: 2.984802007675171


 22%|██▏       | 22/100 [02:27<08:16,  6.37s/it]

Epoch: 22, Loss: 2.9602713584899902


 23%|██▎       | 23/100 [02:34<08:20,  6.50s/it]

Epoch: 23, Loss: 2.9966659545898438


 24%|██▍       | 24/100 [02:40<08:00,  6.33s/it]

Epoch: 24, Loss: 2.641376256942749


 25%|██▌       | 25/100 [02:46<08:02,  6.44s/it]

Epoch: 25, Loss: 2.8564422130584717


 26%|██▌       | 26/100 [02:52<07:45,  6.29s/it]

Epoch: 26, Loss: 2.742447853088379


 27%|██▋       | 27/100 [02:59<07:47,  6.41s/it]

Epoch: 27, Loss: 2.9004666805267334


 28%|██▊       | 28/100 [03:06<07:51,  6.55s/it]

Epoch: 28, Loss: 2.8065786361694336


 29%|██▉       | 29/100 [03:13<07:47,  6.59s/it]

Epoch: 29, Loss: 2.7035770416259766


 30%|███       | 30/100 [03:18<07:26,  6.38s/it]

Epoch: 30, Loss: 2.8482728004455566


 31%|███       | 31/100 [03:25<07:26,  6.48s/it]

Epoch: 31, Loss: 2.7365574836730957


 32%|███▏      | 32/100 [03:31<07:10,  6.33s/it]

Epoch: 32, Loss: 2.6767826080322266


 33%|███▎      | 33/100 [03:38<07:12,  6.45s/it]

Epoch: 33, Loss: 2.6945652961730957


 34%|███▍      | 34/100 [03:44<06:54,  6.28s/it]

Epoch: 34, Loss: 2.7486305236816406


 35%|███▌      | 35/100 [03:50<06:55,  6.39s/it]

Epoch: 35, Loss: 2.7918193340301514


 36%|███▌      | 36/100 [03:56<06:39,  6.25s/it]

Epoch: 36, Loss: 2.8068695068359375


 37%|███▋      | 37/100 [04:03<06:41,  6.38s/it]

Epoch: 37, Loss: 2.1688246726989746


 38%|███▊      | 38/100 [04:09<06:26,  6.24s/it]

Epoch: 38, Loss: 2.6420960426330566


 39%|███▉      | 39/100 [04:16<06:28,  6.36s/it]

Epoch: 39, Loss: 2.3489389419555664


 40%|████      | 40/100 [04:21<06:13,  6.22s/it]

Epoch: 40, Loss: 2.2116966247558594


 41%|████      | 41/100 [04:28<06:13,  6.33s/it]

Epoch: 41, Loss: 2.473954439163208


 42%|████▏     | 42/100 [04:34<05:59,  6.19s/it]

Epoch: 42, Loss: 2.2813544273376465


 43%|████▎     | 43/100 [04:41<06:00,  6.32s/it]

Epoch: 43, Loss: 2.3919692039489746


 44%|████▍     | 44/100 [04:46<05:45,  6.17s/it]

Epoch: 44, Loss: 2.0033438205718994


 45%|████▌     | 45/100 [04:53<05:48,  6.34s/it]

Epoch: 45, Loss: 1.846014380455017


 46%|████▌     | 46/100 [04:59<05:34,  6.19s/it]

Epoch: 46, Loss: 1.832003116607666


 47%|████▋     | 47/100 [05:06<05:36,  6.34s/it]

Epoch: 47, Loss: 2.4317984580993652


 48%|████▊     | 48/100 [05:11<05:22,  6.20s/it]

Epoch: 48, Loss: 2.2515370845794678


 49%|████▉     | 49/100 [05:18<05:22,  6.32s/it]

Epoch: 49, Loss: 1.7482659816741943


 50%|█████     | 50/100 [05:24<05:07,  6.14s/it]

Epoch: 50, Loss: 2.086815357208252


 51%|█████     | 51/100 [05:30<05:06,  6.25s/it]

Epoch: 51, Loss: 1.9069534540176392


 52%|█████▏    | 52/100 [05:36<04:55,  6.16s/it]

Epoch: 52, Loss: 2.5395874977111816


 53%|█████▎    | 53/100 [05:43<04:54,  6.26s/it]

Epoch: 53, Loss: 1.9039366245269775


 54%|█████▍    | 54/100 [05:49<04:44,  6.18s/it]

Epoch: 54, Loss: 1.9664582014083862


 55%|█████▌    | 55/100 [05:55<04:39,  6.22s/it]

Epoch: 55, Loss: 2.0258982181549072


 56%|█████▌    | 56/100 [06:01<04:33,  6.22s/it]

Epoch: 56, Loss: 2.3273837566375732


 57%|█████▋    | 57/100 [06:07<04:24,  6.15s/it]

Epoch: 57, Loss: 1.7517691850662231


 58%|█████▊    | 58/100 [06:14<04:19,  6.18s/it]

Epoch: 58, Loss: 1.8474456071853638


 59%|█████▉    | 59/100 [06:19<04:09,  6.09s/it]

Epoch: 59, Loss: 1.6507267951965332


 60%|██████    | 60/100 [06:26<04:09,  6.25s/it]

Epoch: 60, Loss: 1.6201262474060059


 61%|██████    | 61/100 [06:32<03:59,  6.15s/it]

Epoch: 61, Loss: 1.9215202331542969


 62%|██████▏   | 62/100 [06:39<04:00,  6.32s/it]

Epoch: 62, Loss: 1.6842265129089355


 63%|██████▎   | 63/100 [06:45<03:48,  6.18s/it]

Epoch: 63, Loss: 2.380747079849243


 64%|██████▍   | 64/100 [06:51<03:48,  6.34s/it]

Epoch: 64, Loss: 1.481950283050537


 65%|██████▌   | 65/100 [06:57<03:37,  6.21s/it]

Epoch: 65, Loss: 1.4654607772827148


 66%|██████▌   | 66/100 [07:04<03:36,  6.36s/it]

Epoch: 66, Loss: 1.6652863025665283


 67%|██████▋   | 67/100 [07:10<03:24,  6.19s/it]

Epoch: 67, Loss: 1.9006166458129883


 68%|██████▊   | 68/100 [07:16<03:22,  6.32s/it]

Epoch: 68, Loss: 1.9269837141036987


 69%|██████▉   | 69/100 [07:22<03:11,  6.17s/it]

Epoch: 69, Loss: 1.835679054260254


 70%|███████   | 70/100 [07:29<03:09,  6.31s/it]

Epoch: 70, Loss: 1.654508352279663


 71%|███████   | 71/100 [07:35<03:06,  6.42s/it]

Epoch: 71, Loss: 2.1783249378204346


 72%|███████▏  | 72/100 [07:42<03:02,  6.51s/it]

Epoch: 72, Loss: 1.904577612876892


 73%|███████▎  | 73/100 [07:48<02:49,  6.29s/it]

Epoch: 73, Loss: 1.5626908540725708


 74%|███████▍  | 74/100 [07:55<02:46,  6.42s/it]

Epoch: 74, Loss: 1.5273444652557373


 75%|███████▌  | 75/100 [08:00<02:35,  6.23s/it]

Epoch: 75, Loss: 1.923181176185608


 76%|███████▌  | 76/100 [08:07<02:33,  6.38s/it]

Epoch: 76, Loss: 1.6889731884002686


 77%|███████▋  | 77/100 [08:13<02:22,  6.20s/it]

Epoch: 77, Loss: 1.766061782836914


 78%|███████▊  | 78/100 [08:20<02:20,  6.37s/it]

Epoch: 78, Loss: 2.739710569381714


 79%|███████▉  | 79/100 [08:25<02:10,  6.20s/it]

Epoch: 79, Loss: 1.5315260887145996


 80%|████████  | 80/100 [08:32<02:06,  6.34s/it]

Epoch: 80, Loss: 1.4378925561904907


 81%|████████  | 81/100 [08:38<01:57,  6.17s/it]

Epoch: 81, Loss: 1.4884517192840576


 82%|████████▏ | 82/100 [08:45<01:53,  6.30s/it]

Epoch: 82, Loss: 1.7060127258300781


 83%|████████▎ | 83/100 [08:50<01:44,  6.15s/it]

Epoch: 83, Loss: 1.4190292358398438


 84%|████████▍ | 84/100 [08:57<01:41,  6.33s/it]

Epoch: 84, Loss: 1.7566395998001099


 85%|████████▌ | 85/100 [09:03<01:33,  6.25s/it]

Epoch: 85, Loss: 1.5642518997192383


 86%|████████▌ | 86/100 [09:10<01:28,  6.36s/it]

Epoch: 86, Loss: 1.5340303182601929


 87%|████████▋ | 87/100 [09:16<01:20,  6.22s/it]

Epoch: 87, Loss: 1.4752646684646606


 88%|████████▊ | 88/100 [09:22<01:16,  6.34s/it]

Epoch: 88, Loss: 1.3723173141479492


 89%|████████▉ | 89/100 [09:28<01:07,  6.18s/it]

Epoch: 89, Loss: 1.5187358856201172


 90%|█████████ | 90/100 [09:35<01:03,  6.32s/it]

Epoch: 90, Loss: 1.2923630475997925


 91%|█████████ | 91/100 [09:40<00:55,  6.16s/it]

Epoch: 91, Loss: 1.59932279586792


 92%|█████████▏| 92/100 [09:47<00:50,  6.30s/it]

Epoch: 92, Loss: 1.2937438488006592


 93%|█████████▎| 93/100 [09:53<00:43,  6.19s/it]

Epoch: 93, Loss: 1.3293522596359253


 94%|█████████▍| 94/100 [10:00<00:37,  6.29s/it]

Epoch: 94, Loss: 1.3653085231781006


 95%|█████████▌| 95/100 [10:05<00:30,  6.17s/it]

Epoch: 95, Loss: 1.2880083322525024


 96%|█████████▌| 96/100 [10:12<00:24,  6.23s/it]

Epoch: 96, Loss: 1.7510793209075928


 97%|█████████▋| 97/100 [10:18<00:18,  6.19s/it]

Epoch: 97, Loss: 1.651495337486267


 98%|█████████▊| 98/100 [10:24<00:12,  6.22s/it]

Epoch: 98, Loss: 1.4260797500610352


 99%|█████████▉| 99/100 [10:30<00:06,  6.22s/it]

Epoch: 99, Loss: 1.8688944578170776


100%|██████████| 100/100 [10:37<00:00,  6.37s/it]

Epoch: 100, Loss: 1.8372821807861328





In [91]:
final_user_representations = {}
repr = pos_z
for i in range(pos_z.shape[0]):
  final_user_representations[i] = (i, repr[i].detach(), user_representations[i][1])

In [92]:
user_dataset = UserDataset(final_user_representations)

n = len(user_dataset)

indices = np.arange(n)
indices = np.random.permutation(indices)

train_indices = indices [:int(0.8*n)]
test_indices = indices[int(0.8*n):]

user_train_dataset = Subset(user_dataset, train_indices)
user_test_dataset = Subset(user_dataset, test_indices)

user_train_dataloader = DataLoader(user_train_dataset, batch_size=64, shuffle=True)
user_test_dataloader = DataLoader(user_test_dataset, batch_size=64, shuffle=True)

In [93]:
num_classes = 186
classes = np.arange(0, num_classes)

In [94]:
head_model = train2task()
torch.save(head_model.state_dict(), 'edge_dropout_infomax_linkpred_model.pth')

Epoch: 0, Loss: 4.329938888549805
Epoch: 1, Loss: 4.177267074584961
Epoch: 2, Loss: 4.094830513000488
Epoch: 3, Loss: 4.038265705108643
Epoch: 4, Loss: 3.994664430618286
Epoch: 5, Loss: 3.959409236907959
Epoch: 6, Loss: 3.9297213554382324
Epoch: 7, Loss: 3.903956890106201
Epoch: 8, Loss: 3.881286859512329
Epoch: 9, Loss: 3.8610832691192627
Epoch: 10, Loss: 3.8427982330322266
Epoch: 11, Loss: 3.8262059688568115
Epoch: 12, Loss: 3.8109612464904785
Epoch: 13, Loss: 3.796868324279785
Epoch: 14, Loss: 3.783818483352661
Epoch: 15, Loss: 3.771604061126709
Epoch: 16, Loss: 3.760194778442383
Epoch: 17, Loss: 3.7494356632232666
Epoch: 18, Loss: 3.739291191101074
Epoch: 19, Loss: 3.7297346591949463
Epoch: 20, Loss: 3.7206382751464844
Epoch: 21, Loss: 3.712029218673706
Epoch: 22, Loss: 3.703831195831299
Epoch: 23, Loss: 3.6959879398345947
Epoch: 24, Loss: 3.6885006427764893
Epoch: 25, Loss: 3.6813390254974365
Epoch: 26, Loss: 3.6744658946990967
Epoch: 27, Loss: 3.667891263961792
Epoch: 28, Loss: 3

### Random walk, Infomax + Link pred

In [95]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
model = DeepGraphInfomax(hidden_channels=128,
                         encoder=Encoder(interm_channels=74, hidden_channels=128,
                                         num_features=128),
                         summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)),
                         corruption=random_walk).to(device)
linkmodel = LinkPredHead()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)
loss_bce = nn.BCEWithLogitsLoss()

In [96]:
num_epoch = 100
loss_epoch = []
for epoch in tqdm(range(1, num_epoch+1)):
  if epoch == num_epoch:
    pos_z, neg_z, summary, loss = train_infomax_linkpred(epoch)
  else:
    loss = train_infomax_linkpred(epoch)
  loss_epoch.append(loss)

  1%|          | 1/100 [00:03<06:35,  4.00s/it]

Epoch: 1, Loss: 4.045363903045654


  2%|▏         | 2/100 [00:08<07:14,  4.43s/it]

Epoch: 2, Loss: 3.7003698348999023


  3%|▎         | 3/100 [00:12<06:45,  4.18s/it]

Epoch: 3, Loss: 2.869166135787964


  4%|▍         | 4/100 [00:16<06:30,  4.07s/it]

Epoch: 4, Loss: 3.1408703327178955


  5%|▌         | 5/100 [00:21<06:41,  4.23s/it]

Epoch: 5, Loss: 2.8785858154296875


  6%|▌         | 6/100 [00:25<06:31,  4.16s/it]

Epoch: 6, Loss: 2.4802544116973877


  7%|▋         | 7/100 [00:28<06:18,  4.07s/it]

Epoch: 7, Loss: 2.8583176136016846


  8%|▊         | 8/100 [00:33<06:24,  4.18s/it]

Epoch: 8, Loss: 2.4600753784179688


  9%|▉         | 9/100 [00:37<06:19,  4.17s/it]

Epoch: 9, Loss: 2.5935864448547363


 10%|█         | 10/100 [00:41<06:07,  4.08s/it]

Epoch: 10, Loss: 2.3615214824676514


 11%|█         | 11/100 [00:45<06:06,  4.11s/it]

Epoch: 11, Loss: 2.675266742706299


 12%|█▏        | 12/100 [00:49<06:06,  4.17s/it]

Epoch: 12, Loss: 2.827786922454834


 13%|█▎        | 13/100 [00:53<05:54,  4.08s/it]

Epoch: 13, Loss: 2.1640090942382812


 14%|█▍        | 14/100 [00:57<05:49,  4.06s/it]

Epoch: 14, Loss: 1.5669028759002686


 15%|█▌        | 15/100 [01:02<05:57,  4.20s/it]

Epoch: 15, Loss: 1.6848876476287842


 16%|█▌        | 16/100 [01:06<05:45,  4.12s/it]

Epoch: 16, Loss: 3.374389171600342


 17%|█▋        | 17/100 [01:10<05:41,  4.11s/it]

Epoch: 17, Loss: 1.9758387804031372


 18%|█▊        | 18/100 [01:14<05:48,  4.25s/it]

Epoch: 18, Loss: 1.3678864240646362


 19%|█▉        | 19/100 [01:18<05:34,  4.13s/it]

Epoch: 19, Loss: 1.661690592765808


 20%|██        | 20/100 [01:22<05:25,  4.07s/it]

Epoch: 20, Loss: 3.0064892768859863


 21%|██        | 21/100 [01:27<05:37,  4.27s/it]

Epoch: 21, Loss: 3.3410279750823975


 22%|██▏       | 22/100 [01:31<05:24,  4.17s/it]

Epoch: 22, Loss: 1.6410889625549316


 23%|██▎       | 23/100 [01:35<05:14,  4.09s/it]

Epoch: 23, Loss: 1.4748339653015137


 24%|██▍       | 24/100 [01:40<05:28,  4.32s/it]

Epoch: 24, Loss: 2.9077308177948


 25%|██▌       | 25/100 [01:43<05:13,  4.19s/it]

Epoch: 25, Loss: 2.0338001251220703


 26%|██▌       | 26/100 [01:47<05:02,  4.09s/it]

Epoch: 26, Loss: 1.7707874774932861


 27%|██▋       | 27/100 [01:52<05:12,  4.28s/it]

Epoch: 27, Loss: 1.5196064710617065


 28%|██▊       | 28/100 [01:56<05:00,  4.17s/it]

Epoch: 28, Loss: 1.7184540033340454


 29%|██▉       | 29/100 [02:00<04:50,  4.09s/it]

Epoch: 29, Loss: 2.062309503555298


 30%|███       | 30/100 [02:05<05:00,  4.29s/it]

Epoch: 30, Loss: 1.9644492864608765


 31%|███       | 31/100 [02:08<04:46,  4.16s/it]

Epoch: 31, Loss: 2.14428448677063


 32%|███▏      | 32/100 [02:12<04:36,  4.07s/it]

Epoch: 32, Loss: 2.3679356575012207


 33%|███▎      | 33/100 [02:17<04:46,  4.27s/it]

Epoch: 33, Loss: 2.5924949645996094


 34%|███▍      | 34/100 [02:21<04:34,  4.15s/it]

Epoch: 34, Loss: 2.1943230628967285


 35%|███▌      | 35/100 [02:25<04:24,  4.07s/it]

Epoch: 35, Loss: 1.8624025583267212


 36%|███▌      | 36/100 [02:30<04:32,  4.25s/it]

Epoch: 36, Loss: 1.4546563625335693


 37%|███▋      | 37/100 [02:33<04:20,  4.13s/it]

Epoch: 37, Loss: 1.6054141521453857


 38%|███▊      | 38/100 [02:37<04:11,  4.06s/it]

Epoch: 38, Loss: 2.587271213531494


 39%|███▉      | 39/100 [02:42<04:19,  4.25s/it]

Epoch: 39, Loss: 1.8662199974060059


 40%|████      | 40/100 [02:46<04:09,  4.16s/it]

Epoch: 40, Loss: 1.676074504852295


 41%|████      | 41/100 [02:50<04:00,  4.08s/it]

Epoch: 41, Loss: 2.71652889251709


 42%|████▏     | 42/100 [02:54<04:05,  4.23s/it]

Epoch: 42, Loss: 2.7993764877319336


 43%|████▎     | 43/100 [02:58<03:57,  4.16s/it]

Epoch: 43, Loss: 1.5105054378509521


 44%|████▍     | 44/100 [03:02<03:48,  4.08s/it]

Epoch: 44, Loss: 1.7108350992202759


 45%|████▌     | 45/100 [03:07<03:48,  4.16s/it]

Epoch: 45, Loss: 1.1944724321365356


 46%|████▌     | 46/100 [03:11<03:44,  4.16s/it]

Epoch: 46, Loss: 1.4973351955413818


 47%|████▋     | 47/100 [03:15<03:36,  4.08s/it]

Epoch: 47, Loss: 1.5779606103897095


 48%|████▊     | 48/100 [03:19<03:34,  4.13s/it]

Epoch: 48, Loss: 0.9685097932815552


 49%|████▉     | 49/100 [03:23<03:34,  4.20s/it]

Epoch: 49, Loss: 1.9786922931671143


 50%|█████     | 50/100 [03:27<03:27,  4.15s/it]

Epoch: 50, Loss: 1.0754640102386475


 51%|█████     | 51/100 [03:32<03:38,  4.46s/it]

Epoch: 51, Loss: 0.8921419382095337


 52%|█████▏    | 52/100 [03:37<03:29,  4.36s/it]

Epoch: 52, Loss: 1.8270515203475952


 53%|█████▎    | 53/100 [03:41<03:18,  4.22s/it]

Epoch: 53, Loss: 1.3453588485717773


 54%|█████▍    | 54/100 [03:45<03:15,  4.24s/it]

Epoch: 54, Loss: 1.703320026397705


 55%|█████▌    | 55/100 [03:49<03:11,  4.25s/it]

Epoch: 55, Loss: 0.9428843259811401


 56%|█████▌    | 56/100 [03:53<03:02,  4.14s/it]

Epoch: 56, Loss: 2.011641502380371


 57%|█████▋    | 57/100 [03:57<02:58,  4.16s/it]

Epoch: 57, Loss: 1.2622612714767456


 58%|█████▊    | 58/100 [04:02<02:58,  4.25s/it]

Epoch: 58, Loss: 4.091346263885498


 59%|█████▉    | 59/100 [04:06<02:49,  4.14s/it]

Epoch: 59, Loss: 1.4742070436477661


 60%|██████    | 60/100 [04:10<02:44,  4.12s/it]

Epoch: 60, Loss: 1.741357445716858


 61%|██████    | 61/100 [04:14<02:45,  4.24s/it]

Epoch: 61, Loss: 1.4301071166992188


 62%|██████▏   | 62/100 [04:18<02:37,  4.15s/it]

Epoch: 62, Loss: 1.671081781387329


 63%|██████▎   | 63/100 [04:22<02:30,  4.08s/it]

Epoch: 63, Loss: 1.574641466140747


 64%|██████▍   | 64/100 [04:27<02:34,  4.29s/it]

Epoch: 64, Loss: 1.7105236053466797


 65%|██████▌   | 65/100 [04:31<02:25,  4.17s/it]

Epoch: 65, Loss: 1.4652636051177979


 66%|██████▌   | 66/100 [04:34<02:18,  4.07s/it]

Epoch: 66, Loss: 1.1777663230895996


 67%|██████▋   | 67/100 [04:39<02:20,  4.27s/it]

Epoch: 67, Loss: 1.3724443912506104


 68%|██████▊   | 68/100 [04:43<02:12,  4.16s/it]

Epoch: 68, Loss: 1.2866837978363037


 69%|██████▉   | 69/100 [04:47<02:06,  4.07s/it]

Epoch: 69, Loss: 2.0251479148864746


 70%|███████   | 70/100 [04:52<02:08,  4.28s/it]

Epoch: 70, Loss: 1.0826435089111328


 71%|███████   | 71/100 [04:56<02:00,  4.16s/it]

Epoch: 71, Loss: 0.9856103658676147


 72%|███████▏  | 72/100 [05:00<01:55,  4.11s/it]

Epoch: 72, Loss: 1.9708340167999268


 73%|███████▎  | 73/100 [05:04<01:55,  4.29s/it]

Epoch: 73, Loss: 0.7935872673988342


 74%|███████▍  | 74/100 [05:08<01:48,  4.17s/it]

Epoch: 74, Loss: 1.3315649032592773


 75%|███████▌  | 75/100 [05:12<01:41,  4.08s/it]

Epoch: 75, Loss: 3.0733141899108887


 76%|███████▌  | 76/100 [05:17<01:42,  4.28s/it]

Epoch: 76, Loss: 0.8943969011306763


 77%|███████▋  | 77/100 [05:21<01:37,  4.24s/it]

Epoch: 77, Loss: 2.14449143409729


 78%|███████▊  | 78/100 [05:25<01:30,  4.13s/it]

Epoch: 78, Loss: 1.0924087762832642


 79%|███████▉  | 79/100 [05:30<01:30,  4.30s/it]

Epoch: 79, Loss: 2.0166471004486084


 80%|████████  | 80/100 [05:33<01:23,  4.17s/it]

Epoch: 80, Loss: 1.4941229820251465


 81%|████████  | 81/100 [05:37<01:17,  4.08s/it]

Epoch: 81, Loss: 1.2537046670913696


 82%|████████▏ | 82/100 [05:42<01:16,  4.27s/it]

Epoch: 82, Loss: 1.2746740579605103


 83%|████████▎ | 83/100 [05:46<01:10,  4.15s/it]

Epoch: 83, Loss: 1.1911382675170898


 84%|████████▍ | 84/100 [05:50<01:05,  4.06s/it]

Epoch: 84, Loss: 1.6215509176254272


 85%|████████▌ | 85/100 [05:55<01:04,  4.28s/it]

Epoch: 85, Loss: 1.722212314605713


 86%|████████▌ | 86/100 [05:58<00:58,  4.16s/it]

Epoch: 86, Loss: 1.1609619855880737


 87%|████████▋ | 87/100 [06:02<00:53,  4.08s/it]

Epoch: 87, Loss: 1.1838200092315674


 88%|████████▊ | 88/100 [06:07<00:50,  4.23s/it]

Epoch: 88, Loss: 1.0569336414337158


 89%|████████▉ | 89/100 [06:11<00:46,  4.19s/it]

Epoch: 89, Loss: 1.4495384693145752


 90%|█████████ | 90/100 [06:15<00:41,  4.12s/it]

Epoch: 90, Loss: 1.2206183671951294


 91%|█████████ | 91/100 [06:20<00:38,  4.32s/it]

Epoch: 91, Loss: 0.8731157779693604


 92%|█████████▏| 92/100 [06:24<00:33,  4.24s/it]

Epoch: 92, Loss: 0.8101173639297485


 93%|█████████▎| 93/100 [06:28<00:29,  4.14s/it]

Epoch: 93, Loss: 0.9698516726493835


 94%|█████████▍| 94/100 [06:32<00:25,  4.23s/it]

Epoch: 94, Loss: 0.9213635921478271


 95%|█████████▌| 95/100 [06:36<00:21,  4.22s/it]

Epoch: 95, Loss: 8.358936309814453


 96%|█████████▌| 96/100 [06:40<00:16,  4.13s/it]

Epoch: 96, Loss: 2.0286574363708496


 97%|█████████▋| 97/100 [06:45<00:12,  4.21s/it]

Epoch: 97, Loss: 1.3188148736953735


 98%|█████████▊| 98/100 [06:49<00:08,  4.23s/it]

Epoch: 98, Loss: 1.3365979194641113


 99%|█████████▉| 99/100 [06:53<00:04,  4.13s/it]

Epoch: 99, Loss: 1.445319652557373


100%|██████████| 100/100 [06:57<00:00,  4.18s/it]

Epoch: 100, Loss: 1.634057641029358





In [97]:
final_user_representations = {}
repr = pos_z
for i in range(pos_z.shape[0]):
  final_user_representations[i] = (i, repr[i].detach(), user_representations[i][1])

In [98]:
user_dataset = UserDataset(final_user_representations)

n = len(user_dataset)

indices = np.arange(n)
indices = np.random.permutation(indices)

train_indices = indices [:int(0.8*n)]
test_indices = indices[int(0.8*n):]

user_train_dataset = Subset(user_dataset, train_indices)
user_test_dataset = Subset(user_dataset, test_indices)

user_train_dataloader = DataLoader(user_train_dataset, batch_size=64, shuffle=True)
user_test_dataloader = DataLoader(user_test_dataset, batch_size=64, shuffle=True)

In [99]:
num_classes = 186
classes = np.arange(0, num_classes)

In [100]:
head_model = train2task()
torch.save(head_model.state_dict(), 'random_walk_infomax_linkpred_model.pth')

Epoch: 0, Loss: 4.424149513244629
Epoch: 1, Loss: 4.248197078704834
Epoch: 2, Loss: 4.152710914611816
Epoch: 3, Loss: 4.08593225479126
Epoch: 4, Loss: 4.034071922302246
Epoch: 5, Loss: 3.991478204727173
Epoch: 6, Loss: 3.9552371501922607
Epoch: 7, Loss: 3.923595666885376
Epoch: 8, Loss: 3.895565986633301
Epoch: 9, Loss: 3.8702945709228516
Epoch: 10, Loss: 3.847313165664673
Epoch: 11, Loss: 3.8262948989868164
Epoch: 12, Loss: 3.806889057159424
Epoch: 13, Loss: 3.7889153957366943
Epoch: 14, Loss: 3.772141456604004
Epoch: 15, Loss: 3.7564241886138916
Epoch: 16, Loss: 3.741687297821045
Epoch: 17, Loss: 3.7277867794036865
Epoch: 18, Loss: 3.714641809463501
Epoch: 19, Loss: 3.7021484375
Epoch: 20, Loss: 3.690312385559082
Epoch: 21, Loss: 3.679049015045166
Epoch: 22, Loss: 3.668288469314575
Epoch: 23, Loss: 3.6580002307891846
Epoch: 24, Loss: 3.6481425762176514
Epoch: 25, Loss: 3.638688087463379
Epoch: 26, Loss: 3.629626750946045
Epoch: 27, Loss: 3.6208925247192383
Epoch: 28, Loss: 3.61247682