## Setup

In [16]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [17]:
import pandas as pd
from pydantic import BaseModel
import sys
import os
from lightning.pytorch.loggers import MLFlowLogger
from loguru import logger
from load_dotenv import load_dotenv
import torch
from torch.utils.data import DataLoader
import lightning as L
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint, ModelSummary

import mlflow

sys.path.insert(0, "..")

from src.utils.embedding_id_mapper import IDMapper
from src.algo.two_tower.model import TwoTowerRating
from src.algo.two_tower.dataset import UserItemRatingDFDataset, UserItemBinaryDFDataset
from src.algo.two_tower.trainer import TwoTowerLitModule

In [18]:
load_dotenv(override = True)

True

In [19]:
hidden_dim = 128
embedding_dim: int = 128

In [20]:
class Args(BaseModel):
    testing: bool = False
    log_to_mlflow: bool = True
    experiment_name: str = "first-attempt"
    run_name: str = f"005-two-tower-{embedding_dim}-{hidden_dim}-neg-4"
    notebook_persit_dp: str = None
    
    user_col: str = "user_id"
    item_col: str = "parent_asin"
    rating_col: str = "rating"
    timestamp_col: str = "timestamp"
    group_name: str = "two-tower"

    top_K: int = 100
    top_k: int = 10

    batch_size: int = 128
    embedding_dim: int = embedding_dim
    learning_rate: float = 0.01
    l2_reg: float = 1e-5
    early_stopping_patience: int = 20
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    max_epochs: int = 100

    # TwoTower specific
    hidden_dim: int = hidden_dim
    dropout: float = 0
    
    best_checkpoint_path: str = None

    train_data_fp: str = os.path.abspath("../data_for_ai/interim/train_sample_interactions_16407u_neg_seq.parquet")
    val_data_fp: str = os.path.abspath("../data_for_ai/interim/val_sample_interactions_16407u_neg_seq.parquet")

    def init(self):
        self.notebook_persit_dp = os.path.abspath(f"data/{self.experiment_name}/{self.run_name}")

        if not (mlflow_uri := os.environ.get("MLFLOW_TRACKING_URI")):
            self.log_to_mlflow = False
            logger.warning("MLFlow is not enabled. Turn off tracking to Mlflow.")

        if self.log_to_mlflow:
            logger.info(
                f"Setting up Mlflow experiment: {self.experiment_name}, run_name: {self.run_name}"
            )

            self._mlf_logger = MLFlowLogger(
                experiment_name=self.experiment_name,
                run_name=self.run_name,
                tracking_uri=mlflow_uri,
                log_model=True,
            )

        if not self.testing:
            os.makedirs(self.notebook_persit_dp, exist_ok=True)
        return self
    
args = Args().init()
print(args.model_dump_json(indent=2))

[32m2025-06-20 23:42:03.649[0m | [1mINFO    [0m | [36m__main__[0m:[36minit[0m:[36m42[0m - [1mSetting up Mlflow experiment: first-attempt, run_name: 005-two-tower-128-128-neg-4[0m


{
  "testing": false,
  "log_to_mlflow": true,
  "experiment_name": "first-attempt",
  "run_name": "005-two-tower-128-128-neg-4",
  "notebook_persit_dp": "/home/dinhln/Desktop/real_time_recsys/notebooks/data/first-attempt/005-two-tower-128-128-neg-4",
  "user_col": "user_id",
  "item_col": "parent_asin",
  "rating_col": "rating",
  "timestamp_col": "timestamp",
  "group_name": "two-tower",
  "top_K": 100,
  "top_k": 10,
  "batch_size": 128,
  "embedding_dim": 128,
  "learning_rate": 0.01,
  "l2_reg": 0.00001,
  "early_stopping_patience": 20,
  "device": "cuda",
  "max_epochs": 100,
  "hidden_dim": 128,
  "dropout": 0.0,
  "best_checkpoint_path": null,
  "train_data_fp": "/home/dinhln/Desktop/real_time_recsys/data_for_ai/interim/train_sample_interactions_16407u_neg_seq.parquet",
  "val_data_fp": "/home/dinhln/Desktop/real_time_recsys/data_for_ai/interim/val_sample_interactions_16407u_neg_seq.parquet"
}


## Init model

In [21]:
def init_model(n_user, n_items, embedding_dim, hidden_dim, dropout):
    """
    Initialize the model with the given parameters.
    """
    model = TwoTowerRating(
        num_users = n_user,
        num_items = n_items,
        embedding_dim = embedding_dim,
        hidden_units_dim = hidden_dim,
        dropout = dropout,
    )
    return model

## Test implementation

In [22]:
embedding_dim = 8
batch_size = 2

# Mock data
user_indices = [0, 0, 1, 2, 2]
item_indices = [0, 1, 2, 3, 4]
timestamps = [0, 1, 2, 3, 4]
ratings = [0, 1, 0, 1, 0]

n_users = len(set(user_indices))
n_items = len(set(item_indices))

train_df = pd.DataFrame(
    {
        "user_indice": user_indices,
        "item_indice": item_indices,
        args.timestamp_col: timestamps,
        args.rating_col: ratings,
    }
)

model = init_model(n_users, n_items, embedding_dim, embedding_dim, args.dropout)

# Example forward pass
model.eval()
user = torch.tensor([0])
target_item = torch.tensor([2])
predictions = model(user, target_item)
print(predictions)
model.train()

tensor([-0.1309], grad_fn=<SumBackward1>)


TwoTowerRating(
  (item_embedding): Embedding(5, 8)
  (item_fc): Linear(in_features=8, out_features=8, bias=True)
  (user_embedding): Embedding(3, 8)
  (user_fc): Linear(in_features=8, out_features=8, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0, inplace=False)
  (sigmoid_fn): Sigmoid()
)

In [23]:
rating_dataset = UserItemRatingDFDataset(
    train_df,
    user_col = "user_indice",
    item_col = "item_indice",
    rating_col = args.rating_col,
    timestamp_col = args.timestamp_col,
)
train_loader = DataLoader(
    rating_dataset,
    batch_size = batch_size,
    shuffle = False,
)

In [24]:
for batch_input in train_loader:
    print(batch_input)

{'user': tensor([0, 0]), 'item': tensor([0, 1]), 'rating': tensor([0., 1.])}
{'user': tensor([1, 2]), 'item': tensor([2, 3]), 'rating': tensor([0., 1.])}
{'user': tensor([2]), 'item': tensor([4]), 'rating': tensor([0.])}


In [25]:
# model
lit_model = TwoTowerLitModule(model, log_dir=args.notebook_persit_dp)

# callbacks
callbacks = [ModelSummary(max_depth=-1)]
# train model
trainer = L.Trainer(
    default_root_dir=f"{args.notebook_persit_dp}/logs/test",
    max_epochs=300,
    accelerator=args.device if args.device else "auto",
    callbacks=callbacks
)
trainer.fit(
    model=lit_model, train_dataloaders=train_loader, val_dataloaders=train_loader
)

[32m2025-06-20 23:42:04.010[0m | [1mINFO    [0m | [36msrc.algo.two_tower.trainer[0m:[36m__init__[0m:[36m59[0m - [1mIDM is not provided. Skipping Evidently ranking metrics logging.[0m
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                 | Type                   | Params | Mode 
------------------------------------------------------------------------
0 | model                | TwoTowerRating         | 208    | train
1 | model.item_embedding | Embedding              | 40     | train
2 | model.item_fc        | Linear                 | 72     | t

Sanity Checking: |          | 0/? [00:00<?, ?it/s]


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


The number of training batches (3) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.



Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=300` reached.


In [26]:
users =torch.tensor([1, 1, 1, 1])
items = torch.tensor([0, 1, 2, 3])

predictions = model(users, items)
print(predictions)

tensor([ 0.0942,  0.1900, -0.8749,  0.4588], grad_fn=<SumBackward1>)


## Trainig loop

In [27]:
train_df = pd.read_parquet(args.train_data_fp)
val_df = pd.read_parquet(args.val_data_fp)

assert set(val_df[args.user_col].unique()).issubset(set(train_df[args.user_col].unique())), "Validation users must be present in training users."

assert set(val_df[args.item_col].unique()).issubset(set(train_df[args.item_col].unique())), "Validation items must be present in training items."
assert train_df[args.timestamp_col].max() < val_df[args.timestamp_col].min(), "Validation data must be after training data. Otherwise, its a data contamination problem."

In [28]:
train_df.head(3)

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,item_sequence
151343,AEEV5YWQKPBTLFWHKOBBULYA2RDQ,B009RUZ7TS,0.0,2014-07-17 19:15:55.000,1412,4220,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 455..."
40958,AF7KZV4NJ5GBDVFTB7PEEUN4U53A,B0BBMLD8QT,5.0,2015-07-29 20:38:06.000,4871,4476,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1...."
218918,AFVQ4K4KZPLQ3E2VFYSGX6HFXGNQ,B0BB6R89VF,0.0,2017-12-13 20:35:02.334,7616,1218,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 129..."


## Convert user_id and item_id into indices

In [29]:
idm_path = os.path.abspath("../data_for_ai/interim/idm_16407u.json")
idm = IDMapper().load(idm_path)
idm.get_user_id(1)

'AE227WAM4NWQPJI33OPN7ZARNNZQ'

In [None]:
# train_df = train_df.pipe(idm.map_indices)
val_df = val_df.pipe(idm.map_indices)

assert idm.unknown_item_index not in train_df["item_indice"].values, "Unknown item index must be present in training data."
assert idm.unknown_user_index not in train_df["user_indice"].values, "Unknown user index must be present in training data."
assert idm.unknown_item_index not in val_df["item_indice"].values, "Unknown item index must be present in validation data."
assert idm.unknown_user_index not in val_df["user_indice"].values, "Unknown user index must be present in validation data."

In [31]:
train_df.head(3)

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,item_sequence
151343,AEEV5YWQKPBTLFWHKOBBULYA2RDQ,B009RUZ7TS,0.0,2014-07-17 19:15:55.000,1412,4220,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 455..."
40958,AF7KZV4NJ5GBDVFTB7PEEUN4U53A,B0BBMLD8QT,5.0,2015-07-29 20:38:06.000,4871,4476,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1...."
218918,AFVQ4K4KZPLQ3E2VFYSGX6HFXGNQ,B0BB6R89VF,0.0,2017-12-13 20:35:02.334,7616,1218,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 129..."


In [32]:
assert train_df.groupby(args.user_col)[args.item_col].nunique().min() >= 5, "Each user must have at least five items."
assert train_df.groupby(args.item_col)[args.user_col].nunique().min() >= 10, "Each item must have at least ten users."

In [33]:
train_df = train_df.sample(frac=1, random_state=42)

In [34]:
rating_dataset = UserItemBinaryDFDataset(
    train_df, "user_indice", "item_indice", args.rating_col, args.timestamp_col
)
val_rating_dataset = UserItemBinaryDFDataset(
    val_df, "user_indice", "item_indice", args.rating_col, args.timestamp_col
)

train_loader = DataLoader(
    rating_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True
)
val_loader = DataLoader(
    val_rating_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False
)

In [35]:
val_df

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,item_sequence
260331,AGMJWWTZ6HMM2FBRDLFW2CWMV5DQ,B00E0ISVLI,0.0,2021-07-18 15:44:29.739,10483,2563,"[-1, 2906, 3011, 4674, 4593, 4755, 3810, 3921,..."
259198,AE3XVOCHEO5MTDIAIET5BZS26AJA,B07GPGVYGX,0.0,2021-03-12 03:28:00.854,254,3381,"[-1, -1, -1, -1, 1188, 1510, 4399, 3089, 2290,..."
258841,AESPJW3GNHXNJNW5CYV7PTEX44MQ,B07GZFM1ZM,0.0,2021-02-09 16:08:20.512,3190,921,"[-1, -1, -1, -1, -1, 2569, 2742, 2855, 2351, 346]"
259382,AE3HTD5GV52IDFUQ6MMXRNF4WDZQ,B09M3BZYVP,0.0,2021-03-30 11:48:08.855,181,971,"[-1, -1, -1, -1, -1, 1872, 1570, 2366, 3899, 3..."
127982,AHTGQCLAFVD43IQ2AIERW2FQ7P4A,B00006JPE1,5.0,2021-02-10 14:43:48.128,15577,25,"[4166, 3089, 3074, 3443, 3227, 3493, 4466, 355..."
...,...,...,...,...,...,...,...
130708,AFAFYOKLVZYF2VM2VZ6H37ATHOOA,B0BPZFW1JH,5.0,2022-01-26 18:05:13.266,4968,4595,"[-1, 3326, 2830, 1420, 4007, 426, 1375, 1678, ..."
261331,AG7QSQMGIWBT5EM6MQV63NXMGURA,B0BWD4WGJB,0.0,2021-12-24 18:36:35.171,8909,1810,"[2232, 4477, 2092, 2556, 330, 3898, 2931, 4345..."
130349,AECSIHUJ2JSARFYGYHETBMVPFONA,B074JT3698,5.0,2021-12-07 16:40:04.204,1171,3039,"[-1, -1, -1, -1, 3117, 4220, 4592, 3565, 3494,..."
130673,AFWI6SEQ5EP2YOTM3OOLLWHR4ITA,B08KGSZKCT,5.0,2022-01-22 20:39:35.970,7717,3959,"[1492, 2972, 4530, 3104, 3678, 1502, 3089, 770..."


In [36]:
for i in train_loader:
    print(i)
    break

{'user': tensor([13445, 13998,  9653, 10915,   634,  3384, 11749, 15479,  7150,  7221,
        13932,  8744,  9442,  3277,  5400, 15692, 10921,  6055,  2343, 13815,
         8478,  5797,  4178, 13295, 14451,  2349,  7202,  6023, 12976, 15696,
         7247, 13836, 11489, 14473,  2085, 11150,   138, 12747, 10832,  4172,
         3056, 11728,  2832,  2478,  7945,  9768, 14237,  2869,  9056, 16403,
         4975,  9507, 15696, 15772, 10286,  9521,  4518, 15426, 11894,  3903,
         9395, 12906, 13905,  1019,  7569,  8965,  9693, 12409, 12493,  4691,
        13658, 12993,  5860, 11547, 12749, 11504,  7773, 14082,  1848,  2930,
         5321,  8269,  5487,  9739,  3552,  1657,  2061,  2788,  1784,  5190,
         9613,    68, 12616,   403, 15456,  6468,  8629, 13689,  9530, 15360,
        13311, 15258, 13735,  7395,  7858, 14246,  5428,  5874,  2297,  5626,
         9006,  5677,   969,  8850,  1831,  9127,  8973,  7879,  2484,  9100,
         9231, 15384, 13048, 15114,  7396, 13876,  4593

In [37]:
item_indices = train_df[args.item_col].unique()
user_indices = train_df[args.user_col].unique()
n_items = len(item_indices)
n_users = len(user_indices)

logger.info(f"Number of users: {n_users}, Number of items: {n_items}")
model = init_model(n_users, n_items, args.embedding_dim, args.hidden_dim, args.dropout)

[32m2025-06-20 23:42:23.117[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mNumber of users: 16407, Number of items: 4817[0m


## Overfir 1 batch

In [38]:
early_stopping = EarlyStopping(
    monitor="val_loss", patience=5, mode="min", verbose=False
)

model_summary = ModelSummary(max_depth=-1)

model = init_model(n_users, n_items, args.embedding_dim, args.hidden_dim, args.dropout)
lit_model = TwoTowerLitModule(
    model,
    learning_rate=args.learning_rate,
    l2_reg=args.l2_reg,
    log_dir=args.notebook_persit_dp,
    accelerator=args.device,
    idm= idm,
    log_ranking_metrics=False
)

log_dir = f"{args.notebook_persit_dp}/logs/overfit"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    accelerator=args.device if args.device else "auto",
    max_epochs=100,
    overfit_batches=1,
    callbacks=[early_stopping, model_summary],
)
trainer.fit(
    model=lit_model,
    train_dataloaders=train_loader,
    val_dataloaders=train_loader,
)
logger.info(f"Logs available at {trainer.log_dir}")

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(overfit_batches=1)` was configured so 1 batch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                 | Type                   | Params | Mode 
------------------------------------------------------------------------
0 | model                | TwoTowerRating         | 2.7 M  | train
1 | model.item_embedding | Embedding              | 616 K  | train
2 | model.item_fc        | Linear                 | 16.5 K | train
3 | model.user_embedding | Embedding              | 2.1 M  | train
4 | model.user_fc        | Linear                 | 16

Sanity Checking: |          | 0/? [00:00<?, ?it/s]


You requested to overfit but enabled val dataloader shuffling. We are turning off the val dataloader shuffling for you.


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


You requested to overfit but enabled train dataloader shuffling. We are turning off the train dataloader shuffling for you.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.



Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.
[32m2025-06-20 23:42:39.688[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m33[0m - [1mLogs available at /home/dinhln/Desktop/real_time_recsys/notebooks/data/first-attempt/005-two-tower-128-128-neg-4/logs/overfit/lightning_logs/version_3[0m


## Run on all data

In [None]:
early_stopping = EarlyStopping(
    monitor="val_roc_auc", patience=5, mode="max", verbose=False, min_delta=0.01
)

checkpoint_callback = ModelCheckpoint(
    dirpath=f"{args.notebook_persit_dp}/checkpoints",
    filename="best-checkpoint",
    save_top_k=1,
    monitor="val_roc_auc",
    mode="max",
)

model_summary = ModelSummary(max_depth=-1)

model = init_model(n_users, n_items, args.embedding_dim, args.hidden_dim, args.dropout)
lit_model = TwoTowerLitModule(
    model,
    learning_rate=args.learning_rate,
    l2_reg=args.l2_reg,
    log_dir=args.notebook_persit_dp,
    accelerator=args.device,
    idm= idm
    
)

log_dir = f"{args.notebook_persit_dp}/logs/run"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    accelerator=args.device if args.device else "auto",
    max_epochs=args.max_epochs,
    callbacks=[early_stopping, checkpoint_callback, model_summary],
    logger=args._mlf_logger if args.log_to_mlflow else None,
)
trainer.fit(
    model=lit_model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)

# Change the library as a workaround for the issue in the latest Lightning release
#https://github.com/Lightning-AI/pytorch-lightning/pull/20669/commits/429f732a0528c558e701da7ec01e51c1e2e4f32e

Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

Checkpoint directory /home/dinhln/Desktop/real_time_recsys/notebooks/data/first-attempt/005-two-tower-128-128-neg-4/checkpoints exists and is not empty.

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                 | Type                   | Params | Mode 
------------------------------------------------------------------------
0 | model                | TwoTowerRating         | 2.7 M  | train
1 | model.item_embedding | Embedding              | 616 K  | train
2 | model.item_fc        | Linear                 | 16.5 K | train
3 | model.user_embedding | Embedding              | 2.1 M  | train
4 | model.user_fc        | Linear                 | 16.5 K | train
5 | model.relu           | ReLU   

Sanity Checking: |          | 0/? [00:00<?, ?it/s]


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.



Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

## Clean up

In [None]:
all_params = [args]

if args.log_to_mlflow:
    run_id = trainer.logger.run_id

    with mlflow.start_run(run_id=run_id):
        for params in all_params:
            params_dict = params.model_dump()
            params_ = dict()
            for k, v in params_dict.items():
                if k == "top_K":
                    k = "top_big_K"
                if k == "top_k":
                    k = "top_small_k"
                params_[f"{params.__repr_name__()}.{k}"] = v
            mlflow.log_params(params_)

🏃 View run 005-two-tower-128-128-neg-4 at: http://138.2.61.6:5002/#/experiments/2/runs/89c2a8496f684e0a8ecf450a33b57793
🧪 View experiment at: http://138.2.61.6:5002/#/experiments/2


## Log metrics

In [None]:
from src.eval.utils import create_rec_df, create_label_df, merge_recs_with_target
from src.eval.log_metrics import log_ranking_metrics, log_classification_metrics

In [None]:
model

TwoTowerRating(
  (item_embedding): Embedding(4817, 128)
  (item_fc): Linear(in_features=128, out_features=128, bias=True)
  (user_embedding): Embedding(16407, 128)
  (user_fc): Linear(in_features=128, out_features=128, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0, inplace=False)
  (sigmoid_fn): Sigmoid()
)

In [None]:
logger.info(f"Loading best checkpoint from {checkpoint_callback.best_model_path}...")
args.best_checkpoint_path = checkpoint_callback.best_model_path

best_trainer = TwoTowerLitModule.load_from_checkpoint(
    checkpoint_path=checkpoint_callback.best_model_path,
    model = init_model(n_users, n_items, 128,128, args.dropout),
)

[32m2025-06-18 11:47:26.753[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mLoading best checkpoint from /home/dinhln/Desktop/real_time_recsys/notebooks/data/first-attempt/005-two-tower-128-128-neg-4/checkpoints/best-checkpoint-v1.ckpt...[0m
[32m2025-06-18 11:47:26.901[0m | [1mINFO    [0m | [36msrc.algo.two_tower.trainer[0m:[36m__init__[0m:[36m59[0m - [1mIDM is not provided. Skipping Evidently ranking metrics logging.[0m


In [None]:
best_model = best_trainer.model.to(args.device)
best_model.eval()

TwoTowerRating(
  (item_embedding): Embedding(4817, 128)
  (item_fc): Linear(in_features=128, out_features=128, bias=True)
  (user_embedding): Embedding(16407, 128)
  (user_fc): Linear(in_features=128, out_features=128, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0, inplace=False)
  (sigmoid_fn): Sigmoid()
)

In [None]:
val_recs_df = val_df.sort_values(by=args.timestamp_col).drop_duplicates(subset=[args.user_col], keep="first")

In [None]:
mlflow.start_run(run_id = trainer.logger.run_id)

<ActiveRun: >

## Classification metrics

In [None]:
val_user_indices = val_df["user_indice"].values
val_item_indices = val_df["item_indice"].values

In [None]:
len(val_user_indices), len(val_item_indices)

(6958, 6958)

In [None]:
users = torch.tensor(val_user_indices, device=args.device)
items = torch.tensor(val_item_indices, device=args.device)
classifications = best_model.predict(users, items)

In [None]:
eval_classification_df = val_df.assign(
    classification_proba=classifications.cpu().detach().numpy(),
    label=lambda df: df[args.rating_col].gt(0).astype(int),
)

In [None]:
eval_classification_df.head(3)

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,classification_proba,label
260331,AGQFM7GX5UGRCK5F6EEGEEB25FKQ,B01C4W2P18,0.0,2021-06-02 04:46:55.478,10995,2482,0.583165,0
259198,AGYITA5HB3G7B5UQIIYBVCPLRFVA,B07GXDLJP9,0.0,2022-01-20 17:33:34.286,12058,3441,0.505395,0
258841,AFKBE4VLE3XEQ5IHUZI2Q5KAKFCQ,B00XAJD7NA,0.0,2021-08-15 20:42:07.057,6169,2147,0.511077,0


In [None]:
classification_report = log_classification_metrics(
    args,
    eval_classification_df,
    target_col="label",
    prediction_col="classification_proba",
)

## Ranking metrics

In [None]:
val_recs_df

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice
129191,AGSP5XAQPQBUUXZHEZSC65FD7NOQ,B004FV4ROA,1.0,2020-12-27 00:30:31.146,11295,528
128040,AEHS7YR7BGGWMZS24H5UR5IP46HQ,B08F1P3BCC,2.0,2020-12-27 01:44:52.242,1784,3925
128167,AGAVHCK42EGMVS7DGPRX6HBCUCNQ,B09Q3NR84W,5.0,2020-12-27 02:25:48.357,9042,4273
127542,AEFVBMCJAFNULDI5V2CKKTBCPURA,B07N1L5HX1,5.0,2020-12-27 02:32:15.171,1542,3550
258856,AGLXMKHBLTBNT3X2CLBAPW6QUTQA,B001EAQTRI,0.0,2020-12-27 03:37:22.772,10418,212
...,...,...,...,...,...,...
258901,AGGDNWGN3NDJ2DI5CBSFOMUAM6XA,B076XFGK32,0.0,2022-02-18 19:43:25.492,9711,3115
258608,AEKUF6AOVWDWFYOKPWO2CV72PEDQ,B009OBCAW2,0.0,2022-02-19 01:32:51.519,2171,1042
129601,AFBTD25HPE4BE4LUFV3DTI2E2N2A,B07TMJ8S5Z,5.0,2022-02-19 16:49:57.966,5159,3699
130454,AHLN6GKTKZE22AON34YAQXTGK63A,B0C682GZ5X,5.0,2022-02-19 17:28:55.519,14550,4772


In [None]:
recommendations = best_model.recommend(
    torch.tensor(val_recs_df["user_indice"].values, device=args.device),
    top_k=args.top_K,
    batch_size=1)

Generating recommendations:   0%|          | 0/2424 [00:00<?, ?it/s]

In [None]:
recommendations_df = pd.DataFrame(recommendations).pipe(
    create_rec_df, idm, args.user_col, args.item_col
)
recommendations_df

Unnamed: 0,user_indice,recommendation,score,rec_ranking,user_id,parent_asin
0,11295,2507,0.636937,1.0,AGSP5XAQPQBUUXZHEZSC65FD7NOQ,B01CW4AR9K
1,11295,4490,0.635667,2.0,AGSP5XAQPQBUUXZHEZSC65FD7NOQ,B0BD7FN8K9
2,11295,4254,0.632372,3.0,AGSP5XAQPQBUUXZHEZSC65FD7NOQ,B09N72XPK3
3,11295,2723,0.630959,4.0,AGSP5XAQPQBUUXZHEZSC65FD7NOQ,B01LLANEAU
4,11295,3024,0.628205,5.0,AGSP5XAQPQBUUXZHEZSC65FD7NOQ,B07456BG8N
...,...,...,...,...,...,...
242395,2446,1962,0.574869,96.0,AEMYBWDN67IB5IBTMHLHN76V4QHQ,B00STNUB04
242396,2446,2112,0.574857,97.0,AEMYBWDN67IB5IBTMHLHN76V4QHQ,B00WUICTXG
242397,2446,698,0.574786,98.0,AEMYBWDN67IB5IBTMHLHN76V4QHQ,B005M08NE8
242398,2446,3996,0.574724,99.0,AEMYBWDN67IB5IBTMHLHN76V4QHQ,B08PBS5CBX


In [None]:
label_df = create_label_df(
    val_df,
    user_col=args.user_col,
    item_col=args.item_col,
    rating_col=args.rating_col,
    timestamp_col=args.timestamp_col,
)
label_df

Unnamed: 0,user_id,parent_asin,rating,rating_rank
129956,AEMYBWDN67IB5IBTMHLHN76V4QHQ,B091K4WYD1,4.0,1.0
127398,AHZ6GFHFM6Z7CRPSXRIYQ5Z7GERQ,B07JMQP6T6,5.0,1.0
130809,AFQZQHAMZHP54BLVW3AZG2NDKAQA,B01N27P7ME,3.0,1.0
129383,AH7L2ZE36P7Q7ZDTDE2FIWWBU7ZA,B0B5J7MLTS,5.0,1.0
128262,AGOAZS3ZJNV74POYA7OW2JBZYAQQ,B0B2Y5WYRG,5.0,1.0
...,...,...,...,...
259295,AFKERAMSXU4MWO3H53R7DEFOHUVQ,B0BSF17PM2,0.0,17.0
259294,AFKERAMSXU4MWO3H53R7DEFOHUVQ,B003XRES32,0.0,18.0
258813,AEN2KQVSR5TWRXNQS3OTFT4EZQCA,B07D4Z36V8,0.0,18.0
259293,AFKERAMSXU4MWO3H53R7DEFOHUVQ,B0051VVOB2,0.0,19.0


In [None]:
eval_df = merge_recs_with_target(
    recommendations_df,
    label_df,
    k=args.top_K,
    user_col=args.user_col,
    item_col=args.item_col,
    rating_col=args.rating_col,
)
eval_df

Unnamed: 0,user_indice,recommendation,score,rec_ranking,user_id,parent_asin,rating,rating_rank
8,8.0,603.0,0.636124,1,AE24AB4DW5KYK3F5DYOT5VPW2VLA,B004XXMUCQ,0,
14,8.0,752.0,0.634447,2,AE24AB4DW5KYK3F5DYOT5VPW2VLA,B006C1ILUC,0,
3,8.0,161.0,0.634399,3,AE24AB4DW5KYK3F5DYOT5VPW2VLA,B000Z80ICM,0,
87,8.0,3937.0,0.634391,4,AE24AB4DW5KYK3F5DYOT5VPW2VLA,B08HKGXGML,0,
29,8.0,1349.0,0.630586,5,AE24AB4DW5KYK3F5DYOT5VPW2VLA,B00EHFJJHY,0,
...,...,...,...,...,...,...,...,...
249029,16403.0,4344.0,0.578279,98,AHZZM7BCJAF2UEMMBHZCLXBB2SVA,B09Y5PWHZM,0,
248996,16403.0,3078.0,0.578176,99,AHZZM7BCJAF2UEMMBHZCLXBB2SVA,B075LX7S6B,0,
248992,16403.0,2891.0,0.578106,100,AHZZM7BCJAF2UEMMBHZCLXBB2SVA,B06X9FX834,0,
248997,,,,101,AHZZM7BCJAF2UEMMBHZCLXBB2SVA,B075QC3TZY,1,1.0


In [None]:
ranking_report = log_ranking_metrics(args, eval_df)


invalid value encountered in divide



In [None]:
mlflow.end_run()

🏃 View run 005-two-tower-128-128-neg-4 at: http://138.2.61.6:5002/#/experiments/2/runs/89c2a8496f684e0a8ecf450a33b57793
🧪 View experiment at: http://138.2.61.6:5002/#/experiments/2


## Register model

In [None]:
run_id = "fd51bfab00c54c5290ce3691d6aafbbd"
mlf_client = mlflow.tracking.MlflowClient()

In [None]:
# Persit idm to mlflow
mlf_client.log_artifact(
    run_id=run_id,
    local_path=idm_path,
)

In [None]:
from src.algo.two_tower.inference import TwoTowerInferenceWrapper
from mlflow.models.signature import infer_signature


[33mAdd type hints to the `predict` method to enable data validation and automatic signature inference during model logging. Check https://mlflow.org/docs/latest/model/python_model.html#type-hint-usage-in-pythonmodel for more details.[0m



In [None]:
best_trainer = TwoTowerLitModule.load_from_checkpoint(
    checkpoint_path="/home/dinhln/Desktop/real_time_recsys/notebooks/data/first-attempt/005-two-tower-128-128-neg-4/checkpoints/best-checkpoint-v2.ckpt",
    model = init_model(n_users, n_items, 128,128, args.dropout),
)

best_model = best_trainer.model.cpu()

FileNotFoundError: [Errno 2] No such file or directory: '/home/dinhln/Desktop/real_time_recsys/notebooks/data/first-attempt/005-two-tower-128-128-neg-4/checkpoints/best-checkpoint-v2.ckpt'

In [None]:
inferrer = TwoTowerInferenceWrapper(best_model)

In [None]:
sample_user_indices = train_df[args.user_col].sample(1).values[0]
sample_item_indices = train_df[args.item_col].sample(1).values[0]

logger.info(f"Sample user: {sample_user_indices}, Sample item: {sample_item_indices}")
sample_input = {"user_id": [sample_user_indices], "item_id": [sample_item_indices]}
sample_output = inferrer.infer([0], [1])
sample_output

[32m2025-05-13 16:39:53.873[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mSample user: AGQ5ERLI2VUZVYLQV5WYJ5TLGVYA, Sample item: B0C2P7CNWG[0m


array([0.09537707], dtype=float32)

In [None]:
signature = infer_signature(sample_input,sample_output,
)

In [None]:
with mlflow.start_run(run_id=run_id):
    mlflow.pyfunc.log_model(
        python_model=inferrer,
        artifact_path="inferrer",
        artifacts={"idm": mlflow.get_artifact_uri("idm_16407u.json")},
        signature=signature,
        input_example=sample_input,
        registered_model_name="two-tower"
    )

2025/05/13 16:48:02 INFO mlflow.pyfunc: Validating input example against model signature


Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Successfully registered model 'two-tower'.
2025/05/13 16:48:10 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: two-tower, version 1
Created version '1' of model 'two-tower'.


🏃 View run 005-two-tower-128-128-neg-4 at: http://138.2.61.6:5002/#/experiments/2/runs/fd51bfab00c54c5290ce3691d6aafbbd
🧪 View experiment at: http://138.2.61.6:5002/#/experiments/2


In [None]:
model_version = (
            mlf_client.get_registered_model("two-tower")
            .latest_versions[0]
            .version
        )

In [None]:
mlf_client.set_registered_model_alias(
            name="two-tower", alias="champion", version=model_version
        )

In [None]:
mlf_client.set_model_version_tag(
            name="two-tower",
            version=model_version,
            key="author",
            value="dinhln",
        )