# Sequence modeling for ranking task

# Set up

In [1]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

In [2]:
import os
import sys

import lightning as L
import numpy as np
import pandas as pd
import torch
from dotenv import load_dotenv
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import MLFlowLogger
from loguru import logger
from mlflow.models.signature import infer_signature
from pydantic import BaseModel
from torch.utils.data import DataLoader

import mlflow

load_dotenv()

sys.path.insert(0, "..")

from src.dataset import UserItemRatingDFDataset
from src.id_mapper import IDMapper
from src.sequence.inference import SequenceRatingPredictionInferenceWrapper
from src.sequence.model import SequenceRatingPrediction
from src.sequence.trainer import LitSequenceRatingPrediction
from src.sequence.utils import generate_item_sequences
from src.viz import blueq_colors

# Controller

In [3]:
max_epochs = 100

In [4]:
class Args(BaseModel):
    testing: bool = False
    log_to_mlflow: bool = True
    experiment_name: str = "RecSys MVP - Sequence Modeling"
    run_name: str = "001-seq-model"
    notebook_persist_dp: str = None
    random_seed: int = 41
    device: str = None

    max_epochs: int = max_epochs
    batch_size: int = 128

    user_col: str = "user_id"
    item_col: str = "parent_asin"
    rating_col: str = "rating"
    timestamp_col: str = "timestamp"

    top_K: int = 100
    top_k: int = 10

    batch_size: int = 128

    embedding_dim: int = 128
    dropout: float = 0.3
    early_stopping_patience: int = 5
    learning_rate: float = 0.001
    l2_reg: float = 1e-4

    mlf_item2vec_model_name: str = "item2vec"
    mlf_model_name: str = "sequence_rating_prediction"
    min_roc_auc: float = 0.7

    best_checkpoint_path: str = None

    def init(self):
        self.notebook_persist_dp = os.path.abspath(f"data/{self.run_name}")
        os.makedirs(self.notebook_persist_dp, exist_ok=True)

        if not (mlflow_uri := os.environ.get("MLFLOW_TRACKING_URI")):
            logger.warning(
                f"Environment variable MLFLOW_TRACKING_URI is not set. Setting self.log_to_mlflow to false."
            )
            self.log_to_mlflow = False

        if self.log_to_mlflow:
            logger.info(
                f"Setting up MLflow experiment {self.experiment_name} - run {self.run_name}..."
            )
            self._mlf_logger = MLFlowLogger(
                experiment_name=self.experiment_name,
                run_name=self.run_name,
                tracking_uri=mlflow_uri,
                log_model=True,
            )

        return self


args = Args().init()

print(args.model_dump_json(indent=2))

[32m2024-10-25 21:53:28.060[0m | [1mINFO    [0m | [36m__main__[0m:[36minit[0m:[36m44[0m - [1mSetting up MLflow experiment RecSys MVP - Sequence Modeling - run 001-seq-model...[0m


{
  "testing": false,
  "log_to_mlflow": true,
  "experiment_name": "RecSys MVP - Sequence Modeling",
  "run_name": "001-seq-model",
  "notebook_persist_dp": "/Users/dvq/frostmourne/recsys-mvp/notebooks/data/001-seq-model",
  "random_seed": 41,
  "device": null,
  "max_epochs": 100,
  "batch_size": 128,
  "user_col": "user_id",
  "item_col": "parent_asin",
  "rating_col": "rating",
  "timestamp_col": "timestamp",
  "top_K": 100,
  "top_k": 10,
  "embedding_dim": 128,
  "dropout": 0.3,
  "early_stopping_patience": 5,
  "learning_rate": 0.001,
  "l2_reg": 0.0001,
  "mlf_item2vec_model_name": "item2vec",
  "mlf_model_name": "sequence_rating_prediction",
  "min_roc_auc": 0.7
}


# Implement

In [5]:
def init_model(n_users, n_items, embedding_dim, dropout, item_embedding=None):
    model = SequenceRatingPrediction(
        n_users, n_items, embedding_dim, dropout=dropout, item_embedding=item_embedding
    )
    return model

## Load pretrained Item2Vec embeddings

In [6]:
mlf_client = mlflow.MlflowClient()
model = mlflow.pyfunc.load_model(
    model_uri=f"models:/{args.mlf_item2vec_model_name}@champion"
)
skipgram_model = model.unwrap_python_model().model
embedding_0 = skipgram_model.embeddings(torch.tensor(0))
embedding_dim = embedding_0.size()[0]
id_mapping = model.unwrap_python_model().id_mapping
pretrained_item_embedding = skipgram_model.embeddings

Downloading artifacts:   0%|          | 0/8 [00:00<?, ?it/s]

In [7]:
assert (
    pretrained_item_embedding.embedding_dim == args.embedding_dim
), "Mismatch pretrained item_embedding dimension"

# Test implementation

In [8]:
embedding_dim = 8
batch_size = 2

# Mock data
user_indices = [0, 0, 1, 2, 2]
item_indices = [0, 1, 2, 3, 4]
timestamps = [0, 1, 2, 3, 4]
ratings = [0, 4, 5, 3, 0]
item_sequences = [
    [-1, -1, 2, 3],
    [-1, -1, 2, 3],
    [-1, -1, 1, 3],
    [-1, -1, 2, 1],
    [-1, -1, 2, 1],
]

n_users = len(set(user_indices))
n_items = len(set(item_indices))

train_df = pd.DataFrame(
    {
        "user_indice": user_indices,
        "item_indice": item_indices,
        args.timestamp_col: timestamps,
        args.rating_col: ratings,
        "item_sequence": item_sequences,
    }
)

model = init_model(n_users, n_items, embedding_dim, args.dropout)

# Example forward pass
model.eval()
user = torch.tensor([0])
item_sequence = torch.tensor([[-1, -1, -1, 0, 1]])
target_item = torch.tensor([2])
predictions = model.predict(user, item_sequence, target_item)
print(predictions)

tensor([[0.5208]], grad_fn=<SigmoidBackward0>)


In [9]:
rating_dataset = UserItemRatingDFDataset(
    train_df, "user_indice", "item_indice", args.rating_col, args.timestamp_col
)

train_loader = DataLoader(rating_dataset, batch_size=batch_size, shuffle=False)

In [10]:
for batch_input in train_loader:
    print(batch_input)

{'user': tensor([0, 0]), 'item': tensor([0, 1]), 'rating': tensor([0., 4.]), 'item_sequence': tensor([[-1, -1,  2,  3],
        [-1, -1,  2,  3]]), 'item_feature': tensor([], size=(2, 0))}
{'user': tensor([1, 2]), 'item': tensor([2, 3]), 'rating': tensor([5., 3.]), 'item_sequence': tensor([[-1, -1,  1,  3],
        [-1, -1,  2,  1]]), 'item_feature': tensor([], size=(2, 0))}
{'user': tensor([2]), 'item': tensor([4]), 'rating': tensor([0.]), 'item_sequence': tensor([[-1, -1,  2,  1]]), 'item_feature': tensor([], size=(1, 0))}


In [11]:
# model
lit_model = LitSequenceRatingPrediction(model, log_dir=args.notebook_persist_dp)

# train model
trainer = L.Trainer(default_root_dir=f"{args.notebook_persist_dp}/test", max_epochs=2)
trainer.fit(
    model=lit_model, train_dataloaders=train_loader, val_dataloaders=train_loader
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name  | Type                     | Params | Mode
----------------------------------------------------------
0 | model | SequenceRatingPrediction | 729    | eval
----------------------------------------------------------
729       Trainable params
0         Non-trainable params
729       Total params
0.003     Total estimated model params size (MB)
0         Modules in train mode
10        Modules in eval mode


Sanity Checking: |                                                                                            …

/Users/dvq/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.
/Users/dvq/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.
/Users/dvq/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (3) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=2` reached.
[32m2024-10-25 21:53:29.313[0m | [1mINFO    [0m | [36msrc.sequence.trainer[0m:[36mon_fit_end[0m:[36m124[0m - [1mLogging classification metrics...[0m


In [12]:
users = torch.tensor([0, 0, 0, 0])
item_sequences = torch.tensor(
    [[-1, -1, 2, 3], [-1, -1, 2, 3], [-1, -1, 1, 3], [-1, -1, 2, 1]]
)
items = torch.tensor([0, 1, 2, 3])
predictions = model.predict(users, item_sequences, items)
print(predictions)

tensor([[0.5526],
        [0.5538],
        [0.5427],
        [0.5993]], grad_fn=<SigmoidBackward0>)


In [13]:
def create_predict_df(
    train_df,
    val_user_indices,
    val_timestamp,
    rating_col,
    timestamp_col,
    sequence_length=10,
):
    predict_df = pd.DataFrame(
        {
            "user_indice": val_user_indices,
            "item_indice": -1,  # placeholder
            "timestamp": val_timestamp,
            "source": "predict",
        }
    )

    predict_df = (
        pd.concat(
            [
                train_df.loc[lambda df: df[rating_col].gt(0)][
                    ["user_indice", "item_indice", timestamp_col]
                ].assign(source="train"),
                predict_df,
            ],
            axis=0,
        )
        .pipe(
            generate_item_sequences,
            "user_indice",
            "item_indice",
            timestamp_col,
            sequence_length=sequence_length,
            padding=True,
            padding_value=-1,
        )
        .loc[lambda df: df["source"].eq("predict")]
        .assign(item_sequence=lambda df: df["item_sequence"].apply(np.array))
    )

    return predict_df


predict_df = create_predict_df(
    train_df,
    user_indices,
    timestamps[-1],
    args.rating_col,
    args.timestamp_col,
    sequence_length=10,
)

predict_df

Unnamed: 0,user_indice,item_indice,timestamp,source,item_sequence
0,0,-1,4,predict,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, 1]"
1,0,-1,4,predict,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, 1]"
2,1,-1,4,predict,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, 2]"
3,2,-1,4,predict,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, 3]"
4,2,-1,4,predict,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, 3]"


In [14]:
recommendations = model.recommend(
    torch.tensor(predict_df["user_indice"].values),
    torch.tensor(predict_df["item_sequence"].values.tolist()),
    k=2,
    batch_size=4,
)
recommendations


Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_new.cpp:281.)



Generating recommendations:   0%|          | 0/2 [00:00<?, ?it/s]

{'user_indice': [0, 0, 0, 0, 1, 1, 2, 2, 2, 2],
 'recommendation': [3, 5, 3, 5, 3, 1, 2, 3, 2, 3],
 'score': [0.6049464344978333,
  0.5866492986679077,
  0.6049464344978333,
  0.5866492986679077,
  0.49212324619293213,
  0.4886247515678406,
  0.584133505821228,
  0.5803021192550659,
  0.584133505821228,
  0.5803021192550659]}

# Prep data

In [15]:
train_df = pd.read_parquet("../data/train_features_neg_df.parquet")
val_df = pd.read_parquet("../data/val_features_neg_df.parquet")
idm_fp = "../data/idm.json"
idm = IDMapper().load(idm_fp)

In [16]:
user_indices = train_df["user_indice"].unique()
item_indices = train_df["item_indice"].unique()

logger.info(f"{len(user_indices)=:,.0f}, {len(item_indices)=:,.0f}")

[32m2024-10-25 21:53:31.083[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mlen(user_indices)=19,578, len(item_indices)=4,630[0m


In [17]:
train_df

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,main_category,title,description,categories,price,user_rating_cnt_90d,user_rating_avg_prev_rating_90d,user_rating_list_10_recent_asin,item_sequence
0,AFDBCNCWNMRNRTMZLCJYAH25ZI2A,B08N7QBVBJ,0.0,2010-12-14 18:13:27.000,10157,681,Video Games,PowerA Charging Stand for Xbox One - White,[An Xbox wireless Controller looks great on th...,"[Video Games, Xbox One, Accessories, Batteries...",13.57,1,,,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]"
1,AERBGYRVU2NO24B5CNSMGDSCLD3Q,B004JLO65Q,0.0,2016-11-26 05:51:29.000,18922,3515,Video Games,Nintendo Official Executive Case for 3DS,[Carry your Nintendo 3DS system in style with ...,"[Video Games, Legacy Systems, Nintendo Systems...",,1,,"B0044R8X9U,B000P297JS","[-1, -1, -1, -1, -1, -1, -1, -1, 4206, 4417]"
2,AEXNTHZMDXE4GL2ZKHLMOZWEHUNA,B01B62OSTE,0.0,2018-10-17 22:02:48.892,15120,3587,Video Games,Turtle Beach - Ear Force Elite 800 - Premium F...,[Turtle Beach’s Elite 800 isn’t your ordinary ...,"[Video Games, PlayStation 4, Accessories, Head...",123.24,1,,"B0036F0V4G,B003HGGN82,B0088MVPFQ,B007CM0K86,B0...","[-1, 921, 1011, 4330, 2417, 3362, 1859, 1614, ..."
3,AFC5XTCF5D7J3NSDITB2Z26XWWYA,B001E8WQUY,5.0,2019-05-01 21:22:39.265,4316,229,Video Games,Rock Band 2 - Nintendo Wii (Game only),"[Product description, Rock Band 2 lets you and...","[Video Games, Legacy Systems, Nintendo Systems...",28.49,1,,"B006HZA6VK,B0BN2FNKLM,B0086VPUHI,B0040UAYI4,B0...","[550, 3643, 464, 400, 1177, 997, 4585, 440, 90..."
4,AF7LJQOIWF3Y3YD7SGOJ34MA5JPA,B001E8WQKY,5.0,2015-01-09 12:53:25.000,13887,2028,Video Games,Resident Evil 5 - Xbox 360,[],"[Video Games, Legacy Systems, Xbox Systems, Xb...",29.88,3,5.0,"B00A2ML6XG,B003VUO6LU","[-1, -1, -1, -1, -1, -1, -1, -1, 1378, 303]"
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
328591,AF2UEE65LG6WPCRSIP3UBGZLN7EQ,B00XBLQCLQ,0.0,2012-10-01 06:46:28.000,10087,3752,Video Games,Assassin’s Creed Syndicate - Gold Edition | PC...,"[London, 1868. The Industrial Revolution unlea...","[Video Games, PC, Games]",66.35,2,3.0,"B001EYUQVE,B001ELJFGO","[-1, -1, -1, -1, -1, -1, -1, -1, 1088, 309]"
328592,AH2RSPTE3H6XPONAC7XHIXFHE4IA,B002BSA298,0.0,2017-10-20 12:55:58.546,6088,2775,Video Games,Kinect Sensor with Kinect Adventures!,"[Product Description, Kinect for Xbox 360 brin...","[Video Games, Legacy Systems, Xbox Systems, Xb...",88.0,2,5.0,"B0118GJKIW,B00X3EDHZU,B004HO6CQG,B073CFJG46","[-1, -1, -1, -1, -1, -1, 3679, 3668, 2576, 3818]"
328593,AFKEXMJWTFZMBO7QF6OFI4AD2B5A,B087LSSNG1,0.0,2020-02-16 21:01:48.099,18733,611,Video Games,Xenoblade Chronicles: Definitive Edition - Nin...,[Discover the origins of Shulk as he and his c...,"[Video Games, Nintendo Switch, Games]",54.98,1,,,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]"
328594,AHAVA5VKMJ3OMOLGDZ3W45CKXEWA,B00KTORA0K,5.0,2019-05-25 04:03:51.505,17876,1189,Video Games,Just Dance 2015 - Wii,[With more than 50 million copies of Just Danc...,"[Video Games, Legacy Systems, Nintendo Systems...",33.0,2,5.0,"B004AYCNR0,B007NUQICE,B000TYQL1O,B000SEU92W,B0...","[-1, -1, -1, 2229, 4467, 4566, 262, 3157, 1682..."


# Train

In [18]:
rating_dataset = UserItemRatingDFDataset(
    train_df, "user_indice", "item_indice", args.rating_col, args.timestamp_col
)
val_rating_dataset = UserItemRatingDFDataset(
    val_df, "user_indice", "item_indice", args.rating_col, args.timestamp_col
)

train_loader = DataLoader(
    rating_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True
)
val_loader = DataLoader(
    val_rating_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False
)

In [19]:
n_items = len(item_indices)
n_users = len(user_indices)

model = init_model(n_users, n_items, args.embedding_dim, args.dropout)

#### Predict before train

In [20]:
model.item_embedding

Embedding(4631, 128, padding_idx=4630)

In [21]:
val_df = val_rating_dataset.df
val_df.sample(10)

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,main_category,title,description,categories,price,user_rating_cnt_90d,user_rating_avg_prev_rating_90d,user_rating_list_10_recent_asin,item_sequence
373,AGSZJ35QJEYZFOJC7LQ3FGJKUCOA,B0C7BN9G35,2.0,2021-12-29 17:30:54.170,4463,4616,Video Games,"PDP Universal PS4/PS5 Media Remote Control, Pl...","[Manage your PS4 systems with this convenient,...","[Video Games, PlayStation 4, Accessories, Remo...",,1,,"B000UZVL58,B00030GS80,B002BSH82M,B001ELJDWA,B0...","[-1, -1, -1, -1, -1, 1176, 1240, 2241, 239, 1658]"
1806,AERYH6RJZT3LREKDLJQRMYWRAOFA,B07MYVF61Y,3.0,2021-11-09 16:47:51.954,10219,4025,All Electronics,LevelHike HDMI Cable for Playstation 2 & Plays...,[],"[Video Games, Legacy Systems, PlayStation Syst...",29.99,1,,"B002JJWH5G,B001EYUQWI,B000MIXFWA,B00000DMAR,B0...","[-1, -1, -1, -1, -1, 2080, 3694, 3740, 504, 4138]"
1121,AHX7CNBWGLD425HQOIGBZNPLHZCQ,B06X97FLWF,0.0,2021-08-21 01:02:28.576,3087,3492,Computers,amCase Carrying Case for Nintendo Switch-14 Ga...,[],"[Video Games, Nintendo Switch, Accessories, Ca...",9.99,1,,"B014N4RTS4,B01IC2A28C,B087STK7JZ,B07CD6F5PX,B0...","[811, 3826, 4482, 3165, 4201, 626, 3461, 3528,..."
693,AFLMA3UZEESTNQL24Y5OIMWK7K7A,B0891DDRYN,1.0,2021-10-28 22:00:11.394,8535,119,Computers,Corsair K95 RGB Platinum Mechanical Gaming Key...,[Corsair K95 RGB Platinum features Cherry MX B...,"[Video Games, PC, Accessories, Gaming Keyboards]",259.94,1,,"B00DU2CHE2,B01GW3H3U8,B0C3QFBBH5,B0754LGLFP,B0...","[-1, -1, -1, -1, -1, 1709, 983, 1540, 518, 1871]"
1612,AGTWGY3LVMNOJHRRO5N24SQRTU5Q,B0029LJIFG,0.0,2022-06-06 20:55:59.612,1930,3717,Video Games,Xbox LIVE 12 Month Gold Membership Card,"[From the Manufacturer, With an Xbox LIVE Gold...","[Video Games, Legacy Systems, Xbox Systems, Xb...",68.75,1,,"B07HHW8C4V,B01MQTQV5B,B073V9W4D8,B0BN942894,B0...","[-1, -1, -1, 2908, 3173, 3655, 1575, 1026, 991..."
996,AHXGULVDCWQKTANFRWLCOPOKOBMQ,B07WPS59HY,4.0,2021-11-30 05:41:02.891,9541,2472,Video Games,WB Games LEGOBatman2: DC Super Heroes - Ninten...,[The Melding of Video Game and Toy Lines - LEG...,"[Video Games, Legacy Systems, Nintendo Systems...",26.88,4,2.666667,"B003ZHMMEM,B002Z01QO2,B000FQ9QVI,B002BSC4ZS,B0...","[-1, -1, 398, 2223, 3061, 72, 4262, 2452, 2866..."
1307,AGODEE2NRFP5H2KRXIZPZPV5QSJQ,B07624RBWB,5.0,2022-03-14 13:22:39.302,15021,4405,Video Games,Nintendo Switch Pro Controller,[],"[Video Games, Nintendo Switch, Accessories, Co...",69.0,1,,"B01953YZ0S,B00XY0IVK4,B00W435BL4,B00QO4NAOO,B0...","[-1, 1665, 3508, 2791, 1478, 1054, 3111, 3536,..."
904,AF2LMB25WDXBRLGBS773HLXXBFZA,B003N18O5Q,0.0,2021-11-19 00:58:05.727,5491,4138,Video Games,Vanquish - Playstation 3,"[Product Description, The story of VANQUISH is...","[Video Games, Legacy Systems, PlayStation Syst...",24.99,1,,"B00KR2C0RC,B00KVP78FE,B00GOZSR96,B06X9898LZ,B0...","[-1, -1, -1, 715, 3884, 4120, 3244, 3937, 2137..."
1622,AE6TIR7FF4RCBKKGQ5GVD5YIA7IQ,B0041G5EU0,0.0,2022-04-13 06:46:13.165,13477,791,Video Games,Disney Epic Mickey - Nintendo Wii,"[Product Description, Disney Epic Mickey is an...","[Video Games, Legacy Systems, Nintendo Systems...",19.99,1,,"B01D63UU52,B01JGNF60A,B082VX16VJ,B08LT6PT1X,B0...","[-1, -1, -1, -1, -1, 1958, 2008, 794, 1018, 2720]"
577,AGCDPZJ4Z4VOGKLKJFRZBQKOVLLA,B0062UOBW0,0.0,2021-12-25 04:16:52.075,13227,136,Video Games,Skylanders Spyro's Adventure: Voodood,[The Item is platform independent.],"[Video Games, Legacy Systems, PlayStation Syst...",,1,,"B01GKFPFZS,B00Z9TMBOU,B00E4MQODC,B0995GXFV4,B0...","[2902, 4380, 1597, 3311, 2983, 2209, 1662, 349..."


In [22]:
idm.get_item_index("B00FA1CIUE")

2104

In [23]:
# user_id = val_df.sample(1)[args.user_col].values[0]
user_id = "AH4AOFTTDPHPAFAAVFMAF25H2LIQ"
test_df = val_df.loc[lambda df: df[args.user_col].eq(user_id)]
with pd.option_context("display.max_colwidth", None):
    display(test_df)

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,main_category,title,description,categories,price,user_rating_cnt_90d,user_rating_avg_prev_rating_90d,user_rating_list_10_recent_asin,item_sequence
466,AH4AOFTTDPHPAFAAVFMAF25H2LIQ,B07DK1H3H5,1.0,2022-01-09 17:19:07.823,3550,483,Video Games,Cyberpunk 2077 - PC [Game Download Code in Box],"[Cyberpunk 2077 is an open world, an action adventure story set in Night City, a megalopolis obsessed with power, glamour and body modification. You play at V, a mercenary outlaw going after a one of a kind implant that is the key to immortality. You can customize your character's cyberwar, skill set and playstyle, and explore a vast city where the choices you make shape the story and the world around you.]","[Video Games, PC, Games]",,2,5.0,"B000PS4X9G,B00J5C3Z10,B00DBLBMBQ,B07WS18ZS3,B002XH972U,B0088TN5FM,B00XBLQCLQ,B01GY35W22,B07DKYN13M,B08JDVKWHS","[1156, 826, 4323, 4007, 4517, 2516, 3752, 816, 836, 3693]"
547,AH4AOFTTDPHPAFAAVFMAF25H2LIQ,B004NBXRDE,0.0,2022-01-09 17:09:46.436,3550,454,Video Games,Cars 2 3DS,"[Product Description, Inspired by the upcoming Disney Pixar animated film, Cars 2: The Video Game lets players jump into the Cars 2 universe with some of their favorite Cars personalities in exotic locations around the globe. Continuing the storyline from the upcoming film, players can choose to play as Mater and Lightning McQueen, as well as some brand new characters, as they train in the international training center - CHROME (Command Headquarters for Recon Operations and Motorized Espionage) to become world-class spies. They’ll take on dangerous missions, compete to become the fastest racecar in the world, or use their spy skills in exciting, action-packed combat racing and battle arenas. Players can race against friends and family in either single or multiplayer modes with up to four players to unlock challenging new tracks, characters, events and thrilling spy missions., From the Manufacturer]","[Video Games, Legacy Systems, Nintendo Systems, Nintendo 3DS & 2DS]",49.99,1,,"B008PQU3E4,B000PS4X9G,B00J5C3Z10,B00DBLBMBQ,B07WS18ZS3,B002XH972U,B0088TN5FM,B00XBLQCLQ,B01GY35W22,B07DKYN13M","[244, 1156, 826, 4323, 4007, 4517, 2516, 3752, 816, 836]"
769,AH4AOFTTDPHPAFAAVFMAF25H2LIQ,B08JDVKWHS,5.0,2022-01-09 17:09:46.436,3550,3693,Video Games,Marvel's Spider-Man: Miles Morales Launch Edition - PlayStation 4,"[In the latest adventure in the Marvel's spider-man universe, teenager miles morales is adjusting to his new home while following in the footsteps of his mentor, peter parker, as a new spider-man. But when a fierce power struggle threatens to destroy his new home, the aspiring hero realizes that with great power, there must also come great responsibility. To save all of Marvel's new York, miles must take up the mantle of spider-man and own it.]","[Video Games, PlayStation 4, Games]",57.98,1,,"B008PQU3E4,B000PS4X9G,B00J5C3Z10,B00DBLBMBQ,B07WS18ZS3,B002XH972U,B0088TN5FM,B00XBLQCLQ,B01GY35W22,B07DKYN13M","[244, 1156, 826, 4323, 4007, 4517, 2516, 3752, 816, 836]"
1525,AH4AOFTTDPHPAFAAVFMAF25H2LIQ,B01GH9MVWW,0.0,2022-01-09 17:19:07.823,3550,2343,Video Games,Skylanders Imaginators - Crash Bandicoot Edition - PlayStation 4,"[Skylanders Imaginators Featuring Crash Bandicoot. A wormhole opens in Skylands and the great Aku Aku appears! He comes to announce the once in two decades Synchronization Celebration! It is the time when all of the worlds align perfectly. They are having a huge celebration event in the Wumpa Islands and want to invite the Skylanders. But with Kaos on a quest to take over Skylands using his army of Doomlanders, the Skylanders must focus on stopping his evil plans. Never one to shy away from danger, the legendary marsupial Crash Bandicoot travels through the wormhole to join the Skylanders in the ultimate battle against Kaos!]","[Video Games, PlayStation 4, Games]",349.95,2,5.0,"B000PS4X9G,B00J5C3Z10,B00DBLBMBQ,B07WS18ZS3,B002XH972U,B0088TN5FM,B00XBLQCLQ,B01GY35W22,B07DKYN13M,B08JDVKWHS","[1156, 826, 4323, 4007, 4517, 2516, 3752, 816, 836, 3693]"


In [24]:
test_row = test_df.loc[lambda df: df[args.rating_col].gt(0)].iloc[0]
item_id = test_row[args.item_col]
item_sequence = test_row["item_sequence"]
logger.info(
    f"Test predicting before training with {args.user_col} = {user_id} and {args.item_col} = {item_id}"
)
user_indice = idm.get_user_index(user_id)
item_indice = idm.get_item_index(item_id)
user = torch.tensor([user_indice])
item_sequence = torch.tensor([item_sequence])
item = torch.tensor([item_indice])

model.eval()
model.predict(user, item_sequence, item)

[32m2024-10-25 21:53:31.305[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mTest predicting before training with user_id = AH4AOFTTDPHPAFAAVFMAF25H2LIQ and parent_asin = B07DK1H3H5[0m


tensor([[0.5352]], grad_fn=<SigmoidBackward0>)

#### Training loop

##### Overfit 1 batch

In [25]:
early_stopping = EarlyStopping(
    monitor="val_loss", patience=10, mode="min", verbose=False
)

model = init_model(n_users, n_items, args.embedding_dim, dropout=0)
lit_model = LitSequenceRatingPrediction(
    model,
    learning_rate=args.learning_rate,
    l2_reg=0.0,
    log_dir=args.notebook_persist_dp,
)

log_dir = f"{args.notebook_persist_dp}/logs/overfit"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    accelerator=args.device if args.device else "auto",
    max_epochs=100,
    overfit_batches=1,
    callbacks=[early_stopping],
)
trainer.fit(
    model=lit_model,
    train_dataloaders=train_loader,
    val_dataloaders=train_loader,
)
logger.info(f"Logs available at {trainer.log_dir}")

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(overfit_batches=1)` was configured so 1 batch will be used.

  | Name  | Type                     | Params | Mode 
-----------------------------------------------------------
0 | model | SequenceRatingPrediction | 3.2 M  | train
-----------------------------------------------------------
3.2 M     Trainable params
0         Non-trainable params
3.2 M     Total params
12.990    Total estimated model params size (MB)
10        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …


You requested to overfit but enabled val dataloader shuffling. We are turning off the val dataloader shuffling for you.


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.


You requested to overfit but enabled train dataloader shuffling. We are turning off the train dataloader shuffling for you.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.


The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.



Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

[32m2024-10-25 21:53:36.161[0m | [1mINFO    [0m | [36msrc.sequence.trainer[0m:[36mon_fit_end[0m:[36m124[0m - [1mLogging classification metrics...[0m
[32m2024-10-25 21:53:51.559[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m30[0m - [1mLogs available at /Users/dvq/frostmourne/recsys-mvp/notebooks/data/001-seq-model/logs/overfit/lightning_logs/version_2[0m


In [26]:
%tensorboard --logdir $trainer.log_dir

##### Fit on all data

In [27]:
# papermill_description=fit-model
early_stopping = EarlyStopping(
    monitor="val_loss", patience=args.early_stopping_patience, mode="min", verbose=False
)

checkpoint_callback = ModelCheckpoint(
    dirpath=f"{args.notebook_persist_dp}/checkpoints",
    filename="best-checkpoint",
    save_top_k=1,
    monitor="val_loss",
    mode="min",
)

model = init_model(
    n_users,
    n_items,
    args.embedding_dim,
    dropout=args.dropout,
    item_embedding=pretrained_item_embedding,
)
lit_model = LitSequenceRatingPrediction(
    model,
    learning_rate=args.learning_rate,
    l2_reg=args.l2_reg,
    log_dir=args.notebook_persist_dp,
    evaluate_ranking=True,
    idm=idm,
    args=args,
)

log_dir = f"{args.notebook_persist_dp}/logs/run"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    max_epochs=args.max_epochs,
    callbacks=[early_stopping, checkpoint_callback],
    accelerator=args.device if args.device else "auto",
    logger=args._mlf_logger if args.log_to_mlflow else None,
)
trainer.fit(
    model=lit_model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Experiment with name RecSys MVP - Sequence Modeling not found. Creating it.

Checkpoint directory /Users/dvq/frostmourne/recsys-mvp/notebooks/data/001-seq-model/checkpoints exists and is not empty.


  | Name  | Type                     | Params | Mode 
-----------------------------------------------------------
0 | model | SequenceRatingPrediction | 3.2 M  | train
-----------------------------------------------------------
3.2 M     Trainable params
0         Non-trainable params
3.2 M     Total params
12.990    Total estimated model params size (MB)
10        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=10` in the `DataLoader` to improve performance.



Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

[32m2024-10-25 21:59:33.753[0m | [1mINFO    [0m | [36msrc.sequence.trainer[0m:[36mon_fit_end[0m:[36m124[0m - [1mLogging classification metrics...[0m
[32m2024-10-25 21:59:39.331[0m | [1mINFO    [0m | [36msrc.sequence.trainer[0m:[36mon_fit_end[0m:[36m127[0m - [1mLogging ranking metrics...[0m


Generating recommendations:   0%|          | 0/481 [00:00<?, ?it/s]

2024/10/25 22:01:03 INFO mlflow.tracking._tracking_service.client: 🏃 View run 001-seq-model at: http://localhost:5002/#/experiments/2/runs/d486880d20ea4d6ab4e6c62404c32280.
2024/10/25 22:01:03 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://localhost:5002/#/experiments/2.


In [28]:
logger.info(
    f"Test predicting after training with {args.user_col} = {user_id} and {args.item_col} = {item_id}"
)
model.eval()
model.predict(user, item_sequence, item)

[32m2024-10-25 22:01:03.464[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mTest predicting after training with user_id = AH4AOFTTDPHPAFAAVFMAF25H2LIQ and parent_asin = B07DK1H3H5[0m


tensor([[0.9572]], grad_fn=<SigmoidBackward0>)

# Load best checkpoint

In [45]:
logger.info(f"Loading best checkpoint from {checkpoint_callback.best_model_path}...")
args.best_checkpoint_path = checkpoint_callback.best_model_path

best_trainer = LitSequenceRatingPrediction.load_from_checkpoint(
    checkpoint_callback.best_model_path,
    model=init_model(n_users, n_items, args.embedding_dim, dropout=0),
)

[32m2024-10-25 22:11:37.982[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mLoading best checkpoint from /Users/dvq/frostmourne/recsys-mvp/notebooks/data/001-seq-model/checkpoints/best-checkpoint-v2.ckpt...[0m


In [46]:
best_model = best_trainer.model.to(lit_model.device)

In [47]:
best_model.eval()
best_model.predict(user, item_sequence, item)

tensor([[0.9551]], grad_fn=<SigmoidBackward0>)

### Persist id mapping

In [48]:
if args.log_to_mlflow:
    # Persist id_mapping so that at inference we can predict based on item_ids (string) instead of item_index
    run_id = trainer.logger.run_id
    mlf_client = trainer.logger.experiment
    mlf_client.log_artifact(run_id, idm_fp)

### Wrap inference function and register best checkpoint as MLflow model

In [49]:
inferrer = SequenceRatingPredictionInferenceWrapper(best_model)

In [50]:
sample_input = {
    "user_ids": [idm.get_user_id(0)],
    "item_sequences": [[idm.get_item_id(0), idm.get_item_id(1)]],
    "item_ids": [idm.get_item_id(0)],
}
sample_output = inferrer.infer([0], [[0, 1]], [0])
sample_output

array([0.7880765], dtype=float32)

In [51]:
if args.log_to_mlflow:
    run_id = trainer.logger.run_id
    sample_output_np = sample_output
    signature = infer_signature(sample_input, sample_output_np)
    idm_filename = idm_fp.split("/")[-1]
    with mlflow.start_run(run_id=run_id):
        mlflow.pyfunc.log_model(
            python_model=inferrer,
            artifact_path="inferrer",
            # We log the id_mapping to the predict function so that it can accept item_id and automatically convert ot item_indice for PyTorch model to use
            artifacts={"idm": mlflow.get_artifact_uri(idm_filename)},
            signature=signature,
            input_example=sample_input,
            registered_model_name=args.mlf_model_name,
        )


Since MLflow 2.16.0, we no longer convert dictionary input example to pandas Dataframe, and directly save it as a json object. If the model expects a pandas DataFrame input instead, please pass the pandas DataFrame as input example directly.



Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Successfully registered model 'sequence_rating_prediction'.
2024/10/25 22:11:50 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: sequence_rating_prediction, version 1
Created version '1' of model 'sequence_rating_prediction'.


Downloading artifacts:   0%|          | 0/8 [00:00<?, ?it/s]

2024/10/25 22:11:55 INFO mlflow.tracking._tracking_service.client: 🏃 View run 001-seq-model at: http://localhost:5002/#/experiments/2/runs/d486880d20ea4d6ab4e6c62404c32280.
2024/10/25 22:11:55 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://localhost:5002/#/experiments/2.


# Set the newly trained model as champion

In [52]:
if args.log_to_mlflow:
    val_roc_auc = trainer.logger.experiment.get_run(trainer.logger.run_id).data.metrics[
        "val_roc_auc"
    ]

    if val_roc_auc > args.min_roc_auc:
        logger.info(f"Aliasing the new model as champion...")
        model_version = (
            mlf_client.get_registered_model(args.mlf_model_name)
            .latest_versions[0]
            .version
        )

        mlf_client.set_registered_model_alias(
            name=args.mlf_model_name, alias="champion", version=model_version
        )

        mlf_client.set_model_version_tag(
            name=args.mlf_model_name,
            version=model_version,
            key="author",
            value="quy.dinh",
        )

[32m2024-10-25 22:11:55.663[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m7[0m - [1mAliasing the new model as champion...[0m


# Clean up

In [53]:
all_params = [args]

if args.log_to_mlflow:
    with mlflow.start_run(run_id=run_id):
        for params in all_params:
            params_dict = params.dict()
            params_ = dict()
            for k, v in params_dict.items():
                if k == "top_K":
                    k = "top_big_K"
                if k == "top_k":
                    k = "top_small_k"
                params_[f"{params.__repr_name__()}.{k}"] = v
            mlflow.log_params(params_)

2024/10/25 22:11:55 INFO mlflow.tracking._tracking_service.client: 🏃 View run 001-seq-model at: http://localhost:5002/#/experiments/2/runs/d486880d20ea4d6ab4e6c62404c32280.
2024/10/25 22:11:55 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://localhost:5002/#/experiments/2.
