In [1]:
# cell for Google Colab ipynb opening

## %%capture
# !pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://pytorch-geometric.com/whl/torch-2.2.1+cu121.html
# !pip install pyg-lib -f https://data.pyg.org/whl/nightly/torch-2.2.0+cu121.html
# !pip install faiss-gpu

In [1]:
import numpy as np
import pandas as pd
import torch

from torch_geometric import EdgeIndex
from torch_geometric.utils import degree
from torch_geometric.loader import LinkNeighborLoader, NeighborLoader
from torch_geometric.nn import MIPSKNNIndex
from torch_geometric.metrics import LinkPredMAP, LinkPredPrecision, LinkPredRecall
from torch_geometric.nn.models.lightgcn import BPRLoss

from tqdm import tqdm
import os
os.environ['PYDEVD_DISABLE_FILE_VALIDATION']='1'
os.environ['KMP_DUPLICATE_LIB_OK']='True'

from utils.model import GNN

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
faiss_device = torch.device('cpu') # 'cuda' for Colab, 'cpu' for my Windows laptop

## HeteroData, train/test split

In [2]:
from utils.data_transformation import data_to_heterograph

data, users_rev_mapping, movies_rev_mapping = data_to_heterograph('encoded_data.npz')
data

HeteroData(
  movie={ x=[15008, 202] },
  user={ x=[744288, 14] },
  (user, watched, movie)={
    edge_index=[2, 4424477],
    time=[4424477],
  },
  (movie, rev_watched, user)={
    edge_index=[2, 4424477],
    time=[4424477],
  }
)

In [3]:
# sparse interaction matrix
4424477 / (15008 * 744288)

0.0003960938540619029

In [4]:
train_ratio = 0.8
train_size = int(train_ratio * data['user', 'movie'].num_edges)
watch_threshold = 5

edges = data['user', 'movie'].edge_index
time = data['user', 'movie'].time

loader_kwargs = dict(
    data=data, batch_size=1024,
    num_neighbors=[5, 5, 5],
    time_attr='time', temporal_strategy='last',
    num_workers=4)

train_loader = LinkNeighborLoader(
    edge_label_index=(('user', 'movie'), edges[:, :train_size]),
    edge_label_time=time[torch.arange(train_size)]-1,
    neg_sampling=dict(mode='binary', amount=1),
    shuffle=True,
    **loader_kwargs)

user_loader = NeighborLoader(
    input_nodes='user',
    input_time=(time[train_size]-1).repeat(data['user'].num_nodes),
    **loader_kwargs)

movie_loader = NeighborLoader(
    input_nodes='movie',
    input_time=(time[train_size]-1).repeat(data['movie'].num_nodes),
    **loader_kwargs)

sparse_size = data['user'].num_nodes, data['movie'].num_nodes
train_edges = EdgeIndex(edges[:, :train_size].contiguous().to(device),
                        sparse_size=sparse_size).sort_by('row').values
test_edges = EdgeIndex(edges[:, train_size:].contiguous().to(device),
                       sparse_size=sparse_size).sort_by('row').values

is_test_node = degree(train_edges[0], num_nodes=data['user'].num_nodes) >= watch_threshold

test_edges = test_edges[:, is_test_node[test_edges[0]]]
train_edges = train_edges[:, is_test_node[train_edges[0]]]

## GNN

In [7]:
gnn_model = GNN(num_layers=3, hidden_channels=64).to(device)
optimizer = torch.optim.Adam(gnn_model.parameters(), lr=0.005, weight_decay=1e-4)
bpr_loss = BPRLoss()

In [8]:
from utils.data_transformation import sparse_batch_narrow

def train():
    gnn_model.train()
    total_loss = total_examples = 0
    for batch in tqdm(train_loader):
        batch = batch.to(device)
        batch_size = len(batch['user', 'movie'].input_id)
        optimizer.zero_grad()

        out = gnn_model(batch.x_dict, batch.edge_index_dict,
                        batch['user', 'movie'].edge_label_index)

        loss = bpr_loss(out[:batch_size], out[batch_size:])
        loss.backward(); optimizer.step()

        total_loss += float(loss) * batch_size
        total_examples += batch_size

    return total_loss / total_examples

@torch.no_grad()
def test(test_edges: EdgeIndex, train_edges: EdgeIndex, k: int, top_count: int = None):
    gnn_model.eval()
    movie_embs = gnn_model.get_movies_embeddings(movie_loader, device)

    if top_count is None:
        movie_embs = torch.cat(movie_embs, dim=0)
    else:
        test_edges[1, test_edges[1] >= top_count] = top_count
        train_edges[1, train_edges[1] >= top_count] = top_count
        emb_size = movie_embs[0].size()[1]
        movie_embs.append(torch.zeros((1, emb_size), device=device))
        movie_embs = torch.cat(movie_embs, dim=0)
        top_indices = list(range(top_count)) + [len(movie_embs)-1]
        movie_embs = movie_embs[top_indices]
    
    mipsknn = MIPSKNNIndex(movie_embs.to(faiss_device))
    metrics = LinkPredMAP(k), LinkPredPrecision(k), LinkPredRecall(k)
    users_infered = 0
    for batch in user_loader:
        batch = batch.to(device)
        batch_size = batch['user'].batch_size
        batch_user_embs = gnn_model.encoder(batch.x_dict, batch.edge_index_dict)\
            ['user'][:batch_size]
        
        batch_test_user_embs =\
            batch_user_embs[is_test_node[users_infered:users_infered+batch_size]].to(faiss_device)

        batch_test_edges = sparse_batch_narrow(test_edges, users_infered, batch_size)
        batch_train_edges = sparse_batch_narrow(train_edges, users_infered, batch_size).to(faiss_device)

        top_indices_mat = mipsknn.search(batch_test_user_embs, k, exclude_links=batch_train_edges)[1]
        for metric in metrics:
            metric.update(top_indices_mat.cpu(), batch_test_edges)

        users_infered += batch_size

    return tuple(float(metric.compute()) for metric in metrics)

k = 20
metrics_list = []
try:
    for epoch_num in range(1, 30):
        loss = train()
        print(f'Train: Epoch №{epoch_num:02d}, Loss: {loss:.4f}')
        map, precision, recall = test(test_edges, train_edges, k=20, top_count=2000)
        print('Test@%d, MAP: %.4f, Precision: %.4f, Recall: %.4f' % (k, map, precision, recall))
        metrics_list.append([loss, map, precision, recall])

except KeyboardInterrupt:
  print('--KeyboardInterrupt--')

  0%|          | 0/3457 [00:00<?, ?it/s]

100%|██████████| 3457/3457 [07:48<00:00,  7.38it/s]


Train: Epoch №01, Loss: 0.1225
Test@20, MAP: 0.0304, Precision: 0.0192, Recall: 0.1208


100%|██████████| 3457/3457 [07:52<00:00,  7.31it/s]


Train: Epoch №02, Loss: 0.1056
Test@20, MAP: 0.0380, Precision: 0.0199, Recall: 0.1241


100%|██████████| 3457/3457 [07:55<00:00,  7.28it/s]


Train: Epoch №03, Loss: 0.1031
Test@20, MAP: 0.0310, Precision: 0.0197, Recall: 0.1224


100%|██████████| 3457/3457 [07:50<00:00,  7.34it/s]


Train: Epoch №04, Loss: 0.1020
Test@20, MAP: 0.0305, Precision: 0.0199, Recall: 0.1222


100%|██████████| 3457/3457 [07:35<00:00,  7.59it/s]


Train: Epoch №05, Loss: 0.1014
Test@20, MAP: 0.0329, Precision: 0.0197, Recall: 0.1205


100%|██████████| 3457/3457 [07:56<00:00,  7.25it/s]


Train: Epoch №06, Loss: 0.1010
Test@20, MAP: 0.0389, Precision: 0.0204, Recall: 0.1274


100%|██████████| 3457/3457 [07:44<00:00,  7.44it/s]


Train: Epoch №07, Loss: 0.1002
Test@20, MAP: 0.0323, Precision: 0.0195, Recall: 0.1175


100%|██████████| 3457/3457 [07:30<00:00,  7.68it/s]


Train: Epoch №08, Loss: 0.0999
Test@20, MAP: 0.0302, Precision: 0.0195, Recall: 0.1219


 47%|████▋     | 1610/3457 [03:32<04:04,  7.56it/s]

--KeyboardInterrupt--





In [137]:
torch.save(gnn_model.state_dict(), './gnn_state.pt')

## Example of making recommendations by Graph Neural Network

In [5]:
%%capture
gnn_model = GNN(num_layers=3, hidden_channels=64)
gnn_model.load_state_dict(torch.load('./gnn_state.pt'))
gnn_model.to(device)

In [6]:
df_inter = pd.read_csv('processed_data/interactions.csv')
df_items = pd.read_csv('processed_data/items.csv')
df_users = pd.read_csv('processed_data/users.csv')

In [40]:
user_id = 6343
df_inter[df_inter['user_id'] == user_id].merge(df_items)

Unnamed: 0,user_id,item_id,last_watch_dt,watched_pct,content_type,title,genres,age_rating,keywords,views
0,6343,13867,2021-07-26,100.0,film,Рио,"мультфильм, криминал, приключения, мюзиклы, ко...",0,"домашнее животное, птица, попугай, мюзикл, кан...",2074
1,6343,15568,2021-07-26,100.0,film,Как поймать перо Жар-Птицы,"мультфильм, фэнтези",0,"2013, россия, как, поймать, перо, жар, птицы",853
2,6343,85,2021-07-29,13.0,film,Турбо,"мультфильм, приключения, спорт, фантастика, ко...",0,"аутсайдер, гоночный автомобиль, мечта, скорост...",914
3,6343,7310,2021-07-29,27.0,film,Гадкий я 2,"мультфильм, приключения, фантастика, фэнтези, ...",0,"отношения родитель-ребенок, секретный агент, п...",3289
4,6343,565,2021-08-02,34.0,film,Ледниковый период 3: Эра динозавров,"мультфильм, приключения, комедии",0,"ледниковый период, мост, безумие, джунгли, дин...",2687


In [41]:
df_users[df_users['user_id'] == user_id]

Unnamed: 0,user_id,age,income,sex,kids_flg
543304,6343,age_25_34,income_40_60,Ж,0


In [42]:
users_mapping = {user_id: idx for idx, user_id in users_rev_mapping.items()}
movies_mapping = {movie_id: idx for idx, movie_id in movies_rev_mapping.items()}
movies_id_to_df_idx = {movie_id: df_idx for df_idx, movie_id in df_items['item_id'].items()}

In [43]:
top_count, k = 1000, 20
with torch.no_grad():
    gnn_model.eval()
    movie_embs = gnn_model.get_movies_embeddings(movie_loader, device)
    movie_embs = torch.cat(movie_embs, dim=0).to(faiss_device)
    if top_count is not None:
        movie_embs = movie_embs[:top_count]

    mipsknn = MIPSKNNIndex(movie_embs)

    user_neighbors_loader = NeighborLoader(
    data=data,
    num_neighbors=[5, 5, 5],
    time_attr='time', temporal_strategy='last',
    input_nodes=('user', torch.tensor([users_mapping[user_id]])),
    input_time=(time[train_size]-1).repeat(1))

    batch = next(iter(user_neighbors_loader)).to(device)
    user_emb = gnn_model.encoder(batch.x_dict, batch.edge_index_dict)\
        ['user'][0].unsqueeze(0).to(faiss_device)
    
    user_edges_mask = (train_edges[0] == users_mapping[user_id]) & (train_edges[1] < top_count)
    user_train_edges = train_edges[:, user_edges_mask].to(faiss_device)
    user_train_edges[0, :] = torch.zeros(user_train_edges.size()[1])

    top_indices_mat = mipsknn.search(user_emb, k, exclude_links=user_train_edges)[1]

In [44]:
rec_ids = np.vectorize(movies_rev_mapping.get)(top_indices_mat.squeeze())
rec_df_idx = np.vectorize(movies_id_to_df_idx.get)(rec_ids)
df_items.iloc[rec_df_idx].head().reset_index(drop=True)

Unnamed: 0,item_id,content_type,title,genres,age_rating,keywords,views
0,12743,film,Ледниковый период 4: Континентальный дрейф,"мультфильм, приключения, комедии",0,"тюлень (животное), доисторические времена, таю...",5410
1,12988,film,Гномео и Джульетта,"мелодрамы, мультфильм, приключения, комедии",0,"сад, запретная любовь, поцелуй, садовый гном, ...",1826
2,14942,film,История игрушек: Большой побег,"мультфильм, фэнтези, комедии",6,"заложник, колледж, игрушка, побег, детский сад...",2525
3,16270,film,Тайна Коко,"мультфильм, фэнтези, приключения",12,"Мексика, гитара, музыкант, скелет, музыка, заг...",6201
4,6774,film,Тачки 2,"мультфильм, комедии",0,"автомобильная гонка, продолжение, антропоморфи...",3287
