## Setup

In [1]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

In [2]:
import pandas as pd
from pydantic import BaseModel
import sys
import os
from lightning.pytorch.loggers import MLFlowLogger
from loguru import logger
from load_dotenv import load_dotenv
import torch
from torch.utils.data import DataLoader
import lightning as L
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint
import mlflow

sys.path.insert(0, "..")

from src.utils.embedding_id_mapper import IDMapper
from src.algo.sequence_two_tower.model import SequenceRatingPrediction
from src.algo.sequence_two_tower.dataset import UserItemBinaryRatingDFDataset
from src.algo.sequence_two_tower.trainer import SeqModellingLitModule
from src.eval.utils import create_rec_df, create_label_df, merge_recs_with_target
from src.eval.log_metrics import log_ranking_metrics, log_classification_metrics
from src.domain.model_request import SequenceModelRequest

In [3]:
load_dotenv(override = True)

True

In [4]:
class Args(BaseModel):
    testing: bool = False
    log_to_mlflow: bool = True
    experiment_name: str = "first-attempt"
    notebook_persit_dp: str = None
    
    run_name: str = None

    user_col: str = "user_id"
    item_col: str = "parent_asin"
    rating_col: str = "rating"
    timestamp_col: str = "timestamp"
    group_name: str = "seq-modelling"

    top_K: int = 100
    top_k: int = 10

    batch_size: int = 128
    learning_rate: float = 0.001
    l2_reg: float = 1e-6
    early_stopping_patience: int = 10
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    max_epochs: int = 100

    # TwoTower specific
    dropout: float = 0.3
    embedding_dim: int = 256
    
    # Num negative sample
    negative_samples: int = 1


    train_data_fp: str = os.path.abspath("../data_for_ai/interim/train_sample_interactions_16407u_neg_seq.parquet")
    val_data_fp: str = os.path.abspath("../data_for_ai/interim/val_sample_interactions_16407u_neg_seq.parquet")

    best_checkpoint_path: str = None
    def init(self):
        self.run_name: str = f"006-sequence-modelling-attn-{self.embedding_dim}-dim-bce-prelu"
        self.notebook_persit_dp = os.path.abspath(f"data/{self.experiment_name}/{self.run_name}")

        data_prefix = "" if self.negative_samples == 1 else f"_{self.negative_samples}"
        if not os.path.exists(self.train_data_fp.split(".parquet")[0] + f"{data_prefix}.parquet"):
            logger.warning(
                f"Train data file {self.train_data_fp} does not exist. "
                "Set negative_samples to 1"
            )
            data_prefix = ""
            self.negative_samples = 1

        self.train_data_fp = self.train_data_fp.split(".parquet")[0] + f"{data_prefix}.parquet"
        self.val_data_fp = self.val_data_fp.split(".parquet")[0] + f"{data_prefix}.parquet"
        
        logger.info(
            f"Using train data: {self.train_data_fp}, val data: {self.val_data_fp}, "
            f"with negative samples: {self.negative_samples}"
        )

        if not (mlflow_uri := os.environ.get("MLFLOW_TRACKING_URI")):
            self.log_to_mlflow = False
            logger.warning("MLFlow is not enabled. Turn off tracking to Mlflow.")

        if self.log_to_mlflow:
            logger.info(
                f"Setting up Mlflow experiment: {self.experiment_name}, run_name: {self.run_name}"
            )

            self._mlf_logger = MLFlowLogger(
                experiment_name=self.experiment_name,
                run_name=self.run_name,
                tracking_uri=mlflow_uri,
                log_model=False,  # Turn off to False due to Lightning internal bug.
            )

        if not self.testing:
            os.makedirs(self.notebook_persit_dp, exist_ok=True)
        return self
    
args = Args().init()
print(args.model_dump_json(indent=2))

[32m2025-07-01 18:52:24.891[0m | [1mINFO    [0m | [36m__main__[0m:[36minit[0m:[36m53[0m - [1mUsing train data: /home/dinhln/Desktop/real_time_recsys/data_for_ai/interim/train_sample_interactions_16407u_neg_seq.parquet, val data: /home/dinhln/Desktop/real_time_recsys/data_for_ai/interim/val_sample_interactions_16407u_neg_seq.parquet, with negative samples: 1[0m
[32m2025-07-01 18:52:24.892[0m | [1mINFO    [0m | [36m__main__[0m:[36minit[0m:[36m63[0m - [1mSetting up Mlflow experiment: first-attempt, run_name: 006-sequence-modelling-attn-256-dim-bce-prelu[0m


{
  "testing": false,
  "log_to_mlflow": true,
  "experiment_name": "first-attempt",
  "notebook_persit_dp": "/home/dinhln/Desktop/real_time_recsys/notebooks/data/first-attempt/006-sequence-modelling-attn-256-dim-bce-prelu",
  "run_name": "006-sequence-modelling-attn-256-dim-bce-prelu",
  "user_col": "user_id",
  "item_col": "parent_asin",
  "rating_col": "rating",
  "timestamp_col": "timestamp",
  "group_name": "seq-modelling",
  "top_K": 100,
  "top_k": 10,
  "batch_size": 128,
  "learning_rate": 0.001,
  "l2_reg": 1e-6,
  "early_stopping_patience": 10,
  "device": "cuda",
  "max_epochs": 100,
  "dropout": 0.3,
  "embedding_dim": 256,
  "negative_samples": 1,
  "train_data_fp": "/home/dinhln/Desktop/real_time_recsys/data_for_ai/interim/train_sample_interactions_16407u_neg_seq.parquet",
  "val_data_fp": "/home/dinhln/Desktop/real_time_recsys/data_for_ai/interim/val_sample_interactions_16407u_neg_seq.parquet",
  "best_checkpoint_path": null
}


## Init model

In [5]:
def init_model(n_users, n_items, embedding_dim, dropout, item_embedding=None):
    return SequenceRatingPrediction(
        item_embedding=item_embedding,
        num_users=n_users,
        num_items=n_items,
        embedding_dim=embedding_dim,
        dropout=dropout,
        use_user_embedding= False
    )

## Test implementation

In [6]:
embedding_dim = 32
batch_size = 2

# Mock data
user_indices = [0, 0, 1, 2, 2]
item_indices = [0, 1, 2, 3, 4]
timestamps = [0, 1, 2, 3, 4]
ratings = [0, 3, 1, 3, 0]
# item_sequences = [
#     [2, 3, -1, -1],
#     [2, 4, -1, -1],
#     [1, 3, -1, -1],
#     [2, 1, -1, -1],
#     [4, 1, -1, -1],
# ]

item_sequences = [
    [-1, -1, 2, 3],
    [-1, -1, 2, 4],
    [-1, -1, 1, 3],
    [-1, -1, 2, 1],
    [-1, -1, 4, 1],
]


n_users = len(set(user_indices))

n_items = len(set(item_indices))

train_df = pd.DataFrame(
    {
        "user_indice": user_indices,
        "item_indice": item_indices,
        args.timestamp_col: timestamps,
        args.rating_col: ratings,
        "item_sequence": item_sequences,
    }
)

model = init_model(n_users, n_items, embedding_dim, args.dropout)

# Example forward pass
model.eval()
user = torch.tensor([0])
item_sequence = torch.tensor([[-0, 1, -1, -1]])
target_item = torch.tensor([2])
input = SequenceModelRequest(
    user_id=user,
    item_sequence=item_sequence,
    target_item=target_item,
    recommendation=False)
predictions = model(input)
print(predictions)
model.train()


enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.activation_relu_or_gelu was not True

[32m2025-07-01 18:53:27.023[0m | [1mINFO    [0m | [36msrc.algo.sequence_two_tower.model[0m:[36m__init__[0m:[36m122[0m - [1mStart token used: 4, Padding token used: 5[0m


tensor([0.5235], grad_fn=<DivBackward0>)


SequenceRatingPrediction(
  (item_embedding): Embedding(6, 32, padding_idx=5)
  (encoder_layer): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
    )
    (linear1): Linear(in_features=32, out_features=32, bias=True)
    (dropout): Dropout(p=0.3, inplace=False)
    (linear2): Linear(in_features=32, out_features=32, bias=True)
    (norm1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.3, inplace=False)
    (dropout2): Dropout(p=0.3, inplace=False)
    (activation): PReLU(num_parameters=1)
  )
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
        )
        (linear1): Linear(in_features=32, out_feat

In [7]:
rating_dataset = UserItemBinaryRatingDFDataset(
    train_df, "user_indice", "item_indice", args.rating_col, args.timestamp_col,"item_sequence"
)

train_loader = DataLoader(
    rating_dataset, batch_size=batch_size, shuffle=False, drop_last=True
)

In [8]:
for batch_input in train_loader:
    print(batch_input)

{'user': tensor([0, 0]), 'item': tensor([0, 1]), 'rating': tensor([0., 1.]), 'item_sequence': tensor([[-1, -1,  2,  3],
        [-1, -1,  2,  4]], dtype=torch.int32)}
{'user': tensor([1, 2]), 'item': tensor([2, 3]), 'rating': tensor([1., 1.]), 'item_sequence': tensor([[-1, -1,  1,  3],
        [-1, -1,  2,  1]], dtype=torch.int32)}


In [14]:
# model
lit_model = SeqModellingLitModule(model, log_dir=args.notebook_persit_dp)

# train model
trainer = L.Trainer(
    default_root_dir=f"{args.notebook_persit_dp}/test",
    max_epochs=100,
    accelerator=args.device if args.device else "auto",
)
trainer.fit(
    model=lit_model, train_dataloaders=train_loader, val_dataloaders=train_loader
)

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name               | Type                     | Params | Mode 
------------------------------------------------------------------------
0 | model              | SequenceRatingPrediction | 15.2 K | eval 
1 | val_roc_auc_metric | BinaryAUROC              | 0      | train
2 | val_pr_auc_metric  | BinaryAveragePrecision   | 0      | train
------------------------------------------------------------------------
15.2 K    Trainable params
0         Non-trainable params
15.2 K    Total params
0.061     Total estimated model params size (MB)
2         Modules in train mode
30        Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


In [18]:
model.eval()
user = None
item_sequence = torch.tensor([[-1, -1, 2, 1]])
target_item = torch.tensor([3])

input_data = SequenceModelRequest(
    user_id=user,
    item_sequence=item_sequence,
    target_item=target_item,
    recommendation=False)
predictions = model.predict(input_data)
print(predictions)

tensor([1.], grad_fn=<DivBackward0>)


## Training loop

In [19]:
train_df = pd.read_parquet(args.train_data_fp)
val_df = pd.read_parquet(args.val_data_fp)

assert set(val_df[args.user_col].unique()).issubset(set(train_df[args.user_col].unique())), "Validation users must be present in training users."

assert set(val_df[args.item_col].unique()).issubset(set(train_df[args.item_col].unique())), "Validation items must be present in training items."
assert train_df[args.timestamp_col].max() < val_df[args.timestamp_col].min(), "Validation data must be after training data. Otherwise, its a data contamination problem."

In [20]:
train_df

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,item_sequence
151343,AEEV5YWQKPBTLFWHKOBBULYA2RDQ,B009RUZ7TS,0.0,2014-07-17 19:15:55.000,1412,4220,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 455..."
40958,AF7KZV4NJ5GBDVFTB7PEEUN4U53A,B0BBMLD8QT,5.0,2015-07-29 20:38:06.000,4871,4476,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1...."
218918,AFVQ4K4KZPLQ3E2VFYSGX6HFXGNQ,B0BB6R89VF,0.0,2017-12-13 20:35:02.334,7616,1218,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 129..."
43115,AFCLWJMGYFCOJQR7T4454OF5A5WA,B00ENFP224,5.0,2015-09-06 12:09:59.000,5250,1355,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1...."
233421,AFP4PHJ6Q2RRXLDPSDSH6VXJRUTA,B07CMXS5FP,0.0,2018-11-23 09:44:21.734,6792,838,"[-1.0, -1.0, -1.0, 1055.0, 3572.0, 3865.0, 176..."
...,...,...,...,...,...,...,...
250960,AGQHC7YNLYP4QV2PSBD6URSMJSVA,B07H65KP63,0.0,2020-02-08 04:09:50.457,11001,3568,"[-1.0, -1.0, -1.0, -1.0, 3585.0, 1866.0, 4040...."
217058,AHD65JAOVTTPDNJWOLSSGS3QVK6Q,B07DKMJ61N,0.0,2017-11-02 15:25:18.351,13410,4239,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1...."
61324,AF32PWYNLPCVAU4UX35IEAZOFA3Q,B011BRUOMO,5.0,2016-07-18 05:42:21.000,4264,2253,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1...."
132003,AGM65FYYAPHOLESGIDMFMPUQIYNA,B0016BVDIK,0.0,2010-12-16 19:59:19.000,10445,4250,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1...."


In [21]:
with pd.option_context("display.max_colwidth", None):
    display(train_df.loc[train_df["user_id"] == "AEEV5YWQKPBTLFWHKOBBULYA2RDQ"].sort_values(by=args.timestamp_col, ascending=False))

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,item_sequence
172167,AEEV5YWQKPBTLFWHKOBBULYA2RDQ,B07C1RSV9C,0.0,2015-08-05 16:31:49,1412,934,"[-1.0, -1.0, -1.0, 4559.0, 4443.0, 3164.0, 1047.0, 4685.0, 107.0, 3295.0]"
41296,AEEV5YWQKPBTLFWHKOBBULYA2RDQ,B07C1RSV9C,5.0,2015-08-05 16:31:49,1412,3276,"[-1.0, -1.0, -1.0, 4559.0, 4443.0, 3164.0, 1047.0, 4685.0, 107.0, 3295.0]"
151346,AEEV5YWQKPBTLFWHKOBBULYA2RDQ,B07CB22VVJ,0.0,2014-07-17 19:19:28,1412,4599,"[-1.0, -1.0, -1.0, -1.0, 4559.0, 4443.0, 3164.0, 1047.0, 4685.0, 107.0]"
20475,AEEV5YWQKPBTLFWHKOBBULYA2RDQ,B07CB22VVJ,5.0,2014-07-17 19:19:28,1412,3295,"[-1.0, -1.0, -1.0, -1.0, 4559.0, 4443.0, 3164.0, 1047.0, 4685.0, 107.0]"
20474,AEEV5YWQKPBTLFWHKOBBULYA2RDQ,B000I23TTE,5.0,2014-07-17 19:16:43,1412,107,"[-1.0, -1.0, -1.0, -1.0, -1.0, 4559.0, 4443.0, 3164.0, 1047.0, 4685.0]"
151345,AEEV5YWQKPBTLFWHKOBBULYA2RDQ,B000I23TTE,0.0,2014-07-17 19:16:43,1412,4659,"[-1.0, -1.0, -1.0, -1.0, -1.0, 4559.0, 4443.0, 3164.0, 1047.0, 4685.0]"
151344,AEEV5YWQKPBTLFWHKOBBULYA2RDQ,B0BYSP9676,0.0,2014-07-17 19:16:20,1412,3678,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 4559.0, 4443.0, 3164.0, 1047.0]"
20473,AEEV5YWQKPBTLFWHKOBBULYA2RDQ,B0BYSP9676,5.0,2014-07-17 19:16:20,1412,4685,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 4559.0, 4443.0, 3164.0, 1047.0]"
151343,AEEV5YWQKPBTLFWHKOBBULYA2RDQ,B009RUZ7TS,0.0,2014-07-17 19:15:55,1412,4220,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 4559.0, 4443.0, 3164.0]"
20472,AEEV5YWQKPBTLFWHKOBBULYA2RDQ,B009RUZ7TS,5.0,2014-07-17 19:15:55,1412,1047,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 4559.0, 4443.0, 3164.0]"


## Convert user_id and item_id into indices

In [39]:
# idm_path = os.path.abspath("../data_for_ai/interim/idm_16407u.json")
# idm = IDMapper().load(idm_path)
# idm.get_user_id(1)

In [40]:
# train_df = train_df.pipe(idm.map_indices)
# val_df = val_df.pipe(idm.map_indices)

# assert idm.unknown_item_index not in train_df["item_indice"].values, "Unknown item index must be present in training data."
# assert idm.unknown_user_index not in train_df["user_indice"].values, "Unknown user index must be present in training data."
# assert idm.unknown_item_index not in val_df["item_indice"].values, "Unknown item index must be present in validation data."
# assert idm.unknown_user_index not in val_df["user_indice"].values, "Unknown user index must be present in validation data."

In [41]:
# train_df.head(3)

In [42]:
# assert train_df.groupby(args.user_col)[args.item_col].nunique().min() >= 5, "Each user must have at least five items."
# assert train_df.groupby(args.item_col)[args.user_col].nunique().min() >= 10, "Each item must have at least ten users."

## Training loop

In [22]:
rating_dataset = UserItemBinaryRatingDFDataset(
    train_df, "user_indice", "item_indice", args.rating_col, args.timestamp_col, "item_sequence"
)
val_rating_dataset = UserItemBinaryRatingDFDataset(
    val_df, "user_indice", "item_indice", args.rating_col, args.timestamp_col, "item_sequence"
)

train_loader = DataLoader(
    rating_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True
)
val_loader = DataLoader(
    val_rating_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False
)

In [23]:
item_indices = train_df[args.item_col].unique()
user_indices = train_df[args.user_col].unique()
n_items = len(item_indices)
n_users = len(user_indices)

logger.info(f"Number of users: {n_users}, Number of items: {n_items}")
model = init_model(n_users, n_items, args.embedding_dim, args.dropout)

[32m2025-07-01 18:58:34.499[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mNumber of users: 16407, Number of items: 4817[0m

enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.activation_relu_or_gelu was not True

[32m2025-07-01 18:58:34.509[0m | [1mINFO    [0m | [36msrc.algo.sequence_two_tower.model[0m:[36m__init__[0m:[36m122[0m - [1mStart token used: 4816, Padding token used: 4817[0m


In [24]:
idm_path = os.path.abspath("../data_for_ai/interim/idm_16407u.json")
idm = IDMapper().load(idm_path)
idm.get_user_id(1)

'AE227WAM4NWQPJI33OPN7ZARNNZQ'

## Overfit 1 batch

In [25]:
early_stopping = EarlyStopping(
    monitor="val_loss", patience=5, mode="min", verbose=False
)

model = init_model(n_users, n_items, args.embedding_dim, args.dropout)
lit_model = SeqModellingLitModule(
    model,
    learning_rate=args.learning_rate,
    l2_reg=args.l2_reg,
    log_dir=args.notebook_persit_dp,
    accelerator=args.device,
    idm= idm,
    negative_samples=args.negative_samples,
)

log_dir = f"{args.notebook_persit_dp}/logs/overfit"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    accelerator=args.device if args.device else "auto",
    max_epochs=300,
    overfit_batches=1,
    callbacks=[early_stopping],
)
trainer.fit(
    model=lit_model,
    train_dataloaders=train_loader,
    val_dataloaders=train_loader,
)
logger.info(f"Logs available at {trainer.log_dir}")

[32m2025-07-01 18:59:31.110[0m | [1mINFO    [0m | [36msrc.algo.sequence_two_tower.model[0m:[36m__init__[0m:[36m122[0m - [1mStart token used: 4816, Padding token used: 4817[0m
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(overfit_batches=1)` was configured so 1 batch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name               | Type                     | Params | Mode 
------------------------------------------------------------------------
0 | model              | SequenceRatingPrediction | 2.2 M  | train
1 | val_roc_auc_metric | BinaryAUROC              | 0      | train
2 | val_pr_auc_metric  | BinaryAveragePrecision   | 0      | train
---------------------------------------

Sanity Checking: |          | 0/? [00:00<?, ?it/s]


You requested to overfit but enabled val dataloader shuffling. We are turning off the val dataloader shuffling for you.


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


You requested to overfit but enabled train dataloader shuffling. We are turning off the train dataloader shuffling for you.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.



Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=300` reached.
[32m2025-07-01 19:00:06.590[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m31[0m - [1mLogs available at /home/dinhln/Desktop/real_time_recsys/notebooks/data/first-attempt/006-sequence-modelling-attn-256-dim-bce-prelu/logs/overfit/lightning_logs/version_34[0m


## Run on all data

In [26]:
early_stopping = EarlyStopping(
    monitor="val_roc_auc", patience=args.early_stopping_patience, mode="max", verbose=False, min_delta=0.001
)

checkpoint_callback = ModelCheckpoint(
    dirpath=f"{args.notebook_persit_dp}/checkpoints",
    filename="best-checkpoint",
    save_top_k=1,
    monitor="val_roc_auc",
    mode="max",
)

model = init_model(n_users, n_items, args.embedding_dim, args.dropout)

print(f"Model: {model}")
lit_model = SeqModellingLitModule(
    model,
    learning_rate=args.learning_rate,
    l2_reg=args.l2_reg,
    log_dir=args.notebook_persit_dp,
    accelerator=args.device,
    idm= idm,
    negative_samples=args.negative_samples,
)

log_dir = f"{args.notebook_persit_dp}/logs/run"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    accelerator=args.device if args.device else "auto",
    max_epochs=args.max_epochs,
    callbacks=[early_stopping, checkpoint_callback],
    logger=args._mlf_logger if args.log_to_mlflow else None,
    
)
trainer.fit(
    model=lit_model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
    
)

# Change the library as a workaround for the issue in the latest Lightning release
#https://github.com/Lightning-AI/pytorch-lightning/pull/20669/commits/429f732a0528c558e701da7ec01e51c1e2e4f32e

[32m2025-07-01 19:00:24.621[0m | [1mINFO    [0m | [36msrc.algo.sequence_two_tower.model[0m:[36m__init__[0m:[36m122[0m - [1mStart token used: 4816, Padding token used: 4817[0m
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Model: SequenceRatingPrediction(
  (item_embedding): Embedding(4818, 256, padding_idx=4817)
  (encoder_layer): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
    )
    (linear1): Linear(in_features=256, out_features=256, bias=True)
    (dropout): Dropout(p=0.3, inplace=False)
    (linear2): Linear(in_features=256, out_features=256, bias=True)
    (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.3, inplace=False)
    (dropout2): Dropout(p=0.3, inplace=False)
    (activation): PReLU(num_parameters=1)
  )
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(


Checkpoint directory /home/dinhln/Desktop/real_time_recsys/notebooks/data/first-attempt/006-sequence-modelling-attn-256-dim-bce-prelu/checkpoints exists and is not empty.

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name               | Type                     | Params | Mode 
------------------------------------------------------------------------
0 | model              | SequenceRatingPrediction | 2.2 M  | train
1 | val_roc_auc_metric | BinaryAUROC              | 0      | train
2 | val_pr_auc_metric  | BinaryAveragePrecision   | 0      | train
------------------------------------------------------------------------
2.2 M     Trainable params
0         Non-trainable params
2.2 M     Total params
8.626     Total estimated model params size (MB)
32        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.




The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.



Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

🏃 View run 006-sequence-modelling-attn-256-dim-bce-prelu at: http://138.2.61.6:5002/#/experiments/2/runs/700c360745634be8869475cc07ba3efb
🧪 View experiment at: http://138.2.61.6:5002/#/experiments/2


In [27]:
from src.eval.recommendation import RankingMetricComputer

In [28]:
best_model_path = trainer.checkpoint_callback.best_model_path

best_model_path

'/home/dinhln/Desktop/real_time_recsys/notebooks/data/first-attempt/006-sequence-modelling-attn-256-dim-bce-prelu/checkpoints/best-checkpoint-v27.ckpt'

In [29]:
# best_model_path = trainer.checkpoint_callback.best_model_path
best_model = SeqModellingLitModule.load_from_checkpoint(
    best_model_path, model = init_model(n_users, n_items, args.embedding_dim, args.dropout)
).model


best_model.eval()


enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.activation_relu_or_gelu was not True

[32m2025-07-01 19:18:46.118[0m | [1mINFO    [0m | [36msrc.algo.sequence_two_tower.model[0m:[36m__init__[0m:[36m122[0m - [1mStart token used: 4816, Padding token used: 4817[0m


SequenceRatingPrediction(
  (item_embedding): Embedding(4818, 256, padding_idx=4817)
  (encoder_layer): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
    )
    (linear1): Linear(in_features=256, out_features=256, bias=True)
    (dropout): Dropout(p=0.3, inplace=False)
    (linear2): Linear(in_features=256, out_features=256, bias=True)
    (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.3, inplace=False)
    (dropout2): Dropout(p=0.3, inplace=False)
    (activation): PReLU(num_parameters=1)
  )
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_feat

In [33]:

ranking_computer = RankingMetricComputer(
    rec_model = best_model,
    batch_size= 128,
    mlf_client=args._mlf_logger.experiment if args.log_to_mlflow else None,
    evidently_report_fp= log_dir)

ranking_computer.calculate(
    val_df,
    args._mlf_logger.run_id if args.log_to_mlflow else None,
    device=args.device,
    log_to_mlflow=True
)

[32m2025-07-01 19:26:04.177[0m | [1mINFO    [0m | [36msrc.domain.model_request[0m:[36mfrom_df_for_rec[0m:[36m39[0m - [1mUse user_col=user_indice[0m


AttributeError: 'ModelRequest' object has no attribute 'item_sequence'

In [56]:
mlflow.end_run()

🏃 View run 006-sequence-modelling-attn-256-dim-bce-prelu at: http://138.2.61.6:5002/#/experiments/2/runs/03bd52f2d031411cabccc87a7a06d1c0
🧪 View experiment at: http://138.2.61.6:5002/#/experiments/2


## Clean up

In [57]:
all_params = [args]

if args.log_to_mlflow:
    run_id = trainer.logger.run_id

    with mlflow.start_run(run_id=run_id):
        for params in all_params:
            params_dict = params.model_dump()
            params_ = dict()
            for k, v in params_dict.items():
                if k == "top_K":
                    k = "top_big_K"
                if k == "top_k":
                    k = "top_small_k"
                params_[f"{params.__repr_name__()}.{k}"] = v
            mlflow.log_params(params_)

🏃 View run 006-sequence-modelling-attn-256-dim-bce-prelu at: http://138.2.61.6:5002/#/experiments/2/runs/03bd52f2d031411cabccc87a7a06d1c0
🧪 View experiment at: http://138.2.61.6:5002/#/experiments/2


## Log model and idm

In [None]:
if args.log_to_mlflow:
    # Persist id_mapping so that at inference we can predict based on item_ids (string) instead of item_index
    run_id = trainer.logger.run_id
    mlf_client = trainer.logger.experiment
    mlf_client.log_artifact(run_id, idm_path)

In [32]:
from src.algo.sequence.inference import SequenceRatingPredictionInferenceWrapper
inferrer = SequenceRatingPredictionInferenceWrapper(best_model)

In [33]:
sample_input = {
    "user_ids": [idm.get_user_id(0)],
    "item_sequences": [[idm.get_item_id(0), idm.get_item_id(1)]],
    "item_ids": [idm.get_item_id(0)],
}
sample_output = inferrer.infer([0], [[0, 1]], [0])
sample_output

array([0.5657515], dtype=float32)

In [34]:
from mlflow.models import infer_signature

In [35]:
mlflow.end_run()

In [51]:
if args.log_to_mlflow:
    # run_id = trainer.logger.run_id
    run_id = "03bd52f2d031411cabccc87a7a06d1c0"
    sample_output_np = sample_output
    signature = infer_signature(sample_input, sample_output_np)
    idm_filename = idm_path.split("/")[-1]
    with mlflow.start_run(run_id=run_id):
        mlflow.pyfunc.log_model(
            python_model=inferrer,
            artifact_path="inferrer",
            # We log the id_mapping to the predict function so that it can accept item_id and automatically convert ot item_indice for PyTorch model to use
            artifacts={"idm_path": "/home/dinhln/Desktop/real_time_recsys/data_for_ai/interim/idm_16407u.json"},
            signature=signature,
            input_example=sample_input,
            registered_model_name="sequence_two_tower",
            model_config={
                "device": "cpu"}
        )

2025/06/28 17:00:35 INFO mlflow.pyfunc: Validating input example against model signature


Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Registered model 'sequence_two_tower' already exists. Creating a new version of this model...
2025/06/28 17:00:48 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: sequence_two_tower, version 10
Created version '10' of model 'sequence_two_tower'.


🏃 View run 006-sequence-modelling-attn-256-dim-bce-prelu at: http://138.2.61.6:5002/#/experiments/2/runs/03bd52f2d031411cabccc87a7a06d1c0
🧪 View experiment at: http://138.2.61.6:5002/#/experiments/2


In [52]:
mlf_client = mlflow.MlflowClient()
if args.log_to_mlflow:
        model_version = (
            mlf_client.get_registered_model("sequence_two_tower")
            .latest_versions[0]
            .version
        )

        mlf_client.set_registered_model_alias(
            name="sequence_two_tower", alias="champion", version=model_version
        )

        mlf_client.set_model_version_tag(
            name="sequence_two_tower",
            version=model_version,
            key="author",
            value="dinh_ln",
        )

In [53]:
import mlflow
model_uri = 'runs:/03bd52f2d031411cabccc87a7a06d1c0/inferrer'
# The model is logged with an input example
pyfunc_model = mlflow.pyfunc.load_model(model_uri)
input_data = pyfunc_model.input_example

Downloading artifacts:   0%|          | 0/8 [00:00<?, ?it/s]

In [54]:
import mlflow
logged_model = 'runs:/03bd52f2d031411cabccc87a7a06d1c0/inferrer'

# Load model as a PyFuncModel.
loaded_model = mlflow.pyfunc.load_model(logged_model)


Downloading artifacts:   0%|          | 0/8 [00:00<?, ?it/s]

In [55]:
loaded_model.predict(input_data)

{'user_ids': ['AE22236AFRRSMQIKGG7TPTB75QEA'],
 'item_sequences': [['0972683275', '1449410243']],
 'item_ids': ['0972683275'],
 'scores': [0.5657538175582886]}