In [None]:
import torch

from project.config import TARGET_EDGE

from project.model import MF
from project.utils.loss import BPRLoss
from torch_geometric.loader import LinkNeighborLoader
from torch.optim import Adam

from functools import partial

from project.utils.train import dispatch_epoch, dispatch_session
from project.utils.score import composite, recall_score, ndcg_score
from project.model import mf

## Data Handling

In [None]:
# Loads the datasets to compute the loss.
trn_data = torch.load('data/out/trn_Video_Games.pt')
vld_data = torch.load('data/out/vld_Video_Games.pt')

# Extracts the edge attribute indices for all datasets.
trn_edge_label_index = trn_data[TARGET_EDGE].edge_label_index
vld_edge_label_index = vld_data[TARGET_EDGE].edge_label_index
# Specifies the shared key-word arguments for the batch loaders.
kwargs = dict(
    num_neighbors=[0],  # [8, 4, 2],
    neg_sampling='triplet',
    num_workers=10,
    shuffle=True,
    pin_memory=True
)
# Creates the sub-graph loaders for the loss .
trn_loader = LinkNeighborLoader(**kwargs,
    data=trn_data,
    edge_label_index=[TARGET_EDGE, trn_edge_label_index],
    batch_size=2048
)
vld_loader = LinkNeighborLoader(**kwargs,
    data=vld_data,
    edge_label_index=[TARGET_EDGE, vld_edge_label_index],
    batch_size=2048
)

# Builds the dataset for the all-ranking protocol.
rnk_data = torch.load('data/out/rnk_vld_Video_Games.pt')
# Extracts the edge label index for the all-ranking data.
rnk_edge_label_index = rnk_data[TARGET_EDGE].edge_label_index
rnk_edge_label = rnk_data[TARGET_EDGE].edge_label
# Creates a batch loader for the all-ranking data.
rnk_loader = LinkNeighborLoader(
    data=rnk_data,
    edge_label_index=[TARGET_EDGE, rnk_edge_label_index],
    edge_label=rnk_edge_label,
    num_neighbors=[0],  # [8, 4, 2],
    batch_size=8192,
    num_workers=10,
    pin_memory=True
)

## Training

In [None]:
# Creates the model to train.
model = MF(
    num_embeddings=trn_data.num_nodes_dict,
    embedding_dim=64
)
display(model)

# Instanciates the learning algorithm and loss criterion.
optimizer = Adam(model.parameters(), lr=1e-3)
loss_fn = BPRLoss(model.parameters(), reg_factor=1e-4)

# Builds the evaluation function.
batch_handler = partial(mf.evaluate, edge_type=TARGET_EDGE)
# Builds the update and validate function.
update_fn = partial(dispatch_epoch,
    loader=trn_loader,
    loss_fn=loss_fn,
    optimizer=optimizer,
    batch_handler=batch_handler,
)
validate_fn = partial(dispatch_epoch,
    loader=vld_loader,
    loss_fn=loss_fn,
    batch_handler=batch_handler,
)

# Builds the score function.
score_fn = partial(composite,
    loader=rnk_loader,
    pred_fn=mf.predict,
    edge_type=TARGET_EDGE,
    score_fns=[
        recall_score, 
        ndcg_score
    ],
    at_k=20
)

In [None]:
dispatch_session(
    module=model,
    update_fn=update_fn,
    validate_fn=validate_fn,
    score_fn=score_fn,
    verbose=True,
    num_epochs=32,
    path='/home/jowin/Git/GNN-CF/out/mf/1e-4'
)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pickle

from pandas import DataFrame


# Setting the plotting style.
sns.set()

# Loading the data.
with open('/home/jowin/Git/GNN-CF/out/mf/1e-4/trc.pkl', 'rb') as file:
    trace = pickle.load(file)

# Prepares the data for plotting.
score = DataFrame(trace['score'], columns=['Recall', 'NDCG']) \
    .mul(100) \
    .rename_axis(index='epoch', columns='metric') \
    .stack() \
    .rename('score') \
    .reset_index()

# Plots the score trace.
fig, ax = plt.subplots(figsize=[7, 3], dpi=192)
ax = sns.lineplot(
    data=score, 
    x='epoch', 
    y='score', 
    hue='metric', 
    style='metric', 
    ax=ax
)
ax.set_xlabel('Epoch')
ax.set_ylabel('Score@20 (%)')
ax.legend(title='Metric')
plt.show(fig)