In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
from loguru import logger

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.sparse as sparse
import torch.optim as optim

import pandas as pd
import plotly.express as px

sys.path.insert(0, '..')

from src.viz import blueq_colors

# Implement

In [3]:
from src.train_utils import mse_loss, train, MetricLogCallback
from src.model import LightGCN

In [4]:
# device = (
#     "cuda"
#     if torch.cuda.is_available()
#     else "mps"
#     if torch.backends.mps.is_available()
#     else "cpu"
# )
device = 'cpu'
logger.info(f"Using {device} device")

[32m2024-09-09 00:07:12.291[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m9[0m - [1mUsing cpu device[0m


# Test implementation

In [5]:
# Mock data
user_ids = [0, 0, 1, 2, 2]
item_ids = [0, 1, 2, 3, 1]
ratings = [1, 4, 5, 3, 2]
n_users = len(set(user_ids))
n_items = len(set(item_ids))

val_user_ids = [0, 1, 2]
val_item_ids = [2, 1, 2]
val_ratings = [2, 4, 5]

model = LightGCN(embedding_dim=64, n_layers=3, user_ids=user_ids, item_ids=item_ids, ratings=ratings, device=device)

# Example forward pass
users = torch.tensor([0, 1, 2])
items = torch.tensor([0, 1, 2])
predictions = model.predict(users, items)
print(predictions)

tensor([-0.0043,  0.0017, -0.0015], grad_fn=<SumBackward1>)


In [6]:
import random
import numpy as np
from torch.utils.data import DataLoader
from src.dataset_loader import UserItemRatingDataset

# Fixing random seed for reproducibility
random.seed(42)
np.random.seed(42)

embedding_dim = 8
n_layers = 3
batch_size = 4

# Display mock dataset
print("Mock User IDs:", user_ids)
print("Mock Item IDs:", item_ids)
print("Ratings:", ratings)

# class RatingDataset(Dataset):
#     def __init__(self, user_ids, item_ids, ratings):
#         """
#         Args:
#             user_ids (list or array): List of user indices.
#             item_ids (list or array): List of item indices.
#             ratings (list or array): List of corresponding ratings.
#         """
#         self.user_ids = user_ids
#         self.item_ids = item_ids
#         self.ratings = ratings
    
#     def __len__(self):
#         return len(self.user_ids)
    
#     def __getitem__(self, idx):
#         user = self.user_ids[idx]
#         item = self.item_ids[idx]
#         rating = self.ratings[idx]
#         return user, item, rating

rating_dataset = UserItemRatingDataset(user_ids, item_ids, ratings)
dataloader = DataLoader(rating_dataset, batch_size=batch_size, shuffle=True)

val_rating_dataset = UserItemRatingDataset(val_user_ids, val_item_ids, val_ratings)
val_dataloader = DataLoader(val_rating_dataset, batch_size=batch_size, shuffle=True)

# Instantiate LightGCN model
model = LightGCN(embedding_dim=embedding_dim, n_layers=n_layers,
                 user_ids=user_ids, item_ids=item_ids, ratings=ratings, device=device)

# Training loop
n_epochs = 50

train(model, dataloader, val_dataloader, epochs=n_epochs, patience=2, print_steps=1, lr=0.001, device=device, progress_bar_type='tqdm_notebook')

Mock User IDs: [0, 0, 1, 2, 2]
Mock Item IDs: [0, 1, 2, 3, 1]
Ratings: [1, 4, 5, 3, 2]


Epochs:   0%|          | 0/50 [00:00<?, ?it/s]

Training Epoch 1:   0%|          | 0/2 [00:00<?, ?it/s]

[32m2024-09-09 00:07:12.785[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m122[0m - [1mStep 1, Global Loss: 11.5151[0m
[32m2024-09-09 00:07:12.786[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m125[0m - [1mStep 1, Learning Rate: 0.001000[0m
[32m2024-09-09 00:07:12.786[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m128[0m - [1mStep 1, Gradient Norms: {'grad_norm_user_embedding.weight': 0.06376536190509796, 'grad_norm_item_embedding.weight': 0.07702846080064774, 'total_grad_norm': 0.0999970257167934}[0m
[32m2024-09-09 00:07:12.788[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m122[0m - [1mStep 2, Global Loss: 10.2530[0m
[32m2024-09-09 00:07:12.788[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m125[0m - [1mStep 2, Learning Rate: 0.001000[0m
[32m2024-09-09 00:07:12.789[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m128[0m - [1mStep

Training Epoch 2:   0%|          | 0/2 [00:00<?, ?it/s]

[32m2024-09-09 00:07:12.795[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m122[0m - [1mStep 3, Global Loss: 13.5106[0m
[32m2024-09-09 00:07:12.795[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m125[0m - [1mStep 3, Learning Rate: 0.001000[0m
[32m2024-09-09 00:07:12.796[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m128[0m - [1mStep 3, Gradient Norms: {'grad_norm_user_embedding.weight': 0.07017427682876587, 'grad_norm_item_embedding.weight': 0.08113785833120346, 'total_grad_norm': 0.10727432676560925}[0m
[32m2024-09-09 00:07:12.797[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m122[0m - [1mStep 4, Global Loss: 7.2579[0m
[32m2024-09-09 00:07:12.797[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m125[0m - [1mStep 4, Learning Rate: 0.001000[0m
[32m2024-09-09 00:07:12.797[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m128[0m - [1mStep

Training Epoch 3:   0%|          | 0/2 [00:00<?, ?it/s]

[32m2024-09-09 00:07:12.803[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m122[0m - [1mStep 5, Global Loss: 12.7595[0m
[32m2024-09-09 00:07:12.803[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m125[0m - [1mStep 5, Learning Rate: 0.001000[0m
[32m2024-09-09 00:07:12.803[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m128[0m - [1mStep 5, Gradient Norms: {'grad_norm_user_embedding.weight': 0.06755920499563217, 'grad_norm_item_embedding.weight': 0.07502468675374985, 'total_grad_norm': 0.10096013966977334}[0m
[32m2024-09-09 00:07:12.804[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m122[0m - [1mStep 6, Global Loss: 8.3827[0m
[32m2024-09-09 00:07:12.805[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m125[0m - [1mStep 6, Learning Rate: 0.001000[0m
[32m2024-09-09 00:07:12.805[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m128[0m - [1mStep

# Prep data

In [7]:
train_df = pd.read_parquet("../data/train.parquet")
val_df = pd.read_parquet("../data/val.parquet")

In [8]:
from src.id_mapper import IDMapper

In [9]:
user_ids = train_df['user_id'].values
item_ids = train_df['parent_asin'].values
unique_user_ids = list(set(user_ids))
unique_item_ids = list(set(item_ids))

logger.info(f"{len(unique_user_ids)=:,.0f}, {len(unique_item_ids)=:,.0f}")

[32m2024-09-09 00:07:12.882[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mlen(unique_user_ids)=5,223, len(unique_item_ids)=2,653[0m


In [10]:
idm = IDMapper()
idm.fit(unique_user_ids, unique_item_ids)

In [11]:
user_indices = [idm.get_user_index(user_id) for user_id in user_ids]
item_indices = [idm.get_item_index(item_id) for item_id in item_ids]
ratings = train_df['rating'].values.tolist()

In [12]:
val_user_indices = [idm.get_user_index(user_id) for user_id in val_df['user_id']]
val_item_indices = [idm.get_item_index(item_id) for item_id in val_df['parent_asin']]
val_ratings = val_df['rating'].values.tolist()

# Train

In [13]:
batch_size = 32

rating_dataset = UserItemRatingDataset(user_indices, item_indices, ratings)
dataloader = DataLoader(rating_dataset, batch_size=batch_size, shuffle=True)

val_rating_dataset = UserItemRatingDataset(val_user_indices, val_item_indices, val_ratings)
val_dataloader = DataLoader(val_rating_dataset, batch_size=batch_size, shuffle=True)

In [14]:
embedding_dim = 128
n_layers = 3

# Instantiate LightGCN model
model = LightGCN(embedding_dim=embedding_dim, n_layers=n_layers,
                 user_ids=user_indices, item_ids=item_indices, ratings=ratings, device=device)

#### Predict before train

In [15]:
user_id = 'AEHW2B54HDLZ3APBEWXHYLZ6SSYQ'
val_df.loc[lambda df: df['user_id'].eq(user_id)]

Unnamed: 0,user_id,parent_asin,rating,timestamp
34367,AEHW2B54HDLZ3APBEWXHYLZ6SSYQ,B07MYVF61Y,4.0,1654225907045


In [16]:
item_id = 'B07MYVF61Y'
user_indice = idm.get_user_index(user_id)
item_indice = idm.get_item_index(item_id)

model.predict([user_indice], [item_indice])

tensor([0.0066], grad_fn=<SumBackward1>)

#### Training loop

In [None]:
n_epochs = 50

metric_log_cb = MetricLogCallback()

train(
    model,
    dataloader,
    val_dataloader,
    epochs=n_epochs,
    patience=2,
    print_steps=100,
    lr=0.03,
    device=device,
    progress_bar_type='tqdm_notebook',
    callbacks=[metric_log_cb.process_payload]
)

Epochs:   0%|          | 0/50 [00:00<?, ?it/s]

Training Epoch 1:   0%|          | 0/566 [00:00<?, ?it/s]

[32m2024-09-09 00:07:14.880[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m122[0m - [1mStep 100, Global Loss: 19.7837[0m
[32m2024-09-09 00:07:14.881[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m125[0m - [1mStep 100, Learning Rate: 0.030000[0m
[32m2024-09-09 00:07:14.881[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m128[0m - [1mStep 100, Gradient Norms: {'grad_norm_user_embedding.weight': 0.2209496945142746, 'grad_norm_item_embedding.weight': 0.17535638809204102, 'total_grad_norm': 0.2820791207279223}[0m
[32m2024-09-09 00:07:16.753[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m122[0m - [1mStep 200, Global Loss: 19.6663[0m
[32m2024-09-09 00:07:16.753[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m125[0m - [1mStep 200, Learning Rate: 0.030000[0m
[32m2024-09-09 00:07:16.754[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m128[0m -

Training Epoch 2:   0%|          | 0/566 [00:00<?, ?it/s]

[32m2024-09-09 00:07:25.383[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m122[0m - [1mStep 600, Global Loss: 6.5678[0m
[32m2024-09-09 00:07:25.384[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m125[0m - [1mStep 600, Learning Rate: 0.030000[0m
[32m2024-09-09 00:07:25.384[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m128[0m - [1mStep 600, Gradient Norms: {'grad_norm_user_embedding.weight': 0.3292803168296814, 'grad_norm_item_embedding.weight': 0.32689398527145386, 'total_grad_norm': 0.4639883669426518}[0m
[32m2024-09-09 00:07:27.395[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m122[0m - [1mStep 700, Global Loss: 6.1948[0m
[32m2024-09-09 00:07:27.395[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m125[0m - [1mStep 700, Learning Rate: 0.030000[0m
[32m2024-09-09 00:07:27.396[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m128[0m - 

Training Epoch 3:   0%|          | 0/566 [00:00<?, ?it/s]

[32m2024-09-09 00:07:38.323[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m122[0m - [1mStep 1200, Global Loss: 1.6251[0m
[32m2024-09-09 00:07:38.324[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m125[0m - [1mStep 1200, Learning Rate: 0.027000[0m
[32m2024-09-09 00:07:38.324[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m128[0m - [1mStep 1200, Gradient Norms: {'grad_norm_user_embedding.weight': 0.2486582100391388, 'grad_norm_item_embedding.weight': 0.22259633243083954, 'total_grad_norm': 0.33373647183298577}[0m
[32m2024-09-09 00:07:40.403[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m122[0m - [1mStep 1300, Global Loss: 1.5008[0m
[32m2024-09-09 00:07:40.404[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m125[0m - [1mStep 1300, Learning Rate: 0.027000[0m
[32m2024-09-09 00:07:40.404[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m128[

Training Epoch 4:   0%|          | 0/566 [00:00<?, ?it/s]

[32m2024-09-09 00:07:49.631[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m122[0m - [1mStep 1700, Global Loss: 0.7929[0m
[32m2024-09-09 00:07:49.631[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m125[0m - [1mStep 1700, Learning Rate: 0.027000[0m
[32m2024-09-09 00:07:49.632[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m128[0m - [1mStep 1700, Gradient Norms: {'grad_norm_user_embedding.weight': 0.21320492029190063, 'grad_norm_item_embedding.weight': 0.15109604597091675, 'total_grad_norm': 0.26131657648285744}[0m
[32m2024-09-09 00:07:51.738[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m122[0m - [1mStep 1800, Global Loss: 1.1383[0m
[32m2024-09-09 00:07:51.739[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m125[0m - [1mStep 1800, Learning Rate: 0.027000[0m
[32m2024-09-09 00:07:51.739[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m128

Training Epoch 5:   0%|          | 0/566 [00:00<?, ?it/s]

[32m2024-09-09 00:08:03.097[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m122[0m - [1mStep 2300, Global Loss: 1.1182[0m
[32m2024-09-09 00:08:03.098[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m125[0m - [1mStep 2300, Learning Rate: 0.024300[0m
[32m2024-09-09 00:08:03.098[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m128[0m - [1mStep 2300, Gradient Norms: {'grad_norm_user_embedding.weight': 0.20204001665115356, 'grad_norm_item_embedding.weight': 0.1578778326511383, 'total_grad_norm': 0.256409006025567}[0m
[32m2024-09-09 00:08:05.177[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m122[0m - [1mStep 2400, Global Loss: 1.0930[0m
[32m2024-09-09 00:08:05.178[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m125[0m - [1mStep 2400, Learning Rate: 0.024300[0m
[32m2024-09-09 00:08:05.178[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m128[0m

Training Epoch 6:   0%|          | 0/566 [00:00<?, ?it/s]

[32m2024-09-09 00:08:16.716[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m122[0m - [1mStep 2900, Global Loss: 1.3460[0m
[32m2024-09-09 00:08:16.716[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m125[0m - [1mStep 2900, Learning Rate: 0.024300[0m
[32m2024-09-09 00:08:16.717[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m128[0m - [1mStep 2900, Gradient Norms: {'grad_norm_user_embedding.weight': 0.25276660919189453, 'grad_norm_item_embedding.weight': 0.18699786067008972, 'total_grad_norm': 0.3144187631448833}[0m
[32m2024-09-09 00:08:18.820[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m122[0m - [1mStep 3000, Global Loss: 1.3134[0m
[32m2024-09-09 00:08:18.820[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m125[0m - [1mStep 3000, Learning Rate: 0.021870[0m
[32m2024-09-09 00:08:18.821[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m128[

Training Epoch 7:   0%|          | 0/566 [00:00<?, ?it/s]

[32m2024-09-09 00:08:28.233[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m122[0m - [1mStep 3400, Global Loss: 1.1464[0m
[32m2024-09-09 00:08:28.233[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m125[0m - [1mStep 3400, Learning Rate: 0.021870[0m
[32m2024-09-09 00:08:28.233[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m128[0m - [1mStep 3400, Gradient Norms: {'grad_norm_user_embedding.weight': 0.3041734993457794, 'grad_norm_item_embedding.weight': 0.2293393313884735, 'total_grad_norm': 0.3809436265721858}[0m
[32m2024-09-09 00:08:30.359[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m122[0m - [1mStep 3500, Global Loss: 1.1496[0m
[32m2024-09-09 00:08:30.360[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m125[0m - [1mStep 3500, Learning Rate: 0.021870[0m
[32m2024-09-09 00:08:30.360[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m128[0m

Training Epoch 8:   0%|          | 0/566 [00:00<?, ?it/s]

[32m2024-09-09 00:08:41.862[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m122[0m - [1mStep 4000, Global Loss: 1.2495[0m
[32m2024-09-09 00:08:41.863[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m125[0m - [1mStep 4000, Learning Rate: 0.019683[0m
[32m2024-09-09 00:08:41.863[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m128[0m - [1mStep 4000, Gradient Norms: {'grad_norm_user_embedding.weight': 0.2811598479747772, 'grad_norm_item_embedding.weight': 0.20301806926727295, 'total_grad_norm': 0.3467956120861553}[0m
[32m2024-09-09 00:08:43.999[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m122[0m - [1mStep 4100, Global Loss: 1.1203[0m
[32m2024-09-09 00:08:44.000[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m125[0m - [1mStep 4100, Learning Rate: 0.019683[0m
[32m2024-09-09 00:08:44.000[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m128[0

Training Epoch 9:   0%|          | 0/566 [00:00<?, ?it/s]

[32m2024-09-09 00:08:55.635[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m122[0m - [1mStep 4600, Global Loss: 0.9214[0m
[32m2024-09-09 00:08:55.636[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m125[0m - [1mStep 4600, Learning Rate: 0.019683[0m
[32m2024-09-09 00:08:55.636[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m128[0m - [1mStep 4600, Gradient Norms: {'grad_norm_user_embedding.weight': 0.20199033617973328, 'grad_norm_item_embedding.weight': 0.14372995495796204, 'total_grad_norm': 0.24790803912382403}[0m
[32m2024-09-09 00:08:57.808[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m122[0m - [1mStep 4700, Global Loss: 0.9491[0m
[32m2024-09-09 00:08:57.808[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m125[0m - [1mStep 4700, Learning Rate: 0.019683[0m
[32m2024-09-09 00:08:57.808[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m128

Training Epoch 10:   0%|          | 0/566 [00:00<?, ?it/s]

[32m2024-09-09 00:09:07.239[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m122[0m - [1mStep 5100, Global Loss: 0.7466[0m
[32m2024-09-09 00:09:07.239[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m125[0m - [1mStep 5100, Learning Rate: 0.017715[0m
[32m2024-09-09 00:09:07.240[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m128[0m - [1mStep 5100, Gradient Norms: {'grad_norm_user_embedding.weight': 0.22143200039863586, 'grad_norm_item_embedding.weight': 0.14372944831848145, 'total_grad_norm': 0.2639891761312886}[0m
[32m2024-09-09 00:09:09.451[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m122[0m - [1mStep 5200, Global Loss: 0.8400[0m
[32m2024-09-09 00:09:09.452[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m125[0m - [1mStep 5200, Learning Rate: 0.017715[0m
[32m2024-09-09 00:09:09.452[0m | [1mINFO    [0m | [36msrc.train_utils[0m:[36mtrain[0m:[36m128[

# Visualize training

In [None]:
step_metrics = [p for p in metric_log_cb.payloads if 'step' in p]
epoch_metrics = [p for p in metric_log_cb.payloads if 'epoch' in p]
step_metrics_df = pd.DataFrame(step_metrics)
step_metrics_df

In [None]:
epoch_metrics_df = pd.DataFrame(epoch_metrics)
epoch_metrics_df = epoch_metrics_df.pipe(lambda df: pd.melt(df, id_vars=["epoch"], var_name="loss_type", value_name="value"))
epoch_metrics_df

In [None]:
def plot_metric(df, index='step', col: str = None, color=None):
    # Create the plot
    cols = [index, col]
    if color:
        cols.append(color)
    fig = px.line(
        df[cols].dropna(),
        x=index,
        y=col,
        color=color,
        labels={'x': index, 'y': col},
        title=f'{col} by {index}',
    )

    if color:
        fig.update_layout(showlegend=True)

    fig.show()

plot_metric(step_metrics_df, col='global_loss')
plot_metric(step_metrics_df, col='learning_rate')
plot_metric(step_metrics_df, col='total_grad_norm')
plot_metric(epoch_metrics_df, index='epoch', col='value', color='loss_type')

In [None]:
fig = px.bar(
    step_metrics_df,
    x='step',
    y=['grad_norm_user_embedding.weight', 'grad_norm_item_embedding.weight'],
    title='Norm gradients of user embeddings and item embeddings',
    height=500
)
fig.update_layout(showlegend=True, barmode='stack')
fig.show()

# Predict

In [None]:
train_df

In [None]:
user_id = 'AEHW2B54HDLZ3APBEWXHYLZ6SSYQ'
val_df.loc[lambda df: df['user_id'].eq(user_id)]

In [None]:
item_id = 'B07MYVF61Y'
user_indice = idm.get_user_index(user_id)
item_indice = idm.get_item_index(item_id)

model.predict([user_indice], [item_indice])