# Ranker that can takes into accound different features

# Set up

In [1]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

In [2]:
import os
import sys
from typing import List

import dill
import lightning as L
import numpy as np
import pandas as pd
import torch
from tqdm.auto import tqdm
from dotenv import load_dotenv
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import MLFlowLogger
from loguru import logger
from mlflow.models.signature import infer_signature
from pydantic import BaseModel
from torch.utils.data import DataLoader
from qdrant_client import QdrantClient

import mlflow

load_dotenv()

sys.path.insert(0, "..")

from src.data_prep_utils import chunk_transform
from src.dataset import UserItemBinaryDFDataset
from src.id_mapper import IDMapper
from src.ranker.inference import RankerInferenceWrapper
from src.ranker.model import Ranker
from src.ranker.trainer import LitRanker
from src.viz import blueq_colors

# Controller

In [3]:
max_epochs = 100

In [4]:
class Args(BaseModel):
    testing: bool = False
    log_to_mlflow: bool = True
    experiment_name: str = "RecSys MVP - Ranker"
    run_name: str = "012-3-neg-samples-and-user-item-out-product"
    notebook_persist_dp: str = None
    random_seed: int = 41
    device: str = None

    item_metadata_pipeline_fp: str = "../data/item_metadata_pipeline.dill"
    qdrant_url: str = None
    qdrant_collection_name: str = "item_desc_sbert"

    max_epochs: int = max_epochs
    batch_size: int = 128
    tfm_chunk_size: int = 10000
    neg_to_pos_ratio: int = 3

    user_col: str = "user_id"
    item_col: str = "parent_asin"
    rating_col: str = "rating"
    timestamp_col: str = "timestamp"
    item_feature_cols: List[str] = ["main_category", "categories"]

    top_K: int = 100
    top_k: int = 10

    batch_size: int = 128

    embedding_dim: int = 128
    dropout: float = 0.3
    early_stopping_patience: int = 5
    learning_rate: float = 0.001
    l2_reg: float = 1e-4

    mlf_item2vec_model_name: str = "item2vec"
    mlf_model_name: str = "ranker"
    min_roc_auc: float = 0.7

    best_checkpoint_path: str = None

    def init(self):
        self.notebook_persist_dp = os.path.abspath(f"data/{self.run_name}")
        os.makedirs(self.notebook_persist_dp, exist_ok=True)

        if not (qdrant_host := os.getenv("QDRANT_HOST")):
            raise Exception(f"Environment variable QDRANT_HOST is not set.")

        qdrant_port = os.getenv("QDRANT_PORT")
        self.qdrant_url = f"{qdrant_host}:{qdrant_port}"
        
        if not (mlflow_uri := os.environ.get("MLFLOW_TRACKING_URI")):
            logger.warning(
                f"Environment variable MLFLOW_TRACKING_URI is not set. Setting self.log_to_mlflow to false."
            )
            self.log_to_mlflow = False

        if self.log_to_mlflow:
            logger.info(
                f"Setting up MLflow experiment {self.experiment_name} - run {self.run_name}..."
            )
            self._mlf_logger = MLFlowLogger(
                experiment_name=self.experiment_name,
                run_name=self.run_name,
                tracking_uri=mlflow_uri,
                log_model=True,
            )

        return self


args = Args().init()

print(args.model_dump_json(indent=2))

[32m2024-10-27 15:11:35.237[0m | [1mINFO    [0m | [36m__main__[0m:[36minit[0m:[36m59[0m - [1mSetting up MLflow experiment RecSys MVP - Ranker - run 012-3-neg-samples-and-user-item-out-product...[0m


{
  "testing": false,
  "log_to_mlflow": true,
  "experiment_name": "RecSys MVP - Ranker",
  "run_name": "012-3-neg-samples-and-user-item-out-product",
  "notebook_persist_dp": "/Users/dvq/frostmourne/recsys-mvp/notebooks/data/012-3-neg-samples-and-user-item-out-product",
  "random_seed": 41,
  "device": null,
  "item_metadata_pipeline_fp": "../data/item_metadata_pipeline.dill",
  "qdrant_url": "localhost:6333",
  "qdrant_collection_name": "item_desc_sbert",
  "max_epochs": 100,
  "batch_size": 128,
  "tfm_chunk_size": 10000,
  "neg_to_pos_ratio": 3,
  "user_col": "user_id",
  "item_col": "parent_asin",
  "rating_col": "rating",
  "timestamp_col": "timestamp",
  "item_feature_cols": [
    "main_category",
    "categories"
  ],
  "top_K": 100,
  "top_k": 10,
  "embedding_dim": 128,
  "dropout": 0.3,
  "early_stopping_patience": 5,
  "learning_rate": 0.001,
  "l2_reg": 0.0001,
  "mlf_item2vec_model_name": "item2vec",
  "mlf_model_name": "ranker",
  "min_roc_auc": 0.7,
  "best_checkpoint_

# Implement

In [5]:
def init_model(
    n_users, n_items, embedding_dim, item_feature_size, dropout, item_embedding=None
):
    model = Ranker(
        n_users,
        n_items,
        embedding_dim,
        item_feature_size=item_feature_size,
        dropout=dropout,
        item_embedding=item_embedding,
    )
    return model

## Load pretrained Item2Vec embeddings

In [6]:
mlf_client = mlflow.MlflowClient()
model = mlflow.pyfunc.load_model(
    model_uri=f"models:/{args.mlf_item2vec_model_name}@champion"
)
skipgram_model = model.unwrap_python_model().model
embedding_0 = skipgram_model.embeddings(torch.tensor(0))
embedding_dim = embedding_0.size()[0]
id_mapping = model.unwrap_python_model().id_mapping
pretrained_item_embedding = skipgram_model.embeddings

Downloading artifacts:   0%|          | 0/8 [00:00<?, ?it/s]

In [7]:
assert (
    pretrained_item_embedding.embedding_dim == args.embedding_dim
), "Mismatch pretrained item_embedding dimension"

## Load vectorized item features

In [8]:
with open(args.item_metadata_pipeline_fp, "rb") as f:
    item_metadata_pipeline = dill.load(f)

## Load ANN Index

In [9]:
ann_index = QdrantClient(url=args.qdrant_url)
if not ann_index.collection_exists(args.qdrant_collection_name):
    raise Exception(
        f"Required Qdrant collection {args.qdrant_collection_name} does not exist"
    )

In [10]:
def get_vector_by_ids(ids: List[int], chunk_size=100):
    records = []
    for i in tqdm(range(0, len(ids), chunk_size)):
        _ids = ids[i:i+chunk_size]
        _records = ann_index.retrieve(
            collection_name=args.qdrant_collection_name, ids=_ids, with_vectors=True
        )
        records.extend(_records)
    return np.array([record.vector for record in records])

In [11]:
vector = get_vector_by_ids([0])[0]
sbert_embedding_dim = vector.shape[0]
neighbors = ann_index.search(
    collection_name=args.qdrant_collection_name, query_vector=vector, limit=5
)

  0%|          | 0/1 [00:00<?, ?it/s]

In [12]:
neighbors

[ScoredPoint(id=0, version=125, score=1.0, payload={'parent_asin': '0375869026', 'title': 'Wonder'}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=3206, version=95, score=0.81854767, payload={'parent_asin': 'B017VLXJ7G', 'title': ''}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=202, version=22, score=0.8096904, payload={'parent_asin': 'B000066RKC', 'title': 'The Thing'}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=56, version=126, score=0.7248029, payload={'parent_asin': 'B00002SVFQ', 'title': 'F-Zero'}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=437, version=111, score=0.72028804, payload={'parent_asin': 'B000EGELQ4', 'title': 'Big Brain Academy'}, vector=None, shard_key=None, order_value=None)]

# Test implementation

In [13]:
embedding_dim = 8
batch_size = 2

# Mock data
user_indices = [0, 0, 1, 2, 2]
item_indices = [0, 1, 2, 3, 4]
timestamps = [0, 1, 2, 3, 4]
ratings = [0, 4, 5, 3, 0]
item_sequences = [
    [-1, -1, 2, 3],
    [-1, -1, 2, 3],
    [-1, -1, 1, 3],
    [-1, -1, 2, 1],
    [-1, -1, 2, 1],
]
main_category = [
    "All Electronics",
    "Video Games",
    "All Electronics",
    "Video Games",
    "Unknown",
]
categories = [[], ["Headsets"], ["Video Games"], [], ["blah blah"]]
title = ["World of Warcraft", "DotA 2", "Diablo IV", "Football Manager 2024", "Unknown"]
description = [[], [], ["Video games blah blah"], [], ["blah blah"]]
price = ["from 14.99", "14.99", "price: 9.99", "20 dollars", "None"]

train_df = pd.DataFrame(
    {
        "user_indice": user_indices,
        "item_indice": item_indices,
        args.timestamp_col: timestamps,
        args.rating_col: ratings,
        "item_sequence": item_sequences,
        "main_category": main_category,
        "title": title,
        "description": description,
        "categories": categories,
        "price": price,
    }
)
train_item_features = item_metadata_pipeline.transform(train_df).astype(np.float32)
sbert_vectors = get_vector_by_ids(train_df['item_indice'].values.tolist()).astype(np.float32)
train_item_features = np.hstack([train_item_features, sbert_vectors])

n_users = len(set(user_indices))
n_items = len(set(item_indices))
item_feature_size = train_item_features.shape[1]

model = init_model(n_users, n_items, embedding_dim, item_feature_size, args.dropout)

# Example forward pass
model.eval()
users = torch.tensor(user_indices)
items = torch.tensor(item_indices)
item_sequences = torch.tensor(item_sequences)
item_features = torch.tensor(train_item_features)
predictions = model.predict(users, item_sequences, item_features, items)
print(predictions)

  0%|          | 0/1 [00:00<?, ?it/s]

tensor([[0.4631],
        [0.4790],
        [0.5072],
        [0.5027],
        [0.4979]], grad_fn=<SigmoidBackward0>)


In [14]:
rating_dataset = UserItemBinaryDFDataset(
    train_df,
    "user_indice",
    "item_indice",
    args.rating_col,
    args.timestamp_col,
    item_feature=train_item_features,
)

train_loader = DataLoader(rating_dataset, batch_size=batch_size, shuffle=False)

In [15]:
for batch_input in train_loader:
    print(batch_input)

{'user': tensor([0, 0]), 'item': tensor([0, 1]), 'rating': tensor([0., 1.]), 'item_sequence': tensor([[-1, -1,  2,  3],
        [-1, -1,  2,  3]]), 'item_feature': tensor([[-1.4698e-02,  5.6424e+00, -1.4698e-02,  ...,  2.6757e-02,
         -5.9982e-02,  2.5429e-03],
        [-1.4698e-02, -1.7723e-01, -1.4698e-02,  ..., -1.7954e-02,
         -2.8491e-02,  2.2718e-03]])}
{'user': tensor([1, 2]), 'item': tensor([2, 3]), 'rating': tensor([1., 1.]), 'item_sequence': tensor([[-1, -1,  1,  3],
        [-1, -1,  2,  1]]), 'item_feature': tensor([[-1.4698e-02,  5.6424e+00, -1.4698e-02,  ...,  1.4054e-02,
         -3.8501e-02, -7.5163e-03],
        [-1.4698e-02, -1.7723e-01, -1.4698e-02,  ...,  1.0250e-02,
         -7.4514e-02,  2.7393e-03]])}
{'user': tensor([2]), 'item': tensor([4]), 'rating': tensor([0.]), 'item_sequence': tensor([[-1, -1,  2,  1]]), 'item_feature': tensor([[-1.4698e-02, -1.7723e-01, -1.4698e-02, -2.5463e-02, -1.4698e-02,
         -1.4698e-02, -2.0788e-02, -8.2101e-02, -2.872

In [16]:
# model
lit_model = LitRanker(model, log_dir=args.notebook_persist_dp)

# train model
trainer = L.Trainer(
    default_root_dir=f"{args.notebook_persist_dp}/test",
    max_epochs=2,
    accelerator=args.device if args.device else "auto",
)
trainer.fit(
    model=lit_model, train_dataloaders=train_loader, val_dataloaders=train_loader
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name  | Type   | Params | Mode
----------------------------------------
0 | model | Ranker | 8.2 K  | eval
----------------------------------------
8.2 K     Trainable params
0         Non-trainable params
8.2 K     Total params
0.033     Total estimated model params size (MB)
0         Modules in train mode
14        Modules in eval mode


Sanity Checking: |                                                                                            …

/Users/dvq/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.
/Users/dvq/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.
/Users/dvq/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (3) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=2` reached.
[32m2024-10-27 15:11:37.489[0m | [1mINFO    [0m | [36msrc.ranker.trainer[0m:[36mon_fit_end[0m:[36m133[0m - [1mLogging classification metrics...[0m


In [17]:
# After fitting
model.eval()
predictions = model.predict(users, item_sequences, item_features, items)
print(predictions)

tensor([[0.4670],
        [0.4987],
        [0.5152],
        [0.5208],
        [0.4987]], grad_fn=<SigmoidBackward0>)


In [18]:
all_items_df = train_df.drop_duplicates(subset=["item_indice"])
all_items_indices = all_items_df["item_indice"].values
all_items_features = item_metadata_pipeline.transform(all_items_df).astype(np.float32)
all_sbert_vectors = get_vector_by_ids(all_items_indices.tolist()).astype(np.float32)
all_items_features = np.hstack([all_items_features, all_sbert_vectors])

# Get the last row of each item as input for recommendations (containing the most updated item_sequence)
to_rec_df = train_df.sort_values(args.timestamp_col, ascending=False).drop_duplicates(
    subset=["user_indice"]
)
recommendations = model.recommend(
    torch.tensor(to_rec_df["user_indice"].values.tolist()),
    torch.tensor(to_rec_df["item_sequence"].values.tolist()),
    torch.tensor(all_items_features),
    torch.tensor(all_items_indices),
    k=2,
    batch_size=4,
)
recommendations

  0%|          | 0/1 [00:00<?, ?it/s]

Generating recommendations:   0%|          | 0/1 [00:00<?, ?it/s]

{'user_indice': [2, 2, 1, 1, 0, 0],
 'recommendation': [1, 3, 1, 2, 1, 3],
 'score': [0.5275313258171082,
  0.5208275318145752,
  0.5420741438865662,
  0.5151524543762207,
  0.49870458245277405,
  0.4703948497772217]}

# Prep data

In [19]:
train_df = pd.read_parquet("../data/train_features_neg_df.parquet")
val_df = pd.read_parquet("../data/val_features_neg_df.parquet")
idm_fp = "../data/idm.json"
idm = IDMapper().load(idm_fp)

assert (
    train_df[args.user_col].map(lambda s: idm.get_user_index(s))
    != train_df["user_indice"]
).sum() == 0, "Mismatch IDM"
assert (
    val_df[args.user_col].map(lambda s: idm.get_user_index(s)) != val_df["user_indice"]
).sum() == 0, "Mismatch IDM"

In [20]:
train_df

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,main_category,categories,price,user_rating_cnt_90d,user_rating_avg_prev_rating_90d,user_rating_list_10_recent_asin,item_sequence
0,AEKSUPM7CH53J3G5PA3JLWLJXUMQ,B00QXJFDZO,0.0,2017-10-30 14:23:22.389,2561,2919,Video Games,"[Video Games, PlayStation 4, Games]",,2,5.000000,"B005FVBYV8,B003FMTZSI,B01MS6WG9S,B073W2T5F6","[-1, -1, -1, -1, -1, -1, 1912, 1470, 3498, 3723]"
1,AHSNMFN6DUFTNEZAXBVPIYMXWIFQ,B075MYT126,0.0,2017-11-27 22:01:33.258,18413,3777,Video Games,"[Video Games, Nintendo Switch, Accessories, Co...",94.98,2,4.000000,"B00CJ9OTNE,B0118YZG0A,B008M502H6,B003Y70W4U,B0...","[2391, 3100, 2176, 1588, 3161, 2133, 2906, 166..."
2,AFE47G5MX35LSHZHZXRYEJFMYPUA,B007VYW5K6,0.0,2017-03-23 21:41:18.000,6463,2086,Video Games,"[Video Games, PC]",,1,,"B07YBX8RNF,B0166QDJDQ,B01CHU4IY4,B00Z9LUDX4,B0...","[4278, 3183, 3288, 3038, 4508, 3391, 3403, 368..."
3,AFJDWGBE3MGULXTO3FUZ5YB6FKDA,B07L5FKGQH,0.0,2017-01-18 15:50:12.000,7246,4048,Video Games,"[Video Games, Xbox One, Games]",49.88,25,4.416667,"B00I6E6SH6,B00O65I2VY,B005GISQQG,B00008KTNW,B0...","[2632, 2859, 1920, 253, 1428, 1053, 584, 732, ..."
4,AHFDYGJR3SM2D463ZWKGHJPNBKDA,B002BSA2LQ,0.0,2014-01-29 22:50:20.000,16376,1215,Video Games,"[Video Games, Legacy Systems, Xbox Systems, Xb...",31.49,4,5.000000,"B002I0K956,B008CZN458,B0050SXVK8","[-1, -1, -1, -1, -1, -1, -1, 1328, 2146, 1843]"
...,...,...,...,...,...,...,...,...,...,...,...,...,...
657187,AE5TQ7DBEX2L5T665M6ZDPGYZ32Q,B01LDUYTYS,0.0,2013-10-05 20:20:52.000,592,3442,Video Games,"[Video Games, Legacy Systems, Nintendo Systems...",249.99,1,,B07X1HF3V6,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, 4237]"
657188,AFDG3CXM4DP7X436YNOKTJHVKJQA,B087NNPYP3,5.0,2018-07-10 21:22:10.594,6351,4342,Video Games,"[Video Games, Nintendo Switch, Consoles]",,3,5.000000,"B002I0H79C,B00503E9FY,B00KVOVBGM,B00SHXKC8M,B0...","[1292, 1807, 2734, 2940, 3402, 2759, 2702, 104..."
657189,AFOUC3S3RH7AXMPZBZHLO4WMLLVA,B004AM65C6,0.0,2018-12-16 13:39:37.174,8062,1651,Video Games,"[Video Games, Legacy Systems, Xbox Systems, Xb...",12.48,2,5.000000,"B002BSA388,B00PIEI1DG,B08MBHYJP4,B071GPJVTQ,B0...","[1216, 2898, 4397, 3643, 3642, 3527, 3423, 367..."
657190,AEPOGF2QMAXO4W3TYP27DCQRITGA,B07X1HF3V6,0.0,2013-05-30 22:53:17.000,3373,4237,Video Games,"[Video Games, Legacy Systems, PlayStation Syst...",34.43,4,3.666667,"B0013OL0BK,B002D2Y3IS,B0044R8X9U,B07VLCRZ21,B0...","[-1, -1, -1, -1, -1, 652, 1240, 1629, 4207, 1707]"


In [21]:
user_indices = train_df["user_indice"].unique()
item_indices = train_df["item_indice"].unique()
all_sbert_vectors = get_vector_by_ids(item_indices.tolist(), chunk_size=1000).astype(np.float32)

train_item_features = chunk_transform(
    train_df, item_metadata_pipeline, chunk_size=args.tfm_chunk_size
)
train_item_features = train_item_features.astype(np.float32)
train_sbert_vectors = all_sbert_vectors[train_df['item_indice'].values]
train_item_features = np.hstack([train_item_features, train_sbert_vectors])

val_item_features = chunk_transform(
    val_df, item_metadata_pipeline, chunk_size=args.tfm_chunk_size
)
val_item_features = val_item_features.astype(np.float32)
val_sbert_vectors = all_sbert_vectors[val_df['item_indice'].values]
val_item_features = np.hstack([val_item_features, val_sbert_vectors])

logger.info(f"{len(user_indices)=:,.0f}, {len(item_indices)=:,.0f}")

  0%|          | 0/5 [00:00<?, ?it/s]

Transforming chunks:   0%|          | 0/66 [00:00<?, ?it/s]

Transforming chunks:   0%|          | 0/1 [00:00<?, ?it/s]

[32m2024-10-27 15:11:44.041[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m19[0m - [1mlen(user_indices)=19,578, len(item_indices)=4,630[0m


# Train

In [22]:
rating_dataset = UserItemBinaryDFDataset(
    train_df,
    "user_indice",
    "item_indice",
    args.rating_col,
    args.timestamp_col,
    item_feature=train_item_features,
)
val_rating_dataset = UserItemBinaryDFDataset(
    val_df,
    "user_indice",
    "item_indice",
    args.rating_col,
    args.timestamp_col,
    item_feature=val_item_features,
)

train_loader = DataLoader(
    rating_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True
)
val_loader = DataLoader(
    val_rating_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False
)

In [23]:
n_items = len(item_indices)
n_users = len(user_indices)

model = init_model(
    n_users, n_items, args.embedding_dim, item_feature_size, args.dropout
)
model

Ranker(
  (item_embedding): Embedding(4631, 128, padding_idx=4630)
  (user_embedding): Embedding(19578, 128)
  (gru): GRU(128, 128, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (item_feature_tower): Sequential(
    (0): Linear(in_features=921, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
  )
  (fc_rating): Sequential(
    (0): Linear(in_features=640, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=128, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

#### Predict before train

In [24]:
val_df = val_rating_dataset.df
val_df.sample(10)

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,main_category,categories,price,user_rating_cnt_90d,user_rating_avg_prev_rating_90d,user_rating_list_10_recent_asin,item_sequence
3759,AE3NRCMFIBBA2XVODR47YYNLKRDA,B001EYUQC8,1.0,2021-11-13 09:59:46.634,268,908,Video Games,"[Video Games, Legacy Systems, PlayStation Syst...",44.49,1,,"B000OLXX86,B000B9RI14,B0050SWQ86,B00CTKHXFO,B0...","[-1, -1, -1, 543, 417, 1835, 2407, 3123, 1526,..."
2114,AEILFB67AZEGIKLQ7AQ5XPSIAMFQ,B007VM72BA,0.0,2022-03-09 22:27:41.699,2218,2083,Video Games,"[Video Games, Legacy Systems, PlayStation Syst...",89.95,1,,"B005B8DRVU,B0015PHMFU,B00TKLFES8,B00MOR1A7Y,B0...","[-1, -1, -1, -1, -1, 1899, 666, 2954, 2806, 2804]"
1467,AGTPSXXK4B2NSMTIDJCOIXELXOIA,B072199RYC,0.0,2021-08-13 13:37:54.242,13677,3674,Video Games,"[Video Games, PlayStation 4, Games]",61.99,5,4.5,"B07T5QKKVP,B08JHX17ZZ,B07MWB5YJW,B094WQR3H3,B0...","[-1, -1, 4176, 4386, 4068, 4453, 4070, 4387, 3..."
1653,AH4LS5HFN2VGSKNKQGNZGENHBVVA,B0BFFJTZNN,0.0,2021-12-19 22:28:25.889,15027,4544,Video Games,"[Video Games, Xbox One, Accessories, Headsets]",135.0,1,,"B008FHL56S,B0053B66KE,B00CQ9L1Z6,B017C6OK7S,B0...","[2151, 1871, 2404, 3199, 4403, 2682, 1915, 275..."
1054,AFWIWPE67QV6CN77CXHRUW6PKVHQ,B00K5HTPR2,0.0,2022-01-30 23:38:08.357,9258,2697,Video Games,"[Video Games, PlayStation 4, Games]",26.64,1,,"B0C8MPWZ1H,B07WZ78VRN,B07624RBWB,B07DK1H3H5,B0...","[-1, -1, -1, -1, 4624, 4233, 3788, 3936, 4603,..."
2229,AHYEDK3YUJ272WRGEYYPKMJYYGBQ,B000N5Z2L4,1.0,2022-01-02 19:30:15.954,19324,526,Video Games,"[Video Games, Xbox Digital Content, Subscripti...",9.99,2,4.0,"B00DTWEOZ8,B00H727K20,B00JWSJ6G0,B00SN1QEGW,B0...","[2460, 2592, 2690, 2941, 4004, 4476, 4613, 461..."
3470,AERAX4VNX4JDFBK6BOH6NQ57U4BA,B07TC8J6HK,1.0,2022-01-26 17:59:37.364,3634,4180,Cell Phones & Accessories,"[Video Games, Legacy Systems, Nintendo Systems...",29.99,3,5.0,"B003BFW4OG,B001G7PS4Y,B001BP4JY6,B001QCWSII,B0...","[-1, -1, -1, 1447, 1078, 690, 1140, 1124, 4455..."
3798,AFZ5WJ7R2S75BH4XUIYV7AMPUBZQ,B019WRM1IA,0.0,2021-10-19 15:12:46.725,9657,3242,All Electronics,"[Video Games, Legacy Systems, Xbox Systems, Xb...",67.83,2,5.0,"B08F4C6HCD,B08JHYYTMT,B08JHX17ZZ,B07NVVV9R7,B0...","[-1, -1, -1, 4372, 4387, 4386, 4103, 3440, 438..."
3089,AF3FBQYJGE2RQRMCI5OOOFTPIJ6A,B00111SFEU,0.0,2022-01-24 20:05:10.086,5128,641,Video Games,"[Video Games, Value Games Under $20]",,2,5.0,"B00ZM5ON88,B07Z9Z39ZW,B00DTY9B0O,B07DJRFSL1,B0...","[3073, 4299, 2461, 3933, 4140, 2943, 4158, 272..."
1393,AFQPVFAE6XX4B2HICHSV4BOA2DGA,B002JTX7KA,0.0,2022-03-15 16:34:12.635,8332,1352,Video Games,"[Video Games, PC, Games]",,6,5.0,"B06XHLM4DX,B01N3ASPNV,B087NNPYP3,B00MB1I3FU,B0...","[3573, 3527, 4342, 2793, 4391, 4078, 3829, 424..."


In [25]:
user_id = val_df.sample(1)[args.user_col].values[0]
# user_id = "AH4AOFTTDPHPAFAAVFMAF25H2LIQ"
test_df = val_df.loc[lambda df: df[args.user_col].eq(user_id)]
with pd.option_context("display.max_colwidth", None):
    display(test_df)

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,main_category,categories,price,user_rating_cnt_90d,user_rating_avg_prev_rating_90d,user_rating_list_10_recent_asin,item_sequence
635,AGDAPPCYV472FOUKDGAHZRW766GA,B00CMQTVK0,0.0,2021-11-12 00:10:40.778,11176,2399,Video Games,"[Video Games, Xbox One, Consoles]",589.99,2,1.0,"B001EYUW72,B00HM0RJJG,B00GM5UIN6,B087SHFL9B,B00000DMAT,B001EYUOA2,B0086VPV86,B000095ZH5,B00005O0I2,B07NPZZSRC","[996, 2605, 2569, 4344, 7, 844, 2115, 264, 162, 4090]"
636,AGDAPPCYV472FOUKDGAHZRW766GA,B07NPZZSRC,1.0,2021-09-13 02:35:26.496,11176,4090,Video Games,"[Video Games, Xbox One, Games]",26.5,1,,"B001ELJPSC,B001EYUW72,B00HM0RJJG,B00GM5UIN6,B087SHFL9B,B00000DMAT,B001EYUOA2,B0086VPV86,B000095ZH5,B00005O0I2","[815, 996, 2605, 2569, 4344, 7, 844, 2115, 264, 162]"
942,AGDAPPCYV472FOUKDGAHZRW766GA,B014SIVGAW,0.0,2021-11-12 00:10:40.778,11176,3152,Computers,"[Video Games, Legacy Systems, Xbox Systems, Xbox 360, Accessories, Memory]",33.99,2,1.0,"B001EYUW72,B00HM0RJJG,B00GM5UIN6,B087SHFL9B,B00000DMAT,B001EYUOA2,B0086VPV86,B000095ZH5,B00005O0I2,B07NPZZSRC","[996, 2605, 2569, 4344, 7, 844, 2115, 264, 162, 4090]"
1181,AGDAPPCYV472FOUKDGAHZRW766GA,B00UH9DN58,0.0,2022-01-26 23:23:50.857,11176,2964,Video Games,"[Video Games, Legacy Systems, Nintendo Systems, Nintendo 3DS & 2DS, Games]",70.74,2,5.0,"B00HM0RJJG,B00GM5UIN6,B087SHFL9B,B00000DMAT,B001EYUOA2,B0086VPV86,B000095ZH5,B00005O0I2,B07NPZZSRC,B00004TN9O","[2605, 2569, 4344, 7, 844, 2115, 264, 162, 4090, 102]"
1205,AGDAPPCYV472FOUKDGAHZRW766GA,B00005AV8W,0.0,2022-01-26 23:23:50.857,11176,130,Video Games,"[Video Games, Legacy Systems, Nintendo Systems, Nintendo NES, Games]",39.87,2,5.0,"B00HM0RJJG,B00GM5UIN6,B087SHFL9B,B00000DMAT,B001EYUOA2,B0086VPV86,B000095ZH5,B00005O0I2,B07NPZZSRC,B00004TN9O","[2605, 2569, 4344, 7, 844, 2115, 264, 162, 4090, 102]"
1259,AGDAPPCYV472FOUKDGAHZRW766GA,B002D2Y3IS,0.0,2022-04-07 05:49:55.659,11176,1240,Video Games,"[Video Games, Legacy Systems, Nintendo Systems, Wii, Accessories, Controllers]",24.74,2,1.0,"B00GM5UIN6,B087SHFL9B,B00000DMAT,B001EYUOA2,B0086VPV86,B000095ZH5,B00005O0I2,B07NPZZSRC,B00004TN9O,B004UPBXDO","[2569, 4344, 7, 844, 2115, 264, 162, 4090, 102, 1778]"
1545,AGDAPPCYV472FOUKDGAHZRW766GA,B01GWHR1TC,0.0,2022-01-26 23:23:50.857,11176,3369,Video Games,"[Video Games, PlayStation 4, Games]",18.0,2,5.0,"B00HM0RJJG,B00GM5UIN6,B087SHFL9B,B00000DMAT,B001EYUOA2,B0086VPV86,B000095ZH5,B00005O0I2,B07NPZZSRC,B00004TN9O","[2605, 2569, 4344, 7, 844, 2115, 264, 162, 4090, 102]"
1658,AGDAPPCYV472FOUKDGAHZRW766GA,B00GOOSV98,0.0,2021-09-13 02:35:26.496,11176,2580,Video Games,"[Video Games, PlayStation 4, Accessories]",,1,,"B001ELJPSC,B001EYUW72,B00HM0RJJG,B00GM5UIN6,B087SHFL9B,B00000DMAT,B001EYUOA2,B0086VPV86,B000095ZH5,B00005O0I2","[815, 996, 2605, 2569, 4344, 7, 844, 2115, 264, 162]"
1874,AGDAPPCYV472FOUKDGAHZRW766GA,B004UPBXDO,1.0,2022-01-26 23:23:50.857,11176,1778,Video Games,"[Video Games, Legacy Systems, PlayStation Systems, PlayStation 3, Games]",34.29,2,5.0,"B00HM0RJJG,B00GM5UIN6,B087SHFL9B,B00000DMAT,B001EYUOA2,B0086VPV86,B000095ZH5,B00005O0I2,B07NPZZSRC,B00004TN9O","[2605, 2569, 4344, 7, 844, 2115, 264, 162, 4090, 102]"
2017,AGDAPPCYV472FOUKDGAHZRW766GA,B00KCCNMYW,0.0,2021-11-12 00:10:40.778,11176,2702,Video Games,"[Video Games, PlayStation 4, Games]",14.99,2,1.0,"B001EYUW72,B00HM0RJJG,B00GM5UIN6,B087SHFL9B,B00000DMAT,B001EYUOA2,B0086VPV86,B000095ZH5,B00005O0I2,B07NPZZSRC","[996, 2605, 2569, 4344, 7, 844, 2115, 264, 162, 4090]"


In [26]:
test_row = test_df.loc[lambda df: df[args.rating_col].gt(0)].iloc[0]
item_id = test_row[args.item_col]
item_sequence = test_row["item_sequence"]
row_idx = test_row.name
item_feature = val_item_features[row_idx]
logger.info(
    f"Test predicting before training with {args.user_col} = {user_id} and {args.item_col} = {item_id}"
)
user_indice = idm.get_user_index(user_id)
item_indice = idm.get_item_index(item_id)
user = torch.tensor([user_indice])
item_sequence = torch.tensor([item_sequence])
item_feature = torch.tensor([item_feature])
item = torch.tensor([item_indice])

model.eval()
model.predict(user, item_sequence, item_feature, item)

[32m2024-10-27 15:11:44.832[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mTest predicting before training with user_id = AGDAPPCYV472FOUKDGAHZRW766GA and parent_asin = B07NPZZSRC[0m

Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_new.cpp:281.)



tensor([[0.4913]], grad_fn=<SigmoidBackward0>)

#### Training loop

##### Overfit 1 batch

In [27]:
early_stopping = EarlyStopping(
    monitor="val_loss", patience=10, mode="min", verbose=False
)

model = init_model(n_users, n_items, args.embedding_dim, item_feature_size, dropout=0)
lit_model = LitRanker(
    model,
    learning_rate=args.learning_rate,
    l2_reg=0.0,
    log_dir=args.notebook_persist_dp,
)

log_dir = f"{args.notebook_persist_dp}/logs/overfit"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    accelerator=args.device if args.device else "auto",
    max_epochs=100,
    overfit_batches=1,
    callbacks=[early_stopping],
)
trainer.fit(
    model=lit_model,
    train_dataloaders=train_loader,
    val_dataloaders=train_loader,
)
logger.info(f"Logs available at {trainer.log_dir}")

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(overfit_batches=1)` was configured so 1 batch will be used.

  | Name  | Type   | Params | Mode 
-----------------------------------------
0 | model | Ranker | 3.4 M  | train
-----------------------------------------
3.4 M     Trainable params
0         Non-trainable params
3.4 M     Total params
13.594    Total estimated model params size (MB)
14        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …


You requested to overfit but enabled val dataloader shuffling. We are turning off the val dataloader shuffling for you.


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.


You requested to overfit but enabled train dataloader shuffling. We are turning off the train dataloader shuffling for you.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.


The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.



Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=100` reached.
[32m2024-10-27 15:11:53.315[0m | [1mINFO    [0m | [36msrc.ranker.trainer[0m:[36mon_fit_end[0m:[36m133[0m - [1mLogging classification metrics...[0m
[32m2024-10-27 15:12:28.807[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m28[0m - [1mLogs available at /Users/dvq/frostmourne/recsys-mvp/notebooks/data/012-3-neg-samples-and-user-item-out-product/logs/overfit/lightning_logs/version_0[0m


In [28]:
%tensorboard --logdir $trainer.log_dir

##### Fit on all data

In [29]:
# papermill_description=fit-model
early_stopping = EarlyStopping(
    monitor="val_loss", patience=args.early_stopping_patience, mode="min", verbose=False
)

checkpoint_callback = ModelCheckpoint(
    dirpath=f"{args.notebook_persist_dp}/checkpoints",
    filename="best-checkpoint",
    save_top_k=1,
    monitor="val_loss",
    mode="min",
)

model = init_model(
    n_users,
    n_items,
    args.embedding_dim,
    item_feature_size,
    dropout=args.dropout,
    item_embedding=pretrained_item_embedding,
)
lit_model = LitRanker(
    model,
    learning_rate=args.learning_rate,
    l2_reg=args.l2_reg,
    log_dir=args.notebook_persist_dp,
    evaluate_ranking=True,
    idm=idm,
    item_metadata_pipeline=item_metadata_pipeline,
    args=args,
    neg_to_pos_ratio=args.neg_to_pos_ratio,
)

log_dir = f"{args.notebook_persist_dp}/logs/run"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    max_epochs=args.max_epochs,
    callbacks=[early_stopping, checkpoint_callback],
    accelerator=args.device if args.device else "auto",
    logger=args._mlf_logger if args.log_to_mlflow else None,
)
trainer.fit(
    model=lit_model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name  | Type   | Params | Mode 
-----------------------------------------
0 | model | Ranker | 3.4 M  | train
-----------------------------------------
3.4 M     Trainable params
0         Non-trainable params
3.4 M     Total params
13.594    Total estimated model params size (MB)
14        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.



Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

[32m2024-10-27 15:27:07.501[0m | [1mINFO    [0m | [36msrc.ranker.trainer[0m:[36mon_fit_end[0m:[36m133[0m - [1mLogging classification metrics...[0m
[32m2024-10-27 15:27:13.344[0m | [1mINFO    [0m | [36msrc.ranker.trainer[0m:[36mon_fit_end[0m:[36m136[0m - [1mLogging ranking metrics...[0m


Generating recommendations:   0%|          | 0/177 [00:00<?, ?it/s]

2024/10/27 15:27:24 INFO mlflow.tracking._tracking_service.client: 🏃 View run 012-3-neg-samples-and-user-item-out-product at: http://localhost:5002/#/experiments/3/runs/e8efa2dd82ea449aaef04caf5a22c27d.
2024/10/27 15:27:24 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://localhost:5002/#/experiments/3.


RuntimeError: mat1 and mat2 shapes cannot be multiplied (18520x153 and 921x128)

In [None]:
logger.info(
    f"Test predicting after training with {args.user_col} = {user_id} and {args.item_col} = {item_id}"
)
model.eval()
model.predict(user, item_sequence, item_feature, item)

# Load best checkpoint

In [None]:
logger.info(f"Loading best checkpoint from {checkpoint_callback.best_model_path}...")
args.best_checkpoint_path = checkpoint_callback.best_model_path

best_trainer = LitRanker.load_from_checkpoint(
    checkpoint_callback.best_model_path,
    model=init_model(
        n_users, n_items, args.embedding_dim, item_feature_size, dropout=0
    ),
)

In [None]:
best_model = best_trainer.model.to(lit_model.device)

In [None]:
best_model.eval()
best_model.predict(user, item_sequence, item_feature, item)

### Persist id mapping

In [None]:
if args.log_to_mlflow:
    # Persist id_mapping so that at inference we can predict based on item_ids (string) instead of item_index
    run_id = trainer.logger.run_id
    mlf_client = trainer.logger.experiment
    mlf_client.log_artifact(run_id, idm_fp)
    # Persist item_feature_metadata pipeline
    mlf_client.log_artifact(run_id, args.item_metadata_pipeline_fp)

### Wrap inference function and register best checkpoint as MLflow model

In [None]:
inferrer = RankerInferenceWrapper(best_model)

In [None]:
sample_input = {
    "user_ids": [idm.get_user_id(0)],
    "item_sequences": [[idm.get_item_id(0), idm.get_item_id(1)]],
    **{col: [train_df[col].iloc[0]] for col in args.item_feature_cols},
    "item_ids": [idm.get_item_id(0)],
}
sample_output = inferrer.infer([0], [[0, 1]], [train_item_features[0]], [0])
sample_output

In [None]:
sample_input

In [None]:
if args.log_to_mlflow:
    run_id = trainer.logger.run_id
    sample_output_np = sample_output
    signature = infer_signature(sample_input, sample_output_np)
    idm_filename = idm_fp.split("/")[-1]
    item_metadata_pipeline_filename = args.item_metadata_pipeline_fp.split("/")[-1]
    with mlflow.start_run(run_id=run_id):
        mlflow.pyfunc.log_model(
            python_model=inferrer,
            artifact_path="inferrer",
            # We log the id_mapping to the predict function so that it can accept item_id and automatically convert ot item_indice for PyTorch model to use
            artifacts={
                "idm": mlflow.get_artifact_uri(idm_filename),
                "item_metadata_pipeline": mlflow.get_artifact_uri(
                    item_metadata_pipeline_filename
                ),
            },
            signature=signature,
            input_example=sample_input,
            registered_model_name=args.mlf_model_name,
        )

# Set the newly trained model as champion

In [None]:
if args.log_to_mlflow:
    val_roc_auc = trainer.logger.experiment.get_run(trainer.logger.run_id).data.metrics[
        "val_roc_auc"
    ]

    if val_roc_auc > args.min_roc_auc:
        logger.info(f"Aliasing the new model as champion...")
        model_version = (
            mlf_client.get_registered_model(args.mlf_model_name)
            .latest_versions[0]
            .version
        )

        mlf_client.set_registered_model_alias(
            name=args.mlf_model_name, alias="champion", version=model_version
        )

        mlf_client.set_model_version_tag(
            name=args.mlf_model_name,
            version=model_version,
            key="author",
            value="quy.dinh",
        )

# Clean up

In [None]:
all_params = [args]

if args.log_to_mlflow:
    with mlflow.start_run(run_id=run_id):
        for params in all_params:
            params_dict = params.dict()
            params_ = dict()
            for k, v in params_dict.items():
                if k == "top_K":
                    k = "top_big_K"
                if k == "top_k":
                    k = "top_small_k"
                params_[f"{params.__repr_name__()}.{k}"] = v
            mlflow.log_params(params_)