# Sequence modeling for ranking task

# Set up

In [1]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

In [2]:
import os
import sys

import lightning as L
import numpy as np
import pandas as pd
import torch
from dotenv import load_dotenv
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import MLFlowLogger
from loguru import logger
from mlflow.models.signature import infer_signature
from pydantic import BaseModel
from torch.utils.data import DataLoader

import mlflow

load_dotenv()

sys.path.insert(0, "..")

from src.dataset import UserItemBinaryDFDataset as UserItemRatingDFDataset
from src.id_mapper import IDMapper
from src.sequence.inference import SequenceRatingPredictionInferenceWrapper
from src.sequence.model import SequenceRatingPrediction
from src.sequence.trainer import LitSequenceRatingPrediction
from src.sequence.utils import generate_item_sequences
from src.viz import blueq_colors

# Controller

In [3]:
max_epochs = 100

In [4]:
class Args(BaseModel):
    testing: bool = False
    log_to_mlflow: bool = True
    experiment_name: str = "RecSys MVP - Sequence Modeling"
    run_name: str = "002-try-binary"
    notebook_persist_dp: str = None
    random_seed: int = 41
    device: str = None

    max_epochs: int = max_epochs
    batch_size: int = 128

    user_col: str = "user_id"
    item_col: str = "parent_asin"
    rating_col: str = "rating"
    timestamp_col: str = "timestamp"

    top_K: int = 100
    top_k: int = 10

    batch_size: int = 128

    embedding_dim: int = 128
    dropout: float = 0.3
    early_stopping_patience: int = 5
    learning_rate: float = 0.001
    l2_reg: float = 1e-4

    mlf_item2vec_model_name: str = "item2vec"
    mlf_model_name: str = "sequence_rating_prediction"
    min_roc_auc: float = 0.7

    best_checkpoint_path: str = None

    def init(self):
        self.notebook_persist_dp = os.path.abspath(f"data/{self.run_name}")
        os.makedirs(self.notebook_persist_dp, exist_ok=True)

        if not (mlflow_uri := os.environ.get("MLFLOW_TRACKING_URI")):
            logger.warning(
                f"Environment variable MLFLOW_TRACKING_URI is not set. Setting self.log_to_mlflow to false."
            )
            self.log_to_mlflow = False

        if self.log_to_mlflow:
            logger.info(
                f"Setting up MLflow experiment {self.experiment_name} - run {self.run_name}..."
            )
            self._mlf_logger = MLFlowLogger(
                experiment_name=self.experiment_name,
                run_name=self.run_name,
                tracking_uri=mlflow_uri,
                log_model=True,
            )

        return self


args = Args().init()

print(args.model_dump_json(indent=2))

[32m2024-10-26 11:44:56.670[0m | [1mINFO    [0m | [36m__main__[0m:[36minit[0m:[36m46[0m - [1mSetting up MLflow experiment RecSys MVP - Sequence Modeling - run 002-try-binary...[0m


{
  "testing": false,
  "log_to_mlflow": true,
  "experiment_name": "RecSys MVP - Sequence Modeling",
  "run_name": "002-try-binary",
  "notebook_persist_dp": "/Users/dvq/frostmourne/recsys-mvp/notebooks/data/002-try-binary",
  "random_seed": 41,
  "device": null,
  "max_epochs": 100,
  "batch_size": 128,
  "user_col": "user_id",
  "item_col": "parent_asin",
  "rating_col": "rating",
  "timestamp_col": "timestamp",
  "top_K": 100,
  "top_k": 10,
  "embedding_dim": 128,
  "dropout": 0.3,
  "early_stopping_patience": 5,
  "learning_rate": 0.001,
  "l2_reg": 0.0001,
  "mlf_item2vec_model_name": "item2vec",
  "mlf_model_name": "sequence_rating_prediction",
  "min_roc_auc": 0.7,
  "best_checkpoint_path": null
}


# Implement

In [5]:
def init_model(n_users, n_items, embedding_dim, dropout, item_embedding=None):
    model = SequenceRatingPrediction(
        n_users, n_items, embedding_dim, dropout=dropout, item_embedding=item_embedding
    )
    return model

## Load pretrained Item2Vec embeddings

In [6]:
mlf_client = mlflow.MlflowClient()
model = mlflow.pyfunc.load_model(
    model_uri=f"models:/{args.mlf_item2vec_model_name}@champion"
)
skipgram_model = model.unwrap_python_model().model
embedding_0 = skipgram_model.embeddings(torch.tensor(0))
embedding_dim = embedding_0.size()[0]
id_mapping = model.unwrap_python_model().id_mapping
pretrained_item_embedding = skipgram_model.embeddings

Downloading artifacts:   0%|          | 0/8 [00:00<?, ?it/s]

In [7]:
assert (
    pretrained_item_embedding.embedding_dim == args.embedding_dim
), "Mismatch pretrained item_embedding dimension"

# Test implementation

In [8]:
embedding_dim = 8
batch_size = 2

# Mock data
user_indices = [0, 0, 1, 2, 2]
item_indices = [0, 1, 2, 3, 4]
timestamps = [0, 1, 2, 3, 4]
ratings = [0, 4, 5, 3, 0]
item_sequences = [
    [-1, -1, 2, 3],
    [-1, -1, 2, 3],
    [-1, -1, 1, 3],
    [-1, -1, 2, 1],
    [-1, -1, 2, 1],
]

n_users = len(set(user_indices))
n_items = len(set(item_indices))

train_df = pd.DataFrame(
    {
        "user_indice": user_indices,
        "item_indice": item_indices,
        args.timestamp_col: timestamps,
        args.rating_col: ratings,
        "item_sequence": item_sequences,
    }
)

model = init_model(n_users, n_items, embedding_dim, args.dropout)

# Example forward pass
model.eval()
user = torch.tensor([0])
item_sequence = torch.tensor([[-1, -1, -1, 0, 1]])
target_item = torch.tensor([2])
predictions = model.predict(user, item_sequence, target_item)
print(predictions)

tensor([[0.5778]], grad_fn=<SigmoidBackward0>)


In [9]:
rating_dataset = UserItemRatingDFDataset(
    train_df, "user_indice", "item_indice", args.rating_col, args.timestamp_col
)

train_loader = DataLoader(rating_dataset, batch_size=batch_size, shuffle=False)

In [10]:
for batch_input in train_loader:
    print(batch_input)

{'user': tensor([0, 0]), 'item': tensor([0, 1]), 'rating': tensor([0., 1.]), 'item_sequence': tensor([[-1, -1,  2,  3],
        [-1, -1,  2,  3]]), 'item_feature': tensor([], size=(2, 0))}
{'user': tensor([1, 2]), 'item': tensor([2, 3]), 'rating': tensor([1., 1.]), 'item_sequence': tensor([[-1, -1,  1,  3],
        [-1, -1,  2,  1]]), 'item_feature': tensor([], size=(2, 0))}
{'user': tensor([2]), 'item': tensor([4]), 'rating': tensor([0.]), 'item_sequence': tensor([[-1, -1,  2,  1]]), 'item_feature': tensor([], size=(1, 0))}


In [11]:
# model
lit_model = LitSequenceRatingPrediction(model, log_dir=args.notebook_persist_dp)

# train model
trainer = L.Trainer(default_root_dir=f"{args.notebook_persist_dp}/test", max_epochs=2, accelerator=args.device if args.device else "auto")
trainer.fit(
    model=lit_model, train_dataloaders=train_loader, val_dataloaders=train_loader
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name  | Type                     | Params | Mode
----------------------------------------------------------
0 | model | SequenceRatingPrediction | 729    | eval
----------------------------------------------------------
729       Trainable params
0         Non-trainable params
729       Total params
0.003     Total estimated model params size (MB)
0         Modules in train mode
11        Modules in eval mode


Sanity Checking: |                                                                                            …

/Users/dvq/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.
/Users/dvq/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.
/Users/dvq/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (3) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=2` reached.
[32m2024-10-26 11:44:57.877[0m | [1mINFO    [0m | [36msrc.sequence.trainer[0m:[36mon_fit_end[0m:[36m125[0m - [1mLogging classification metrics...[0m
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [12]:
users = torch.tensor([0, 0, 0, 0])
item_sequences = torch.tensor(
    [[-1, -1, 2, 3], [-1, -1, 2, 3], [-1, -1, 1, 3], [-1, -1, 2, 1]]
)
items = torch.tensor([0, 1, 2, 3])
predictions = model.predict(users, item_sequences, items)
print(predictions)

tensor([[0.5703],
        [0.5769],
        [0.5813],
        [0.6149]], grad_fn=<SigmoidBackward0>)


In [13]:
def create_predict_df(
    train_df,
    val_user_indices,
    val_timestamp,
    rating_col,
    timestamp_col,
    sequence_length=10,
):
    predict_df = pd.DataFrame(
        {
            "user_indice": val_user_indices,
            "item_indice": -1,  # placeholder
            "timestamp": val_timestamp,
            "source": "predict",
        }
    )

    predict_df = (
        pd.concat(
            [
                train_df.loc[lambda df: df[rating_col].gt(0)][
                    ["user_indice", "item_indice", timestamp_col]
                ].assign(source="train"),
                predict_df,
            ],
            axis=0,
        )
        .pipe(
            generate_item_sequences,
            "user_indice",
            "item_indice",
            timestamp_col,
            sequence_length=sequence_length,
            padding=True,
            padding_value=-1,
        )
        .loc[lambda df: df["source"].eq("predict")]
        .assign(item_sequence=lambda df: df["item_sequence"].apply(np.array))
    )

    return predict_df


predict_df = create_predict_df(
    train_df,
    user_indices,
    timestamps[-1],
    args.rating_col,
    args.timestamp_col,
    sequence_length=10,
)

predict_df

Unnamed: 0,user_indice,item_indice,timestamp,source,item_sequence
0,0,-1,4,predict,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, 1]"
1,0,-1,4,predict,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, 1]"
2,1,-1,4,predict,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, 2]"
3,2,-1,4,predict,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, 3]"
4,2,-1,4,predict,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, 3]"


In [14]:
recommendations = model.recommend(
    torch.tensor(predict_df["user_indice"].values),
    torch.tensor(predict_df["item_sequence"].values.tolist()),
    k=2,
    batch_size=4,
)
recommendations


Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_new.cpp:281.)



Generating recommendations:   0%|          | 0/2 [00:00<?, ?it/s]

{'user_indice': [0, 0, 0, 0, 1, 1, 2, 2, 2, 2],
 'recommendation': [3, 4, 3, 4, 3, 4, 3, 2, 3, 2],
 'score': [0.6181633472442627,
  0.6073523163795471,
  0.6181633472442627,
  0.6073523163795471,
  0.5995869636535645,
  0.5976166725158691,
  0.6558717489242554,
  0.6382187008857727,
  0.6558717489242554,
  0.6382187008857727]}

# Prep data

In [15]:
train_df = pd.read_parquet("../data/train_features_neg_df.parquet")
val_df = pd.read_parquet("../data/val_features_neg_df.parquet")
idm_fp = "../data/idm.json"
idm = IDMapper().load(idm_fp)

assert (train_df[args.user_col].map(lambda s: idm.get_user_index(s)) != train_df['user_indice']).sum() == 0, "Mismatch IDM"
assert (val_df[args.user_col].map(lambda s: idm.get_user_index(s)) != val_df['user_indice']).sum() == 0, "Mismatch IDM"

In [16]:
user_indices = train_df["user_indice"].unique()
item_indices = train_df["item_indice"].unique()

logger.info(f"{len(user_indices)=:,.0f}, {len(item_indices)=:,.0f}")

[32m2024-10-26 11:44:59.700[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mlen(user_indices)=19,578, len(item_indices)=4,630[0m


In [17]:
train_df

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,main_category,title,description,categories,price,user_rating_cnt_90d,user_rating_avg_prev_rating_90d,user_rating_list_10_recent_asin,item_sequence
0,AG5MDHR7VXSEQKENQFUIGIFCGAHQ,B00BHRD4BM,0.0,2019-05-28 06:28:22.706,10116,3676,Video Games,Destiny Expansion II: House of Wolves - PS4 [D...,[],"[Video Games, Game Genre of the Month]",,3,1.5,"B071J42387,B072MQNKYV,B087NNPYP3,B07P6MD9B7,B0...","[-1, -1, 3588, 3139, 3860, 1700, 4619, 2365, 8..."
1,AHXA345XC2AXKKGJR4DKW6HHMFWA,B00CD90R7M,0.0,2015-04-02 04:46:11.000,18921,1618,Video Games,FIFA 14 Legacy Edition - PlayStation Vita,[Experience the emotion of scoring great goals...,"[Video Games, Legacy Systems, PlayStation Syst...",39.62,2,5.0,B07KRWJCQW,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, 903]"
2,AH2EBLRPCUCJSPM7QFBS6OWVGZPA,B07X56RNBY,0.0,2017-07-26 23:49:38.082,15098,2300,Video Games,Batman: Arkham Asylum (Game of The Year Editio...,[Become the Invisible Predator with Batman's f...,"[Video Games, Legacy Systems, PlayStation Syst...",25.45,1,,"B002I0H79C,B0041CWZEM,B001SEQWXQ,B019WRM1IA,B0...","[-1, -1, -1, -1, -1, 2951, 2992, 3178, 4392, 1..."
3,AFC5XTCF5D7J3NSDITB2Z26XWWYA,B001E8WQUY,5.0,2019-05-01 21:22:39.265,11036,1395,Video Games,Rock Band 2 - Nintendo Wii (Game only),"[Product description, Rock Band 2 lets you and...","[Video Games, Legacy Systems, Nintendo Systems...",28.49,1,,"B006HZA6VK,B0BN2FNKLM,B0086VPUHI,B0040UAYI4,B0...","[3150, 1398, 2104, 3393, 4380, 1527, 4285, 156..."
4,AF7LJQOIWF3Y3YD7SGOJ34MA5JPA,B001E8WQKY,5.0,2015-01-09 12:53:25.000,13126,1653,Video Games,Resident Evil 5 - Xbox 360,[],"[Video Games, Legacy Systems, Xbox Systems, Xb...",29.88,3,5.0,"B00A2ML6XG,B003VUO6LU","[-1, -1, -1, -1, -1, -1, -1, -1, 597, 1633]"
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
328591,AEUAL3EJKUSTNB4YY6STLYGTJALA,B005A0MBR0,0.0,2018-12-05 21:24:25.638,10044,3155,,Forza Motorsport 4 - Xbox 360,"[Product Description, Forza Motorsport, the hi...","[Video Games, Legacy Systems, Xbox Systems, Xb...",39.0,2,3.0,"B01F7S5NJW,B00BZS9JV2,B006VR689I,B00IRHE892,B0...","[-1, 2216, 6, 2912, 11, 1746, 1637, 3427, 3856..."
328592,AEALHPZXEOAMWWIBOPCNNUTYKYDA,B0072A4JTY,0.0,2018-12-21 18:20:50.687,6135,673,Video Games,PDP PSVita Pull 'N Go Folio,[The ultimate organization and protection solu...,"[Video Games, Legacy Systems, PlayStation Syst...",,2,4.0,"B06XSMSL45,B07796MBJ7","[-1, -1, -1, -1, -1, -1, -1, -1, 629, 2587]"
328593,AHLPBRVL6UTVIRKCMJ3MFRBSGG7Q,B004QIY0Y4,0.0,2018-11-04 08:36:05.714,18743,472,Video Games,"3 pack - LIMBO, Trials HD, Splosion Man - Xbox...","[Product Description, Three complete and criti...","[Video Games, Legacy Systems, Xbox Systems, Xb...",45.0,3,5.0,"B01LW6SS1X,B004RMK57U,B07CV6LH3V,B087SHFL9B","[-1, -1, -1, -1, -1, -1, 3213, 1381, 3958, 2765]"
328594,AHAVA5VKMJ3OMOLGDZ3W45CKXEWA,B00KTORA0K,5.0,2019-05-25 04:03:51.505,4091,3238,Video Games,Just Dance 2015 - Wii,[With more than 50 million copies of Just Danc...,"[Video Games, Legacy Systems, Nintendo Systems...",33.0,2,5.0,"B004AYCNR0,B007NUQICE,B000TYQL1O,B000SEU92W,B0...","[-1, -1, -1, 4394, 4158, 4013, 869, 2540, 1314..."


# Train

In [18]:
rating_dataset = UserItemRatingDFDataset(
    train_df, "user_indice", "item_indice", args.rating_col, args.timestamp_col
)
val_rating_dataset = UserItemRatingDFDataset(
    val_df, "user_indice", "item_indice", args.rating_col, args.timestamp_col
)

train_loader = DataLoader(
    rating_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True
)
val_loader = DataLoader(
    val_rating_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False
)

In [19]:
n_items = len(item_indices)
n_users = len(user_indices)

model = init_model(n_users, n_items, args.embedding_dim, args.dropout)

#### Predict before train

In [20]:
model.item_embedding

Embedding(4631, 128, padding_idx=4630)

In [21]:
val_df = val_rating_dataset.df
val_df.sample(10)

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,main_category,title,description,categories,price,user_rating_cnt_90d,user_rating_avg_prev_rating_90d,user_rating_list_10_recent_asin,item_sequence
1245,AEHAFZUQPE7N3KZ7YF5X27HTGA5A,B07R6YZQB8,1.0,2022-01-10 19:38:36.201,16119,1649,All Electronics,Glistco Simple Feet- Horizontal Stand/Feet Com...,[],"[Video Games, PlayStation 4]",12.6,1,,"B004R9OVEG,B002JTX5RK,B003S3RFAY,B00341B3LW,B0...","[-1, -1, -1, -1, 4322, 2767, 1719, 3307, 496, ..."
1573,AFBHJIL3UN5EUHHEALANUIXYGEYA,B0B1N7619L,0.0,2021-10-02 02:52:29.606,18994,1248,Video Games,Deruitu Switch Accessories Bundle Compatible w...,[],"[Video Games, Legacy Systems, PlayStation Syst...",39.99,1,,"B0C5K4M7WJ,B01N3ASPNV,B005N4I24Y,B07C79LP8M,B0...","[-1, -1, 415, 2487, 4278, 2907, 3157, 2540, 18..."
1576,AEVHKUG4MXISWGQAY6YJROC3O5YQ,B08P1NS2X1,1.0,2022-04-25 18:41:18.626,15558,2998,Video Games,LEGO City Undercover - PlayStation 4,"[Join the Chase! In LEGO CITY Undercover, play...","[Video Games, PlayStation 4, Games]",19.74,1,,"B00L59D9HG,B00P8EMB5A,B01L1Y0RZQ,B01GW3H3U8,B0...","[-1, -1, 1392, 2978, 1999, 2266, 2672, 1700, 3..."
718,AFVSWUUY2EIM5FVF7LZGCRRKQ2KA,B079Y44LDC,1.0,2022-01-19 01:30:09.733,16979,4383,Video Games,Hello Neighbor - Nintendo Switch,[You move into a brand new suburb and notice y...,"[Video Games, Nintendo Switch, Games]",19.99,2,1.0,"B00BZS9JV2,B00Z9LUDX4,B07P9VKCF6,B08P1NS2X1,B0...","[6, 1101, 4523, 2998, 2488, 218, 1191, 2777, 1..."
1877,AF3TAMOYZEDTYYEX4ZB23A6CS7ZA,B002I0K6X6,0.0,2022-02-03 23:14:16.202,17110,3970,Video Games,Playstation Move Navigation Controller,"[Product Description, The PlayStation Move nav...","[Video Games, Legacy Systems, PlayStation Syst...",39.71,2,5.0,"B000FGA1US,B00ZMBLKPG,B016P6KIFY,B01KZUL1CA,B0...","[746, 902, 786, 217, 965, 827, 2388, 2190, 942..."
1540,AE3U66S5YBEMPF36PVYR6QAS5ETA,B002KAS4OW,0.0,2021-09-05 03:56:39.104,1629,1023,Video Games,Demon's Souls - PS3 [Digital Code],"[Deep beneath the Nexus, the Old One has awake...","[Video Games, Legacy Systems, PlayStation Syst...",,2,5.0,"B0B9MJK753,B07H13GWRH,B0C1K1R6HK,B001ELJEJW,B0...","[-1, -1, -1, -1, 1903, 4192, 165, 990, 3873, 4..."
1579,AGD2KE77JSUWQKD5CGYVGCQYJPHQ,B00XBLQCLQ,0.0,2021-12-13 02:24:44.720,14691,2438,Video Games,Assassin’s Creed Syndicate - Gold Edition | PC...,"[London, 1868. The Industrial Revolution unlea...","[Video Games, PC, Games]",66.35,1,,"B008DBJPLS,B00DWXV1B4,B00PQ1OQ4Y,B008KGN9DG,B0...","[2444, 2194, 2465, 2144, 2725, 3942, 3986, 562..."
359,AFSKM4IMOVT5G3IAFHQ6Y7BI7TIA,B000MFOOLY,0.0,2021-12-31 00:17:17.881,7556,3056,Video Games,Teenage Mutant Ninja Turtles II: The Arcade Game,[The second TMNT game for the NES is based on ...,"[Video Games, Legacy Systems, Nintendo Systems...",26.06,1,,"B00Z9TJHEC,B00TKLFES8,B017W1771Y,B01LD7MR2C,B0...","[19, 4369, 2011, 2562, 602, 271, 1789, 2487, 1..."
1784,AHPJHWUFX7DFIVS5B3XNEK7JLSAQ,B001E56K5Y,1.0,2021-12-04 04:56:55.945,4637,3020,Video Games,Pure - Xbox 360,"[Product Description, PURE, the next gold stan...","[Video Games, Legacy Systems, Xbox Systems, Xb...",25.0,1,,"B00DB2BI8M,B016XBGWAQ,B071S8M8TB,B07CZTVHY8,B0...","[431, 1783, 514, 4275, 1710, 4552, 3539, 2020,..."
634,AH3YUGFVELCSHVCFS6WAMA3YLW4A,B005ZBNXMG,0.0,2022-06-16 16:35:49.712,6742,1750,Video Games,Mario & Sonic at the London 2012 Olympic Games...,"[Product Description, Arriving on Nintendo’s n...","[Video Games, Legacy Systems, Nintendo Systems...",29.72,1,,"B00341B3LW,B07YBXFF99,B00KQHHDZC,B00OBZNI0O,B0...","[-1, -1, -1, -1, 3307, 4217, 728, 2102, 995, 828]"


In [22]:
idm.get_item_index("B00FA1CIUE")

3450

In [23]:
# user_id = val_df.sample(1)[args.user_col].values[0]
user_id = "AH4AOFTTDPHPAFAAVFMAF25H2LIQ"
test_df = val_df.loc[lambda df: df[args.user_col].eq(user_id)]
with pd.option_context("display.max_colwidth", None):
    display(test_df)

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,main_category,title,description,categories,price,user_rating_cnt_90d,user_rating_avg_prev_rating_90d,user_rating_list_10_recent_asin,item_sequence
461,AH4AOFTTDPHPAFAAVFMAF25H2LIQ,B07DK1H3H5,1.0,2022-01-09 17:19:07.823,9913,3480,Video Games,Cyberpunk 2077 - PC [Game Download Code in Box],"[Cyberpunk 2077 is an open world, an action adventure story set in Night City, a megalopolis obsessed with power, glamour and body modification. You play at V, a mercenary outlaw going after a one of a kind implant that is the key to immortality. You can customize your character's cyberwar, skill set and playstyle, and explore a vast city where the choices you make shape the story and the world around you.]","[Video Games, PC, Games]",,2,5.0,"B000PS4X9G,B00J5C3Z10,B00DBLBMBQ,B07WS18ZS3,B002XH972U,B0088TN5FM,B00XBLQCLQ,B01GY35W22,B07DKYN13M,B08JDVKWHS","[1090, 2820, 828, 4439, 3090, 2330, 2438, 1289, 1344, 4335]"
709,AH4AOFTTDPHPAFAAVFMAF25H2LIQ,B00IAVDOS6,0.0,2022-01-09 17:19:07.823,9913,1236,Video Games,Xbox One Stereo Headset Adapter,"[Plug your favorite compatible headset into the Xbox One Stereo Headset Adapter and hear the action just the way you like it. Easily adjust chat audio without taking your hands off the controller. Add game audio by connecting directly to your console or TV., Non-compatible headsets The following headsets are incompatible with the Xbox One Stereo Headset Adapter: Mad Catz Tritton Warhead headset Mad Catz Tritton Primer headset (Mad Catz offers an adapter for the Primer headset to convert the 2.5-mm audio jack to a 3.5-mm audio jack. Contact Mad Catz for support.) The Xbox 360 Wireless Headset and Xbox 360 Wireless Bluetooth Headset Headsets with this 2.5-mm connector will not work because of the connector format. This connector includes a long, cylindrical pin in the middle of the connecting side that does not fit into the Xbox One Stereo Headset Adapter.]","[Video Games, Xbox One, Accessories, Cables & Adapters, Adapters]",36.97,2,5.0,"B000PS4X9G,B00J5C3Z10,B00DBLBMBQ,B07WS18ZS3,B002XH972U,B0088TN5FM,B00XBLQCLQ,B01GY35W22,B07DKYN13M,B08JDVKWHS","[1090, 2820, 828, 4439, 3090, 2330, 2438, 1289, 1344, 4335]"
755,AH4AOFTTDPHPAFAAVFMAF25H2LIQ,B08JDVKWHS,1.0,2022-01-09 17:09:46.436,9913,4335,Video Games,Marvel's Spider-Man: Miles Morales Launch Edition - PlayStation 4,"[In the latest adventure in the Marvel's spider-man universe, teenager miles morales is adjusting to his new home while following in the footsteps of his mentor, peter parker, as a new spider-man. But when a fierce power struggle threatens to destroy his new home, the aspiring hero realizes that with great power, there must also come great responsibility. To save all of Marvel's new York, miles must take up the mantle of spider-man and own it.]","[Video Games, PlayStation 4, Games]",57.98,1,,"B008PQU3E4,B000PS4X9G,B00J5C3Z10,B00DBLBMBQ,B07WS18ZS3,B002XH972U,B0088TN5FM,B00XBLQCLQ,B01GY35W22,B07DKYN13M","[589, 1090, 2820, 828, 4439, 3090, 2330, 2438, 1289, 1344]"
1467,AH4AOFTTDPHPAFAAVFMAF25H2LIQ,B07HHVF2XG,0.0,2022-01-09 17:09:46.436,9913,2986,All Electronics,PlayStation Classic,"[Introducing PlayStation Classic A miniature recreation of the iconic PlayStation console, pre loaded with 20 fan favorite games along with two wired controllers for local multiplayer showdowns and a virtual memory card for vital game saves., PlayStation Classic also features the same famous logo, button layout and outer packaging – but this mini console is approximately 45 percent smaller than the original PlayStation and includes a HDMI cable to connect directly to a TV., PlayStation Classic is the perfect console for retro loving fans – and for a new generation wanting to experience the dawn of PlayStation for the first time.]","[Video Games, Legacy Systems, PlayStation Systems, PlayStation, Consoles]",99.0,1,,"B008PQU3E4,B000PS4X9G,B00J5C3Z10,B00DBLBMBQ,B07WS18ZS3,B002XH972U,B0088TN5FM,B00XBLQCLQ,B01GY35W22,B07DKYN13M","[589, 1090, 2820, 828, 4439, 3090, 2330, 2438, 1289, 1344]"


In [24]:
test_row = test_df.loc[lambda df: df[args.rating_col].gt(0)].iloc[0]
item_id = test_row[args.item_col]
item_sequence = test_row["item_sequence"]
logger.info(
    f"Test predicting before training with {args.user_col} = {user_id} and {args.item_col} = {item_id}"
)
user_indice = idm.get_user_index(user_id)
item_indice = idm.get_item_index(item_id)
user = torch.tensor([user_indice])
item_sequence = torch.tensor([item_sequence])
item = torch.tensor([item_indice])

model.eval()
model.predict(user, item_sequence, item)

[32m2024-10-26 11:44:59.929[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mTest predicting before training with user_id = AH4AOFTTDPHPAFAAVFMAF25H2LIQ and parent_asin = B07DK1H3H5[0m


tensor([[0.4554]], grad_fn=<SigmoidBackward0>)

#### Training loop

##### Overfit 1 batch

In [25]:
early_stopping = EarlyStopping(
    monitor="val_loss", patience=10, mode="min", verbose=False
)

model = init_model(n_users, n_items, args.embedding_dim, dropout=0)
lit_model = LitSequenceRatingPrediction(
    model,
    learning_rate=args.learning_rate,
    l2_reg=0.0,
    log_dir=args.notebook_persist_dp,
)

log_dir = f"{args.notebook_persist_dp}/logs/overfit"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    accelerator=args.device if args.device else "auto",
    max_epochs=100,
    overfit_batches=1,
    callbacks=[early_stopping],
)
trainer.fit(
    model=lit_model,
    train_dataloaders=train_loader,
    val_dataloaders=train_loader,
)
logger.info(f"Logs available at {trainer.log_dir}")

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(overfit_batches=1)` was configured so 1 batch will be used.

  | Name  | Type                     | Params | Mode 
-----------------------------------------------------------
0 | model | SequenceRatingPrediction | 3.2 M  | train
-----------------------------------------------------------
3.2 M     Trainable params
0         Non-trainable params
3.2 M     Total params
12.990    Total estimated model params size (MB)
11        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …


You requested to overfit but enabled val dataloader shuffling. We are turning off the val dataloader shuffling for you.


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.


You requested to overfit but enabled train dataloader shuffling. We are turning off the train dataloader shuffling for you.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.


The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.



Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=100` reached.
[32m2024-10-26 11:45:07.809[0m | [1mINFO    [0m | [36msrc.sequence.trainer[0m:[36mon_fit_end[0m:[36m125[0m - [1mLogging classification metrics...[0m
[32m2024-10-26 11:45:24.514[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m28[0m - [1mLogs available at /Users/dvq/frostmourne/recsys-mvp/notebooks/data/002-try-binary/logs/overfit/lightning_logs/version_3[0m


In [26]:
%tensorboard --logdir $trainer.log_dir

##### Fit on all data

In [27]:
# papermill_description=fit-model
early_stopping = EarlyStopping(
    monitor="val_loss", patience=args.early_stopping_patience, mode="min", verbose=False
)

checkpoint_callback = ModelCheckpoint(
    dirpath=f"{args.notebook_persist_dp}/checkpoints",
    filename="best-checkpoint",
    save_top_k=1,
    monitor="val_loss",
    mode="min",
)

model = init_model(
    n_users,
    n_items,
    args.embedding_dim,
    dropout=args.dropout,
    item_embedding=pretrained_item_embedding,
)
lit_model = LitSequenceRatingPrediction(
    model,
    learning_rate=args.learning_rate,
    l2_reg=args.l2_reg,
    log_dir=args.notebook_persist_dp,
    evaluate_ranking=True,
    idm=idm,
    args=args,
)

log_dir = f"{args.notebook_persist_dp}/logs/run"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    max_epochs=args.max_epochs,
    callbacks=[early_stopping, checkpoint_callback],
    accelerator=args.device if args.device else "auto",
    logger=args._mlf_logger if args.log_to_mlflow else None,
)
trainer.fit(
    model=lit_model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

Checkpoint directory /Users/dvq/frostmourne/recsys-mvp/notebooks/data/002-try-binary/checkpoints exists and is not empty.


  | Name  | Type                     | Params | Mode 
-----------------------------------------------------------
0 | model | SequenceRatingPrediction | 3.2 M  | train
-----------------------------------------------------------
3.2 M     Trainable params
0         Non-trainable params
3.2 M     Total params
12.990    Total estimated model params size (MB)
11        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.



Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

[32m2024-10-26 11:52:38.356[0m | [1mINFO    [0m | [36msrc.sequence.trainer[0m:[36mon_fit_end[0m:[36m125[0m - [1mLogging classification metrics...[0m
[32m2024-10-26 11:52:43.874[0m | [1mINFO    [0m | [36msrc.sequence.trainer[0m:[36mon_fit_end[0m:[36m128[0m - [1mLogging ranking metrics...[0m


Generating recommendations:   0%|          | 0/177 [00:00<?, ?it/s]

2024/10/26 11:53:24 INFO mlflow.tracking._tracking_service.client: 🏃 View run 002-try-binary at: http://localhost:5002/#/experiments/2/runs/b1099dfb94094f41be93d54550de42fc.
2024/10/26 11:53:24 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://localhost:5002/#/experiments/2.


In [28]:
logger.info(
    f"Test predicting after training with {args.user_col} = {user_id} and {args.item_col} = {item_id}"
)
model.eval()
model.predict(user, item_sequence, item)

[32m2024-10-26 11:53:24.799[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mTest predicting after training with user_id = AH4AOFTTDPHPAFAAVFMAF25H2LIQ and parent_asin = B07DK1H3H5[0m


tensor([[0.9002]], grad_fn=<SigmoidBackward0>)

# Load best checkpoint

In [29]:
logger.info(f"Loading best checkpoint from {checkpoint_callback.best_model_path}...")
args.best_checkpoint_path = checkpoint_callback.best_model_path

best_trainer = LitSequenceRatingPrediction.load_from_checkpoint(
    checkpoint_callback.best_model_path,
    model=init_model(n_users, n_items, args.embedding_dim, dropout=0),
)

[32m2024-10-26 11:53:24.828[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mLoading best checkpoint from /Users/dvq/frostmourne/recsys-mvp/notebooks/data/002-try-binary/checkpoints/best-checkpoint-v3.ckpt...[0m


In [30]:
best_model = best_trainer.model.to(lit_model.device)

In [31]:
best_model.eval()
best_model.predict(user, item_sequence, item)

tensor([[0.8660]], grad_fn=<SigmoidBackward0>)

### Persist id mapping

In [32]:
if args.log_to_mlflow:
    # Persist id_mapping so that at inference we can predict based on item_ids (string) instead of item_index
    run_id = trainer.logger.run_id
    mlf_client = trainer.logger.experiment
    mlf_client.log_artifact(run_id, idm_fp)

### Wrap inference function and register best checkpoint as MLflow model

In [33]:
inferrer = SequenceRatingPredictionInferenceWrapper(best_model)

In [34]:
sample_input = {
    "user_ids": [idm.get_user_id(0)],
    "item_sequences": [[idm.get_item_id(0), idm.get_item_id(1)]],
    "item_ids": [idm.get_item_id(0)],
}
sample_output = inferrer.infer([0], [[0, 1]], [0])
sample_output

array([0.20323314], dtype=float32)

In [35]:
if args.log_to_mlflow:
    run_id = trainer.logger.run_id
    sample_output_np = sample_output
    signature = infer_signature(sample_input, sample_output_np)
    idm_filename = idm_fp.split("/")[-1]
    with mlflow.start_run(run_id=run_id):
        mlflow.pyfunc.log_model(
            python_model=inferrer,
            artifact_path="inferrer",
            # We log the id_mapping to the predict function so that it can accept item_id and automatically convert ot item_indice for PyTorch model to use
            artifacts={"idm": mlflow.get_artifact_uri(idm_filename)},
            signature=signature,
            input_example=sample_input,
            registered_model_name=args.mlf_model_name,
        )


Since MLflow 2.16.0, we no longer convert dictionary input example to pandas Dataframe, and directly save it as a json object. If the model expects a pandas DataFrame input instead, please pass the pandas DataFrame as input example directly.



Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Registered model 'sequence_rating_prediction' already exists. Creating a new version of this model...
2024/10/26 11:53:27 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: sequence_rating_prediction, version 2
Created version '2' of model 'sequence_rating_prediction'.


Downloading artifacts:   0%|          | 0/8 [00:00<?, ?it/s]

2024/10/26 11:53:27 INFO mlflow.tracking._tracking_service.client: 🏃 View run 002-try-binary at: http://localhost:5002/#/experiments/2/runs/b1099dfb94094f41be93d54550de42fc.
2024/10/26 11:53:27 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://localhost:5002/#/experiments/2.


# Set the newly trained model as champion

In [36]:
if args.log_to_mlflow:
    val_roc_auc = trainer.logger.experiment.get_run(trainer.logger.run_id).data.metrics[
        "val_roc_auc"
    ]

    if val_roc_auc > args.min_roc_auc:
        logger.info(f"Aliasing the new model as champion...")
        model_version = (
            mlf_client.get_registered_model(args.mlf_model_name)
            .latest_versions[0]
            .version
        )

        mlf_client.set_registered_model_alias(
            name=args.mlf_model_name, alias="champion", version=model_version
        )

        mlf_client.set_model_version_tag(
            name=args.mlf_model_name,
            version=model_version,
            key="author",
            value="quy.dinh",
        )

[32m2024-10-26 11:53:27.721[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m7[0m - [1mAliasing the new model as champion...[0m


# Clean up

In [37]:
all_params = [args]

if args.log_to_mlflow:
    with mlflow.start_run(run_id=run_id):
        for params in all_params:
            params_dict = params.dict()
            params_ = dict()
            for k, v in params_dict.items():
                if k == "top_K":
                    k = "top_big_K"
                if k == "top_k":
                    k = "top_small_k"
                params_[f"{params.__repr_name__()}.{k}"] = v
            mlflow.log_params(params_)

2024/10/26 11:53:27 INFO mlflow.tracking._tracking_service.client: 🏃 View run 002-try-binary at: http://localhost:5002/#/experiments/2/runs/b1099dfb94094f41be93d54550de42fc.
2024/10/26 11:53:27 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://localhost:5002/#/experiments/2.
