# Ranker that can takes into accound different features

# Set up

In [1]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

In [2]:
import os
import sys

import dill
import lightning as L
import numpy as np
import pandas as pd
import torch
from dotenv import load_dotenv
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import MLFlowLogger
from loguru import logger
from mlflow.models.signature import infer_signature
from pydantic import BaseModel
from torch.utils.data import DataLoader

import mlflow

load_dotenv()

sys.path.insert(0, "..")

from cfg.run_cfg import RunCfg
from src.ann import AnnIndex
from src.data_prep_utils import chunk_transform
from src.dataset import UserItemBinaryDFDataset
from src.id_mapper import IDMapper
from src.ranker.inference import RankerInferenceWrapper
from src.ranker.model import Ranker
from src.ranker.trainer import LitRanker
from src.viz import blueq_colors

# Controller

In [3]:
max_epochs = 100

In [4]:
class Args(BaseModel):
    testing: bool = False
    author: str = "quy.dinh"
    log_to_mlflow: bool = True
    experiment_name: str = "RecSys MVP - Ranker"
    run_name: str = "037-add-llm-item-tags"
    notebook_persist_dp: str = None
    random_seed: int = 41
    device: str = None

    rc: RunCfg = RunCfg().init()

    item_metadata_pipeline_fp: str = "../data/item_metadata_pipeline.dill"
    qdrant_url: str = None
    qdrant_collection_name: str = "item_desc_sbert"

    max_epochs: int = max_epochs
    batch_size: int = 128
    tfm_chunk_size: int = 10000
    neg_to_pos_ratio: int = 1

    user_col: str = "user_id"
    item_col: str = "parent_asin"
    rating_col: str = "rating"
    timestamp_col: str = "timestamp"

    top_K: int = 100
    top_k: int = 10

    embedding_dim: int = 128
    item_sequence_ts_bucket_size: int = 10
    bucket_embedding_dim: int = 16
    dropout: float = 0.3
    early_stopping_patience: int = 5
    learning_rate: float = 0.0003
    l2_reg: float = 1e-4

    mlf_item2vec_model_name: str = "item2vec"
    mlf_model_name: str = "ranker"
    min_roc_auc: float = 0.7

    best_checkpoint_path: str = None

    def init(self):
        self.notebook_persist_dp = os.path.abspath(f"data/{self.run_name}")
        os.makedirs(self.notebook_persist_dp, exist_ok=True)

        if not (qdrant_host := os.getenv("QDRANT_HOST")):
            raise Exception(f"Environment variable QDRANT_HOST is not set.")

        qdrant_port = os.getenv("QDRANT_PORT")
        self.qdrant_url = f"{qdrant_host}:{qdrant_port}"

        if not (mlflow_uri := os.environ.get("MLFLOW_TRACKING_URI")):
            logger.warning(
                f"Environment variable MLFLOW_TRACKING_URI is not set. Setting self.log_to_mlflow to false."
            )
            self.log_to_mlflow = False

        if self.log_to_mlflow:
            logger.info(
                f"Setting up MLflow experiment {self.experiment_name} - run {self.run_name}..."
            )
            self._mlf_logger = MLFlowLogger(
                experiment_name=self.experiment_name,
                run_name=self.run_name,
                tracking_uri=mlflow_uri,
                log_model=True,
            )

        if self.device is None:
            self.device = (
                "cuda"
                if torch.cuda.is_available()
                else "mps" if torch.backends.mps.is_available() else "cpu"
            )

        return self


args = Args().init()

print(args.model_dump_json(indent=2))

[32m2024-11-10 19:23:55.703[0m | [34m[1mDEBUG   [0m | [36mcfg.run_cfg[0m:[36minit[0m:[36m36[0m - [34m[1mSetting use_sbert_features=True requires running notebook 016-sentence-transformers[0m
[32m2024-11-10 19:23:55.704[0m | [34m[1mDEBUG   [0m | [36mcfg.run_cfg[0m:[36minit[0m:[36m40[0m - [34m[1mSetting use_item_tags_from_llm=True requires running notebook 040-retrieve-item-tags-from-llm[0m
[32m2024-11-10 19:23:55.704[0m | [34m[1mDEBUG   [0m | [36mcfg.run_cfg[0m:[36minit[0m:[36m43[0m - [34m[1mChanging use_item_tags_from_llm requires re-running notebook 002-features-v2 to get the new item_metadata_pipeline.dill file[0m
[32m2024-11-10 19:23:55.707[0m | [1mINFO    [0m | [36m__main__[0m:[36minit[0m:[36m61[0m - [1mSetting up MLflow experiment RecSys MVP - Ranker - run 037-add-llm-item-tags...[0m


{
  "testing": false,
  "author": "quy.dinh",
  "log_to_mlflow": true,
  "experiment_name": "RecSys MVP - Ranker",
  "run_name": "037-add-llm-item-tags",
  "notebook_persist_dp": "/home/dvquys/frostmourne/recsys-mvp/notebooks/data/037-add-llm-item-tags",
  "random_seed": 41,
  "device": "cuda",
  "rc": {
    "use_sbert_features": true,
    "use_item_tags_from_llm": true,
    "item_feature_cols": [
      "main_category",
      "categories",
      "price",
      "parent_asin_rating_cnt_365d",
      "parent_asin_rating_avg_prev_rating_365d",
      "parent_asin_rating_cnt_90d",
      "parent_asin_rating_avg_prev_rating_90d",
      "parent_asin_rating_cnt_30d",
      "parent_asin_rating_avg_prev_rating_30d",
      "parent_asin_rating_cnt_7d",
      "parent_asin_rating_avg_prev_rating_7d",
      "tags"
    ],
    "item_tags_from_llm_fp": "../data/item_tags_from_llm.parquet"
  },
  "item_metadata_pipeline_fp": "../data/item_metadata_pipeline.dill",
  "qdrant_url": "localhost:6333",
  "qdrant_

# Implement

In [5]:
def init_model(
    n_users,
    n_items,
    embedding_dim,
    item_sequence_ts_bucket_size,
    bucket_embedding_dim,
    item_feature_size,
    dropout,
    item_embedding=None,
):
    model = Ranker(
        n_users,
        n_items,
        embedding_dim,
        item_sequence_ts_bucket_size=item_sequence_ts_bucket_size,
        bucket_embedding_dim=bucket_embedding_dim,
        item_feature_size=item_feature_size,
        dropout=dropout,
        item_embedding=item_embedding,
    )
    return model

## Load pretrained Item2Vec embeddings

In [6]:
mlf_client = mlflow.MlflowClient()
model = mlflow.pyfunc.load_model(
    model_uri=f"models:/{args.mlf_item2vec_model_name}@champion"
)
skipgram_model = model.unwrap_python_model().model
embedding_0 = skipgram_model.embeddings(torch.tensor(0))
embedding_dim = embedding_0.size()[0]
id_mapping = model.unwrap_python_model().id_mapping
pretrained_item_embedding = skipgram_model.embeddings

Downloading artifacts:   0%|          | 0/8 [00:00<?, ?it/s]

In [7]:
assert (
    pretrained_item_embedding.embedding_dim == args.embedding_dim
), "Mismatch pretrained item_embedding dimension"

## Load vectorized item features

In [8]:
with open(args.item_metadata_pipeline_fp, "rb") as f:
    item_metadata_pipeline = dill.load(f)

## Load ANN Index

In [9]:
if args.rc.use_sbert_features:
    ann_index = AnnIndex(args.qdrant_url, args.qdrant_collection_name)
    vector = ann_index.get_vector_by_ids([0])[0]
    sbert_embedding_dim = vector.shape[0]
    logger.info(f"{sbert_embedding_dim=}")
    neighbors = ann_index.get_neighbors_by_ids([0])
    display(neighbors)

  0%|          | 0/1 [00:00<?, ?it/s]

[32m2024-11-10 19:23:56.200[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [1msbert_embedding_dim=768[0m


  0%|          | 0/1 [00:00<?, ?it/s]

[ScoredPoint(id=0, version=0, score=0.0, payload={'parent_asin': '0375869026', 'title': 'Wonder'}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=1916, version=59, score=0.3048898, payload={'parent_asin': 'B005GFPZYK', 'title': 'American Sniper: The Autobiography of the Most Lethal Sniper in U.S. Military History'}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=166, version=5, score=0.35464483, payload={'parent_asin': 'B00005OARM', 'title': 'Golden Sun'}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=3845, version=120, score=0.35932326, payload={'parent_asin': 'B0794W1LWG', 'title': 'Life is Strange: Before The Storm Limited Edition - Xbox One'}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=3896, version=121, score=0.36754838, payload={'parent_asin': 'B07CD6F5PX', 'title': 'Dragon Quest Xi: Echoes of An Elusive Age - PlayStation 4'}, vector=None, shard_key=None, order_value=None)]

# Test implementation

In [10]:
embedding_dim = 8
batch_size = 2

# Mock data
user_indices = [0, 0, 1, 2, 2]
item_indices = [0, 1, 2, 3, 4]
timestamps = [0, 1, 2, 3, 4]
ratings = [0, 4, 5, 3, 0]
item_sequences = [
    [-1, -1, 2, 3],
    [-1, -1, 2, 3],
    [-1, -1, 1, 3],
    [-1, -1, 2, 1],
    [-1, -1, 2, 1],
]
item_sequences_ts_buckets = [
    [-1, -1, 2, 3],
    [-1, -1, 2, 3],
    [-1, -1, 1, 3],
    [-1, -1, 2, 1],
    [-1, -1, 2, 1],
]
main_category = [
    "All Electronics",
    "Video Games",
    "All Electronics",
    "Video Games",
    "Unknown",
]
categories = [[], ["Headsets"], ["Video Games"], [], ["blah blah"]]
tags = categories  # Mock data not important
title = ["World of Warcraft", "DotA 2", "Diablo IV", "Football Manager 2024", "Unknown"]
description = [[], [], ["Video games blah blah"], [], ["blah blah"]]
price = ["from 14.99", "14.99", "price: 9.99", "20 dollars", "None"]
parent_asin_rating_cnt_365d = [0, 1, 2, 3, 4]
parent_asin_rating_avg_prev_rating_365d = [4.0, 3.5, 4.5, 5.0, 2.0]
parent_asin_rating_cnt_90d = [0, 1, 2, 3, 4]
parent_asin_rating_avg_prev_rating_90d = [4.0, 3.5, 4.5, 5.0, 2.0]
parent_asin_rating_cnt_30d = [0, 1, 2, 3, 4]
parent_asin_rating_avg_prev_rating_30d = [4.0, 3.5, 4.5, 5.0, 2.0]
parent_asin_rating_cnt_7d = [0, 1, 2, 3, 4]
parent_asin_rating_avg_prev_rating_7d = [4.0, 3.5, 4.5, 5.0, 2.0]

train_data = {
    "user_indice": user_indices,
    "item_indice": item_indices,
    args.timestamp_col: timestamps,
    args.rating_col: ratings,
    "item_sequence": item_sequences,
    "item_sequence_ts_bucket": item_sequences_ts_buckets,
    "main_category": main_category,
    "title": title,
    "description": description,
    "categories": categories,
    "price": price,
    "parent_asin_rating_cnt_365d": parent_asin_rating_cnt_365d,
    "parent_asin_rating_avg_prev_rating_365d": parent_asin_rating_avg_prev_rating_365d,
    "parent_asin_rating_cnt_90d": parent_asin_rating_cnt_90d,
    "parent_asin_rating_avg_prev_rating_90d": parent_asin_rating_avg_prev_rating_90d,
    "parent_asin_rating_cnt_30d": parent_asin_rating_cnt_30d,
    "parent_asin_rating_avg_prev_rating_30d": parent_asin_rating_avg_prev_rating_30d,
    "parent_asin_rating_cnt_7d": parent_asin_rating_cnt_7d,
    "parent_asin_rating_avg_prev_rating_7d": parent_asin_rating_avg_prev_rating_7d,
}

if args.rc.use_item_tags_from_llm:
    train_data["tags"] = tags

train_df = pd.DataFrame(train_data)
train_item_features = item_metadata_pipeline.transform(train_df).astype(np.float32)
if args.rc.use_sbert_features:
    sbert_vectors = ann_index.get_vector_by_ids(
        train_df["item_indice"].values.tolist()
    ).astype(np.float32)
    train_item_features = np.hstack([train_item_features, sbert_vectors])

n_users = len(set(user_indices))
n_items = len(set(item_indices))
item_feature_size = train_item_features.shape[1]

model = init_model(
    n_users,
    n_items,
    embedding_dim,
    args.item_sequence_ts_bucket_size,
    args.bucket_embedding_dim,
    item_feature_size,
    args.dropout,
)

# Example forward pass
model.eval()
users = torch.tensor(user_indices)
items = torch.tensor(item_indices)
item_sequences = torch.tensor(item_sequences)
item_sequences_ts_buckets = torch.tensor(item_sequences_ts_buckets)
item_features = torch.tensor(train_item_features)
predictions = model.predict(
    users, item_sequences, item_sequences_ts_buckets, item_features, items
)
print(predictions)
model.train()

  0%|          | 0/1 [00:00<?, ?it/s]

tensor([[0.4574],
        [0.4652],
        [0.4525],
        [0.5698],
        [0.5422]], grad_fn=<SigmoidBackward0>)


Ranker(
  (item_embedding): Embedding(6, 8, padding_idx=5)
  (user_embedding): Embedding(3, 8)
  (item_sequence_ts_bucket_embedding): Embedding(11, 16, padding_idx=10)
  (gru): GRU(24, 8, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (item_feature_tower): Sequential(
    (0): Linear(in_features=1229, out_features=8, bias=True)
    (1): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
  )
  (fc_rating): Sequential(
    (0): Linear(in_features=32, out_features=8, bias=True)
    (1): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=8, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

In [11]:
rating_dataset = UserItemBinaryDFDataset(
    train_df,
    "user_indice",
    "item_indice",
    args.rating_col,
    args.timestamp_col,
    item_feature=train_item_features,
)

train_loader = DataLoader(
    rating_dataset, batch_size=batch_size, shuffle=False, drop_last=True
)

In [12]:
for batch_input in train_loader:
    print(batch_input)

{'user': tensor([0, 0]), 'item': tensor([0, 1]), 'rating': tensor([0., 1.]), 'item_sequence': tensor([[-1, -1,  2,  3],
        [-1, -1,  2,  3]]), 'item_sequence_ts_bucket': tensor([[-1, -1,  2,  3],
        [-1, -1,  2,  3]]), 'item_feature': tensor([[-1.4698e-02,  5.6424e+00, -1.4698e-02,  ...,  2.2739e-02,
         -2.3894e-02,  1.1594e-03],
        [-1.4698e-02, -1.7723e-01, -1.4698e-02,  ..., -2.3171e-03,
         -2.9986e-02,  9.9219e-03]])}
{'user': tensor([1, 2]), 'item': tensor([2, 3]), 'rating': tensor([1., 1.]), 'item_sequence': tensor([[-1, -1,  1,  3],
        [-1, -1,  2,  1]]), 'item_sequence_ts_bucket': tensor([[-1, -1,  1,  3],
        [-1, -1,  2,  1]]), 'item_feature': tensor([[-1.4698e-02,  5.6424e+00, -1.4698e-02,  ...,  2.1281e-03,
         -3.6321e-02,  4.1318e-04],
        [-1.4698e-02, -1.7723e-01, -1.4698e-02,  ..., -4.9983e-03,
         -5.8112e-02,  4.0446e-03]])}


In [13]:
# Prepare all item features for recommendation
all_items_df = train_df.drop_duplicates(subset=["item_indice"])
all_items_indices = all_items_df["item_indice"].values
all_items_features = item_metadata_pipeline.transform(all_items_df).astype(np.float32)
if args.rc.use_sbert_features:
    all_sbert_vectors = ann_index.get_vector_by_ids(all_items_indices.tolist()).astype(
        np.float32
    )
    all_items_features = np.hstack([all_items_features, all_sbert_vectors])

lit_model = LitRanker(
    model,
    log_dir=args.notebook_persist_dp,
    all_items_indices=all_items_indices,
    all_items_features=all_items_features,
)

# train model
trainer = L.Trainer(
    default_root_dir=f"{args.notebook_persist_dp}/test",
    max_epochs=2,
    accelerator=args.device if args.device else "auto",
)
trainer.fit(
    model=lit_model, train_dataloaders=train_loader, val_dataloaders=train_loader
)

  0%|          | 0/1 [00:00<?, ?it/s]

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4060 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params | Mode 
-----------------------------------------
0 | model | Ranker | 11.2 K | train
-----------------------------------------
11.2 K    Trainable params
0         Non-trainable params
11.2 K    Total params
0.045     Total estimated model params size (MB)
15        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …

/home/dvquys/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/home/dvquys/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/home/dvquys/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=2` reached.
[32m2024-11-10 19:23:56.726[0m | [1mINFO    [0m | [36msrc.ranker.trainer[0m:[36mon_fit_end[0m:[36m171[0m - [1mLogging classification metrics...[0m


In [14]:
# After fitting
model.eval()
predictions = model.predict(
    users, item_sequences, item_sequences_ts_buckets, item_features, items
)
print(predictions)
model.train()

tensor([[0.4590],
        [0.4805],
        [0.4564],
        [0.5912],
        [0.5565]], grad_fn=<SigmoidBackward0>)


Ranker(
  (item_embedding): Embedding(6, 8, padding_idx=5)
  (user_embedding): Embedding(3, 8)
  (item_sequence_ts_bucket_embedding): Embedding(11, 16, padding_idx=10)
  (gru): GRU(24, 8, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (item_feature_tower): Sequential(
    (0): Linear(in_features=1229, out_features=8, bias=True)
    (1): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
  )
  (fc_rating): Sequential(
    (0): Linear(in_features=32, out_features=8, bias=True)
    (1): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=8, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

In [15]:
# Get the last row of each item as input for recommendations (containing the most updated item_sequence)
to_rec_df = train_df.sort_values(args.timestamp_col, ascending=False).drop_duplicates(
    subset=["user_indice"]
)
recommendations = model.recommend(
    torch.tensor(to_rec_df["user_indice"].values.tolist()),
    torch.tensor(to_rec_df["item_sequence"].values.tolist()),
    torch.tensor(to_rec_df["item_sequence_ts_bucket"].values.tolist()),
    torch.tensor(lit_model.all_items_features),
    torch.tensor(lit_model.all_items_indices),
    k=2,
    batch_size=4,
)
recommendations

Generating recommendations:   0%|          | 0/1 [00:00<?, ?it/s]

{'user_indice': [2, 2, 1, 1, 0, 0],
 'recommendation': [3, 1, 4, 3, 4, 3],
 'score': [0.591189444065094,
  0.5595508813858032,
  0.5168175101280212,
  0.4937076270580292,
  0.5425406098365784,
  0.49553555250167847]}

# Prep data

In [16]:
train_df = pd.read_parquet("../data/train_features_neg_df.parquet")
val_df = pd.read_parquet("../data/val_features_neg_df.parquet")
idm_fp = "../data/idm.json"
idm = IDMapper().load(idm_fp)

assert (
    train_df[args.user_col].map(lambda s: idm.get_user_index(s))
    != train_df["user_indice"]
).sum() == 0, "Mismatch IDM"
assert (
    val_df[args.user_col].map(lambda s: idm.get_user_index(s)) != val_df["user_indice"]
).sum() == 0, "Mismatch IDM"

if args.rc.use_item_tags_from_llm:
    assert (
        "tags" in train_df.columns
    ), "There is no column `tags` in train_df, please make sure you have run notebook 002, 020 with RunCfg.use_item_tags_from_llm=True"

In [17]:
train_df

Unnamed: 0,user_id,parent_asin,rating,timestamp,timestamp_unix,parent_asin_rating_cnt_365d,parent_asin_rating_avg_prev_rating_365d,parent_asin_rating_cnt_90d,parent_asin_rating_avg_prev_rating_90d,parent_asin_rating_cnt_30d,...,user_rating_list_10_recent_asin_timestamp,item_sequence,item_sequence_ts,item_sequence_ts_bucket,tags,main_category,title,description,categories,price
0,AG57LGJFCNNQJ6P6ABQAVUKXDUDA,B0015AARJI,0.0,2016-01-12 11:59:11.000,,76.0,4.592105,10.0,4.3,3.0,...,1452599936,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1....","[-1, -1, -1, -1, -1, -1, -1, -1, -1, 1452599936]","[-1, -1, -1, -1, -1, -1, -1, -1, -1, 0]","[Controller, Wireless, Gamepad, Ergonomic, Pla...",Video Games,PlayStation 3 Dualshock 3 Wireless Controller ...,"[Amazon.com, The Dualshock 3 wireless controll...","[Video Games, Legacy Systems, PlayStation Syst...",49.99
1,AHWG4EGOV5ZDKPETL56MAYGPLJRQ,B0BMGHMP23,0.0,2016-04-18 19:26:20.000,,,,,,,...,"1449254540,1449256005,1449257733,1452715791,14...","[3028.0, 2742.0, 2755.0, 3159.0, 3101.0, 3036....","[1449254540, 1449256005, 1449257733, 145271579...","[5, 5, 5, 5, 5, 5, 5, 5, 5, 5]","[Wireless Mouse, Gaming Precision, High DPI, C...",Computers,Logitech G502 Lightspeed Wireless Gaming Mouse...,[G502 is the best gaming mouse from Logitech G...,"[Video Games, PC, Accessories, Gaming Mice]",87.95
2,AH5PTZ2U74OZ3HT6QVUWM4CV6OVQ,B009AP23NI,0.0,2016-02-10 18:45:08.000,,9.0,4.666667,0.0,,0.0,...,"1443454097,1455129080,1455129186,1455129499,14...","[-1.0, -1.0, 3234.0, 2508.0, 2318.0, 2964.0, 1...","[-1, -1, 1443454097, 1455129080, 1455129186, 1...","[-1, -1, 5, 1, 1, 0, 0, 0, 0, 0]","[Controller, Wii U, Japanese Version, Gaming A...",Video Games,Nintendo Wii U Pro U Controller (Japanese Vers...,[Wii U PRO controller (black) (WUP-A-RSKA)],"[Video Games, Legacy Systems, Nintendo Systems...",43.99
3,AFC5XTCF5D7J3NSDITB2Z26XWWYA,B001E8WQUY,5.0,2019-05-01 21:22:39.265,1.556746e+09,0.0,,0.0,,0.0,...,"1327120514,1377289907,1402605836,1402606396,14...","[1987.0, 4569.0, 2114.0, 1606.0, 2159.0, 2279....","[1327120514, 1377289907, 1402605836, 140260639...","[8, 8, 7, 7, 7, 7, 7, 7, 6, 6]","[Music Rhythm, Band Simulation, Multiplayer, N...",Video Games,Rock Band 2 - Nintendo Wii (Game only),"[Product description, Rock Band 2 lets you and...","[Video Games, Legacy Systems, Nintendo Systems...",28.49
4,AF7LJQOIWF3Y3YD7SGOJ34MA5JPA,B001E8WQKY,5.0,2015-01-09 12:53:25.000,1.420808e+09,16.0,4.375000,8.0,4.5,4.0,...,14208077931420807991,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1....","[-1, -1, -1, -1, -1, -1, -1, -1, 1420807793, 1...","[-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]","[Survival Horror, Action, Co-op, Zombies, Thir...",Video Games,Resident Evil 5 - Xbox 360,[],"[Video Games, Legacy Systems, Xbox Systems, Xb...",29.88
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
328591,AG4RATLNVLOKZCPXN67HKOAK65CA,B078FBVJMB,0.0,2015-10-31 18:25:09.000,,,,,,,...,1425233294,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1....","[-1, -1, -1, -1, -1, -1, -1, -1, -1, 1425233294]","[-1, -1, -1, -1, -1, -1, -1, -1, -1, 5]","[Co-Op, Action-Adventure, Narrative, Online Mu...",Video Games,A Way Out – PC Origin [Online Game Code],[From the creators of Brothers - A Tale of Two...,"[Video Games, PC, Games]",5.99
328592,AFBXO3BFWBJX6QS5NW73O37IXF2A,B0771ZXXV6,0.0,2011-03-08 02:06:38.000,,,,,,,...,12995495171299549928,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1....","[-1, -1, -1, -1, -1, -1, -1, -1, 1299549517, 1...","[-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]","[Controller, Joy-Con, Nintendo Accessory, Wire...",Video Games,Nintendo Joy-Con (R) - Neon Red - Nintendo Switch,[To be determined],"[Video Games, Nintendo Switch, Accessories, Co...",
328593,AHVANA5GZNJ45UABPXWZNAF4ECBQ,B00BBF6MO6,0.0,2015-02-15 05:31:04.000,,3.0,4.666667,0.0,,0.0,...,137041433213704147071370416530,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 137...","[-1, -1, -1, -1, -1, -1, -1, 1370414332, 13704...","[-1, -1, -1, -1, -1, -1, -1, 6, 6, 6]","[Action, Hack and Slash, Stylish Gameplay, Sto...",Video Games,Killer is Dead - Xbox 360,[Killer Is Dead is the latest title from the d...,"[Video Games, Legacy Systems, Xbox Systems, Xb...",39.82
328594,AHAVA5VKMJ3OMOLGDZ3W45CKXEWA,B00KTORA0K,5.0,2019-05-25 04:03:51.505,1.558757e+09,3.0,4.666667,1.0,5.0,1.0,...,"1431150669,1431150834,1432041664,1432041986,15...","[-1.0, -1.0, -1.0, 1657.0, 2074.0, 593.0, 583....","[-1, -1, -1, 1431150669, 1431150834, 143204166...","[-1, -1, -1, 7, 7, 7, 7, 5, 5, 0]","[Dance Game, Family-friendly, Party Game, Fitn...",Video Games,Just Dance 2015 - Wii,[With more than 50 million copies of Just Danc...,"[Video Games, Legacy Systems, Nintendo Systems...",33.0


In [18]:
user_indices = train_df["user_indice"].unique()
item_indices = train_df["item_indice"].unique()
if args.rc.use_sbert_features:
    all_sbert_vectors = ann_index.get_vector_by_ids(
        item_indices.tolist(), chunk_size=1000
    ).astype(np.float32)

train_item_features = chunk_transform(
    train_df, item_metadata_pipeline, chunk_size=args.tfm_chunk_size
)
train_item_features = train_item_features.astype(np.float32)

val_item_features = chunk_transform(
    val_df, item_metadata_pipeline, chunk_size=args.tfm_chunk_size
)
val_item_features = val_item_features.astype(np.float32)

if args.rc.use_sbert_features:
    train_sbert_vectors = all_sbert_vectors[train_df["item_indice"].values]
    train_item_features = np.hstack([train_item_features, train_sbert_vectors])
    val_sbert_vectors = all_sbert_vectors[val_df["item_indice"].values]
    val_item_features = np.hstack([val_item_features, val_sbert_vectors])

logger.info(f"{len(user_indices)=:,.0f}, {len(item_indices)=:,.0f}")

  0%|          | 0/5 [00:00<?, ?it/s]

Transforming chunks:   0%|          | 0/33 [00:00<?, ?it/s]

Transforming chunks:   0%|          | 0/1 [00:00<?, ?it/s]

[32m2024-11-10 19:24:03.978[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m24[0m - [1mlen(user_indices)=19,578, len(item_indices)=4,630[0m


# Train

In [19]:
rating_dataset = UserItemBinaryDFDataset(
    train_df,
    "user_indice",
    "item_indice",
    args.rating_col,
    args.timestamp_col,
    item_feature=train_item_features,
)
val_rating_dataset = UserItemBinaryDFDataset(
    val_df,
    "user_indice",
    "item_indice",
    args.rating_col,
    args.timestamp_col,
    item_feature=val_item_features,
)

train_loader = DataLoader(
    rating_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True
)
val_loader = DataLoader(
    val_rating_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False
)

In [20]:
n_items = len(item_indices)
n_users = len(user_indices)

model = init_model(
    n_users,
    n_items,
    args.embedding_dim,
    args.item_sequence_ts_bucket_size,
    args.bucket_embedding_dim,
    item_feature_size,
    args.dropout,
)
model

Ranker(
  (item_embedding): Embedding(4631, 128, padding_idx=4630)
  (user_embedding): Embedding(19578, 128)
  (item_sequence_ts_bucket_embedding): Embedding(11, 16, padding_idx=10)
  (gru): GRU(144, 128, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (item_feature_tower): Sequential(
    (0): Linear(in_features=1229, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
  )
  (fc_rating): Sequential(
    (0): Linear(in_features=512, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=128, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

#### Predict before train

In [21]:
val_df = val_rating_dataset.df
val_df.sample(10)

Unnamed: 0,user_id,parent_asin,rating,timestamp,timestamp_unix,parent_asin_rating_cnt_365d,parent_asin_rating_avg_prev_rating_365d,parent_asin_rating_cnt_90d,parent_asin_rating_avg_prev_rating_90d,parent_asin_rating_cnt_30d,...,user_rating_list_10_recent_asin_timestamp,item_sequence,item_sequence_ts,item_sequence_ts_bucket,tags,main_category,title,description,categories,price
1746,AFWMJSQPYQIAEJY3A6RPFDHSB7BA,B001ELJE5Q,0.0,2021-10-20 09:53:52.036,,2.0,5.0,1.0,5.0,0.0,...,"1591649949,1591649979,1591650004,1596447928,16...","[-1, -1, -1, -1, 4067, 4343, 3527, 4357, 4373,...","[-1, -1, -1, -1, 1591649949, 1591649979, 15916...","[-1, -1, -1, -1, 6, 6, 6, 6, 5, 5]","[Music Game, Band Kit, Guitar Controller, Coop...",Video Games,Wii Guitar Hero World Tour Band Kit,"[Product description, We've all dreamed of bei...","[Video Games, Legacy Systems, Nintendo Systems...",
1205,AGW7IQV7JCRQATCFS46QDGRJ2PMQ,B00OQLWM8M,0.0,2021-08-22 00:57:05.420,,0.0,,0.0,,0.0,...,"1612769950,1617677646,1619503875,1619503928,16...","[4170, 1596, 4479, 254, 3723, 3988, 3903, 695,...","[1612769950, 1617677646, 1619503875, 161950392...","[5, 5, 5, 5, 5, 5, 5, 5, 5, 5]","[Role-Playing Game, Pokemon Franchise, Dual Pa...",Video Games,Pokemon Omega Ruby and Pokemon Alpha Sapphire ...,[],"[Video Games, Legacy Systems, Nintendo Systems...",
1869,AGCJRJNBQTHOSAWCI7SW5ZV3AZPQ,B004NBXRDE,0.0,2022-04-10 20:34:55.586,,1.0,5.0,0.0,,0.0,...,"1547168420,1586266416,1611119565,1615734901,16...","[-1, -1, -1, 4134, 3594, 3686, 3818, 3851, 369...","[-1, -1, -1, 1547168420, 1586266416, 161111956...","[-1, -1, -1, 7, 6, 6, 6, 5, 5, 5]","[Racing, Adventure Game, Family-Friendly, Base...",Video Games,Cars 2 3DS,"[Product Description, Inspired by the upcoming...","[Video Games, Legacy Systems, Nintendo Systems...",49.99
74,AGPNXILEFTDD373LDWNEGXEAVGGA,B0171RL3P0,0.0,2021-08-30 13:08:09.315,,0.0,,0.0,,0.0,...,"1500398091,1500398564,1500399206,1500399307,15...","[-1, -1, -1, -1, -1, 3601, 2478, 3362, 3681, 4...","[-1, -1, -1, -1, -1, 1500398091, 1500398564, 1...","[-1, -1, -1, -1, -1, 7, 7, 7, 7, 7]","[Amiibo, Collector's Item, Card Holder, Ninten...",Video Games,HORI Amiibo Card Folio Officially Licensed by ...,[Officially Licensed by Nintendo. Store and or...,"[Video Games, Legacy Systems, Nintendo Systems...",
1145,AE76SZJUUWJVL62I2Q44F37SPFXQ,B07D3JSZ8F,1.0,2021-09-27 09:37:20.508,1632735000.0,6.0,4.833333,2.0,5.0,1.0,...,"1515276861,1516560815,1521206577,1521316155,15...","[-1, -1, -1, -1, 4577, 2643, 3150, 2143, 4003,...","[-1, -1, -1, -1, 1515276861, 1516560815, 15212...","[-1, -1, -1, -1, 7, 7, 7, 7, 6, 6]","[Tactical RPG, Strategy, Crossover Game, Digit...",Video Games,Mario + Rabbids Kingdom Battle - Nintendo Swit...,[Two worlds collide in Mario + Rabbids Kingdom...,"[Video Games, Nintendo Switch, Accessories]",59.99
1071,AH2TF7A2AXRSXVAYTV2CB2OHUYHA,B00Z9TM1KY,0.0,2022-01-03 13:53:41.920,,6.0,4.666667,1.0,5.0,0.0,...,"1436247878,1441733146,1452635529,1470854964,14...","[-1, 2726, 2186, 2436, 3052, 3256, 3334, 3966,...","[-1, 1436247878, 1441733146, 1452635529, 14708...","[-1, 8, 8, 8, 8, 8, 7, 7, 6, 6]","[Open World, Action-Adventure, Single Player, ...",Video Games,Mafia III - PC,"[1968. New Bordeaux., After years of combat in...","[Video Games, PC, Games]",7.96
1518,AG7VW3N64W3DWP46K3ZPBPTHPS5A,B073SC6V1D,1.0,2022-06-24 11:00:40.306,1656068000.0,1.0,5.0,0.0,,0.0,...,"1368635174,1427214888,1427944362,1473917569,14...","[1556, 2844, 2985, 3665, 2542, 3228, 3319, 338...","[1368635174, 1427214888, 1427944362, 147391756...","[8, 8, 8, 8, 8, 8, 7, 7, 7, 7]","[Gaming Accessories, RGB Backlit, Gaming Mouse...",Computers,"havit Gaming Keyboard and Mouse Combo, Backlit...",[],"[Video Games, PC, Accessories, Gaming Keyboards]",
1212,AED7Y4FAIQCIYYVDJ6WIW2XTP4YA,B004LLHFAW,0.0,2021-09-15 20:43:08.668,,4.0,3.5,1.0,5.0,0.0,...,"1222549506,1222550100,1222550300,1222550607,12...","[827, 1043, 1703, 442, 1102, 668, 2055, 3151, ...","[1222549506, 1222550100, 1222550300, 122255060...","[9, 9, 9, 9, 9, 9, 8, 7, 4, 4]","[First-Person Shooter, Multiplayer, War Game, ...",Video Games,Battlefield 3 - Playstation 3,[Battlefield 3 leaps ahead of the competition ...,"[Video Games, Legacy Systems, PlayStation Syst...",12.45
424,AFQDH6ZKGPLWKKVNX2OJY6XPMCFA,B094YHB1QK,1.0,2022-06-06 13:41:16.029,1654523000.0,29.0,4.310345,6.0,4.333333,4.0,...,"1441052986,1504993804,1507574462,1511927540,15...","[-1, -1, 3077, 3752, 3373, 4344, 4418, 4518, 2...","[-1, -1, 1441052986, 1504993804, 1507574462, 1...","[-1, -1, 8, 7, 7, 7, 7, 7, 7, 7]","[Wireless, Controller, PlayStation 5, Ergonomi...",Video Games,PlayStation DualSense Wireless Controller – Ga...,[Plot a course for astronomical adventures on ...,"[Video Games, PlayStation 5, Accessories, Cont...",74.99
1027,AGVAG2GSFQZUAXMRSKKSGKEHGG5A,B09JY72CNG,1.0,2022-01-19 20:57:42.960,1642626000.0,2.0,4.5,1.0,4.0,1.0,...,"1357571879,1357572295,1456940615,1456940845,15...","[-1, -1, -1, 1736, 117, 1911, 3747, 4598, 3455...","[-1, -1, -1, 1357571879, 1357572295, 145694061...","[-1, -1, -1, 8, 8, 8, 8, 6, 6, 0]","[Mouse Pad, RGB Lighting, High Precision, Gami...",Computers,Razer Goliathus Extended Chroma Gaming Mouse P...,[The Razer Goliathus extended Chroma soft gami...,"[Video Games, PC, Accessories, Gaming Mice]",59.99


In [22]:
user_id = val_df.sample(1)[args.user_col].values[0]
test_df = val_df.loc[lambda df: df[args.user_col].eq(user_id)]
with pd.option_context("display.max_colwidth", None):
    display(test_df)

Unnamed: 0,user_id,parent_asin,rating,timestamp,timestamp_unix,parent_asin_rating_cnt_365d,parent_asin_rating_avg_prev_rating_365d,parent_asin_rating_cnt_90d,parent_asin_rating_avg_prev_rating_90d,parent_asin_rating_cnt_30d,...,user_rating_list_10_recent_asin_timestamp,item_sequence,item_sequence_ts,item_sequence_ts_bucket,tags,main_category,title,description,categories,price
1369,AECSORN72UHJMCT4OBY5WFFS7J2A,B0044R8X9U,0.0,2022-01-04 15:59:11.724,,1.0,2.0,0.0,,0.0,...,1537948546154058457015405860561556836554156169040515802828611580354512,"[-1, -1, -1, 3886, 4534, 3840, 4068, 4115, 4267, 4171]","[-1, -1, -1, 1537948546, 1540584570, 1540586056, 1556836554, 1561690405, 1580282861, 1580354512]","[-1, -1, -1, 7, 7, 7, 6, 6, 6, 6]","[Flight Simulator, Combat, Military, Action, Realistic]",Video Games,Ace Combat Assault Horizon - Xbox 360,"[Product Description, Developed by the Project Aces team, Ace Combat Assault Horizon intensifies the franchise, escalating combat to the next level with aircraft that are literally torn apart, spewing oil and debris across the sky. Players will engage in combat across the globe, dodging skyscrapers, and turning their enemies into fiery supersonic debris in both single player and online multiplayer. Never before has combat been so fast and in-your-face., From the Manufacturer, Sit in the pilot's seat and take flight, soaring across real-world locations and experiencing an engaging and dramatic wartime storyline!, View larger, ., STEEL CARNAGE, Developed by the Project Aces team, ACE COMBAT® ASSAULT HORIZON intensifies the franchise, escalating combat to the next level with aircraft that are literally torn apart, spewing oil and debris across the sky. Players will engage in combat across the globe, dodging skyscrapers, and turning their enemies into fiery supersonic debris in both single player and online multiplayer. Never before has combat been so fast and in-your-face., Key Game Features, Dramatic realistic storyline – Written by New York Times Best Seller and military author Jim DeFelice, players will experience an engaging war drama spanning real-world locations across the globe, Dramatic realistic storyline, – Written by New York Times Best Seller and military author Jim DeFelice, players will experience an engaging war drama spanning real-world locations across the globe, Steel carnage destruction – Incredible detail and visual reaction for every explosive attack (aircrafts are shredded to pieces, enemy troops annihilated, buildings shattered, machines bleed), Steel carnage destruction, – Incredible detail and visual reaction for every explosive attack (aircrafts are shredded to pieces, enemy troops annihilated, buildings shattered, machines bleed), Entirely new aircrafts to pilot – Experience split-second maneuvering and positioning, pinpoint targeting, hovering attacks and other gameplay diversity through the introduction of the Attack Helicopter, Door Gunner and more, Entirely new aircrafts to pilot, – Experience split-second maneuvering and positioning, pinpoint targeting, hovering attacks and other gameplay diversity through the introduction of the Attack Helicopter, Door Gunner and more, Revolutionary Close-Range Assault system – Delivering high-speed acrobatics, dizzying one-on-one encounters, satisfying visceral low-altitude and high-flying death from above, Revolutionary Close-Range Assault system, – Delivering high-speed acrobatics, dizzying one-on-one encounters, satisfying visceral low-altitude and high-flying death from above, ACE COMBAT online reinvented – Take to the skies and engage hostile forces in a variety of modes, ACE COMBAT online reinvented, – Take to the skies and engage hostile forces in a variety of modes, Additional Screenshots, :, Watch planes explode and shred to pieces in amazing detail!, View larger, ., Deliver defeat using an arsenal of weapons and aircrafts!, View larger, ., Take 'em down with the Attack Helicopter!, View larger, ., Play with your friends to determine the king of the skies!, View larger, .]","[Video Games, Legacy Systems, Xbox Systems, Xbox 360, Games]",12.99
1414,AECSORN72UHJMCT4OBY5WFFS7J2A,B0BFT941YQ,1.0,2022-01-04 15:59:11.724,1641312000.0,4.0,5.0,1.0,5.0,0.0,...,1537948546154058457015405860561556836554156169040515802828611580354512,"[-1, -1, -1, 3886, 4534, 3840, 4068, 4115, 4267, 4171]","[-1, -1, -1, 1537948546, 1540584570, 1540586056, 1556836554, 1561690405, 1580282861, 1580354512]","[-1, -1, -1, 7, 7, 7, 6, 6, 6, 6]","[Action RPG, Anime, Single-player, Dragon Ball Franchise, Open World]",Video Games,DRAGON BALL Z: Kakarot - PlayStation 5,"[Relive the story of Goku and other Z Fighters in DRAGON BALL Z: KAKAROT! Beyond the epic battles, experience life in the DRAGON BALL Z world as you fight, fish, eat, and train with Goku, Gohan, Vegeta and others. Explore the new areas and adventures as you advance through the story and form powerful bonds with other heroes from the DRAGON BALL Z universe.]","[Video Games, PlayStation 5, Games]",19.99


In [23]:
test_row = test_df.loc[lambda df: df[args.rating_col].gt(0)].iloc[0]
item_id = test_row[args.item_col]
item_sequence = test_row["item_sequence"]
item_sequence_ts_bucket = test_row["item_sequence_ts_bucket"]
row_idx = test_row.name
item_feature = val_item_features[row_idx]
logger.info(
    f"Test predicting before training with {args.user_col} = {user_id} and {args.item_col} = {item_id}"
)
user_indice = idm.get_user_index(user_id)
item_indice = idm.get_item_index(item_id)
user = torch.tensor([user_indice])
item_sequence = torch.tensor([item_sequence])
item_sequence_ts_bucket = torch.tensor([item_sequence_ts_bucket])
item_feature = torch.tensor([item_feature])
item = torch.tensor([item_indice])

model.eval()
model.predict(user, item_sequence, item_sequence_ts_bucket, item_feature, item)
model.train()

[32m2024-11-10 19:24:04.159[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m7[0m - [1mTest predicting before training with user_id = AECSORN72UHJMCT4OBY5WFFS7J2A and parent_asin = B0BFT941YQ[0m

Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:278.)



Ranker(
  (item_embedding): Embedding(4631, 128, padding_idx=4630)
  (user_embedding): Embedding(19578, 128)
  (item_sequence_ts_bucket_embedding): Embedding(11, 16, padding_idx=10)
  (gru): GRU(144, 128, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (item_feature_tower): Sequential(
    (0): Linear(in_features=1229, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
  )
  (fc_rating): Sequential(
    (0): Linear(in_features=512, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=128, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

#### Training loop

##### Overfit 1 batch

In [24]:
early_stopping = EarlyStopping(
    monitor="val_loss", patience=10, mode="min", verbose=False
)

model = init_model(
    n_users,
    n_items,
    args.embedding_dim,
    args.item_sequence_ts_bucket_size,
    args.bucket_embedding_dim,
    item_feature_size,
    dropout=0,
)
lit_model = LitRanker(
    model,
    learning_rate=args.learning_rate,
    l2_reg=0.0,
    log_dir=args.notebook_persist_dp,
    accelerator=args.device,
)

log_dir = f"{args.notebook_persist_dp}/logs/overfit"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    accelerator=args.device if args.device else "auto",
    max_epochs=100,
    overfit_batches=1,
    callbacks=[early_stopping],
)
trainer.fit(
    model=lit_model,
    train_dataloaders=train_loader,
    val_dataloaders=train_loader,
)
logger.info(f"Logs available at {trainer.log_dir}")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(overfit_batches=1)` was configured so 1 batch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params | Mode 
-----------------------------------------
0 | model | Ranker | 3.4 M  | train
-----------------------------------------
3.4 M     Trainable params
0         Non-trainable params
3.4 M     Total params
13.712    Total estimated model params size (MB)
15        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …


You requested to overfit but enabled val dataloader shuffling. We are turning off the val dataloader shuffling for you.


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


You requested to overfit but enabled train dataloader shuffling. We are turning off the train dataloader shuffling for you.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.



Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=100` reached.
[32m2024-11-10 19:24:12.241[0m | [1mINFO    [0m | [36msrc.ranker.trainer[0m:[36mon_fit_end[0m:[36m171[0m - [1mLogging classification metrics...[0m
[32m2024-11-10 19:24:30.729[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m37[0m - [1mLogs available at /home/dvquys/frostmourne/recsys-mvp/notebooks/data/037-add-llm-item-tags/logs/overfit/lightning_logs/version_0[0m


In [25]:
%tensorboard --logdir $trainer.log_dir --host localhost

##### Fit on all data

In [26]:
all_items_df = train_df.drop_duplicates(subset=["item_indice"])
all_items_indices = all_items_df["item_indice"].values
all_items_features = item_metadata_pipeline.transform(all_items_df).astype(np.float32)
logger.info(
    f"Mean std over categorical and numerical features: {all_items_features.std(axis=0).mean()}"
)
if args.rc.use_sbert_features:
    all_sbert_vectors = ann_index.get_vector_by_ids(all_items_indices.tolist()).astype(
        np.float32
    )
    logger.info(f"Mean std over text features: {all_sbert_vectors.std(axis=0).mean()}")
    all_items_features = np.hstack([all_items_features, all_sbert_vectors])

[32m2024-11-10 19:24:31.841[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mMean std over categorical and numerical features: 0.9988587498664856[0m


  0%|          | 0/47 [00:00<?, ?it/s]

[32m2024-11-10 19:24:33.073[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mMean std over text features: 0.015866756439208984[0m


In [27]:
# papermill_description=fit-model
early_stopping = EarlyStopping(
    monitor="val_loss", patience=args.early_stopping_patience, mode="min", verbose=False
)

checkpoint_callback = ModelCheckpoint(
    dirpath=f"{args.notebook_persist_dp}/checkpoints",
    filename="best-checkpoint",
    save_top_k=1,
    monitor="val_loss",
    mode="min",
)

model = init_model(
    n_users,
    n_items,
    args.embedding_dim,
    args.item_sequence_ts_bucket_size,
    args.bucket_embedding_dim,
    item_feature_size,
    dropout=args.dropout,
    item_embedding=pretrained_item_embedding,
)
lit_model = LitRanker(
    model,
    learning_rate=args.learning_rate,
    l2_reg=args.l2_reg,
    log_dir=args.notebook_persist_dp,
    evaluate_ranking=True,
    idm=idm,
    all_items_indices=all_items_indices,
    all_items_features=all_items_features,
    args=args,
    neg_to_pos_ratio=args.neg_to_pos_ratio,
    checkpoint_callback=checkpoint_callback,
    accelerator=args.device,
)

In [28]:
log_dir = f"{args.notebook_persist_dp}/logs/run"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    max_epochs=args.max_epochs,
    callbacks=[early_stopping, checkpoint_callback],
    accelerator=args.device if args.device else "auto",
    logger=args._mlf_logger if args.log_to_mlflow else None,
)
trainer.fit(
    model=lit_model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params | Mode 
-----------------------------------------
0 | model | Ranker | 3.4 M  | train
-----------------------------------------
3.4 M     Trainable params
0         Non-trainable params
3.4 M     Total params
13.712    Total estimated model params size (MB)
15        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.



Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

[32m2024-11-10 19:29:15.787[0m | [1mINFO    [0m | [36msrc.ranker.trainer[0m:[36mon_fit_end[0m:[36m164[0m - [1mLoading best model from /home/dvquys/frostmourne/recsys-mvp/notebooks/data/037-add-llm-item-tags/checkpoints/best-checkpoint.ckpt...[0m
[32m2024-11-10 19:29:15.967[0m | [1mINFO    [0m | [36msrc.ranker.trainer[0m:[36mon_fit_end[0m:[36m171[0m - [1mLogging classification metrics...[0m
[32m2024-11-10 19:29:16.469[0m | [1mINFO    [0m | [36msrc.ranker.trainer[0m:[36mon_fit_end[0m:[36m174[0m - [1mLogging ranking metrics...[0m


Generating recommendations:   0%|          | 0/177 [00:00<?, ?it/s]

2024/11/10 19:29:22 INFO mlflow.tracking._tracking_service.client: 🏃 View run 037-add-llm-item-tags at: http://localhost:5002/#/experiments/3/runs/0fa27d4da91a41f0bb5b3bce829e506b.
2024/11/10 19:29:22 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://localhost:5002/#/experiments/3.


In [29]:
logger.info(
    f"Test predicting after training with {args.user_col} = {user_id} and {args.item_col} = {item_id}"
)
model.eval()
model = model.to(user.device)  # Move model back to align with data device
model.predict(user, item_sequence, item_sequence_ts_bucket, item_feature, item)
model.train()

[32m2024-11-10 19:29:22.930[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mTest predicting after training with user_id = AECSORN72UHJMCT4OBY5WFFS7J2A and parent_asin = B0BFT941YQ[0m


Ranker(
  (item_embedding): Embedding(4631, 128, padding_idx=4630)
  (user_embedding): Embedding(19578, 128)
  (item_sequence_ts_bucket_embedding): Embedding(11, 16, padding_idx=10)
  (gru): GRU(144, 128, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (item_feature_tower): Sequential(
    (0): Linear(in_features=1229, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
  )
  (fc_rating): Sequential(
    (0): Linear(in_features=512, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=128, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

# Load best checkpoint

In [30]:
logger.info(f"Loading best checkpoint from {checkpoint_callback.best_model_path}...")
args.best_checkpoint_path = checkpoint_callback.best_model_path

best_trainer = LitRanker.load_from_checkpoint(
    checkpoint_callback.best_model_path,
    model=init_model(
        n_users,
        n_items,
        args.embedding_dim,
        args.item_sequence_ts_bucket_size,
        args.bucket_embedding_dim,
        item_feature_size,
        dropout=0,
        item_embedding=pretrained_item_embedding,
    ),
)

[32m2024-11-10 19:29:22.958[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mLoading best checkpoint from /home/dvquys/frostmourne/recsys-mvp/notebooks/data/037-add-llm-item-tags/checkpoints/best-checkpoint.ckpt...[0m


In [31]:
best_model = best_trainer.model.to(lit_model.device)

In [32]:
best_model.eval()
best_model.predict(user, item_sequence, item_sequence_ts_bucket, item_feature, item)
best_model.train()

Ranker(
  (item_embedding): Embedding(4631, 128, padding_idx=4630)
  (user_embedding): Embedding(19578, 128)
  (item_sequence_ts_bucket_embedding): Embedding(11, 16, padding_idx=10)
  (gru): GRU(144, 128, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0, inplace=False)
  (item_feature_tower): Sequential(
    (0): Linear(in_features=1229, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0, inplace=False)
  )
  (fc_rating): Sequential(
    (0): Linear(in_features=512, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0, inplace=False)
    (4): Linear(in_features=128, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

### Persist artifacts

In [33]:
if args.log_to_mlflow:
    # Persist id_mapping so that at inference we can predict based on item_ids (string) instead of item_index
    run_id = trainer.logger.run_id
    mlf_client = trainer.logger.experiment
    mlf_client.log_artifact(run_id, idm_fp)
    # Persist item_feature_metadata pipeline
    mlf_client.log_artifact(run_id, args.item_metadata_pipeline_fp)

    # Persist model architecture
    model_architecture_fp = f"{args.notebook_persist_dp}/model_architecture.txt"
    with open(model_architecture_fp, "w") as f:
        f.write(repr(model))
    mlf_client.log_artifact(run_id, model_architecture_fp)

### Wrap inference function and register best checkpoint as MLflow model

In [34]:
inferrer = RankerInferenceWrapper(best_model)

In [35]:
def generate_sample_item_features():
    sample_row = train_df.iloc[0].fillna(0)
    output = dict()
    for col in args.rc.item_feature_cols:
        v = sample_row[col]
        if isinstance(v, np.ndarray):
            v = "__".join(
                sample_row[col].tolist()
            )  # Workaround to avoid MLflow Got error: Per-column arrays must each be 1-dimensional
        output[col] = [v]
    return output

In [36]:
sample_input = {
    args.user_col: [idm.get_user_id(0)],
    "item_sequence": [",".join([idm.get_item_id(0), idm.get_item_id(1)])],
    "item_sequence_ts": [
        "1095133116,109770848"
    ],  # Here we input unix timestamp seconds instead of timestamp bucket because we need to calculate the bucket
    # **{col: [train_df.iloc[0].fillna(0)[col]] for col in args.item_feature_cols},
    **generate_sample_item_features(),
    args.item_col: [idm.get_item_id(0)],
}
sample_output = inferrer.infer([0], [[0, 1]], [[2, 0]], [train_item_features[0]], [0])
sample_output

array([0.8922505], dtype=float32)

In [37]:
sample_input

{'user_id': ['AE225O22SA7DLBOGOEIFL7FT5VYQ'],
 'item_sequence': ['0375869026,9625990674'],
 'item_sequence_ts': ['1095133116,109770848'],
 'main_category': ['Video Games'],
 'categories': ['Video Games__Legacy Systems__PlayStation Systems__PlayStation 3__Accessories__Controllers'],
 'price': ['49.99'],
 'parent_asin_rating_cnt_365d': [76.0],
 'parent_asin_rating_avg_prev_rating_365d': [4.592105263157895],
 'parent_asin_rating_cnt_90d': [10.0],
 'parent_asin_rating_avg_prev_rating_90d': [4.3],
 'parent_asin_rating_cnt_30d': [3.0],
 'parent_asin_rating_avg_prev_rating_30d': [5.0],
 'parent_asin_rating_cnt_7d': [1.0],
 'parent_asin_rating_avg_prev_rating_7d': [5.0],
 'tags': ['Controller__Wireless__Gamepad__Ergonomic__PlayStation 3 Accessory'],
 'parent_asin': ['0375869026']}

In [38]:
if args.log_to_mlflow:
    run_id = trainer.logger.run_id
    sample_output_np = sample_output
    signature = infer_signature(sample_input, sample_output_np)
    idm_filename = idm_fp.split("/")[-1]
    item_metadata_pipeline_filename = args.item_metadata_pipeline_fp.split("/")[-1]
    with mlflow.start_run(run_id=run_id):
        mlflow.pyfunc.log_model(
            python_model=inferrer,
            artifact_path="inferrer",
            artifacts={
                # We log the id_mapping to the predict function so that it can accept item_id and automatically convert ot item_indice for PyTorch model to use
                "idm": mlflow.get_artifact_uri(idm_filename),
                "item_metadata_pipeline": mlflow.get_artifact_uri(
                    item_metadata_pipeline_filename
                ),
            },
            model_config={"use_sbert_features": args.rc.use_sbert_features},
            signature=signature,
            input_example=sample_input,
            registered_model_name=args.mlf_model_name,
        )


Since MLflow 2.16.0, we no longer convert dictionary input example to pandas Dataframe, and directly save it as a json object. If the model expects a pandas DataFrame input instead, please pass the pandas DataFrame as input example directly.



Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Registered model 'ranker' already exists. Creating a new version of this model...
2024/11/10 19:29:26 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: ranker, version 5
Created version '5' of model 'ranker'.


Downloading artifacts:   0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

2024/11/10 19:29:26 INFO mlflow.tracking._tracking_service.client: 🏃 View run 037-add-llm-item-tags at: http://localhost:5002/#/experiments/3/runs/0fa27d4da91a41f0bb5b3bce829e506b.
2024/11/10 19:29:26 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://localhost:5002/#/experiments/3.


# Set the newly trained model as champion

In [39]:
if args.log_to_mlflow:
    val_roc_auc = trainer.logger.experiment.get_run(trainer.logger.run_id).data.metrics[
        "val_roc_auc"
    ]

    if val_roc_auc > args.min_roc_auc:
        logger.info(f"Aliasing the new model as champion...")
        model_version = (
            mlf_client.get_registered_model(args.mlf_model_name)
            .latest_versions[0]
            .version
        )

        mlf_client.set_registered_model_alias(
            name=args.mlf_model_name, alias="champion", version=model_version
        )

        mlf_client.set_model_version_tag(
            name=args.mlf_model_name,
            version=model_version,
            key="author",
            value=args.author,
        )

[32m2024-11-10 19:29:26.786[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m7[0m - [1mAliasing the new model as champion...[0m


# Clean up

In [40]:
all_params = [args]

if args.log_to_mlflow:
    with mlflow.start_run(run_id=run_id):
        for params in all_params:
            params_dict = params.dict()
            params_ = dict()
            for k, v in params_dict.items():
                if k == "top_K":
                    k = "top_big_K"
                if k == "top_k":
                    k = "top_small_k"
                params_[f"{params.__repr_name__()}.{k}"] = v
            mlflow.log_params(params_)

2024/11/10 19:29:26 INFO mlflow.tracking._tracking_service.client: 🏃 View run 037-add-llm-item-tags at: http://localhost:5002/#/experiments/3/runs/0fa27d4da91a41f0bb5b3bce829e506b.
2024/11/10 19:29:26 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://localhost:5002/#/experiments/3.
