# Ranker that can takes into accound different features

# Set up

In [1]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

In [2]:
import os
import sys

import dill
import lightning as L
import numpy as np
import pandas as pd
import torch
from dotenv import load_dotenv
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import MLFlowLogger
from loguru import logger
from mlflow.models.signature import infer_signature
from pydantic import BaseModel
from torch.utils.data import DataLoader

import mlflow

load_dotenv()

sys.path.insert(0, "..")

from cfg.run_cfg import RunCfg
from src.ann import AnnIndex
from src.data_prep_utils import chunk_transform
from src.dataset import UserItemBinaryDFDataset
from src.id_mapper import IDMapper
from src.ranker.inference import RankerInferenceWrapper
from src.ranker.model import Ranker
from src.ranker.trainer import LitRanker
from src.viz import blueq_colors

# Controller

In [3]:
max_epochs = 1

In [4]:
class Args(BaseModel):
    testing: bool = False
    author: str = "quy.dinh"
    log_to_mlflow: bool = True
    experiment_name: str = "RecSys MVP - Ranker"
    run_name: str = "035-rerun-add-features-l9"
    notebook_persist_dp: str = None
    random_seed: int = 41
    device: str = None

    rc: RunCfg = RunCfg().init()

    item_metadata_pipeline_fp: str = "../data/item_metadata_pipeline.dill"
    qdrant_url: str = None
    qdrant_collection_name: str = "item_desc_sbert"

    max_epochs: int = max_epochs
    batch_size: int = 128
    tfm_chunk_size: int = 10000
    neg_to_pos_ratio: int = 1

    user_col: str = "user_id"
    item_col: str = "parent_asin"
    rating_col: str = "rating"
    timestamp_col: str = "timestamp"

    top_K: int = 100
    top_k: int = 10

    embedding_dim: int = 128
    item_sequence_ts_bucket_size: int = 10
    bucket_embedding_dim: int = 16
    dropout: float = 0.3
    early_stopping_patience: int = 5
    learning_rate: float = 0.0003
    l2_reg: float = 1e-4

    mlf_item2vec_model_name: str = "item2vec"
    mlf_model_name: str = "ranker"
    min_roc_auc: float = 0.7

    best_checkpoint_path: str = None

    def init(self):
        self.notebook_persist_dp = os.path.abspath(f"data/{self.run_name}")
        os.makedirs(self.notebook_persist_dp, exist_ok=True)

        if not (qdrant_host := os.getenv("QDRANT_HOST")):
            raise Exception(f"Environment variable QDRANT_HOST is not set.")

        qdrant_port = os.getenv("QDRANT_PORT")
        self.qdrant_url = f"{qdrant_host}:{qdrant_port}"

        if not (mlflow_uri := os.environ.get("MLFLOW_TRACKING_URI")):
            logger.warning(
                f"Environment variable MLFLOW_TRACKING_URI is not set. Setting self.log_to_mlflow to false."
            )
            self.log_to_mlflow = False

        if self.log_to_mlflow:
            logger.info(
                f"Setting up MLflow experiment {self.experiment_name} - run {self.run_name}..."
            )
            self._mlf_logger = MLFlowLogger(
                experiment_name=self.experiment_name,
                run_name=self.run_name,
                tracking_uri=mlflow_uri,
                log_model=True,
            )

        if self.device is None:
            self.device = (
                "cuda"
                if torch.cuda.is_available()
                else "mps" if torch.backends.mps.is_available() else "cpu"
            )

        return self


args = Args().init()

print(args.model_dump_json(indent=2))

[32m2024-11-10 16:33:56.874[0m | [34m[1mDEBUG   [0m | [36mcfg.run_cfg[0m:[36minit[0m:[36m43[0m - [34m[1mChanging use_item_tags_from_llm requires re-running notebook 002-features-v2 to get the new item_metadata_pipeline.dill file[0m
[32m2024-11-10 16:33:56.877[0m | [1mINFO    [0m | [36m__main__[0m:[36minit[0m:[36m61[0m - [1mSetting up MLflow experiment RecSys MVP - Ranker - run 035-rerun-add-features-l9...[0m


{
  "testing": false,
  "author": "quy.dinh",
  "log_to_mlflow": true,
  "experiment_name": "RecSys MVP - Ranker",
  "run_name": "035-rerun-add-features-l9",
  "notebook_persist_dp": "/home/dvquys/frostmourne/recsys-mvp/notebooks/data/035-rerun-add-features-l9",
  "random_seed": 41,
  "device": "cuda",
  "rc": {
    "use_sbert_features": false,
    "use_item_tags_from_llm": false,
    "item_feature_cols": [
      "main_category",
      "categories",
      "price",
      "parent_asin_rating_cnt_365d",
      "parent_asin_rating_avg_prev_rating_365d",
      "parent_asin_rating_cnt_90d",
      "parent_asin_rating_avg_prev_rating_90d",
      "parent_asin_rating_cnt_30d",
      "parent_asin_rating_avg_prev_rating_30d",
      "parent_asin_rating_cnt_7d",
      "parent_asin_rating_avg_prev_rating_7d"
    ],
    "item_tags_from_llm_fp": "../data/item_tags_from_llm.parquet"
  },
  "item_metadata_pipeline_fp": "../data/item_metadata_pipeline.dill",
  "qdrant_url": "localhost:6333",
  "qdrant_coll

# Implement

In [5]:
def init_model(
    n_users,
    n_items,
    embedding_dim,
    item_sequence_ts_bucket_size,
    bucket_embedding_dim,
    item_feature_size,
    dropout,
    item_embedding=None,
):
    model = Ranker(
        n_users,
        n_items,
        embedding_dim,
        item_sequence_ts_bucket_size=item_sequence_ts_bucket_size,
        bucket_embedding_dim=bucket_embedding_dim,
        item_feature_size=item_feature_size,
        dropout=dropout,
        item_embedding=item_embedding,
    )
    return model

## Load pretrained Item2Vec embeddings

In [6]:
mlf_client = mlflow.MlflowClient()
model = mlflow.pyfunc.load_model(
    model_uri=f"models:/{args.mlf_item2vec_model_name}@champion"
)
skipgram_model = model.unwrap_python_model().model
embedding_0 = skipgram_model.embeddings(torch.tensor(0))
embedding_dim = embedding_0.size()[0]
id_mapping = model.unwrap_python_model().id_mapping
pretrained_item_embedding = skipgram_model.embeddings

Downloading artifacts:   0%|          | 0/8 [00:00<?, ?it/s]

In [7]:
assert (
    pretrained_item_embedding.embedding_dim == args.embedding_dim
), "Mismatch pretrained item_embedding dimension"

## Load vectorized item features

In [8]:
with open(args.item_metadata_pipeline_fp, "rb") as f:
    item_metadata_pipeline = dill.load(f)

## Load ANN Index

In [9]:
if args.rc.use_sbert_features:
    ann_index = AnnIndex(args.qdrant_url, args.qdrant_collection_name)
    vector = ann_index.get_vector_by_ids([0])[0]
    sbert_embedding_dim = vector.shape[0]
    logger.info(f"{sbert_embedding_dim=}")
    neighbors = ann_index.get_neighbors_by_ids([0])
    display(neighbors)

# Test implementation

In [10]:
embedding_dim = 8
batch_size = 2

# Mock data
user_indices = [0, 0, 1, 2, 2]
item_indices = [0, 1, 2, 3, 4]
timestamps = [0, 1, 2, 3, 4]
ratings = [0, 4, 5, 3, 0]
item_sequences = [
    [-1, -1, 2, 3],
    [-1, -1, 2, 3],
    [-1, -1, 1, 3],
    [-1, -1, 2, 1],
    [-1, -1, 2, 1],
]
item_sequences_ts_buckets = [
    [-1, -1, 2, 3],
    [-1, -1, 2, 3],
    [-1, -1, 1, 3],
    [-1, -1, 2, 1],
    [-1, -1, 2, 1],
]
main_category = [
    "All Electronics",
    "Video Games",
    "All Electronics",
    "Video Games",
    "Unknown",
]
categories = [[], ["Headsets"], ["Video Games"], [], ["blah blah"]]
tags = categories  # Mock data not important
title = ["World of Warcraft", "DotA 2", "Diablo IV", "Football Manager 2024", "Unknown"]
description = [[], [], ["Video games blah blah"], [], ["blah blah"]]
price = ["from 14.99", "14.99", "price: 9.99", "20 dollars", "None"]
parent_asin_rating_cnt_365d = [0, 1, 2, 3, 4]
parent_asin_rating_avg_prev_rating_365d = [4.0, 3.5, 4.5, 5.0, 2.0]
parent_asin_rating_cnt_90d = [0, 1, 2, 3, 4]
parent_asin_rating_avg_prev_rating_90d = [4.0, 3.5, 4.5, 5.0, 2.0]
parent_asin_rating_cnt_30d = [0, 1, 2, 3, 4]
parent_asin_rating_avg_prev_rating_30d = [4.0, 3.5, 4.5, 5.0, 2.0]
parent_asin_rating_cnt_7d = [0, 1, 2, 3, 4]
parent_asin_rating_avg_prev_rating_7d = [4.0, 3.5, 4.5, 5.0, 2.0]

train_data = {
    "user_indice": user_indices,
    "item_indice": item_indices,
    args.timestamp_col: timestamps,
    args.rating_col: ratings,
    "item_sequence": item_sequences,
    "item_sequence_ts_bucket": item_sequences_ts_buckets,
    "main_category": main_category,
    "title": title,
    "description": description,
    "categories": categories,
    "price": price,
    "parent_asin_rating_cnt_365d": parent_asin_rating_cnt_365d,
    "parent_asin_rating_avg_prev_rating_365d": parent_asin_rating_avg_prev_rating_365d,
    "parent_asin_rating_cnt_90d": parent_asin_rating_cnt_90d,
    "parent_asin_rating_avg_prev_rating_90d": parent_asin_rating_avg_prev_rating_90d,
    "parent_asin_rating_cnt_30d": parent_asin_rating_cnt_30d,
    "parent_asin_rating_avg_prev_rating_30d": parent_asin_rating_avg_prev_rating_30d,
    "parent_asin_rating_cnt_7d": parent_asin_rating_cnt_7d,
    "parent_asin_rating_avg_prev_rating_7d": parent_asin_rating_avg_prev_rating_7d,
}

if args.rc.use_item_tags_from_llm:
    train_data["tags"] = tags

train_df = pd.DataFrame(train_data)
train_item_features = item_metadata_pipeline.transform(train_df).astype(np.float32)
if args.rc.use_sbert_features:
    sbert_vectors = ann_index.get_vector_by_ids(
        train_df["item_indice"].values.tolist()
    ).astype(np.float32)
    train_item_features = np.hstack([train_item_features, sbert_vectors])

n_users = len(set(user_indices))
n_items = len(set(item_indices))
item_feature_size = train_item_features.shape[1]

model = init_model(
    n_users,
    n_items,
    embedding_dim,
    args.item_sequence_ts_bucket_size,
    args.bucket_embedding_dim,
    item_feature_size,
    args.dropout,
)

# Example forward pass
model.eval()
users = torch.tensor(user_indices)
items = torch.tensor(item_indices)
item_sequences = torch.tensor(item_sequences)
item_sequences_ts_buckets = torch.tensor(item_sequences_ts_buckets)
item_features = torch.tensor(train_item_features)
predictions = model.predict(
    users, item_sequences, item_sequences_ts_buckets, item_features, items
)
print(predictions)
model.train()

tensor([[0.3810],
        [0.4688],
        [0.4127],
        [0.4198],
        [0.4208]], grad_fn=<SigmoidBackward0>)


Ranker(
  (item_embedding): Embedding(6, 8, padding_idx=5)
  (user_embedding): Embedding(3, 8)
  (item_sequence_ts_bucket_embedding): Embedding(11, 16, padding_idx=10)
  (gru): GRU(24, 8, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (item_feature_tower): Sequential(
    (0): Linear(in_features=161, out_features=8, bias=True)
    (1): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
  )
  (fc_rating): Sequential(
    (0): Linear(in_features=32, out_features=8, bias=True)
    (1): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=8, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

In [11]:
rating_dataset = UserItemBinaryDFDataset(
    train_df,
    "user_indice",
    "item_indice",
    args.rating_col,
    args.timestamp_col,
    item_feature=train_item_features,
)

train_loader = DataLoader(
    rating_dataset, batch_size=batch_size, shuffle=False, drop_last=True
)

In [12]:
for batch_input in train_loader:
    print(batch_input)

{'user': tensor([0, 0]), 'item': tensor([0, 1]), 'rating': tensor([0., 1.]), 'item_sequence': tensor([[-1, -1,  2,  3],
        [-1, -1,  2,  3]]), 'item_sequence_ts_bucket': tensor([[-1, -1,  2,  3],
        [-1, -1,  2,  3]]), 'item_feature': tensor([[-0.0147,  5.6424, -0.0147, -0.0255, -0.0147, -0.0147, -0.0208, -0.0821,
         -0.2873, -0.0147, -0.0389, -0.0294, -0.0255, -0.0147, -0.0147, -0.0389,
         -0.0208, -0.0389, -0.0510, -2.4504, -0.1343,  7.1430, -0.6101, -0.0821,
         -0.0794, -0.0208, -0.0147, -0.0147, -0.0510, -0.1326, -0.0255, -0.0147,
         -0.0910, -0.1359, -0.0147, -0.1486, -0.1045, -0.0147, -0.0147, -0.2057,
         -0.2260, -0.0294, -0.0208, -0.0147, -0.0208, -0.0147, -0.0294, -0.0737,
         -0.0208, -0.0208, -0.0147, -0.1034, -0.0208, -0.0208, -0.0389, -0.1034,
         -0.0551, -0.1248, -0.0675, -0.1239, -0.1638, -1.1327, -0.0208, -0.0208,
         -0.1300, -0.1359, -0.0329, -0.0147, -0.0147, -0.0208, -0.0147, -0.0208,
         -0.0465, -0.1836,

In [13]:
# Prepare all item features for recommendation
all_items_df = train_df.drop_duplicates(subset=["item_indice"])
all_items_indices = all_items_df["item_indice"].values
all_items_features = item_metadata_pipeline.transform(all_items_df).astype(np.float32)
if args.rc.use_sbert_features:
    all_sbert_vectors = ann_index.get_vector_by_ids(all_items_indices.tolist()).astype(
        np.float32
    )
    all_items_features = np.hstack([all_items_features, all_sbert_vectors])

lit_model = LitRanker(
    model,
    log_dir=args.notebook_persist_dp,
    all_items_indices=all_items_indices,
    all_items_features=all_items_features,
)

# train model
trainer = L.Trainer(
    default_root_dir=f"{args.notebook_persist_dp}/test",
    max_epochs=2,
    accelerator=args.device if args.device else "auto",
)
trainer.fit(
    model=lit_model, train_dataloaders=train_loader, val_dataloaders=train_loader
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4060 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params | Mode 
-----------------------------------------
0 | model | Ranker | 2.7 K  | train
-----------------------------------------
2.7 K     Trainable params
0         Non-trainable params
2.7 K     Total params
0.011     Total estimated model params size (MB)
15        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …

/home/dvquys/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/home/dvquys/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/home/dvquys/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=2` reached.
[32m2024-11-10 16:33:57.846[0m | [1mINFO    [0m | [36msrc.ranker.trainer[0m:[36mon_fit_end[0m:[36m171[0m - [1mLogging classification metrics...[0m
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [14]:
# After fitting
model.eval()
predictions = model.predict(
    users, item_sequences, item_sequences_ts_buckets, item_features, items
)
print(predictions)
model.train()

tensor([[0.3702],
        [0.4686],
        [0.4217],
        [0.4018],
        [0.4010]], grad_fn=<SigmoidBackward0>)


Ranker(
  (item_embedding): Embedding(6, 8, padding_idx=5)
  (user_embedding): Embedding(3, 8)
  (item_sequence_ts_bucket_embedding): Embedding(11, 16, padding_idx=10)
  (gru): GRU(24, 8, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (item_feature_tower): Sequential(
    (0): Linear(in_features=161, out_features=8, bias=True)
    (1): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
  )
  (fc_rating): Sequential(
    (0): Linear(in_features=32, out_features=8, bias=True)
    (1): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=8, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

In [15]:
# Get the last row of each item as input for recommendations (containing the most updated item_sequence)
to_rec_df = train_df.sort_values(args.timestamp_col, ascending=False).drop_duplicates(
    subset=["user_indice"]
)
recommendations = model.recommend(
    torch.tensor(to_rec_df["user_indice"].values.tolist()),
    torch.tensor(to_rec_df["item_sequence"].values.tolist()),
    torch.tensor(to_rec_df["item_sequence_ts_bucket"].values.tolist()),
    torch.tensor(lit_model.all_items_features),
    torch.tensor(lit_model.all_items_indices),
    k=2,
    batch_size=4,
)
recommendations

Generating recommendations:   0%|          | 0/1 [00:00<?, ?it/s]

{'user_indice': [2, 2, 1, 1, 0, 0],
 'recommendation': [1, 3, 4, 1, 4, 1],
 'score': [0.4282514750957489,
  0.40176281332969666,
  0.4904431998729706,
  0.4629577696323395,
  0.4821385443210602,
  0.46864748001098633]}

# Prep data

In [16]:
train_df = pd.read_parquet("../data/train_features_neg_df.parquet")
val_df = pd.read_parquet("../data/val_features_neg_df.parquet")
idm_fp = "../data/idm.json"
idm = IDMapper().load(idm_fp)

assert (
    train_df[args.user_col].map(lambda s: idm.get_user_index(s))
    != train_df["user_indice"]
).sum() == 0, "Mismatch IDM"
assert (
    val_df[args.user_col].map(lambda s: idm.get_user_index(s)) != val_df["user_indice"]
).sum() == 0, "Mismatch IDM"

if args.rc.use_item_tags_from_llm:
    assert (
        "tags" in train_df.columns
    ), "There is no column `tags` in train_df, please make sure you have run notebook 002, 020 with RunCfg.use_item_tags_from_llm=True"

In [17]:
train_df

Unnamed: 0,user_id,parent_asin,rating,timestamp,timestamp_unix,parent_asin_rating_cnt_365d,parent_asin_rating_avg_prev_rating_365d,parent_asin_rating_cnt_90d,parent_asin_rating_avg_prev_rating_90d,parent_asin_rating_cnt_30d,...,user_rating_list_10_recent_asin,user_rating_list_10_recent_asin_timestamp,item_sequence,item_sequence_ts,item_sequence_ts_bucket,main_category,title,description,categories,price
0,AG57LGJFCNNQJ6P6ABQAVUKXDUDA,B0015AARJI,0.0,2016-01-12 11:59:11.000,,76.0,4.592105,10.0,4.3,3.0,...,B00J00BLRM,1452599936,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1....","[-1, -1, -1, -1, -1, -1, -1, -1, -1, 1452599936]","[-1, -1, -1, -1, -1, -1, -1, -1, -1, 0]",Video Games,PlayStation 3 Dualshock 3 Wireless Controller ...,"[Amazon.com, The Dualshock 3 wireless controll...","[Video Games, Legacy Systems, PlayStation Syst...",49.99
1,AHWG4EGOV5ZDKPETL56MAYGPLJRQ,B0BMGHMP23,0.0,2016-04-18 19:26:20.000,,,,,,,...,"B00YOGZFCO,B00KWFCSB2,B00L3LQ1FI,B0151K6J9Y,B0...","1449254540,1449256005,1449257733,1452715791,14...","[3028.0, 2742.0, 2755.0, 3159.0, 3101.0, 3036....","[1449254540, 1449256005, 1449257733, 145271579...","[5, 5, 5, 5, 5, 5, 5, 5, 5, 5]",Computers,Logitech G502 Lightspeed Wireless Gaming Mouse...,[G502 is the best gaming mouse from Logitech G...,"[Video Games, PC, Accessories, Gaming Mice]",87.95
2,AH5PTZ2U74OZ3HT6QVUWM4CV6OVQ,B009AP23NI,0.0,2016-02-10 18:45:08.000,,9.0,4.666667,0.0,,0.0,...,"B0199OXR0W,B00EVPR4FY,B00B7ELWAU,B00UH9DN58,B0...","1443454097,1455129080,1455129186,1455129499,14...","[-1.0, -1.0, 3234.0, 2508.0, 2318.0, 2964.0, 1...","[-1, -1, 1443454097, 1455129080, 1455129186, 1...","[-1, -1, 5, 1, 1, 0, 0, 0, 0, 0]",Video Games,Nintendo Wii U Pro U Controller (Japanese Vers...,[Wii U PRO controller (black) (WUP-A-RSKA)],"[Video Games, Legacy Systems, Nintendo Systems...",43.99
3,AFC5XTCF5D7J3NSDITB2Z26XWWYA,B001E8WQUY,5.0,2019-05-01 21:22:39.265,1.556746e+09,0.0,,0.0,,0.0,...,"B006HZA6VK,B0BN2FNKLM,B0086VPUHI,B0040UAYI4,B0...","1327120514,1377289907,1402605836,1402606396,14...","[1987.0, 4569.0, 2114.0, 1606.0, 2159.0, 2279....","[1327120514, 1377289907, 1402605836, 140260639...","[8, 8, 7, 7, 7, 7, 7, 7, 6, 6]",Video Games,Rock Band 2 - Nintendo Wii (Game only),"[Product description, Rock Band 2 lets you and...","[Video Games, Legacy Systems, Nintendo Systems...",28.49
4,AF7LJQOIWF3Y3YD7SGOJ34MA5JPA,B001E8WQKY,5.0,2015-01-09 12:53:25.000,1.420808e+09,16.0,4.375000,8.0,4.5,4.0,...,"B00A2ML6XG,B003VUO6LU",14208077931420807991,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1....","[-1, -1, -1, -1, -1, -1, -1, -1, 1420807793, 1...","[-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]",Video Games,Resident Evil 5 - Xbox 360,[],"[Video Games, Legacy Systems, Xbox Systems, Xb...",29.88
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
328591,AG4RATLNVLOKZCPXN67HKOAK65CA,B078FBVJMB,0.0,2015-10-31 18:25:09.000,,,,,,,...,B00TFVD688,1425233294,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1....","[-1, -1, -1, -1, -1, -1, -1, -1, -1, 1425233294]","[-1, -1, -1, -1, -1, -1, -1, -1, -1, 5]",Video Games,A Way Out – PC Origin [Online Game Code],[From the creators of Brothers - A Tale of Two...,"[Video Games, PC, Games]",5.99
328592,AFBXO3BFWBJX6QS5NW73O37IXF2A,B0771ZXXV6,0.0,2011-03-08 02:06:38.000,,,,,,,...,"B003JVCA9Q,B0029NZ4HA",12995495171299549928,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1....","[-1, -1, -1, -1, -1, -1, -1, -1, 1299549517, 1...","[-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]",Video Games,Nintendo Joy-Con (R) - Neon Red - Nintendo Switch,[To be determined],"[Video Games, Nintendo Switch, Accessories, Co...",
328593,AHVANA5GZNJ45UABPXWZNAF4ECBQ,B00BBF6MO6,0.0,2015-02-15 05:31:04.000,,3.0,4.666667,0.0,,0.0,...,"B002L93F0A,B002KJ02ZC,B001H4NMNA",137041433213704147071370416530,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 137...","[-1, -1, -1, -1, -1, -1, -1, 1370414332, 13704...","[-1, -1, -1, -1, -1, -1, -1, 6, 6, 6]",Video Games,Killer is Dead - Xbox 360,[Killer Is Dead is the latest title from the d...,"[Video Games, Legacy Systems, Xbox Systems, Xb...",39.82
328594,AHAVA5VKMJ3OMOLGDZ3W45CKXEWA,B00KTORA0K,5.0,2019-05-25 04:03:51.505,1.558757e+09,3.0,4.666667,1.0,5.0,1.0,...,"B004AYCNR0,B007NUQICE,B000TYQL1O,B000SEU92W,B0...","1431150669,1431150834,1432041664,1432041986,15...","[-1.0, -1.0, -1.0, 1657.0, 2074.0, 593.0, 583....","[-1, -1, -1, 1431150669, 1431150834, 143204166...","[-1, -1, -1, 7, 7, 7, 7, 5, 5, 0]",Video Games,Just Dance 2015 - Wii,[With more than 50 million copies of Just Danc...,"[Video Games, Legacy Systems, Nintendo Systems...",33.0


In [18]:
user_indices = train_df["user_indice"].unique()
item_indices = train_df["item_indice"].unique()
if args.rc.use_sbert_features:
    all_sbert_vectors = ann_index.get_vector_by_ids(
        item_indices.tolist(), chunk_size=1000
    ).astype(np.float32)

train_item_features = chunk_transform(
    train_df, item_metadata_pipeline, chunk_size=args.tfm_chunk_size
)
train_item_features = train_item_features.astype(np.float32)

val_item_features = chunk_transform(
    val_df, item_metadata_pipeline, chunk_size=args.tfm_chunk_size
)
val_item_features = val_item_features.astype(np.float32)

if args.rc.use_sbert_features:
    train_sbert_vectors = all_sbert_vectors[train_df["item_indice"].values]
    train_item_features = np.hstack([train_item_features, train_sbert_vectors])
    val_sbert_vectors = all_sbert_vectors[val_df["item_indice"].values]
    val_item_features = np.hstack([val_item_features, val_sbert_vectors])

logger.info(f"{len(user_indices)=:,.0f}, {len(item_indices)=:,.0f}")

Transforming chunks:   0%|          | 0/33 [00:00<?, ?it/s]

Transforming chunks:   0%|          | 0/1 [00:00<?, ?it/s]

[32m2024-11-10 16:34:01.699[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m24[0m - [1mlen(user_indices)=19,578, len(item_indices)=4,630[0m


# Train

In [19]:
rating_dataset = UserItemBinaryDFDataset(
    train_df,
    "user_indice",
    "item_indice",
    args.rating_col,
    args.timestamp_col,
    item_feature=train_item_features,
)
val_rating_dataset = UserItemBinaryDFDataset(
    val_df,
    "user_indice",
    "item_indice",
    args.rating_col,
    args.timestamp_col,
    item_feature=val_item_features,
)

train_loader = DataLoader(
    rating_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True
)
val_loader = DataLoader(
    val_rating_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False
)

In [20]:
n_items = len(item_indices)
n_users = len(user_indices)

model = init_model(
    n_users,
    n_items,
    args.embedding_dim,
    args.item_sequence_ts_bucket_size,
    args.bucket_embedding_dim,
    item_feature_size,
    args.dropout,
)
model

Ranker(
  (item_embedding): Embedding(4631, 128, padding_idx=4630)
  (user_embedding): Embedding(19578, 128)
  (item_sequence_ts_bucket_embedding): Embedding(11, 16, padding_idx=10)
  (gru): GRU(144, 128, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (item_feature_tower): Sequential(
    (0): Linear(in_features=161, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
  )
  (fc_rating): Sequential(
    (0): Linear(in_features=512, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=128, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

#### Predict before train

In [21]:
val_df = val_rating_dataset.df
val_df.sample(10)

Unnamed: 0,user_id,parent_asin,rating,timestamp,timestamp_unix,parent_asin_rating_cnt_365d,parent_asin_rating_avg_prev_rating_365d,parent_asin_rating_cnt_90d,parent_asin_rating_avg_prev_rating_90d,parent_asin_rating_cnt_30d,...,user_rating_list_10_recent_asin,user_rating_list_10_recent_asin_timestamp,item_sequence,item_sequence_ts,item_sequence_ts_bucket,main_category,title,description,categories,price
305,AH4LS5HFN2VGSKNKQGNZGENHBVVA,B0BL65X86R,0.0,2021-12-20 05:58:28.874,,12.0,5.0,0.0,,0.0,...,"B00CQ9L1Z6,B017C6OK7S,B08MBMV3VG,B00JM3R6M6,B0...","1402389343,1402389660,1405159126,1405159214,14...","[2404, 3199, 4403, 2682, 1915, 2759, 2428, 430...","[1402389343, 1402389660, 1405159126, 140515921...","[8, 8, 8, 8, 8, 7, 6, 6, 2, 0]",Video Games,$25 PlayStation Store Gift Card [Digital Code],[Redeem against anything on PlayStation Store....,"[Video Games, Online Game Services, PlayStatio...",25.0
179,AG7A42N537XQMGC5URJ6VZ2JO2FA,B0823F9NB8,1.0,2021-10-19 04:50:46.379,1634619000.0,3.0,4.0,0.0,,0.0,...,"B07HGT7JC8,B087NM8BB6,B07L73KHYH,B08DFB488B,B0...","1566358144,1577074836,1577075033,1607700519,16...","[-1, -1, -1, 4002, 4338, 4050, 4370, 4387, 438...","[-1, -1, -1, 1566358144, 1577074836, 157707503...","[-1, -1, -1, 6, 6, 6, 5, 5, 5, 5]",Video Games,Hori Nintendo Switch D-Pad Controller (L) (Pok...,[Hori Nintendo Switch D-Pad Controller (L) (Po...,[],37.98
993,AG4RCXKPTC6QRORJLUSBY4SO2IAA,B009DL2TBA,0.0,2022-06-09 23:25:32.480,,4.0,3.5,0.0,,0.0,...,"B00PV515DU,B07QQ8N7LL,B0BFT941YQ,B087XRWHHL,B0...","1551303197,1578542946,1582164927,1612064838,16...","[2903, 4140, 4545, 4350, 3908, 3527, 3788, 421...","[1551303197, 1578542946, 1582164927, 161206483...","[7, 6, 6, 6, 5, 5, 5, 5, 5, 5]",Video Games,PlayStation 3 500 GB System,"[Play, Stream, and Watch your favorite games, ...","[Video Games, Legacy Systems, PlayStation Syst...",233.98
414,AEL2FTDUA5KWOWBLN2HXG3CJRFRA,B08DHTZNNF,1.0,2021-09-02 13:28:34.010,1630589000.0,13.0,3.769231,1.0,5.0,0.0,...,"B00RJNE3TK,B018K6KV68,B09JDLC31H,B087NN2K41,B0...","1462572545,1471776059,1491916082,1493417579,14...","[-1, -1, -1, -1, 2922, 3223, 4479, 4341, 4497,...","[-1, -1, -1, -1, 1462572545, 1471776059, 14919...","[-1, -1, -1, -1, 8, 8, 7, 7, 7, 5]",Computers,"Razer Mamba Elite Wired Gaming Mouse: 16,000 D...",[The Razer Mamba Elite features our acclaimed ...,"[Video Games, Legacy Systems, PlayStation Syst...",36.95
1654,AEFTCHDMQEMR77C5VMPO7OH5CDQA,B00XBLQCLQ,1.0,2022-02-05 03:59:17.229,1644034000.0,0.0,,0.0,,0.0,...,"B00503E8S2,B005GISQX4,B00BHRD4BM,B00ZQB28XK,B0...","1321458588,1350061407,1418486090,1471894225,14...","[-1, -1, 1806, 1921, 2343, 3081, 3058, 2207, 1...","[-1, -1, 1321458588, 1350061407, 1418486090, 1...","[-1, -1, 9, 8, 8, 8, 8, 7, 7, 6]",Video Games,Assassin’s Creed Syndicate - Gold Edition | PC...,"[London, 1868. The Industrial Revolution unlea...","[Video Games, PC, Games]",66.35
730,AHMB2TBSRX4E5LG5A4BBYRXTFSLA,B0BVPBPLZL,1.0,2022-05-20 19:28:55.185,1653075000.0,0.0,,0.0,,0.0,...,"B0154OZU4M,B01N4JYY1H,B01D63UU52,B06X9898LZ,B0...","1518371361,1545676142,1554156978,1564847685,15...","[-1, -1, -1, -1, -1, 3161, 3528, 3303, 3563, 3...","[-1, -1, -1, -1, -1, 1518371361, 1545676142, 1...","[-1, -1, -1, -1, -1, 7, 7, 7, 6, 6]",Cell Phones & Accessories,TotalMount for Nintendo Switch – Mounts Ninten...,[],"[Video Games, Nintendo Switch, Accessories, Mo...",24.99
94,AGKLFWGDSETQCTZ62LHABEGT3YTQ,B087NN2K41,1.0,2021-10-01 18:42:53.063,1633114000.0,22.0,4.909091,3.0,4.666667,0.0,...,"B07R4C3ZVH,B07DJ6D4ZY,B01GY3601E,B01N4X95QL,B0...","1527729322,1546119187,1546216617,1557103543,15...","[4145, 3931, 3379, 3531, 2402, 4016, 4440, 420...","[1527729322, 1546119187, 1546216617, 155710354...","[7, 6, 6, 6, 6, 6, 5, 5, 5, 5]",Video Games,Mario Kart 8 Deluxe – Booster Course Pass - Ni...,[​Double your course options with the Mario Ka...,"[Video Games, Game Genre of the Month]",24.88
824,AF7SKKZS25RBMFPGXRTX5DWLPKZA,B000VIUNZI,1.0,2021-12-22 02:18:29.668,1640140000.0,0.0,,0.0,,0.0,...,"B00ATNFLDY,B0080CAOYM,B002JTX5RK,B00503E9FY,B0...","1452372758,1452372877,1452372928,1452373043,14...","[-1, -1, -1, -1, 2295, 2102, 1337, 1807, 1178,...","[-1, -1, -1, -1, 1452372758, 1452372877, 14523...","[-1, -1, -1, -1, 8, 8, 8, 8, 8, 8]",Video Games,PSP AV Cable,[Analog Audio and Video Connection Cable for u...,"[Video Games, Legacy Systems, PlayStation Syst...",15.99
61,AH6LKVRXW6JNSFAAGCUTDPZAZQ5Q,B0C3KYVDWT,1.0,2021-12-21 21:29:06.336,1640122000.0,37.0,4.945946,2.0,5.0,0.0,...,"B002I0J5JW,B002BRZ712,B005OGL766,B001EHD9N8,B0...","1402002580,1404595385,1415317246,1436474670,14...","[1304, 1207, 1947, 758, 575, 3453, 1579, 3135,...","[1402002580, 1404595385, 1415317246, 143647467...","[8, 8, 8, 8, 8, 7, 7, 6, 6, 6]",Computers,"SanDisk 128GB microSDXC-Card, Licensed for Nin...","[With incredible speed, the officially license...","[Video Games, Nintendo Switch, Accessories]",14.99
1704,AHJUN2GLK4WZ6C4QOZFMC7VXE6HA,B000FQBF1M,0.0,2021-08-16 13:23:17.734,,0.0,,0.0,,0.0,...,"B07WXKQWTV,B07DK1H3H5,B094YHB1QK,B08FC5L3RG,B0...","1602018186,1607967431,1610282831,1614692333,16...","[4231, 3936, 4455, 4376, 1765, 4318, 3385, 437...","[1602018186, 1607967431, 1610282831, 161469233...","[5, 5, 5, 5, 5, 5, 5, 5, 5, 5]",Video Games,Killzone 2 - Playstation 3,"[Product Description, In and out within a mont...","[Video Games, Legacy Systems, PlayStation Syst...",9.85


In [22]:
user_id = val_df.sample(1)[args.user_col].values[0]
test_df = val_df.loc[lambda df: df[args.user_col].eq(user_id)]
with pd.option_context("display.max_colwidth", None):
    display(test_df)

Unnamed: 0,user_id,parent_asin,rating,timestamp,timestamp_unix,parent_asin_rating_cnt_365d,parent_asin_rating_avg_prev_rating_365d,parent_asin_rating_cnt_90d,parent_asin_rating_avg_prev_rating_90d,parent_asin_rating_cnt_30d,...,user_rating_list_10_recent_asin,user_rating_list_10_recent_asin_timestamp,item_sequence,item_sequence_ts,item_sequence_ts_bucket,main_category,title,description,categories,price
1457,AHALPVQPNWXDAXMQ4DGOKASLXMZA,B002B9XB0E,0.0,2021-12-10 02:10:57.888,,1.0,5.0,0.0,,0.0,...,"B0029LJIFG,B0050SYEYK,B005OGKYVK,B001EYUQTQ,B0086VPUHI",14672653161467265322146726532514672653311467265337,"[-1, -1, -1, -1, -1, 1197, 1846, 1944, 923, 2114]","[-1, -1, -1, -1, -1, 1467265316, 1467265322, 1467265325, 1467265331, 1467265337]","[-1, -1, -1, -1, -1, 8, 8, 8, 8, 8]",Computers,Buffalo iBuffalo Classic USB Gamepad for PC,"[Buffalo super Nintendo famicom snes, USB gamepad for pc. 8 buttons product. USB connection. Support windows 7/ME/2000/XP/Vista.]","[Video Games, PC, Accessories, Controllers, Gamepads & Standard Controllers]",
1549,AHALPVQPNWXDAXMQ4DGOKASLXMZA,B08N7PX4WX,1.0,2021-12-10 02:10:57.888,1639102000.0,10.0,4.7,1.0,5.0,0.0,...,"B0029LJIFG,B0050SYEYK,B005OGKYVK,B001EYUQTQ,B0086VPUHI",14672653161467265322146726532514672653311467265337,"[-1, -1, -1, -1, -1, 1197, 1846, 1944, 923, 2114]","[-1, -1, -1, -1, -1, 1467265316, 1467265322, 1467265325, 1467265331, 1467265337]","[-1, -1, -1, -1, -1, 8, 8, 8, 8, 8]",Video Games,PowerA Joy Con Comfort Grips for Nintendo Switch - Black,[How to setup: Align the left Joy-Con controllers with the left rail of the Joy-Con Comfort Grip making sure the Minus symbol is at the top. Slide the Joy-Con controller all the way down until you see the player indicator light turn on. Repeat for the right Joy-Con controller.],"[Video Games, Nintendo Switch, Accessories, Hand Grips]",9.88


In [23]:
test_row = test_df.loc[lambda df: df[args.rating_col].gt(0)].iloc[0]
item_id = test_row[args.item_col]
item_sequence = test_row["item_sequence"]
item_sequence_ts_bucket = test_row["item_sequence_ts_bucket"]
row_idx = test_row.name
item_feature = val_item_features[row_idx]
logger.info(
    f"Test predicting before training with {args.user_col} = {user_id} and {args.item_col} = {item_id}"
)
user_indice = idm.get_user_index(user_id)
item_indice = idm.get_item_index(item_id)
user = torch.tensor([user_indice])
item_sequence = torch.tensor([item_sequence])
item_sequence_ts_bucket = torch.tensor([item_sequence_ts_bucket])
item_feature = torch.tensor([item_feature])
item = torch.tensor([item_indice])

model.eval()
model.predict(user, item_sequence, item_sequence_ts_bucket, item_feature, item)
model.train()

[32m2024-11-10 16:34:01.895[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m7[0m - [1mTest predicting before training with user_id = AHALPVQPNWXDAXMQ4DGOKASLXMZA and parent_asin = B08N7PX4WX[0m

Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:278.)



Ranker(
  (item_embedding): Embedding(4631, 128, padding_idx=4630)
  (user_embedding): Embedding(19578, 128)
  (item_sequence_ts_bucket_embedding): Embedding(11, 16, padding_idx=10)
  (gru): GRU(144, 128, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (item_feature_tower): Sequential(
    (0): Linear(in_features=161, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
  )
  (fc_rating): Sequential(
    (0): Linear(in_features=512, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=128, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

#### Training loop

##### Overfit 1 batch

In [None]:
early_stopping = EarlyStopping(
    monitor="val_loss", patience=10, mode="min", verbose=False
)

model = init_model(
    n_users,
    n_items,
    args.embedding_dim,
    args.item_sequence_ts_bucket_size,
    args.bucket_embedding_dim,
    item_feature_size,
    dropout=0,
)
lit_model = LitRanker(
    model,
    learning_rate=args.learning_rate,
    l2_reg=0.0,
    log_dir=args.notebook_persist_dp,
    accelerator=args.device,
)

log_dir = f"{args.notebook_persist_dp}/logs/overfit"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    accelerator=args.device if args.device else "auto",
    max_epochs=100,
    overfit_batches=1,
    callbacks=[early_stopping],
)
trainer.fit(
    model=lit_model,
    train_dataloaders=train_loader,
    val_dataloaders=train_loader,
)
logger.info(f"Logs available at {trainer.log_dir}")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(overfit_batches=1)` was configured so 1 batch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params | Mode 
-----------------------------------------
0 | model | Ranker | 3.3 M  | train
-----------------------------------------
3.3 M     Trainable params
0         Non-trainable params
3.3 M     Total params
13.165    Total estimated model params size (MB)
15        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …


You requested to overfit but enabled val dataloader shuffling. We are turning off the val dataloader shuffling for you.


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


You requested to overfit but enabled train dataloader shuffling. We are turning off the train dataloader shuffling for you.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.



Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=100` reached.
[32m2024-11-10 16:34:09.924[0m | [1mINFO    [0m | [36msrc.ranker.trainer[0m:[36mon_fit_end[0m:[36m171[0m - [1mLogging classification metrics...[0m
[32m2024-11-10 16:34:27.759[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m37[0m - [1mLogs available at /home/dvquys/frostmourne/recsys-mvp/notebooks/data/035-rerun-add-features-l9/logs/overfit/lightning_logs/version_6[0m


In [41]:
%tensorboard --logdir $trainer.log_dir --host localhost

##### Fit on all data

In [26]:
all_items_df = train_df.drop_duplicates(subset=["item_indice"])
all_items_indices = all_items_df["item_indice"].values
all_items_features = item_metadata_pipeline.transform(all_items_df).astype(np.float32)
logger.info(
    f"Mean std over categorical and numerical features: {all_items_features.std(axis=0).mean()}"
)
if args.rc.use_sbert_features:
    all_sbert_vectors = ann_index.get_vector_by_ids(all_items_indices.tolist()).astype(
        np.float32
    )
    logger.info(f"Mean std over text features: {all_sbert_vectors.std(axis=0).mean()}")
    all_items_features = np.hstack([all_items_features, all_sbert_vectors])

[32m2024-11-10 16:34:28.859[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mMean std over categorical and numerical features: 0.996731698513031[0m


In [None]:
# papermill_description=fit-model
early_stopping = EarlyStopping(
    monitor="val_loss", patience=args.early_stopping_patience, mode="min", verbose=False
)

checkpoint_callback = ModelCheckpoint(
    dirpath=f"{args.notebook_persist_dp}/checkpoints",
    filename="best-checkpoint",
    save_top_k=1,
    monitor="val_loss",
    mode="min",
)

model = init_model(
    n_users,
    n_items,
    args.embedding_dim,
    args.item_sequence_ts_bucket_size,
    args.bucket_embedding_dim,
    item_feature_size,
    dropout=args.dropout,
    item_embedding=pretrained_item_embedding,
)
lit_model = LitRanker(
    model,
    learning_rate=args.learning_rate,
    l2_reg=args.l2_reg,
    log_dir=args.notebook_persist_dp,
    evaluate_ranking=True,
    idm=idm,
    all_items_indices=all_items_indices,
    all_items_features=all_items_features,
    args=args,
    neg_to_pos_ratio=args.neg_to_pos_ratio,
    checkpoint_callback=checkpoint_callback,
    accelerator=args.device,
)

In [28]:
log_dir = f"{args.notebook_persist_dp}/logs/run"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    max_epochs=args.max_epochs,
    callbacks=[early_stopping, checkpoint_callback],
    accelerator=args.device if args.device else "auto",
    logger=args._mlf_logger if args.log_to_mlflow else None,
)
trainer.fit(
    model=lit_model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

Checkpoint directory /home/dvquys/frostmourne/recsys-mvp/notebooks/data/035-rerun-add-features-l9/checkpoints exists and is not empty.

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params | Mode 
-----------------------------------------
0 | model | Ranker | 3.3 M  | train
-----------------------------------------
3.3 M     Trainable params
0         Non-trainable params
3.3 M     Total params
13.165    Total estimated model params size (MB)
15        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.



Training: |                                                                                                   …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=1` reached.
[32m2024-11-10 16:34:50.111[0m | [1mINFO    [0m | [36msrc.ranker.trainer[0m:[36mon_fit_end[0m:[36m164[0m - [1mLoading best model from /home/dvquys/frostmourne/recsys-mvp/notebooks/data/035-rerun-add-features-l9/checkpoints/best-checkpoint-v6.ckpt...[0m
[32m2024-11-10 16:34:50.274[0m | [1mINFO    [0m | [36msrc.ranker.trainer[0m:[36mon_fit_end[0m:[36m171[0m - [1mLogging classification metrics...[0m
[32m2024-11-10 16:34:50.946[0m | [1mINFO    [0m | [36msrc.ranker.trainer[0m:[36mon_fit_end[0m:[36m174[0m - [1mLogging ranking metrics...[0m


Generating recommendations:   0%|          | 0/177 [00:00<?, ?it/s]

2024/11/10 16:34:57 INFO mlflow.tracking._tracking_service.client: 🏃 View run 035-rerun-add-features-l9 at: http://localhost:5002/#/experiments/3/runs/53672526046b44969cae410b490d2857.
2024/11/10 16:34:57 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://localhost:5002/#/experiments/3.


In [29]:
logger.info(
    f"Test predicting after training with {args.user_col} = {user_id} and {args.item_col} = {item_id}"
)
model.eval()
model = model.to(user.device)  # Move model back to align with data device
model.predict(user, item_sequence, item_sequence_ts_bucket, item_feature, item)
model.train()

[32m2024-11-10 16:34:57.094[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mTest predicting after training with user_id = AHALPVQPNWXDAXMQ4DGOKASLXMZA and parent_asin = B08N7PX4WX[0m


Ranker(
  (item_embedding): Embedding(4631, 128, padding_idx=4630)
  (user_embedding): Embedding(19578, 128)
  (item_sequence_ts_bucket_embedding): Embedding(11, 16, padding_idx=10)
  (gru): GRU(144, 128, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (item_feature_tower): Sequential(
    (0): Linear(in_features=161, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
  )
  (fc_rating): Sequential(
    (0): Linear(in_features=512, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=128, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

# Load best checkpoint

In [30]:
logger.info(f"Loading best checkpoint from {checkpoint_callback.best_model_path}...")
args.best_checkpoint_path = checkpoint_callback.best_model_path

best_trainer = LitRanker.load_from_checkpoint(
    checkpoint_callback.best_model_path,
    model=init_model(
        n_users,
        n_items,
        args.embedding_dim,
        args.item_sequence_ts_bucket_size,
        args.bucket_embedding_dim,
        item_feature_size,
        dropout=0,
        item_embedding=pretrained_item_embedding,
    ),
)

[32m2024-11-10 16:34:57.120[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mLoading best checkpoint from /home/dvquys/frostmourne/recsys-mvp/notebooks/data/035-rerun-add-features-l9/checkpoints/best-checkpoint-v6.ckpt...[0m


In [31]:
best_model = best_trainer.model.to(lit_model.device)

In [32]:
best_model.eval()
best_model.predict(user, item_sequence, item_sequence_ts_bucket, item_feature, item)
best_model.train()

Ranker(
  (item_embedding): Embedding(4631, 128, padding_idx=4630)
  (user_embedding): Embedding(19578, 128)
  (item_sequence_ts_bucket_embedding): Embedding(11, 16, padding_idx=10)
  (gru): GRU(144, 128, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0, inplace=False)
  (item_feature_tower): Sequential(
    (0): Linear(in_features=161, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0, inplace=False)
  )
  (fc_rating): Sequential(
    (0): Linear(in_features=512, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0, inplace=False)
    (4): Linear(in_features=128, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

### Persist artifacts

In [33]:
if args.log_to_mlflow:
    # Persist id_mapping so that at inference we can predict based on item_ids (string) instead of item_index
    run_id = trainer.logger.run_id
    mlf_client = trainer.logger.experiment
    mlf_client.log_artifact(run_id, idm_fp)
    # Persist item_feature_metadata pipeline
    mlf_client.log_artifact(run_id, args.item_metadata_pipeline_fp)

    # Persist model architecture
    model_architecture_fp = f"{args.notebook_persist_dp}/model_architecture.txt"
    with open(model_architecture_fp, "w") as f:
        f.write(repr(model))
    mlf_client.log_artifact(run_id, model_architecture_fp)

### Wrap inference function and register best checkpoint as MLflow model

In [34]:
inferrer = RankerInferenceWrapper(best_model)

In [35]:
def generate_sample_item_features():
    sample_row = train_df.iloc[0].fillna(0)
    output = dict()
    for col in args.rc.item_feature_cols:
        v = sample_row[col]
        if isinstance(v, np.ndarray):
            v = "__".join(
                sample_row[col].tolist()
            )  # Workaround to avoid MLflow Got error: Per-column arrays must each be 1-dimensional
        output[col] = [v]
    return output

In [36]:
sample_input = {
    args.user_col: [idm.get_user_id(0)],
    "item_sequence": [",".join([idm.get_item_id(0), idm.get_item_id(1)])],
    "item_sequence_ts": [
        "1095133116,109770848"
    ],  # Here we input unix timestamp seconds instead of timestamp bucket because we need to calculate the bucket
    # **{col: [train_df.iloc[0].fillna(0)[col]] for col in args.item_feature_cols},
    **generate_sample_item_features(),
    args.item_col: [idm.get_item_id(0)],
}
sample_output = inferrer.infer([0], [[0, 1]], [[2, 0]], [train_item_features[0]], [0])
sample_output

array([0.72171295], dtype=float32)

In [37]:
sample_input

{'user_id': ['AE225O22SA7DLBOGOEIFL7FT5VYQ'],
 'item_sequence': ['0375869026,9625990674'],
 'item_sequence_ts': ['1095133116,109770848'],
 'main_category': ['Video Games'],
 'categories': ['Video Games__Legacy Systems__PlayStation Systems__PlayStation 3__Accessories__Controllers'],
 'price': ['49.99'],
 'parent_asin_rating_cnt_365d': [76.0],
 'parent_asin_rating_avg_prev_rating_365d': [4.592105263157895],
 'parent_asin_rating_cnt_90d': [10.0],
 'parent_asin_rating_avg_prev_rating_90d': [4.3],
 'parent_asin_rating_cnt_30d': [3.0],
 'parent_asin_rating_avg_prev_rating_30d': [5.0],
 'parent_asin_rating_cnt_7d': [1.0],
 'parent_asin_rating_avg_prev_rating_7d': [5.0],
 'parent_asin': ['0375869026']}

In [38]:
if args.log_to_mlflow:
    run_id = trainer.logger.run_id
    sample_output_np = sample_output
    signature = infer_signature(sample_input, sample_output_np)
    idm_filename = idm_fp.split("/")[-1]
    item_metadata_pipeline_filename = args.item_metadata_pipeline_fp.split("/")[-1]
    with mlflow.start_run(run_id=run_id):
        mlflow.pyfunc.log_model(
            python_model=inferrer,
            artifact_path="inferrer",
            artifacts={
                # We log the id_mapping to the predict function so that it can accept item_id and automatically convert ot item_indice for PyTorch model to use
                "idm": mlflow.get_artifact_uri(idm_filename),
                "item_metadata_pipeline": mlflow.get_artifact_uri(
                    item_metadata_pipeline_filename
                ),
            },
            model_config={"use_sbert_features": args.rc.use_sbert_features},
            signature=signature,
            input_example=sample_input,
            registered_model_name=args.mlf_model_name,
        )


Since MLflow 2.16.0, we no longer convert dictionary input example to pandas Dataframe, and directly save it as a json object. If the model expects a pandas DataFrame input instead, please pass the pandas DataFrame as input example directly.



Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Registered model 'ranker' already exists. Creating a new version of this model...
2024/11/10 16:35:00 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: ranker, version 3
Created version '3' of model 'ranker'.


Downloading artifacts:   0%|          | 0/9 [00:00<?, ?it/s]

2024/11/10 16:35:00 INFO mlflow.tracking._tracking_service.client: 🏃 View run 035-rerun-add-features-l9 at: http://localhost:5002/#/experiments/3/runs/53672526046b44969cae410b490d2857.
2024/11/10 16:35:00 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://localhost:5002/#/experiments/3.


# Set the newly trained model as champion

In [39]:
if args.log_to_mlflow:
    val_roc_auc = trainer.logger.experiment.get_run(trainer.logger.run_id).data.metrics[
        "val_roc_auc"
    ]

    if val_roc_auc > args.min_roc_auc:
        logger.info(f"Aliasing the new model as champion...")
        model_version = (
            mlf_client.get_registered_model(args.mlf_model_name)
            .latest_versions[0]
            .version
        )

        mlf_client.set_registered_model_alias(
            name=args.mlf_model_name, alias="champion", version=model_version
        )

        mlf_client.set_model_version_tag(
            name=args.mlf_model_name,
            version=model_version,
            key="author",
            value=args.author,
        )

[32m2024-11-10 16:35:00.872[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m7[0m - [1mAliasing the new model as champion...[0m


# Clean up

In [40]:
all_params = [args]

if args.log_to_mlflow:
    with mlflow.start_run(run_id=run_id):
        for params in all_params:
            params_dict = params.dict()
            params_ = dict()
            for k, v in params_dict.items():
                if k == "top_K":
                    k = "top_big_K"
                if k == "top_k":
                    k = "top_small_k"
                params_[f"{params.__repr_name__()}.{k}"] = v
            mlflow.log_params(params_)

2024/11/10 16:35:00 INFO mlflow.tracking._tracking_service.client: 🏃 View run 035-rerun-add-features-l9 at: http://localhost:5002/#/experiments/3/runs/53672526046b44969cae410b490d2857.
2024/11/10 16:35:00 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://localhost:5002/#/experiments/3.
