# Sequence modeling for ranking task

# Set up

In [1]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

In [None]:
import os
import sys

import lightning as L
import numpy as np
import pandas as pd
import torch
from dotenv import load_dotenv
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import MLFlowLogger
from loguru import logger
from mlflow.exceptions import MlflowException
from mlflow.models.signature import infer_signature
from pydantic import BaseModel
from torch.utils.data import DataLoader

import mlflow

load_dotenv()

sys.path.insert(0, "..")

from src.dataset import UserItemBinaryDFDataset as UserItemRatingDFDataset
from src.id_mapper import IDMapper
from src.sequence.inference import SequenceRatingPredictionInferenceWrapper
from src.sequence.model import SequenceRatingPrediction
from src.sequence.trainer import LitSequenceRatingPrediction
from src.sequence.utils import generate_item_sequences
from src.viz import blueq_colors



# Controller

In [3]:
# This is a parameter cell used by papermill
max_epochs = 100

In [4]:
class Args(BaseModel):
    testing: bool = False
    author: str = "quy.dinh"
    log_to_mlflow: bool = True
    experiment_name: str = "RecSys MVP - Ranker"
    run_name: str = "000-sequence-modeling-baseline"
    notebook_persist_dp: str = None
    random_seed: int = 41
    device: str = None

    max_epochs: int = max_epochs
    batch_size: int = 128

    user_col: str = "user_id"
    item_col: str = "parent_asin"
    rating_col: str = "rating"
    timestamp_col: str = "timestamp"

    top_K: int = 100
    top_k: int = 10

    batch_size: int = 128

    embedding_dim: int = 128
    dropout: float = 0.3
    early_stopping_patience: int = 5
    learning_rate: float = 0.001
    l2_reg: float = 1e-5

    mlf_item2vec_model_name: str = "item2vec"
    mlf_model_name: str = "sequence_rating_prediction"
    min_roc_auc: float = 0.7

    best_checkpoint_path: str = None

    def init(self):
        self.notebook_persist_dp = os.path.abspath(f"data/{self.run_name}")
        os.makedirs(self.notebook_persist_dp, exist_ok=True)

        if not (mlflow_uri := os.environ.get("MLFLOW_TRACKING_URI")):
            logger.warning(
                f"Environment variable MLFLOW_TRACKING_URI is not set. Setting self.log_to_mlflow to false."
            )
            self.log_to_mlflow = False

        if self.log_to_mlflow:
            logger.info(
                f"Setting up MLflow experiment {self.experiment_name} - run {self.run_name}..."
            )
            self._mlf_logger = MLFlowLogger(
                experiment_name=self.experiment_name,
                run_name=self.run_name,
                tracking_uri=mlflow_uri,
                log_model=True,
            )

        if self.device is None:
            self.device = (
                "cuda"
                if torch.cuda.is_available()
                else "mps" if torch.backends.mps.is_available() else "cpu"
            )

        return self


args = Args().init()

print(args.model_dump_json(indent=2))

[32m2025-03-01 19:52:24.215[0m | [1mINFO    [0m | [36m__main__[0m:[36minit[0m:[36m47[0m - [1mSetting up MLflow experiment RecSys MVP - Ranker - run 000-sequence-modeling-baseline...[0m


{
  "testing": false,
  "author": "quy.dinh",
  "log_to_mlflow": true,
  "experiment_name": "RecSys MVP - Ranker",
  "run_name": "000-sequence-modeling-baseline",
  "notebook_persist_dp": "/home/dvq/frostmourne/recsys-mvp/notebooks/data/000-sequence-modeling-baseline",
  "random_seed": 41,
  "device": "cuda",
  "max_epochs": 100,
  "batch_size": 128,
  "user_col": "user_id",
  "item_col": "parent_asin",
  "rating_col": "rating",
  "timestamp_col": "timestamp",
  "top_K": 100,
  "top_k": 10,
  "embedding_dim": 128,
  "dropout": 0.3,
  "early_stopping_patience": 5,
  "learning_rate": 0.001,
  "l2_reg": 0.00001,
  "mlf_item2vec_model_name": "item2vec",
  "mlf_model_name": "sequence_rating_prediction",
  "min_roc_auc": 0.7,
  "best_checkpoint_path": null
}


# Implement

In [5]:
def init_model(n_users, n_items, embedding_dim, dropout, item_embedding=None):
    model = SequenceRatingPrediction(
        n_users, n_items, embedding_dim, dropout=dropout, item_embedding=item_embedding
    )
    return model

## Load pretrained Item2Vec embeddings

In [6]:
mlf_client = mlflow.MlflowClient()
model = mlflow.pyfunc.load_model(
    model_uri=f"models:/{args.mlf_item2vec_model_name}@champion"
)
skipgram_model = model.unwrap_python_model().model
embedding_0 = skipgram_model.embeddings(torch.tensor(0))
embedding_dim = embedding_0.size()[0]
id_mapping = model.unwrap_python_model().id_mapping
pretrained_item_embedding = skipgram_model.embeddings

Downloading artifacts:   0%|          | 0/8 [00:00<?, ?it/s]



In [7]:
assert (
    pretrained_item_embedding.embedding_dim == args.embedding_dim
), "Mismatch pretrained item_embedding dimension"

# Test implementation

In [8]:
embedding_dim = 8
batch_size = 2

# Mock data
user_indices = [0, 0, 1, 2, 2]
item_indices = [0, 1, 2, 3, 4]
timestamps = [0, 1, 2, 3, 4]
ratings = [0, 4, 5, 3, 0]
item_sequences = [
    [-1, -1, 2, 3],
    [-1, -1, 2, 3],
    [-1, -1, 1, 3],
    [-1, -1, 2, 1],
    [-1, -1, 2, 1],
]

n_users = len(set(user_indices))
n_items = len(set(item_indices))

train_df = pd.DataFrame(
    {
        "user_indice": user_indices,
        "item_indice": item_indices,
        args.timestamp_col: timestamps,
        args.rating_col: ratings,
        "item_sequence": item_sequences,
    }
)

model = init_model(n_users, n_items, embedding_dim, args.dropout)

# Example forward pass
model.eval()
user = torch.tensor([0])
item_sequence = torch.tensor([[-1, -1, -1, 0, 1]])
target_item = torch.tensor([2])
predictions = model.predict(user, item_sequence, target_item)
print(predictions)
model.train()

tensor([[0.5419]], grad_fn=<SigmoidBackward0>)


SequenceRatingPrediction(
  (item_embedding): Embedding(6, 8, padding_idx=5)
  (user_embedding): Embedding(3, 8)
  (gru): GRU(8, 8, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (fc_rating): Sequential(
    (0): Linear(in_features=24, out_features=8, bias=True)
    (1): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=8, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

In [9]:
rating_dataset = UserItemRatingDFDataset(
    train_df, "user_indice", "item_indice", args.rating_col, args.timestamp_col
)

train_loader = DataLoader(
    rating_dataset, batch_size=batch_size, shuffle=False, drop_last=True
)

In [10]:
for batch_input in train_loader:
    print(batch_input)

{'user': tensor([0, 0]), 'item': tensor([0, 1]), 'rating': tensor([0., 1.]), 'item_sequence': tensor([[-1, -1,  2,  3],
        [-1, -1,  2,  3]]), 'item_sequence_ts_bucket': tensor([], size=(2, 0), dtype=torch.int64), 'item_feature': tensor([], size=(2, 0))}
{'user': tensor([1, 2]), 'item': tensor([2, 3]), 'rating': tensor([1., 1.]), 'item_sequence': tensor([[-1, -1,  1,  3],
        [-1, -1,  2,  1]]), 'item_sequence_ts_bucket': tensor([], size=(2, 0), dtype=torch.int64), 'item_feature': tensor([], size=(2, 0))}


In [11]:
# model
lit_model = LitSequenceRatingPrediction(model, log_dir=args.notebook_persist_dp)

# train model
trainer = L.Trainer(
    default_root_dir=f"{args.notebook_persist_dp}/test",
    max_epochs=2,
    accelerator=args.device if args.device else "auto",
)
trainer.fit(
    model=lit_model, train_dataloaders=train_loader, val_dataloaders=train_loader
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4070 SUPER') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name               | Type                     | Params | Mode 
------------------------------------------------------------------------
0 | model              | SequenceRatingPrediction | 729    | train
1 | val_roc_auc_metric | BinaryAUROC              | 0      | train
2 | val_pr_auc_metric  | BinaryAveragePrecision   | 0      | train
------------------------------------------------------------------------
729       Trainable params
0         Non-trainable params

Sanity Checking: |                                                                                            …

/home/dvq/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
/home/dvq/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
/home/dvq/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=2` reached.
[32m2025-03-01 19:52:25.141[0m | [1mINFO    [0m | [36msrc.sequence.trainer[0m:[36mon_fit_end[0m:[36m172[0m - [1mLogging classification metrics...[0m


In [12]:
users = torch.tensor([0, 0, 0, 0])
item_sequences = torch.tensor(
    [[-1, -1, 2, 3], [-1, -1, 2, 3], [-1, -1, 1, 3], [-1, -1, 2, 1]]
)
items = torch.tensor([0, 1, 2, 3])
predictions = model.predict(users, item_sequences, items)
print(predictions)

tensor([[0.5498],
        [0.5998],
        [0.5506],
        [0.5526]], grad_fn=<SigmoidBackward0>)


In [13]:
def create_predict_df(
    train_df,
    val_user_indices,
    val_timestamp,
    rating_col,
    timestamp_col,
    sequence_length=10,
):
    predict_df = pd.DataFrame(
        {
            "user_indice": val_user_indices,
            "item_indice": -1,  # placeholder
            "timestamp": val_timestamp,
            "source": "predict",
        }
    )

    predict_df = (
        pd.concat(
            [
                train_df.loc[lambda df: df[rating_col].gt(0)][
                    ["user_indice", "item_indice", timestamp_col]
                ].assign(source="train"),
                predict_df,
            ],
            axis=0,
        )
        .pipe(
            generate_item_sequences,
            "user_indice",
            "item_indice",
            timestamp_col,
            sequence_length=sequence_length,
            padding=True,
            padding_value=-1,
        )
        .loc[lambda df: df["source"].eq("predict")]
        .assign(item_sequence=lambda df: df["item_sequence"].apply(np.array))
    )

    return predict_df


predict_df = create_predict_df(
    train_df,
    user_indices,
    timestamps[-1],
    args.rating_col,
    args.timestamp_col,
    sequence_length=10,
)

predict_df

Unnamed: 0,user_indice,item_indice,timestamp,source,item_sequence
0,0,-1,4,predict,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, 1]"
1,0,-1,4,predict,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, 1]"
2,1,-1,4,predict,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, 2]"
3,2,-1,4,predict,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, 3]"
4,2,-1,4,predict,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, 3]"


In [14]:
recommendations = model.recommend(
    torch.tensor(predict_df["user_indice"].values),
    torch.tensor(predict_df["item_sequence"].values.tolist()),
    k=2,
    batch_size=4,
)
recommendations


Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /pytorch/torch/csrc/utils/tensor_new.cpp:254.)



Generating recommendations:   0%|          | 0/2 [00:00<?, ?it/s]

{'user_indice': [0, 0, 0, 0, 1, 1, 2, 2, 2, 2],
 'recommendation': [4, 1, 4, 1, 3, 5, 3, 1, 3, 1],
 'score': [0.5868487358093262,
  0.5797895789146423,
  0.5868487358093262,
  0.5797895789146423,
  0.5909941792488098,
  0.5758971571922302,
  0.5616493225097656,
  0.5495050549507141,
  0.5616493225097656,
  0.5495050549507141]}

# Prep data

In [15]:
train_df = pd.read_parquet("../data/train_features_neg_df.parquet")
val_df = pd.read_parquet("../data/val_features_neg_df.parquet")
idm_fp = "../data/idm.json"
idm = IDMapper().load(idm_fp)

assert (
    train_df[args.user_col].map(lambda s: idm.get_user_index(s))
    != train_df["user_indice"]
).sum() == 0, "Mismatch IDM"
assert (
    val_df[args.user_col].map(lambda s: idm.get_user_index(s)) != val_df["user_indice"]
).sum() == 0, "Mismatch IDM"

In [16]:
user_indices = train_df["user_indice"].unique()
item_indices = train_df["item_indice"].unique()

logger.info(f"{len(user_indices)=:,.0f}, {len(item_indices)=:,.0f}")

[32m2025-03-01 19:52:27.228[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mlen(user_indices)=19,578, len(item_indices)=4,630[0m


In [17]:
train_df

Unnamed: 0,user_id,parent_asin,rating,timestamp,timestamp_unix,parent_asin_rating_cnt_365d,parent_asin_rating_avg_prev_rating_365d,parent_asin_rating_cnt_90d,parent_asin_rating_avg_prev_rating_90d,parent_asin_rating_cnt_30d,...,user_rating_list_10_recent_asin,user_rating_list_10_recent_asin_timestamp,item_sequence,item_sequence_ts,item_sequence_ts_bucket,main_category,title,description,categories,price
0,AG57LGJFCNNQJ6P6ABQAVUKXDUDA,B0015AARJI,0.0,2016-01-12 11:59:11.000,,76.0,4.592105,10.0,4.3,3.0,...,B00J00BLRM,1452599936,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1....","[-1, -1, -1, -1, -1, -1, -1, -1, -1, 1452599936]","[-1, -1, -1, -1, -1, -1, -1, -1, -1, 0]",Video Games,PlayStation 3 Dualshock 3 Wireless Controller ...,"[Amazon.com, The Dualshock 3 wireless controll...","[Video Games, Legacy Systems, PlayStation Syst...",49.99
1,AHWG4EGOV5ZDKPETL56MAYGPLJRQ,B0BMGHMP23,0.0,2016-04-18 19:26:20.000,,,,,,,...,"B00YOGZFCO,B00KWFCSB2,B00L3LQ1FI,B0151K6J9Y,B0...","1449254540,1449256005,1449257733,1452715791,14...","[3028.0, 2742.0, 2755.0, 3159.0, 3101.0, 3036....","[1449254540, 1449256005, 1449257733, 145271579...","[5, 5, 5, 5, 5, 5, 5, 5, 5, 5]",Computers,Logitech G502 Lightspeed Wireless Gaming Mouse...,[G502 is the best gaming mouse from Logitech G...,"[Video Games, PC, Accessories, Gaming Mice]",87.95
2,AH5PTZ2U74OZ3HT6QVUWM4CV6OVQ,B009AP23NI,0.0,2016-02-10 18:45:08.000,,9.0,4.666667,0.0,,0.0,...,"B0199OXR0W,B00EVPR4FY,B00B7ELWAU,B00UH9DN58,B0...","1443454097,1455129080,1455129186,1455129499,14...","[-1.0, -1.0, 3234.0, 2508.0, 2318.0, 2964.0, 1...","[-1, -1, 1443454097, 1455129080, 1455129186, 1...","[-1, -1, 5, 1, 1, 0, 0, 0, 0, 0]",Video Games,Nintendo Wii U Pro U Controller (Japanese Vers...,[Wii U PRO controller (black) (WUP-A-RSKA)],"[Video Games, Legacy Systems, Nintendo Systems...",43.99
3,AFC5XTCF5D7J3NSDITB2Z26XWWYA,B001E8WQUY,5.0,2019-05-01 21:22:39.265,1.556746e+09,0.0,,0.0,,0.0,...,"B006HZA6VK,B0BN2FNKLM,B0086VPUHI,B0040UAYI4,B0...","1327120514,1377289907,1402605836,1402606396,14...","[1987.0, 4569.0, 2114.0, 1606.0, 2159.0, 2279....","[1327120514, 1377289907, 1402605836, 140260639...","[8, 8, 7, 7, 7, 7, 7, 7, 6, 6]",Video Games,Rock Band 2 - Nintendo Wii (Game only),"[Product description, Rock Band 2 lets you and...","[Video Games, Legacy Systems, Nintendo Systems...",28.49
4,AF7LJQOIWF3Y3YD7SGOJ34MA5JPA,B001E8WQKY,5.0,2015-01-09 12:53:25.000,1.420808e+09,16.0,4.375000,8.0,4.5,4.0,...,"B00A2ML6XG,B003VUO6LU",14208077931420807991,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1....","[-1, -1, -1, -1, -1, -1, -1, -1, 1420807793, 1...","[-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]",Video Games,Resident Evil 5 - Xbox 360,[],"[Video Games, Legacy Systems, Xbox Systems, Xb...",29.88
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
328591,AG4RATLNVLOKZCPXN67HKOAK65CA,B078FBVJMB,0.0,2015-10-31 18:25:09.000,,,,,,,...,B00TFVD688,1425233294,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1....","[-1, -1, -1, -1, -1, -1, -1, -1, -1, 1425233294]","[-1, -1, -1, -1, -1, -1, -1, -1, -1, 5]",Video Games,A Way Out – PC Origin [Online Game Code],[From the creators of Brothers - A Tale of Two...,"[Video Games, PC, Games]",5.99
328592,AFBXO3BFWBJX6QS5NW73O37IXF2A,B0771ZXXV6,0.0,2011-03-08 02:06:38.000,,,,,,,...,"B003JVCA9Q,B0029NZ4HA",12995495171299549928,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1....","[-1, -1, -1, -1, -1, -1, -1, -1, 1299549517, 1...","[-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]",Video Games,Nintendo Joy-Con (R) - Neon Red - Nintendo Switch,[To be determined],"[Video Games, Nintendo Switch, Accessories, Co...",
328593,AHVANA5GZNJ45UABPXWZNAF4ECBQ,B00BBF6MO6,0.0,2015-02-15 05:31:04.000,,3.0,4.666667,0.0,,0.0,...,"B002L93F0A,B002KJ02ZC,B001H4NMNA",137041433213704147071370416530,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 137...","[-1, -1, -1, -1, -1, -1, -1, 1370414332, 13704...","[-1, -1, -1, -1, -1, -1, -1, 6, 6, 6]",Video Games,Killer is Dead - Xbox 360,[Killer Is Dead is the latest title from the d...,"[Video Games, Legacy Systems, Xbox Systems, Xb...",39.82
328594,AHAVA5VKMJ3OMOLGDZ3W45CKXEWA,B00KTORA0K,5.0,2019-05-25 04:03:51.505,1.558757e+09,3.0,4.666667,1.0,5.0,1.0,...,"B004AYCNR0,B007NUQICE,B000TYQL1O,B000SEU92W,B0...","1431150669,1431150834,1432041664,1432041986,15...","[-1.0, -1.0, -1.0, 1657.0, 2074.0, 593.0, 583....","[-1, -1, -1, 1431150669, 1431150834, 143204166...","[-1, -1, -1, 7, 7, 7, 7, 5, 5, 0]",Video Games,Just Dance 2015 - Wii,[With more than 50 million copies of Just Danc...,"[Video Games, Legacy Systems, Nintendo Systems...",33.0


# Train

In [18]:
rating_dataset = UserItemRatingDFDataset(
    train_df, "user_indice", "item_indice", args.rating_col, args.timestamp_col
)
val_rating_dataset = UserItemRatingDFDataset(
    val_df, "user_indice", "item_indice", args.rating_col, args.timestamp_col
)

train_loader = DataLoader(
    rating_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True
)
val_loader = DataLoader(
    val_rating_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False
)

In [19]:
n_items = len(item_indices)
n_users = len(user_indices)

model = init_model(n_users, n_items, args.embedding_dim, args.dropout)

#### Predict before train

In [20]:
model.item_embedding

Embedding(4631, 128, padding_idx=4630)

In [21]:
val_df = val_rating_dataset.df
val_df.sample(10)

Unnamed: 0,user_id,parent_asin,rating,timestamp,timestamp_unix,parent_asin_rating_cnt_365d,parent_asin_rating_avg_prev_rating_365d,parent_asin_rating_cnt_90d,parent_asin_rating_avg_prev_rating_90d,parent_asin_rating_cnt_30d,...,user_rating_list_10_recent_asin,user_rating_list_10_recent_asin_timestamp,item_sequence,item_sequence_ts,item_sequence_ts_bucket,main_category,title,description,categories,price
788,AHAOIPSKT4LEWU47ZEN7LMOKRMTA,B002DZKZ5K,0.0,2022-02-02 08:29:00.447,,2.0,3.0,0.0,,0.0,...,"B00L3LQ1FI,B017QU5G1O,B00A878J5I,B07G4YYZ1M,B0...","1417991979,1434316316,1434318550,1452921402,15...","[2755, 3202, 2267, 3973, 3834, 3698, 4426, 430...","[1417991979, 1434316316, 1434318550, 145292140...","[8, 8, 8, 8, 7, 7, 7, 6, 6, 5]",Video Games,Lego Indiana Jones 2: The Adventure Continues ...,"[Product Description, LEGO Indiana Jones 2: Th...","[Video Games, Legacy Systems, Nintendo Systems...",28.53
1178,AEJDEYIQP7GV6MFFDS4NW5M66YMA,B0BL65X86R,1.0,2022-01-15 23:49:08.812,1642291000.0,11.0,5.0,1.0,5.0,0.0,...,"B006PP41Q8,B00DBRNQZ0,B0050SX9I2,B00IFF0SIQ,B0...","1440968956,1445379263,1445379365,1445379373,14...","[2012, 2432, 1838, 2645, 2466, 1998, 2050, 359...","[1440968956, 1445379263, 1445379365, 144537937...","[8, 8, 8, 8, 8, 7, 7, 7, 7, 7]",Video Games,$25 PlayStation Store Gift Card [Digital Code],[Redeem against anything on PlayStation Store....,"[Video Games, Online Game Services, PlayStatio...",25.0
1585,AHTLGOWRXF4PNX6DEUZXVZWBBVTQ,B0C3KYVDWT,1.0,2022-03-15 17:13:55.253,1647364000.0,25.0,4.92,3.0,5.0,0.0,...,"B00PGLG79G,B01BF9X6LO,B07YBXFDYN,B07VRD1TT1,B0...","1528409892,1528410399,1571267805,1580915819,15...","[2893, 3277, 4282, 4209, 3360, 4299, 3940, 427...","[1528409892, 1528410399, 1571267805, 158091581...","[7, 7, 6, 6, 6, 6, 6, 3, 3, 3]",Computers,"SanDisk 128GB microSDXC-Card, Licensed for Nin...","[With incredible speed, the officially license...","[Video Games, Nintendo Switch, Accessories]",14.99
135,AGSRO7JUOTXSI76KD4A3J5XED5EQ,B00S1LRUVW,0.0,2022-03-16 17:42:47.025,,0.0,,0.0,,0.0,...,"B00002STFD,B001KN31ZM,B00503E8S2,B004YV9TSA,B0...","1378780997,1382760104,1399439878,1399440244,14...","[-1, -1, -1, 50, 1107, 1806, 1798, 2391, 4272,...","[-1, -1, -1, 1378780997, 1382760104, 139943987...","[-1, -1, -1, 8, 8, 8, 8, 8, 8, 4]",Video Games,Nintendo New 3DS Xl - Red [Discontinued],[THE NEXT DIMENSION IN ENTERTAINMENT. The New ...,"[Video Games, Legacy Systems, Nintendo Systems...",379.99
1377,AHCMSGZRS6NBUPM4DPVUAZHLOQ7Q,B00008URUF,0.0,2021-12-12 21:05:53.342,,1.0,4.0,0.0,,0.0,...,"B09918MSTF,B08MBQ51KG,B087T1FS9K,B087NNZZM8,B0...","1595603546,1595603593,1595783045,1597267305,16...","[-1, -1, -1, 4460, 4407, 4349, 4343, 4370, 430...","[-1, -1, -1, 1595603546, 1595603593, 159578304...","[-1, -1, -1, 6, 6, 6, 6, 6, 5, 5]",Video Games,Donkey Kong Country,"[Product Description, The arcade classic gets ...","[Video Games, Legacy Systems, Nintendo Systems...",38.66
1449,AETZPD7JKD42GBVYXBYPGOY4NF6Q,B09B35J159,1.0,2022-07-08 19:10:11.311,1657307000.0,4.0,2.25,2.0,3.0,2.0,...,"B07N5LL4YW,B00BN5T30E,B00C1TTF86,B00DE2W4PK,B0...","1384463846,1384464555,1384543182,1384543825,15...","[-1, -1, -1, -1, -1, 4074, 2352, 2368, 2444, 3...","[-1, -1, -1, -1, -1, 1384463846, 1384464555, 1...","[-1, -1, -1, -1, -1, 8, 8, 8, 8, 7]",Computers,Razer Basilisk Ultimate HyperSpeed Wireless Ga...,"[With a high-speed transmission, extremely low...","[Video Games, PC, Accessories, Gaming Mice]",
416,AFJCV5BC3AFKXNLMUSLRGCGXNLFQ,B09WB49MP6,1.0,2022-04-29 18:25:39.335,1651257000.0,0.0,,0.0,,0.0,...,"B00G9X4YRM,B00BN5T30E,B00K32USMU,B00OGNV5HY,B0...","1401869274,1403281337,1422155394,1422155844,14...","[-1, -1, -1, 2556, 2352, 2696, 2872, 4412, 308...","[-1, -1, -1, 1401869274, 1403281337, 142215539...","[-1, -1, -1, 8, 8, 8, 8, 8, 8, 8]",Video Games,OpenWheeler GEN3 Racing Wheel Stand Cockpit Bl...,[],"[Video Games, Legacy Systems, Xbox Systems, Xb...",399.0
1297,AETZPD7JKD42GBVYXBYPGOY4NF6Q,B07X1HF3V6,0.0,2022-07-08 19:13:28.232,,1.0,5.0,0.0,,0.0,...,"B07N5LL4YW,B00BN5T30E,B00C1TTF86,B00DE2W4PK,B0...","1384463846,1384464555,1384543182,1384543825,15...","[-1, -1, -1, -1, 4074, 2352, 2368, 2444, 3628,...","[-1, -1, -1, -1, 1384463846, 1384464555, 13845...","[-1, -1, -1, -1, 8, 8, 8, 8, 7, 0]",Video Games,WB Games Mortal Kombat: Komplete Edition - Pla...,[Note:The extra downloadable content is a bonu...,"[Video Games, Legacy Systems, PlayStation Syst...",34.43
789,AGLQLDOF6JVKQIZY7BU7OCAJFEIA,B004QEV0MI,0.0,2022-02-06 16:20:19.149,,1.0,1.0,0.0,,0.0,...,"B01M8GXNJX,B01FUWJAQC,B072JYVYCX,B06XX4D2KY,B0...","1488662110,1498345226,1507389221,1509398438,15...","[3481, 3332, 3694, 3591, 3871, 3942, 3059, 398...","[1488662110, 1498345226, 1507389221, 150939843...","[7, 7, 7, 7, 7, 7, 6, 6, 6, 6]",Video Games,Saint's Row: The Third - Xbox 360,"[Product Description, Years after taking Stilw...","[Video Games, Legacy Systems, Xbox Systems, Xb...",17.0
355,AGYJI6XSILTABA3ZWALFM6SMEVOQ,B08XD54VJY,1.0,2021-12-01 01:40:44.191,1638323000.0,18.0,4.055556,4.0,4.5,0.0,...,"B07SM7G9CN,B07G1SC6BW,B07L6MJ6LD,B07J3P1GJM,B0...","1527019952,1538858187,1557937823,1557937962,16...","[-1, -1, -1, 4164, 3970, 4049, 4018, 4388, 369...","[-1, -1, -1, 1527019952, 1538858187, 155793782...","[-1, -1, -1, 7, 7, 6, 6, 6, 6, 4]",Video Games,The Legend of Zelda: Skyward Sword HD - Ninten...,"[Solve puzzles, explore dungeons, and soar the...","[Video Games, Nintendo Switch, Games]",48.49


In [22]:
user_id = val_df.sample(1)[args.user_col].values[0]
# user_id = "AH4AOFTTDPHPAFAAVFMAF25H2LIQ"
test_df = val_df.loc[lambda df: df[args.user_col].eq(user_id)]
with pd.option_context("display.max_colwidth", None):
    display(test_df)

Unnamed: 0,user_id,parent_asin,rating,timestamp,timestamp_unix,parent_asin_rating_cnt_365d,parent_asin_rating_avg_prev_rating_365d,parent_asin_rating_cnt_90d,parent_asin_rating_avg_prev_rating_90d,parent_asin_rating_cnt_30d,...,user_rating_list_10_recent_asin,user_rating_list_10_recent_asin_timestamp,item_sequence,item_sequence_ts,item_sequence_ts_bucket,main_category,title,description,categories,price
4,AEBEAZUAX3HMA7EF3BA6L2DK3LPA,B008HPAXZ2,0.0,2022-06-25 22:48:11.729,,0.0,,0.0,,0.0,...,"B06XMRQ68B,B0BLXJC8MZ,B0748N6796,B0BS9YCBYY,B0BKRXQ5GL,B07HNW68ZC,B0754LGLFP",1494507012154275125615573684931557883885155882742915704952121655255201,"[-1, -1, -1, 3577, 4566, 3740, 4579, 4554, 4007, 3764]","[-1, -1, -1, 1494507012, 1542751256, 1557368493, 1557883885, 1558827429, 1570495212, 1655255201]","[-1, -1, -1, 8, 7, 7, 7, 7, 6, 4]",Video Games,HORI Nintendo 3DS XL Screen Protective Filter,[Officially licensed by Nintendo. This is the only screen protective filter you'll need to protect your new Nintendo 3DS XL LCD screens from dirt and scratches. Uses the same proven new and improved filter application method as HORI's Nintendo 3DS version. This method will allow anyone to easily and neatly apply the screen filters and get perfect results every time! Also includes a cleaning cloth. Nintendo 3DS XL system not included.],"[Video Games, Legacy Systems, Nintendo Systems, Nintendo 3DS & 2DS, Accessories, Faceplates, Protectors & Skins, Screen Protectors]",
27,AEBEAZUAX3HMA7EF3BA6L2DK3LPA,B0754LGLFP,1.0,2022-06-15 01:06:41.380,1655255000.0,2.0,2.5,1.0,3.0,0.0,...,"B06XMRQ68B,B0BLXJC8MZ,B0748N6796,B0BS9YCBYY,B0BKRXQ5GL,B07HNW68ZC",149450701215427512561557368493155788388515588274291570495212,"[-1, -1, -1, -1, 3577, 4566, 3740, 4579, 4554, 4007]","[-1, -1, -1, -1, 1494507012, 1542751256, 1557368493, 1557883885, 1558827429, 1570495212]","[-1, -1, -1, -1, 8, 7, 7, 7, 7, 6]",Computers,"Redragon K552 Mechanical Gaming Keyboard Rainbow LED Backlit Wired with Anti-Dust Proof Switches for Windows PC (White, 87 Keys Blue Switches)","[Redragon K552 KUMARA 87 Key Rainbow LED Backlit Mechanical Wired illuminated Gaming Keyboard with Anti-Dust Blue Switches The Redragon K552isn't your average gaming keyboard. Not only is it over-engineered and built to take a beating, it is loaded with pro features including solid metal and ABS construction, precision engineered keycaps, high-end mechanical dust proof switches and crisp, bright adjustable RGB LED backlighting, a gold-plated USB connector, and a splash-resistant design.The Blue Switches are clicky with medium resistance, audible loud click sound, crisp precise tactile feedback, good for gaming and typing Features: * Rainbow LED Backlit, 19 different Backlight, plus 2 user programmable modes * 6 different backlight colors & 6 brightness levels * 8 preset gaming modes * Durable solid Metal-ABS Construction * Dust Proof Tactile Blue Switches * Compact Space Saving Tenkeyless Design with 87 Full Sized Keys * All 87 keys are 100% conflict free, anti-ghosting * 12 Multimedia with FN Keys * WIN Key can Be disabled for Gaming * WASD and arrow keys are interchangeable * Keycaps offering crystal clear lettering that doesn't scratch off * Weight & Dimensions: 30.90oz, 13.93x4.86x1.46 inches* For Windows 11, Windows 10, Windows 8, Windows 7, Windows Vista, and Windows XP What's in the box * Keyboard (White, Blue Switches) * User guide * Warranty card]","[Video Games, PC, Accessories, Gaming Keyboards]",39.99
1494,AEBEAZUAX3HMA7EF3BA6L2DK3LPA,B0C6DH316S,1.0,2022-06-25 22:48:11.729,1656197000.0,1.0,4.0,0.0,,0.0,...,"B06XMRQ68B,B0BLXJC8MZ,B0748N6796,B0BS9YCBYY,B0BKRXQ5GL,B07HNW68ZC,B0754LGLFP",1494507012154275125615573684931557883885155882742915704952121655255201,"[-1, -1, -1, 3577, 4566, 3740, 4579, 4554, 4007, 3764]","[-1, -1, -1, 1494507012, 1542751256, 1557368493, 1557883885, 1558827429, 1570495212, 1655255201]","[-1, -1, -1, 8, 7, 7, 7, 7, 6, 4]",Computers,"Logitech G PRO X Wireless Lightspeed Gaming Headset - Shroud Edition, Black",[],"[Video Games, PC, Accessories, Headsets]",253.82
1824,AEBEAZUAX3HMA7EF3BA6L2DK3LPA,B00N2KKSNO,0.0,2022-06-15 01:06:41.380,,0.0,,0.0,,0.0,...,"B06XMRQ68B,B0BLXJC8MZ,B0748N6796,B0BS9YCBYY,B0BKRXQ5GL,B07HNW68ZC",149450701215427512561557368493155788388515588274291570495212,"[-1, -1, -1, -1, 3577, 4566, 3740, 4579, 4554, 4007]","[-1, -1, -1, -1, 1494507012, 1542751256, 1557368493, 1557883885, 1558827429, 1570495212]","[-1, -1, -1, -1, 8, 7, 7, 7, 7, 6]",Video Games,Sleeping Dogs: Definitive Edition- PlayStation 4,"[Sleeping Dogs is an open-world action game set on the exotic island of Hong Kongwith brutal martial arts combat, thrilling street races and a celebrated, grippingstory. Entirely remade for the new generation of consoles and the very latest PCsThe Definitive Edition includes all previously available content and a wealth of newtechnological upgrades. In this open world game, you play the roleof Wei Shen, an undercover cop trying to take down the Triads from the inside out.You’ll have to prove yourself worthy as you fight your way up the organization,taking part in brutal criminal activities without blowing your cover. Torn betweenyour loyalty to the badge and a criminal code of honor, you will risk everything asthe lines between truth, loyalty and justice become permanently blurred.]","[Video Games, PlayStation 4, Games]",22.33


In [23]:
test_row = test_df.loc[lambda df: df[args.rating_col].gt(0)].iloc[0]
item_id = test_row[args.item_col]
item_sequence = test_row["item_sequence"]
logger.info(
    f"Test predicting before training with {args.user_col} = {user_id} and {args.item_col} = {item_id}"
)
user_indice = idm.get_user_index(user_id)
item_indice = idm.get_item_index(item_id)
user = torch.tensor([user_indice])
item_sequence = torch.tensor([item_sequence])
item = torch.tensor([item_indice])

model.eval()
model.predict(user, item_sequence, item)
model.train()

[32m2025-03-01 19:52:27.671[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mTest predicting before training with user_id = AEBEAZUAX3HMA7EF3BA6L2DK3LPA and parent_asin = B0754LGLFP[0m


SequenceRatingPrediction(
  (item_embedding): Embedding(4631, 128, padding_idx=4630)
  (user_embedding): Embedding(19578, 128)
  (gru): GRU(128, 128, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (fc_rating): Sequential(
    (0): Linear(in_features=384, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=128, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

#### Training loop

##### Overfit 1 batch

In [24]:
early_stopping = EarlyStopping(
    monitor="val_roc_auc", patience=10, mode="max", verbose=False
)

model = init_model(n_users, n_items, args.embedding_dim, dropout=0)
lit_model = LitSequenceRatingPrediction(
    model,
    learning_rate=args.learning_rate,
    l2_reg=0.0,
    log_dir=args.notebook_persist_dp,
    accelerator=args.device,
)

log_dir = f"{args.notebook_persist_dp}/logs/overfit"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    accelerator=args.device if args.device else "auto",
    max_epochs=100,
    overfit_batches=1,
    callbacks=[early_stopping],
)
trainer.fit(
    model=lit_model,
    train_dataloaders=train_loader,
    val_dataloaders=train_loader,
)
logger.info(f"Logs available at {trainer.log_dir}")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(overfit_batches=1)` was configured so 1 batch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name               | Type                     | Params | Mode 
------------------------------------------------------------------------
0 | model              | SequenceRatingPrediction | 3.2 M  | train
1 | val_roc_auc_metric | BinaryAUROC              | 0      | train
2 | val_pr_auc_metric  | BinaryAveragePrecision   | 0      | train
------------------------------------------------------------------------
3.2 M     Trainable params
0         Non-trainable params
3.2 M     Total params
12.990    Total estimated model params size (MB)
13        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …


You requested to overfit but enabled val dataloader shuffling. We are turning off the val dataloader shuffling for you.


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


You requested to overfit but enabled train dataloader shuffling. We are turning off the train dataloader shuffling for you.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.



Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

[32m2025-03-01 19:52:28.627[0m | [1mINFO    [0m | [36msrc.sequence.trainer[0m:[36mon_fit_end[0m:[36m172[0m - [1mLogging classification metrics...[0m
[32m2025-03-01 19:52:43.672[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m29[0m - [1mLogs available at /home/dvq/frostmourne/recsys-mvp/notebooks/data/000-sequence-modeling-baseline/logs/overfit/lightning_logs/version_5[0m


In [25]:
# Need to make sure port 6006 at local is accessible
%tensorboard --logdir $trainer.log_dir

##### Fit on all data

In [26]:
# papermill_description=fit-model
early_stopping = EarlyStopping(
    monitor="val_roc_auc", patience=args.early_stopping_patience, mode="max", verbose=False
)

checkpoint_callback = ModelCheckpoint(
    dirpath=f"{args.notebook_persist_dp}/checkpoints",
    filename="best-checkpoint",
    save_top_k=1,
    monitor="val_roc_auc",
    mode="max",
)

model = init_model(
    n_users,
    n_items,
    args.embedding_dim,
    dropout=args.dropout,
    item_embedding=pretrained_item_embedding,
)
lit_model = LitSequenceRatingPrediction(
    model,
    learning_rate=args.learning_rate,
    l2_reg=args.l2_reg,
    log_dir=args.notebook_persist_dp,
    evaluate_ranking=True,
    idm=idm,
    args=args,
    accelerator=args.device,
    checkpoint_callback=checkpoint_callback,
)

log_dir = f"{args.notebook_persist_dp}/logs/run"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    max_epochs=args.max_epochs,
    callbacks=[early_stopping, checkpoint_callback],
    accelerator=args.device if args.device else "auto",
    logger=args._mlf_logger if args.log_to_mlflow else None,
)
trainer.fit(
    model=lit_model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

Checkpoint directory /home/dvq/frostmourne/recsys-mvp/notebooks/data/000-sequence-modeling-baseline/checkpoints exists and is not empty.

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name               | Type                     | Params | Mode 
------------------------------------------------------------------------
0 | model              | SequenceRatingPrediction | 3.2 M  | train
1 | val_roc_auc_metric | BinaryAUROC              | 0      | train
2 | val_pr_auc_metric  | BinaryAveragePrecision   | 0      | train
------------------------------------------------------------------------
3.2 M     Trainable params
0         Non-trainable params
3.2 M     Total params
12.990    Total estimated model params size (MB)
13        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.



Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

[32m2025-03-01 19:55:17.134[0m | [1mINFO    [0m | [36msrc.sequence.trainer[0m:[36mon_fit_end[0m:[36m165[0m - [1mLoading best model from /home/dvq/frostmourne/recsys-mvp/notebooks/data/000-sequence-modeling-baseline/checkpoints/best-checkpoint-v5.ckpt...[0m
[32m2025-03-01 19:55:17.301[0m | [1mINFO    [0m | [36msrc.sequence.trainer[0m:[36mon_fit_end[0m:[36m172[0m - [1mLogging classification metrics...[0m
[32m2025-03-01 19:55:18.251[0m | [1mINFO    [0m | [36msrc.sequence.trainer[0m:[36mon_fit_end[0m:[36m175[0m - [1mLogging ranking metrics...[0m


Generating recommendations:   0%|          | 0/177 [00:00<?, ?it/s]

🏃 View run 000-sequence-modeling-baseline at: http://localhost:5002/#/experiments/3/runs/df9df9f3eed64e96acc4934cebab0afb
🧪 View experiment at: http://localhost:5002/#/experiments/3


In [27]:
logger.info(
    f"Test predicting after training with {args.user_col} = {user_id} and {args.item_col} = {item_id}"
)
model.eval()
model = model.to(user.device)
model.predict(user, item_sequence, item)
model.train()

[32m2025-03-01 19:55:23.558[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mTest predicting after training with user_id = AEBEAZUAX3HMA7EF3BA6L2DK3LPA and parent_asin = B0754LGLFP[0m


SequenceRatingPrediction(
  (item_embedding): Embedding(4631, 128, padding_idx=4630)
  (user_embedding): Embedding(19578, 128)
  (gru): GRU(128, 128, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (fc_rating): Sequential(
    (0): Linear(in_features=384, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=128, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

# Load best checkpoint

In [28]:
logger.info(f"Loading best checkpoint from {checkpoint_callback.best_model_path}...")
args.best_checkpoint_path = checkpoint_callback.best_model_path

best_trainer = LitSequenceRatingPrediction.load_from_checkpoint(
    checkpoint_callback.best_model_path,
    model=init_model(n_users, n_items, args.embedding_dim, dropout=0),
)

[32m2025-03-01 19:55:23.595[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mLoading best checkpoint from /home/dvq/frostmourne/recsys-mvp/notebooks/data/000-sequence-modeling-baseline/checkpoints/best-checkpoint-v5.ckpt...[0m


In [29]:
best_model = best_trainer.model.to(lit_model.device)

In [30]:
best_model.eval()
best_model.predict(user, item_sequence, item)
best_model.train()

SequenceRatingPrediction(
  (item_embedding): Embedding(4631, 128, padding_idx=4630)
  (user_embedding): Embedding(19578, 128)
  (gru): GRU(128, 128, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0, inplace=False)
  (fc_rating): Sequential(
    (0): Linear(in_features=384, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0, inplace=False)
    (4): Linear(in_features=128, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

### Persist id mapping

In [31]:
if args.log_to_mlflow:
    # Persist id_mapping so that at inference we can predict based on item_ids (string) instead of item_index
    run_id = trainer.logger.run_id
    mlf_client = trainer.logger.experiment
    mlf_client.log_artifact(run_id, idm_fp)

### Wrap inference function and register best checkpoint as MLflow model

In [32]:
inferrer = SequenceRatingPredictionInferenceWrapper(best_model)

In [33]:
sample_input = {
    "user_ids": [idm.get_user_id(0)],
    "item_sequences": [[idm.get_item_id(0), idm.get_item_id(1)]],
    "item_ids": [idm.get_item_id(0)],
}
sample_output = inferrer.infer([0], [[0, 1]], [0])
sample_output

array([0.7713251], dtype=float32)

In [34]:
if args.log_to_mlflow:
    run_id = trainer.logger.run_id
    sample_output_np = sample_output
    signature = infer_signature(sample_input, sample_output_np)
    idm_filename = idm_fp.split("/")[-1]
    with mlflow.start_run(run_id=run_id):
        mlflow.pyfunc.log_model(
            python_model=inferrer,
            artifact_path="inferrer",
            # We log the id_mapping to the predict function so that it can accept item_id and automatically convert ot item_indice for PyTorch model to use
            artifacts={"idm": mlflow.get_artifact_uri(idm_filename)},
            signature=signature,
            input_example=sample_input,
            registered_model_name=args.mlf_model_name,
        )

2025/03/01 19:55:23 INFO mlflow.pyfunc: Validating input example against model signature


Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Registered model 'sequence_rating_prediction' already exists. Creating a new version of this model...
2025/03/01 19:55:26 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: sequence_rating_prediction, version 7


🏃 View run 000-sequence-modeling-baseline at: http://localhost:5002/#/experiments/3/runs/df9df9f3eed64e96acc4934cebab0afb
🧪 View experiment at: http://localhost:5002/#/experiments/3


Created version '7' of model 'sequence_rating_prediction'.


# Set the newly trained model as champion

In [35]:
if args.log_to_mlflow:
    # Get current champion
    deploy_alias = "champion"
    curr_model_run_id = None

    min_roc_auc = args.min_roc_auc

    try:
        curr_champion_model = mlf_client.get_model_version_by_alias(
            args.mlf_model_name, deploy_alias
        )
        curr_model_run_id = curr_champion_model.run_id
    except MlflowException as e:
        if "not found" in str(e).lower():
            logger.info(
                f"There is no {deploy_alias} alias for model {args.mlf_model_name}"
            )

    # Compare new vs curr models
    new_mlf_run = trainer.logger.experiment.get_run(trainer.logger.run_id)
    new_metrics = new_mlf_run.data.metrics
    roc_auc = new_metrics["roc_auc"]
    if curr_model_run_id:
        curr_model_run_info = mlf_client.get_run(curr_model_run_id)
        curr_metrics = curr_model_run_info.data.metrics
        if (curr_roc_auc := curr_metrics["roc_auc"]) > min_roc_auc:
            logger.info(
                f"Current {deploy_alias} model has {curr_roc_auc:,.4f} ROC-AUC. Setting it to the deploy baseline..."
            )
            min_roc_auc = curr_roc_auc

        top_metrics = ["roc_auc"]
        vizer = ModelMetricsComparisonVisualizer(curr_metrics, new_metrics, top_metrics)
        print(f"Comparing metrics between new run and current champion:")
        display(vizer.compare_metrics_df())
        vizer.create_metrics_comparison_plot(n_cols=5)
        vizer.plot_diff()

    # Register new champion
    if roc_auc < min_roc_auc:
        logger.info(
            f"Current run has ROC-AUC = {roc_auc:,.4f}, smaller than {min_roc_auc:,.4f}. Skip aliasing this model as the new {deploy_alias}.."
        )
    else:
        logger.info(f"Aliasing the new model as champion...")
        # Get the model version for current run by assuming it's the most recent registered version
        model_version = (
            mlf_client.get_registered_model(args.mlf_model_name)
            .latest_versions[0]
            .version
        )

        mlf_client.set_registered_model_alias(
            name=args.mlf_model_name, alias="champion", version=model_version
        )

        mlf_client.set_model_version_tag(
            name=args.mlf_model_name,
            version=model_version,
            key="author",
            value=args.author,
        )

[32m2025-03-01 19:55:26.749[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m15[0m - [1mThere is no champion alias for model sequence_rating_prediction[0m
[32m2025-03-01 19:55:26.754[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m45[0m - [1mAliasing the new model as champion...[0m


# Clean up

In [36]:
all_params = [args]

if args.log_to_mlflow:
    with mlflow.start_run(run_id=run_id):
        for params in all_params:
            params_dict = params.dict()
            params_ = dict()
            for k, v in params_dict.items():
                if k == "top_K":
                    k = "top_big_K"
                if k == "top_k":
                    k = "top_small_k"
                params_[f"{params.__repr_name__()}.{k}"] = v
            mlflow.log_params(params_)

🏃 View run 000-sequence-modeling-baseline at: http://localhost:5002/#/experiments/3/runs/df9df9f3eed64e96acc4934cebab0afb
🧪 View experiment at: http://localhost:5002/#/experiments/3


/tmp/ipykernel_297612/747004171.py:6: PydanticDeprecatedSince20:

The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.10/migration/

