In [27]:
# pip install pyg-lib -f https://data.pyg.org/whl/nightly/torch-2.1.0+cu121.html

import torch
from torch.nn.functional import binary_cross_entropy_with_logits

from torch_geometric import EdgeIndex
from torch_geometric.utils import degree
from torch_geometric.loader import LinkNeighborLoader, NeighborLoader
from torch_geometric.nn import MIPSKNNIndex
from torch_geometric.metrics import LinkPredMAP, LinkPredPrecision, LinkPredRecall

from copy import deepcopy
from tqdm import tqdm
# import os
# os.environ['PYDEVD_DISABLE_FILE_VALIDATION']='1'
# os.environ['KMP_DUPLICATE_LIB_OK']='True'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## HeteroData, train/test split

In [28]:
from utils.data_transformation import data_to_heterograph

data = data_to_heterograph('encoded_data.npz', temporal_order=True)
data

HeteroData(
  movie={ x=[11909, 200] },
  user={ x=[448798, 14] },
  (user, watched, movie)={
    edge_index=[2, 1288996],
    time=[1288996],
  },
  (movie, rev_watched, user)={
    edge_index=[2, 1288996],
    time=[1288996],
  }
)

In [29]:
# vary sparse interaction matrix
1288996 / (11909 * 448798)

0.0002411711540184161

In [30]:
train_size = int(0.8 * data['user', 'movie'].num_edges)
watch_threshold = 5

edges = data['user', 'movie'].edge_index
time = data['user', 'movie'].time

loader_kwargs = dict(
    data=data, batch_size=256,
    num_neighbors=[5, 5, 5],
    time_attr='time', temporal_strategy='last',
    num_workers=0)

train_loader = LinkNeighborLoader(
    edge_label_index=(('user', 'movie'), edges[:, :train_size]),
    edge_label_time=time[torch.arange(train_size)]-1,
    neg_sampling=dict(mode='binary', amount=2),
    shuffle=True,
    **loader_kwargs)

user_loader = NeighborLoader(
    input_nodes='user',
    input_time=(time[train_size:].min()-1).repeat(data['user'].num_nodes),
    **loader_kwargs)

movie_loader = NeighborLoader(
    input_nodes='movie',
    input_time=(time[train_size:].min()-1).repeat(data['movie'].num_nodes),
    **loader_kwargs)

sparse_size = data['user'].num_nodes, data['movie'].num_nodes
train_edges = EdgeIndex(edges[:, :train_size].contiguous().to(device),
                        sparse_size=sparse_size).sort_by('row').values
test_edges = EdgeIndex(edges[:, train_size:].contiguous().to(device),
                       sparse_size=sparse_size).sort_by('row').values

is_test_node = degree(train_edges[0], num_nodes=data['user'].num_nodes) >= watch_threshold
test_edges = test_edges[:, is_test_node[test_edges[0]]]
train_edges = train_edges[:, is_test_node[train_edges[0]]]

## MetaPath2Vec

In [35]:
from utils.node_representation import Metapath2Vec

train_edge_index_dict = {
    ('user', 'watched', 'movie'):
        data.edge_index_dict[('user', 'watched', 'movie')][:, :train_size],
    ('movie', 'rev_watched', 'user'):
        data.edge_index_dict[('movie', 'rev_watched', 'user')][:, :train_size]}

mp2v = Metapath2Vec(train_edge_index_dict, data.num_nodes_dict, device=device)
mp2v.train()

users_emb = mp2v.get_embeddings('user').numpy()
movies_emb = mp2v.get_embeddings('movie').numpy()

## GNN

In [None]:
data['user'].x = torch.cat([data['user'].x, users_emb], dim=1)
data['movie'].x = torch.cat([data['movie'].x, movies_emb], dim=1)

In [None]:
from utils.model import GNN

gnn_model = GNN(data.metadata(), hidden_channels=64, decoder='IP',
                dropout_encoder_p=0.2, dropout_decoder_p=0.4)

optimizer = torch.optim.Adam(gnn_model.parameters(), lr=0.01)

In [None]:
from utils.data_transformation import sparse_batch_narrow

def train():
    gnn_model.train()
    total_loss = total_examples = 0
    for batch in tqdm(train_loader):
        batch = batch.to(device)
        batch_size = len(batch['user', 'movie'].edge_label)
        optimizer.zero_grad()

        out = gnn_model(batch.x_dict, batch.edge_index_dict,
                        batch['user', 'movie'].edge_label_index)
        target = batch['user', 'movie'].edge_label

        loss = binary_cross_entropy_with_logits(out, target)
        loss.backward(); optimizer.step()

        total_loss += float(loss) * batch_size
        total_examples += batch_size

    return total_loss / total_examples

@torch.no_grad()
def test(k: int, test_edges: EdgeIndex, train_edges: EdgeIndex):
    gnn_model.eval()
    movie_embs = []
    for batch in movie_loader:
        batch = batch.to(device)
        batch_size = batch['movie'].batch_size
        batch_movie_embs = gnn_model.encoder(batch.x_dict, batch.edge_index_dict)\
            ['movie'][:batch_size].cpu()
        movie_embs.append(batch_movie_embs)
    
    mipsknn = MIPSKNNIndex( torch.cat(movie_embs, dim=0) )
    metrics = LinkPredMAP(k), LinkPredPrecision(k), LinkPredRecall(k)
    users_infered = 0
    for batch in user_loader:
        batch = batch.to(device)
        batch_size = batch['user'].batch_size
        batch_user_embs = gnn_model.encoder(batch.x_dict, batch.edge_index_dict)\
            ['user'][:batch_size].cpu()
        
        batch_test_user_embs =\
            batch_user_embs[is_test_node[users_infered:users_infered+batch_size]]

        batch_test_edges = sparse_batch_narrow(test_edges, users_infered, batch_size).cpu()
        batch_train_edges = sparse_batch_narrow(train_edges, users_infered, batch_size).cpu()

        top_indices_mat = mipsknn.search(batch_test_user_embs, k, exclude_links=batch_train_edges)[1]
        for metric in metrics:
            metric.update(top_indices_mat, batch_test_edges)

        users_infered += batch_size

    return tuple(float(metric.compute()) for metric in metrics)

k = 10
for epoch_num in range(1, 16):
    loss = train()
    print(f'Train: Epoch №{epoch_num:02d}, Loss: {loss:.4f}')
    map, precision, recall = test(k, test_edges, train_edges)
    print('Test@%d, MAP: %.4f, Precision: %.4f, Recall: %.4f' % k, map, precision, recall)