In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import OneHotEncoder
from spacy.lang.ru.stop_words import STOP_WORDS

# Encoding

In [2]:
df_inter = pd.read_csv('./mod_data/interactions.csv')
df_items = pd.read_csv('./mod_data/items.csv')
df_users = pd.read_csv('./mod_data/users.csv')

## Items

In [3]:
df_items.sample(1)

Unnamed: 0,item_id,content_type,title,genres,age_rating,keywords
4229,7434,series,Далеко от тебя,"мелодрамы, комедии",18,"Далеко, от, тебя, 2019, Россия"


In [4]:
# keywords extraction
def tokenize(line):
    vec = [word for word in line.split(', ') if not word.isnumeric() and word not in {'', 'nan'}]
    return vec

keywords_vectorizer = TfidfVectorizer(
    tokenizer=tokenize, token_pattern=None,
    max_features=100, stop_words=list(STOP_WORDS)
)
X_keywords = keywords_vectorizer.fit_transform(df_items['keywords'].values.astype('U'))

# content type preprocessing
X_content_type = (df_items['content_type'] == 'film').astype(int).values.reshape(-1, 1)

# age rating preprocessing
age_rating_encoder = OneHotEncoder()
X_age_rating = age_rating_encoder.fit_transform(df_items['age_rating'].values.astype('U').reshape(-1, 1)).toarray()

# genres preprocessing
genres_vectorizer = TfidfVectorizer(tokenizer=tokenize, token_pattern=None)
X_genres = genres_vectorizer.fit_transform(df_items['genres'].values.astype('U')).toarray()

In [5]:
pd.Series(data=np.squeeze(np.asarray(X_keywords.sum(axis=0))), index=keywords_vectorizer.get_feature_names_out())\
    .sort_values(ascending=False)[:10]

россия                     2646.475781
соединенные штаты          1650.497265
отношения                  1344.449214
франция                    1185.161904
сша                        1085.736725
ссср                        756.435969
любовь                      633.931426
дружба                      599.568120
соединенное королевство     495.680623
женщины                     490.716083
dtype: float64

In [6]:
X_items = np.hstack([
    X_content_type,
    X_genres,
    X_age_rating,
    X_keywords.todense()
])
X_items.shape

(15963, 202)

## Users

In [7]:
df_users.sample(1)

Unnamed: 0,user_id,age,income,sex,kids_flg
171442,1066107,age_35_44,income_40_60,Ж,1


In [8]:
# age extraction
age_encoder = OneHotEncoder()
X_age = age_encoder.fit_transform(df_users['age'].values.astype('U').reshape(-1, 1)).toarray()

# income extraction
income_encoder = OneHotEncoder()
X_income = income_encoder.fit_transform(df_users['income'].values.astype('U').reshape(-1, 1)).toarray()

# sex extraction
X_sex = (df_users['sex'] == 'М').astype(int).values.reshape(-1, 1)

In [9]:
X_users = np.hstack([
    X_age,
    X_income,
    X_sex,
    df_users['kids_flg'].to_numpy().reshape(-1, 1)
])
X_users.shape

(840197, 14)

# GNN

In [10]:
#pip install pyg-lib -f https://data.pyg.org/whl/nightly/torch-2.1.0+cu121.html

import torch
import torch.nn.functional as F
import torch_geometric.nn as gnn
from torch_geometric.data import HeteroData
import torch_geometric.transforms as T
from torch_geometric import EdgeIndex
from torch_geometric.loader import LinkNeighborLoader, NeighborLoader
from torch_geometric.metrics import (
    LinkPredMAP,
    LinkPredPrecision,
    LinkPredRecall,
)
import os
from tqdm import tqdm
os.environ['PYDEVD_DISABLE_FILE_VALIDATION']='1'
os.environ['KMP_DUPLICATE_LIB_OK']='True'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [11]:
user_id_to_idx = {user_id: idx for idx, user_id in df_users['user_id'].items()}
item_id_to_idx = {item_id: idx for idx, item_id in df_items['item_id'].items()}

edges = np.vstack([
    df_inter['user_id'].map(user_id_to_idx).values,
    df_inter['item_id'].map(item_id_to_idx).values
])

connections = np.ones(edges.shape[1])
time = pd.to_datetime(df_inter['last_watch_dt'], format='%Y-%m-%d').values.astype(np.int64) // 10**9

In [12]:
data = HeteroData()

# Add item node features
data['movie'].x = torch.Tensor(X_items)
# Add user node features for message passing
data['user'].x = torch.Tensor(X_users)

# Add edges (users connections with movies)
data['user', 'watched', 'movie'].edge_index = torch.tensor(edges)
data['user', 'watched', 'movie'].time = torch.tensor(time)

# Add a reverse relation for message passing
data = T.ToUndirected()(data)
data

HeteroData(
  movie={ x=[15963, 202] },
  user={ x=[840197, 14] },
  (user, watched, movie)={
    edge_index=[2, 1288996],
    time=[1288996],
  },
  (movie, rev_watched, user)={
    edge_index=[2, 1288996],
    time=[1288996],
  }
)

In [13]:
# Number of predictions
k = 10

# Perform a temporal link-level split into training and test edges:
edge_label_index = data['user', 'movie'].edge_index
time = data['user', 'movie'].time

perm = time.argsort()
train_index = perm[:int(0.8 * perm.numel())]
test_index = perm[int(0.8 * perm.numel()):]

kwargs = dict(  # Shared data loader arguments:
    data=data,
    num_neighbors=[10, 10],
    batch_size=256,
    time_attr='time',
    num_workers=0,
    persistent_workers=False,
    temporal_strategy='last',
)

train_loader = LinkNeighborLoader(
    edge_label_index=(('user', 'movie'), edge_label_index[:, train_index]),
    edge_label_time=time[train_index] - 1,  # No leakage.
    neg_sampling=dict(mode='binary', amount=2),
    shuffle=True,
    **kwargs,
)


# During testing, we sample node-level subgraphs from both endpoints to
# retrieve their embeddings.
# This allows us to do efficient k-NN search on top of embeddings:
src_loader = NeighborLoader(
    input_nodes='user',
    input_time=(time[test_index].min() - 1).repeat(data['user'].num_nodes),
    **kwargs,
)
dst_loader = NeighborLoader(
    input_nodes='movie',
    input_time=(time[test_index].min() - 1).repeat(data['movie'].num_nodes),
    **kwargs,
)

# Save test edges and the edges we want to exclude when evaluating:
sparse_size = (data['user'].num_nodes, data['movie'].num_nodes)
test_edge_label_index = EdgeIndex(
    edge_label_index[:, test_index].to(device),
    sparse_size=sparse_size,
).sort_by('row')[0]
test_exclude_links = EdgeIndex(
    edge_label_index[:, train_index].to(device),
    sparse_size=sparse_size,
).sort_by('row')[0]


class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.conv1 = gnn.SAGEConv((-1, -1), hidden_channels)
        self.conv2 = gnn.SAGEConv((-1, -1), hidden_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x


class InnerProductDecoder(torch.nn.Module):
    def forward(self, x_dict, edge_label_index):
        x_src = x_dict['user'][edge_label_index[0]]
        x_dst = x_dict['movie'][edge_label_index[1]]
        return (x_src * x_dst).sum(dim=-1)


class Model(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.encoder = GNN(hidden_channels)
        self.encoder = gnn.to_hetero(self.encoder, data.metadata(), aggr='sum')
        self.decoder = InnerProductDecoder()

    def forward(self, x_dict, edge_index_dict, edge_label_index):
        x_dict = self.encoder(x_dict, edge_index_dict)
        return self.decoder(x_dict, edge_label_index)


model = Model(hidden_channels=64).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train():
    model.train()
    total_loss = total_examples = 0
    for batch in tqdm(train_loader):
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(
            batch.x_dict,
            batch.edge_index_dict,
            batch['user', 'movie'].edge_label_index,
        )
        y = batch['user', 'movie'].edge_label

        loss = F.binary_cross_entropy_with_logits(out, y)
        loss.backward()
        optimizer.step()

        total_loss += float(loss) * y.numel()
        total_examples += y.numel()

    return total_loss / total_examples


@torch.no_grad()
def test(edge_label_index, exclude_links):
    model.eval()

    dst_embs = []
    for batch in dst_loader:  # Collect destination node/movie embeddings:
        batch = batch.to(device)
        emb = model.encoder(batch.x_dict, batch.edge_index_dict)['movie'].cpu()
        emb = emb[:batch['movie'].batch_size]
        dst_embs.append(emb)
    dst_emb = torch.cat(dst_embs, dim=0)
    del dst_embs

    # Instantiate k-NN index based on maximum inner product search (MIPS):
    mips = gnn.MIPSKNNIndex(dst_emb)

    # Initialize metrics:
    map_metric = LinkPredMAP(k=k)
    precision_metric = LinkPredPrecision(k=k)
    recall_metric = LinkPredRecall(k=k)

    num_processed = 0
    for batch in src_loader:  # Collect source node/user embeddings:
        batch = batch.to(device)

        # Compute user embeddings:
        emb = model.encoder(batch.x_dict, batch.edge_index_dict)['user']
        emb = emb[:batch['user'].batch_size]

        # Filter labels/exclusion by current batch:
        _edge_label_index = edge_label_index.sparse_narrow(
            dim=0,
            start=num_processed,
            length=emb.size(0),
        ).cpu()
        _exclude_links = exclude_links.sparse_narrow(
            dim=0,
            start=num_processed,
            length=emb.size(0),
        ).cpu()
        num_processed += emb.size(0)


        # Perform MIPS search:
        _, pred_index_mat = mips.search(emb.cpu(), k, _exclude_links)

        # Update retrieval metrics:
        map_metric.update(pred_index_mat, _edge_label_index)
        precision_metric.update(pred_index_mat, _edge_label_index)
        recall_metric.update(pred_index_mat, _edge_label_index)

    return (
        float(map_metric.compute()),
        float(precision_metric.compute()),
        float(recall_metric.compute()),
    )


for epoch in range(1, 16):
    train_loss = train()
    print(f'Epoch: {epoch:02d}, Loss: {train_loss:.4f}')
    val_map, val_precision, val_recall = test(
        test_edge_label_index,
        test_exclude_links,
    )
    print(f'Test MAP@{k}: {val_map:.4f}, '
          f'Test Precision@{k}: {val_precision:.4f}, '
          f'Test Recall@{k}: {val_recall:.4f}')

100%|██████████| 4029/4029 [01:35<00:00, 42.33it/s]


Epoch: 01, Loss: 0.2546
Test MAP@10: 0.0604, Test Precision@10: 0.0270, Test Recall@10: 0.1847


 28%|██▊       | 1111/4029 [00:26<01:10, 41.47it/s]


KeyboardInterrupt: 