# Sequence modeling for ranking task

# Set up

In [1]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

In [None]:
import os
import sys

import lightning as L
import numpy as np
import pandas as pd
import torch
from dotenv import load_dotenv
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import MLFlowLogger
from loguru import logger
from mlflow.models.signature import infer_signature
from pydantic import BaseModel
from torch.utils.data import DataLoader

import mlflow

load_dotenv()

sys.path.insert(0, "..")

from src.dataset import UserItemBinaryDFDataset as UserItemRatingDFDataset
from src.id_mapper import IDMapper
from src.sequence.inference import SequenceRatingPredictionInferenceWrapper
from src.sequence.model import SequenceRatingPrediction
from src.sequence.trainer import LitSequenceRatingPrediction
from src.sequence.utils import generate_item_sequences
from src.viz import blueq_colors

# Controller

In [3]:
max_epochs = 1

In [4]:
class Args(BaseModel):
    testing: bool = False
    log_to_mlflow: bool = True
    experiment_name: str = "RecSys MVP - Sequence Modeling"
    run_name: str = "002-try-binary"
    notebook_persist_dp: str = None
    random_seed: int = 41
    device: str = None

    max_epochs: int = max_epochs
    batch_size: int = 128

    user_col: str = "user_id"
    item_col: str = "parent_asin"
    rating_col: str = "rating"
    timestamp_col: str = "timestamp"

    top_K: int = 100
    top_k: int = 10

    batch_size: int = 128

    embedding_dim: int = 128
    dropout: float = 0.3
    early_stopping_patience: int = 5
    learning_rate: float = 0.001
    l2_reg: float = 1e-4

    mlf_item2vec_model_name: str = "item2vec"
    mlf_model_name: str = "sequence_rating_prediction"
    min_roc_auc: float = 0.7

    best_checkpoint_path: str = None

    def init(self):
        self.notebook_persist_dp = os.path.abspath(f"data/{self.run_name}")
        os.makedirs(self.notebook_persist_dp, exist_ok=True)

        if not (mlflow_uri := os.environ.get("MLFLOW_TRACKING_URI")):
            logger.warning(
                f"Environment variable MLFLOW_TRACKING_URI is not set. Setting self.log_to_mlflow to false."
            )
            self.log_to_mlflow = False

        if self.log_to_mlflow:
            logger.info(
                f"Setting up MLflow experiment {self.experiment_name} - run {self.run_name}..."
            )
            self._mlf_logger = MLFlowLogger(
                experiment_name=self.experiment_name,
                run_name=self.run_name,
                tracking_uri=mlflow_uri,
                log_model=True,
            )

        if self.device is None:
            self.device = (
                "cuda"
                if torch.cuda.is_available()
                else "mps" if torch.backends.mps.is_available() else "cpu"
            )

        return self


args = Args().init()

print(args.model_dump_json(indent=2))

[32m2024-11-10 16:26:10.314[0m | [1mINFO    [0m | [36m__main__[0m:[36minit[0m:[36m46[0m - [1mSetting up MLflow experiment RecSys MVP - Sequence Modeling - run 002-try-binary...[0m


{
  "testing": false,
  "log_to_mlflow": true,
  "experiment_name": "RecSys MVP - Sequence Modeling",
  "run_name": "002-try-binary",
  "notebook_persist_dp": "/home/dvquys/frostmourne/recsys-mvp/notebooks/data/002-try-binary",
  "random_seed": 41,
  "device": "cuda",
  "max_epochs": 1,
  "batch_size": 128,
  "user_col": "user_id",
  "item_col": "parent_asin",
  "rating_col": "rating",
  "timestamp_col": "timestamp",
  "top_K": 100,
  "top_k": 10,
  "embedding_dim": 128,
  "dropout": 0.3,
  "early_stopping_patience": 5,
  "learning_rate": 0.001,
  "l2_reg": 0.0001,
  "mlf_item2vec_model_name": "item2vec",
  "mlf_model_name": "sequence_rating_prediction",
  "min_roc_auc": 0.7,
  "best_checkpoint_path": null
}


# Implement

In [5]:
def init_model(n_users, n_items, embedding_dim, dropout, item_embedding=None):
    model = SequenceRatingPrediction(
        n_users, n_items, embedding_dim, dropout=dropout, item_embedding=item_embedding
    )
    return model

## Load pretrained Item2Vec embeddings

In [6]:
mlf_client = mlflow.MlflowClient()
model = mlflow.pyfunc.load_model(
    model_uri=f"models:/{args.mlf_item2vec_model_name}@champion"
)
skipgram_model = model.unwrap_python_model().model
embedding_0 = skipgram_model.embeddings(torch.tensor(0))
embedding_dim = embedding_0.size()[0]
id_mapping = model.unwrap_python_model().id_mapping
pretrained_item_embedding = skipgram_model.embeddings

Downloading artifacts:   0%|          | 0/8 [00:00<?, ?it/s]

In [7]:
assert (
    pretrained_item_embedding.embedding_dim == args.embedding_dim
), "Mismatch pretrained item_embedding dimension"

# Test implementation

In [8]:
embedding_dim = 8
batch_size = 2

# Mock data
user_indices = [0, 0, 1, 2, 2]
item_indices = [0, 1, 2, 3, 4]
timestamps = [0, 1, 2, 3, 4]
ratings = [0, 4, 5, 3, 0]
item_sequences = [
    [-1, -1, 2, 3],
    [-1, -1, 2, 3],
    [-1, -1, 1, 3],
    [-1, -1, 2, 1],
    [-1, -1, 2, 1],
]

n_users = len(set(user_indices))
n_items = len(set(item_indices))

train_df = pd.DataFrame(
    {
        "user_indice": user_indices,
        "item_indice": item_indices,
        args.timestamp_col: timestamps,
        args.rating_col: ratings,
        "item_sequence": item_sequences,
    }
)

model = init_model(n_users, n_items, embedding_dim, args.dropout)

# Example forward pass
model.eval()
user = torch.tensor([0])
item_sequence = torch.tensor([[-1, -1, -1, 0, 1]])
target_item = torch.tensor([2])
predictions = model.predict(user, item_sequence, target_item)
print(predictions)
model.train()

tensor([[0.5199]], grad_fn=<SigmoidBackward0>)


SequenceRatingPrediction(
  (item_embedding): Embedding(6, 8, padding_idx=5)
  (user_embedding): Embedding(3, 8)
  (gru): GRU(8, 8, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (fc_rating): Sequential(
    (0): Linear(in_features=24, out_features=8, bias=True)
    (1): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=8, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

In [None]:
rating_dataset = UserItemRatingDFDataset(
    train_df, "user_indice", "item_indice", args.rating_col, args.timestamp_col
)

train_loader = DataLoader(
    rating_dataset, batch_size=batch_size, shuffle=False, drop_last=True
)

In [10]:
for batch_input in train_loader:
    print(batch_input)

{'user': tensor([0, 0]), 'item': tensor([0, 1]), 'rating': tensor([0., 1.]), 'item_sequence': tensor([[-1, -1,  2,  3],
        [-1, -1,  2,  3]]), 'item_sequence_ts_bucket': tensor([], size=(2, 0), dtype=torch.int64), 'item_feature': tensor([], size=(2, 0))}
{'user': tensor([1, 2]), 'item': tensor([2, 3]), 'rating': tensor([1., 1.]), 'item_sequence': tensor([[-1, -1,  1,  3],
        [-1, -1,  2,  1]]), 'item_sequence_ts_bucket': tensor([], size=(2, 0), dtype=torch.int64), 'item_feature': tensor([], size=(2, 0))}


In [11]:
# model
lit_model = LitSequenceRatingPrediction(model, log_dir=args.notebook_persist_dp)

# train model
trainer = L.Trainer(
    default_root_dir=f"{args.notebook_persist_dp}/test",
    max_epochs=2,
    accelerator=args.device if args.device else "auto",
)
trainer.fit(
    model=lit_model, train_dataloaders=train_loader, val_dataloaders=train_loader
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4060 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                     | Params | Mode 
-----------------------------------------------------------
0 | model | SequenceRatingPrediction | 729    | train
-----------------------------------------------------------
729       Trainable params
0         Non-trainable params
729       Total params
0.003     Total estimated model params size (MB)
11        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …

/home/dvquys/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/home/dvquys/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/home/dvquys/frostmourne/recsys-mvp/.venv/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=2` reached.
[32m2024-11-10 16:26:11.209[0m | [1mINFO    [0m | [36msrc.sequence.trainer[0m:[36mon_fit_end[0m:[36m127[0m - [1mLogging classification metrics...[0m


In [12]:
%debug

ERROR:root:No traceback has been produced, nothing to debug.


ipdb>  input


tensor([[-0.1289,  0.3273, -0.0985, -0.8232,  0.7046, -0.1893,  0.0574, -0.5758]],
       device='cuda:0', grad_fn=<AddmmBackward0>)


ipdb>  exit


In [13]:
users = torch.tensor([0, 0, 0, 0])
item_sequences = torch.tensor(
    [[-1, -1, 2, 3], [-1, -1, 2, 3], [-1, -1, 1, 3], [-1, -1, 2, 1]]
)
items = torch.tensor([0, 1, 2, 3])
predictions = model.predict(users, item_sequences, items)
print(predictions)

tensor([[0.5487],
        [0.5351],
        [0.5420],
        [0.5283]], grad_fn=<SigmoidBackward0>)


In [14]:
def create_predict_df(
    train_df,
    val_user_indices,
    val_timestamp,
    rating_col,
    timestamp_col,
    sequence_length=10,
):
    predict_df = pd.DataFrame(
        {
            "user_indice": val_user_indices,
            "item_indice": -1,  # placeholder
            "timestamp": val_timestamp,
            "source": "predict",
        }
    )

    predict_df = (
        pd.concat(
            [
                train_df.loc[lambda df: df[rating_col].gt(0)][
                    ["user_indice", "item_indice", timestamp_col]
                ].assign(source="train"),
                predict_df,
            ],
            axis=0,
        )
        .pipe(
            generate_item_sequences,
            "user_indice",
            "item_indice",
            timestamp_col,
            sequence_length=sequence_length,
            padding=True,
            padding_value=-1,
        )
        .loc[lambda df: df["source"].eq("predict")]
        .assign(item_sequence=lambda df: df["item_sequence"].apply(np.array))
    )

    return predict_df


predict_df = create_predict_df(
    train_df,
    user_indices,
    timestamps[-1],
    args.rating_col,
    args.timestamp_col,
    sequence_length=10,
)

predict_df

Unnamed: 0,user_indice,item_indice,timestamp,source,item_sequence
0,0,-1,4,predict,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, 1]"
1,0,-1,4,predict,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, 1]"
2,1,-1,4,predict,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, 2]"
3,2,-1,4,predict,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, 3]"
4,2,-1,4,predict,"[-1, -1, -1, -1, -1, -1, -1, -1, -1, 3]"


In [15]:
recommendations = model.recommend(
    torch.tensor(predict_df["user_indice"].values),
    torch.tensor(predict_df["item_sequence"].values.tolist()),
    k=2,
    batch_size=4,
)
recommendations


Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:278.)



Generating recommendations:   0%|          | 0/2 [00:00<?, ?it/s]

{'user_indice': [0, 0, 0, 0, 1, 1, 2, 2, 2, 2],
 'recommendation': [4, 5, 4, 5, 0, 4, 4, 0, 4, 0],
 'score': [0.5549459457397461,
  0.5477004647254944,
  0.5549459457397461,
  0.5477004647254944,
  0.6304539442062378,
  0.6078678369522095,
  0.5229425430297852,
  0.5174740552902222,
  0.5229425430297852,
  0.5174740552902222]}

# Prep data

In [16]:
train_df = pd.read_parquet("../data/train_features_neg_df.parquet")
val_df = pd.read_parquet("../data/val_features_neg_df.parquet")
idm_fp = "../data/idm.json"
idm = IDMapper().load(idm_fp)

assert (
    train_df[args.user_col].map(lambda s: idm.get_user_index(s))
    != train_df["user_indice"]
).sum() == 0, "Mismatch IDM"
assert (
    val_df[args.user_col].map(lambda s: idm.get_user_index(s)) != val_df["user_indice"]
).sum() == 0, "Mismatch IDM"

In [17]:
user_indices = train_df["user_indice"].unique()
item_indices = train_df["item_indice"].unique()

logger.info(f"{len(user_indices)=:,.0f}, {len(item_indices)=:,.0f}")

[32m2024-11-10 16:26:13.446[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mlen(user_indices)=19,578, len(item_indices)=4,630[0m


In [18]:
train_df

Unnamed: 0,user_id,parent_asin,rating,timestamp,timestamp_unix,parent_asin_rating_cnt_365d,parent_asin_rating_avg_prev_rating_365d,parent_asin_rating_cnt_90d,parent_asin_rating_avg_prev_rating_90d,parent_asin_rating_cnt_30d,...,user_rating_list_10_recent_asin,user_rating_list_10_recent_asin_timestamp,item_sequence,item_sequence_ts,item_sequence_ts_bucket,main_category,title,description,categories,price
0,AG57LGJFCNNQJ6P6ABQAVUKXDUDA,B0015AARJI,0.0,2016-01-12 11:59:11.000,,76.0,4.592105,10.0,4.3,3.0,...,B00J00BLRM,1452599936,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1....","[-1, -1, -1, -1, -1, -1, -1, -1, -1, 1452599936]","[-1, -1, -1, -1, -1, -1, -1, -1, -1, 0]",Video Games,PlayStation 3 Dualshock 3 Wireless Controller ...,"[Amazon.com, The Dualshock 3 wireless controll...","[Video Games, Legacy Systems, PlayStation Syst...",49.99
1,AHWG4EGOV5ZDKPETL56MAYGPLJRQ,B0BMGHMP23,0.0,2016-04-18 19:26:20.000,,,,,,,...,"B00YOGZFCO,B00KWFCSB2,B00L3LQ1FI,B0151K6J9Y,B0...","1449254540,1449256005,1449257733,1452715791,14...","[3028.0, 2742.0, 2755.0, 3159.0, 3101.0, 3036....","[1449254540, 1449256005, 1449257733, 145271579...","[5, 5, 5, 5, 5, 5, 5, 5, 5, 5]",Computers,Logitech G502 Lightspeed Wireless Gaming Mouse...,[G502 is the best gaming mouse from Logitech G...,"[Video Games, PC, Accessories, Gaming Mice]",87.95
2,AH5PTZ2U74OZ3HT6QVUWM4CV6OVQ,B009AP23NI,0.0,2016-02-10 18:45:08.000,,9.0,4.666667,0.0,,0.0,...,"B0199OXR0W,B00EVPR4FY,B00B7ELWAU,B00UH9DN58,B0...","1443454097,1455129080,1455129186,1455129499,14...","[-1.0, -1.0, 3234.0, 2508.0, 2318.0, 2964.0, 1...","[-1, -1, 1443454097, 1455129080, 1455129186, 1...","[-1, -1, 5, 1, 1, 0, 0, 0, 0, 0]",Video Games,Nintendo Wii U Pro U Controller (Japanese Vers...,[Wii U PRO controller (black) (WUP-A-RSKA)],"[Video Games, Legacy Systems, Nintendo Systems...",43.99
3,AFC5XTCF5D7J3NSDITB2Z26XWWYA,B001E8WQUY,5.0,2019-05-01 21:22:39.265,1.556746e+09,0.0,,0.0,,0.0,...,"B006HZA6VK,B0BN2FNKLM,B0086VPUHI,B0040UAYI4,B0...","1327120514,1377289907,1402605836,1402606396,14...","[1987.0, 4569.0, 2114.0, 1606.0, 2159.0, 2279....","[1327120514, 1377289907, 1402605836, 140260639...","[8, 8, 7, 7, 7, 7, 7, 7, 6, 6]",Video Games,Rock Band 2 - Nintendo Wii (Game only),"[Product description, Rock Band 2 lets you and...","[Video Games, Legacy Systems, Nintendo Systems...",28.49
4,AF7LJQOIWF3Y3YD7SGOJ34MA5JPA,B001E8WQKY,5.0,2015-01-09 12:53:25.000,1.420808e+09,16.0,4.375000,8.0,4.5,4.0,...,"B00A2ML6XG,B003VUO6LU",14208077931420807991,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1....","[-1, -1, -1, -1, -1, -1, -1, -1, 1420807793, 1...","[-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]",Video Games,Resident Evil 5 - Xbox 360,[],"[Video Games, Legacy Systems, Xbox Systems, Xb...",29.88
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
328591,AG4RATLNVLOKZCPXN67HKOAK65CA,B078FBVJMB,0.0,2015-10-31 18:25:09.000,,,,,,,...,B00TFVD688,1425233294,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1....","[-1, -1, -1, -1, -1, -1, -1, -1, -1, 1425233294]","[-1, -1, -1, -1, -1, -1, -1, -1, -1, 5]",Video Games,A Way Out – PC Origin [Online Game Code],[From the creators of Brothers - A Tale of Two...,"[Video Games, PC, Games]",5.99
328592,AFBXO3BFWBJX6QS5NW73O37IXF2A,B0771ZXXV6,0.0,2011-03-08 02:06:38.000,,,,,,,...,"B003JVCA9Q,B0029NZ4HA",12995495171299549928,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1....","[-1, -1, -1, -1, -1, -1, -1, -1, 1299549517, 1...","[-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]",Video Games,Nintendo Joy-Con (R) - Neon Red - Nintendo Switch,[To be determined],"[Video Games, Nintendo Switch, Accessories, Co...",
328593,AHVANA5GZNJ45UABPXWZNAF4ECBQ,B00BBF6MO6,0.0,2015-02-15 05:31:04.000,,3.0,4.666667,0.0,,0.0,...,"B002L93F0A,B002KJ02ZC,B001H4NMNA",137041433213704147071370416530,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 137...","[-1, -1, -1, -1, -1, -1, -1, 1370414332, 13704...","[-1, -1, -1, -1, -1, -1, -1, 6, 6, 6]",Video Games,Killer is Dead - Xbox 360,[Killer Is Dead is the latest title from the d...,"[Video Games, Legacy Systems, Xbox Systems, Xb...",39.82
328594,AHAVA5VKMJ3OMOLGDZ3W45CKXEWA,B00KTORA0K,5.0,2019-05-25 04:03:51.505,1.558757e+09,3.0,4.666667,1.0,5.0,1.0,...,"B004AYCNR0,B007NUQICE,B000TYQL1O,B000SEU92W,B0...","1431150669,1431150834,1432041664,1432041986,15...","[-1.0, -1.0, -1.0, 1657.0, 2074.0, 593.0, 583....","[-1, -1, -1, 1431150669, 1431150834, 143204166...","[-1, -1, -1, 7, 7, 7, 7, 5, 5, 0]",Video Games,Just Dance 2015 - Wii,[With more than 50 million copies of Just Danc...,"[Video Games, Legacy Systems, Nintendo Systems...",33.0


# Train

In [19]:
rating_dataset = UserItemRatingDFDataset(
    train_df, "user_indice", "item_indice", args.rating_col, args.timestamp_col
)
val_rating_dataset = UserItemRatingDFDataset(
    val_df, "user_indice", "item_indice", args.rating_col, args.timestamp_col
)

train_loader = DataLoader(
    rating_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True
)
val_loader = DataLoader(
    val_rating_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False
)

In [20]:
n_items = len(item_indices)
n_users = len(user_indices)

model = init_model(n_users, n_items, args.embedding_dim, args.dropout)

#### Predict before train

In [21]:
model.item_embedding

Embedding(4631, 128, padding_idx=4630)

In [22]:
val_df = val_rating_dataset.df
val_df.sample(10)

Unnamed: 0,user_id,parent_asin,rating,timestamp,timestamp_unix,parent_asin_rating_cnt_365d,parent_asin_rating_avg_prev_rating_365d,parent_asin_rating_cnt_90d,parent_asin_rating_avg_prev_rating_90d,parent_asin_rating_cnt_30d,...,user_rating_list_10_recent_asin,user_rating_list_10_recent_asin_timestamp,item_sequence,item_sequence_ts,item_sequence_ts_bucket,main_category,title,description,categories,price
1736,AFGNF7LSRDIJDBJ2WOUIINMAKTPQ,B001FRTK4O,0.0,2021-12-31 08:44:58.568,,0.0,,0.0,,0.0,...,"B0BJKR3LJJ,B00K0NV5J2,B08VFQ3XJX,B07CB12726,B0...","1595187952,1620593210,1620609469,1620609531,16...","[-1, -1, 4553, 2692, 4436, 3894, 119, 168, 9, 42]","[-1, -1, 1595187952, 1620593210, 1620609469, 1...","[-1, -1, 6, 5, 5, 5, 5, 5, 5, 5]",Video Games,Spore - PC/Mac,"[Product description, The creators of The Sims...","[Video Games, PC, Games]",30.0
138,AGMLADPSZMDXW53UV52EUUDNOLTQ,B09T2QP4WV,1.0,2021-09-16 04:27:30.127,1631766000.0,7.0,3.428571,1.0,1.0,0.0,...,"B003V4AK8E,B00IMVRVC4,B00MNP9PD8,B071H7XCPV,B0...","1371059318,1397947292,1447605322,1460390105,14...","[-1, -1, 1572, 2651, 2802, 3644, 3681, 3372, 3...","[-1, -1, 1371059318, 1397947292, 1447605322, 1...","[-1, -1, 8, 8, 8, 8, 7, 7, 7, 7]",Computers,Redragon M908 Impact RGB LED MMO Mouse with Si...,[],"[Video Games, PC, Accessories, Gaming Mice]",32.89
1492,AGLQLDOF6JVKQIZY7BU7OCAJFEIA,B01K2O2RSG,1.0,2022-02-06 16:20:19.149,1644164000.0,1.0,1.0,0.0,,0.0,...,"B01M8GXNJX,B01FUWJAQC,B072JYVYCX,B06XX4D2KY,B0...","1488662110,1498345226,1507389221,1509398438,15...","[3481, 3332, 3694, 3591, 3871, 3942, 3059, 398...","[1488662110, 1498345226, 1507389221, 150939843...","[7, 7, 7, 7, 7, 7, 6, 6, 6, 6]",Video Games,Sniper: Ghost Warrior 3 Season Pass Edition - ...,[Go behind enemy lines with the ultimate moder...,"[Video Games, PlayStation 4, Games]",21.69
613,AEE2GRES5JQA6J5B3N5ALF3V3OPA,B07N8XW8WX,0.0,2022-06-02 18:11:11.861,,1.0,5.0,0.0,,0.0,...,"B004C43FH0,B001EYUXIK,B000EMJA3M,B004APAEL6,B0...","1359870667,1360436187,1361047278,1369028864,13...","[-1, -1, -1, -1, 1664, 1025, 440, 1654, 1656, ...","[-1, -1, -1, -1, 1359870667, 1360436187, 13610...","[-1, -1, -1, -1, 8, 8, 8, 8, 8, 8]",Video Games,Axiom Verge: Standard Edition - Nintendo Switch,"[In Axiom Verge , you play as Trace, a scienti...","[Video Games, Prime Member Pre-Orders in Video...",51.17
905,AHTLGOWRXF4PNX6DEUZXVZWBBVTQ,B0C4KN63KM,0.0,2022-03-11 08:20:00.705,,7.0,4.428571,1.0,5.0,0.0,...,"B00DDILSBG,B07YBXFDYK,B00PGLG79G,B01BF9X6LO,B0...","1514194117,1528409448,1528409892,1528410399,15...","[2441, 4281, 2893, 3277, 4282, 4209, 3360, 429...","[1514194117, 1528409448, 1528409892, 152841039...","[7, 7, 7, 7, 6, 6, 6, 6, 6, 0]",Computers,"NPET K10V1 Wired Computer Keyboard, Plug and P...",[],"[Video Games, PC, Accessories, Gaming Keyboards]",18.99
249,AFSDTMNFTR6SW5GKERMS424O22XQ,B00F27JGVA,0.0,2021-08-18 18:48:32.126,,6.0,4.0,0.0,,0.0,...,"B000TLU67W,B001QCWSJC,B07CQXSC7K,B00MOQWBQ4,B0...","1421078256,1421078322,1421078350,1421078392,14...","[589, 1142, 3904, 2803, 439, 770, 1706, 2670, ...","[1421078256, 1421078322, 1421078350, 142107839...","[8, 8, 8, 8, 8, 8, 8, 7, 7, 5]",Computers,PlayStation Vita Memory Card 64GB (PCH-Z641J),[This is a memory card for Play Station Vita t...,"[Video Games, Legacy Systems, PlayStation Syst...",251.7
881,AGYJI6XSILTABA3ZWALFM6SMEVOQ,B0BFT941YQ,1.0,2022-04-22 22:32:31.047,1650667000.0,2.0,4.5,0.0,,0.0,...,"B07SM7G9CN,B07G1SC6BW,B07L6MJ6LD,B07J3P1GJM,B0...","1527019952,1538858187,1557937823,1557937962,16...","[-1, 4164, 3970, 4049, 4018, 4388, 3697, 3576,...","[-1, 1527019952, 1538858187, 1557937823, 15579...","[-1, 7, 7, 6, 6, 6, 6, 5, 5, 5]",Video Games,DRAGON BALL Z: Kakarot - PlayStation 5,[Relive the story of Goku and other Z Fighters...,"[Video Games, PlayStation 5, Games]",19.99
1509,AHQJ5UXX647PD77SHSTCPKTSA3XA,B07WZ78VRN,1.0,2021-09-13 14:51:38.718,1631545000.0,2.0,5.0,0.0,,0.0,...,"B00LE3EAIK,B004RMK5QG,B0BL65X86R,B06XHLM4DX,B0...","1514575134,1576179131,1582813328,1588087497,16...","[-1, -1, -1, -1, 2760, 1767, 4559, 3573, 4043,...","[-1, -1, -1, -1, 1514575134, 1576179131, 15828...","[-1, -1, -1, -1, 7, 6, 6, 6, 5, 5]",Computers,8Bitdo SN30 Pro Wireless Bluetooth Controller ...,[],"[Video Games, Mac, Accessories, Controllers, G...",44.99
975,AH4LS5HFN2VGSKNKQGNZGENHBVVA,B07624RBWB,0.0,2021-12-19 22:28:25.889,,32.0,4.59375,2.0,5.0,0.0,...,"B008FHL56S,B0053B66KE,B00CQ9L1Z6,B017C6OK7S,B0...","1397806962,1399892741,1402389343,1402389660,14...","[2151, 1871, 2404, 3199, 4403, 2682, 1915, 275...","[1397806962, 1399892741, 1402389343, 140238966...","[8, 8, 8, 8, 8, 8, 8, 7, 6, 6]",Video Games,Nintendo Switch Pro Controller,[],"[Video Games, Nintendo Switch, Accessories, Co...",69.0
1373,AH2FY6PMPRUFGY4CLK4ER356Q2JQ,B0044XU27A,0.0,2021-08-14 21:34:28.025,,2.0,5.0,1.0,5.0,0.0,...,"B00ZO1SUSE,B01CKGI0XA,B007BGUGVO,B08NYV2VLS,B0...","1451833963,1466268353,1468513788,1504289735,15...","[-1, 3079, 3294, 2052, 4421, 3650, 3555, 4372,...","[-1, 1451833963, 1466268353, 1468513788, 15042...","[-1, 8, 8, 8, 7, 7, 6, 6, 6, 6]",Video Games,Kingdoms of Amalur: Reckoning - Xbox 360,"[Product Description, The minds of, New York T...","[Video Games, Legacy Systems, Xbox Systems, Xb...",14.99


In [23]:
user_id = val_df.sample(1)[args.user_col].values[0]
# user_id = "AH4AOFTTDPHPAFAAVFMAF25H2LIQ"
test_df = val_df.loc[lambda df: df[args.user_col].eq(user_id)]
with pd.option_context("display.max_colwidth", None):
    display(test_df)

Unnamed: 0,user_id,parent_asin,rating,timestamp,timestamp_unix,parent_asin_rating_cnt_365d,parent_asin_rating_avg_prev_rating_365d,parent_asin_rating_cnt_90d,parent_asin_rating_avg_prev_rating_90d,parent_asin_rating_cnt_30d,...,user_rating_list_10_recent_asin,user_rating_list_10_recent_asin_timestamp,item_sequence,item_sequence_ts,item_sequence_ts_bucket,main_category,title,description,categories,price
39,AEUA6C3BAGAKZFYIEUQKXB74VNIQ,B003171CEW,0.0,2022-02-15 23:56:18.381,,5.0,4.6,0.0,,0.0,...,"B004QWYV2Q,B002BSH3K4,B07HYW4XMP,B007HUNYE0,B009AGXH64,B00J4Y6L2Y,B00004S9AF,B007RNWUC4,B07K3KHFSY,B09JDJYPD8",1311866699131585476613358164221341508406135335985813653860541476066169150290649315120747891568498076,"[1761, 1224, 4013, 2066, 2217, 2669, 85, 2079, 4028, 4478]","[1311866699, 1315854766, 1335816422, 1341508406, 1353359858, 1365386054, 1476066169, 1502906493, 1512074789, 1568498076]","[9, 9, 8, 8, 8, 8, 8, 7, 7, 6]",Video Games,MLB 10: The Show - Playstation 3,"[Product description, Welcome to The Show All-Star. The best selling and highest rated baseball franchise is back in MLB 10 The Show throwing you into an unsurpassed baseball experience where big moments come alive. It's all here too; the All New Home Run Derby, MLB All-Star Futures Game, Movie Maker, Catcher Mode, Personalized cheers and yells and Joe Mauer, newly crowned American League MVP, as the new cover athlete., Amazon.com, Welcome to The Show all-star. The best selling and highest rated baseball franchise is back in, MLB 10: The Show, throwing you into an unsurpassed baseball experience where big moments come alive. It’s all here too. In addition to a wealth of improvements to longtime franchise features this newest release includes the return of the Home Run Derby, MLB All-Star Futures Game, Movie Maker, Catcher Mode, Personalized cheers and yells and Joe Mauer, newly crowned American League MVP, as the new cover athlete. So get ready,, The Show, is about to begin and you are leading off. .caption { font-family: Verdana, Helvetica neue, Arial, serif; font-size: 10px; font-weight: bold; font-style: italic; } ul.indent { list-style: inside disc; text-indent: -15px; } table.callout { font-family: verdana; font-size: 11px; line-height: 1. 3em; } td.vgoverview { height: 125px; background: #9DC4D8 url(https://images-na.ssl-images-amazon.com/images/G/01/electronics/detail-page/callout-bg.png) repeat-x; border-left: 1px solid #999999; border-right: 1px solid #999999; padding-left: 20px; padding-right: 20px; padding-bottom: 10px; width: 250px; font-family: verdana; font-size: 12px; }, The return of the Home Run Derby., View larger, ., Game calling from the catcher position., View larger, ., New fielding and pitcher training options., View larger, ., 11 new stadiums and 1,250 new gameplay animations., View larger, ., The Return of Home Run Derby, The arcade style action of the Home Run Derby returns in, MLB 10: The Show, . Available through a variety play options, including season modes and a stand-alone mode that can be selected at any time, Home Run Derby is based on the MLB rules and flow of the actual MLB All-star game. In addition, the MLB All-star Futures Game is also available via season modes in its correct timeframe (just before the HR Derby). Will your Road to The Show player be invited to compete in the Home Run Derby or the Futures Game? Swing for the fences and keep your stats high and you may just get the call as the All-star break approaches., Call the Game as Catcher, The catcher is the brains of any defensive squad on the field, and as such, MLB 10: The Show, lets you test both your baseball IQ and your skills in the crouch as you call the game, one pitch at a time, from behind the plate. Available during Road to The Show, and exhibition game play options, players cultivate their rookie Minor League catcher and bring him up to the Majors where he can lay down signs to your pitcher, offer additional pitch selections if his first pitch call is shaken off and even change signs at will, all via your controller's face buttons. Pitches are called from a unique first-person perspective, after which the camera angle changes to a standard third-person perspective affording a better view of the entire field. This new functionality defines a whole new level of strategy, demanding knowledge of pitches, opposing batters' tendencies and the state of your own pitcher's well-being, as well as the ability to check runners on base and handle and/or block balls that are in the dirt or wild., New Defensive Training Options, No ballplayer becomes a golden glove overnight. With that in mind, MLB 10: The Show, includes new fielding and pitching training modes that augment the existing training functionality built into the game's improved Road to the Show mode. Fielding drills focus both on the basics of player's throwing arm, utilizing the new throw meter, as well as the more advanced combination of throwing and decision-making that players will need during game situations. Pitching training consists of a multi-pronged focus designed to improved control and accuracy. Training in these areas is available in isolated one-on-one battles known as ""Knockout,"" as well as simulated game situations, where goals are clearly defined for each drill., Key Game Features, A Wealth of New Features Including:, A Wealth of New Features Including:, MLB All-star Week consisting of the Home Run Derby and MLB All-star Futures Game.Full online season leagues with better multiplayer functionality.Catcher game calling functionality in certain modes.Movie Maker functionality to create personal highlight reels11 new StadiumsNew fielding and pitching training modesCustom music, fan yells, and chants., MLB All-star Week consisting of the Home Run Derby and MLB All-star Futures Game., MLB All-star Week consisting of the Home Run Derby and MLB All-star Futures Game., Full online season leagues with better multiplayer functionality., Full online season leagues with better multiplayer functionality., Catcher game calling functionality in certain modes., Catcher game calling functionality in certain modes., Movie Maker functionality to create personal highlight reels, Movie Maker functionality to create personal highlight reels, 11 new Stadiums, 11 new Stadiums, New fielding and pitching training modes, New fielding and pitching training modes, Custom music, fan yells, and chants., Custom music, fan yells, and chants., Improved Stadium Realism and Experience - From crowd ambiance to enhanced presentation system, even transitional daylight., Improved Stadium Realism and Experience, - From crowd ambiance to enhanced presentation system, even transitional daylight., Road to the Show v4.0 – Play the way that you want to with multiple new options settings, a rewards/penalty system based on play and a new, more accessible stat tracking system., Road to the Show v4.0, – Play the way that you want to with multiple new options settings, a rewards/penalty system based on play and a new, more accessible stat tracking system., Improved Online Gameplay – This year the online gameplay experience has been vastly improved and will detect and respond better to adverse network conditions along with reduced bandwidth to help the speed and flow of online gameplay., Improved Online Gameplay, – This year the online gameplay experience has been vastly improved and will detect and respond better to adverse network conditions along with reduced bandwidth to help the speed and flow of online gameplay., Full Online Season Leagues – Fully functional online season leagues, save and display MLB Player stats, track player energy, allow for trades/injuries, and offer 40-man roster functionality., Full Online Season Leagues, – Fully functional online season leagues, save and display MLB Player stats, track player energy, allow for trades/injuries, and offer 40-man roster functionality., Real-time Presentations – More than 1,250 new gameplay animations, more than 1,000 new presentation animations, and more than 400 personalized pitcher and batter animations., Real-time Presentations, – More than 1,250 new gameplay animations, more than 1,000 new presentation animations, and more than 400 personalized pitcher and batter animations., New Stadiums Available – The PlayStation 3 version of MLB 10: The Show includes five new Minor League stadiums, as well as classic parks including Forbes Field, Crosley Field, Polo Grounds, Shibe Park, Sportsman Park, and Griffith Stadium., New Stadiums Available, – The PlayStation 3 version of, MLB 10: The Show, includes five new Minor League stadiums, as well as classic parks including Forbes Field, Crosley Field, Polo Grounds, Shibe Park, Sportsman Park, and Griffith Stadium., Additional Features – Additional features include: multiplayer support, in-game messaging, voice chat, custom soundtracks, add-on content and HD support up to 1080p., Additional Features, – Additional features include: multiplayer support, in-game messaging, voice chat, custom soundtracks, add-on content and HD support up to 1080p.]","[Video Games, Legacy Systems, PlayStation Systems, PlayStation 3, Games]",19.71
1611,AEUA6C3BAGAKZFYIEUQKXB74VNIQ,B08QFPS48J,1.0,2022-02-15 23:56:18.381,1644969000.0,33.0,3.636364,3.0,2.333333,1.0,...,"B004QWYV2Q,B002BSH3K4,B07HYW4XMP,B007HUNYE0,B009AGXH64,B00J4Y6L2Y,B00004S9AF,B007RNWUC4,B07K3KHFSY,B09JDJYPD8",1311866699131585476613358164221341508406135335985813653860541476066169150290649315120747891568498076,"[1761, 1224, 4013, 2066, 2217, 2669, 85, 2079, 4028, 4478]","[1311866699, 1315854766, 1335816422, 1341508406, 1353359858, 1365386054, 1476066169, 1502906493, 1512074789, 1568498076]","[9, 9, 8, 8, 8, 8, 8, 7, 7, 6]",Video Games,Returnal - PlayStation 5,"[After crash-landing on this shape-shifting world, selene must search through the barren landscape of an ancient civilization for her escape. Isolated and alone, she finds herself fighting tooth and nail for survival. Again and again, she's defeated - forced to restart her journey every time she dies. Through relentless roguelike gameplay, you'll discover that just as the planet changes with every cycle, so do the items at your disposal. Every loop offers new combinations, forcing you to push your boundaries and approach combat with a different strategy each time. Brought to life by stunning visual effects, the dark beauty of the decaying world around you is packed with explosive surprises. From high stakes, Bullet hell-fueled combat, to visceral twists and turns through Stark and contrasting environments. You'll explore, discover and fight your way through an unforgiving journey, where mystery stalks your every move. Designed for extreme playability, the procedural world of returnal invites you to dust yourself off in the face of defeat and take on new, evolving challenges with every Rebirth.]","[Video Games, PlayStation 5, Games]",26.99


In [24]:
test_row = test_df.loc[lambda df: df[args.rating_col].gt(0)].iloc[0]
item_id = test_row[args.item_col]
item_sequence = test_row["item_sequence"]
logger.info(
    f"Test predicting before training with {args.user_col} = {user_id} and {args.item_col} = {item_id}"
)
user_indice = idm.get_user_index(user_id)
item_indice = idm.get_item_index(item_id)
user = torch.tensor([user_indice])
item_sequence = torch.tensor([item_sequence])
item = torch.tensor([item_indice])

model.eval()
model.predict(user, item_sequence, item)
model.train()

[32m2024-11-10 16:26:14.001[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mTest predicting before training with user_id = AEUA6C3BAGAKZFYIEUQKXB74VNIQ and parent_asin = B08QFPS48J[0m


SequenceRatingPrediction(
  (item_embedding): Embedding(4631, 128, padding_idx=4630)
  (user_embedding): Embedding(19578, 128)
  (gru): GRU(128, 128, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (fc_rating): Sequential(
    (0): Linear(in_features=384, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=128, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

#### Training loop

##### Overfit 1 batch

In [None]:
early_stopping = EarlyStopping(
    monitor="val_loss", patience=10, mode="min", verbose=False
)

model = init_model(n_users, n_items, args.embedding_dim, dropout=0)
lit_model = LitSequenceRatingPrediction(
    model,
    learning_rate=args.learning_rate,
    l2_reg=0.0,
    log_dir=args.notebook_persist_dp,
    accelerator=args.device,
)

log_dir = f"{args.notebook_persist_dp}/logs/overfit"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    accelerator=args.device if args.device else "auto",
    max_epochs=100,
    overfit_batches=1,
    callbacks=[early_stopping],
)
trainer.fit(
    model=lit_model,
    train_dataloaders=train_loader,
    val_dataloaders=train_loader,
)
logger.info(f"Logs available at {trainer.log_dir}")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(overfit_batches=1)` was configured so 1 batch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                     | Params | Mode 
-----------------------------------------------------------
0 | model | SequenceRatingPrediction | 3.2 M  | train
-----------------------------------------------------------
3.2 M     Trainable params
0         Non-trainable params
3.2 M     Total params
12.990    Total estimated model params size (MB)
11        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …


You requested to overfit but enabled val dataloader shuffling. We are turning off the val dataloader shuffling for you.


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


You requested to overfit but enabled train dataloader shuffling. We are turning off the train dataloader shuffling for you.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.



Training: |                                                                                                   …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=100` reached.
[32m2024-11-10 16:26:20.998[0m | [1mINFO    [0m | [36msrc.sequence.trainer[0m:[36mon_fit_end[0m:[36m127[0m - [1mLogging classification metrics...[0m
[32m2024-11-10 16:26:38.399[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m29[0m - [1mLogs available at /home/dvquys/frostmourne/recsys-mvp/notebooks/data/002-try-binary/logs/overfit/lightning_logs/version_2[0m


In [26]:
%tensorboard --logdir $trainer.log_dir

##### Fit on all data

In [None]:
# papermill_description=fit-model
early_stopping = EarlyStopping(
    monitor="val_loss", patience=args.early_stopping_patience, mode="min", verbose=False
)

checkpoint_callback = ModelCheckpoint(
    dirpath=f"{args.notebook_persist_dp}/checkpoints",
    filename="best-checkpoint",
    save_top_k=1,
    monitor="val_loss",
    mode="min",
)

model = init_model(
    n_users,
    n_items,
    args.embedding_dim,
    dropout=args.dropout,
    item_embedding=pretrained_item_embedding,
)
lit_model = LitSequenceRatingPrediction(
    model,
    learning_rate=args.learning_rate,
    l2_reg=args.l2_reg,
    log_dir=args.notebook_persist_dp,
    evaluate_ranking=True,
    idm=idm,
    args=args,
    accelerator=args.device,
)

log_dir = f"{args.notebook_persist_dp}/logs/run"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    max_epochs=args.max_epochs,
    callbacks=[early_stopping, checkpoint_callback],
    accelerator=args.device if args.device else "auto",
    logger=args._mlf_logger if args.log_to_mlflow else None,
)
trainer.fit(
    model=lit_model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

Checkpoint directory /home/dvquys/frostmourne/recsys-mvp/notebooks/data/002-try-binary/checkpoints exists and is not empty.

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                     | Params | Mode 
-----------------------------------------------------------
0 | model | SequenceRatingPrediction | 3.2 M  | train
-----------------------------------------------------------
3.2 M     Trainable params
0         Non-trainable params
3.2 M     Total params
12.990    Total estimated model params size (MB)
11        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.



Training: |                                                                                                   …

Validation: |                                                                                                 …

`Trainer.fit` stopped: `max_epochs=1` reached.
[32m2024-11-10 16:26:59.710[0m | [1mINFO    [0m | [36msrc.sequence.trainer[0m:[36mon_fit_end[0m:[36m127[0m - [1mLogging classification metrics...[0m
[32m2024-11-10 16:27:00.544[0m | [1mINFO    [0m | [36msrc.sequence.trainer[0m:[36mon_fit_end[0m:[36m130[0m - [1mLogging ranking metrics...[0m


Generating recommendations:   0%|          | 0/177 [00:00<?, ?it/s]


invalid value encountered in divide

2024/11/10 16:27:07 INFO mlflow.tracking._tracking_service.client: 🏃 View run 002-try-binary at: http://localhost:5002/#/experiments/2/runs/36a66aac50dd4c9ba306cc8176bca28e.
2024/11/10 16:27:07 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://localhost:5002/#/experiments/2.


In [29]:
logger.info(
    f"Test predicting after training with {args.user_col} = {user_id} and {args.item_col} = {item_id}"
)
model.eval()
model = model.to(user.device)
model.predict(user, item_sequence, item)
model.train()

[32m2024-11-10 16:27:08.694[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mTest predicting after training with user_id = AEUA6C3BAGAKZFYIEUQKXB74VNIQ and parent_asin = B08QFPS48J[0m


SequenceRatingPrediction(
  (item_embedding): Embedding(4631, 128, padding_idx=4630)
  (user_embedding): Embedding(19578, 128)
  (gru): GRU(128, 128, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.3, inplace=False)
  (fc_rating): Sequential(
    (0): Linear(in_features=384, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
    (4): Linear(in_features=128, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

# Load best checkpoint

In [30]:
logger.info(f"Loading best checkpoint from {checkpoint_callback.best_model_path}...")
args.best_checkpoint_path = checkpoint_callback.best_model_path

best_trainer = LitSequenceRatingPrediction.load_from_checkpoint(
    checkpoint_callback.best_model_path,
    model=init_model(n_users, n_items, args.embedding_dim, dropout=0),
)

[32m2024-11-10 16:27:09.502[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mLoading best checkpoint from /home/dvquys/frostmourne/recsys-mvp/notebooks/data/002-try-binary/checkpoints/best-checkpoint-v2.ckpt...[0m


In [31]:
best_model = best_trainer.model.to(lit_model.device)

In [32]:
best_model.eval()
best_model.predict(user, item_sequence, item)
best_model.train()

SequenceRatingPrediction(
  (item_embedding): Embedding(4631, 128, padding_idx=4630)
  (user_embedding): Embedding(19578, 128)
  (gru): GRU(128, 128, batch_first=True)
  (relu): ReLU()
  (dropout): Dropout(p=0, inplace=False)
  (fc_rating): Sequential(
    (0): Linear(in_features=384, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0, inplace=False)
    (4): Linear(in_features=128, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

### Persist id mapping

In [33]:
if args.log_to_mlflow:
    # Persist id_mapping so that at inference we can predict based on item_ids (string) instead of item_index
    run_id = trainer.logger.run_id
    mlf_client = trainer.logger.experiment
    mlf_client.log_artifact(run_id, idm_fp)

### Wrap inference function and register best checkpoint as MLflow model

In [34]:
inferrer = SequenceRatingPredictionInferenceWrapper(best_model)

In [35]:
sample_input = {
    "user_ids": [idm.get_user_id(0)],
    "item_sequences": [[idm.get_item_id(0), idm.get_item_id(1)]],
    "item_ids": [idm.get_item_id(0)],
}
sample_output = inferrer.infer([0], [[0, 1]], [0])
sample_output

array([0.5575273], dtype=float32)

In [36]:
if args.log_to_mlflow:
    run_id = trainer.logger.run_id
    sample_output_np = sample_output
    signature = infer_signature(sample_input, sample_output_np)
    idm_filename = idm_fp.split("/")[-1]
    with mlflow.start_run(run_id=run_id):
        mlflow.pyfunc.log_model(
            python_model=inferrer,
            artifact_path="inferrer",
            # We log the id_mapping to the predict function so that it can accept item_id and automatically convert ot item_indice for PyTorch model to use
            artifacts={"idm": mlflow.get_artifact_uri(idm_filename)},
            signature=signature,
            input_example=sample_input,
            registered_model_name=args.mlf_model_name,
        )


Since MLflow 2.16.0, we no longer convert dictionary input example to pandas Dataframe, and directly save it as a json object. If the model expects a pandas DataFrame input instead, please pass the pandas DataFrame as input example directly.



Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Registered model 'sequence_rating_prediction' already exists. Creating a new version of this model...
2024/11/10 16:27:13 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: sequence_rating_prediction, version 2
Created version '2' of model 'sequence_rating_prediction'.


Downloading artifacts:   0%|          | 0/8 [00:00<?, ?it/s]

2024/11/10 16:27:13 INFO mlflow.tracking._tracking_service.client: 🏃 View run 002-try-binary at: http://localhost:5002/#/experiments/2/runs/36a66aac50dd4c9ba306cc8176bca28e.
2024/11/10 16:27:13 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://localhost:5002/#/experiments/2.


# Set the newly trained model as champion

In [37]:
if args.log_to_mlflow:
    val_roc_auc = trainer.logger.experiment.get_run(trainer.logger.run_id).data.metrics[
        "val_roc_auc"
    ]

    if val_roc_auc > args.min_roc_auc:
        logger.info(f"Aliasing the new model as champion...")
        model_version = (
            mlf_client.get_registered_model(args.mlf_model_name)
            .latest_versions[0]
            .version
        )

        mlf_client.set_registered_model_alias(
            name=args.mlf_model_name, alias="champion", version=model_version
        )

        mlf_client.set_model_version_tag(
            name=args.mlf_model_name,
            version=model_version,
            key="author",
            value="quy.dinh",
        )

[32m2024-11-10 16:27:13.446[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m7[0m - [1mAliasing the new model as champion...[0m


# Clean up

In [38]:
all_params = [args]

if args.log_to_mlflow:
    with mlflow.start_run(run_id=run_id):
        for params in all_params:
            params_dict = params.dict()
            params_ = dict()
            for k, v in params_dict.items():
                if k == "top_K":
                    k = "top_big_K"
                if k == "top_k":
                    k = "top_small_k"
                params_[f"{params.__repr_name__()}.{k}"] = v
            mlflow.log_params(params_)

2024/11/10 16:27:13 INFO mlflow.tracking._tracking_service.client: 🏃 View run 002-try-binary at: http://localhost:5002/#/experiments/2/runs/36a66aac50dd4c9ba306cc8176bca28e.
2024/11/10 16:27:13 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://localhost:5002/#/experiments/2.
