# Ranker that can takes into accound different features

# Set up

In [1]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

In [2]:
import os
import sys

import dill
import lightning as L
import numpy as np
import pandas as pd
import torch
from dotenv import load_dotenv
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import MLFlowLogger
from loguru import logger
from mlflow.exceptions import MlflowException
from mlflow.models.signature import infer_signature
from pydantic import BaseModel
from torch.utils.data import DataLoader

import mlflow

load_dotenv()

sys.path.insert(0, "..")

from cfg.run_cfg import RunCfg
# from src.ann import AnnIndex
from src.utils.data_prep import chunk_transform
from src.algo.ranker.dataset import UserItemBinaryDFDataset
from src.utils.embedding_id_mapper import IDMapper
from src.algo.ranker.inference import RankerInferenceWrapper
from src.algo.ranker.model import Ranker
from src.algo.ranker.trainer import LitRanker
from src.algo.item2vec.trainer import LitSkipGram
from src.algo.item2vec.model import SkipGram



# Controller

In [3]:
# This is a parameter cell used by papermill
max_epochs = 1

In [4]:
class Args(BaseModel):
    testing: bool = False
    author: str = "dinh-trieu"
    log_to_mlflow: bool = True
    experiment_name: str = "RecSys MVP - Ranker"
    run_name: str = "004-use-sbert-features-and-llm-tags"
    notebook_persist_dp: str = None
    random_seed: int = 41
    device: str = None

    rc: RunCfg = RunCfg().init()

    item_metadata_pipeline_fp: str = "../data_for_ai/interim/item_metadata_pipeline_wo_user_item_manipulate.dill"
    qdrant_url: str = None
    qdrant_collection_name: str = "item_desc_sbert"

    max_epochs: int = max_epochs
    batch_size: int = 128
    tfm_chunk_size: int = 10000
    neg_to_pos_ratio: int = 1

    user_col: str = "user_id"
    item_col: str = "parent_asin"
    rating_col: str = "rating"
    timestamp_col: str = "timestamp"

    top_K: int = 100
    top_k: int = 10

    embedding_dim: int = 256
    item_sequence_ts_bucket_size: int = 10
    bucket_embedding_dim: int = 16
    dropout: float = 0.3
    early_stopping_patience: int = 5
    learning_rate: float = 0.001
    l2_reg: float = 1e-5

    mlf_item2vec_model_name: str = "item2vec"
    mlf_model_name: str = "ranker"
    min_roc_auc: float = 0.7

    best_checkpoint_path: str = None
    use_item_feature: bool = True

    def init(self):
        self.notebook_persist_dp = os.path.abspath(f"data/{self.run_name}")
        os.makedirs(self.notebook_persist_dp, exist_ok=True)

        if not (qdrant_host := os.getenv("QDRANT_HOST")):
            raise Exception(f"Environment variable QDRANT_HOST is not set.")

        qdrant_port = os.getenv("QDRANT_PORT")
        self.qdrant_url = f"{qdrant_host}:{qdrant_port}"

        if not (mlflow_uri := os.environ.get("MLFLOW_TRACKING_URI")):
            logger.warning(
                f"Environment variable MLFLOW_TRACKING_URI is not set. Setting self.log_to_mlflow to false."
            )
            self.log_to_mlflow = False

        if self.log_to_mlflow:
            logger.info(
                f"Setting up MLflow experiment {self.experiment_name} - run {self.run_name}..."
            )
            self._mlf_logger = MLFlowLogger(
                experiment_name=self.experiment_name,
                run_name=self.run_name,
                tracking_uri=mlflow_uri,
                log_model=False,
            )

        if self.device is None:
            self.device = (
                "cuda"
                if torch.cuda.is_available()
                else "mps" if torch.backends.mps.is_available() else "cpu"
            )

        return self


args = Args().init()

print(args.model_dump_json(indent=2))

[32m2025-07-01 01:11:39.635[0m | [34m[1mDEBUG   [0m | [36mcfg.run_cfg[0m:[36minit[0m:[36m31[0m - [34m[1mSetting use_sbert_features=True requires running notebook 016-sentence-transformers[0m
[32m2025-07-01 01:11:39.635[0m | [34m[1mDEBUG   [0m | [36mcfg.run_cfg[0m:[36minit[0m:[36m38[0m - [34m[1mChanging use_item_tags_from_llm requires re-running notebook 002-features-v2 to get the new item_metadata_pipeline.dill file[0m


[32m2025-07-01 01:11:39.640[0m | [1mINFO    [0m | [36m__main__[0m:[36minit[0m:[36m62[0m - [1mSetting up MLflow experiment RecSys MVP - Ranker - run 004-use-sbert-features-and-llm-tags...[0m


{
  "testing": false,
  "author": "dinh-trieu",
  "log_to_mlflow": true,
  "experiment_name": "RecSys MVP - Ranker",
  "run_name": "004-use-sbert-features-and-llm-tags",
  "notebook_persist_dp": "/home/dinhln/Desktop/real_time_recsys/notebooks/data/004-use-sbert-features-and-llm-tags",
  "random_seed": 41,
  "device": "cuda",
  "rc": {
    "use_sbert_features": true,
    "use_item_tags_from_llm": false,
    "item_feature_cols": [
      "main_category",
      "categories",
      "price",
      "parent_asin_rating_cnt_365d",
      "parent_asin_rating_avg_prev_rating_365d",
      "parent_asin_rating_cnt_90d",
      "parent_asin_rating_avg_prev_rating_90d",
      "parent_asin_rating_cnt_30d",
      "parent_asin_rating_avg_prev_rating_30d",
      "parent_asin_rating_cnt_7d",
      "parent_asin_rating_avg_prev_rating_7d"
    ]
  },
  "item_metadata_pipeline_fp": "../data_for_ai/interim/item_metadata_pipeline_wo_user_item_manipulate.dill",
  "qdrant_url": "138.2.61.6:6333",
  "qdrant_collecti

# Implement

In [5]:
def init_model(
    n_users,
    n_items,
    embedding_dim,
    item_sequence_ts_bucket_size,
    bucket_embedding_dim,
    item_feature_size,
    dropout,
    item_embedding=None,
    use_item_feature=False,
):
    model = Ranker(
        n_users,
        n_items,
        embedding_dim,
        item_sequence_ts_bucket_size=item_sequence_ts_bucket_size,
        bucket_embedding_dim=bucket_embedding_dim,
        use_item_feature=use_item_feature,
        item_feature_size=item_feature_size,
        dropout=dropout,
        item_embedding=item_embedding,
    )
    return model

## Load pretrained Item2Vec embeddings

In [6]:
# n_items = 4817  # This should be the number of unique items in your dataset
# assert args.embedding_dim == 256, "Embedding dimension must be 256"
# best_trainer = LitSkipGram.load_from_checkpoint(
#     "../data_for_ai/interim/best-item2vec-weight.ckpt",
#     skipgram_model=SkipGram(n_items, args.embedding_dim).to(args.device),
# )
# skipgram_item_embedding = best_trainer.skipgram_model.embeddings.weight.data.cpu()
# print(f"SkipGram Item embedding shape: {skipgram_item_embedding.shape}")
# print(f"SkipGram Item embedding dtype: {skipgram_item_embedding.dtype}")

# # create a embedding layer with num_items + 1 embedding, then apply the pretrained weights
# pretrained_item_embedding = torch.nn.Embedding(
#     num_embeddings=n_items + 1,  # +1 for the unknown item (-1 padding)
#     embedding_dim=args.embedding_dim,
#     padding_idx=n_items,  # Set padding_idx to the last index
# )
# pretrained_item_embedding.weight.data[:n_items] = skipgram_item_embedding[:n_items]
# pretrained_item_embedding.weight.data[n_items] = torch.zeros(
#     args.embedding_dim, dtype=skipgram_item_embedding.dtype
# )

In [7]:
# mlf_client = mlflow.MlflowClient()
# model = mlflow.pyfunc.load_model(
#     model_uri=f"models:/{args.mlf_item2vec_model_name}@champion"
# )
# skipgram_model = model.unwrap_python_model().model
# embedding_0 = skipgram_model.embeddings(torch.tensor(0))
# embedding_dim = embedding_0.size()[0]
# id_mapping = model.unwrap_python_model().id_mapping
# pretrained_item_embedding = skipgram_model.embeddings

In [8]:
# assert (
#     pretrained_item_embedding.embedding_dim == args.embedding_dim
# ), "Mismatch pretrained item_embedding dimension"

## Load vectorized item features

In [9]:
with open(args.item_metadata_pipeline_fp, "rb") as f:
    item_metadata_pipeline = dill.load(f)

## Load ANN Index

In [10]:
# if args.rc.use_sbert_features:
#     ann_index = AnnIndex(args.qdrant_url, args.qdrant_collection_name)
#     vector = ann_index.get_vector_by_ids([0])[0]
#     sbert_embedding_dim = vector.shape[0]
#     logger.info(f"{sbert_embedding_dim=}")
#     neighbors = ann_index.get_neighbors_by_ids([0])
#     display(neighbors)

# Prep data

In [11]:
train_df = pd.read_parquet("../data_for_ai/interim/train_sample_interactions_16407u_features_neg_seq_without_stats_item_user.parquet")
val_df = pd.read_parquet("../data_for_ai/interim/val_sample_interactions_16407u_features_neg_seq_without_stats_item_user.parquet")
idm_fp = "../data_for_ai/interim/idm_16407u.json"
idm = IDMapper().load(idm_fp)

assert (
    train_df[args.user_col].map(lambda s: idm.get_user_index(s))
    != train_df["user_indice"]
).sum() == 0, "Mismatch IDM"
assert (
    val_df[args.user_col].map(lambda s: idm.get_user_index(s)) != val_df["user_indice"]
).sum() == 0, "Mismatch IDM"

if args.rc.use_item_tags_from_llm:
    assert (
        "tags" in train_df.columns
    ), "There is no column `tags` in train_df, please make sure you have run notebook 002, 020 with RunCfg.use_item_tags_from_llm=True"

4817 items in the dataset


In [12]:
train_df.head()

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,item_sequence,timestamp_unix,item_sequence_ts,item_sequence_ts_bucket,main_category,title,description,categories,price
0,AEEV5YWQKPBTLFWHKOBBULYA2RDQ,B009RUZ7TS,0.0,2014-07-17 19:15:55.000,1412,4220,"[-1, -1, -1, -1, -1, -1, -1, 4559, 4443, 3164]",1405624555,"[-1, -1, -1, -1, -1, -1, -1, 1405624273, 14056...","[-1, -1, -1, -1, -1, -1, -1, 0, 0, 0]",All Electronics,"SanDisk 32GB Class 4 SDHC Memory Card, Frustra...",[],"[Electronics, Computers & Accessories, Compute...",
1,AF7KZV4NJ5GBDVFTB7PEEUN4U53A,B0BBMLD8QT,5.0,2015-07-29 20:38:06.000,4871,4476,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, 1924]",1438202286,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, 1436921997]","[-1, -1, -1, -1, -1, -1, -1, -1, -1, 4]",All Electronics,Logitech S150 USB Speakers with Digital Sound,"[There are plenty of speakers out there, with ...","[Electronics, Computers & Accessories, Compute...",10.78
2,AFVQ4K4KZPLQ3E2VFYSGX6HFXGNQ,B0BB6R89VF,0.0,2017-12-13 20:35:02.334,7616,1218,"[-1, -1, -1, -1, -1, -1, -1, 1293, 1728, 445]",1513197302,"[-1, -1, -1, -1, -1, -1, -1, 1427996903, 14279...","[-1, -1, -1, -1, -1, -1, -1, 6, 6, 6]",All Electronics,"Belkin Surge Protector Power Cube, Power Strip...",[The Belkin 6-Outlet Surge Protector Power Cub...,"[Electronics, Accessories & Supplies, Power St...",24.99
3,AFCLWJMGYFCOJQR7T4454OF5A5WA,B00ENFP224,5.0,2015-09-06 12:09:59.000,5250,1355,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]",1441541399,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]","[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]",Computers,(Old Model) Seagate 4TB Gaming SSHD(Solid Stat...,[<br><h3>Model</h3><strong>Brand: </strong>Sea...,"[Electronics, Computers & Accessories, Compute...",
4,AFP4PHJ6Q2RRXLDPSDSH6VXJRUTA,B07CMXS5FP,0.0,2018-11-23 09:44:21.734,6792,838,"[-1, -1, -1, 1055, 3572, 3865, 1761, 1591, 388...",1542966261,"[-1, -1, -1, 1403729520, 1447419458, 145791450...","[-1, -1, -1, 7, 7, 6, 6, 6, 6, 5]",Computers,A-Tech 1GB DDR 400MHz PC3200 184-pin DIMM Desk...,[],"[Electronics, Computers & Accessories, Compute...",24.97


In [13]:
print(train_df.shape)
train_df.head()

(254784, 15)


Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,item_sequence,timestamp_unix,item_sequence_ts,item_sequence_ts_bucket,main_category,title,description,categories,price
0,AEEV5YWQKPBTLFWHKOBBULYA2RDQ,B009RUZ7TS,0.0,2014-07-17 19:15:55.000,1412,4220,"[-1, -1, -1, -1, -1, -1, -1, 4559, 4443, 3164]",1405624555,"[-1, -1, -1, -1, -1, -1, -1, 1405624273, 14056...","[-1, -1, -1, -1, -1, -1, -1, 0, 0, 0]",All Electronics,"SanDisk 32GB Class 4 SDHC Memory Card, Frustra...",[],"[Electronics, Computers & Accessories, Compute...",
1,AF7KZV4NJ5GBDVFTB7PEEUN4U53A,B0BBMLD8QT,5.0,2015-07-29 20:38:06.000,4871,4476,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, 1924]",1438202286,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, 1436921997]","[-1, -1, -1, -1, -1, -1, -1, -1, -1, 4]",All Electronics,Logitech S150 USB Speakers with Digital Sound,"[There are plenty of speakers out there, with ...","[Electronics, Computers & Accessories, Compute...",10.78
2,AFVQ4K4KZPLQ3E2VFYSGX6HFXGNQ,B0BB6R89VF,0.0,2017-12-13 20:35:02.334,7616,1218,"[-1, -1, -1, -1, -1, -1, -1, 1293, 1728, 445]",1513197302,"[-1, -1, -1, -1, -1, -1, -1, 1427996903, 14279...","[-1, -1, -1, -1, -1, -1, -1, 6, 6, 6]",All Electronics,"Belkin Surge Protector Power Cube, Power Strip...",[The Belkin 6-Outlet Surge Protector Power Cub...,"[Electronics, Accessories & Supplies, Power St...",24.99
3,AFCLWJMGYFCOJQR7T4454OF5A5WA,B00ENFP224,5.0,2015-09-06 12:09:59.000,5250,1355,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]",1441541399,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]","[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]",Computers,(Old Model) Seagate 4TB Gaming SSHD(Solid Stat...,[<br><h3>Model</h3><strong>Brand: </strong>Sea...,"[Electronics, Computers & Accessories, Compute...",
4,AFP4PHJ6Q2RRXLDPSDSH6VXJRUTA,B07CMXS5FP,0.0,2018-11-23 09:44:21.734,6792,838,"[-1, -1, -1, 1055, 3572, 3865, 1761, 1591, 388...",1542966261,"[-1, -1, -1, 1403729520, 1447419458, 145791450...","[-1, -1, -1, 7, 7, 6, 6, 6, 6, 5]",Computers,A-Tech 1GB DDR 400MHz PC3200 184-pin DIMM Desk...,[],"[Electronics, Computers & Accessories, Compute...",24.97


In [14]:
user_indices = train_df["user_indice"].unique()
item_indices = train_df["item_indice"].unique()

train_item_features = chunk_transform(
    train_df, item_metadata_pipeline, chunk_size=args.tfm_chunk_size
)
train_item_features = train_item_features.astype(np.float32)

val_item_features = chunk_transform(
    val_df, item_metadata_pipeline, chunk_size=args.tfm_chunk_size
)
val_item_features = val_item_features.astype(np.float32)

logger.info(f"{len(user_indices)=:,.0f}, {len(item_indices)=:,.0f}")

Transforming chunks:   0%|          | 0/26 [00:00<?, ?it/s]

Transforming chunks:   0%|          | 0/1 [00:00<?, ?it/s]

[32m2025-07-01 01:11:41.501[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m14[0m - [1mlen(user_indices)=16,407, len(item_indices)=4,817[0m


In [15]:
val_item_features.shape

(6958, 1)

# Train

In [16]:
rating_dataset = UserItemBinaryDFDataset(
    train_df,
    "user_indice",
    "item_indice",
    args.rating_col,
    args.timestamp_col,
    item_feature=train_item_features,
)
val_rating_dataset = UserItemBinaryDFDataset(
    val_df,
    "user_indice",
    "item_indice",
    args.rating_col,
    args.timestamp_col,
    item_feature=val_item_features,
)

train_loader = DataLoader(
    rating_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True
)
val_loader = DataLoader(
    val_rating_dataset, batch_size=args.batch_size, shuffle=True, drop_last=False
)

In [17]:
n_items = len(item_indices)
n_users = len(user_indices)
item_feature_size = train_item_features.shape[1]

model = init_model(
    n_users,
    n_items,
    args.embedding_dim,
    args.item_sequence_ts_bucket_size,
    args.bucket_embedding_dim,
    item_feature_size,
    args.dropout,
    use_item_feature=args.use_item_feature,
)
model.item_embedding.padding_idx

4817

In [18]:
for i in val_loader:
    print(i)
    break

{'user': tensor([ 2457, 13643,  7216,  7722,  1184,  2239, 10147,  6239,  1789,  3737,
         2513,  7386, 13819,   539,   975,  8896, 10435,  9564,  2021,  6275,
         8568,  2429,  2468,  2283,  8821, 12100,  1009, 14129, 10894, 16348,
         8238, 12733, 10063,  6119,  1463, 11479,  3905,  3181,  7304, 12115,
         1120, 15029,  8781, 13440,  1692,  5192,  4935,  2030,  1533,  6849,
        12712,  1701,  9927,  4590, 10816,  3638, 11302,  6273, 10123,  6910,
        10088,  3437,   226,  1163,  2844, 10497,  1051, 12811,  7674,  5777,
         4800, 10364,  4164,   554, 15401,  4802,  1318,   298,  2435,  9272,
          984,  1533, 10497,  7389,  6450, 12547,   466, 16139,  2094, 15010,
        13776,  4485, 15577,  2934,  6358,  2513,  4961, 10894, 10842, 14531,
        13682,    21,  5568,  9403,  8531,  2922,  7196,  1150, 14821,  2767,
        12339,  1789,   896,  5677,  4376,  9004,  9411,  5435,  7805, 10487,
        10932,   174, 10198,  4775,  6071,  8944,  7731

#### Predict before train

In [19]:
print(val_df.shape)
val_df.head()

(6958, 15)


Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,item_sequence,timestamp_unix,item_sequence_ts,item_sequence_ts_bucket,main_category,title,description,categories,price
0,AGMJWWTZ6HMM2FBRDLFW2CWMV5DQ,B00E0ISVLI,0.0,2021-07-18 15:44:29.739,10483,2563,"[-1, 2906, 3011, 4674, 4593, 4755, 3810, 3921,...",1626623069,"[-1, 1502521296, 1539601394, 1539993160, 15472...","[-1, 7, 6, 6, 6, 6, 6, 5, 5, 0]",Home Audio & Theater,Kaito KA500 5-way Powered Emergency AM/FM/SW N...,[],"[Electronics, Portable Audio & Video, Radios]",49.98
1,AE3XVOCHEO5MTDIAIET5BZS26AJA,B07GPGVYGX,0.0,2021-03-12 03:28:00.854,254,3381,"[-1, -1, -1, -1, 1188, 1510, 4399, 3089, 2290,...",1615519680,"[-1, -1, -1, -1, 1413239830, 1419709485, 14936...","[-1, -1, -1, -1, 8, 8, 7, 7, 7, 7]",All Electronics,Amazon Basics 8-Outlet Power Strip Surge Prote...,[],"[Electronics, Accessories & Supplies, Power St...",17.23
2,AESPJW3GNHXNJNW5CYV7PTEX44MQ,B07GZFM1ZM,0.0,2021-02-09 16:08:20.512,3190,921,"[-1, -1, -1, -1, -1, 2569, 2742, 2855, 2351, 346]",1612886900,"[-1, -1, -1, -1, -1, 1527447297, 1527447304, 1...","[-1, -1, -1, -1, -1, 6, 6, 6, 6, 6]",Amazon Devices,Fire TV Stick 4K streaming device with Alexa V...,[],[],
3,AE3HTD5GV52IDFUQ6MMXRNF4WDZQ,B09M3BZYVP,0.0,2021-03-30 11:48:08.855,181,971,"[-1, -1, -1, -1, -1, 1872, 1570, 2366, 3899, 3...",1617104888,"[-1, -1, -1, -1, -1, 1438559226, 1469567052, 1...","[-1, -1, -1, -1, -1, 8, 7, 7, 6, 6]",Computers,Seagate BarraCuda 4TB Internal Hard Drive HDD ...,"[Store more, compute faster, and do it confide...","[Electronics, Computers & Accessories, Compute...",67.99
4,AHTGQCLAFVD43IQ2AIERW2FQ7P4A,B00006JPE1,5.0,2021-02-10 14:43:48.128,15577,25,"[4166, 3089, 3074, 3443, 3227, 3493, 4466, 355...",1612968228,"[1507923309, 1510784806, 1513265514, 155008170...","[7, 7, 7, 6, 6, 6, 6, 6, 5, 0]",All Electronics,CHANNEL PLUS 2532 2-Way Splitter/Combiner CHAN...,[The model 2532 2-way splitter/combiner is a b...,"[Electronics, Television & Video, Accessories,...",


In [20]:
val_df = val_rating_dataset.df
val_df.sample(10)

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,item_sequence,timestamp_unix,item_sequence_ts,item_sequence_ts_bucket,main_category,title,description,categories,price
5826,AGRWOSMEDY4NK3LICTJNSZ3CTYTA,B08D7638C8,1.0,2021-03-02 15:53:49.190,11178,3914,"[4366, 1479, 3206, 4423, 2418, 4107, 3785, 330...",1614700429,"[1573764914, 1573764979, 1582742685, 158799808...","[6, 6, 6, 5, 5, 5, 5, 5, 5, 5]",Computers,Samsung 970 EVO Plus SSD 2TB NVMe M.2 Internal...,[>For intensive workloads on PCs and workstati...,"[Electronics, Computers & Accessories, Data St...",114.99
3919,AE42UVGYZ3PMRSOQWSSCYNLZXKRQ,B07QMV54LY,0.0,2021-02-28 21:24:26.247,267,1852,"[3104, 2457, 2253, 3025, 3089, 3188, 3147, 323...",1614547466,"[1475241070, 1486080908, 1486081003, 151640326...","[7, 7, 7, 7, 7, 6, 6, 6, 6, 4]",All Electronics,XIRON [2 PACK] Paper Screen Protector Compatib...,[],"[Electronics, Computers & Accessories, Tablet ...",8.98
5961,AFRKNECR7GWCTPK5J54FQKWTV2NA,B07P9V8GSH,1.0,2021-08-17 02:04:14.209,7114,3585,"[-1, -1, -1, -1, 3221, 3493, 3245, 58, 2665, 4...",1629165854,"[-1, -1, -1, -1, 1417826500, 1430793725, 14307...","[-1, -1, -1, -1, 8, 8, 8, 8, 7, 7]",Computers,SanDisk Ultra 32GB UHS-I/Class 10 Micro SDHC M...,[SanDisk Ultra plus 32GB class 10 memory card ...,"[Electronics, Computers & Accessories, Compute...",8.29
2133,AEB2GBRQETUHD2W5DRQBSDAFGTVQ,B07GZFM1ZM,0.0,2021-07-22 19:12:50.377,929,2375,"[-1, 3279, 1647, 856, 1859, 2694, 1395, 3517, ...",1626981170,"[-1, 1364744454, 1496760660, 1496760740, 14967...","[-1, 8, 7, 7, 7, 7, 7, 6, 6, 6]",Amazon Devices,Fire TV Stick 4K streaming device with Alexa V...,[],[],
5066,AECNFQYIBHLAPCYARMOTD3PZMJKA,B0BQGMX5PJ,0.0,2021-12-15 19:53:26.964,1157,3214,"[-1, -1, -1, -1, 4599, 4730, 4709, 4311, 3360,...",1639598006,"[-1, -1, -1, -1, 1364069196, 1489268976, 15337...","[-1, -1, -1, -1, 8, 7, 7, 7, 6, 5]",Computers,Western Digital 8TB WD Red Plus NAS Internal H...,[Packed with power to handle the small- to med...,"[Electronics, Computers & Accessories, Data St...",149.99
4442,AGHCXQP3CLMP4K5DRAMJTVRXT5VQ,B0BTVN2YTV,0.0,2021-12-30 22:30:20.132,9836,3520,"[-1, -1, -1, 4504, 1781, 1433, 1434, 2886, 400...",1640903420,"[-1, -1, -1, 1520816751, 1525414800, 157176265...","[-1, -1, -1, 7, 7, 6, 6, 6, 6, 1]",Computers,Lexar Professional 1667x 64GB SDXC UHS-II Memo...,"[Whether you're a professional photographer, v...","[Electronics, Computers & Accessories, Compute...",29.99
2207,AGKZ7OXKTIXPKEDDKBUAKBD6OHRQ,B09S6Y5BRG,0.0,2022-01-10 22:37:05.953,10300,2519,"[-1, -1, -1, -1, 2694, 4304, 1866, 3074, 2811,...",1641854225,"[-1, -1, -1, -1, 1498939359, 1520349976, 15206...","[-1, -1, -1, -1, 7, 7, 7, 7, 7, 6]",All Electronics,Otium Bluetooth Earbuds Wireless Headphones Bl...,[],"[Electronics, Headphones, Earbuds & Accessorie...",19.99
115,AGUUWPWWPXT5A7TXKB7POFPI6G7A,B071JN4FW6,0.0,2021-03-31 00:40:23.540,11582,1373,"[-1, -1, -1, -1, 2858, 4471, 1791, 2475, 4031,...",1617151223,"[-1, -1, -1, -1, 1454037483, 1459833067, 14807...","[-1, -1, -1, -1, 8, 7, 7, 7, 5, 5]",All Electronics,DOSS SoundBox XL Bluetooth Speaker with Subwoo...,[],"[Electronics, Portable Audio & Video, Portable...",86.09
6466,AGCDEWAHYEFIWC2Q6GNQTPEZCOCA,B09V1FT19S,1.0,2021-04-27 23:54:07.446,9216,4313,"[-1, -1, -1, -1, -1, 2786, 3820, 2276, 4542, 3...",1619567647,"[-1, -1, -1, -1, -1, 1585864247, 1588274027, 1...","[-1, -1, -1, -1, -1, 6, 5, 5, 5, 5]",Computers,SanDisk 256GB Extreme microSDXC UHS-I Memory C...,[With the SanDisk Extreme 256GB microSD UHS-I ...,"[Electronics, Computers & Accessories, Compute...",28.99
4463,AHGX6VUMJZXM2EGN5PWBUJWQC32A,B07HL4PSH4,1.0,2021-01-24 03:03:45.954,13951,3465,"[-1, -1, -1, 198, 1570, 909, 2230, 4227, 1451,...",1611457425,"[-1, -1, -1, 1394897043, 1414690318, 141641750...","[-1, -1, -1, 8, 8, 8, 8, 7, 7, 7]",Amazon Devices,Kindle Paperwhite Leather Cover (10th Generati...,[],[],


In [21]:
user_id = val_df.sample(1)[args.user_col].values[0]
test_df = val_df.loc[lambda df: df[args.user_col].eq(user_id)]
with pd.option_context("display.max_colwidth", None):
    display(test_df)

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,item_sequence,timestamp_unix,item_sequence_ts,item_sequence_ts_bucket,main_category,title,description,categories,price
3360,AE7BUUEUVNMJQNQ4KHO55A4QLERA,B08F1P3BCC,0.0,2021-06-26 20:24:08.969,703,561,"[-1, 2237, 2229, 3645, 2508, 4287, 1400, 3024, 1636, 4403]",1624739048,"[-1, 1472560432, 1472560660, 1482602111, 1486598946, 1521945475, 1525131934, 1549294790, 1560094375, 1590940338]","[-1, 7, 7, 7, 7, 7, 7, 6, 6, 6]",Amazon Devices,Echo Dot (4th Gen) | Smart speaker with Alexa | Twilight Blue,[],[],
5100,AE7BUUEUVNMJQNQ4KHO55A4QLERA,B08F1P3BCC,1.0,2021-06-26 20:24:08.969,703,3925,"[-1, 2237, 2229, 3645, 2508, 4287, 1400, 3024, 1636, 4403]",1624739048,"[-1, 1472560432, 1472560660, 1482602111, 1486598946, 1521945475, 1525131934, 1549294790, 1560094375, 1590940338]","[-1, 7, 7, 7, 7, 7, 7, 6, 6, 6]",Amazon Devices,Echo Dot (4th Gen) | Smart speaker with Alexa | Twilight Blue,[],[],


In [22]:
val_item_features.shape, train_item_features.shape

((6958, 1), (254784, 1))

In [23]:
test_row = test_df.loc[lambda df: df[args.rating_col].gt(0)].iloc[0]
item_id = test_row[args.item_col]
item_sequence = test_row["item_sequence"]
item_sequence_ts_bucket = test_row["item_sequence_ts_bucket"]
row_idx = test_row.name
item_feature = val_item_features[row_idx]
logger.info(
    f"Test predicting before training with {args.user_col} = {user_id} and {args.item_col} = {item_id}"
)
user_indice = idm.get_user_index(user_id)
item_indice = idm.get_item_index(item_id)
user = torch.tensor([user_indice])
item_sequence = torch.tensor([item_sequence])
item_sequence_ts_bucket = torch.tensor([item_sequence_ts_bucket])
item_feature = torch.tensor([item_feature])
item = torch.tensor([item_indice])

model.eval()
model.predict(user, item_sequence, item_sequence_ts_bucket, item_feature, item)
model.train()

[32m2025-07-01 01:11:42.129[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m7[0m - [1mTest predicting before training with user_id = AE7BUUEUVNMJQNQ4KHO55A4QLERA and parent_asin = B08F1P3BCC[0m

Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /pytorch/torch/csrc/utils/tensor_new.cpp:254.)



Ranker(
  (item_embedding): Embedding(4818, 256, padding_idx=4817)
  (user_embedding): Embedding(16407, 256)
  (item_sequence_ts_bucket_embedding): Embedding(11, 16, padding_idx=10)
  (gru): GRU(272, 256, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (item_feature_tower): Sequential(
    (0): Linear(in_features=1, out_features=256, bias=True)
    (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
  )
  (fc_rating): Sequential(
    (0): Linear(in_features=1024, out_features=256, bias=True)
    (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=256, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

#### Training loop

##### Overfit 1 batch

In [24]:
early_stopping = EarlyStopping(
    monitor="val_loss", patience=10, mode="min", verbose=False
)

model = init_model(
    n_users,
    n_items,
    args.embedding_dim,
    args.item_sequence_ts_bucket_size,
    args.bucket_embedding_dim,
    item_feature_size,
    dropout=args.dropout,
    use_item_feature=args.use_item_feature,
)
lit_model = LitRanker(
    model,
    learning_rate=args.learning_rate,
    l2_reg=0.0,
    log_dir=args.notebook_persist_dp,
    accelerator=args.device,
)

log_dir = f"{args.notebook_persist_dp}/logs/overfit"

# # train model
# trainer = L.Trainer(
#     default_root_dir=log_dir,
#     accelerator=args.device if args.device else "auto",
#     max_epochs=2,
#     overfit_batches=1,
#     callbacks=[early_stopping],
# )
# trainer.fit(
#     model=lit_model,
#     train_dataloaders=train_loader,
#     val_dataloaders=val_loader,

# )
# logger.info(f"Logs available at {trainer.log_dir}")

In [25]:
# Need to make sure port 6006 at local is accessible
# %tensorboard --logdir $trainer.log_dir

##### Fit on all data

In [26]:
# print the number of rows in train_df that has rating = 0 and 1
print(
    f"Number of rows in train_df that has rating = 0: {train_df[train_df[args.rating_col] == 0.0].shape[0]}"
)
print(
    f"Number of rows in train_df that has rating = 1: {train_df[train_df[args.rating_col] >= 1.0].shape[0]}"
)
print(f"Number of rows in train_df: {train_df.shape[0]}")

Number of rows in train_df that has rating = 0: 127392
Number of rows in train_df that has rating = 1: 127392
Number of rows in train_df: 254784


In [27]:
# group by a specific user_id and all the rows for that user
user_id = "AF5KKBAOVY7J7LGPHAECKUTDQVTA"
user_df = train_df.loc[lambda df: df[args.user_col].eq(user_id)]
print(f"Number of rows for user {user_id}: {user_df.shape[0]}")
user_df = user_df.sort_values(by=args.timestamp_col, ascending=False)
user_df

Number of rows for user AF5KKBAOVY7J7LGPHAECKUTDQVTA: 28


Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,item_sequence,timestamp_unix,item_sequence_ts,item_sequence_ts_bucket,main_category,title,description,categories,price
141974,AF5KKBAOVY7J7LGPHAECKUTDQVTA,B077V2BF3C,5.0,2020-01-16 15:17:23.469,4590,3145,"[1396, 511, 582, 3096, 3795, 1314, 1610, 2253,...",1579187843,"[1455485886, 1455486035, 1455486042, 149334639...","[7, 7, 7, 6, 6, 6, 6, 6, 6, 5]",Computers,"Moread HDMI to VGA, 2 Pack, Gold-Plated HDMI t...",[],"[Electronics, Computers & Accessories, Compute...",14.99
80879,AF5KKBAOVY7J7LGPHAECKUTDQVTA,B077V2BF3C,0.0,2020-01-16 15:17:23.469,4590,1661,"[1396, 511, 582, 3096, 3795, 1314, 1610, 2253,...",1579187843,"[1455485886, 1455486035, 1455486042, 149334639...","[7, 7, 7, 6, 6, 6, 6, 6, 6, 5]",Computers,"Moread HDMI to VGA, 2 Pack, Gold-Plated HDMI t...",[],"[Electronics, Computers & Accessories, Compute...",14.99
47872,AF5KKBAOVY7J7LGPHAECKUTDQVTA,B0779V61XB,0.0,2019-01-29 21:29:22.979,4590,3310,"[1333, 1396, 511, 582, 3096, 3795, 1314, 1610,...",1548797362,"[1455485483, 1455485886, 1455486035, 145548604...","[6, 6, 6, 6, 6, 6, 6, 6, 5, 5]",Computers,UGREEN SD Card Reader Portable USB 3.0 Dual Sl...,[],"[Electronics, Computers & Accessories, Compute...",12.99
90924,AF5KKBAOVY7J7LGPHAECKUTDQVTA,B0779V61XB,5.0,2019-01-29 21:29:22.979,4590,3128,"[1333, 1396, 511, 582, 3096, 3795, 1314, 1610,...",1548797362,"[1455485483, 1455485886, 1455486035, 145548604...","[6, 6, 6, 6, 6, 6, 6, 6, 5, 5]",Computers,UGREEN SD Card Reader Portable USB 3.0 Dual Sl...,[],"[Electronics, Computers & Accessories, Compute...",12.99
205044,AF5KKBAOVY7J7LGPHAECKUTDQVTA,B08CLNX58K,0.0,2018-11-16 22:40:45.180,4590,2054,"[3585, 1333, 1396, 511, 582, 3096, 3795, 1314,...",1542408045,"[1455485241, 1455485483, 1455485886, 145548603...","[6, 6, 6, 6, 6, 6, 6, 6, 5, 5]",Computers,SanDisk 32GB 2-Pack Ultra MicroSDHC UHS-I Memo...,[Transfer speeds of up to 98MB/sec . Records F...,"[Electronics, Computers & Accessories, Compute...",13.4
231845,AF5KKBAOVY7J7LGPHAECKUTDQVTA,B08CLNX58K,5.0,2018-11-16 22:40:45.180,4590,3908,"[3585, 1333, 1396, 511, 582, 3096, 3795, 1314,...",1542408045,"[1455485241, 1455485483, 1455485886, 145548603...","[6, 6, 6, 6, 6, 6, 6, 6, 5, 5]",Computers,SanDisk 32GB 2-Pack Ultra MicroSDHC UHS-I Memo...,[Transfer speeds of up to 98MB/sec . Records F...,"[Electronics, Computers & Accessories, Compute...",13.4
43162,AF5KKBAOVY7J7LGPHAECKUTDQVTA,B011BRUOMO,5.0,2018-02-16 03:17:37.395,4590,2253,"[150, 3585, 1333, 1396, 511, 582, 3096, 3795, ...",1518751057,"[1455485232, 1455485241, 1455485483, 145548588...","[6, 6, 6, 6, 6, 6, 5, 5, 5, 5]",Computers,SanDisk Ultra 32GB microSDHC UHS-I Card with A...,"[Capture, carry and keep more high-quality pho...","[Electronics, Computers & Accessories, Compute...",8.99
70579,AF5KKBAOVY7J7LGPHAECKUTDQVTA,B011BRUOMO,0.0,2018-02-16 03:17:37.395,4590,1933,"[150, 3585, 1333, 1396, 511, 582, 3096, 3795, ...",1518751057,"[1455485232, 1455485241, 1455485483, 145548588...","[6, 6, 6, 6, 6, 6, 5, 5, 5, 5]",Computers,SanDisk Ultra 32GB microSDHC UHS-I Card with A...,"[Capture, carry and keep more high-quality pho...","[Electronics, Computers & Accessories, Compute...",8.99
27145,AF5KKBAOVY7J7LGPHAECKUTDQVTA,B00JO6RO8C,5.0,2017-11-29 20:39:49.054,4590,1610,"[-1, 150, 3585, 1333, 1396, 511, 582, 3096, 37...",1511987989,"[-1, 1455485232, 1455485241, 1455485483, 14554...","[-1, 6, 6, 6, 6, 6, 6, 5, 4, 4]",Computers,SanDisk Cruzer Fit CZ33 32GB USB 2.0 Low-Profi...,"[With its low-profile design, the Cruzer Fit U...","[Electronics, Computers & Accessories, Data St...",21.5
235398,AF5KKBAOVY7J7LGPHAECKUTDQVTA,B00JO6RO8C,0.0,2017-11-29 20:39:49.054,4590,3146,"[-1, 150, 3585, 1333, 1396, 511, 582, 3096, 37...",1511987989,"[-1, 1455485232, 1455485241, 1455485483, 14554...","[-1, 6, 6, 6, 6, 6, 6, 5, 4, 4]",Computers,SanDisk Cruzer Fit CZ33 32GB USB 2.0 Low-Profi...,"[With its low-profile design, the Cruzer Fit U...","[Electronics, Computers & Accessories, Data St...",21.5


In [28]:
# sort the train_df by timestamp and get the lastest item features from train_df
all_items_df = train_df.sort_values(by=args.timestamp_col, ascending=False)
# get the lastest item features from train_df
all_items_indices = all_items_df.drop_duplicates(subset=[args.item_col], keep="first")["item_indice"].values
all_items_features = item_metadata_pipeline.transform(all_items_df).astype(np.float32)
logger.info(
    f"Mean std over categorical and numerical features: {all_items_features.std(axis=0).mean()}"
)
# if args.rc.use_sbert_features:
#     all_sbert_vectors = ann_index.get_vector_by_ids(all_items_indices.tolist()).astype(
#         np.float32
#     )
#     logger.info(f"Mean std over text features: {all_sbert_vectors.std(axis=0).mean()}")
#     all_items_features = np.hstack([all_items_features, all_sbert_vectors])

[32m2025-07-01 01:11:42.725[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mMean std over categorical and numerical features: 0.8554474711418152[0m


In [29]:
# all_items_df = train_df.drop_duplicates(subset=["item_indice"])
# all_items_indices = all_items_df["item_indice"].values
# all_items_features = item_metadata_pipeline.transform(all_items_df).astype(np.float32)
# logger.info(
#     f"Mean std over categorical and numerical features: {all_items_features.std(axis=0).mean()}"
# )
# # if args.rc.use_sbert_features:
# #     all_sbert_vectors = ann_index.get_vector_by_ids(all_items_indices.tolist()).astype(
# #         np.float32
# #     )
# #     logger.info(f"Mean std over text features: {all_sbert_vectors.std(axis=0).mean()}")
# #     all_items_features = np.hstack([all_items_features, all_sbert_vectors])

In [30]:
# papermill_description=fit-model
early_stopping = EarlyStopping(
    monitor="val_roc_auc", patience=args.early_stopping_patience, mode="max", verbose=False
)

checkpoint_callback = ModelCheckpoint(
    dirpath=f"{args.notebook_persist_dp}/checkpoints",
    filename="best-checkpoint",
    save_top_k=1,
    monitor="val_roc_auc",
    mode="max",
)

model = init_model(
    n_users,
    n_items,
    args.embedding_dim,
    args.item_sequence_ts_bucket_size,
    args.bucket_embedding_dim,
    item_feature_size,
    dropout=args.dropout,
    item_embedding=None,
    use_item_feature=args.use_item_feature,
)
lit_model = LitRanker(
    model,
    learning_rate=args.learning_rate,
    l2_reg=args.l2_reg,
    log_dir=args.notebook_persist_dp,
    evaluate_ranking=True,
    idm=idm,
    all_items_indices=all_items_indices,
    all_items_features=all_items_features,
    args=args,
    neg_to_pos_ratio=args.neg_to_pos_ratio,
    checkpoint_callback=checkpoint_callback,
    accelerator=args.device,
)

In [31]:
log_dir = f"{args.notebook_persist_dp}/logs/run"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    max_epochs=args.max_epochs,
    callbacks=[early_stopping, checkpoint_callback],
    accelerator=args.device if args.device else "auto",
    logger=args._mlf_logger if args.log_to_mlflow else None,
)
trainer.fit(
    model=lit_model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3050 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision

Checkpoint directory /home/dinhln/Desktop/real_time_recsys/notebooks/data/004-use-sbert-features-and-llm-tags/checkpoints exists and is not empty.

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name               | Type                   | Params | Mode 
----------------------------------------------------------------------
0 | model              | Ranker                 | 6.1 M  | train
1 | val_roc_auc_metric | BinaryAUROC            | 0      | train
2 | val_pr_auc_metric  | BinaryAveragePrecision | 0

Sanity Checking: |          | 0/? [00:00<?, ?it/s]


Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.



Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.
[32m2025-07-01 01:12:26.166[0m | [1mINFO    [0m | [36msrc.algo.ranker.trainer[0m:[36mon_fit_end[0m:[36m199[0m - [1mLoading best model from /home/dinhln/Desktop/real_time_recsys/notebooks/data/004-use-sbert-features-and-llm-tags/checkpoints/best-checkpoint-v8.ckpt...[0m
[32m2025-07-01 01:12:26.315[0m | [1mINFO    [0m | [36msrc.algo.ranker.trainer[0m:[36mon_fit_end[0m:[36m206[0m - [1mLogging classification metrics...[0m
[32m2025-07-01 01:12:42.299[0m | [1mINFO    [0m | [36msrc.algo.ranker.trainer[0m:[36mon_fit_end[0m:[36m209[0m - [1mLogging ranking metrics...[0m


🏃 View run 004-use-sbert-features-and-llm-tags at: http://138.2.61.6:5002/#/experiments/12/runs/651210476fe94867b5e14d928a3f4415
🧪 View experiment at: http://138.2.61.6:5002/#/experiments/12


OutOfMemoryError: CUDA out of memory. Tried to allocate 892.00 MiB. GPU 0 has a total capacity of 3.69 GiB of which 821.44 MiB is free. Process 239034 has 264.00 MiB memory in use. Process 256669 has 264.00 MiB memory in use. Process 258769 has 2.02 GiB memory in use. Including non-PyTorch memory, this process has 326.00 MiB memory in use. Of the allocated memory 213.24 MiB is allocated by PyTorch, and 10.76 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
# print the number of unique items in val_df
print(f"Number of unique items in val_df: {val_df['item_indice'].nunique()}")

In [None]:
# logger.info(
#     f"Test predicting after training with {args.user_col} = {user_id} and {args.item_col} = {item_id}"
# )
# model.eval()
# model = model.to(user.device)  # Move model back to align with data device
# model.predict(user, item_sequence, item_sequence_ts_bucket, item_feature, item)
# model.train()

# Load best checkpoint

In [None]:
logger.info(f"Loading best checkpoint from {checkpoint_callback.best_model_path}...")
args.best_checkpoint_path = checkpoint_callback.best_model_path

best_trainer = LitRanker.load_from_checkpoint(
    checkpoint_callback.best_model_path,
    model=init_model(
        n_users,
        n_items,
        args.embedding_dim,
        args.item_sequence_ts_bucket_size,
        args.bucket_embedding_dim,
        item_feature_size,
        dropout=0,
        item_embedding=pretrained_item_embedding,
        use_item_feature=args.use_item_feature,
    ),
)

In [None]:
# best_trainer = LitRanker.load_from_checkpoint(
#     "C:/Users/Trieu/Downloads/best-checkpoint (2).ckpt",
#     model=init_model(
#         n_users,
#         n_items,
#         args.embedding_dim,
#         args.item_sequence_ts_bucket_size,
#         args.bucket_embedding_dim,
#         item_feature_size,
#         dropout=0,
#         item_embedding=pretrained_item_embedding,
#     ),
# )

## testing after train

In [None]:
# user_id = val_df.sample(1)[args.user_col].values[0]
# test_df = val_df.loc[lambda df: df[args.user_col].eq(user_id)]
# # with pd.option_context("display.max_colwidth", None):
# #     display(test_df)
# test_df.shape

In [None]:
# test_row = test_df.loc[lambda df: df[args.rating_col].eq(0)].iloc[0]
# item_id = test_row[args.item_col]
# item_sequence = test_row["item_sequence"]
# item_sequence_ts_bucket = test_row["item_sequence_ts_bucket"]
# row_idx = test_row.name
# item_feature = val_item_features[row_idx]
# logger.info(
#     f"Test predicting before training with {args.user_col} = {user_id} and {args.item_col} = {item_id}"
# )
# user_indice = idm.get_user_index(user_id)
# item_indice = idm.get_item_index(item_id)
# user = torch.tensor([user_indice])
# item_sequence = torch.tensor([item_sequence])
# item_sequence_ts_bucket = torch.tensor([item_sequence_ts_bucket])
# item_feature = torch.tensor([item_feature])
# item = torch.tensor([item_indice])

# # print the information of the user and item and rating we are testing as a row in dataframe
# user_df = pd.DataFrame({
#     args.user_col: [user_id],
#     args.item_col: [item_id],
#     args.rating_col: [test_row[args.rating_col]],
#     "item_sequence": [item_sequence.tolist()],
#     "item_sequence_ts_bucket": [item_sequence_ts_bucket],
#     "item_feature": [item_feature.tolist()],
# })
# user_df


In [None]:
best_model = best_trainer.model.to(lit_model.device)

In [None]:
# best_model.eval()
# best_model.predict(user, item_sequence, item_sequence_ts_bucket, item_feature, item)

### Persist artifacts

In [None]:
if args.log_to_mlflow:
    # Persist id_mapping so that at inference we can predict based on item_ids (string) instead of item_index
    run_id = trainer.logger.run_id
    mlf_client = trainer.logger.experiment
    mlf_client.log_artifact(run_id, idm_fp)
    # Persist item_feature_metadata pipeline
    mlf_client.log_artifact(run_id, args.item_metadata_pipeline_fp)

    # Persist model architecture
    model_architecture_fp = f"{args.notebook_persist_dp}/model_architecture.txt"
    with open(model_architecture_fp, "w") as f:
        f.write(repr(model))
    mlf_client.log_artifact(run_id, model_architecture_fp)

### Wrap inference function and register best checkpoint as MLflow model

In [None]:
inferrer = RankerInferenceWrapper(best_model)

In [None]:
def generate_sample_item_features():
    sample_row = train_df.iloc[0].fillna(0)
    output = dict()
    for col in args.rc.item_feature_cols:
        v = sample_row[col]
        if isinstance(v, np.ndarray):
            v = "__".join(
                sample_row[col].tolist()
            )  # Workaround to avoid MLflow Got error: Per-column arrays must each be 1-dimensional
        output[col] = [v]
    return output

In [None]:
sample_input = {
    args.user_col: [idm.get_user_id(0)],
    "item_sequence": [",".join([idm.get_item_id(0), idm.get_item_id(1)])],
    "item_sequence_ts": [
        "1095133116,109770848"
    ],  # Here we input unix timestamp seconds instead of timestamp bucket because we need to calculate the bucket
    # **{col: [train_df.iloc[0].fillna(0)[col]] for col in args.item_feature_cols},
    **generate_sample_item_features(),
    args.item_col: [idm.get_item_id(0)],
}
sample_output = inferrer.infer([0], [[0, 1]], [[2, 0]], [train_item_features[0]], [0])
sample_output

In [None]:
sample_input

In [None]:
if args.log_to_mlflow:
    run_id = trainer.logger.run_id
    sample_output_np = sample_output
    signature = infer_signature(sample_input, sample_output_np)
    idm_filename = idm_fp.split("/")[-1]
    item_metadata_pipeline_filename = args.item_metadata_pipeline_fp.split("/")[-1]
    with mlflow.start_run(run_id=run_id):
        mlflow.pyfunc.log_model(
            python_model=inferrer,
            artifact_path="inferrer",
            artifacts={
                # We log the id_mapping to the predict function so that it can accept item_id and automatically convert ot item_indice for PyTorch model to use
                "idm": mlflow.get_artifact_uri(idm_filename),
                "item_metadata_pipeline": mlflow.get_artifact_uri(
                    item_metadata_pipeline_filename
                ),
            },
            model_config={"use_sbert_features": args.rc.use_sbert_features},
            signature=signature,
            input_example=sample_input,
            registered_model_name=args.mlf_model_name,
        )

# Set the newly trained model as champion

In [None]:
if args.log_to_mlflow:
    # Get current champion
    deploy_alias = "champion"
    curr_model_run_id = None

    min_roc_auc = args.min_roc_auc

    try:
        curr_champion_model = mlf_client.get_model_version_by_alias(
            args.mlf_model_name, deploy_alias
        )
        curr_model_run_id = curr_champion_model.run_id
    except MlflowException as e:
        if "not found" in str(e).lower():
            logger.info(
                f"There is no {deploy_alias} alias for model {args.mlf_model_name}"
            )

    # Compare new vs curr models
    new_mlf_run = trainer.logger.experiment.get_run(trainer.logger.run_id)
    new_metrics = new_mlf_run.data.metrics
    roc_auc = new_metrics["roc_auc"]
    if curr_model_run_id:
        curr_model_run_info = mlf_client.get_run(curr_model_run_id)
        curr_metrics = curr_model_run_info.data.metrics
        if (curr_roc_auc := curr_metrics["roc_auc"]) > min_roc_auc:
            logger.info(
                f"Current {deploy_alias} model has {curr_roc_auc:,.4f} ROC-AUC..."
            )
            min_roc_auc = curr_roc_auc

        top_metrics = ["roc_auc", "val_PersonalizationMetric"]
        vizer = ModelMetricsComparisonVisualizer(curr_metrics, new_metrics, top_metrics)
        print(f"Comparing metrics between new run and current champion:")
        display(vizer.compare_metrics_df())
        vizer.create_metrics_comparison_plot(n_cols=5)
        vizer.plot_diff()

    # Register new champion
    if roc_auc < min_roc_auc:
        logger.info(
            f"Current run has ROC-AUC = {roc_auc:,.4f}, smaller than {min_roc_auc:,.4f}. Skip aliasing this model as the new {deploy_alias}.."
        )
    else:
        logger.info(f"Aliasing the new model as champion...")
        # Get the model version for current run by assuming it's the most recent registered version
        model_version = (
            mlf_client.get_registered_model(args.mlf_model_name)
            .latest_versions[0]
            .version
        )

        mlf_client.set_registered_model_alias(
            name=args.mlf_model_name, alias="champion", version=model_version
        )

        mlf_client.set_model_version_tag(
            name=args.mlf_model_name,
            version=model_version,
            key="author",
            value=args.author,
        )

# Clean up

In [None]:
all_params = [args]

if args.log_to_mlflow:
    with mlflow.start_run(run_id=run_id):
        for params in all_params:
            params_dict = params.dict()
            params_ = dict()
            for k, v in params_dict.items():
                if k == "top_K":
                    k = "top_big_K"
                if k == "top_k":
                    k = "top_small_k"
                params_[f"{params.__repr_name__()}.{k}"] = v
            mlflow.log_params(params_)