In [None]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

In [None]:
import pandas as pd
import numpy as np
from pydantic import BaseModel
import sys
import os
from lightning.pytorch.loggers import MLFlowLogger
from loguru import logger
from load_dotenv import load_dotenv
import time
import json
import torch
from torch.utils.data import DataLoader
import lightning as L
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint
import mlflow

sys.path.insert(0, "..")

from src.utils.embedding_id_mapper import IDMapper
from src.algo.gSASRec.model import SASRec
from src.algo.gSASRec.dataset import SASRecDataset
from src.algo.gSASRec.trainer import SASRecLitModule
from src.eval.utils import create_rec_df, create_label_df, merge_recs_with_target
from src.eval.log_metrics import log_ranking_metrics, log_classification_metrics

In [None]:
load_dotenv(override = True)

In [None]:
class Args(BaseModel):
    testing: bool = False
    log_to_mlflow: bool = True
    experiment_name: str = "first-attempt"
    run_name: str = f"018-sasrec"
    notebook_persit_dp: str = None
    
    user_col: str = "user_id"
    item_col: str = "parent_asin"
    rating_col: str = "rating"
    timestamp_col: str = "timestamp"
    group_name: str = "seq-modelling"

    top_K: int = 100
    top_k: int = 10

    batch_size: int = 256
    lr: float = 0.001
    l2_emb: float = 0.0001
    early_stopping_patience: int = 10
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    num_epochs: int = 100

    # SASrec specific
    max_len: int = 10
    dropout: float = 0.5
    hidden_units: int = 128
    num_blocks: int = 1
    num_heads: int = 2
    num_workers: int = 3
    pad_token: int = 4817
    # seq_length: int = 10
    
    train_data_fp: str = os.path.abspath("../data_for_ai/interim/train_sample_interactions_16407u_neg_seq.parquet")
    val_data_fp: str = os.path.abspath("../data_for_ai/interim/val_sample_interactions_16407u_neg_seq.parquet")

    def init(self):
        self.notebook_persit_dp = os.path.abspath(f"data/{self.experiment_name}/{self.run_name}")

        if not (mlflow_uri := os.environ.get("MLFLOW_TRACKING_URI")):
            self.log_to_mlflow = False
            logger.warning("MLFlow is not enabled. Turn off tracking to Mlflow.")

        if self.log_to_mlflow:
            logger.info(
                f"Setting up Mlflow experiment: {self.experiment_name}, run_name: {self.run_name}"
            )

            self._mlf_logger = MLFlowLogger(
                experiment_name=self.experiment_name,
                run_name=self.run_name,
                tracking_uri=mlflow_uri,
                log_model=True,
            )

        if not self.testing:
            os.makedirs(self.notebook_persit_dp, exist_ok=True)
        return self
    
args = Args().init()
print(args.model_dump_json(indent=2))

In [None]:
train_df = pd.read_parquet(args.train_data_fp)
train_df[args.rating_col] = train_df[args.rating_col].apply(lambda x: 1 if x > 0 else 0)            

val_df = pd.read_parquet(args.val_data_fp)
val_df[args.rating_col] = val_df[args.rating_col].apply(lambda x: 1 if x > 0 else 0)

assert set(val_df[args.user_col].unique()).issubset(set(train_df[args.user_col].unique())), "Validation users must be present in training users."

assert set(val_df[args.item_col].unique()).issubset(set(train_df[args.item_col].unique())), "Validation items must be present in training items."
assert train_df[args.timestamp_col].max() < val_df[args.timestamp_col].min(), "Validation data must be after training data. Otherwise, its a data contamination problem."

In [None]:
val_df

In [None]:
train_df.head(3)

In [None]:
def init_model(n_user, n_items, dropout, hidden_units, num_blocks, num_heads):
    """
    Initialize the model with the given parameters.
    """
    model = SASRec(
        user_num = n_user,
        item_num = n_items,
        dropout_rate = dropout,
        hidden_units = hidden_units,
        num_blocks = num_blocks,
        num_heads = num_heads,
    )
    return model

In [None]:
batch_size = 2
hidden_units = 8
dropout = 0.2
num_blocks = 1
num_heads = 2

# Mock data
user_indices = [0, 0, 1, 2, 2]
item_indices = [0, 1, 2, 3, 4]
timestamps = [0, 1, 2, 3, 4]
ratings = [0, 4, 5, 3, 0]

user_num = len(set(user_indices))
item_num = len(set(item_indices))

train_test_df = pd.DataFrame(
    {
        "user_indice": user_indices,
        "item_indice": item_indices,
        args.timestamp_col: timestamps,
        args.rating_col: ratings,
    }
)

model = init_model(user_num, item_num, dropout,hidden_units, num_blocks, num_heads)

# Example forward pass
model.eval()
user = torch.tensor([[0]])
seq = torch.tensor([[0,0,0,0,0,1,2,3,4,5]])
target_item = torch.tensor([[2]])
predictions = model.predict(user, seq, target_item)
print(predictions)

In [None]:
train_df["item_indice"].max() + 1

In [None]:
rating_dataset = SASRecDataset(
    train_df, "user_indice", "item_sequence", "item_indice", "rating",args.max_len, args.pad_token, args.timestamp_col, 
)
val_rating_dataset = SASRecDataset(
    val_df, "user_indice", "item_sequence", "item_indice", "rating", args.max_len, args.pad_token, args.timestamp_col, 
)

train_loader = DataLoader(
    rating_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.num_workers, persistent_workers=True
)
val_loader = DataLoader(
    val_rating_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.num_workers, persistent_workers=True
)

In [None]:
for i in val_loader:
    print(i["user"])
    print(i["sequence"])
    print(i["item"])
    print(i["rating"])
    break

In [None]:
for i in train_loader:
    print(i["user"])
    print(i["sequence"])
    print(i["item"])
    print(i["rating"])
    break

In [None]:
item_indices = train_df[args.item_col].unique()
user_indices = train_df[args.user_col].unique()
n_items = len(item_indices)
n_users = len(user_indices)

logger.info(f"Number of users: {n_users}, Number of items: {n_items}")
model = init_model(n_users, n_items, args.dropout, args.hidden_units, args.num_blocks, args.num_heads)

In [None]:
idm_path = os.path.abspath("../data_for_ai/interim/idm_16407u.json")
idm = IDMapper().load(idm_path)
idm.get_user_id(1)

## overfit 1 batch

In [None]:
# early_stopping = EarlyStopping(
#     monitor="val_loss", patience=5, mode="min", verbose=False
# )
# # create log_dir if it does not exist
# if not os.path.exists(args.notebook_persit_dp):
#     os.makedirs(args.notebook_persit_dp, exist_ok=True)

# model = init_model(n_users, n_items, args.dropout, args.hidden_units, args.num_blocks, args.num_heads)
# lit_model = SASRecLitModule(
#     model,
#     log_dir=args.notebook_persit_dp,
#     accelerator=args.device,
#     lr=args.lr,
#     l2_emb=args.l2_emb,
#     idm= idm
# )

# log_dir = f"{args.notebook_persit_dp}/logs/overfit"
# # create log_dir if it does not exist
# if not os.path.exists(log_dir):
#     os.makedirs(log_dir, exist_ok=True)

# # train model
# trainer = L.Trainer(
#     default_root_dir=log_dir,
#     accelerator=args.device if args.device else "auto",
#     max_epochs=args.num_epochs,
#     overfit_batches=1,
#     callbacks=[early_stopping],
# )
# trainer.fit(
#     model=lit_model,
#     train_dataloaders=train_loader,
#     val_dataloaders=train_loader,
# )
# logger.info(f"Logs available at {trainer.log_dir}")

In [None]:
early_stopping = EarlyStopping(
    monitor="val_loss", patience=args.early_stopping_patience, mode="min", verbose=False, min_delta=0.0025
)

checkpoint_callback = ModelCheckpoint(
    dirpath=f"{args.notebook_persit_dp}/checkpoints",
    filename="best-checkpoint",
    save_top_k=1,
    monitor="val_loss",
    mode="min",
)

model = init_model(n_users, n_items, args.dropout, args.hidden_units, args.num_blocks, args.num_heads)
lit_model = SASRecLitModule(
    model,
    log_dir=args.notebook_persit_dp,
    accelerator=args.device,
    lr=args.lr,
    l2_emb=args.l2_emb,
    idm= idm
)

log_dir = f"{args.notebook_persit_dp}/logs/run"
# create log_dir if it does not exist
if not os.path.exists(log_dir):
    os.makedirs(log_dir, exist_ok=True)
    
# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    accelerator=args.device if args.device else "auto",
    max_epochs=1,
    # max_epochs=args.num_epochs,
    callbacks=[early_stopping, checkpoint_callback],
    logger=args._mlf_logger if args.log_to_mlflow else None,
)
trainer.fit(
    model=lit_model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)

# Change the library as a workaround for the issue in the latest Lightning release
#https://github.com/Lightning-AI/pytorch-lightning/pull/20669/commits/429f732a0528c558e701da7ec01e51c1e2e4f32e

In [None]:
all_params = [args]

if args.log_to_mlflow:
    run_id = trainer.logger.run_id

    with mlflow.start_run(run_id=run_id):
        for params in all_params:
            params_dict = params.model_dump()
            params_ = dict()
            for k, v in params_dict.items():
                if k == "top_K":
                    k = "top_big_K"
                if k == "top_k":
                    k = "top_small_k"
                params_[f"{params.__repr_name__()}.{k}"] = v
            mlflow.log_params(params_)