# Ranker that can takes into accound different features

# Set up

In [1]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

In [2]:
import os
import sys

import dill
import lightning as L
import numpy as np
import pandas as pd
import torch
from dotenv import load_dotenv
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import MLFlowLogger
from loguru import logger
from mlflow.exceptions import MlflowException
from mlflow.models.signature import infer_signature
from pydantic import BaseModel
from torch.utils.data import DataLoader

import mlflow

load_dotenv()

sys.path.insert(0, "..")

from cfg.run_cfg import RunCfg
# from src.ann import AnnIndex
from src.utils.data_prep import chunk_transform
from src.algo.ranker.dataset import UserItemBinaryDFDataset
from src.utils.embedding_id_mapper import IDMapper
from src.algo.ranker.inference import RankerInferenceWrapper
from src.algo.ranker.model import Ranker
from src.algo.ranker.trainer import LitRanker
from src.algo.item2vec.trainer import LitSkipGram
from src.algo.item2vec.model import SkipGram



# Controller

In [3]:
# This is a parameter cell used by papermill
max_epochs = 100

In [None]:
class Args(BaseModel):
    testing: bool = False
    author: str = "dinh-trieu"
    log_to_mlflow: bool = True
    experiment_name: str = "RecSys MVP - Ranker"
    run_name: str = "004-use-sbert-features-and-llm-tags"
    notebook_persist_dp: str = None
    random_seed: int = 41
    device: str = None

    rc: RunCfg = RunCfg().init()

    item_metadata_pipeline_fp: str = "../data_for_ai/interim/item_metadata_pipeline.dill"
    qdrant_url: str = None
    qdrant_collection_name: str = "item_desc_sbert"

    max_epochs: int = max_epochs
    batch_size: int = 2
    tfm_chunk_size: int = 10000
    neg_to_pos_ratio: int = 1

    user_col: str = "user_id"
    item_col: str = "parent_asin"
    rating_col: str = "rating"
    timestamp_col: str = "timestamp"

    top_K: int = 100
    top_k: int = 10

    embedding_dim: int = 256
    item_sequence_ts_bucket_size: int = 10
    bucket_embedding_dim: int = 16
    dropout: float = 0.3
    early_stopping_patience: int = 5
    learning_rate: float = 0.001
    l2_reg: float = 1e-5

    mlf_item2vec_model_name: str = "item2vec"
    mlf_model_name: str = "ranker"
    min_roc_auc: float = 0.7

    best_checkpoint_path: str = None

    def init(self):
        self.notebook_persist_dp = os.path.abspath(f"data/{self.run_name}")
        os.makedirs(self.notebook_persist_dp, exist_ok=True)

        if not (qdrant_host := os.getenv("QDRANT_HOST")):
            raise Exception(f"Environment variable QDRANT_HOST is not set.")

        qdrant_port = os.getenv("QDRANT_PORT")
        self.qdrant_url = f"{qdrant_host}:{qdrant_port}"

        if not (mlflow_uri := os.environ.get("MLFLOW_TRACKING_URI")):
            logger.warning(
                f"Environment variable MLFLOW_TRACKING_URI is not set. Setting self.log_to_mlflow to false."
            )
            self.log_to_mlflow = False

        if self.log_to_mlflow:
            logger.info(
                f"Setting up MLflow experiment {self.experiment_name} - run {self.run_name}..."
            )
            self._mlf_logger = MLFlowLogger(
                experiment_name=self.experiment_name,
                run_name=self.run_name,
                tracking_uri=mlflow_uri,
                log_model=True,
            )

        if self.device is None:
            self.device = (
                "cuda"
                if torch.cuda.is_available()
                else "mps" if torch.backends.mps.is_available() else "cpu"
            )

        return self


args = Args().init()

print(args.model_dump_json(indent=2))

[32m2025-06-29 14:55:26.118[0m | [34m[1mDEBUG   [0m | [36mcfg.run_cfg[0m:[36minit[0m:[36m31[0m - [34m[1mSetting use_sbert_features=True requires running notebook 016-sentence-transformers[0m
[32m2025-06-29 14:55:26.123[0m | [34m[1mDEBUG   [0m | [36mcfg.run_cfg[0m:[36minit[0m:[36m38[0m - [34m[1mChanging use_item_tags_from_llm requires re-running notebook 002-features-v2 to get the new item_metadata_pipeline.dill file[0m
[32m2025-06-29 14:55:26.123[0m | [1mINFO    [0m | [36m__main__[0m:[36minit[0m:[36m61[0m - [1mSetting up MLflow experiment RecSys MVP - Ranker - run 004-use-sbert-features-and-llm-tags...[0m


{
  "testing": false,
  "author": "dinh-trieu",
  "log_to_mlflow": true,
  "experiment_name": "RecSys MVP - Ranker",
  "run_name": "004-use-sbert-features-and-llm-tags",
  "notebook_persist_dp": "c:\\Users\\Trieu\\OneDrive\\Desktop\\recsys\\real_time_recsys\\notebooks\\data\\004-use-sbert-features-and-llm-tags",
  "random_seed": 41,
  "device": "cpu",
  "rc": {
    "use_sbert_features": true,
    "use_item_tags_from_llm": false,
    "item_feature_cols": [
      "main_category",
      "categories",
      "price",
      "parent_asin_rating_cnt_365d",
      "parent_asin_rating_avg_prev_rating_365d",
      "parent_asin_rating_cnt_90d",
      "parent_asin_rating_avg_prev_rating_90d",
      "parent_asin_rating_cnt_30d",
      "parent_asin_rating_avg_prev_rating_30d",
      "parent_asin_rating_cnt_7d",
      "parent_asin_rating_avg_prev_rating_7d"
    ]
  },
  "item_metadata_pipeline_fp": "../data_for_ai/interim/item_metadata_pipeline.dill",
  "qdrant_url": "138.2.61.6:6333",
  "qdrant_collec

# Implement

In [5]:
def init_model(
    n_users,
    n_items,
    embedding_dim,
    item_sequence_ts_bucket_size,
    bucket_embedding_dim,
    item_feature_size,
    dropout,
    item_embedding=None,
):
    model = Ranker(
        n_users,
        n_items,
        embedding_dim,
        item_sequence_ts_bucket_size=item_sequence_ts_bucket_size,
        bucket_embedding_dim=bucket_embedding_dim,
        item_feature_size=item_feature_size,
        dropout=dropout,
        item_embedding=item_embedding,
    )
    return model

## Load pretrained Item2Vec embeddings

In [6]:
n_items = 4817  # This should be the number of unique items in your dataset
assert args.embedding_dim == 256, "Embedding dimension must be 256"
best_trainer = LitSkipGram.load_from_checkpoint(
    "../data_for_ai/interim/best-item2vec-weight.ckpt",
    skipgram_model=SkipGram(n_items, args.embedding_dim).to(args.device),
)
skipgram_item_embedding = best_trainer.skipgram_model.embeddings.weight.data.cpu()
print(f"SkipGram Item embedding shape: {skipgram_item_embedding.shape}")
print(f"SkipGram Item embedding dtype: {skipgram_item_embedding.dtype}")

# create a embedding layer with num_items + 1 embedding, then apply the pretrained weights
pretrained_item_embedding = torch.nn.Embedding(
    num_embeddings=n_items + 1,  # +1 for the unknown item (-1 padding)
    embedding_dim=args.embedding_dim,
    padding_idx=n_items,  # Set padding_idx to the last index
)
pretrained_item_embedding.weight.data[:n_items] = skipgram_item_embedding[:n_items]
pretrained_item_embedding.weight.data[n_items] = torch.zeros(
    args.embedding_dim, dtype=skipgram_item_embedding.dtype
)

[32m2025-06-29 14:55:26.972[0m | [1mINFO    [0m | [36msrc.algo.item2vec.model[0m:[36m__init__[0m:[36m12[0m - [1mInitializing item embeddings with num items 4817, embedding dim 256[0m


SkipGram Item embedding shape: torch.Size([4818, 256])
SkipGram Item embedding dtype: torch.float32



The loaded checkpoint was produced with Lightning v2.5.2, which is newer than your current Lightning version: v2.5.0



In [7]:
# mlf_client = mlflow.MlflowClient()
# model = mlflow.pyfunc.load_model(
#     model_uri=f"models:/{args.mlf_item2vec_model_name}@champion"
# )
# skipgram_model = model.unwrap_python_model().model
# embedding_0 = skipgram_model.embeddings(torch.tensor(0))
# embedding_dim = embedding_0.size()[0]
# id_mapping = model.unwrap_python_model().id_mapping
# pretrained_item_embedding = skipgram_model.embeddings

In [8]:
assert (
    pretrained_item_embedding.embedding_dim == args.embedding_dim
), "Mismatch pretrained item_embedding dimension"

## Load vectorized item features

In [9]:
with open(args.item_metadata_pipeline_fp, "rb") as f:
    item_metadata_pipeline = dill.load(f)

## Load ANN Index

In [10]:
# if args.rc.use_sbert_features:
#     ann_index = AnnIndex(args.qdrant_url, args.qdrant_collection_name)
#     vector = ann_index.get_vector_by_ids([0])[0]
#     sbert_embedding_dim = vector.shape[0]
#     logger.info(f"{sbert_embedding_dim=}")
#     neighbors = ann_index.get_neighbors_by_ids([0])
#     display(neighbors)

# Prep data

In [11]:
train_df = pd.read_parquet("../data_for_ai/interim/train_sample_interactions_16407u_features_neg_seq.parquet")
val_df = pd.read_parquet("../data_for_ai/interim/val_sample_interactions_16407u_features_neg_seq.parquet")
idm_fp = "../data_for_ai/interim/idm_16407u.json"
idm = IDMapper().load(idm_fp)

assert (
    train_df[args.user_col].map(lambda s: idm.get_user_index(s))
    != train_df["user_indice"]
).sum() == 0, "Mismatch IDM"
assert (
    val_df[args.user_col].map(lambda s: idm.get_user_index(s)) != val_df["user_indice"]
).sum() == 0, "Mismatch IDM"

if args.rc.use_item_tags_from_llm:
    assert (
        "tags" in train_df.columns
    ), "There is no column `tags` in train_df, please make sure you have run notebook 002, 020 with RunCfg.use_item_tags_from_llm=True"

4817 items in the dataset


In [12]:
print(train_df.shape)
train_df.head()

(254784, 27)


Unnamed: 0,user_id,parent_asin,rating,timestamp,timestamp_unix,parent_asin_rating_cnt_365d,parent_asin_rating_avg_prev_rating_365d,parent_asin_rating_cnt_90d,parent_asin_rating_avg_prev_rating_90d,parent_asin_rating_cnt_30d,...,user_indice,item_indice,item_sequence,item_sequence_ts,item_sequence_ts_bucket,main_category,title,description,categories,price
0,AENOXSRSNC5VGY3JQKZQ5DD7HIUA,B00SG3CWGS,0.0,2017-06-10 00:30:32.698,1497054632,10.0,4.5,1.0,5.0,0.0,...,2546,4213,"[-1, -1, -1, -1, -1, -1, -1, -1, 218, 2648]","[-1, -1, -1, -1, -1, -1, -1, -1, 1457886402, 1...","[-1, -1, -1, -1, -1, -1, -1, -1, 6, 0]",Cell Phones & Accessories,Garmin Nuvi 67LMT 6-Inch GPS Navigator,"[With bright 6” dual-orientation displays, spo...","[Electronics, GPS, Finders & Accessories, Spor...",199.0
1,AEMPVT2U6BIHQDV52BDEDDKPH4HA,B01BCWKBZI,2.0,2017-08-03 00:40:30.172,1501720830,16.0,4.1875,3.0,4.333333,2.0,...,2416,2467,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]","[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]","[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]",Computers,Samsung T3 Portable SSD - 2TB - USB 3.1 Extern...,[Portability is the key element shared among a...,"[Electronics, Computers & Accessories, Data St...",348.69
2,AF3CKYP3BTJ7MEKU6J64BS57MQBA,B09BW3XJQV,0.0,2018-12-08 16:57:03.101,1544288223,5.0,3.4,1.0,4.0,1.0,...,4292,1208,"[-1, -1, -1, -1, 3541, 3089, 4168, 3936, 4066,...","[-1, -1, -1, -1, 1488569087, 1499723220, 15334...","[-1, -1, -1, -1, 6, 6, 5, 5, 4, 4]",Computers,ASUS AC1300 WiFi Router (RT-ACRH13) - Dual Ban...,[Upgrade to AC Wi-Fi for your bandwidth-hungry...,"[Electronics, Computers & Accessories, Network...",
3,AE7IGXXTK7XTWRJGLIAL5BJDTEAQ,B005L38VRU,5.0,2014-09-04 02:03:39.000,1409796219,5.0,4.6,1.0,5.0,0.0,...,728,689,"[-1, -1, -1, -1, -1, -1, 193, 3945, 1849, 4407]","[-1, -1, -1, -1, -1, -1, 1327177801, 133520743...","[-1, -1, -1, -1, -1, -1, 6, 6, 6, 5]",All Electronics,Logitech K750 Wireless Solar Keyboard for Mac ...,[Battery hassles are a thing of the past with ...,"[Electronics, Computers & Accessories, Compute...",49.99
4,AFEJ5GRYG2PQD6EWSAKVG56XMKNA,B001W6Q7SU,0.0,2016-09-14 16:29:39.000,1473870579,0.0,,0.0,,0.0,...,5481,834,"[-1, -1, -1, -1, -1, -1, -1, 3965, 4617, 2003]","[-1, -1, -1, -1, -1, -1, -1, 1473870313, 14738...","[-1, -1, -1, -1, -1, -1, -1, 0, 0, 0]",All Electronics,PNY Optima 2GB (2x1GB) Dual Channel Kit DDR2 6...,[PNY OPTIMA 2GB (2x1GB) Dual Channel Kit DDR2 ...,"[Electronics, Computers & Accessories, Compute...",65.99


In [13]:
user_indices = train_df["user_indice"].unique()
item_indices = train_df["item_indice"].unique()
# if args.rc.use_sbert_features:
#     all_sbert_vectors = ann_index.get_vector_by_ids(
#         item_indices.tolist(), chunk_size=1000
#     ).astype(np.float32)

train_item_features = chunk_transform(
    train_df, item_metadata_pipeline, chunk_size=args.tfm_chunk_size
)
train_item_features = train_item_features.astype(np.float32)

val_item_features = chunk_transform(
    val_df, item_metadata_pipeline, chunk_size=args.tfm_chunk_size
)
val_item_features = val_item_features.astype(np.float32)

# if args.rc.use_sbert_features:
#     train_sbert_vectors = all_sbert_vectors[train_df["item_indice"].values]
#     train_item_features = np.hstack([train_item_features, train_sbert_vectors])
#     val_sbert_vectors = all_sbert_vectors[val_df["item_indice"].values]
#     val_item_features = np.hstack([val_item_features, val_sbert_vectors])

logger.info(f"{len(user_indices)=:,.0f}, {len(item_indices)=:,.0f}")

Transforming chunks:   0%|          | 0/26 [00:00<?, ?it/s]

Transforming chunks:   0%|          | 0/1 [00:00<?, ?it/s]

[32m2025-06-29 14:55:49.504[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m24[0m - [1mlen(user_indices)=16,407, len(item_indices)=4,817[0m


# Train

In [14]:
rating_dataset = UserItemBinaryDFDataset(
    train_df,
    "user_indice",
    "item_indice",
    args.rating_col,
    args.timestamp_col,
    item_feature=train_item_features,
)
val_rating_dataset = UserItemBinaryDFDataset(
    val_df,
    "user_indice",
    "item_indice",
    args.rating_col,
    args.timestamp_col,
    item_feature=val_item_features,
)

train_loader = DataLoader(
    rating_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True
)
val_loader = DataLoader(
    val_rating_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False
)

In [None]:
for i in train_loader:
    print(f"Batch {i['user_indice'].shape=}, {i['item_indice'].shape=}, {i['rating'].shape=}")
    print(f"{i['item_feature'].shape=}")
    break

In [15]:
n_items = len(item_indices)
n_users = len(user_indices)
item_feature_size = train_item_features.shape[1]

model = init_model(
    n_users,
    n_items,
    args.embedding_dim,
    args.item_sequence_ts_bucket_size,
    args.bucket_embedding_dim,
    item_feature_size,
    args.dropout,
)
model.item_embedding.padding_idx

4817

#### Predict before train

In [16]:
print(val_df.shape)
val_df.head()

(6958, 27)


Unnamed: 0,user_id,parent_asin,rating,timestamp,timestamp_unix,parent_asin_rating_cnt_365d,parent_asin_rating_avg_prev_rating_365d,parent_asin_rating_cnt_90d,parent_asin_rating_avg_prev_rating_90d,parent_asin_rating_cnt_30d,...,user_indice,item_indice,item_sequence,item_sequence_ts,item_sequence_ts_bucket,main_category,title,description,categories,price
0,AGQFM7GX5UGRCK5F6EEGEEB25FKQ,B0C1J8RZ46,0.0,2021-06-02 04:46:55.478,1622609215,10.0,4.3,3.0,2.666667,1.0,...,10995,2570,"[-1, 1003, 3206, 4162, 3115, 4694, 1542, 4801,...","[-1, 1455812253, 1518767927, 1526164990, 15371...","[-1, 8, 7, 7, 6, 6, 6, 6, 6, 0]",Computers,TP-Link TL-SG108 | 8 Port Gigabit Unmanaged Et...,[TP-Link 8 Port Gigabit Ethernet Network Switch.],"[Electronics, Computers & Accessories, Network...",18.99
1,AGYITA5HB3G7B5UQIIYBVCPLRFVA,B09RS2KZK4,0.0,2022-01-20 17:33:34.286,1642700014,4.0,4.0,0.0,,0.0,...,12058,3350,"[4612, 2735, 154, 3987, 1287, 3811, 1991, 3585...","[1409842813, 1414422135, 1414422581, 141442273...","[8, 8, 8, 8, 8, 8, 8, 8, 6, 6]",Computers,"TP-Link USB WiFi Adapter for PC(TL-WN725N), N1...",[Maximum wireless transmission rates are the p...,"[Electronics, Computers & Accessories, Network...",9.99
2,AFKBE4VLE3XEQ5IHUZI2Q5KAKFCQ,B00PKTU83U,0.0,2021-08-15 20:42:07.057,1629060127,10.0,3.8,2.0,4.5,0.0,...,6169,922,"[2325, 746, 3957, 2556, 4445, 2612, 4650, 4157...","[1558980368, 1558980387, 1569515004, 157851333...","[6, 6, 6, 6, 6, 6, 6, 6, 6, 6]",Home Audio & Theater,Sony ZX Series Wired On-Ear Headphones with Mi...,[Form meets function with smartphone control f...,"[Electronics, Headphones, Earbuds & Accessorie...",18.25
3,AHKE5QFJM747GCEVBUF7QRH47HJA,B0BS1QXF6M,0.0,2021-07-16 00:07:10.093,1626394030,3.0,5.0,0.0,,0.0,...,14395,973,"[-1, -1, -1, -1, 4648, 651, 1232, 3311, 1754, ...","[-1, -1, -1, -1, 1561749246, 1561749319, 15645...","[-1, -1, -1, -1, 6, 6, 6, 6, 6, 5]",All Electronics,"iJoy Bluetooth Headphones Over Ear, Wireless a...",[],"[Electronics, Headphones, Earbuds & Accessorie...",24.99
4,AGLAJCRBNNNKVIASX3KKSULI2CFQ,B006R0VWSG,4.0,2022-01-21 18:46:08.771,1642790768,,,,,,...,10330,769,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]","[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]","[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]",Home Audio & Theater,Dim It light dimming sheets,[Package contains 2 (6 inch by 3 inch) static ...,[],9.99


In [17]:
val_df = val_rating_dataset.df
val_df.sample(10)

Unnamed: 0,user_id,parent_asin,rating,timestamp,timestamp_unix,parent_asin_rating_cnt_365d,parent_asin_rating_avg_prev_rating_365d,parent_asin_rating_cnt_90d,parent_asin_rating_avg_prev_rating_90d,parent_asin_rating_cnt_30d,...,user_indice,item_indice,item_sequence,item_sequence_ts,item_sequence_ts_bucket,main_category,title,description,categories,price
6620,AGAU256T5BRSADFHWRAUEFGPXDOA,B07KGVB6D6,1.0,2021-04-12 19:10:17.091,1618254617,25.0,4.4,2.0,4.5,0.0,...,9039,3504,"[-1, 3186, 1423, 3957, 3025, 2481, 2678, 3074,...","[-1, 1553779161, 1553993949, 1553994093, 15539...","[-1, 6, 6, 6, 6, 6, 6, 6, 6, 6]",Amazon Devices,"Fire TV Cube, Hands-free streaming device with...",[],[],
4393,AHROZWP46O7EO3GK2ITTTJYXC2NA,B08F1P3BCC,0.0,2021-06-19 00:11:31.891,1624061491,57.0,4.333333,7.0,5.0,2.0,...,15344,2359,"[-1, 2091, 3861, 445, 838, 2084, 2112, 1791, 1...","[-1, 1357628230, 1370362186, 1387503843, 14050...","[-1, 8, 8, 8, 8, 8, 8, 8, 8, 7]",Amazon Devices,Echo Dot (4th Gen) | Smart speaker with Alexa ...,[],[],
6419,AGH4HYH2UCRNA2W5ENOAC64IYQLA,B00JX3Q28Y,0.0,2021-03-23 21:26:39.666,1616534799,2.0,3.5,0.0,,0.0,...,9811,883,"[4710, 3981, 3954, 3736, 3076, 4398, 1834, 202...","[1425554561, 1425554713, 1426617752, 142667512...","[8, 8, 8, 8, 8, 7, 7, 7, 7, 7]",Computers,Plugable USB 3.0 Sharing Switch for One-Button...,[],"[Electronics, Computers & Accessories, Compute...",
2319,AGL76ZOS5KCDI2A2NUGKAMHLKSVQ,B08WR4DVYV,0.0,2021-02-02 12:49:53.486,1612270193,2.0,3.0,1.0,5.0,0.0,...,10324,4653,"[2779, 2093, 1319, 1422, 4528, 1907, 2890, 212...","[1395766331, 1425913862, 1425914196, 148357311...","[8, 8, 8, 7, 7, 7, 7, 6, 6, 6]",Car Electronics,BOSS Audio Systems 625UAB Multimedia Car Stere...,[Pump out your digital music with the BOSS Aud...,"[Electronics, Car & Vehicle Electronics, Car E...",39.99
3472,AEAOEDTHXDFQRCJFTCOB7CXZDATQ,B0BWVDRR6L,1.0,2021-06-25 16:02:18.728,1624636938,5.0,3.6,1.0,5.0,1.0,...,875,4665,"[-1, -1, -1, 1702, 3668, 3519, 4307, 3805, 407...","[-1, -1, -1, 1515614605, 1519175181, 156528718...","[-1, -1, -1, 7, 7, 6, 6, 5, 5, 5]",Computers,Toshiba Canvio Basics 1TB Portable External Ha...,[Discover one of the easiest ways to free up s...,"[Electronics, Computers & Accessories, Data St...",45.99
1056,AFQIJGFKAVOPYUXMIEYITHHQALOQ,B08YF1VBYD,1.0,2022-01-06 19:49:22.760,1641498562,5.0,3.6,3.0,2.666667,2.0,...,6981,4063,"[-1, -1, -1, -1, 1440, 2452, 3296, 2520, 707, ...","[-1, -1, -1, -1, 1422067107, 1426897277, 14751...","[-1, -1, -1, -1, 8, 8, 8, 8, 7, 7]",All Electronics,"DOSS Bluetooth Speaker, SoundBox Touch Portabl...",[],"[Electronics, Portable Audio & Video, Portable...",29.99
6678,AE5KRVZLJZQ7DAARLXKQ53KIMVMQ,B01054S5FM,0.0,2021-06-08 22:59:57.487,1623193197,4.0,5.0,1.0,5.0,0.0,...,476,3512,"[89, 4311, 1042, 4013, 427, 1208, 3100, 1600, ...","[1365005458, 1419865413, 1419865518, 143213495...","[8, 8, 8, 8, 8, 8, 8, 8, 8, 6]",,"Screen Cleaner Kit - Best for LED & LCD TV, Co...",[],"[Electronics, Television & Video, Accessories,...",
2565,AFSLRFECOJQNSO6V5H6YU6OYMTDA,B0C22WW175,1.0,2021-10-01 20:05:24.685,1633118724,1.0,1.0,0.0,,0.0,...,7230,4714,"[-1, -1, -1, -1, -1, 625, 626, 4040, 4369, 1791]","[-1, -1, -1, -1, -1, 1430343180, 1430343212, 1...","[-1, -1, -1, -1, -1, 8, 8, 7, 7, 6]",Computers,Tenda Mesh WiFi System (MW6) - Up to 4000 Sq.F...,[The ultimate Wi-Fi experience: The mesh Wi-Fi...,"[Electronics, Computers & Accessories, Network...",77.99
4630,AE6M2ZOVQ72YONNCHTY46LTX7YYA,B00MARDJZ4,1.0,2021-03-11 21:00:07.812,1615496407,0.0,,0.0,,0.0,...,628,1754,"[804, 1572, 1653, 154, 1091, 1906, 3702, 3399,...","[1412646886, 1436908587, 1474059165, 148415613...","[8, 8, 7, 7, 7, 7, 7, 6, 6, 6]",Computers,CanaKit 5V 2.5A Raspberry Pi 3 B+ Power Supply...,[2.5A is now a requirement for the Raspberry P...,"[Electronics, Computers & Accessories, Compute...",9.95
110,AGF43FVZ4PIC3C2UQDAJ2IF66VHA,B00XNYXQHE,0.0,2022-01-29 01:22:12.127,1643419332,,,,,,...,9564,1102,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]","[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]","[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]",Camera & Photo,Plugable USB Digital Microscope with Flexible ...,[],"[Electronics, Camera & Photo, Binoculars & Sco...",31.5


In [18]:
user_id = val_df.sample(1)[args.user_col].values[0]
test_df = val_df.loc[lambda df: df[args.user_col].eq(user_id)]
with pd.option_context("display.max_colwidth", None):
    display(test_df)

Unnamed: 0,user_id,parent_asin,rating,timestamp,timestamp_unix,parent_asin_rating_cnt_365d,parent_asin_rating_avg_prev_rating_365d,parent_asin_rating_cnt_90d,parent_asin_rating_avg_prev_rating_90d,parent_asin_rating_cnt_30d,...,user_indice,item_indice,item_sequence,item_sequence_ts,item_sequence_ts_bucket,main_category,title,description,categories,price
1250,AHGFAYBN2AD62SSMMCAFSWWPS46A,B07S764D9V,0.0,2021-10-18 02:33:38.428,1634524418,3.0,5.0,1.0,5.0,0.0,...,13866,4699,"[-1, -1, -1, -1, -1, 1744, 4176, 1639, 1819, 4086]","[-1, -1, -1, -1, -1, 1423437697, 1434671932, 1475371805, 1502936939, 1507689283]","[-1, -1, -1, -1, -1, 8, 8, 8, 7, 7]",Home Audio & Theater,"Panasonic ErgoFit Wired Earbuds, In-Ear Headphones with Microphone and Call Controller, Ergonomic Custom-Fit Earpieces (S/M/L), 3.5mm Jack for Phones and Laptops - RP-TCM125-A (Blue)","[The Panasonic RP-TCM125 ErgoFit Earbud Headphones with Microphone and Call Controller are the perfect combination of style, comfort, functionality and most of all, high-quality sound. The In-line microphone on the cord of earbud headphones is used for answering calls or voice commands; Compatible with iPhone, Android and Blackberry. Earbuds with three sets of earpads (S/M/L included), provide a custom, comfortable ergonomic fit that won’t slip out. Choose from eleven (11) vivid color options with color-matching earbuds, headphone cord and call controller to best complement your personal style and mood. Large 9mm drivers with neodymium magnets along with a wide frequency response and smart ergonomic fit earbuds deliver dynamic, crystal-clear sound while helping to keep out unwanted outside noise. Stereo headset tonally balanced audio with crisp highs and deep low notes, plus wider frequency response and lively sound quality for recorded audio. Long, 3. 6-ft headphone cord threads comfortably through clothing and bags making it easy to connect. DID YOU KNOW? You can use your headphones as a microphone! A microphone is the same as a speaker. The diaphragm of the driver of the headphones is moved by the molecules of air that make up a soundwave. This driver is attached to a coil of wire (voice coil) which rests between a magnets. DISCLAIMER: iPhone is a Trademark of Apple, registered in the US and other countries. The Trademark BlackBerry is owned by Research In Motion Limited and is registered in the United States and may be pending or registered in other countries. Panasonic is not endorsed, sponsored, affiliated with or otherwise authorized by Research In Motion Limited. Android is a trademark of Google requires lightning plug adaptor for iPhone 7 and later models (not included).]","[Electronics, Headphones, Earbuds & Accessories, Headphones & Earbuds, Earbud Headphones]",13.99
3883,AHGFAYBN2AD62SSMMCAFSWWPS46A,B07S764D9V,1.0,2021-10-18 02:33:38.428,1634524418,3.0,5.0,1.0,5.0,0.0,...,13866,3678,"[-1, -1, -1, -1, -1, 1744, 4176, 1639, 1819, 4086]","[-1, -1, -1, -1, -1, 1423437697, 1434671932, 1475371805, 1502936939, 1507689283]","[-1, -1, -1, -1, -1, 8, 8, 8, 7, 7]",Home Audio & Theater,"Panasonic ErgoFit Wired Earbuds, In-Ear Headphones with Microphone and Call Controller, Ergonomic Custom-Fit Earpieces (S/M/L), 3.5mm Jack for Phones and Laptops - RP-TCM125-A (Blue)","[The Panasonic RP-TCM125 ErgoFit Earbud Headphones with Microphone and Call Controller are the perfect combination of style, comfort, functionality and most of all, high-quality sound. The In-line microphone on the cord of earbud headphones is used for answering calls or voice commands; Compatible with iPhone, Android and Blackberry. Earbuds with three sets of earpads (S/M/L included), provide a custom, comfortable ergonomic fit that won’t slip out. Choose from eleven (11) vivid color options with color-matching earbuds, headphone cord and call controller to best complement your personal style and mood. Large 9mm drivers with neodymium magnets along with a wide frequency response and smart ergonomic fit earbuds deliver dynamic, crystal-clear sound while helping to keep out unwanted outside noise. Stereo headset tonally balanced audio with crisp highs and deep low notes, plus wider frequency response and lively sound quality for recorded audio. Long, 3. 6-ft headphone cord threads comfortably through clothing and bags making it easy to connect. DID YOU KNOW? You can use your headphones as a microphone! A microphone is the same as a speaker. The diaphragm of the driver of the headphones is moved by the molecules of air that make up a soundwave. This driver is attached to a coil of wire (voice coil) which rests between a magnets. DISCLAIMER: iPhone is a Trademark of Apple, registered in the US and other countries. The Trademark BlackBerry is owned by Research In Motion Limited and is registered in the United States and may be pending or registered in other countries. Panasonic is not endorsed, sponsored, affiliated with or otherwise authorized by Research In Motion Limited. Android is a trademark of Google requires lightning plug adaptor for iPhone 7 and later models (not included).]","[Electronics, Headphones, Earbuds & Accessories, Headphones & Earbuds, Earbud Headphones]",13.99


In [19]:
val_item_features.shape, train_item_features.shape

((6958, 626), (254784, 626))

In [20]:
test_row = test_df.loc[lambda df: df[args.rating_col].gt(0)].iloc[0]
item_id = test_row[args.item_col]
item_sequence = test_row["item_sequence"]
item_sequence_ts_bucket = test_row["item_sequence_ts_bucket"]
row_idx = test_row.name
item_feature = val_item_features[row_idx]
logger.info(
    f"Test predicting before training with {args.user_col} = {user_id} and {args.item_col} = {item_id}"
)
user_indice = idm.get_user_index(user_id)
item_indice = idm.get_item_index(item_id)
user = torch.tensor([user_indice])
item_sequence = torch.tensor([item_sequence])
item_sequence_ts_bucket = torch.tensor([item_sequence_ts_bucket])
item_feature = torch.tensor([item_feature])
item = torch.tensor([item_indice])

model.eval()
model.predict(user, item_sequence, item_sequence_ts_bucket, item_feature, item)
# model.train()

[32m2025-06-29 14:55:53.067[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m7[0m - [1mTest predicting before training with user_id = AHGFAYBN2AD62SSMMCAFSWWPS46A and parent_asin = B07S764D9V[0m

Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\pytorch\torch\csrc\utils\tensor_new.cpp:257.)



tensor([[0.5280]], grad_fn=<SigmoidBackward0>)

#### Training loop

##### Overfit 1 batch

In [21]:
early_stopping = EarlyStopping(
    monitor="val_loss", patience=10, mode="min", verbose=False
)

model = init_model(
    n_users,
    n_items,
    args.embedding_dim,
    args.item_sequence_ts_bucket_size,
    args.bucket_embedding_dim,
    item_feature_size,
    dropout=args.dropout,
)
lit_model = LitRanker(
    model,
    learning_rate=args.learning_rate,
    l2_reg=0.0,
    log_dir=args.notebook_persist_dp,
    accelerator=args.device,
)

log_dir = f"{args.notebook_persist_dp}/logs/overfit"

# # train model
# trainer = L.Trainer(
#     default_root_dir=log_dir,
#     accelerator=args.device if args.device else "auto",
#     max_epochs=2,
#     overfit_batches=1,
#     callbacks=[early_stopping],
# )
# trainer.fit(
#     model=lit_model,
#     train_dataloaders=train_loader,
#     val_dataloaders=val_loader,

# )
# logger.info(f"Logs available at {trainer.log_dir}")

In [22]:
# Need to make sure port 6006 at local is accessible
# %tensorboard --logdir $trainer.log_dir

##### Fit on all data

In [23]:
# print the number of rows in train_df that has rating = 0 and 1
print(
    f"Number of rows in train_df that has rating = 0: {train_df[train_df[args.rating_col] == 0.0].shape[0]}"
)
print(
    f"Number of rows in train_df that has rating = 1: {train_df[train_df[args.rating_col] >= 1.0].shape[0]}"
)
print(f"Number of rows in train_df: {train_df.shape[0]}")

Number of rows in train_df that has rating = 0: 127392
Number of rows in train_df that has rating = 1: 127392
Number of rows in train_df: 254784


In [24]:
# group by a specific user_id and all the rows for that user
user_id = "AF5KKBAOVY7J7LGPHAECKUTDQVTA"
user_df = train_df.loc[lambda df: df[args.user_col].eq(user_id)]
print(f"Number of rows for user {user_id}: {user_df.shape[0]}")
user_df = user_df.sort_values(by=args.timestamp_col, ascending=False)
user_df

Number of rows for user AF5KKBAOVY7J7LGPHAECKUTDQVTA: 28


Unnamed: 0,user_id,parent_asin,rating,timestamp,timestamp_unix,parent_asin_rating_cnt_365d,parent_asin_rating_avg_prev_rating_365d,parent_asin_rating_cnt_90d,parent_asin_rating_avg_prev_rating_90d,parent_asin_rating_cnt_30d,...,user_indice,item_indice,item_sequence,item_sequence_ts,item_sequence_ts_bucket,main_category,title,description,categories,price
90079,AF5KKBAOVY7J7LGPHAECKUTDQVTA,B077V2BF3C,0.0,2020-01-16 15:17:23.469,1579187843,12.0,4.583333,4.0,5.0,3.0,...,4590,1605,"[1396, 511, 582, 3096, 3795, 1314, 1610, 2253,...","[1455485886, 1455486035, 1455486042, 149334639...","[7, 7, 7, 6, 6, 6, 6, 6, 6, 5]",Computers,"Moread HDMI to VGA, 2 Pack, Gold-Plated HDMI t...",[],"[Electronics, Computers & Accessories, Compute...",14.99
146629,AF5KKBAOVY7J7LGPHAECKUTDQVTA,B077V2BF3C,5.0,2020-01-16 15:17:23.469,1579187843,12.0,4.583333,4.0,5.0,3.0,...,4590,3145,"[1396, 511, 582, 3096, 3795, 1314, 1610, 2253,...","[1455485886, 1455486035, 1455486042, 149334639...","[7, 7, 7, 6, 6, 6, 6, 6, 6, 5]",Computers,"Moread HDMI to VGA, 2 Pack, Gold-Plated HDMI t...",[],"[Electronics, Computers & Accessories, Compute...",14.99
161061,AF5KKBAOVY7J7LGPHAECKUTDQVTA,B0779V61XB,0.0,2019-01-29 21:29:22.979,1548797362,6.0,3.666667,4.0,3.0,1.0,...,4590,2586,"[1333, 1396, 511, 582, 3096, 3795, 1314, 1610,...","[1455485483, 1455485886, 1455486035, 145548604...","[6, 6, 6, 6, 6, 6, 6, 6, 5, 5]",Computers,UGREEN SD Card Reader Portable USB 3.0 Dual Sl...,[],"[Electronics, Computers & Accessories, Compute...",12.99
146024,AF5KKBAOVY7J7LGPHAECKUTDQVTA,B0779V61XB,5.0,2019-01-29 21:29:22.979,1548797362,6.0,3.666667,4.0,3.0,1.0,...,4590,3128,"[1333, 1396, 511, 582, 3096, 3795, 1314, 1610,...","[1455485483, 1455485886, 1455486035, 145548604...","[6, 6, 6, 6, 6, 6, 6, 6, 5, 5]",Computers,UGREEN SD Card Reader Portable USB 3.0 Dual Sl...,[],"[Electronics, Computers & Accessories, Compute...",12.99
27042,AF5KKBAOVY7J7LGPHAECKUTDQVTA,B08CLNX58K,5.0,2018-11-16 22:40:45.180,1542408045,77.0,4.571429,18.0,4.277778,2.0,...,4590,3908,"[3585, 1333, 1396, 511, 582, 3096, 3795, 1314,...","[1455485241, 1455485483, 1455485886, 145548603...","[6, 6, 6, 6, 6, 6, 6, 6, 5, 5]",Computers,SanDisk 32GB 2-Pack Ultra MicroSDHC UHS-I Memo...,[Transfer speeds of up to 98MB/sec . Records F...,"[Electronics, Computers & Accessories, Compute...",13.4
171506,AF5KKBAOVY7J7LGPHAECKUTDQVTA,B08CLNX58K,0.0,2018-11-16 22:40:45.180,1542408045,77.0,4.571429,18.0,4.277778,2.0,...,4590,4251,"[3585, 1333, 1396, 511, 582, 3096, 3795, 1314,...","[1455485241, 1455485483, 1455485886, 145548603...","[6, 6, 6, 6, 6, 6, 6, 6, 5, 5]",Computers,SanDisk 32GB 2-Pack Ultra MicroSDHC UHS-I Memo...,[Transfer speeds of up to 98MB/sec . Records F...,"[Electronics, Computers & Accessories, Compute...",13.4
140755,AF5KKBAOVY7J7LGPHAECKUTDQVTA,B011BRUOMO,0.0,2018-02-16 03:17:37.395,1518751057,174.0,4.712644,30.0,4.666667,7.0,...,4590,890,"[150, 3585, 1333, 1396, 511, 582, 3096, 3795, ...","[1455485232, 1455485241, 1455485483, 145548588...","[6, 6, 6, 6, 6, 6, 5, 5, 5, 5]",Computers,SanDisk Ultra 32GB microSDHC UHS-I Card with A...,"[Capture, carry and keep more high-quality pho...","[Electronics, Computers & Accessories, Compute...",8.99
141119,AF5KKBAOVY7J7LGPHAECKUTDQVTA,B011BRUOMO,5.0,2018-02-16 03:17:37.395,1518751057,174.0,4.712644,30.0,4.666667,7.0,...,4590,2253,"[150, 3585, 1333, 1396, 511, 582, 3096, 3795, ...","[1455485232, 1455485241, 1455485483, 145548588...","[6, 6, 6, 6, 6, 6, 5, 5, 5, 5]",Computers,SanDisk Ultra 32GB microSDHC UHS-I Card with A...,"[Capture, carry and keep more high-quality pho...","[Electronics, Computers & Accessories, Compute...",8.99
248528,AF5KKBAOVY7J7LGPHAECKUTDQVTA,B00JO6RO8C,0.0,2017-11-29 20:39:49.054,1511987989,24.0,4.833333,5.0,4.6,2.0,...,4590,1047,"[-1, 150, 3585, 1333, 1396, 511, 582, 3096, 37...","[-1, 1455485232, 1455485241, 1455485483, 14554...","[-1, 6, 6, 6, 6, 6, 6, 5, 4, 4]",Computers,SanDisk Cruzer Fit CZ33 32GB USB 2.0 Low-Profi...,"[With its low-profile design, the Cruzer Fit U...","[Electronics, Computers & Accessories, Data St...",21.5
64548,AF5KKBAOVY7J7LGPHAECKUTDQVTA,B00JO6RO8C,5.0,2017-11-29 20:39:49.054,1511987989,24.0,4.833333,5.0,4.6,2.0,...,4590,1610,"[-1, 150, 3585, 1333, 1396, 511, 582, 3096, 37...","[-1, 1455485232, 1455485241, 1455485483, 14554...","[-1, 6, 6, 6, 6, 6, 6, 5, 4, 4]",Computers,SanDisk Cruzer Fit CZ33 32GB USB 2.0 Low-Profi...,"[With its low-profile design, the Cruzer Fit U...","[Electronics, Computers & Accessories, Data St...",21.5


In [25]:
# sort the train_df by timestamp and get the lastest item features from train_df
all_items_df = train_df.sort_values(by=args.timestamp_col, ascending=False)
# get the lastest item features from train_df
all_items_indices = all_items_df.drop_duplicates(subset=[args.item_col], keep="first")["item_indice"].values
all_items_features = item_metadata_pipeline.transform(all_items_df).astype(np.float32)
logger.info(
    f"Mean std over categorical and numerical features: {all_items_features.std(axis=0).mean()}"
)
# if args.rc.use_sbert_features:
#     all_sbert_vectors = ann_index.get_vector_by_ids(all_items_indices.tolist()).astype(
#         np.float32
#     )
#     logger.info(f"Mean std over text features: {all_sbert_vectors.std(axis=0).mean()}")
#     all_items_features = np.hstack([all_items_features, all_sbert_vectors])

[32m2025-06-29 14:56:22.944[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mMean std over categorical and numerical features: 0.9046086072921753[0m


In [26]:
# all_items_df = train_df.drop_duplicates(subset=["item_indice"])
# all_items_indices = all_items_df["item_indice"].values
# all_items_features = item_metadata_pipeline.transform(all_items_df).astype(np.float32)
# logger.info(
#     f"Mean std over categorical and numerical features: {all_items_features.std(axis=0).mean()}"
# )
# # if args.rc.use_sbert_features:
# #     all_sbert_vectors = ann_index.get_vector_by_ids(all_items_indices.tolist()).astype(
# #         np.float32
# #     )
# #     logger.info(f"Mean std over text features: {all_sbert_vectors.std(axis=0).mean()}")
# #     all_items_features = np.hstack([all_items_features, all_sbert_vectors])

In [27]:
# papermill_description=fit-model
early_stopping = EarlyStopping(
    monitor="val_roc_auc", patience=args.early_stopping_patience, mode="max", verbose=False
)

checkpoint_callback = ModelCheckpoint(
    dirpath=f"{args.notebook_persist_dp}/checkpoints",
    filename="best-checkpoint",
    save_top_k=1,
    monitor="val_roc_auc",
    mode="max",
)

model = init_model(
    n_users,
    n_items,
    args.embedding_dim,
    args.item_sequence_ts_bucket_size,
    args.bucket_embedding_dim,
    item_feature_size,
    dropout=args.dropout,
    item_embedding=pretrained_item_embedding,
)
lit_model = LitRanker(
    model,
    learning_rate=args.learning_rate,
    l2_reg=args.l2_reg,
    log_dir=args.notebook_persist_dp,
    evaluate_ranking=True,
    idm=idm,
    all_items_indices=all_items_indices,
    all_items_features=all_items_features,
    args=args,
    neg_to_pos_ratio=args.neg_to_pos_ratio,
    checkpoint_callback=checkpoint_callback,
    accelerator=args.device,
)

In [28]:
log_dir = f"{args.notebook_persist_dp}/logs/run"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    max_epochs=args.max_epochs,
    callbacks=[early_stopping, checkpoint_callback],
    accelerator=args.device if args.device else "auto",
    logger=args._mlf_logger if args.log_to_mlflow else None,
)
# trainer.fit(
#     model=lit_model,
#     train_dataloaders=train_loader,
#     val_dataloaders=val_loader,
# )

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [29]:
# print the number of unique items in val_df
print(f"Number of unique items in val_df: {val_df['item_indice'].nunique()}")

Number of unique items in val_df: 2680


In [30]:
logger.info(
    f"Test predicting after training with {args.user_col} = {user_id} and {args.item_col} = {item_id}"
)
model.eval()
model = model.to(user.device)  # Move model back to align with data device
model.predict(user, item_sequence, item_sequence_ts_bucket, item_feature, item)
model.train()

[32m2025-06-29 14:56:26.538[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mTest predicting after training with user_id = AF5KKBAOVY7J7LGPHAECKUTDQVTA and parent_asin = B07S764D9V[0m


Ranker(
  (item_embedding): Embedding(4818, 256, padding_idx=4817)
  (user_embedding): Embedding(16407, 256)
  (item_sequence_ts_bucket_embedding): Embedding(11, 16, padding_idx=10)
  (gru): GRU(272, 256, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (item_feature_tower): Sequential(
    (0): Linear(in_features=626, out_features=256, bias=True)
    (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
  )
  (fc_rating): Sequential(
    (0): Linear(in_features=1024, out_features=256, bias=True)
    (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=256, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

# Load best checkpoint

In [31]:
# logger.info(f"Loading best checkpoint from {checkpoint_callback.best_model_path}...")
# args.best_checkpoint_path = checkpoint_callback.best_model_path

# best_trainer = LitRanker.load_from_checkpoint(
#     checkpoint_callback.best_model_path,
#     model=init_model(
#         n_users,
#         n_items,
#         args.embedding_dim,
#         args.item_sequence_ts_bucket_size,
#         args.bucket_embedding_dim,
#         item_feature_size,
#         dropout=0,
#         item_embedding=pretrained_item_embedding,
#     ),
# )

In [38]:
best_trainer = LitRanker.load_from_checkpoint(
    "C:/Users/Trieu/Downloads/best-checkpoint (2).ckpt",
    model=init_model(
        n_users,
        n_items,
        args.embedding_dim,
        args.item_sequence_ts_bucket_size,
        args.bucket_embedding_dim,
        item_feature_size,
        dropout=0,
        item_embedding=pretrained_item_embedding,
    ),
)

## testing after train

In [33]:
# user_id = val_df.sample(1)[args.user_col].values[0]
# test_df = val_df.loc[lambda df: df[args.user_col].eq(user_id)]
# # with pd.option_context("display.max_colwidth", None):
# #     display(test_df)
# test_df.shape

In [34]:
# test_row = test_df.loc[lambda df: df[args.rating_col].eq(0)].iloc[0]
# item_id = test_row[args.item_col]
# item_sequence = test_row["item_sequence"]
# item_sequence_ts_bucket = test_row["item_sequence_ts_bucket"]
# row_idx = test_row.name
# item_feature = val_item_features[row_idx]
# logger.info(
#     f"Test predicting before training with {args.user_col} = {user_id} and {args.item_col} = {item_id}"
# )
# user_indice = idm.get_user_index(user_id)
# item_indice = idm.get_item_index(item_id)
# user = torch.tensor([user_indice])
# item_sequence = torch.tensor([item_sequence])
# item_sequence_ts_bucket = torch.tensor([item_sequence_ts_bucket])
# item_feature = torch.tensor([item_feature])
# item = torch.tensor([item_indice])

# # print the information of the user and item and rating we are testing as a row in dataframe
# user_df = pd.DataFrame({
#     args.user_col: [user_id],
#     args.item_col: [item_id],
#     args.rating_col: [test_row[args.rating_col]],
#     "item_sequence": [item_sequence.tolist()],
#     "item_sequence_ts_bucket": [item_sequence_ts_bucket],
#     "item_feature": [item_feature.tolist()],
# })
# user_df


In [40]:
best_model = best_trainer.model.to(lit_model.device)

In [36]:
# best_model.eval()
# best_model.predict(user, item_sequence, item_sequence_ts_bucket, item_feature, item)

In [42]:
best_model._log_classification_metrics()

AttributeError: 'Ranker' object has no attribute '_log_classification_metrics'

### Persist artifacts

In [None]:
if args.log_to_mlflow:
    # Persist id_mapping so that at inference we can predict based on item_ids (string) instead of item_index
    run_id = trainer.logger.run_id
    mlf_client = trainer.logger.experiment
    mlf_client.log_artifact(run_id, idm_fp)
    # Persist item_feature_metadata pipeline
    mlf_client.log_artifact(run_id, args.item_metadata_pipeline_fp)

    # Persist model architecture
    model_architecture_fp = f"{args.notebook_persist_dp}/model_architecture.txt"
    with open(model_architecture_fp, "w") as f:
        f.write(repr(model))
    mlf_client.log_artifact(run_id, model_architecture_fp)

### Wrap inference function and register best checkpoint as MLflow model

In [None]:
inferrer = RankerInferenceWrapper(best_model)

In [None]:
def generate_sample_item_features():
    sample_row = train_df.iloc[0].fillna(0)
    output = dict()
    for col in args.rc.item_feature_cols:
        v = sample_row[col]
        if isinstance(v, np.ndarray):
            v = "__".join(
                sample_row[col].tolist()
            )  # Workaround to avoid MLflow Got error: Per-column arrays must each be 1-dimensional
        output[col] = [v]
    return output

In [None]:
sample_input = {
    args.user_col: [idm.get_user_id(0)],
    "item_sequence": [",".join([idm.get_item_id(0), idm.get_item_id(1)])],
    "item_sequence_ts": [
        "1095133116,109770848"
    ],  # Here we input unix timestamp seconds instead of timestamp bucket because we need to calculate the bucket
    # **{col: [train_df.iloc[0].fillna(0)[col]] for col in args.item_feature_cols},
    **generate_sample_item_features(),
    args.item_col: [idm.get_item_id(0)],
}
sample_output = inferrer.infer([0], [[0, 1]], [[2, 0]], [train_item_features[0]], [0])
sample_output

In [None]:
sample_input

In [None]:
if args.log_to_mlflow:
    run_id = trainer.logger.run_id
    sample_output_np = sample_output
    signature = infer_signature(sample_input, sample_output_np)
    idm_filename = idm_fp.split("/")[-1]
    item_metadata_pipeline_filename = args.item_metadata_pipeline_fp.split("/")[-1]
    with mlflow.start_run(run_id=run_id):
        mlflow.pyfunc.log_model(
            python_model=inferrer,
            artifact_path="inferrer",
            artifacts={
                # We log the id_mapping to the predict function so that it can accept item_id and automatically convert ot item_indice for PyTorch model to use
                "idm": mlflow.get_artifact_uri(idm_filename),
                "item_metadata_pipeline": mlflow.get_artifact_uri(
                    item_metadata_pipeline_filename
                ),
            },
            model_config={"use_sbert_features": args.rc.use_sbert_features},
            signature=signature,
            input_example=sample_input,
            registered_model_name=args.mlf_model_name,
        )

# Set the newly trained model as champion

In [None]:
if args.log_to_mlflow:
    # Get current champion
    deploy_alias = "champion"
    curr_model_run_id = None

    min_roc_auc = args.min_roc_auc

    try:
        curr_champion_model = mlf_client.get_model_version_by_alias(
            args.mlf_model_name, deploy_alias
        )
        curr_model_run_id = curr_champion_model.run_id
    except MlflowException as e:
        if "not found" in str(e).lower():
            logger.info(
                f"There is no {deploy_alias} alias for model {args.mlf_model_name}"
            )

    # Compare new vs curr models
    new_mlf_run = trainer.logger.experiment.get_run(trainer.logger.run_id)
    new_metrics = new_mlf_run.data.metrics
    roc_auc = new_metrics["roc_auc"]
    if curr_model_run_id:
        curr_model_run_info = mlf_client.get_run(curr_model_run_id)
        curr_metrics = curr_model_run_info.data.metrics
        if (curr_roc_auc := curr_metrics["roc_auc"]) > min_roc_auc:
            logger.info(
                f"Current {deploy_alias} model has {curr_roc_auc:,.4f} ROC-AUC..."
            )
            min_roc_auc = curr_roc_auc

        top_metrics = ["roc_auc", "val_PersonalizationMetric"]
        vizer = ModelMetricsComparisonVisualizer(curr_metrics, new_metrics, top_metrics)
        print(f"Comparing metrics between new run and current champion:")
        display(vizer.compare_metrics_df())
        vizer.create_metrics_comparison_plot(n_cols=5)
        vizer.plot_diff()

    # Register new champion
    if roc_auc < min_roc_auc:
        logger.info(
            f"Current run has ROC-AUC = {roc_auc:,.4f}, smaller than {min_roc_auc:,.4f}. Skip aliasing this model as the new {deploy_alias}.."
        )
    else:
        logger.info(f"Aliasing the new model as champion...")
        # Get the model version for current run by assuming it's the most recent registered version
        model_version = (
            mlf_client.get_registered_model(args.mlf_model_name)
            .latest_versions[0]
            .version
        )

        mlf_client.set_registered_model_alias(
            name=args.mlf_model_name, alias="champion", version=model_version
        )

        mlf_client.set_model_version_tag(
            name=args.mlf_model_name,
            version=model_version,
            key="author",
            value=args.author,
        )

# Clean up

In [None]:
all_params = [args]

if args.log_to_mlflow:
    with mlflow.start_run(run_id=run_id):
        for params in all_params:
            params_dict = params.dict()
            params_ = dict()
            for k, v in params_dict.items():
                if k == "top_K":
                    k = "top_big_K"
                if k == "top_k":
                    k = "top_small_k"
                params_[f"{params.__repr_name__()}.{k}"] = v
            mlflow.log_params(params_)