# Ranker that can takes into accound different features

# Set up

In [1]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

In [2]:
import os
import sys
from typing import List

import dill
import lightning as L
import numpy as np
import pandas as pd
import torch
from dotenv import load_dotenv
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import MLFlowLogger
from loguru import logger
from mlflow.models.signature import infer_signature
from pydantic import BaseModel
from torch.utils.data import DataLoader

import mlflow

load_dotenv()

sys.path.insert(0, "..")

from src.data_prep_utils import chunk_transform
from src.dataset import UserItemBinaryDFDataset
from src.id_mapper import IDMapper
from src.ranker.inference import RankerInferenceWrapper
from src.ranker.model import Ranker
from src.ranker.trainer import LitRanker
from src.viz import blueq_colors

# Controller

In [3]:
max_epochs = 100

In [4]:
class Args(BaseModel):
    testing: bool = False
    log_to_mlflow: bool = True
    experiment_name: str = "RecSys MVP - Ranker"
    run_name: str = "001-add-item-categories"
    notebook_persist_dp: str = None
    random_seed: int = 41
    device: str = None

    item_metadata_pipeline_fp: str = "../data/item_metadata_pipeline.dill"

    max_epochs: int = max_epochs
    batch_size: int = 128
    tfm_chunk_size: int = 10000

    user_col: str = "user_id"
    item_col: str = "parent_asin"
    rating_col: str = "rating"
    timestamp_col: str = "timestamp"
    item_feature_cols: List[str] = ["main_category", "categories"]

    top_K: int = 100
    top_k: int = 10

    batch_size: int = 128

    embedding_dim: int = 128
    dropout: float = 0.3
    early_stopping_patience: int = 5
    learning_rate: float = 0.001
    l2_reg: float = 1e-4

    mlf_item2vec_model_name: str = "item2vec"
    mlf_model_name: str = "ranker"
    min_roc_auc: float = 0.7

    best_checkpoint_path: str = None

    def init(self):
        self.notebook_persist_dp = os.path.abspath(f"data/{self.run_name}")
        os.makedirs(self.notebook_persist_dp, exist_ok=True)

        if not (mlflow_uri := os.environ.get("MLFLOW_TRACKING_URI")):
            logger.warning(
                f"Environment variable MLFLOW_TRACKING_URI is not set. Setting self.log_to_mlflow to false."
            )
            self.log_to_mlflow = False

        if self.log_to_mlflow:
            logger.info(
                f"Setting up MLflow experiment {self.experiment_name} - run {self.run_name}..."
            )
            self._mlf_logger = MLFlowLogger(
                experiment_name=self.experiment_name,
                run_name=self.run_name,
                tracking_uri=mlflow_uri,
                log_model=True,
            )

        return self


args = Args().init()

print(args.model_dump_json(indent=2))

[32m2024-10-26 16:40:25.300[0m | [1mINFO    [0m | [36m__main__[0m:[36minit[0m:[36m50[0m - [1mSetting up MLflow experiment RecSys MVP - Ranker - run 001-add-item-categories...[0m


{
  "testing": false,
  "log_to_mlflow": true,
  "experiment_name": "RecSys MVP - Ranker",
  "run_name": "001-add-item-categories",
  "notebook_persist_dp": "/Users/dvq/frostmourne/recsys-mvp/notebooks/data/001-add-item-categories",
  "random_seed": 41,
  "device": null,
  "item_metadata_pipeline_fp": "../data/item_metadata_pipeline.dill",
  "max_epochs": 100,
  "batch_size": 128,
  "tfm_chunk_size": 10000,
  "user_col": "user_id",
  "item_col": "parent_asin",
  "rating_col": "rating",
  "timestamp_col": "timestamp",
  "item_feature_cols": [
    "main_category",
    "categories"
  ],
  "top_K": 100,
  "top_k": 10,
  "embedding_dim": 128,
  "dropout": 0.3,
  "early_stopping_patience": 5,
  "learning_rate": 0.001,
  "l2_reg": 0.0001,
  "mlf_item2vec_model_name": "item2vec",
  "mlf_model_name": "ranker",
  "min_roc_auc": 0.7,
  "best_checkpoint_path": null
}


# Implement

In [5]:
def init_model(
    n_users, n_items, embedding_dim, item_feature_size, dropout, item_embedding=None
):
    model = Ranker(
        n_users,
        n_items,
        embedding_dim,
        item_feature_size=item_feature_size,
        dropout=dropout,
        item_embedding=item_embedding,
    )
    return model

## Load pretrained Item2Vec embeddings

In [6]:
mlf_client = mlflow.MlflowClient()
model = mlflow.pyfunc.load_model(
    model_uri=f"models:/{args.mlf_item2vec_model_name}@champion"
)
skipgram_model = model.unwrap_python_model().model
embedding_0 = skipgram_model.embeddings(torch.tensor(0))
embedding_dim = embedding_0.size()[0]
id_mapping = model.unwrap_python_model().id_mapping
pretrained_item_embedding = skipgram_model.embeddings

Downloading artifacts:   0%|          | 0/8 [00:00<?, ?it/s]

In [7]:
assert (
    pretrained_item_embedding.embedding_dim == args.embedding_dim
), "Mismatch pretrained item_embedding dimension"

## Load vectorized item features

In [8]:
with open(args.item_metadata_pipeline_fp, "rb") as f:
    item_metadata_pipeline = dill.load(f)

# Test implementation

In [9]:
embedding_dim = 8
batch_size = 2

# Mock data
user_indices = [0, 0, 1, 2, 2]
item_indices = [0, 1, 2, 3, 4]
timestamps = [0, 1, 2, 3, 4]
ratings = [0, 4, 5, 3, 0]
item_sequences = [
    [-1, -1, 2, 3],
    [-1, -1, 2, 3],
    [-1, -1, 1, 3],
    [-1, -1, 2, 1],
    [-1, -1, 2, 1],
]
main_category = [
    "All Electronics",
    "Video Games",
    "All Electronics",
    "Video Games",
    "Unknown",
]
categories = [[], ["Headsets"], ["Video Games"], [], ["blah blah"]]
title = ["World of Warcraft", "DotA 2", "Diablo IV", "Football Manager 2024", "Unknown"]
description = [[], [], ["Video games blah blah"], [], ["blah blah"]]
price = ["from 14.99", "14.99", "price: 9.99", "20 dollars", "None"]

train_df = pd.DataFrame(
    {
        "user_indice": user_indices,
        "item_indice": item_indices,
        args.timestamp_col: timestamps,
        args.rating_col: ratings,
        "item_sequence": item_sequences,
        "main_category": main_category,
        "title": title,
        "description": description,
        "categories": categories,
        "price": price,
    }
)
train_item_features = item_metadata_pipeline.transform(train_df).astype(np.float32)

n_users = len(set(user_indices))
n_items = len(set(item_indices))
item_feature_size = train_item_features.shape[1]

model = init_model(n_users, n_items, embedding_dim, item_feature_size, args.dropout)

# Example forward pass
model.eval()
users = torch.tensor(user_indices)
items = torch.tensor(item_indices)
item_sequences = torch.tensor(item_sequences)
item_features = torch.tensor(train_item_features)
predictions = model.predict(users, item_sequences, item_features, items)
print(predictions)

tensor([[0.4384],
        [0.4400],
        [0.4097],
        [0.4688],
        [0.4615]], grad_fn=<SigmoidBackward0>)


In [10]:
rating_dataset = UserItemBinaryDFDataset(
    train_df,
    "user_indice",
    "item_indice",
    args.rating_col,
    args.timestamp_col,
    item_feature=train_item_features,
)

train_loader = DataLoader(rating_dataset, batch_size=batch_size, shuffle=False)

In [11]:
for batch_input in train_loader:
    print(batch_input)

{'user': tensor([0, 0]), 'item': tensor([0, 1]), 'rating': tensor([0., 1.]), 'item_sequence': tensor([[-1, -1,  2,  3],
        [-1, -1,  2,  3]]), 'item_feature': tensor([[-0.0147,  5.6424, -0.0147, -0.0255, -0.0147, -0.0147, -0.0208, -0.0821,
         -0.2873, -0.0147, -0.0389, -0.0294, -0.0255, -0.0147, -0.0147, -0.0389,
         -0.0208, -0.0389, -0.0510, -2.4504, -0.1343,  7.1430, -0.6101, -0.0821,
         -0.0794, -0.0208, -0.0147, -0.0147, -0.0510, -0.1326, -0.0255, -0.0147,
         -0.0910, -0.1359, -0.0147, -0.1486, -0.1045, -0.0147, -0.0147, -0.2057,
         -0.2260, -0.0294, -0.0208, -0.0147, -0.0208, -0.0147, -0.0294, -0.0737,
         -0.0208, -0.0208, -0.0147, -0.1034, -0.0208, -0.0208, -0.0389, -0.1034,
         -0.0551, -0.1248, -0.0675, -0.1239, -0.1638, -1.1327, -0.0208, -0.0208,
         -0.1300, -0.1359, -0.0329, -0.0147, -0.0147, -0.0208, -0.0147, -0.0208,
         -0.0465, -0.1836, -0.0910, -0.0255, -0.0208, -0.0389, -1.1646, -0.0691,
         -0.0551, -0.0147,

In [12]:
# model
lit_model = LitRanker(model, log_dir=args.notebook_persist_dp)

# train model
trainer = L.Trainer(
    default_root_dir=f"{args.notebook_persist_dp}/test",
    max_epochs=2,
    accelerator=args.device if args.device else "auto",
)
trainer.fit(
    model=lit_model, train_dataloaders=train_loader, val_dataloaders=train_loader
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name  | Type   | Params | Mode
----------------------------------------
0 | model | Ranker | 1.9 K  | eval
----------------------------------------
1.9 K     Trainable params
0         Non-trainable params
1.9 K     Total params
0.008     Total estimated model params size (MB)
0         Modules in train mode
11        Modules in eval mode


Sanity Checking: |                                                                                            …

/Users/dvq/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.
/Users/dvq/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.
/Users/dvq/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (3) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=2` reached.
[32m2024-10-26 16:40:27.280[0m | [1mINFO    [0m | [36msrc.ranker.trainer[0m:[36mon_fit_end[0m:[36m129[0m - [1mLogging classification metrics...[0m
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [13]:
# After fitting
model.eval()
predictions = model.predict(users, item_sequences, item_features, items)
print(predictions)

tensor([[0.4332],
        [0.4573],
        [0.4145],
        [0.4764],
        [0.4587]], grad_fn=<SigmoidBackward0>)


In [14]:
all_items_df = train_df.drop_duplicates(subset=["item_indice"])
all_items_indices = all_items_df["item_indice"].values
all_items_features = item_metadata_pipeline.transform(all_items_df).astype(np.float32)

# Get the last row of each item as input for recommendations (containing the most updated item_sequence)
to_rec_df = train_df.sort_values(args.timestamp_col, ascending=False).drop_duplicates(
    subset=["user_indice"]
)
recommendations = model.recommend(
    torch.tensor(to_rec_df["user_indice"].values.tolist()),
    torch.tensor(to_rec_df["item_sequence"].values.tolist()),
    torch.tensor(all_items_features),
    torch.tensor(all_items_indices),
    k=2,
    batch_size=4,
)
recommendations

Generating recommendations:   0%|          | 0/1 [00:00<?, ?it/s]

{'user_indice': [2, 2, 1, 1, 0, 0],
 'recommendation': [1, 3, 1, 0, 1, 0],
 'score': [0.48788541555404663,
  0.47636258602142334,
  0.4562387466430664,
  0.43075892329216003,
  0.4572566747665405,
  0.43322157859802246]}

# Prep data

In [15]:
train_df = pd.read_parquet("../data/train_features_neg_df.parquet")
val_df = pd.read_parquet("../data/val_features_neg_df.parquet")
idm_fp = "../data/idm.json"
idm = IDMapper().load(idm_fp)

assert (
    train_df[args.user_col].map(lambda s: idm.get_user_index(s))
    != train_df["user_indice"]
).sum() == 0, "Mismatch IDM"
assert (
    val_df[args.user_col].map(lambda s: idm.get_user_index(s)) != val_df["user_indice"]
).sum() == 0, "Mismatch IDM"

In [16]:
train_df

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,main_category,categories,user_rating_cnt_90d,user_rating_avg_prev_rating_90d,user_rating_list_10_recent_asin,item_sequence
0,AG4SECUU2KPB4QPGSXIAJWCH7ZBQ,B07HGT7JC8,0.0,2010-05-19 09:25:09.000,10225,4002,Video Games,"[Video Games, Xbox One, Games]",1,,"B001E8WQKY,B00269QLI8,B001QCWRY8,B001EYUXWQ","[-1, -1, -1, -1, -1, -1, 722, 1181, 1124, 1030]"
1,AHVNPE7LYJL3EAUTJOKBIGJ6THTQ,B00IRHE892,0.0,2019-10-11 12:12:42.670,18909,2655,Computers,"[Video Games, PC, Accessories, Gaming Mice]",4,4.666667,"B0018RYC9Y,B00KCCNMYW,B00C2PJTKS,B000X37732,B0...","[679, 2702, 2370, 628, 685, 1234, 400, 1105, 1..."
2,AH53UXKY56CRB4JH6H4BWB2H7KJA,B00AZJ25Z4,0.0,2018-02-25 22:25:58.651,15099,2301,Video Games,"[Video Games, Legacy Systems, Xbox Systems, Xb...",2,5.000000,B00SZ1DQFM,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, 2944]"
3,AFC5XTCF5D7J3NSDITB2Z26XWWYA,B001E8WQUY,5.0,2019-05-01 21:22:39.265,6190,724,Video Games,"[Video Games, Legacy Systems, Nintendo Systems...",1,,"B006HZA6VK,B0BN2FNKLM,B0086VPUHI,B0040UAYI4,B0...","[1987, 4569, 2114, 1606, 2159, 2279, 2447, 441..."
4,AF7LJQOIWF3Y3YD7SGOJ34MA5JPA,B001E8WQKY,5.0,2015-01-09 12:53:25.000,5792,722,Video Games,"[Video Games, Legacy Systems, Xbox Systems, Xb...",3,5.000000,"B00A2ML6XG,B003VUO6LU","[-1, -1, -1, -1, -1, -1, -1, -1, 2261, 1579]"
...,...,...,...,...,...,...,...,...,...,...,...,...
328591,AG4GFMDMBE6FFJV3C6PQVHCQEPPQ,B00DJRLDMU,0.0,2019-11-04 04:57:24.372,10165,2448,Video Games,"[Video Games, Xbox One, Downloadable Content]",4,5.000000,"B0030MEITE,B00Z9TJBUW,B00CQ35C1Q","[-1, -1, -1, -1, -1, -1, -1, 1411, 3047, 2402]"
328592,AFBNG2LH66VMDJRSSXQLF7QXWYRQ,B004HILZUU,0.0,2017-06-08 11:49:11.000,6121,1695,Video Games,"[Video Games, Legacy Systems, PlayStation Syst...",1,,"B0BLFYF8K2,B00J34ZCWU,B07CRC2X77","[-1, -1, -1, -1, -1, -1, -1, 4561, 2665, 3905]"
328593,AHUJKRR2OSWGYIELO564WEEFUIPA,B01GWHPDEW,0.0,2002-06-14 16:40:12.000,18720,3368,Video Games,"[Video Games, Nintendo Switch, Accessories]",1,,,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]"
328594,AHAVA5VKMJ3OMOLGDZ3W45CKXEWA,B00KTORA0K,5.0,2019-05-25 04:03:51.505,15701,2726,Video Games,"[Video Games, Legacy Systems, Nintendo Systems...",2,5.000000,"B004AYCNR0,B007NUQICE,B000TYQL1O,B000SEU92W,B0...","[-1, -1, -1, 1657, 2074, 593, 583, 3715, 3448,..."


In [17]:
user_indices = train_df["user_indice"].unique()
item_indices = train_df["item_indice"].unique()

train_item_features = chunk_transform(
    train_df, item_metadata_pipeline, chunk_size=args.tfm_chunk_size
)
train_item_features = train_item_features.astype(np.float32)
val_item_features = chunk_transform(
    val_df, item_metadata_pipeline, chunk_size=args.tfm_chunk_size
)
val_item_features = val_item_features.astype(np.float32)

logger.info(f"{len(user_indices)=:,.0f}, {len(item_indices)=:,.0f}")

Transforming chunks:   0%|          | 0/33 [00:00<?, ?it/s]

Transforming chunks:   0%|          | 0/1 [00:00<?, ?it/s]

[32m2024-10-26 16:40:28.753[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m13[0m - [1mlen(user_indices)=19,578, len(item_indices)=4,630[0m


# Train

In [18]:
rating_dataset = UserItemBinaryDFDataset(
    train_df,
    "user_indice",
    "item_indice",
    args.rating_col,
    args.timestamp_col,
    item_feature=train_item_features,
)
val_rating_dataset = UserItemBinaryDFDataset(
    val_df,
    "user_indice",
    "item_indice",
    args.rating_col,
    args.timestamp_col,
    item_feature=val_item_features,
)

train_loader = DataLoader(
    rating_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True
)
val_loader = DataLoader(
    val_rating_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False
)

In [19]:
n_items = len(item_indices)
n_users = len(user_indices)

model = init_model(
    n_users, n_items, args.embedding_dim, item_feature_size, args.dropout
)
model

Ranker(
  (item_embedding): Embedding(4631, 128, padding_idx=4630)
  (user_embedding): Embedding(19578, 128)
  (gru): GRU(128, 128, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (fc_rating): Sequential(
    (0): Linear(in_features=536, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=128, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

#### Predict before train

In [20]:
model.item_embedding

Embedding(4631, 128, padding_idx=4630)

In [21]:
val_df = val_rating_dataset.df
val_df.sample(10)

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,main_category,categories,user_rating_cnt_90d,user_rating_avg_prev_rating_90d,user_rating_list_10_recent_asin,item_sequence
1379,AFKYHTQPYWI34Q5GFNGWNEZUOBQA,B0C4KN63KM,1.0,2021-10-10 23:31:12.208,7472,4610,Computers,"[Video Games, PC, Accessories, Gaming Keyboards]",1,,"B003YMMMFM,B0029ZBZ0I,B002BSH82M,B0000ZUGZ4,B0...","[-1, -1, -1, -1, -1, 1591, 1200, 1226, 306, 926]"
1176,AHI7TQQDFYVQ7Q5AQE4WEUUYQAYA,B002I0K622,0.0,2022-05-13 07:46:17.073,16832,1320,Video Games,"[Video Games, Legacy Systems, PlayStation Syst...",2,5.0,"B00DDILSBG,B07DKYN13M,B07895QZBF,B07WLT1C27,B0...","[-1, -1, -1, -1, 2441, 3942, 3828, 4212, 2442,..."
980,AF54JR3WONKVAUZUYDLOOPZN7NFQ,B0B1N7619L,1.0,2022-02-16 15:46:03.730,5408,4521,Video Games,"[Video Games, Legacy Systems, PlayStation Syst...",1,,"B01IFJEWTM,B0044XU27A,B07PFT19MG,B0181R6WUA,B0...","[-1, -1, -1, 3405, 1630, 4131, 3217, 4341, 420..."
1403,AHSSHPP7VQRWVDRV2SX3DGX5MELA,B07XV4NHHN,1.0,2022-01-20 21:00:29.759,18433,4267,Video Games,"[Video Games, Nintendo Switch, Accessories]",1,,"B000N5Z2L4,B087NNZZM8,B087NNPYP3,B087SHFL9B,B0...","[-1, -1, -1, -1, 526, 4343, 4342, 4344, 4203, ..."
1420,AG4ZLBG3I7CCMT3GUIZIBJ3P3SZA,B077GG9D5D,1.0,2021-09-03 22:26:42.967,10257,3813,Video Games,"[Video Games, PlayStation 4, Accessories, Cont...",1,,"B07NQV3X7K,B00NE5D4RE,B00EQNP8F4,B001H4NMNA,B0...","[-1, -1, -1, 4098, 2829, 2500, 1092, 2867, 142..."
139,AG6RPZLSVKORZN6ENXVURLOQY65Q,B00KWOSQC8,0.0,2021-09-23 21:16:15.379,10529,2746,Video Games,"[Video Games, Legacy Systems, Nintendo Systems...",1,,"B079ZGRDZH,B013HMN66M,B08TG138F1,B01LC9A6M4,B0...","[-1, -1, -1, -1, 3859, 3126, 4433, 3440, 4270,..."
1263,AGY6NCPSGMS4H2CITPJ2PD6KHG4Q,B0BL65X86R,1.0,2021-08-27 00:54:06.913,14354,4559,Video Games,"[Video Games, Online Game Services, PlayStatio...",1,,"B071WNSD3Q,B0118YZG0A,B00K32USMU,B014R4KYMS,B0...","[-1, -1, -1, 3660, 3100, 2696, 3151, 4168, 339..."
378,AEFPUTHKDMUQHHAB5BAZITHEUD3Q,B004Q8N7IE,0.0,2022-01-09 11:02:35.024,1765,1753,Video Games,"[Video Games, Legacy Systems, Xbox Systems, Xb...",1,,"B0009351PM,B07X1HF3V6,B07PZ8NZSZ,B0030JTD7O,B0...","[370, 4237, 4134, 1410, 4259, 4244, 3706, 3012..."
503,AHB2ZFDVYR7PQAJQL5Y7NHYLZQ6Q,B07SRWRH5D,0.0,2022-01-18 20:47:52.177,15726,4167,Video Games,"[Video Games, PlayStation 4, Games]",1,,"B002ST7AEU,B0721GGGS9,B01NBDG49R,B01N1P0KFI,B0...","[-1, -1, -1, -1, 1390, 3676, 3552, 3517, 1259,..."
1286,AGCACN7WBQ7BMZMNETXJGQ5TX42A,B0BVVTQ5JP,1.0,2022-06-08 01:33:14.208,11010,4584,Computers,"[Video Games, PC, Accessories, Gaming Mice]",1,,"B004G5YI3U,B0032C9V7G,B07YBXFDYN,B07YBX6T95,B0...","[-1, -1, -1, -1, -1, 1682, 1419, 4282, 4273, 4..."


In [22]:
user_id = val_df.sample(1)[args.user_col].values[0]
# user_id = "AH4AOFTTDPHPAFAAVFMAF25H2LIQ"
test_df = val_df.loc[lambda df: df[args.user_col].eq(user_id)]
with pd.option_context("display.max_colwidth", None):
    display(test_df)

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,main_category,categories,user_rating_cnt_90d,user_rating_avg_prev_rating_90d,user_rating_list_10_recent_asin,item_sequence
1656,AHOAOZCYMMMWYAKWISZRPQOKMGSA,B0C39GFK7P,1.0,2021-11-08 03:41:35.391,17762,4604,Computers,"[Video Games, PC, Accessories, Gaming Mice]",1,,"B00B67ZS3U,B0053BCML6,B008HPAXZ2,B00G2EVF3E,B00FE8WKMO,B07NQTN66P,B006JKASAC,B017C6OK7S,B00DC7G0GG,B00KWIYPZG","[2316, 1874, 2157, 2552, 2529, 4097, 1998, 3199, 2435, 2745]"
1863,AHOAOZCYMMMWYAKWISZRPQOKMGSA,B003TLNHCK,0.0,2021-11-08 03:41:35.391,17762,1560,Video Games,"[Video Games, Legacy Systems, PlayStation Systems, PlayStation 3, Games]",1,,"B00B67ZS3U,B0053BCML6,B008HPAXZ2,B00G2EVF3E,B00FE8WKMO,B07NQTN66P,B006JKASAC,B017C6OK7S,B00DC7G0GG,B00KWIYPZG","[2316, 1874, 2157, 2552, 2529, 4097, 1998, 3199, 2435, 2745]"


In [23]:
test_row = test_df.loc[lambda df: df[args.rating_col].gt(0)].iloc[0]
item_id = test_row[args.item_col]
item_sequence = test_row["item_sequence"]
row_idx = test_row.name
item_feature = val_item_features[row_idx]
logger.info(
    f"Test predicting before training with {args.user_col} = {user_id} and {args.item_col} = {item_id}"
)
user_indice = idm.get_user_index(user_id)
item_indice = idm.get_item_index(item_id)
user = torch.tensor([user_indice])
item_sequence = torch.tensor([item_sequence])
item_feature = torch.tensor([item_feature])
item = torch.tensor([item_indice])

model.eval()
model.predict(user, item_sequence, item_feature, item)

[32m2024-10-26 16:40:28.929[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mTest predicting before training with user_id = AHOAOZCYMMMWYAKWISZRPQOKMGSA and parent_asin = B0C39GFK7P[0m

Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_new.cpp:281.)



tensor([[0.5151]], grad_fn=<SigmoidBackward0>)

#### Training loop

##### Overfit 1 batch

In [24]:
early_stopping = EarlyStopping(
    monitor="val_loss", patience=10, mode="min", verbose=False
)

model = init_model(n_users, n_items, args.embedding_dim, item_feature_size, dropout=0)
lit_model = LitRanker(
    model,
    learning_rate=args.learning_rate,
    l2_reg=0.0,
    log_dir=args.notebook_persist_dp,
)

log_dir = f"{args.notebook_persist_dp}/logs/overfit"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    accelerator=args.device if args.device else "auto",
    max_epochs=100,
    overfit_batches=1,
    callbacks=[early_stopping],
)
trainer.fit(
    model=lit_model,
    train_dataloaders=train_loader,
    val_dataloaders=train_loader,
)
logger.info(f"Logs available at {trainer.log_dir}")

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(overfit_batches=1)` was configured so 1 batch will be used.

  | Name  | Type   | Params | Mode 
-----------------------------------------
0 | model | Ranker | 3.3 M  | train
-----------------------------------------
3.3 M     Trainable params
0         Non-trainable params
3.3 M     Total params
13.068    Total estimated model params size (MB)
11        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …


You requested to overfit but enabled val dataloader shuffling. We are turning off the val dataloader shuffling for you.


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.


You requested to overfit but enabled train dataloader shuffling. We are turning off the train dataloader shuffling for you.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.


The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.



Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=100` reached.
[32m2024-10-26 16:40:36.131[0m | [1mINFO    [0m | [36msrc.ranker.trainer[0m:[36mon_fit_end[0m:[36m129[0m - [1mLogging classification metrics...[0m
[32m2024-10-26 16:40:54.266[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m28[0m - [1mLogs available at /Users/dvq/frostmourne/recsys-mvp/notebooks/data/001-add-item-categories/logs/overfit/lightning_logs/version_1[0m


In [25]:
%tensorboard --logdir $trainer.log_dir

##### Fit on all data

In [26]:
# papermill_description=fit-model
early_stopping = EarlyStopping(
    monitor="val_loss", patience=args.early_stopping_patience, mode="min", verbose=False
)

checkpoint_callback = ModelCheckpoint(
    dirpath=f"{args.notebook_persist_dp}/checkpoints",
    filename="best-checkpoint",
    save_top_k=1,
    monitor="val_loss",
    mode="min",
)

model = init_model(
    n_users,
    n_items,
    args.embedding_dim,
    item_feature_size,
    dropout=args.dropout,
    item_embedding=pretrained_item_embedding,
)
lit_model = LitRanker(
    model,
    learning_rate=args.learning_rate,
    l2_reg=args.l2_reg,
    log_dir=args.notebook_persist_dp,
    evaluate_ranking=True,
    idm=idm,
    item_metadata_pipeline=item_metadata_pipeline,
    args=args,
)

log_dir = f"{args.notebook_persist_dp}/logs/run"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    max_epochs=args.max_epochs,
    callbacks=[early_stopping, checkpoint_callback],
    accelerator=args.device if args.device else "auto",
    logger=args._mlf_logger if args.log_to_mlflow else None,
)
trainer.fit(
    model=lit_model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

Checkpoint directory /Users/dvq/frostmourne/recsys-mvp/notebooks/data/001-add-item-categories/checkpoints exists and is not empty.


  | Name  | Type   | Params | Mode 
-----------------------------------------
0 | model | Ranker | 3.3 M  | train
-----------------------------------------
3.3 M     Trainable params
0         Non-trainable params
3.3 M     Total params
13.068    Total estimated model params size (MB)
11        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.



Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

[32m2024-10-26 16:49:15.777[0m | [1mINFO    [0m | [36msrc.ranker.trainer[0m:[36mon_fit_end[0m:[36m129[0m - [1mLogging classification metrics...[0m
[32m2024-10-26 16:49:16.411[0m | [1mINFO    [0m | [36msrc.ranker.trainer[0m:[36mon_fit_end[0m:[36m132[0m - [1mLogging ranking metrics...[0m


Generating recommendations:   0%|          | 0/177 [00:00<?, ?it/s]

2024/10/26 16:50:03 INFO mlflow.tracking._tracking_service.client: 🏃 View run 001-add-item-categories at: http://localhost:5002/#/experiments/3/runs/2a3649d8aa3644408d996a0643087aed.
2024/10/26 16:50:03 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://localhost:5002/#/experiments/3.


In [27]:
logger.info(
    f"Test predicting after training with {args.user_col} = {user_id} and {args.item_col} = {item_id}"
)
model.eval()
model.predict(user, item_sequence, item_feature, item)

[32m2024-10-26 16:50:03.324[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mTest predicting after training with user_id = AHOAOZCYMMMWYAKWISZRPQOKMGSA and parent_asin = B0C39GFK7P[0m


tensor([[0.0539]], grad_fn=<SigmoidBackward0>)

# Load best checkpoint

In [28]:
logger.info(f"Loading best checkpoint from {checkpoint_callback.best_model_path}...")
args.best_checkpoint_path = checkpoint_callback.best_model_path

best_trainer = LitRanker.load_from_checkpoint(
    checkpoint_callback.best_model_path,
    model=init_model(
        n_users, n_items, args.embedding_dim, item_feature_size, dropout=0
    ),
)

[32m2024-10-26 16:50:03.349[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mLoading best checkpoint from /Users/dvq/frostmourne/recsys-mvp/notebooks/data/001-add-item-categories/checkpoints/best-checkpoint-v1.ckpt...[0m


In [29]:
best_model = best_trainer.model.to(lit_model.device)

In [30]:
best_model.eval()
best_model.predict(user, item_sequence, item_feature, item)

tensor([[0.1139]], grad_fn=<SigmoidBackward0>)

### Persist id mapping

In [31]:
if args.log_to_mlflow:
    # Persist id_mapping so that at inference we can predict based on item_ids (string) instead of item_index
    run_id = trainer.logger.run_id
    mlf_client = trainer.logger.experiment
    mlf_client.log_artifact(run_id, idm_fp)
    # Persist item_feature_metadata pipeline
    mlf_client.log_artifact(run_id, args.item_metadata_pipeline_fp)

### Wrap inference function and register best checkpoint as MLflow model

In [32]:
inferrer = RankerInferenceWrapper(best_model)

In [33]:
sample_input = {
    "user_ids": [idm.get_user_id(0)],
    "item_sequences": [[idm.get_item_id(0), idm.get_item_id(1)]],
    **{col: [train_df[col].iloc[0]] for col in args.item_feature_cols},
    "item_ids": [idm.get_item_id(0)],
}
sample_output = inferrer.infer([0], [[0, 1]], [train_item_features[0]], [0])
sample_output

array([0.5342971], dtype=float32)

In [34]:
sample_input

{'user_ids': ['AE225O22SA7DLBOGOEIFL7FT5VYQ'],
 'item_sequences': [['0375869026', '9625990674']],
 'main_category': ['Video Games'],
 'categories': [array(['Video Games', 'Xbox One', 'Games'], dtype=object)],
 'item_ids': ['0375869026']}

In [35]:
if args.log_to_mlflow:
    run_id = trainer.logger.run_id
    sample_output_np = sample_output
    signature = infer_signature(sample_input, sample_output_np)
    idm_filename = idm_fp.split("/")[-1]
    item_metadata_pipeline_filename = args.item_metadata_pipeline_fp.split("/")[-1]
    with mlflow.start_run(run_id=run_id):
        mlflow.pyfunc.log_model(
            python_model=inferrer,
            artifact_path="inferrer",
            # We log the id_mapping to the predict function so that it can accept item_id and automatically convert ot item_indice for PyTorch model to use
            artifacts={
                "idm": mlflow.get_artifact_uri(idm_filename),
                "item_metadata_pipeline": mlflow.get_artifact_uri(
                    item_metadata_pipeline_filename
                ),
            },
            signature=signature,
            input_example=sample_input,
            registered_model_name=args.mlf_model_name,
        )


Since MLflow 2.16.0, we no longer convert dictionary input example to pandas Dataframe, and directly save it as a json object. If the model expects a pandas DataFrame input instead, please pass the pandas DataFrame as input example directly.



Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Registered model 'ranker' already exists. Creating a new version of this model...
2024/10/26 16:50:11 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: ranker, version 7
Created version '7' of model 'ranker'.


Downloading artifacts:   0%|          | 0/9 [00:00<?, ?it/s]

  "inputs": {
    "user_ids": [
      "AE225O22SA7DLBOGOEIFL7FT5VYQ"
    ],
    "item_sequences": [
      [
        "0375869026",
        "9625990674"
      ]
    ],
    "main_category": [
      "Video Games"
    ],
    "categories": [
      [
        "Video Games",
        "Xbox One",
        "Games"
      ]
    ],
    "item_ids": [
      "0375869026"
    ]
  }
}. Alternatively, you can avoid passing input example and pass model signature instead when logging the model. To ensure the input example is valid prior to serving, please try calling `mlflow.models.validate_serving_input` on the model uri and serving input example. A serving input example can be generated from model input example using `mlflow.models.convert_input_example_to_serving_input` function.
Got error: Per-column arrays must each be 1-dimensional
2024/10/26 16:50:11 INFO mlflow.tracking._tracking_service.client: 🏃 View run 001-add-item-categories at: http://localhost:5002/#/experiments/3/runs/2a3649d8aa3644408d996a064

# Set the newly trained model as champion

In [36]:
if args.log_to_mlflow:
    val_roc_auc = trainer.logger.experiment.get_run(trainer.logger.run_id).data.metrics[
        "val_roc_auc"
    ]

    if val_roc_auc > args.min_roc_auc:
        logger.info(f"Aliasing the new model as champion...")
        model_version = (
            mlf_client.get_registered_model(args.mlf_model_name)
            .latest_versions[0]
            .version
        )

        mlf_client.set_registered_model_alias(
            name=args.mlf_model_name, alias="champion", version=model_version
        )

        mlf_client.set_model_version_tag(
            name=args.mlf_model_name,
            version=model_version,
            key="author",
            value="quy.dinh",
        )

[32m2024-10-26 16:50:11.847[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m7[0m - [1mAliasing the new model as champion...[0m


# Clean up

In [37]:
all_params = [args]

if args.log_to_mlflow:
    with mlflow.start_run(run_id=run_id):
        for params in all_params:
            params_dict = params.dict()
            params_ = dict()
            for k, v in params_dict.items():
                if k == "top_K":
                    k = "top_big_K"
                if k == "top_k":
                    k = "top_small_k"
                params_[f"{params.__repr_name__()}.{k}"] = v
            mlflow.log_params(params_)

2024/10/26 16:50:11 INFO mlflow.tracking._tracking_service.client: 🏃 View run 001-add-item-categories at: http://localhost:5002/#/experiments/3/runs/2a3649d8aa3644408d996a0643087aed.
2024/10/26 16:50:11 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://localhost:5002/#/experiments/3.
