# Ranker that can takes into accound different features

# Set up

In [1]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

In [2]:
import os
import sys

import dill
import lightning as L
import numpy as np
import pandas as pd
import torch
from dotenv import load_dotenv
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import MLFlowLogger
from loguru import logger
from mlflow.models.signature import infer_signature
from pydantic import BaseModel
from torch.utils.data import DataLoader

import mlflow

load_dotenv()

sys.path.insert(0, "..")

from cfg.run_cfg import RunCfg
from src.ann import AnnIndex
from src.data_prep_utils import chunk_transform
from src.dataset import UserItemBinaryDFDataset
from src.id_mapper import IDMapper
from src.ranker.inference import RankerInferenceWrapper
from src.ranker.model import Ranker
from src.ranker.trainer import LitRanker
from src.viz import blueq_colors

# Controller

In [3]:
max_epochs = 1

In [4]:
class Args(BaseModel):
    testing: bool = False
    author: str = "quy.dinh"
    log_to_mlflow: bool = True
    experiment_name: str = "RecSys MVP - Ranker"
    run_name: str = "037-add-llm-item-tags"
    notebook_persist_dp: str = None
    random_seed: int = 41
    device: str = None

    rc: RunCfg = RunCfg().init()

    item_metadata_pipeline_fp: str = "../data/item_metadata_pipeline.dill"
    qdrant_url: str = None
    qdrant_collection_name: str = "item_desc_sbert"

    max_epochs: int = max_epochs
    batch_size: int = 128
    tfm_chunk_size: int = 10000
    neg_to_pos_ratio: int = 1

    user_col: str = "user_id"
    item_col: str = "parent_asin"
    rating_col: str = "rating"
    timestamp_col: str = "timestamp"

    top_K: int = 100
    top_k: int = 10

    embedding_dim: int = 128
    item_sequence_ts_bucket_size: int = 10
    bucket_embedding_dim: int = 16
    dropout: float = 0.3
    early_stopping_patience: int = 5
    learning_rate: float = 0.0003
    l2_reg: float = 1e-4

    mlf_item2vec_model_name: str = "item2vec"
    mlf_model_name: str = "ranker"
    min_roc_auc: float = 0.7

    best_checkpoint_path: str = None

    def init(self):
        self.notebook_persist_dp = os.path.abspath(f"data/{self.run_name}")
        os.makedirs(self.notebook_persist_dp, exist_ok=True)

        if not (qdrant_host := os.getenv("QDRANT_HOST")):
            raise Exception(f"Environment variable QDRANT_HOST is not set.")

        qdrant_port = os.getenv("QDRANT_PORT")
        self.qdrant_url = f"{qdrant_host}:{qdrant_port}"

        if not (mlflow_uri := os.environ.get("MLFLOW_TRACKING_URI")):
            logger.warning(
                f"Environment variable MLFLOW_TRACKING_URI is not set. Setting self.log_to_mlflow to false."
            )
            self.log_to_mlflow = False

        if self.log_to_mlflow:
            logger.info(
                f"Setting up MLflow experiment {self.experiment_name} - run {self.run_name}..."
            )
            self._mlf_logger = MLFlowLogger(
                experiment_name=self.experiment_name,
                run_name=self.run_name,
                tracking_uri=mlflow_uri,
                log_model=True,
            )

        if self.device is None:
            self.device = (
                "cuda"
                if torch.cuda.is_available()
                else "mps" if torch.backends.mps.is_available() else "cpu"
            )

        return self


args = Args().init()

print(args.model_dump_json(indent=2))

[32m2024-11-16 10:49:01.103[0m | [34m[1mDEBUG   [0m | [36mcfg.run_cfg[0m:[36minit[0m:[36m43[0m - [34m[1mChanging use_item_tags_from_llm requires re-running notebook 002-features-v2 to get the new item_metadata_pipeline.dill file[0m
[32m2024-11-16 10:49:01.106[0m | [1mINFO    [0m | [36m__main__[0m:[36minit[0m:[36m61[0m - [1mSetting up MLflow experiment RecSys MVP - Ranker - run 037-add-llm-item-tags...[0m


{
  "testing": false,
  "author": "quy.dinh",
  "log_to_mlflow": true,
  "experiment_name": "RecSys MVP - Ranker",
  "run_name": "037-add-llm-item-tags",
  "notebook_persist_dp": "/home/dvquys/frostmourne/recsys-mvp/notebooks/data/037-add-llm-item-tags",
  "random_seed": 41,
  "device": "cuda",
  "rc": {
    "use_sbert_features": false,
    "use_item_tags_from_llm": false,
    "item_feature_cols": [
      "main_category",
      "categories",
      "price",
      "parent_asin_rating_cnt_365d",
      "parent_asin_rating_avg_prev_rating_365d",
      "parent_asin_rating_cnt_90d",
      "parent_asin_rating_avg_prev_rating_90d",
      "parent_asin_rating_cnt_30d",
      "parent_asin_rating_avg_prev_rating_30d",
      "parent_asin_rating_cnt_7d",
      "parent_asin_rating_avg_prev_rating_7d"
    ],
    "item_tags_from_llm_fp": "../data/item_tags_from_llm.parquet"
  },
  "item_metadata_pipeline_fp": "../data/item_metadata_pipeline.dill",
  "qdrant_url": "localhost:6333",
  "qdrant_collection_n

# Implement

In [5]:
def init_model(
    n_users,
    n_items,
    embedding_dim,
    item_sequence_ts_bucket_size,
    bucket_embedding_dim,
    item_feature_size,
    dropout,
    item_embedding=None,
):
    model = Ranker(
        n_users,
        n_items,
        embedding_dim,
        item_sequence_ts_bucket_size=item_sequence_ts_bucket_size,
        bucket_embedding_dim=bucket_embedding_dim,
        item_feature_size=item_feature_size,
        dropout=dropout,
        item_embedding=item_embedding,
    )
    return model

## Load pretrained Item2Vec embeddings

In [6]:
mlf_client = mlflow.MlflowClient()
model = mlflow.pyfunc.load_model(
    model_uri=f"models:/{args.mlf_item2vec_model_name}@champion"
)
skipgram_model = model.unwrap_python_model().model
embedding_0 = skipgram_model.embeddings(torch.tensor(0))
embedding_dim = embedding_0.size()[0]
id_mapping = model.unwrap_python_model().id_mapping
pretrained_item_embedding = skipgram_model.embeddings

Downloading artifacts:   0%|          | 0/8 [00:00<?, ?it/s]

In [7]:
assert (
    pretrained_item_embedding.embedding_dim == args.embedding_dim
), "Mismatch pretrained item_embedding dimension"

## Load vectorized item features

In [8]:
with open(args.item_metadata_pipeline_fp, "rb") as f:
    item_metadata_pipeline = dill.load(f)

## Load ANN Index

In [9]:
if args.rc.use_sbert_features:
    ann_index = AnnIndex(args.qdrant_url, args.qdrant_collection_name)
    vector = ann_index.get_vector_by_ids([0])[0]
    sbert_embedding_dim = vector.shape[0]
    logger.info(f"{sbert_embedding_dim=}")
    neighbors = ann_index.get_neighbors_by_ids([0])
    display(neighbors)

# Prep data

In [10]:
train_df = pd.read_parquet("../data/train_features_neg_df.parquet")
val_df = pd.read_parquet("../data/val_features_neg_df.parquet")
idm_fp = "../data/idm.json"
idm = IDMapper().load(idm_fp)

assert (
    train_df[args.user_col].map(lambda s: idm.get_user_index(s))
    != train_df["user_indice"]
).sum() == 0, "Mismatch IDM"
assert (
    val_df[args.user_col].map(lambda s: idm.get_user_index(s)) != val_df["user_indice"]
).sum() == 0, "Mismatch IDM"

if args.rc.use_item_tags_from_llm:
    assert (
        "tags" in train_df.columns
    ), "There is no column `tags` in train_df, please make sure you have run notebook 002, 020 with RunCfg.use_item_tags_from_llm=True"

In [11]:
train_df

Unnamed: 0,user_id,parent_asin,rating,timestamp,timestamp_unix,parent_asin_rating_cnt_365d,parent_asin_rating_avg_prev_rating_365d,parent_asin_rating_cnt_90d,parent_asin_rating_avg_prev_rating_90d,parent_asin_rating_cnt_30d,...,user_rating_list_10_recent_asin_timestamp,item_sequence,item_sequence_ts,item_sequence_ts_bucket,tags,main_category,title,description,categories,price
0,AG57LGJFCNNQJ6P6ABQAVUKXDUDA,B0015AARJI,0.0,2016-01-12 11:59:11.000,,76.0,4.592105,10.0,4.3,3.0,...,1452599936,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1....","[-1, -1, -1, -1, -1, -1, -1, -1, -1, 1452599936]","[-1, -1, -1, -1, -1, -1, -1, -1, -1, 0]","[Controller, Wireless, Gamepad, Ergonomic, Pla...",Video Games,PlayStation 3 Dualshock 3 Wireless Controller ...,"[Amazon.com, The Dualshock 3 wireless controll...","[Video Games, Legacy Systems, PlayStation Syst...",49.99
1,AHWG4EGOV5ZDKPETL56MAYGPLJRQ,B0BMGHMP23,0.0,2016-04-18 19:26:20.000,,,,,,,...,"1449254540,1449256005,1449257733,1452715791,14...","[3028.0, 2742.0, 2755.0, 3159.0, 3101.0, 3036....","[1449254540, 1449256005, 1449257733, 145271579...","[5, 5, 5, 5, 5, 5, 5, 5, 5, 5]","[Wireless Mouse, Gaming Precision, High DPI, C...",Computers,Logitech G502 Lightspeed Wireless Gaming Mouse...,[G502 is the best gaming mouse from Logitech G...,"[Video Games, PC, Accessories, Gaming Mice]",87.95
2,AH5PTZ2U74OZ3HT6QVUWM4CV6OVQ,B009AP23NI,0.0,2016-02-10 18:45:08.000,,9.0,4.666667,0.0,,0.0,...,"1443454097,1455129080,1455129186,1455129499,14...","[-1.0, -1.0, 3234.0, 2508.0, 2318.0, 2964.0, 1...","[-1, -1, 1443454097, 1455129080, 1455129186, 1...","[-1, -1, 5, 1, 1, 0, 0, 0, 0, 0]","[Controller, Wii U, Japanese Version, Gaming A...",Video Games,Nintendo Wii U Pro U Controller (Japanese Vers...,[Wii U PRO controller (black) (WUP-A-RSKA)],"[Video Games, Legacy Systems, Nintendo Systems...",43.99
3,AFC5XTCF5D7J3NSDITB2Z26XWWYA,B001E8WQUY,5.0,2019-05-01 21:22:39.265,1.556746e+09,0.0,,0.0,,0.0,...,"1327120514,1377289907,1402605836,1402606396,14...","[1987.0, 4569.0, 2114.0, 1606.0, 2159.0, 2279....","[1327120514, 1377289907, 1402605836, 140260639...","[8, 8, 7, 7, 7, 7, 7, 7, 6, 6]","[Music Rhythm, Band Simulation, Multiplayer, N...",Video Games,Rock Band 2 - Nintendo Wii (Game only),"[Product description, Rock Band 2 lets you and...","[Video Games, Legacy Systems, Nintendo Systems...",28.49
4,AF7LJQOIWF3Y3YD7SGOJ34MA5JPA,B001E8WQKY,5.0,2015-01-09 12:53:25.000,1.420808e+09,16.0,4.375000,8.0,4.5,4.0,...,14208077931420807991,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1....","[-1, -1, -1, -1, -1, -1, -1, -1, 1420807793, 1...","[-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]","[Survival Horror, Action, Co-op, Zombies, Thir...",Video Games,Resident Evil 5 - Xbox 360,[],"[Video Games, Legacy Systems, Xbox Systems, Xb...",29.88
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
328591,AG4RATLNVLOKZCPXN67HKOAK65CA,B078FBVJMB,0.0,2015-10-31 18:25:09.000,,,,,,,...,1425233294,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1....","[-1, -1, -1, -1, -1, -1, -1, -1, -1, 1425233294]","[-1, -1, -1, -1, -1, -1, -1, -1, -1, 5]","[Co-Op, Action-Adventure, Narrative, Online Mu...",Video Games,A Way Out – PC Origin [Online Game Code],[From the creators of Brothers - A Tale of Two...,"[Video Games, PC, Games]",5.99
328592,AFBXO3BFWBJX6QS5NW73O37IXF2A,B0771ZXXV6,0.0,2011-03-08 02:06:38.000,,,,,,,...,12995495171299549928,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1....","[-1, -1, -1, -1, -1, -1, -1, -1, 1299549517, 1...","[-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]","[Controller, Joy-Con, Nintendo Accessory, Wire...",Video Games,Nintendo Joy-Con (R) - Neon Red - Nintendo Switch,[To be determined],"[Video Games, Nintendo Switch, Accessories, Co...",
328593,AHVANA5GZNJ45UABPXWZNAF4ECBQ,B00BBF6MO6,0.0,2015-02-15 05:31:04.000,,3.0,4.666667,0.0,,0.0,...,137041433213704147071370416530,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 137...","[-1, -1, -1, -1, -1, -1, -1, 1370414332, 13704...","[-1, -1, -1, -1, -1, -1, -1, 6, 6, 6]","[Action, Hack and Slash, Stylish Gameplay, Sto...",Video Games,Killer is Dead - Xbox 360,[Killer Is Dead is the latest title from the d...,"[Video Games, Legacy Systems, Xbox Systems, Xb...",39.82
328594,AHAVA5VKMJ3OMOLGDZ3W45CKXEWA,B00KTORA0K,5.0,2019-05-25 04:03:51.505,1.558757e+09,3.0,4.666667,1.0,5.0,1.0,...,"1431150669,1431150834,1432041664,1432041986,15...","[-1.0, -1.0, -1.0, 1657.0, 2074.0, 593.0, 583....","[-1, -1, -1, 1431150669, 1431150834, 143204166...","[-1, -1, -1, 7, 7, 7, 7, 5, 5, 0]","[Dance Game, Family-friendly, Party Game, Fitn...",Video Games,Just Dance 2015 - Wii,[With more than 50 million copies of Just Danc...,"[Video Games, Legacy Systems, Nintendo Systems...",33.0


In [12]:
user_indices = train_df["user_indice"].unique()
item_indices = train_df["item_indice"].unique()
if args.rc.use_sbert_features:
    all_sbert_vectors = ann_index.get_vector_by_ids(
        item_indices.tolist(), chunk_size=1000
    ).astype(np.float32)

train_item_features = chunk_transform(
    train_df, item_metadata_pipeline, chunk_size=args.tfm_chunk_size
)
train_item_features = train_item_features.astype(np.float32)

val_item_features = chunk_transform(
    val_df, item_metadata_pipeline, chunk_size=args.tfm_chunk_size
)
val_item_features = val_item_features.astype(np.float32)

if args.rc.use_sbert_features:
    train_sbert_vectors = all_sbert_vectors[train_df["item_indice"].values]
    train_item_features = np.hstack([train_item_features, train_sbert_vectors])
    val_sbert_vectors = all_sbert_vectors[val_df["item_indice"].values]
    val_item_features = np.hstack([val_item_features, val_sbert_vectors])

logger.info(f"{len(user_indices)=:,.0f}, {len(item_indices)=:,.0f}")

Transforming chunks:   0%|          | 0/33 [00:00<?, ?it/s]

Transforming chunks:   0%|          | 0/1 [00:00<?, ?it/s]

[32m2024-11-16 10:49:07.342[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m24[0m - [1mlen(user_indices)=19,578, len(item_indices)=4,630[0m


# Train

In [13]:
rating_dataset = UserItemBinaryDFDataset(
    train_df,
    "user_indice",
    "item_indice",
    args.rating_col,
    args.timestamp_col,
    item_feature=train_item_features,
)
val_rating_dataset = UserItemBinaryDFDataset(
    val_df,
    "user_indice",
    "item_indice",
    args.rating_col,
    args.timestamp_col,
    item_feature=val_item_features,
)

train_loader = DataLoader(
    rating_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True
)
val_loader = DataLoader(
    val_rating_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False
)

In [15]:
n_items = len(item_indices)
n_users = len(user_indices)
item_feature_size = train_item_features.shape[1]

model = init_model(
    n_users,
    n_items,
    args.embedding_dim,
    args.item_sequence_ts_bucket_size,
    args.bucket_embedding_dim,
    item_feature_size,
    args.dropout,
)
model

Ranker(
  (item_embedding): Embedding(4631, 128, padding_idx=4630)
  (user_embedding): Embedding(19578, 128)
  (item_sequence_ts_bucket_embedding): Embedding(11, 16, padding_idx=10)
  (gru): GRU(144, 128, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (item_feature_tower): Sequential(
    (0): Linear(in_features=461, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
  )
  (fc_rating): Sequential(
    (0): Linear(in_features=512, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=128, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

#### Predict before train

In [16]:
val_df = val_rating_dataset.df
val_df.sample(10)

Unnamed: 0,user_id,parent_asin,rating,timestamp,timestamp_unix,parent_asin_rating_cnt_365d,parent_asin_rating_avg_prev_rating_365d,parent_asin_rating_cnt_90d,parent_asin_rating_avg_prev_rating_90d,parent_asin_rating_cnt_30d,...,user_rating_list_10_recent_asin_timestamp,item_sequence,item_sequence_ts,item_sequence_ts_bucket,tags,main_category,title,description,categories,price
725,AEBFEGEPBZHEFC4THOEGKOC4Z4JA,B00W435BL4,0.0,2021-09-22 21:00:16.647,,2.0,5.0,0.0,,0.0,...,"1491283332,1514903179,1618335983,1620190519,16...","[-1, -1, -1, -1, -1, 1767, 2622, 4043, 3703, 4...","[-1, -1, -1, -1, -1, 1491283332, 1514903179, 1...","[-1, -1, -1, -1, -1, 7, 7, 5, 5, 5]","[Sports, Football Simulation, Legacy Edition, ...",Video Games,Madden NFL 16 - PlayStation 3,[Be The Playmaker with Madden NFL 16],"[Video Games, Legacy Systems, PlayStation Syst...",35.03
1237,AEUEYNXO2K56N7ISKWYKSRBYSDOA,B073X4RF9Q,0.0,2022-01-16 14:30:31.984,,3.0,4.0,1.0,5.0,1.0,...,"1543285236,1543497284,1544594019,1544594115,15...","[-1, -1, -1, -1, -1, 3935, 4305, 4183, 4303, 3...","[-1, -1, -1, -1, -1, 1543285236, 1543497284, 1...","[-1, -1, -1, -1, -1, 7, 7, 7, 7, 6]","[Dock Accessory, LED Light, Customizable Desig...",Video Games,Nintendo Switch 500-042 Light Up Dock Shield b...,[],[],
864,AGFNYWOFY5N6MWIDLMEROWT5Z2IA,B09KL8P6DP,1.0,2022-03-08 18:42:41.757,1646765000.0,6.0,3.5,3.0,3.0,1.0,...,"1406431596,1521099892,1521099994,1540430617,15...","[-1, -1, 917, 1123, 1035, 2626, 4457, 1228, 28...","[-1, -1, 1406431596, 1521099892, 1521099994, 1...","[-1, -1, 8, 7, 7, 7, 7, 7, 6, 6]","[Wired Controller, Xbox Compatible, PC Gaming,...",All Electronics,"VOYEE PC Controller, Wired Controller Compatib...",[],"[Video Games, Legacy Systems, Xbox Systems, Xb...",16.99
1641,AHJUZFMUESAEQBPC2QQMBDVUBYFQ,B0B1PB5L93,1.0,2022-07-15 11:08:51.431,1657883000.0,3.0,4.0,1.0,4.0,0.0,...,"1608419609,1608420314,1608471521,1608471900,16...","[-1, -1, -1, 4520, 4417, 4368, 4415, 3788, 458...","[-1, -1, -1, 1608419609, 1608420314, 160847152...","[-1, -1, -1, 6, 6, 6, 6, 6, 6, 5]","[Gaming Mouse, RGB Lighting, Wireless, Ultra L...",Computers,Razer Viper Ultimate Lightweight Wireless Gami...,[Forget about average and claim the unfair adv...,"[Video Games, PC, Accessories, Gaming Mice]",89.99
630,AHOUEPXPGR4EI2WVY7LXIDFLQ2FQ,B081243BT6,0.0,2022-05-22 22:38:27.290,,8.0,4.625,2.0,4.5,1.0,...,"1293751188,1299694660,1304454064,1349126155,13...","[1351, 902, 1686, 1611, 526, 2461, 4351, 3263,...","[1293751188, 1299694660, 1304454064, 134912615...","[9, 9, 9, 8, 8, 8, 8, 7, 7, 6]","[Carrying Case, Travel, Protection, Nintendo S...",Cell Phones & Accessories,Orzly Carrying case for Nintendo Switch OLED a...,[],"[Video Games, Nintendo Switch, Accessories, Ca...",29.99
865,AG45MF7DODFAS4EU2UXZASTPFZKA,B0080CAO9C,0.0,2022-01-06 06:05:55.750,,0.0,,0.0,,0.0,...,"1298708776,1389166572,1390183166,1552010347,15...","[-1, -1, 1574, 2338, 2306, 2402, 3132, 4133, 3...","[-1, -1, 1298708776, 1389166572, 1390183166, 1...","[-1, -1, 9, 8, 8, 6, 6, 6, 6, 6]","[Soccer, Sports Simulation, Multiplayer, Compe...",Video Games,Pro Evolution Soccer 2013 - Xbox 360,"[Product Description, This fall Pro Evolution ...","[Video Games, Legacy Systems, Xbox Systems, Xb...",20.25
655,AGZUCRR3HOU7LFWM2PMBSD7TRO7Q,B00000JNHJ,0.0,2021-11-21 08:19:28.527,,0.0,,0.0,,0.0,...,"1546967806,1555758005,1555758319,1566666232,15...","[-1, -1, -1, -1, -1, 3379, 3998, 3985, 3849, 4...","[-1, -1, -1, -1, -1, 1546967806, 1555758005, 1...","[-1, -1, -1, -1, -1, 6, 6, 6, 6, 6]","[Platformer, Adventure, Classic, Puzzle Elemen...",Video Games,Ape Escape,"[Product description, The story begins when Sp...","[Video Games, Legacy Systems, PlayStation Syst...",72.96
154,AHQJ5UXX647PD77SHSTCPKTSA3XA,B01JS3F79I,0.0,2021-09-13 14:51:38.718,,13.0,4.0,2.0,4.5,0.0,...,"1514575134,1576179131,1582813328,1588087497,16...","[-1, -1, -1, -1, 2760, 1767, 4559, 3573, 4043,...","[-1, -1, -1, -1, 1514575134, 1576179131, 15828...","[-1, -1, -1, -1, 7, 6, 6, 6, 5, 5]","[Headset, Wireless, Audio Quality, Gaming Acce...",Video Games,Turtle Beach - Stealth 520 Premium Fully Wirel...,[Turtle Beach’s Stealth 520 brings 100% wirele...,"[Video Games, Legacy Systems, Microconsoles, G...",
1243,AGZTQGIO2TJFNFL4VLINTVS7TJ3A,B003MQMD2W,0.0,2022-05-11 18:02:11.734,,0.0,,0.0,,0.0,...,"1052848873,1055865403,1165586688,1208791991,12...","[-1, 152, 943, 996, 923, 926, 1692, 2958, 4508...","[-1, 1052848873, 1055865403, 1165586688, 12087...","[-1, 9, 9, 9, 9, 9, 8, 8, 8, 7]","[Console Bundle, Legacy Hardware, Family Gamin...",Video Games,Wii Hardware Bundle - Black,"[Product Description, Includes Black Wii Conso...","[Video Games, Legacy Systems, Nintendo Systems...",144.18
185,AEQLE5GSVZDI7I2N5K6P2DL6PNSQ,B00LE3EAIK,1.0,2022-06-14 20:20:27.992,1655238000.0,0.0,,0.0,,0.0,...,"1401296568,1404236800,1408392453,1408393139,14...","[1035, 95, 2384, 71, 130, 14, 3985, 3942, 4069...","[1401296568, 1404236800, 1408392453, 140839313...","[8, 8, 8, 8, 8, 8, 7, 7, 6, 6]","[Legacy Gaming, Video Cable, Nintendo Accessor...",Video Games,Gam3Gear SNES Nintendo N64 Gamecube S Video Cable,[High quaility S-Video Cable for SNES/N64/Game...,"[Video Games, Legacy Systems, Nintendo Systems...",9.99


In [17]:
user_id = val_df.sample(1)[args.user_col].values[0]
test_df = val_df.loc[lambda df: df[args.user_col].eq(user_id)]
with pd.option_context("display.max_colwidth", None):
    display(test_df)

Unnamed: 0,user_id,parent_asin,rating,timestamp,timestamp_unix,parent_asin_rating_cnt_365d,parent_asin_rating_avg_prev_rating_365d,parent_asin_rating_cnt_90d,parent_asin_rating_avg_prev_rating_90d,parent_asin_rating_cnt_30d,...,user_rating_list_10_recent_asin_timestamp,item_sequence,item_sequence_ts,item_sequence_ts_bucket,tags,main_category,title,description,categories,price
306,AFPONZEGVVCYYYP2SAEKZ6KXBDUA,B09R21G9DL,1.0,2022-06-19 18:10:50.167,1655662000.0,2.0,4.5,0.0,,0.0,...,162060403316206043221620604622162060530016232846761623285039,"[-1, -1, -1, -1, 4588, 4433, 4607, 4263, 4337, 4345]","[-1, -1, -1, -1, 1620604033, 1620604322, 1620604622, 1620605300, 1623284676, 1623285039]","[-1, -1, -1, -1, 6, 6, 6, 6, 6, 6]","[Controllers, Gamepad, Wired, Enhanced Grip, Compatible with Nintendo]",Computers,"Cipon Gamecube Controller, Wired Controller Gamepad Compatible with Nintendo Wii/GameCube - Enhanced (Black & Black)",[],"[Video Games, Legacy Systems, Nintendo Systems, Wii, Accessories, Controllers, Gamepads & Standard Controllers]",17.99
561,AFPONZEGVVCYYYP2SAEKZ6KXBDUA,B00503E8S2,0.0,2022-06-19 18:10:50.167,,1.0,1.0,1.0,1.0,0.0,...,162060403316206043221620604622162060530016232846761623285039,"[-1, -1, -1, -1, 4588, 4433, 4607, 4263, 4337, 4345]","[-1, -1, -1, -1, 1620604033, 1620604322, 1620604622, 1620605300, 1623284676, 1623285039]","[-1, -1, -1, -1, 6, 6, 6, 6, 6, 6]","[First-Person Shooter, Action, Military, Single Player, Multiplayer]",Video Games,Call of Duty: Modern Warfare 3 - Xbox 360,"[Product Description, Modern Warfare is back. On November 8th, the best selling first person action series of all time returns with the epic sequel to the multiple Game of the Year award winner Call of Duty: Modern Warfare 2., Amazon.com, Call of Duty: Modern Warfare 3, is First-person Shooter rooted in a fictional, but ultra realistic near-future conflict of mostly American forces with those of the Russian Federation around the globe. The third installment in the, Modern Warfare, branch of the, Call of Duty, franchise,, Modern Warfare 3, features a heavy focus on multiplayer gameplay which includes innovative new functionality that encourages multiple gameplay combat strategies, a new 2-player co-op option, new play modes, weapons and more. The game also includes a gripping single player campaign that picks up where, Call of Duty: Modern Warfare 2, left off, and game integration with the, Call of Duty, : Elite online service., Do What is Necessary in the Face of Invasion, Call of Duty: Modern Warfare 3, is a direct sequel to the previous game in the series,, Call of Duty: Modern Warfare 2, . In the game's single player campaign Russian Ultranationalist Vladimir Makarov continues his manipulation of Russian Federation forces in their invasion of the United States and Europe. In their way stands characters like Task Force 141 Captain John ""Soap"" MacTavish, former SAS Captain John Price as well as new playable characters from Delta Force and the British SAS. Engage enemy forces in New York, Paris, Berlin and other attack sites across the globe. The world stands on the brink, and Makarov is intent on bringing civilization to its knees. In this darkest hour, are you willing to do what is necessary., Multiplayer That is Bigger and Better Than Ever, Call of Duty: Modern Warfare 3, delivers a multiplayer experience that continues to raise the bar by focusing on fast-paced, gun-on-gun combat, along with innovative features that support and enhance a large variety of play-styles. Now, you can truly define your approach with a toolkit more expansive than any previous title., Continue the Call of Duty: Modern Warfare in the third release in the series. View larger Best in class multiplayer action. View larger, Pointstreaks and Strike Packages, Killstreaks, benefits and abilities awarded for stringing together multiple kills, have been transformed into Pointstreaks, creating a system that rewards players both for landing kills and completing objectives. These rewards have been broken up into three different categories, known as Strike Packages:, Assault - Pointstreaks within this package chain together and deal direct damage. It includes classics like the Predator Missile and Attack Helicopter. Your streak resets on death., Assault - Pointstreaks within this package chain together and deal direct damage. It includes classics like the Predator Missile and Attack Helicopter. Your streak resets on death., Support - Pointstreaks within this package do not chain, focusing instead on surveillance and disruption. Your streak does not reset on death meaning they will respawn with you., Support - Pointstreaks within this package do not chain, focusing instead on surveillance and disruption. Your streak does not reset on death meaning they will respawn with you., Specialist - Pointstreaks within this package are designed for advanced players. Rewards come in the form of additional perks for optimal performance. These perks last until death., Specialist - Pointstreaks within this package are designed for advanced players. Rewards come in the form of additional perks for optimal performance. These perks last until death., Weapon Proficiencies, Just like your player, weapons now rank up, unlocking additional attachments, reticules, camos and the new proficiency category. Weapon Proficiencies not only allow you to get better with weapons, they also allow you customize your weapons with helpful attributes such as ""Kick"" for reduced recoil, ""Impact"" for deeper penetration through hard surfaces and much more. Many proficiencies are specific to their weapons class. And all are geared towards enhancing a certain play style and can allow for efficient use of your favorite weapons in maps and game modes in which they might not otherwise be the best choice., Modes and Match Customization, Along with the return of the fan favorites, Call of Duty: Modern Warfare 3, introduces several new game modes. Collect dog tags from killed players, including those on your squad while you prevent the opposing squad from taking yours in Kill Confirmed mode. In Team Defender mode, grab the flag and protect the flag carrier for as long as you can to increase your team's score. In addition players will enjoy user generated match mode functionality which allows you to configure any mode how you want it and then share these over the, Call of Duty, : Elite online service., 2-player Co-op Special Ops Survival Mode, Special Ops cooperative action returns with a bevy of additions, including 16 new objective-based missions and the all-new Survival Mode. Team up online, locally, or play solo and face endless waves of attacking enemies throughout every multiplayer map. Purchase and customize your weapons, air support, equipment and abilities to stand up against increasingly difficult forces and land a spot on the leaderboards. Earn experience and rank up with the newly implemented progression system. The higher the rank, the more weapons, air support, and gear armories you will have available to customize so you can change your tactics on the fly. In addition to the action packed battle for freedom, the cooperative Survival Mode also serves as an effective training tool for competitive multiplayer action., Call of Duty, : Elite, Call of Duty, : Elite is an online multiplayer oriented service launched simultaneously with, Call of Duty: Modern Warfare 3, . The service offers both free and subscription based levels of access and is dedicated to the game franchise, featuring lifetime statistics across multiple games, social-networking options, competitions, a mobile app, Facebook integration and more., Additional Screenshots]","[Video Games, Legacy Systems, Xbox Systems, Xbox 360, Games]",40.99


In [18]:
test_row = test_df.loc[lambda df: df[args.rating_col].gt(0)].iloc[0]
item_id = test_row[args.item_col]
item_sequence = test_row["item_sequence"]
item_sequence_ts_bucket = test_row["item_sequence_ts_bucket"]
row_idx = test_row.name
item_feature = val_item_features[row_idx]
logger.info(
    f"Test predicting before training with {args.user_col} = {user_id} and {args.item_col} = {item_id}"
)
user_indice = idm.get_user_index(user_id)
item_indice = idm.get_item_index(item_id)
user = torch.tensor([user_indice])
item_sequence = torch.tensor([item_sequence])
item_sequence_ts_bucket = torch.tensor([item_sequence_ts_bucket])
item_feature = torch.tensor([item_feature])
item = torch.tensor([item_indice])

model.eval()
model.predict(user, item_sequence, item_sequence_ts_bucket, item_feature, item)
model.train()

[32m2024-11-16 10:50:10.342[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m7[0m - [1mTest predicting before training with user_id = AFPONZEGVVCYYYP2SAEKZ6KXBDUA and parent_asin = B09R21G9DL[0m
  item_sequence = torch.tensor([item_sequence])


Ranker(
  (item_embedding): Embedding(4631, 128, padding_idx=4630)
  (user_embedding): Embedding(19578, 128)
  (item_sequence_ts_bucket_embedding): Embedding(11, 16, padding_idx=10)
  (gru): GRU(144, 128, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (item_feature_tower): Sequential(
    (0): Linear(in_features=461, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
  )
  (fc_rating): Sequential(
    (0): Linear(in_features=512, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=128, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

#### Training loop

##### Overfit 1 batch

In [19]:
early_stopping = EarlyStopping(
    monitor="val_loss", patience=10, mode="min", verbose=False
)

model = init_model(
    n_users,
    n_items,
    args.embedding_dim,
    args.item_sequence_ts_bucket_size,
    args.bucket_embedding_dim,
    item_feature_size,
    dropout=0,
)
lit_model = LitRanker(
    model,
    learning_rate=args.learning_rate,
    l2_reg=0.0,
    log_dir=args.notebook_persist_dp,
    accelerator=args.device,
)

log_dir = f"{args.notebook_persist_dp}/logs/overfit"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    accelerator=args.device if args.device else "auto",
    max_epochs=100,
    overfit_batches=1,
    callbacks=[early_stopping],
)
trainer.fit(
    model=lit_model,
    train_dataloaders=train_loader,
    val_dataloaders=train_loader,
)
logger.info(f"Logs available at {trainer.log_dir}")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(overfit_batches=1)` was configured so 1 batch will be used.
You are using a CUDA device ('NVIDIA GeForce RTX 4060 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params | Mode 
-----------------------------------------
0 | model | Ranker | 3.3 M  | train
-----------------------------------------
3.3 M     Trainable params
0         Non-trainable params
3.3 M     Total params
13.318    Total estimated model params size (MB)
15        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …

/home/dvquys/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:251: You requested to overfit but enabled val dataloader shuffling. We are turning off the val dataloader shuffling for you.
/home/dvquys/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/home/dvquys/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:251: You requested to overfit but enabled train dataloader shuffling. We are turning off the train dataloader shuffling for you.
/home/dvquys/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not 

Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=100` reached.
[32m2024-11-16 10:50:19.384[0m | [1mINFO    [0m | [36msrc.ranker.trainer[0m:[36mon_fit_end[0m:[36m171[0m - [1mLogging classification metrics...[0m
[32m2024-11-16 10:50:37.894[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m37[0m - [1mLogs available at /home/dvquys/frostmourne/recsys-mvp/notebooks/data/037-add-llm-item-tags/logs/overfit/lightning_logs/version_1[0m


In [20]:
%tensorboard --logdir $trainer.log_dir --host localhost

##### Fit on all data

In [21]:
all_items_df = train_df.drop_duplicates(subset=["item_indice"])
all_items_indices = all_items_df["item_indice"].values
all_items_features = item_metadata_pipeline.transform(all_items_df).astype(np.float32)
logger.info(
    f"Mean std over categorical and numerical features: {all_items_features.std(axis=0).mean()}"
)
if args.rc.use_sbert_features:
    all_sbert_vectors = ann_index.get_vector_by_ids(all_items_indices.tolist()).astype(
        np.float32
    )
    logger.info(f"Mean std over text features: {all_sbert_vectors.std(axis=0).mean()}")
    all_items_features = np.hstack([all_items_features, all_sbert_vectors])

[32m2024-11-16 10:50:39.019[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mMean std over categorical and numerical features: 0.9988587498664856[0m


In [22]:
# papermill_description=fit-model
early_stopping = EarlyStopping(
    monitor="val_loss", patience=args.early_stopping_patience, mode="min", verbose=False
)

checkpoint_callback = ModelCheckpoint(
    dirpath=f"{args.notebook_persist_dp}/checkpoints",
    filename="best-checkpoint",
    save_top_k=1,
    monitor="val_loss",
    mode="min",
)

model = init_model(
    n_users,
    n_items,
    args.embedding_dim,
    args.item_sequence_ts_bucket_size,
    args.bucket_embedding_dim,
    item_feature_size,
    dropout=args.dropout,
    item_embedding=pretrained_item_embedding,
)
lit_model = LitRanker(
    model,
    learning_rate=args.learning_rate,
    l2_reg=args.l2_reg,
    log_dir=args.notebook_persist_dp,
    evaluate_ranking=True,
    idm=idm,
    all_items_indices=all_items_indices,
    all_items_features=all_items_features,
    args=args,
    neg_to_pos_ratio=args.neg_to_pos_ratio,
    checkpoint_callback=checkpoint_callback,
    accelerator=args.device,
)

In [23]:
log_dir = f"{args.notebook_persist_dp}/logs/run"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    max_epochs=args.max_epochs,
    callbacks=[early_stopping, checkpoint_callback],
    accelerator=args.device if args.device else "auto",
    logger=args._mlf_logger if args.log_to_mlflow else None,
)
trainer.fit(
    model=lit_model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

Checkpoint directory /home/dvquys/frostmourne/recsys-mvp/notebooks/data/037-add-llm-item-tags/checkpoints exists and is not empty.

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params | Mode 
-----------------------------------------
0 | model | Ranker | 3.3 M  | train
-----------------------------------------
3.3 M     Trainable params
0         Non-trainable params
3.3 M     Total params
13.318    Total estimated model params size (MB)
15        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.



Training: |                                                                                                   …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=1` reached.
[32m2024-11-16 10:51:02.194[0m | [1mINFO    [0m | [36msrc.ranker.trainer[0m:[36mon_fit_end[0m:[36m164[0m - [1mLoading best model from /home/dvquys/frostmourne/recsys-mvp/notebooks/data/037-add-llm-item-tags/checkpoints/best-checkpoint-v1.ckpt...[0m
[32m2024-11-16 10:51:02.244[0m | [1mINFO    [0m | [36msrc.ranker.trainer[0m:[36mon_fit_end[0m:[36m171[0m - [1mLogging classification metrics...[0m
[32m2024-11-16 10:51:03.087[0m | [1mINFO    [0m | [36msrc.ranker.trainer[0m:[36mon_fit_end[0m:[36m174[0m - [1mLogging ranking metrics...[0m


Generating recommendations:   0%|          | 0/177 [00:00<?, ?it/s]

2024/11/16 10:51:08 INFO mlflow.tracking._tracking_service.client: 🏃 View run 037-add-llm-item-tags at: http://localhost:5002/#/experiments/3/runs/c19da73d26ed451da50724137eca81ad.
2024/11/16 10:51:08 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://localhost:5002/#/experiments/3.


In [24]:
logger.info(
    f"Test predicting after training with {args.user_col} = {user_id} and {args.item_col} = {item_id}"
)
model.eval()
model = model.to(user.device)  # Move model back to align with data device
model.predict(user, item_sequence, item_sequence_ts_bucket, item_feature, item)
model.train()

[32m2024-11-16 10:51:08.528[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mTest predicting after training with user_id = AFPONZEGVVCYYYP2SAEKZ6KXBDUA and parent_asin = B09R21G9DL[0m


Ranker(
  (item_embedding): Embedding(4631, 128, padding_idx=4630)
  (user_embedding): Embedding(19578, 128)
  (item_sequence_ts_bucket_embedding): Embedding(11, 16, padding_idx=10)
  (gru): GRU(144, 128, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (item_feature_tower): Sequential(
    (0): Linear(in_features=461, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
  )
  (fc_rating): Sequential(
    (0): Linear(in_features=512, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=128, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

# Load best checkpoint

In [25]:
logger.info(f"Loading best checkpoint from {checkpoint_callback.best_model_path}...")
args.best_checkpoint_path = checkpoint_callback.best_model_path

best_trainer = LitRanker.load_from_checkpoint(
    checkpoint_callback.best_model_path,
    model=init_model(
        n_users,
        n_items,
        args.embedding_dim,
        args.item_sequence_ts_bucket_size,
        args.bucket_embedding_dim,
        item_feature_size,
        dropout=0,
        item_embedding=pretrained_item_embedding,
    ),
)

[32m2024-11-16 10:51:08.559[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mLoading best checkpoint from /home/dvquys/frostmourne/recsys-mvp/notebooks/data/037-add-llm-item-tags/checkpoints/best-checkpoint-v1.ckpt...[0m


In [26]:
best_model = best_trainer.model.to(lit_model.device)

In [27]:
best_model.eval()
best_model.predict(user, item_sequence, item_sequence_ts_bucket, item_feature, item)
best_model.train()

Ranker(
  (item_embedding): Embedding(4631, 128, padding_idx=4630)
  (user_embedding): Embedding(19578, 128)
  (item_sequence_ts_bucket_embedding): Embedding(11, 16, padding_idx=10)
  (gru): GRU(144, 128, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0, inplace=False)
  (item_feature_tower): Sequential(
    (0): Linear(in_features=461, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0, inplace=False)
  )
  (fc_rating): Sequential(
    (0): Linear(in_features=512, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0, inplace=False)
    (4): Linear(in_features=128, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

### Persist artifacts

In [28]:
if args.log_to_mlflow:
    # Persist id_mapping so that at inference we can predict based on item_ids (string) instead of item_index
    run_id = trainer.logger.run_id
    mlf_client = trainer.logger.experiment
    mlf_client.log_artifact(run_id, idm_fp)
    # Persist item_feature_metadata pipeline
    mlf_client.log_artifact(run_id, args.item_metadata_pipeline_fp)

    # Persist model architecture
    model_architecture_fp = f"{args.notebook_persist_dp}/model_architecture.txt"
    with open(model_architecture_fp, "w") as f:
        f.write(repr(model))
    mlf_client.log_artifact(run_id, model_architecture_fp)

### Wrap inference function and register best checkpoint as MLflow model

In [29]:
inferrer = RankerInferenceWrapper(best_model)

In [30]:
def generate_sample_item_features():
    sample_row = train_df.iloc[0].fillna(0)
    output = dict()
    for col in args.rc.item_feature_cols:
        v = sample_row[col]
        if isinstance(v, np.ndarray):
            v = "__".join(
                sample_row[col].tolist()
            )  # Workaround to avoid MLflow Got error: Per-column arrays must each be 1-dimensional
        output[col] = [v]
    return output

In [31]:
sample_input = {
    args.user_col: [idm.get_user_id(0)],
    "item_sequence": [",".join([idm.get_item_id(0), idm.get_item_id(1)])],
    "item_sequence_ts": [
        "1095133116,109770848"
    ],  # Here we input unix timestamp seconds instead of timestamp bucket because we need to calculate the bucket
    # **{col: [train_df.iloc[0].fillna(0)[col]] for col in args.item_feature_cols},
    **generate_sample_item_features(),
    args.item_col: [idm.get_item_id(0)],
}
sample_output = inferrer.infer([0], [[0, 1]], [[2, 0]], [train_item_features[0]], [0])
sample_output

array([0.774318], dtype=float32)

In [32]:
sample_input

{'user_id': ['AE225O22SA7DLBOGOEIFL7FT5VYQ'],
 'item_sequence': ['0375869026,9625990674'],
 'item_sequence_ts': ['1095133116,109770848'],
 'main_category': ['Video Games'],
 'categories': ['Video Games__Legacy Systems__PlayStation Systems__PlayStation 3__Accessories__Controllers'],
 'price': ['49.99'],
 'parent_asin_rating_cnt_365d': [76.0],
 'parent_asin_rating_avg_prev_rating_365d': [4.592105263157895],
 'parent_asin_rating_cnt_90d': [10.0],
 'parent_asin_rating_avg_prev_rating_90d': [4.3],
 'parent_asin_rating_cnt_30d': [3.0],
 'parent_asin_rating_avg_prev_rating_30d': [5.0],
 'parent_asin_rating_cnt_7d': [1.0],
 'parent_asin_rating_avg_prev_rating_7d': [5.0],
 'parent_asin': ['0375869026']}

In [33]:
if args.log_to_mlflow:
    run_id = trainer.logger.run_id
    sample_output_np = sample_output
    signature = infer_signature(sample_input, sample_output_np)
    idm_filename = idm_fp.split("/")[-1]
    item_metadata_pipeline_filename = args.item_metadata_pipeline_fp.split("/")[-1]
    with mlflow.start_run(run_id=run_id):
        mlflow.pyfunc.log_model(
            python_model=inferrer,
            artifact_path="inferrer",
            artifacts={
                # We log the id_mapping to the predict function so that it can accept item_id and automatically convert ot item_indice for PyTorch model to use
                "idm": mlflow.get_artifact_uri(idm_filename),
                "item_metadata_pipeline": mlflow.get_artifact_uri(
                    item_metadata_pipeline_filename
                ),
            },
            model_config={"use_sbert_features": args.rc.use_sbert_features},
            signature=signature,
            input_example=sample_input,
            registered_model_name=args.mlf_model_name,
        )


Since MLflow 2.16.0, we no longer convert dictionary input example to pandas Dataframe, and directly save it as a json object. If the model expects a pandas DataFrame input instead, please pass the pandas DataFrame as input example directly.



Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

ValueError("columns are missing: {'tags'}")Traceback (most recent call last):


  File "/home/dvquys/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/mlflow/utils/_capture_modules.py", line 165, in load_model_and_predict
    model.predict(input_example, params=params)


  File "/home/dvquys/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/mlflow/pyfunc/model.py", line 637, in predict
    return self.python_model.predict(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^


  File "/home/dvquys/frostmourne/recsys-mvp/notebooks/../src/ranker/inference.py", line 91, in predict
    item_features = self.item_metadata_pipeline.transform(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^


  File "/home/dvquys/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/sklearn/pipeline.py", line 903, in transform
    Xt = transform.transform(Xt, **routed_params[name].transform)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^


  File "/home/dvquys/frostmourne/rec

Downloading artifacts:   0%|          | 0/9 [00:00<?, ?it/s]

  "inputs": {
    "user_id": [
      "AE225O22SA7DLBOGOEIFL7FT5VYQ"
    ],
    "item_sequence": [
      "0375869026,9625990674"
    ],
    "item_sequence_ts": [
      "1095133116,109770848"
    ],
    "main_category": [
      "Video Games"
    ],
    "categories": [
      "Video Games__Legacy Systems__PlayStation Systems__PlayStation 3__Accessories__Controllers"
    ],
    "price": [
      "49.99"
    ],
    "parent_asin_rating_cnt_365d": [
      76.0
    ],
    "parent_asin_rating_avg_prev_rating_365d": [
      4.592105263157895
    ],
    "parent_asin_rating_cnt_90d": [
      10.0
    ],
    "parent_asin_rating_avg_prev_rating_90d": [
      4.3
    ],
    "parent_asin_rating_cnt_30d": [
      3.0
    ],
    "parent_asin_rating_avg_prev_rating_30d": [
      5.0
    ],
    "parent_asin_rating_cnt_7d": [
      1.0
    ],
    "parent_asin_rating_avg_prev_rating_7d": [
      5.0
    ],
    "parent_asin": [
      "0375869026"
    ]
  }
}. Alternatively, you can avoid passing input example 

# Set the newly trained model as champion

In [34]:
if args.log_to_mlflow:
    val_roc_auc = trainer.logger.experiment.get_run(trainer.logger.run_id).data.metrics[
        "val_roc_auc"
    ]

    if val_roc_auc > args.min_roc_auc:
        logger.info(f"Aliasing the new model as champion...")
        model_version = (
            mlf_client.get_registered_model(args.mlf_model_name)
            .latest_versions[0]
            .version
        )

        mlf_client.set_registered_model_alias(
            name=args.mlf_model_name, alias="champion", version=model_version
        )

        mlf_client.set_model_version_tag(
            name=args.mlf_model_name,
            version=model_version,
            key="author",
            value=args.author,
        )

[32m2024-11-16 10:51:12.748[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m7[0m - [1mAliasing the new model as champion...[0m


# Clean up

In [35]:
all_params = [args]

if args.log_to_mlflow:
    with mlflow.start_run(run_id=run_id):
        for params in all_params:
            params_dict = params.dict()
            params_ = dict()
            for k, v in params_dict.items():
                if k == "top_K":
                    k = "top_big_K"
                if k == "top_k":
                    k = "top_small_k"
                params_[f"{params.__repr_name__()}.{k}"] = v
            mlflow.log_params(params_)

2024/11/16 10:51:12 INFO mlflow.tracking._tracking_service.client: 🏃 View run 037-add-llm-item-tags at: http://localhost:5002/#/experiments/3/runs/c19da73d26ed451da50724137eca81ad.
2024/11/16 10:51:12 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://localhost:5002/#/experiments/3.
