# Ranker that can takes into accound different features

# Set up

In [1]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

In [2]:
import os
import sys
from typing import List

import dill
import lightning as L
import numpy as np
import pandas as pd
import torch
from dotenv import load_dotenv
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import MLFlowLogger
from loguru import logger
from mlflow.models.signature import infer_signature
from pydantic import BaseModel
from torch.utils.data import DataLoader

import mlflow

load_dotenv()

sys.path.insert(0, "..")

from src.ann import AnnIndex
from src.data_prep_utils import chunk_transform
from src.dataset import UserItemBinaryDFDataset
from src.id_mapper import IDMapper
from src.ranker.inference import RankerInferenceWrapper
from src.ranker.model import Ranker
from src.ranker.trainer import LitRanker
from src.viz import blueq_colors

# Controller

In [3]:
max_epochs = 100

In [4]:
class Args(BaseModel):
    testing: bool = False
    log_to_mlflow: bool = True
    experiment_name: str = "RecSys MVP - Ranker"
    run_name: str = "024-l2-reg-to-0.0003"
    notebook_persist_dp: str = None
    random_seed: int = 41
    device: str = None

    # Feature flagS
    use_sbert_features: bool = True

    item_metadata_pipeline_fp: str = "../data/item_metadata_pipeline.dill"
    qdrant_url: str = None
    qdrant_collection_name: str = "item_desc_sbert"

    max_epochs: int = max_epochs
    batch_size: int = 128
    tfm_chunk_size: int = 10000
    neg_to_pos_ratio: int = 1

    user_col: str = "user_id"
    item_col: str = "parent_asin"
    rating_col: str = "rating"
    timestamp_col: str = "timestamp"
    item_feature_cols: List[str] = [
        "main_category",
        "categories",
        "price",
        "parent_asin_rating_cnt_365d",
        "parent_asin_rating_avg_prev_rating_365d",
        "parent_asin_rating_cnt_90d",
        "parent_asin_rating_avg_prev_rating_90d",
        "parent_asin_rating_cnt_30d",
        "parent_asin_rating_avg_prev_rating_30d",
        "parent_asin_rating_cnt_7d",
        "parent_asin_rating_avg_prev_rating_7d",
    ]

    top_K: int = 100
    top_k: int = 10

    embedding_dim: int = 128
    dropout: float = 0.3
    early_stopping_patience: int = 5
    learning_rate: float = 0.0003
    l2_reg: float = 3e-4

    mlf_item2vec_model_name: str = "item2vec"
    mlf_model_name: str = "ranker"
    min_roc_auc: float = 0.7

    best_checkpoint_path: str = None

    def init(self):
        self.notebook_persist_dp = os.path.abspath(f"data/{self.run_name}")
        os.makedirs(self.notebook_persist_dp, exist_ok=True)

        if not (qdrant_host := os.getenv("QDRANT_HOST")):
            raise Exception(f"Environment variable QDRANT_HOST is not set.")

        qdrant_port = os.getenv("QDRANT_PORT")
        self.qdrant_url = f"{qdrant_host}:{qdrant_port}"

        if not (mlflow_uri := os.environ.get("MLFLOW_TRACKING_URI")):
            logger.warning(
                f"Environment variable MLFLOW_TRACKING_URI is not set. Setting self.log_to_mlflow to false."
            )
            self.log_to_mlflow = False

        if self.log_to_mlflow:
            logger.info(
                f"Setting up MLflow experiment {self.experiment_name} - run {self.run_name}..."
            )
            self._mlf_logger = MLFlowLogger(
                experiment_name=self.experiment_name,
                run_name=self.run_name,
                tracking_uri=mlflow_uri,
                log_model=True,
            )

        return self


args = Args().init()

print(args.model_dump_json(indent=2))

[32m2024-10-28 02:35:28.873[0m | [1mINFO    [0m | [36m__main__[0m:[36minit[0m:[36m72[0m - [1mSetting up MLflow experiment RecSys MVP - Ranker - run 023-dropout-to-0.4...[0m


{
  "testing": false,
  "log_to_mlflow": true,
  "experiment_name": "RecSys MVP - Ranker",
  "run_name": "023-dropout-to-0.4",
  "notebook_persist_dp": "/Users/dvq/frostmourne/recsys-mvp/notebooks/data/023-dropout-to-0.4",
  "random_seed": 41,
  "device": null,
  "use_sbert_features": true,
  "item_metadata_pipeline_fp": "../data/item_metadata_pipeline.dill",
  "qdrant_url": "localhost:6333",
  "qdrant_collection_name": "item_desc_sbert",
  "max_epochs": 100,
  "batch_size": 128,
  "tfm_chunk_size": 10000,
  "neg_to_pos_ratio": 1,
  "user_col": "user_id",
  "item_col": "parent_asin",
  "rating_col": "rating",
  "timestamp_col": "timestamp",
  "item_feature_cols": [
    "main_category",
    "categories",
    "price",
    "parent_asin_rating_cnt_365d",
    "parent_asin_rating_avg_prev_rating_365d",
    "parent_asin_rating_cnt_90d",
    "parent_asin_rating_avg_prev_rating_90d",
    "parent_asin_rating_cnt_30d",
    "parent_asin_rating_avg_prev_rating_30d",
    "parent_asin_rating_cnt_7d",

# Implement

In [5]:
def init_model(
    n_users, n_items, embedding_dim, item_feature_size, dropout, item_embedding=None
):
    model = Ranker(
        n_users,
        n_items,
        embedding_dim,
        item_feature_size=item_feature_size,
        dropout=dropout,
        item_embedding=item_embedding,
    )
    return model

## Load pretrained Item2Vec embeddings

In [6]:
mlf_client = mlflow.MlflowClient()
model = mlflow.pyfunc.load_model(
    model_uri=f"models:/{args.mlf_item2vec_model_name}@champion"
)
skipgram_model = model.unwrap_python_model().model
embedding_0 = skipgram_model.embeddings(torch.tensor(0))
embedding_dim = embedding_0.size()[0]
id_mapping = model.unwrap_python_model().id_mapping
pretrained_item_embedding = skipgram_model.embeddings

Downloading artifacts:   0%|          | 0/8 [00:00<?, ?it/s]

In [7]:
assert (
    pretrained_item_embedding.embedding_dim == args.embedding_dim
), "Mismatch pretrained item_embedding dimension"

## Load vectorized item features

In [8]:
with open(args.item_metadata_pipeline_fp, "rb") as f:
    item_metadata_pipeline = dill.load(f)

## Load ANN Index

In [9]:
ann_index = AnnIndex(args.qdrant_url, args.qdrant_collection_name)

In [10]:
vector = ann_index.get_vector_by_ids([0])[0]
embedding_dim = vector.shape[0]

  0%|          | 0/1 [00:00<?, ?it/s]

In [11]:
sbert_embedding_dim = vector.shape[0]
neighbors = ann_index.get_neighbors_by_ids([0])
neighbors

  0%|          | 0/1 [00:00<?, ?it/s]

[ScoredPoint(id=0, version=0, score=1.0, payload={'parent_asin': '0375869026', 'title': 'Wonder'}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=1916, version=59, score=0.9272537, payload={'parent_asin': 'B005GFPZYK', 'title': 'American Sniper: The Autobiography of the Most Lethal Sniper in U.S. Military History'}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=1968, version=61, score=0.9080587, payload={'parent_asin': 'B005ZBO4VA', 'title': 'Tekken Hybrid - Playstation 3'}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=166, version=5, score=0.899632, payload={'parent_asin': 'B00005OARM', 'title': 'Golden Sun'}, vector=None, shard_key=None, order_value=None),
 ScoredPoint(id=3896, version=121, score=0.89472985, payload={'parent_asin': 'B07CD6F5PX', 'title': 'Dragon Quest Xi: Echoes of An Elusive Age - PlayStation 4'}, vector=None, shard_key=None, order_value=None)]

# Test implementation

In [12]:
embedding_dim = 8
batch_size = 2

# Mock data
user_indices = [0, 0, 1, 2, 2]
item_indices = [0, 1, 2, 3, 4]
timestamps = [0, 1, 2, 3, 4]
ratings = [0, 4, 5, 3, 0]
item_sequences = [
    [-1, -1, 2, 3],
    [-1, -1, 2, 3],
    [-1, -1, 1, 3],
    [-1, -1, 2, 1],
    [-1, -1, 2, 1],
]
main_category = [
    "All Electronics",
    "Video Games",
    "All Electronics",
    "Video Games",
    "Unknown",
]
categories = [[], ["Headsets"], ["Video Games"], [], ["blah blah"]]
title = ["World of Warcraft", "DotA 2", "Diablo IV", "Football Manager 2024", "Unknown"]
description = [[], [], ["Video games blah blah"], [], ["blah blah"]]
price = ["from 14.99", "14.99", "price: 9.99", "20 dollars", "None"]
parent_asin_rating_cnt_365d = [0, 1, 2, 3, 4]
parent_asin_rating_avg_prev_rating_365d = [4.0, 3.5, 4.5, 5.0, 2.0]
parent_asin_rating_cnt_90d = [0, 1, 2, 3, 4]
parent_asin_rating_avg_prev_rating_90d = [4.0, 3.5, 4.5, 5.0, 2.0]
parent_asin_rating_cnt_30d = [0, 1, 2, 3, 4]
parent_asin_rating_avg_prev_rating_30d = [4.0, 3.5, 4.5, 5.0, 2.0]
parent_asin_rating_cnt_7d = [0, 1, 2, 3, 4]
parent_asin_rating_avg_prev_rating_7d = [4.0, 3.5, 4.5, 5.0, 2.0]

train_df = pd.DataFrame(
    {
        "user_indice": user_indices,
        "item_indice": item_indices,
        args.timestamp_col: timestamps,
        args.rating_col: ratings,
        "item_sequence": item_sequences,
        "main_category": main_category,
        "title": title,
        "description": description,
        "categories": categories,
        "price": price,
        "parent_asin_rating_cnt_365d": parent_asin_rating_cnt_365d,
        "parent_asin_rating_avg_prev_rating_365d": parent_asin_rating_avg_prev_rating_365d,
        "parent_asin_rating_cnt_90d": parent_asin_rating_cnt_90d,
        "parent_asin_rating_avg_prev_rating_90d": parent_asin_rating_avg_prev_rating_90d,
        "parent_asin_rating_cnt_30d": parent_asin_rating_cnt_30d,
        "parent_asin_rating_avg_prev_rating_30d": parent_asin_rating_avg_prev_rating_30d,
        "parent_asin_rating_cnt_7d": parent_asin_rating_cnt_7d,
        "parent_asin_rating_avg_prev_rating_7d": parent_asin_rating_avg_prev_rating_7d,
    }
)
train_item_features = item_metadata_pipeline.transform(train_df).astype(np.float32)
if args.use_sbert_features:
    sbert_vectors = ann_index.get_vector_by_ids(
        train_df["item_indice"].values.tolist()
    ).astype(np.float32)
    train_item_features = np.hstack([train_item_features, sbert_vectors])

n_users = len(set(user_indices))
n_items = len(set(item_indices))
item_feature_size = train_item_features.shape[1]

model = init_model(n_users, n_items, embedding_dim, item_feature_size, args.dropout)

# Example forward pass
model.eval()
users = torch.tensor(user_indices)
items = torch.tensor(item_indices)
item_sequences = torch.tensor(item_sequences)
item_features = torch.tensor(train_item_features)
predictions = model.predict(users, item_sequences, item_features, items)
print(predictions)

  0%|          | 0/1 [00:00<?, ?it/s]

tensor([[0.5879],
        [0.5949],
        [0.5327],
        [0.5072],
        [0.5267]], grad_fn=<SigmoidBackward0>)


In [13]:
rating_dataset = UserItemBinaryDFDataset(
    train_df,
    "user_indice",
    "item_indice",
    args.rating_col,
    args.timestamp_col,
    item_feature=train_item_features,
)

train_loader = DataLoader(rating_dataset, batch_size=batch_size, shuffle=False)

In [14]:
for batch_input in train_loader:
    print(batch_input)

{'user': tensor([0, 0]), 'item': tensor([0, 1]), 'rating': tensor([0., 1.]), 'item_sequence': tensor([[-1, -1,  2,  3],
        [-1, -1,  2,  3]]), 'item_feature': tensor([[-1.4698e-02,  5.6424e+00, -1.4698e-02,  ...,  2.9136e-02,
         -3.0617e-02,  1.4856e-03],
        [-1.4698e-02, -1.7723e-01, -1.4698e-02,  ..., -2.6627e-03,
         -3.4459e-02,  1.1402e-02]])}
{'user': tensor([1, 2]), 'item': tensor([2, 3]), 'rating': tensor([1., 1.]), 'item_sequence': tensor([[-1, -1,  1,  3],
        [-1, -1,  2,  1]]), 'item_feature': tensor([[-1.4698e-02,  5.6424e+00, -1.4698e-02,  ...,  2.4071e-03,
         -4.1083e-02,  4.6736e-04],
        [-1.4698e-02, -1.7723e-01, -1.4698e-02,  ..., -5.8320e-03,
         -6.7804e-02,  4.7192e-03]])}
{'user': tensor([2]), 'item': tensor([4]), 'rating': tensor([0.]), 'item_sequence': tensor([[-1, -1,  2,  1]]), 'item_feature': tensor([[-1.4698e-02, -1.7723e-01, -1.4698e-02, -2.5463e-02, -1.4698e-02,
         -1.4698e-02, -2.0788e-02, -8.2101e-02, -2.872

In [15]:
# Prepare all item features for recommendation
all_items_df = train_df.drop_duplicates(subset=["item_indice"])
all_items_indices = all_items_df["item_indice"].values
all_items_features = item_metadata_pipeline.transform(all_items_df).astype(np.float32)
if args.use_sbert_features:
    all_sbert_vectors = ann_index.get_vector_by_ids(all_items_indices.tolist()).astype(
        np.float32
    )
    all_items_features = np.hstack([all_items_features, all_sbert_vectors])

lit_model = LitRanker(
    model,
    log_dir=args.notebook_persist_dp,
    all_items_indices=all_items_indices,
    all_items_features=all_items_features,
)

# train model
trainer = L.Trainer(
    default_root_dir=f"{args.notebook_persist_dp}/test",
    max_epochs=2,
    accelerator=args.device if args.device else "auto",
)
trainer.fit(
    model=lit_model, train_dataloaders=train_loader, val_dataloaders=train_loader
)

  0%|          | 0/1 [00:00<?, ?it/s]

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name  | Type   | Params | Mode
----------------------------------------
0 | model | Ranker | 8.2 K  | eval
----------------------------------------
8.2 K     Trainable params
0         Non-trainable params
8.2 K     Total params
0.033     Total estimated model params size (MB)
0         Modules in train mode
14        Modules in eval mode


Sanity Checking: |                                                                                            …

/Users/dvq/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.
/Users/dvq/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.
/Users/dvq/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (3) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=2` reached.
[32m2024-10-28 02:35:30.609[0m | [1mINFO    [0m | [36msrc.ranker.trainer[0m:[36mon_fit_end[0m:[36m158[0m - [1mLogging classification metrics...[0m
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [16]:
# After fitting
model.eval()
predictions = model.predict(users, item_sequences, item_features, items)
print(predictions)

tensor([[0.5808],
        [0.5949],
        [0.5346],
        [0.5208],
        [0.5225]], grad_fn=<SigmoidBackward0>)


In [17]:
# Get the last row of each item as input for recommendations (containing the most updated item_sequence)
to_rec_df = train_df.sort_values(args.timestamp_col, ascending=False).drop_duplicates(
    subset=["user_indice"]
)
recommendations = model.recommend(
    torch.tensor(to_rec_df["user_indice"].values.tolist()),
    torch.tensor(to_rec_df["item_sequence"].values.tolist()),
    torch.tensor(lit_model.all_items_features),
    torch.tensor(lit_model.all_items_indices),
    k=2,
    batch_size=4,
)
recommendations

Generating recommendations:   0%|          | 0/1 [00:00<?, ?it/s]

{'user_indice': [2, 2, 1, 1, 0, 0],
 'recommendation': [1, 0, 1, 0, 1, 0],
 'score': [0.5539135932922363,
  0.5513929128646851,
  0.5725988149642944,
  0.5582187175750732,
  0.5949248671531677,
  0.5808430910110474]}

# Prep data

In [18]:
train_df = pd.read_parquet("../data/train_features_neg_df.parquet")
val_df = pd.read_parquet("../data/val_features_neg_df.parquet")
idm_fp = "../data/idm.json"
idm = IDMapper().load(idm_fp)

assert (
    train_df[args.user_col].map(lambda s: idm.get_user_index(s))
    != train_df["user_indice"]
).sum() == 0, "Mismatch IDM"
assert (
    val_df[args.user_col].map(lambda s: idm.get_user_index(s)) != val_df["user_indice"]
).sum() == 0, "Mismatch IDM"

In [19]:
train_df

Unnamed: 0,user_id,parent_asin,rating,timestamp,parent_asin_rating_cnt_365d,parent_asin_rating_avg_prev_rating_365d,parent_asin_rating_cnt_90d,parent_asin_rating_avg_prev_rating_90d,parent_asin_rating_cnt_30d,parent_asin_rating_avg_prev_rating_30d,...,item_indice,main_category,title,description,categories,price,user_rating_cnt_90d,user_rating_avg_prev_rating_90d,user_rating_list_10_recent_asin,item_sequence
0,AG57LGJFCNNQJ6P6ABQAVUKXDUDA,B0015AARJI,0.0,2016-01-12 11:59:11.000,76.0,4.592105,10.0,4.3,3.0,5.0,...,660,Video Games,PlayStation 3 Dualshock 3 Wireless Controller ...,"[Amazon.com, The Dualshock 3 wireless controll...","[Video Games, Legacy Systems, PlayStation Syst...",49.99,2,5.000000,B00J00BLRM,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, 2662]"
1,AHWG4EGOV5ZDKPETL56MAYGPLJRQ,B0BMGHMP23,0.0,2016-04-18 19:26:20.000,,,,,,,...,4568,Computers,Logitech G502 Lightspeed Wireless Gaming Mouse...,[G502 is the best gaming mouse from Logitech G...,"[Video Games, PC, Accessories, Gaming Mice]",87.95,3,5.000000,"B00YOGZFCO,B00KWFCSB2,B00L3LQ1FI,B0151K6J9Y,B0...","[3028, 2742, 2755, 3159, 3101, 3036, 3051, 313..."
2,AH5PTZ2U74OZ3HT6QVUWM4CV6OVQ,B009AP23NI,0.0,2016-02-10 18:45:08.000,9.0,4.666667,0.0,,0.0,,...,2219,Video Games,Nintendo Wii U Pro U Controller (Japanese Vers...,[Wii U PRO controller (black) (WUP-A-RSKA)],"[Video Games, Legacy Systems, Nintendo Systems...",43.99,8,4.428571,"B0199OXR0W,B00EVPR4FY,B00B7ELWAU,B00UH9DN58,B0...","[-1, -1, 3234, 2508, 2318, 2964, 1258, 2439, 4..."
3,AFC5XTCF5D7J3NSDITB2Z26XWWYA,B001E8WQUY,5.0,2019-05-01 21:22:39.265,0.0,,0.0,,0.0,,...,724,Video Games,Rock Band 2 - Nintendo Wii (Game only),"[Product description, Rock Band 2 lets you and...","[Video Games, Legacy Systems, Nintendo Systems...",28.49,1,,"B006HZA6VK,B0BN2FNKLM,B0086VPUHI,B0040UAYI4,B0...","[1987, 4569, 2114, 1606, 2159, 2279, 2447, 441..."
4,AF7LJQOIWF3Y3YD7SGOJ34MA5JPA,B001E8WQKY,5.0,2015-01-09 12:53:25.000,16.0,4.375000,8.0,4.5,4.0,4.5,...,722,Video Games,Resident Evil 5 - Xbox 360,[],"[Video Games, Legacy Systems, Xbox Systems, Xb...",29.88,3,5.000000,"B00A2ML6XG,B003VUO6LU","[-1, -1, -1, -1, -1, -1, -1, -1, 2261, 1579]"
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
328591,AG4RATLNVLOKZCPXN67HKOAK65CA,B078FBVJMB,0.0,2015-10-31 18:25:09.000,,,,,,,...,3829,Video Games,A Way Out – PC Origin [Online Game Code],[From the creators of Brothers - A Tale of Two...,"[Video Games, PC, Games]",5.99,1,,B00TFVD688,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, 2951]"
328592,AFBXO3BFWBJX6QS5NW73O37IXF2A,B0771ZXXV6,0.0,2011-03-08 02:06:38.000,,,,,,,...,3806,Video Games,Nintendo Joy-Con (R) - Neon Red - Nintendo Switch,[To be determined],"[Video Games, Nintendo Switch, Accessories, Co...",,3,4.000000,"B003JVCA9Q,B0029NZ4HA","[-1, -1, -1, -1, -1, -1, -1, -1, 1488, 1199]"
328593,AHVANA5GZNJ45UABPXWZNAF4ECBQ,B00BBF6MO6,0.0,2015-02-15 05:31:04.000,3.0,4.666667,0.0,,0.0,,...,2327,Video Games,Killer is Dead - Xbox 360,[Killer Is Dead is the latest title from the d...,"[Video Games, Legacy Systems, Xbox Systems, Xb...",39.82,1,,"B002L93F0A,B002KJ02ZC,B001H4NMNA","[-1, -1, -1, -1, -1, -1, -1, 1377, 1374, 1092]"
328594,AHAVA5VKMJ3OMOLGDZ3W45CKXEWA,B00KTORA0K,5.0,2019-05-25 04:03:51.505,3.0,4.666667,1.0,5.0,1.0,5.0,...,2726,Video Games,Just Dance 2015 - Wii,[With more than 50 million copies of Just Danc...,"[Video Games, Legacy Systems, Nintendo Systems...",33.0,2,5.000000,"B004AYCNR0,B007NUQICE,B000TYQL1O,B000SEU92W,B0...","[-1, -1, -1, 1657, 2074, 593, 583, 3715, 3448,..."


In [20]:
user_indices = train_df["user_indice"].unique()
item_indices = train_df["item_indice"].unique()
if args.use_sbert_features:
    all_sbert_vectors = ann_index.get_vector_by_ids(
        item_indices.tolist(), chunk_size=1000
    ).astype(np.float32)

train_item_features = chunk_transform(
    train_df, item_metadata_pipeline, chunk_size=args.tfm_chunk_size
)
train_item_features = train_item_features.astype(np.float32)

val_item_features = chunk_transform(
    val_df, item_metadata_pipeline, chunk_size=args.tfm_chunk_size
)
val_item_features = val_item_features.astype(np.float32)

if args.use_sbert_features:
    train_sbert_vectors = all_sbert_vectors[train_df["item_indice"].values]
    train_item_features = np.hstack([train_item_features, train_sbert_vectors])
    val_sbert_vectors = all_sbert_vectors[val_df["item_indice"].values]
    val_item_features = np.hstack([val_item_features, val_sbert_vectors])

logger.info(f"{len(user_indices)=:,.0f}, {len(item_indices)=:,.0f}")

  0%|          | 0/5 [00:00<?, ?it/s]

Transforming chunks:   0%|          | 0/33 [00:00<?, ?it/s]

Transforming chunks:   0%|          | 0/1 [00:00<?, ?it/s]

[32m2024-10-28 02:35:36.682[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m22[0m - [1mlen(user_indices)=19,578, len(item_indices)=4,630[0m


# Train

In [21]:
rating_dataset = UserItemBinaryDFDataset(
    train_df,
    "user_indice",
    "item_indice",
    args.rating_col,
    args.timestamp_col,
    item_feature=train_item_features,
)
val_rating_dataset = UserItemBinaryDFDataset(
    val_df,
    "user_indice",
    "item_indice",
    args.rating_col,
    args.timestamp_col,
    item_feature=val_item_features,
)

train_loader = DataLoader(
    rating_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True
)
val_loader = DataLoader(
    val_rating_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False
)

In [22]:
n_items = len(item_indices)
n_users = len(user_indices)

model = init_model(
    n_users, n_items, args.embedding_dim, item_feature_size, args.dropout
)
model

Ranker(
  (item_embedding): Embedding(4631, 128, padding_idx=4630)
  (user_embedding): Embedding(19578, 128)
  (gru): GRU(128, 128, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (item_feature_tower): Sequential(
    (0): Linear(in_features=929, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
  )
  (fc_rating): Sequential(
    (0): Linear(in_features=512, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=128, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

#### Predict before train

In [23]:
val_df = val_rating_dataset.df
val_df.sample(10)

Unnamed: 0,user_id,parent_asin,rating,timestamp,parent_asin_rating_cnt_365d,parent_asin_rating_avg_prev_rating_365d,parent_asin_rating_cnt_90d,parent_asin_rating_avg_prev_rating_90d,parent_asin_rating_cnt_30d,parent_asin_rating_avg_prev_rating_30d,...,item_indice,main_category,title,description,categories,price,user_rating_cnt_90d,user_rating_avg_prev_rating_90d,user_rating_list_10_recent_asin,item_sequence
1533,AEGELTZNU45ZVX5RLS57L3XRI5QQ,B00897Z27C,0.0,2021-10-07 13:27:46.156,1.0,5.0,0.0,,0.0,,...,2128,Video Games,ZombiU - Nintendo Wii U,[Can you survive a zombie swarm? London's fall...,"[Video Games, Legacy Systems, Nintendo Systems...",14.6,1,,"B07KS74Q2X,B0C3KYVDWT,B0795GHTBC,B07P6MD9B7,B0...","[-1, -1, -1, -1, -1, 4034, 4607, 3846, 4124, 3..."
1750,AEKZWIRSLLUT5XOVLBIS5HLVEPCA,B01D63UU52,0.0,2021-11-17 17:37:48.476,2.0,4.5,0.0,,0.0,,...,3303,Computers,"CORSAIR M65 Pro RGB - FPS Gaming Mouse - 12,00...",[The M65 PRO RGB is a competition-grade FPS ga...,"[Video Games, PC, Accessories, Gaming Mice]",33.0,1,,"B07BDJHFQD,B06WVCWY41,B00MB1I3FU,B01NASEX0C,B0...","[3867, 3559, 2793, 3549, 4159, 3923, 4368, 404..."
1519,AGOQBWRYF7Z4A6F6A4DBSD7YU4VA,B09CG15F86,1.0,2021-12-31 20:55:58.723,5.0,4.6,1.0,5.0,0.0,,...,4468,Computers,Razer Doubleshot PBT Keycap Upgrade Set for Me...,[Enjoy durability without cramping your style....,"[Video Games, PC, Accessories, Gaming Keyboards]",,1,,"B072HGLS26,B073X4V4V4,B08L6L6KQL,B07ZJ6RY1W,B0...","[-1, -1, -1, 3693, 3727, 4391, 4303, 3788, 363..."
788,AHAOIPSKT4LEWU47ZEN7LMOKRMTA,B002DZKZ5K,0.0,2022-02-02 08:29:00.447,2.0,3.0,0.0,,0.0,,...,1244,Video Games,Lego Indiana Jones 2: The Adventure Continues ...,"[Product Description, LEGO Indiana Jones 2: Th...","[Video Games, Legacy Systems, Nintendo Systems...",28.53,2,2.0,"B00L3LQ1FI,B017QU5G1O,B00A878J5I,B07G4YYZ1M,B0...","[2755, 3202, 2267, 3973, 3834, 3698, 4426, 430..."
180,AHN5HTWFC6HFP5JRBXGALC456NCA,B08TG138F1,1.0,2021-08-19 13:56:12.440,8.0,4.75,1.0,5.0,1.0,5.0,...,4433,Video Games,PowerA Wired Controller for Nintendo Switch: G...,[GameCube style controllers are widely conside...,"[Video Games, Nintendo Switch, Accessories, Co...",24.88,1,,"B001EYUY7U,B002L8W5V6,B000FQ9R4E,B001D8PFIK,B0...","[-1, 1041, 1376, 457, 705, 1610, 1913, 1599, 1..."
78,AGC6BKQ2TR6SZAB4YMPA7FINKBXQ,B07P6MD9B7,0.0,2022-01-31 05:49:20.773,4.0,3.5,1.0,5.0,1.0,5.0,...,4124,Video Games,LEGO Worlds - Xbox One,[EXPLORE. DISCOVER. CREATE. TOGETHER. LEGO Wor...,"[Video Games, Xbox One, Games]",19.22,2,1.0,"B005N4HZRO,B07X6KDQ98,B07SM7G9CN,B07V5CFMY4,B0...","[1934, 4262, 4164, 4197, 3846, 2621, 4210, 433..."
1391,AHMKDXEUMTNYFEZLLQFYPU54RYPA,B094YHB1QK,1.0,2021-09-03 20:48:44.778,49.0,4.204082,12.0,4.166667,2.0,5.0,...,4455,Video Games,PlayStation DualSense Wireless Controller – Ga...,[Plot a course for astronomical adventures on ...,"[Video Games, PlayStation 5, Accessories, Cont...",74.99,1,,"B0C3KYVDWT,B0128UH1VU,B08C3WQ25C,B01N3ASPNV,B0...","[4607, 3113, 4365, 3527, 3811, 2512, 3010, 202..."
1394,AHSSHPP7VQRWVDRV2SX3DGX5MELA,B07XV4NHHN,1.0,2022-01-20 21:00:29.759,12.0,4.5,2.0,4.5,1.0,4.0,...,4267,Video Games,Ring Fit Adventure - Nintendo Switch,[Explore a fantastical adventure world to defe...,"[Video Games, Nintendo Switch, Accessories]",,1,,"B000N5Z2L4,B087NNZZM8,B087NNPYP3,B087SHFL9B,B0...","[-1, -1, -1, -1, 526, 4343, 4342, 4344, 4203, ..."
1234,AEVTLWMYYFATTYAEK3C5FSDFT3FA,B094YHB1QK,1.0,2021-11-17 15:17:52.304,44.0,4.431818,5.0,4.8,1.0,5.0,...,4455,Video Games,PlayStation DualSense Wireless Controller – Ga...,[Plot a course for astronomical adventures on ...,"[Video Games, PlayStation 5, Accessories, Cont...",74.99,2,1.0,"B0051D8QCA,B001JKTC9A,B00CEGCN76,B0995GXFV4,B0...","[-1, 1864, 1103, 2379, 4461, 3081, 4595, 4130,..."
1421,AFERSD24ISKJZNLIKZ5A2V5DGMRA,B087XRWHHL,1.0,2022-02-21 13:27:53.499,20.0,4.0,2.0,4.5,0.0,,...,4350,Video Games,"Assassin’s Creed Valhalla Xbox Series X|S, Xbo...",[Upgrade to the Xbox Series X S version of the...,"[Video Games, Xbox One, Games]",15.79,2,5.0,"B00BN5T30E,B07WS5R5DP,B00BAQXJMO,B0086VPUHI,B0...","[-1, -1, 2352, 4218, 2325, 2114, 1378, 1365, 3..."


In [24]:
user_id = val_df.sample(1)[args.user_col].values[0]
# user_id = "AH4AOFTTDPHPAFAAVFMAF25H2LIQ"
test_df = val_df.loc[lambda df: df[args.user_col].eq(user_id)]
with pd.option_context("display.max_colwidth", None):
    display(test_df)

Unnamed: 0,user_id,parent_asin,rating,timestamp,parent_asin_rating_cnt_365d,parent_asin_rating_avg_prev_rating_365d,parent_asin_rating_cnt_90d,parent_asin_rating_avg_prev_rating_90d,parent_asin_rating_cnt_30d,parent_asin_rating_avg_prev_rating_30d,...,item_indice,main_category,title,description,categories,price,user_rating_cnt_90d,user_rating_avg_prev_rating_90d,user_rating_list_10_recent_asin,item_sequence
17,AEVKATB3HNRJWVKZ3QFJFFWWRLCA,B006G81HV6,0.0,2022-03-04 01:37:08.233,1.0,5.0,1.0,5.0,0.0,,...,1984,Video Games,FIFA Street - Playstation 3,"[Product Description, From the creators of the award-winning EA SPORTS FIFA Football franchise, and inspired by street football styles and stars from around the world, FIFA Street is the most authentic street football game ever made. Enjoy a unique and fun experience where everything from the environments to the gear to the music is true to the sport and its culture. Whether performing one panna after another without breaking a sweat the way the game is played in Amsterdam, or a physical, fight-for-possession style the way players compete in London, fans will enjoy a superior fidelity of ball control and responsiveness than anything ever experienced. Plus, for the first time ever, utilize aerial skills to maneuver past opponents, an all-new sophisticated wall-play mechanic and over 50 brand new spectacular skill moves. Complete Customization - Enjoy customizable matches to replicate the unique ways the game is played around the world. Choose from different size pitches, number of players, and match types. Take on the challenge of performing trick moves and panna's in a game in Amsterdam, authentic futsal styled matches with no wall play in Spain, or a physical 5v5 contest with assigned player positions in the UK. EA SPORTS Football Club - FIFA Street will be connected to EA SPORTS Football Club, so from the first nutmeg on, you will be contributing to your Football Club identity., Amazon.com, Whether you call it soccer or football, one thing's for sure — you've got the technique worthy of the sport's most competitive clubs. You sleep and wake by the black-and-white ball and you've got the foot to kick it all the way to the World Cup. But something's missing. The yellow- and white-painted lines, neatly groomed turf and perfectly squared-off goals are starting to get a bit old. You've seen one field, you've seen them all — or have you? Ditch your cleats, leave the stadium in the dust and get ready to experience fierce soccer action where it's really at — in the streets. From rooftops to parking lots, from Rio de Janeiro to Amsterdam, players are tearing up the pavement and taking soccer to a whole new level. Build up a killer roster and sharpen your skills, you're about to take it to the Street., Take Fifa to the Streets Play with familiar players More command & responsiveness, Synopsis, You know the award-winning FIFA franchise always scores when it comes to authentic, immersive and action-packed gameplay. Now, the same creators behind the revolutionary sports series take you out of the stadium and steep you in the rough and gritty culture of underground soccer in FIFA Street. Featuring all of the robust FIFA gameplay engines, like the Impact Engine, Precision Dribbling and Personality+, this fresh installment keeps what's already great, and also kicks it up a notch with the most responsive and authentic ball handling yet, Street Control. Bait and beat opponents using a precise standing dribble, maneuver through tight spaces with the deft street dribble, show your flair by juggling and master more than 50 never-before-seen skills. Prove your talent in global tournaments and challenges that take you from futsal-style matches in Spain to physical five-on-five contests in the UK and pit you against fellow gamers in massive, social competitions. Build your own team and climb your way to the top of the leaderboards, earning brag-worthy, in-game rewards along the way. Say goodbye to soccer the way you once knew it — once you hit the Street, you'll never want to look back., Key Features:, Immerse yourself in the gritty, competitive, action-packed culture of street soccer in painstakingly authentic gameplay from the creators of the award-winning FIFA franchise, Immerse yourself in the gritty, competitive, action-packed culture of street soccer in painstakingly authentic gameplay from the creators of the award-winning FIFA franchise, Own the ball with more command and responsiveness than ever as Street Ball Control replicates the touch, creativity and flair that players bring to the streets, Own the ball with more command and responsiveness than ever as Street Ball Control replicates the touch, creativity and flair that players bring to the streets, Maneuver the ball backwards, forwards and side-to-side in a standing dribble to lure your opponent into making the first move, then rally past him with a panna or trick move, Maneuver the ball backwards, forwards and side-to-side in a standing dribble to lure your opponent into making the first move, then rally past him with a panna or trick move, Slip through tight spaces or fend off the defense using close dribble touches thanks to the finely tuned Street Dribble system, Slip through tight spaces or fend off the defense using close dribble touches thanks to the finely tuned Street Dribble system, Master a massive skill set, including juggling and other aerial maneuvers and more than 50 never-before-seen moves, Master a massive skill set, including juggling and other aerial maneuvers and more than 50 never-before-seen moves, Climb your way to soccer stardom by building your own team and clinching wins in 16 tournaments and 20 challenges, starting off at the local level and rising to European and world-stage competitions, Climb your way to soccer stardom by building your own team and clinching wins in 16 tournaments and 20 challenges, starting off at the local level and rising to European and world-stage competitions, Compete against teams created by other FIFA Street gamers or face off one-on-one in the vast, social World Tour mode, and check your ranking on the leaderboards, Compete against teams created by other FIFA Street gamers or face off one-on-one in the vast, social World Tour mode, and check your ranking on the leaderboards, Discover wild ways to win Street Challenges against up to four players, by kicking the ball through your opponent's legs in Panna Rules, wowing the crowd in Entertainment Points, dwindling your team down one by one in Last Man Standing and more, Discover wild ways to win Street Challenges against up to four players, by kicking the ball through your opponent's legs in Panna Rules, wowing the crowd in Entertainment Points, dwindling your team down one by one in Last Man Standing and more, Flaunt your showmanship in front of Amsterdam crowds, compete without walls in the futsal-style matches of Spain or play rough in physical five-on-five contests in the UK as you experience the unique street cultures around the globe, Flaunt your showmanship in front of Amsterdam crowds, compete without walls in the futsal-style matches of Spain or play rough in physical five-on-five contests in the UK as you experience the unique street cultures around the globe, Take the action from parking lots and parks to gyms and rooftop arenas in authentic, real-world environments, Take the action from parking lots and parks to gyms and rooftop arenas in authentic, real-world environments, Utilize the critically acclaimed FIFA gameplay engines you've already mastered, including the Impact Engine, Precision Dribbling, Personality+ and more, Utilize the critically acclaimed FIFA gameplay engines you've already mastered, including the Impact Engine, Precision Dribbling, Personality+ and more, Play with your favorite stars from the world's top clubs, including Manchester United, Barcelona, Real Madrid and more, each sporting their authentic kits and gear, and meet real freestylers plucked from streets around the world, Play with your favorite stars from the world's top clubs, including Manchester United, Barcelona, Real Madrid and more, each sporting their authentic kits and gear, and meet real freestylers plucked from streets around the world, Earn more then 100 different styles, tricks and celebrations to grow your player and unlock more than 225 items for your squad, including team kits, street wear, boots, environments and teams, Earn more then 100 different styles, tricks and celebrations to grow your player and unlock more than 225 items for your squad, including team kits, street wear, boots, environments and teams]","[Video Games, Legacy Systems, PlayStation Systems, PlayStation 3, Games]",16.0,1,,"B0028Y4PUW,B001ELJE1A,B07MB4NPS4,B0026MS1XS,B001EYUXDK,B00005O0I2,B0013OL0BK","[-1, -1, -1, 1190, 770, 4057, 1183, 1022, 162, 652]"
112,AEVKATB3HNRJWVKZ3QFJFFWWRLCA,B07DY1JKRH,1.0,2022-03-04 01:37:08.233,0.0,,0.0,,0.0,,...,3951,Video Games,Dead by Daylight - Xbox One,"[""Death is not an escape… Dead by Daylight is an asymmetrical, multiplayer (4 vs. 1) horror game where one player takes on the role of a savage killer, and the other four players become Survivors, frantically scurrying to avoid being caught, tortured, and killed. Survivors play in third-person and have the advatage of better situational awareness. The Killer plays in first-person and is more focused on their prey. The Survivors' goal in each encounter is to escape the Killing Ground without getting caught by the Killer - in an environment that changes every time you play.""]","[Video Games, Xbox One, Games]",41.6,1,,"B0028Y4PUW,B001ELJE1A,B07MB4NPS4,B0026MS1XS,B001EYUXDK,B00005O0I2,B0013OL0BK","[-1, -1, -1, 1190, 770, 4057, 1183, 1022, 162, 652]"


In [25]:
test_row = test_df.loc[lambda df: df[args.rating_col].gt(0)].iloc[0]
item_id = test_row[args.item_col]
item_sequence = test_row["item_sequence"]
row_idx = test_row.name
item_feature = val_item_features[row_idx]
logger.info(
    f"Test predicting before training with {args.user_col} = {user_id} and {args.item_col} = {item_id}"
)
user_indice = idm.get_user_index(user_id)
item_indice = idm.get_item_index(item_id)
user = torch.tensor([user_indice])
item_sequence = torch.tensor([item_sequence])
item_feature = torch.tensor([item_feature])
item = torch.tensor([item_indice])

model.eval()
model.predict(user, item_sequence, item_feature, item)

[32m2024-10-28 02:35:37.052[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mTest predicting before training with user_id = AEVKATB3HNRJWVKZ3QFJFFWWRLCA and parent_asin = B07DY1JKRH[0m

Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_new.cpp:281.)



tensor([[0.4950]], grad_fn=<SigmoidBackward0>)

#### Training loop

##### Overfit 1 batch

In [26]:
early_stopping = EarlyStopping(
    monitor="val_loss", patience=10, mode="min", verbose=False
)

model = init_model(n_users, n_items, args.embedding_dim, item_feature_size, dropout=0)
lit_model = LitRanker(
    model,
    learning_rate=args.learning_rate,
    l2_reg=0.0,
    log_dir=args.notebook_persist_dp,
)

log_dir = f"{args.notebook_persist_dp}/logs/overfit"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    accelerator=args.device if args.device else "auto",
    max_epochs=100,
    overfit_batches=1,
    callbacks=[early_stopping],
)
trainer.fit(
    model=lit_model,
    train_dataloaders=train_loader,
    val_dataloaders=train_loader,
)
logger.info(f"Logs available at {trainer.log_dir}")

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(overfit_batches=1)` was configured so 1 batch will be used.

  | Name  | Type   | Params | Mode 
-----------------------------------------
0 | model | Ranker | 3.4 M  | train
-----------------------------------------
3.4 M     Trainable params
0         Non-trainable params
3.4 M     Total params
13.533    Total estimated model params size (MB)
14        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …


You requested to overfit but enabled val dataloader shuffling. We are turning off the val dataloader shuffling for you.


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.


You requested to overfit but enabled train dataloader shuffling. We are turning off the train dataloader shuffling for you.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.


The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.



Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=100` reached.
[32m2024-10-28 02:35:46.147[0m | [1mINFO    [0m | [36msrc.ranker.trainer[0m:[36mon_fit_end[0m:[36m158[0m - [1mLogging classification metrics...[0m
[32m2024-10-28 02:36:03.719[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m28[0m - [1mLogs available at /Users/dvq/frostmourne/recsys-mvp/notebooks/data/023-dropout-to-0.4/logs/overfit/lightning_logs/version_1[0m


In [27]:
%tensorboard --logdir $trainer.log_dir

##### Fit on all data

In [28]:
all_items_df = train_df.drop_duplicates(subset=["item_indice"])
all_items_indices = all_items_df["item_indice"].values
all_items_features = item_metadata_pipeline.transform(all_items_df).astype(np.float32)
if args.use_sbert_features:
    all_sbert_vectors = ann_index.get_vector_by_ids(all_items_indices.tolist()).astype(
        np.float32
    )
    all_items_features = np.hstack([all_items_features, all_sbert_vectors])

  0%|          | 0/47 [00:00<?, ?it/s]

In [29]:
all_items_features.shape

(4630, 929)

In [30]:
# papermill_description=fit-model
early_stopping = EarlyStopping(
    monitor="val_loss", patience=args.early_stopping_patience, mode="min", verbose=False
)

checkpoint_callback = ModelCheckpoint(
    dirpath=f"{args.notebook_persist_dp}/checkpoints",
    filename="best-checkpoint",
    save_top_k=1,
    monitor="val_loss",
    mode="min",
)

model = init_model(
    n_users,
    n_items,
    args.embedding_dim,
    item_feature_size,
    dropout=args.dropout,
    item_embedding=pretrained_item_embedding,
)
lit_model = LitRanker(
    model,
    learning_rate=args.learning_rate,
    l2_reg=args.l2_reg,
    log_dir=args.notebook_persist_dp,
    evaluate_ranking=True,
    idm=idm,
    all_items_indices=all_items_indices,
    all_items_features=all_items_features,
    args=args,
    neg_to_pos_ratio=args.neg_to_pos_ratio,
    checkpoint_callback=checkpoint_callback,
)

log_dir = f"{args.notebook_persist_dp}/logs/run"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    max_epochs=args.max_epochs,
    callbacks=[early_stopping, checkpoint_callback],
    accelerator=args.device if args.device else "auto",
    logger=args._mlf_logger if args.log_to_mlflow else None,
)
trainer.fit(
    model=lit_model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

Checkpoint directory /Users/dvq/frostmourne/recsys-mvp/notebooks/data/023-dropout-to-0.4/checkpoints exists and is not empty.


  | Name  | Type   | Params | Mode 
-----------------------------------------
0 | model | Ranker | 3.4 M  | train
-----------------------------------------
3.4 M     Trainable params
0         Non-trainable params
3.4 M     Total params
13.533    Total estimated model params size (MB)
14        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.



Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

[32m2024-10-28 02:47:56.179[0m | [1mINFO    [0m | [36msrc.ranker.trainer[0m:[36mon_fit_end[0m:[36m152[0m - [1mLoading best model from /Users/dvq/frostmourne/recsys-mvp/notebooks/data/023-dropout-to-0.4/checkpoints/best-checkpoint-v1.ckpt...[0m
[32m2024-10-28 02:47:56.389[0m | [1mINFO    [0m | [36msrc.ranker.trainer[0m:[36mon_fit_end[0m:[36m158[0m - [1mLogging classification metrics...[0m
[32m2024-10-28 02:47:57.017[0m | [1mINFO    [0m | [36msrc.ranker.trainer[0m:[36mon_fit_end[0m:[36m161[0m - [1mLogging ranking metrics...[0m


Generating recommendations:   0%|          | 0/177 [00:00<?, ?it/s]


invalid value encountered in divide

2024/10/28 02:48:39 INFO mlflow.tracking._tracking_service.client: 🏃 View run 023-dropout-to-0.4 at: http://localhost:5002/#/experiments/3/runs/9d81ccb8cda843fda71bb497e169ef00.
2024/10/28 02:48:39 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://localhost:5002/#/experiments/3.


In [31]:
logger.info(
    f"Test predicting after training with {args.user_col} = {user_id} and {args.item_col} = {item_id}"
)
model.eval()
model.predict(user, item_sequence, item_feature, item)

[32m2024-10-28 02:48:39.882[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mTest predicting after training with user_id = AEVKATB3HNRJWVKZ3QFJFFWWRLCA and parent_asin = B07DY1JKRH[0m


tensor([[0.0104]], grad_fn=<SigmoidBackward0>)

# Load best checkpoint

In [32]:
logger.info(f"Loading best checkpoint from {checkpoint_callback.best_model_path}...")
args.best_checkpoint_path = checkpoint_callback.best_model_path

best_trainer = LitRanker.load_from_checkpoint(
    checkpoint_callback.best_model_path,
    model=init_model(
        n_users,
        n_items,
        args.embedding_dim,
        item_feature_size,
        dropout=0,
        item_embedding=pretrained_item_embedding,
    ),
)

[32m2024-10-28 02:48:39.918[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mLoading best checkpoint from /Users/dvq/frostmourne/recsys-mvp/notebooks/data/023-dropout-to-0.4/checkpoints/best-checkpoint-v1.ckpt...[0m


In [33]:
best_model = best_trainer.model.to(lit_model.device)

In [34]:
best_model.eval()
best_model.predict(user, item_sequence, item_feature, item)

tensor([[0.0104]], grad_fn=<SigmoidBackward0>)

### Persist id mapping

In [35]:
if args.log_to_mlflow:
    # Persist id_mapping so that at inference we can predict based on item_ids (string) instead of item_index
    run_id = trainer.logger.run_id
    mlf_client = trainer.logger.experiment
    mlf_client.log_artifact(run_id, idm_fp)
    # Persist item_feature_metadata pipeline
    mlf_client.log_artifact(run_id, args.item_metadata_pipeline_fp)

### Wrap inference function and register best checkpoint as MLflow model

In [36]:
inferrer = RankerInferenceWrapper(best_model)

In [37]:
sample_input = {
    "user_ids": [idm.get_user_id(0)],
    "item_sequences": [[idm.get_item_id(0), idm.get_item_id(1)]],
    **{col: [train_df.iloc[0].fillna(0)[col]] for col in args.item_feature_cols},
    "item_ids": [idm.get_item_id(0)],
}
sample_output = inferrer.infer([0], [[0, 1]], [train_item_features[0]], [0])
sample_output

array([0.867823], dtype=float32)

In [38]:
sample_input

{'user_ids': ['AE225O22SA7DLBOGOEIFL7FT5VYQ'],
 'item_sequences': [['0375869026', '9625990674']],
 'main_category': ['Video Games'],
 'categories': [array(['Video Games', 'Legacy Systems', 'PlayStation Systems',
         'PlayStation 3', 'Accessories', 'Controllers'], dtype=object)],
 'price': ['49.99'],
 'parent_asin_rating_cnt_365d': [76.0],
 'parent_asin_rating_avg_prev_rating_365d': [4.592105263157895],
 'parent_asin_rating_cnt_90d': [10.0],
 'parent_asin_rating_avg_prev_rating_90d': [4.3],
 'parent_asin_rating_cnt_30d': [3.0],
 'parent_asin_rating_avg_prev_rating_30d': [5.0],
 'parent_asin_rating_cnt_7d': [1.0],
 'parent_asin_rating_avg_prev_rating_7d': [5.0],
 'item_ids': ['0375869026']}

In [39]:
if args.log_to_mlflow:
    run_id = trainer.logger.run_id
    sample_output_np = sample_output
    signature = infer_signature(sample_input, sample_output_np)
    idm_filename = idm_fp.split("/")[-1]
    item_metadata_pipeline_filename = args.item_metadata_pipeline_fp.split("/")[-1]
    with mlflow.start_run(run_id=run_id):
        mlflow.pyfunc.log_model(
            python_model=inferrer,
            artifact_path="inferrer",
            artifacts={
                # We log the id_mapping to the predict function so that it can accept item_id and automatically convert ot item_indice for PyTorch model to use
                "idm": mlflow.get_artifact_uri(idm_filename),
                "item_metadata_pipeline": mlflow.get_artifact_uri(
                    item_metadata_pipeline_filename
                ),
            },
            model_config={"use_sbert_features": args.use_sbert_features},
            metadata={"use_sbert_features": args.use_sbert_features},
            signature=signature,
            input_example=sample_input,
            registered_model_name=args.mlf_model_name,
        )


Since MLflow 2.16.0, we no longer convert dictionary input example to pandas Dataframe, and directly save it as a json object. If the model expects a pandas DataFrame input instead, please pass the pandas DataFrame as input example directly.



Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Registered model 'ranker' already exists. Creating a new version of this model...
2024/10/28 02:48:44 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: ranker, version 22
Created version '22' of model 'ranker'.


Downloading artifacts:   0%|          | 0/9 [00:00<?, ?it/s]

  "inputs": {
    "user_ids": [
      "AE225O22SA7DLBOGOEIFL7FT5VYQ"
    ],
    "item_sequences": [
      [
        "0375869026",
        "9625990674"
      ]
    ],
    "main_category": [
      "Video Games"
    ],
    "categories": [
      [
        "Video Games",
        "Legacy Systems",
        "PlayStation Systems",
        "PlayStation 3",
        "Accessories",
        "Controllers"
      ]
    ],
    "price": [
      "49.99"
    ],
    "parent_asin_rating_cnt_365d": [
      76.0
    ],
    "parent_asin_rating_avg_prev_rating_365d": [
      4.592105263157895
    ],
    "parent_asin_rating_cnt_90d": [
      10.0
    ],
    "parent_asin_rating_avg_prev_rating_90d": [
      4.3
    ],
    "parent_asin_rating_cnt_30d": [
      3.0
    ],
    "parent_asin_rating_avg_prev_rating_30d": [
      5.0
    ],
    "parent_asin_rating_cnt_7d": [
      1.0
    ],
    "parent_asin_rating_avg_prev_rating_7d": [
      5.0
    ],
    "item_ids": [
      "0375869026"
    ]
  }
}. Alternatively, yo

context.model_config={'use_sbert_features': True}
self.use_sbert_features=True


# Set the newly trained model as champion

In [40]:
if args.log_to_mlflow:
    val_roc_auc = trainer.logger.experiment.get_run(trainer.logger.run_id).data.metrics[
        "val_roc_auc"
    ]

    if val_roc_auc > args.min_roc_auc:
        logger.info(f"Aliasing the new model as champion...")
        model_version = (
            mlf_client.get_registered_model(args.mlf_model_name)
            .latest_versions[0]
            .version
        )

        mlf_client.set_registered_model_alias(
            name=args.mlf_model_name, alias="champion", version=model_version
        )

        mlf_client.set_model_version_tag(
            name=args.mlf_model_name,
            version=model_version,
            key="author",
            value="quy.dinh",
        )

[32m2024-10-28 02:48:44.781[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m7[0m - [1mAliasing the new model as champion...[0m


# Clean up

In [41]:
all_params = [args]

if args.log_to_mlflow:
    with mlflow.start_run(run_id=run_id):
        for params in all_params:
            params_dict = params.dict()
            params_ = dict()
            for k, v in params_dict.items():
                if k == "top_K":
                    k = "top_big_K"
                if k == "top_k":
                    k = "top_small_k"
                params_[f"{params.__repr_name__()}.{k}"] = v
            mlflow.log_params(params_)

2024/10/28 02:48:44 INFO mlflow.tracking._tracking_service.client: 🏃 View run 023-dropout-to-0.4 at: http://localhost:5002/#/experiments/3/runs/9d81ccb8cda843fda71bb497e169ef00.
2024/10/28 02:48:44 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://localhost:5002/#/experiments/3.
