# Ranker that can takes into accound different features

# Set up

In [1]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

In [2]:
import os
import sys
from typing import List

import dill
import lightning as L
import numpy as np
import pandas as pd
import torch
from tqdm.auto import tqdm
from dotenv import load_dotenv
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import MLFlowLogger
from loguru import logger
from mlflow.models.signature import infer_signature
from pydantic import BaseModel
from torch.utils.data import DataLoader
from qdrant_client import QdrantClient

import mlflow

load_dotenv()

sys.path.insert(0, "..")

from src.data_prep_utils import chunk_transform
from src.dataset import UserItemBinaryDFDataset
from src.id_mapper import IDMapper
from src.ranker.inference import RankerInferenceWrapper
from src.ranker.model import Ranker
from src.ranker.trainer import LitRanker
from src.viz import blueq_colors

# Controller

In [3]:
max_epochs = 100

In [4]:
class Args(BaseModel):
    testing: bool = False
    log_to_mlflow: bool = True
    experiment_name: str = "RecSys MVP - Ranker"
    run_name: str = "016-reduce-dim-and-no-pretrained-item-embeddings"
    notebook_persist_dp: str = None
    random_seed: int = 41
    device: str = None

    item_metadata_pipeline_fp: str = "../data/item_metadata_pipeline.dill"
    qdrant_url: str = None
    qdrant_collection_name: str = "item_desc_sbert"

    max_epochs: int = max_epochs
    batch_size: int = 128
    tfm_chunk_size: int = 10000
    neg_to_pos_ratio: int = 3

    user_col: str = "user_id"
    item_col: str = "parent_asin"
    rating_col: str = "rating"
    timestamp_col: str = "timestamp"
    item_feature_cols: List[str] = ["main_category", "categories"]

    top_K: int = 100
    top_k: int = 10

    embedding_dim: int = 128
    dropout: float = 0.3
    early_stopping_patience: int = 5
    learning_rate: float = 0.0003
    l2_reg: float = 1e-4

    mlf_item2vec_model_name: str = "item2vec"
    mlf_model_name: str = "ranker"
    min_roc_auc: float = 0.7

    best_checkpoint_path: str = None

    def init(self):
        self.notebook_persist_dp = os.path.abspath(f"data/{self.run_name}")
        os.makedirs(self.notebook_persist_dp, exist_ok=True)

        if not (qdrant_host := os.getenv("QDRANT_HOST")):
            raise Exception(f"Environment variable QDRANT_HOST is not set.")

        qdrant_port = os.getenv("QDRANT_PORT")
        self.qdrant_url = f"{qdrant_host}:{qdrant_port}"
        
        if not (mlflow_uri := os.environ.get("MLFLOW_TRACKING_URI")):
            logger.warning(
                f"Environment variable MLFLOW_TRACKING_URI is not set. Setting self.log_to_mlflow to false."
            )
            self.log_to_mlflow = False

        if self.log_to_mlflow:
            logger.info(
                f"Setting up MLflow experiment {self.experiment_name} - run {self.run_name}..."
            )
            self._mlf_logger = MLFlowLogger(
                experiment_name=self.experiment_name,
                run_name=self.run_name,
                tracking_uri=mlflow_uri,
                log_model=True,
            )

        return self


args = Args().init()

print(args.model_dump_json(indent=2))

[32m2024-10-27 18:01:21.535[0m | [1mINFO    [0m | [36m__main__[0m:[36minit[0m:[36m57[0m - [1mSetting up MLflow experiment RecSys MVP - Ranker - run 016-reduce-dim-and-no-pretrained-item-embeddings...[0m


{
  "testing": false,
  "log_to_mlflow": true,
  "experiment_name": "RecSys MVP - Ranker",
  "run_name": "016-reduce-dim-and-no-pretrained-item-embeddings",
  "notebook_persist_dp": "/Users/dvq/frostmourne/recsys-mvp/notebooks/data/016-reduce-dim-and-no-pretrained-item-embeddings",
  "random_seed": 41,
  "device": null,
  "item_metadata_pipeline_fp": "../data/item_metadata_pipeline.dill",
  "qdrant_url": "localhost:6333",
  "qdrant_collection_name": "item_desc_sbert",
  "max_epochs": 100,
  "batch_size": 128,
  "tfm_chunk_size": 10000,
  "neg_to_pos_ratio": 3,
  "user_col": "user_id",
  "item_col": "parent_asin",
  "rating_col": "rating",
  "timestamp_col": "timestamp",
  "item_feature_cols": [
    "main_category",
    "categories"
  ],
  "top_K": 100,
  "top_k": 10,
  "embedding_dim": 64,
  "dropout": 0.3,
  "early_stopping_patience": 5,
  "learning_rate": 0.0003,
  "l2_reg": 0.0001,
  "mlf_item2vec_model_name": "item2vec",
  "mlf_model_name": "ranker",
  "min_roc_auc": 0.7,
  "best_c

# Implement

In [5]:
def init_model(
    n_users, n_items, embedding_dim, item_feature_size, dropout, item_embedding=None
):
    model = Ranker(
        n_users,
        n_items,
        embedding_dim,
        item_feature_size=item_feature_size,
        dropout=dropout,
        item_embedding=item_embedding,
    )
    return model

## Load pretrained Item2Vec embeddings

In [6]:
mlf_client = mlflow.MlflowClient()
model = mlflow.pyfunc.load_model(
    model_uri=f"models:/{args.mlf_item2vec_model_name}@champion"
)
skipgram_model = model.unwrap_python_model().model
embedding_0 = skipgram_model.embeddings(torch.tensor(0))
embedding_dim = embedding_0.size()[0]
id_mapping = model.unwrap_python_model().id_mapping
pretrained_item_embedding = skipgram_model.embeddings

Downloading artifacts:   0%|          | 0/8 [00:00<?, ?it/s]

In [7]:
# assert (
#     pretrained_item_embedding.embedding_dim == args.embedding_dim
# ), "Mismatch pretrained item_embedding dimension"

## Load vectorized item features

In [8]:
with open(args.item_metadata_pipeline_fp, "rb") as f:
    item_metadata_pipeline = dill.load(f)

## Load ANN Index

In [9]:
ann_index = QdrantClient(url=args.qdrant_url)
if not ann_index.collection_exists(args.qdrant_collection_name):
    raise Exception(
        f"Required Qdrant collection {args.qdrant_collection_name} does not exist"
    )

In [10]:
def get_vector_by_ids(ids: List[int], chunk_size=100):
    records = []
    for i in tqdm(range(0, len(ids), chunk_size)):
        _ids = ids[i:i+chunk_size]
        _records = ann_index.retrieve(
            collection_name=args.qdrant_collection_name, ids=_ids, with_vectors=True
        )
        records.extend(_records)
    return np.array([record.vector for record in records])

In [11]:
vector = get_vector_by_ids([0])[0]
sbert_embedding_dim = vector.shape[0]
neighbors = ann_index.search(
    collection_name=args.qdrant_collection_name, query_vector=vector, limit=5
)

  0%|          | 0/1 [00:00<?, ?it/s]

In [12]:
neighbors

[ScoredPoint(id=0, version=0, score=1.0, payload={'parent_asin': '0375869026', 'title': 'Wonder'}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=1916, version=59, score=0.9272537, payload={'parent_asin': 'B005GFPZYK', 'title': 'American Sniper: The Autobiography of the Most Lethal Sniper in U.S. Military History'}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=1968, version=61, score=0.9080587, payload={'parent_asin': 'B005ZBO4VA', 'title': 'Tekken Hybrid - Playstation 3'}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=166, version=5, score=0.899632, payload={'parent_asin': 'B00005OARM', 'title': 'Golden Sun'}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=3896, version=121, score=0.89472985, payload={'parent_asin': 'B07CD6F5PX', 'title': 'Dragon Quest Xi: Echoes of An Elusive Age - PlayStation 4'}, vector=None, shard_key=None, order_value=None)]

# Test implementation

In [13]:
embedding_dim = 8
batch_size = 2

# Mock data
user_indices = [0, 0, 1, 2, 2]
item_indices = [0, 1, 2, 3, 4]
timestamps = [0, 1, 2, 3, 4]
ratings = [0, 4, 5, 3, 0]
item_sequences = [
    [-1, -1, 2, 3],
    [-1, -1, 2, 3],
    [-1, -1, 1, 3],
    [-1, -1, 2, 1],
    [-1, -1, 2, 1],
]
main_category = [
    "All Electronics",
    "Video Games",
    "All Electronics",
    "Video Games",
    "Unknown",
]
categories = [[], ["Headsets"], ["Video Games"], [], ["blah blah"]]
title = ["World of Warcraft", "DotA 2", "Diablo IV", "Football Manager 2024", "Unknown"]
description = [[], [], ["Video games blah blah"], [], ["blah blah"]]
price = ["from 14.99", "14.99", "price: 9.99", "20 dollars", "None"]

train_df = pd.DataFrame(
    {
        "user_indice": user_indices,
        "item_indice": item_indices,
        args.timestamp_col: timestamps,
        args.rating_col: ratings,
        "item_sequence": item_sequences,
        "main_category": main_category,
        "title": title,
        "description": description,
        "categories": categories,
        "price": price,
    }
)
train_item_features = item_metadata_pipeline.transform(train_df).astype(np.float32)
sbert_vectors = get_vector_by_ids(train_df['item_indice'].values.tolist()).astype(np.float32)
train_item_features = np.hstack([train_item_features, sbert_vectors])

n_users = len(set(user_indices))
n_items = len(set(item_indices))
item_feature_size = train_item_features.shape[1]

model = init_model(n_users, n_items, embedding_dim, item_feature_size, args.dropout)

# Example forward pass
model.eval()
users = torch.tensor(user_indices)
items = torch.tensor(item_indices)
item_sequences = torch.tensor(item_sequences)
item_features = torch.tensor(train_item_features)
predictions = model.predict(users, item_sequences, item_features, items)
print(predictions)

  0%|          | 0/1 [00:00<?, ?it/s]

tensor([[0.6009],
        [0.6075],
        [0.6160],
        [0.6107],
        [0.5799]], grad_fn=<SigmoidBackward0>)


In [14]:
rating_dataset = UserItemBinaryDFDataset(
    train_df,
    "user_indice",
    "item_indice",
    args.rating_col,
    args.timestamp_col,
    item_feature=train_item_features,
)

train_loader = DataLoader(rating_dataset, batch_size=batch_size, shuffle=False)

In [15]:
for batch_input in train_loader:
    print(batch_input)

{'user': tensor([0, 0]), 'item': tensor([0, 1]), 'rating': tensor([0., 1.]), 'item_sequence': tensor([[-1, -1,  2,  3],
        [-1, -1,  2,  3]]), 'item_feature': tensor([[-1.4698e-02,  5.6424e+00, -1.4698e-02,  ...,  2.9136e-02,
         -3.0617e-02,  1.4856e-03],
        [-1.4698e-02, -1.7723e-01, -1.4698e-02,  ..., -2.6627e-03,
         -3.4459e-02,  1.1402e-02]])}
{'user': tensor([1, 2]), 'item': tensor([2, 3]), 'rating': tensor([1., 1.]), 'item_sequence': tensor([[-1, -1,  1,  3],
        [-1, -1,  2,  1]]), 'item_feature': tensor([[-1.4698e-02,  5.6424e+00, -1.4698e-02,  ...,  2.4071e-03,
         -4.1083e-02,  4.6736e-04],
        [-1.4698e-02, -1.7723e-01, -1.4698e-02,  ..., -5.8320e-03,
         -6.7804e-02,  4.7192e-03]])}
{'user': tensor([2]), 'item': tensor([4]), 'rating': tensor([0.]), 'item_sequence': tensor([[-1, -1,  2,  1]]), 'item_feature': tensor([[-1.4698e-02, -1.7723e-01, -1.4698e-02, -2.5463e-02, -1.4698e-02,
         -1.4698e-02, -2.0788e-02, -8.2101e-02, -2.872

In [16]:
# model
lit_model = LitRanker(model, log_dir=args.notebook_persist_dp)

# train model
trainer = L.Trainer(
    default_root_dir=f"{args.notebook_persist_dp}/test",
    max_epochs=2,
    accelerator=args.device if args.device else "auto",
)
trainer.fit(
    model=lit_model, train_dataloaders=train_loader, val_dataloaders=train_loader
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name  | Type   | Params | Mode
----------------------------------------
0 | model | Ranker | 8.4 K  | eval
----------------------------------------
8.4 K     Trainable params
0         Non-trainable params
8.4 K     Total params
0.034     Total estimated model params size (MB)
0         Modules in train mode
14        Modules in eval mode


Sanity Checking: |                                                                                            …

/Users/dvq/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.
/Users/dvq/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.
/Users/dvq/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (3) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=2` reached.
[32m2024-10-27 18:01:23.216[0m | [1mINFO    [0m | [36msrc.ranker.trainer[0m:[36mon_fit_end[0m:[36m133[0m - [1mLogging classification metrics...[0m
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [17]:
# After fitting
model.eval()
predictions = model.predict(users, item_sequences, item_features, items)
print(predictions)

tensor([[0.5814],
        [0.6179],
        [0.6267],
        [0.6108],
        [0.5697]], grad_fn=<SigmoidBackward0>)


In [18]:
all_items_df = train_df.drop_duplicates(subset=["item_indice"])
all_items_indices = all_items_df["item_indice"].values
all_items_features = item_metadata_pipeline.transform(all_items_df).astype(np.float32)
all_sbert_vectors = get_vector_by_ids(all_items_indices.tolist()).astype(np.float32)
all_items_features = np.hstack([all_items_features, all_sbert_vectors])

# Get the last row of each item as input for recommendations (containing the most updated item_sequence)
to_rec_df = train_df.sort_values(args.timestamp_col, ascending=False).drop_duplicates(
    subset=["user_indice"]
)
recommendations = model.recommend(
    torch.tensor(to_rec_df["user_indice"].values.tolist()),
    torch.tensor(to_rec_df["item_sequence"].values.tolist()),
    torch.tensor(all_items_features),
    torch.tensor(all_items_indices),
    k=2,
    batch_size=4,
)
recommendations

  0%|          | 0/1 [00:00<?, ?it/s]

Generating recommendations:   0%|          | 0/1 [00:00<?, ?it/s]

{'user_indice': [2, 2, 1, 1, 0, 0],
 'recommendation': [3, 1, 1, 3, 1, 3],
 'score': [0.6108070015907288,
  0.6010299324989319,
  0.6547285318374634,
  0.6444249749183655,
  0.6178779602050781,
  0.6057987809181213]}

# Prep data

In [19]:
train_df = pd.read_parquet("../data/train_features_neg_df.parquet")
val_df = pd.read_parquet("../data/val_features_neg_df.parquet")
idm_fp = "../data/idm.json"
idm = IDMapper().load(idm_fp)

assert (
    train_df[args.user_col].map(lambda s: idm.get_user_index(s))
    != train_df["user_indice"]
).sum() == 0, "Mismatch IDM"
assert (
    val_df[args.user_col].map(lambda s: idm.get_user_index(s)) != val_df["user_indice"]
).sum() == 0, "Mismatch IDM"

In [20]:
train_df

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,main_category,categories,price,user_rating_cnt_90d,user_rating_avg_prev_rating_90d,user_rating_list_10_recent_asin,item_sequence
0,AEKSUPM7CH53J3G5PA3JLWLJXUMQ,B00QXJFDZO,0.0,2017-10-30 14:23:22.389,2561,2919,Video Games,"[Video Games, PlayStation 4, Games]",,2,5.000000,"B005FVBYV8,B003FMTZSI,B01MS6WG9S,B073W2T5F6","[-1, -1, -1, -1, -1, -1, 1912, 1470, 3498, 3723]"
1,AHSNMFN6DUFTNEZAXBVPIYMXWIFQ,B075MYT126,0.0,2017-11-27 22:01:33.258,18413,3777,Video Games,"[Video Games, Nintendo Switch, Accessories, Co...",94.98,2,4.000000,"B00CJ9OTNE,B0118YZG0A,B008M502H6,B003Y70W4U,B0...","[2391, 3100, 2176, 1588, 3161, 2133, 2906, 166..."
2,AFE47G5MX35LSHZHZXRYEJFMYPUA,B007VYW5K6,0.0,2017-03-23 21:41:18.000,6463,2086,Video Games,"[Video Games, PC]",,1,,"B07YBX8RNF,B0166QDJDQ,B01CHU4IY4,B00Z9LUDX4,B0...","[4278, 3183, 3288, 3038, 4508, 3391, 3403, 368..."
3,AFJDWGBE3MGULXTO3FUZ5YB6FKDA,B07L5FKGQH,0.0,2017-01-18 15:50:12.000,7246,4048,Video Games,"[Video Games, Xbox One, Games]",49.88,25,4.416667,"B00I6E6SH6,B00O65I2VY,B005GISQQG,B00008KTNW,B0...","[2632, 2859, 1920, 253, 1428, 1053, 584, 732, ..."
4,AHFDYGJR3SM2D463ZWKGHJPNBKDA,B002BSA2LQ,0.0,2014-01-29 22:50:20.000,16376,1215,Video Games,"[Video Games, Legacy Systems, Xbox Systems, Xb...",31.49,4,5.000000,"B002I0K956,B008CZN458,B0050SXVK8","[-1, -1, -1, -1, -1, -1, -1, 1328, 2146, 1843]"
...,...,...,...,...,...,...,...,...,...,...,...,...,...
657187,AE5TQ7DBEX2L5T665M6ZDPGYZ32Q,B01LDUYTYS,0.0,2013-10-05 20:20:52.000,592,3442,Video Games,"[Video Games, Legacy Systems, Nintendo Systems...",249.99,1,,B07X1HF3V6,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, 4237]"
657188,AFDG3CXM4DP7X436YNOKTJHVKJQA,B087NNPYP3,5.0,2018-07-10 21:22:10.594,6351,4342,Video Games,"[Video Games, Nintendo Switch, Consoles]",,3,5.000000,"B002I0H79C,B00503E9FY,B00KVOVBGM,B00SHXKC8M,B0...","[1292, 1807, 2734, 2940, 3402, 2759, 2702, 104..."
657189,AFOUC3S3RH7AXMPZBZHLO4WMLLVA,B004AM65C6,0.0,2018-12-16 13:39:37.174,8062,1651,Video Games,"[Video Games, Legacy Systems, Xbox Systems, Xb...",12.48,2,5.000000,"B002BSA388,B00PIEI1DG,B08MBHYJP4,B071GPJVTQ,B0...","[1216, 2898, 4397, 3643, 3642, 3527, 3423, 367..."
657190,AEPOGF2QMAXO4W3TYP27DCQRITGA,B07X1HF3V6,0.0,2013-05-30 22:53:17.000,3373,4237,Video Games,"[Video Games, Legacy Systems, PlayStation Syst...",34.43,4,3.666667,"B0013OL0BK,B002D2Y3IS,B0044R8X9U,B07VLCRZ21,B0...","[-1, -1, -1, -1, -1, 652, 1240, 1629, 4207, 1707]"


In [21]:
user_indices = train_df["user_indice"].unique()
item_indices = train_df["item_indice"].unique()
all_sbert_vectors = get_vector_by_ids(item_indices.tolist(), chunk_size=1000).astype(np.float32)

train_item_features = chunk_transform(
    train_df, item_metadata_pipeline, chunk_size=args.tfm_chunk_size
)
train_item_features = train_item_features.astype(np.float32)
train_sbert_vectors = all_sbert_vectors[train_df['item_indice'].values]
train_item_features = np.hstack([train_item_features, train_sbert_vectors])

val_item_features = chunk_transform(
    val_df, item_metadata_pipeline, chunk_size=args.tfm_chunk_size
)
val_item_features = val_item_features.astype(np.float32)
val_sbert_vectors = all_sbert_vectors[val_df['item_indice'].values]
val_item_features = np.hstack([val_item_features, val_sbert_vectors])

logger.info(f"{len(user_indices)=:,.0f}, {len(item_indices)=:,.0f}")

  0%|          | 0/5 [00:00<?, ?it/s]

Transforming chunks:   0%|          | 0/66 [00:00<?, ?it/s]

Transforming chunks:   0%|          | 0/1 [00:00<?, ?it/s]

[32m2024-10-27 18:01:30.252[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m19[0m - [1mlen(user_indices)=19,578, len(item_indices)=4,630[0m


# Train

In [22]:
rating_dataset = UserItemBinaryDFDataset(
    train_df,
    "user_indice",
    "item_indice",
    args.rating_col,
    args.timestamp_col,
    item_feature=train_item_features,
)
val_rating_dataset = UserItemBinaryDFDataset(
    val_df,
    "user_indice",
    "item_indice",
    args.rating_col,
    args.timestamp_col,
    item_feature=val_item_features,
)

train_loader = DataLoader(
    rating_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True
)
val_loader = DataLoader(
    val_rating_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False
)

In [23]:
n_items = len(item_indices)
n_users = len(user_indices)

model = init_model(
    n_users, n_items, args.embedding_dim, item_feature_size, args.dropout
)
model

Ranker(
  (item_embedding): Embedding(4631, 64, padding_idx=4630)
  (user_embedding): Embedding(19578, 64)
  (gru): GRU(128, 64, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (item_feature_tower): Sequential(
    (0): Linear(in_features=921, out_features=64, bias=True)
    (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
  )
  (fc_rating): Sequential(
    (0): Linear(in_features=256, out_features=64, bias=True)
    (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=64, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

#### Predict before train

In [24]:
val_df = val_rating_dataset.df
val_df.sample(10)

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,main_category,categories,price,user_rating_cnt_90d,user_rating_avg_prev_rating_90d,user_rating_list_10_recent_asin,item_sequence
364,AGFNYWOFY5N6MWIDLMEROWT5Z2IA,B09KL8P6DP,1.0,2022-03-08 18:42:41.757,11525,4484,All Electronics,"[Video Games, Legacy Systems, Xbox Systems, Xb...",16.99,1,,"B001EYUQKU,B001QCWRWK,B001EYUY3Y,B00HX1UXD8,B0...","[-1, -1, 917, 1123, 1035, 2626, 4457, 1228, 28..."
221,AGKBFG5VF5ZBTL5BUYVVZ2PCB5GA,B0764BLMJL,0.0,2021-09-06 01:55:57.087,12224,3789,Video Games,"[Video Games, PlayStation 4, Games]",39.99,1,,"B00J7N6C26,B07BW77M9V,B00BZS9JV2,B00DC7O77A,B0...","[-1, -1, -1, -1, -1, 2673, 3882, 2367, 2438, 4..."
1134,AHMRBNO7J7WSCIPKFIT743QO55XA,B0C5K4M7WJ,1.0,2022-04-28 23:36:09.433,17534,4614,Computers,"[Video Games, PC, Accessories, Headsets]",119.99,1,,"B006PP404Q,B00R0ZS9YW,B009MRZAUC,B00L59D9HG,B0...","[2010, 2920, 2244, 2759, 2281, 3199, 2778, 287..."
2218,AHCFZPQLIV7236UBXSGU3I5SMXBQ,B007WPGT2Y,0.0,2021-11-01 01:28:20.415,15933,2092,Video Games,"[Video Games, Legacy Systems, PlayStation Syst...",22.5,3,5.0,"B004LLHFAW,B07STWQ38X,B072V478NR,B076GXJNDZ,B0...","[1733, 4170, 3705, 3798, 4612, 3691, 3087, 347..."
3105,AF3TAMOYZEDTYYEX4ZB23A6CS7ZA,B001EYUPJ2,0.0,2021-12-18 23:14:30.970,5197,874,Video Games,"[Video Games, Legacy Systems, Xbox Systems, Xb...",33.02,1,,"B000HZFCT2,B000FGA1US,B00ZMBLKPG,B016P6KIFY,B0...","[496, 450, 3077, 3190, 3434, 1869, 4284, 2152,..."
3422,AFIXV7CY3OC6WI5DXCS3JAGP5SQA,B001EYUQTQ,0.0,2022-07-15 12:10:21.161,7189,923,,"[Video Games, Legacy Systems, PlayStation Syst...",15.0,1,,"B009CYJ8SA,B019WRM1IA,B015IX3X3E,B011ALF4XK,B0...","[2225, 3242, 3168, 3103, 2931, 2759, 1820, 367..."
1781,AFW22P67ZJR422WXAPZDLFQSTCAQ,B0053B5RGI,0.0,2022-07-13 01:03:37.935,9188,1869,Video Games,"[Video Games, Legacy Systems, Nintendo Systems...",49.99,1,,"B0CB8LZT7K,B0795GHTBC,B087NMYQYG,B07D13QGXM,B0...","[-1, -1, -1, -1, 4626, 3846, 4339, 3913, 4533,..."
776,AHSENHOIWSXPI2JWYX3RJXDYFRVA,B08GCGZ221,1.0,2021-10-29 14:57:32.907,18368,4380,All Electronics,"[Video Games, Nintendo Switch, Accessories, Ca...",9.99,2,5.0,"B07V3G4SNR,B07TCPVHTM,B07ST7LY8F,B087NNZZM8,B0...","[4193, 4181, 4169, 4343, 4477, 4603, 4195, 444..."
1601,AFWMPFOSPMWMIJUMWAPCWBTTAHJQ,B0714FFZD8,0.0,2021-11-18 04:25:17.644,9277,3625,Video Games,"[Video Games, Nintendo Switch, Games]",44.99,1,,"B0036EWMIK,B07HHVF2XG,B07YBXFF5C,B00MAM0ALA,B0...","[1426, 4003, 4284, 2791, 3504, 2804, 1875, 395..."
826,AHQXEAJC5ZNBESCNQNJDSU63AJTQ,B000N5Z2L4,0.0,2022-02-17 06:14:34.666,18163,526,Video Games,"[Video Games, Xbox Digital Content, Subscripti...",9.99,1,,"B0088TN7BO,B0084A97Y8,B00KWIYPZG,B0053BCML6,B0...","[2124, 2110, 2745, 1874, 1947, 2582, 2905, 314..."


In [25]:
user_id = val_df.sample(1)[args.user_col].values[0]
# user_id = "AH4AOFTTDPHPAFAAVFMAF25H2LIQ"
test_df = val_df.loc[lambda df: df[args.user_col].eq(user_id)]
with pd.option_context("display.max_colwidth", None):
    display(test_df)

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,main_category,categories,price,user_rating_cnt_90d,user_rating_avg_prev_rating_90d,user_rating_list_10_recent_asin,item_sequence
1391,AECCEXRYXAFKLBRLXSLG434XB2LQ,B07YBX7Y3P,0.0,2022-03-20 20:30:20.825,1258,4276,Video Games,"[Video Games, Xbox One, Games]",14.98,1,,"B07NPYW85T,B07D13QGXM,B01N3ASPNV,B06XHLM4DX,B0771ZXXV6,B07N8T1K97,B017W17G1U,B07V3T7WF1,B00Z9TKWQO","[-1, 4089, 3913, 3527, 3573, 3806, 4075, 3213, 4195, 3053]"
1568,AECCEXRYXAFKLBRLXSLG434XB2LQ,B001EYUQNM,0.0,2022-03-20 20:30:20.825,1258,920,Video Games,"[Video Games, Legacy Systems, Xbox Systems, Xbox, Games]",6.97,1,,"B07NPYW85T,B07D13QGXM,B01N3ASPNV,B06XHLM4DX,B0771ZXXV6,B07N8T1K97,B017W17G1U,B07V3T7WF1,B00Z9TKWQO","[-1, 4089, 3913, 3527, 3573, 3806, 4075, 3213, 4195, 3053]"
2255,AECCEXRYXAFKLBRLXSLG434XB2LQ,B00L3LQ1FI,0.0,2022-03-20 20:30:20.825,1258,2755,Video Games,"[Video Games, Legacy Systems, Nintendo Systems, Wii U, Accessories]",69.98,1,,"B07NPYW85T,B07D13QGXM,B01N3ASPNV,B06XHLM4DX,B0771ZXXV6,B07N8T1K97,B017W17G1U,B07V3T7WF1,B00Z9TKWQO","[-1, 4089, 3913, 3527, 3573, 3806, 4075, 3213, 4195, 3053]"
2934,AECCEXRYXAFKLBRLXSLG434XB2LQ,B00HBUOVIO,1.0,2022-03-20 20:30:20.825,1258,2596,Grocery,"[Video Games, PlayStation 4, Games]",10.49,1,,"B07NPYW85T,B07D13QGXM,B01N3ASPNV,B06XHLM4DX,B0771ZXXV6,B07N8T1K97,B017W17G1U,B07V3T7WF1,B00Z9TKWQO","[-1, 4089, 3913, 3527, 3573, 3806, 4075, 3213, 4195, 3053]"


In [26]:
test_row = test_df.loc[lambda df: df[args.rating_col].gt(0)].iloc[0]
item_id = test_row[args.item_col]
item_sequence = test_row["item_sequence"]
row_idx = test_row.name
item_feature = val_item_features[row_idx]
logger.info(
    f"Test predicting before training with {args.user_col} = {user_id} and {args.item_col} = {item_id}"
)
user_indice = idm.get_user_index(user_id)
item_indice = idm.get_item_index(item_id)
user = torch.tensor([user_indice])
item_sequence = torch.tensor([item_sequence])
item_feature = torch.tensor([item_feature])
item = torch.tensor([item_indice])

model.eval()
model.predict(user, item_sequence, item_feature, item)

[32m2024-10-27 18:01:30.685[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mTest predicting before training with user_id = AECCEXRYXAFKLBRLXSLG434XB2LQ and parent_asin = B00HBUOVIO[0m

Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_new.cpp:281.)



tensor([[0.5514]], grad_fn=<SigmoidBackward0>)

#### Training loop

##### Overfit 1 batch

In [27]:
early_stopping = EarlyStopping(
    monitor="val_loss", patience=10, mode="min", verbose=False
)

model = init_model(n_users, n_items, args.embedding_dim, item_feature_size, dropout=0)
lit_model = LitRanker(
    model,
    learning_rate=args.learning_rate,
    l2_reg=0.0,
    log_dir=args.notebook_persist_dp,
)

log_dir = f"{args.notebook_persist_dp}/logs/overfit"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    accelerator=args.device if args.device else "auto",
    max_epochs=100,
    overfit_batches=1,
    callbacks=[early_stopping],
)
trainer.fit(
    model=lit_model,
    train_dataloaders=train_loader,
    val_dataloaders=train_loader,
)
logger.info(f"Logs available at {trainer.log_dir}")

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(overfit_batches=1)` was configured so 1 batch will be used.

  | Name  | Type   | Params | Mode 
-----------------------------------------
0 | model | Ranker | 1.7 M  | train
-----------------------------------------
1.7 M     Trainable params
0         Non-trainable params
1.7 M     Total params
6.650     Total estimated model params size (MB)
14        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …


You requested to overfit but enabled val dataloader shuffling. We are turning off the val dataloader shuffling for you.


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.


You requested to overfit but enabled train dataloader shuffling. We are turning off the train dataloader shuffling for you.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.


The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.



Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=100` reached.
[32m2024-10-27 18:01:37.981[0m | [1mINFO    [0m | [36msrc.ranker.trainer[0m:[36mon_fit_end[0m:[36m133[0m - [1mLogging classification metrics...[0m
[32m2024-10-27 18:02:10.301[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m28[0m - [1mLogs available at /Users/dvq/frostmourne/recsys-mvp/notebooks/data/016-reduce-dim-and-no-pretrained-item-embeddings/logs/overfit/lightning_logs/version_0[0m


In [28]:
%tensorboard --logdir $trainer.log_dir

##### Fit on all data

In [29]:
# papermill_description=fit-model
early_stopping = EarlyStopping(
    monitor="val_loss", patience=args.early_stopping_patience, mode="min", verbose=False
)

checkpoint_callback = ModelCheckpoint(
    dirpath=f"{args.notebook_persist_dp}/checkpoints",
    filename="best-checkpoint",
    save_top_k=1,
    monitor="val_loss",
    mode="min",
)

model = init_model(
    n_users,
    n_items,
    args.embedding_dim,
    item_feature_size,
    dropout=args.dropout,
    item_embedding=pretrained_item_embedding,
)
lit_model = LitRanker(
    model,
    learning_rate=args.learning_rate,
    l2_reg=args.l2_reg,
    log_dir=args.notebook_persist_dp,
    evaluate_ranking=True,
    idm=idm,
    item_metadata_pipeline=item_metadata_pipeline,
    args=args,
    neg_to_pos_ratio=args.neg_to_pos_ratio,
)

log_dir = f"{args.notebook_persist_dp}/logs/run"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    max_epochs=args.max_epochs,
    callbacks=[early_stopping, checkpoint_callback],
    accelerator=args.device if args.device else "auto",
    logger=args._mlf_logger if args.log_to_mlflow else None,
)
trainer.fit(
    model=lit_model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name  | Type   | Params | Mode 
-----------------------------------------
0 | model | Ranker | 1.7 M  | train
-----------------------------------------
1.7 M     Trainable params
0         Non-trainable params
1.7 M     Total params
6.650     Total estimated model params size (MB)
14        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.



Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

[32m2024-10-27 18:21:51.777[0m | [1mINFO    [0m | [36msrc.ranker.trainer[0m:[36mon_fit_end[0m:[36m133[0m - [1mLogging classification metrics...[0m
[32m2024-10-27 18:21:57.593[0m | [1mINFO    [0m | [36msrc.ranker.trainer[0m:[36mon_fit_end[0m:[36m136[0m - [1mLogging ranking metrics...[0m


Generating recommendations:   0%|          | 0/177 [00:00<?, ?it/s]

2024/10/27 18:22:03 INFO mlflow.tracking._tracking_service.client: 🏃 View run 016-reduce-dim-and-no-pretrained-item-embeddings at: http://localhost:5002/#/experiments/3/runs/19ebdbea212b40acacf3e2c351e36dee.
2024/10/27 18:22:03 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://localhost:5002/#/experiments/3.


RuntimeError: mat1 and mat2 shapes cannot be multiplied (18520x153 and 921x64)

In [None]:
logger.info(
    f"Test predicting after training with {args.user_col} = {user_id} and {args.item_col} = {item_id}"
)
model.eval()
model.predict(user, item_sequence, item_feature, item)

# Load best checkpoint

In [None]:
logger.info(f"Loading best checkpoint from {checkpoint_callback.best_model_path}...")
args.best_checkpoint_path = checkpoint_callback.best_model_path

best_trainer = LitRanker.load_from_checkpoint(
    checkpoint_callback.best_model_path,
    model=init_model(
        n_users, n_items, args.embedding_dim, item_feature_size, dropout=0
    ),
)

In [None]:
best_model = best_trainer.model.to(lit_model.device)

In [None]:
best_model.eval()
best_model.predict(user, item_sequence, item_feature, item)

### Persist id mapping

In [None]:
if args.log_to_mlflow:
    # Persist id_mapping so that at inference we can predict based on item_ids (string) instead of item_index
    run_id = trainer.logger.run_id
    mlf_client = trainer.logger.experiment
    mlf_client.log_artifact(run_id, idm_fp)
    # Persist item_feature_metadata pipeline
    mlf_client.log_artifact(run_id, args.item_metadata_pipeline_fp)

### Wrap inference function and register best checkpoint as MLflow model

In [None]:
inferrer = RankerInferenceWrapper(best_model)

In [None]:
sample_input = {
    "user_ids": [idm.get_user_id(0)],
    "item_sequences": [[idm.get_item_id(0), idm.get_item_id(1)]],
    **{col: [train_df[col].iloc[0]] for col in args.item_feature_cols},
    "item_ids": [idm.get_item_id(0)],
}
sample_output = inferrer.infer([0], [[0, 1]], [train_item_features[0]], [0])
sample_output

In [None]:
sample_input

In [None]:
if args.log_to_mlflow:
    run_id = trainer.logger.run_id
    sample_output_np = sample_output
    signature = infer_signature(sample_input, sample_output_np)
    idm_filename = idm_fp.split("/")[-1]
    item_metadata_pipeline_filename = args.item_metadata_pipeline_fp.split("/")[-1]
    with mlflow.start_run(run_id=run_id):
        mlflow.pyfunc.log_model(
            python_model=inferrer,
            artifact_path="inferrer",
            # We log the id_mapping to the predict function so that it can accept item_id and automatically convert ot item_indice for PyTorch model to use
            artifacts={
                "idm": mlflow.get_artifact_uri(idm_filename),
                "item_metadata_pipeline": mlflow.get_artifact_uri(
                    item_metadata_pipeline_filename
                ),
            },
            signature=signature,
            input_example=sample_input,
            registered_model_name=args.mlf_model_name,
        )

# Set the newly trained model as champion

In [None]:
if args.log_to_mlflow:
    val_roc_auc = trainer.logger.experiment.get_run(trainer.logger.run_id).data.metrics[
        "val_roc_auc"
    ]

    if val_roc_auc > args.min_roc_auc:
        logger.info(f"Aliasing the new model as champion...")
        model_version = (
            mlf_client.get_registered_model(args.mlf_model_name)
            .latest_versions[0]
            .version
        )

        mlf_client.set_registered_model_alias(
            name=args.mlf_model_name, alias="champion", version=model_version
        )

        mlf_client.set_model_version_tag(
            name=args.mlf_model_name,
            version=model_version,
            key="author",
            value="quy.dinh",
        )

# Clean up

In [None]:
all_params = [args]

if args.log_to_mlflow:
    with mlflow.start_run(run_id=run_id):
        for params in all_params:
            params_dict = params.dict()
            params_ = dict()
            for k, v in params_dict.items():
                if k == "top_K":
                    k = "top_big_K"
                if k == "top_k":
                    k = "top_small_k"
                params_[f"{params.__repr_name__()}.{k}"] = v
            mlflow.log_params(params_)