In [86]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [87]:
import pandas as pd
import numpy as np
from pydantic import BaseModel
import sys
import os
from lightning.pytorch.loggers import MLFlowLogger
from loguru import logger
from load_dotenv import load_dotenv
import time
import json
import torch
from torch.utils.data import DataLoader
import lightning as L
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint
import mlflow

sys.path.insert(0, "..")

from src.utils.embedding_id_mapper import IDMapper
from src.algo.gSASRec.model import SASRec
from src.algo.gSASRec.dataset import SASRecDataset
from src.algo.gSASRec.trainer import SASRecLitModule
from src.eval.utils import create_rec_df, create_label_df, merge_recs_with_target
from src.eval.log_metrics import log_ranking_metrics, log_classification_metrics

In [88]:
load_dotenv(override = True)

False

In [89]:
class Args(BaseModel):
    testing: bool = False
    log_to_mlflow: bool = True
    experiment_name: str = "first-attempt"
    run_name: str = f"018-sasrec"
    notebook_persit_dp: str = None
    
    user_col: str = "user_id"
    item_col: str = "parent_asin"
    rating_col: str = "rating"
    timestamp_col: str = "timestamp"
    group_name: str = "seq-modelling"

    top_K: int = 100
    top_k: int = 10

    batch_size: int = 256
    lr: float = 0.001
    l2_emb: float = 0.0001
    early_stopping_patience: int = 10
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    num_epochs: int = 100

    # SASrec specific
    max_len: int = 10
    dropout: float = 0.3
    hidden_units: int = 128
    num_blocks: int = 1
    num_heads: int = 2
    num_workers: int = 4
    pad_token: int = 4817
    # seq_length: int = 10
    
    train_data_fp: str = os.path.abspath("../data_for_ai/interim/train_sample_interactions_16407u_neg_seq.parquet")
    val_data_fp: str = os.path.abspath("../data_for_ai/interim/val_sample_interactions_16407u_neg_seq.parquet")

    def init(self):
        self.notebook_persit_dp = os.path.abspath(f"data/{self.experiment_name}/{self.run_name}")

        if not (mlflow_uri := os.environ.get("MLFLOW_TRACKING_URI")):
            self.log_to_mlflow = False
            logger.warning("MLFlow is not enabled. Turn off tracking to Mlflow.")

        if self.log_to_mlflow:
            logger.info(
                f"Setting up Mlflow experiment: {self.experiment_name}, run_name: {self.run_name}"
            )

            self._mlf_logger = MLFlowLogger(
                experiment_name=self.experiment_name,
                run_name=self.run_name,
                tracking_uri=mlflow_uri,
                log_model=True,
            )

        if not self.testing:
            os.makedirs(self.notebook_persit_dp, exist_ok=True)
        return self
    
args = Args().init()
print(args.model_dump_json(indent=2))



{
  "testing": false,
  "log_to_mlflow": false,
  "experiment_name": "first-attempt",
  "run_name": "018-sasrec",
  "notebook_persit_dp": "c:\\Users\\Trieu\\OneDrive\\Desktop\\recsys\\real_time_recsys\\notebooks\\data\\first-attempt\\018-sasrec",
  "user_col": "user_id",
  "item_col": "parent_asin",
  "rating_col": "rating",
  "timestamp_col": "timestamp",
  "group_name": "seq-modelling",
  "top_K": 100,
  "top_k": 10,
  "batch_size": 256,
  "lr": 0.001,
  "l2_emb": 0.0001,
  "early_stopping_patience": 10,
  "device": "cpu",
  "num_epochs": 100,
  "max_len": 10,
  "dropout": 0.3,
  "hidden_units": 128,
  "num_blocks": 1,
  "num_heads": 2,
  "num_workers": 4,
  "pad_token": 4817,
  "train_data_fp": "c:\\Users\\Trieu\\OneDrive\\Desktop\\recsys\\real_time_recsys\\data_for_ai\\interim\\train_sample_interactions_16407u_neg_seq.parquet",
  "val_data_fp": "c:\\Users\\Trieu\\OneDrive\\Desktop\\recsys\\real_time_recsys\\data_for_ai\\interim\\val_sample_interactions_16407u_neg_seq.parquet"
}


In [90]:
train_df = pd.read_parquet(args.train_data_fp)
train_df[args.rating_col] = train_df[args.rating_col].apply(lambda x: 1 if x > 0 else 0)            

val_df = pd.read_parquet(args.val_data_fp)
val_df[args.rating_col] = val_df[args.rating_col].apply(lambda x: 1 if x > 0 else 0)

assert set(val_df[args.user_col].unique()).issubset(set(train_df[args.user_col].unique())), "Validation users must be present in training users."

assert set(val_df[args.item_col].unique()).issubset(set(train_df[args.item_col].unique())), "Validation items must be present in training items."
assert train_df[args.timestamp_col].max() < val_df[args.timestamp_col].min(), "Validation data must be after training data. Otherwise, its a data contamination problem."

In [91]:
val_df

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,item_sequence
0,AGSP5XAQPQBUUXZHEZSC65FD7NOQ,B004FV4ROA,1,2020-12-27 00:30:31.146,11295,528,"[1898, 3479, 3908, 1570, 91, 2723, 2962, 106, ..."
1,AGSP5XAQPQBUUXZHEZSC65FD7NOQ,B07KFQFDNB,0,2020-12-27 00:30:31.146,11295,3503,"[3479, 3908, 1570, 91, 2723, 2962, 106, 3557, ..."
2,AEHS7YR7BGGWMZS24H5UR5IP46HQ,B08F1P3BCC,1,2020-12-27 01:44:52.242,1784,3925,"[4319, 3382, 4330, 1173, 1330, 423, 2868, 3167..."
3,AEHS7YR7BGGWMZS24H5UR5IP46HQ,B00HXT8EKE,0,2020-12-27 01:44:52.242,1784,1507,"[3382, 4330, 1173, 1330, 423, 2868, 3167, 1071..."
4,AGAVHCK42EGMVS7DGPRX6HBCUCNQ,B09Q3NR84W,1,2020-12-27 02:25:48.357,9042,4273,"[1311, 1416, 455, 3743, 1823, 2694, 3612, 3462..."
...,...,...,...,...,...,...,...
6953,AEEQZRQBOFHFBFPYBX2BZ5WOI33A,B01A08E70K,0,2022-02-19 16:56:53.030,1396,2441,"[3451, 3827, 1839, 1347, 2504, 2694, 4546, 427..."
6954,AHLN6GKTKZE22AON34YAQXTGK63A,B0C682GZ5X,1,2022-02-19 17:28:55.519,14550,4772,"[2950, 1812, 4735, 4165, 4575, 2440, 607, 4807..."
6955,AHLN6GKTKZE22AON34YAQXTGK63A,B09SWWCN6Q,0,2022-02-19 17:28:55.519,14550,4303,"[1812, 4735, 4165, 4575, 2440, 607, 4807, 374,..."
6956,AEMYBWDN67IB5IBTMHLHN76V4QHQ,B091K4WYD1,1,2022-02-19 22:08:53.253,2446,4086,"[644, 3602, 4569, 1865, 3030, 3653, 3803, 3998..."


In [92]:
train_df.head(3)

Unnamed: 0,user_id,parent_asin,rating,timestamp,user_indice,item_indice,item_sequence
0,AFZ4EK2LJ655XQKTEUELCARO6RYA,B00002EQCW,1,2003-01-23 03:28:15,8071,4,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1...."
1,AFZ4EK2LJ655XQKTEUELCARO6RYA,B095JX15XF,0,2003-01-23 03:28:15,8071,4132,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1...."
2,AFY2C4YOUP2SSMM43HD2L3FIEFZA,B00008SCFL,1,2003-11-25 18:12:09,7935,36,"[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1...."


In [93]:
def init_model(n_user, n_items, dropout, hidden_units, num_blocks, num_heads):
    """
    Initialize the model with the given parameters.
    """
    model = SASRec(
        user_num = n_user,
        item_num = n_items,
        dropout_rate = dropout,
        hidden_units = hidden_units,
        num_blocks = num_blocks,
        num_heads = num_heads,
    )
    return model

In [94]:
batch_size = 2
hidden_units = 8
dropout = 0.2
num_blocks = 1
num_heads = 2

# Mock data
user_indices = [0, 0, 1, 2, 2]
item_indices = [0, 1, 2, 3, 4]
timestamps = [0, 1, 2, 3, 4]
ratings = [0, 4, 5, 3, 0]

user_num = len(set(user_indices))
item_num = len(set(item_indices))

train_test_df = pd.DataFrame(
    {
        "user_indice": user_indices,
        "item_indice": item_indices,
        args.timestamp_col: timestamps,
        args.rating_col: ratings,
    }
)

model = init_model(user_num, item_num, dropout,hidden_units, num_blocks, num_heads)

# Example forward pass
model.eval()
user = torch.tensor([[0]])
seq = torch.tensor([[0,0,0,0,0,1,2,3,4,5]])
target_item = torch.tensor([[2]])
predictions = model.predict(user, seq, target_item)
print(predictions)

tensor([[0.1449]], grad_fn=<SigmoidBackward0>)


In [95]:
train_df["item_indice"].max() + 1

4817

In [96]:
rating_dataset = SASRecDataset(
    train_df, "user_indice", "item_sequence", "item_indice", "rating",args.max_len, args.pad_token, args.timestamp_col, 
)
val_rating_dataset = SASRecDataset(
    val_df, "user_indice", "item_sequence", "item_indice", "rating", args.max_len, args.pad_token, args.timestamp_col, 
)

train_loader = DataLoader(
    rating_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True
)
val_loader = DataLoader(
    val_rating_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False
)

In [97]:
for i in val_loader:
    print(i["user"])
    print(i["sequence"])
    print(i["item"])
    print(i["rating"])
    break

tensor([11295, 11295,  1784,  1784,  9042,  9042,  1542,  1542, 10418, 10418,
         2786,  2786,  6758,  6758,  8761,  8761,  6185,  6185,  1134,  1134,
        11437, 11437, 12823, 12823,  8666,  8666,  6710,  6710,  1590,  1590,
         1590,  1590, 11747, 11747,  2393,  2393, 10403, 10403,  3982,  3982,
         4098,  4098,  7417,  7417,  6394,  6394,  7293,  7293,  5421,  5421,
        14988, 14988,  1574,  1574, 16260, 16260, 10908, 10908,  1861,  1861,
        14374, 14374, 13038, 13038,  8429,  8429,  2164,  2164,  4053,  4053,
         6451,  6451,  1394,  1394, 15498, 15498, 11209, 11209, 14837, 14837,
        15761, 15761, 10092, 10092, 12914, 12914, 13467, 13467,  6383,  6383,
         7731,  7731,  1677,  1677,  7342,  7342,  7824,  7824, 11332, 11332,
         2189,  2189,  5829,  5829,  5942,  5942,   746,   746, 12588, 12588,
         4263,  4263, 13562, 13562, 12115, 12115, 12115, 12115, 12115, 12115,
        12115, 12115,  3016,  3016, 12115, 12115,  7119,  7119, 

In [98]:
for i in train_loader:
    print(i["user"])
    print(i["sequence"])
    print(i["item"])
    print(i["rating"])
    break

tensor([ 8764, 11791,  3810,  1698,  7657,  5694, 11980,  6694,  6349, 14496,
        15097,  2119,  8156,  7260, 10062,  1871,  9245,  5886,  5896,  7455,
        14399, 13116,  3887, 14094,  4692, 11056,  7399, 13695,   774,  7558,
         8098,  3752,  3539, 10448, 14942,  7519, 13039, 14693, 14989,  1701,
        16034,  8216, 16331,  2094,  4891,  5236,  2260, 10058, 12990, 16044,
         8024,   656,  9890, 11125,  9945,  8688,  5327,  3158,  4140,  8502,
         3256,  5543, 10498, 15323,  8203, 11943, 14939,  6804, 16037,  1146,
        15469, 13265,  6954,   815,  8517,  3409,  5399,  6147, 11049,  6340,
        11078, 11906,  5861, 14041,  3537,  5991,   473,  8732,  3667,  4332,
        14393,  5055,  6763, 15848,  2682,  2194, 10789,  9596, 15559,  9195,
         7930, 16285, 11608,  9157, 15012,  8426,  3970, 12041,  1227, 14942,
          151, 11533,  3869, 11266, 11200,  4932,   295, 11579, 11159,  3846,
        13533, 15895, 14985,  7893, 11735, 11424, 12135, 11675, 

In [99]:
item_indices = train_df[args.item_col].unique()
user_indices = train_df[args.user_col].unique()
n_items = len(item_indices)
n_users = len(user_indices)

logger.info(f"Number of users: {n_users}, Number of items: {n_items}")
model = init_model(n_users, n_items, args.dropout, args.hidden_units, args.num_blocks, args.num_heads)

[32m2025-05-01 14:44:11.721[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mNumber of users: 16407, Number of items: 4817[0m


In [100]:
# initialize the model parameters
for name, param in model.named_parameters():
    try:
        torch.nn.init.xavier_normal_(param.data)
    except:
        pass  # skip if the parameter is not a tensor

model.pos_emb.weight.data[0, :] = 0
model.item_emb.weight.data[0, :] = 0
model

SASRec(
  (item_emb): Embedding(4818, 128, padding_idx=4817)
  (pos_emb): Embedding(10, 128)
  (emb_dropout): Dropout(p=0.3, inplace=False)
  (attention_layernorms): ModuleList(
    (0): LayerNorm((128,), eps=1e-08, elementwise_affine=True)
  )
  (attention_layers): ModuleList(
    (0): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
    )
  )
  (forward_layernorms): ModuleList(
    (0): LayerNorm((128,), eps=1e-08, elementwise_affine=True)
  )
  (forward_layers): ModuleList(
    (0): PointWiseFeedForward(
      (conv1): Conv1d(128, 128, kernel_size=(1,), stride=(1,))
      (dropout1): Dropout(p=0.3, inplace=False)
      (relu): ReLU()
      (conv2): Conv1d(128, 128, kernel_size=(1,), stride=(1,))
      (dropout2): Dropout(p=0.3, inplace=False)
    )
  )
  (final_layer): Linear(in_features=128, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

In [101]:
idm_path = os.path.abspath("../data_for_ai/interim/idm_16407u.json")
idm = IDMapper().load(idm_path)
idm.get_user_id(1)

'AE227WAM4NWQPJI33OPN7ZARNNZQ'

## overfit 1 batch

In [102]:
early_stopping = EarlyStopping(
    monitor="val_loss", patience=5, mode="min", verbose=False
)

model = init_model(n_users, n_items, args.dropout, args.hidden_units, args.num_blocks, args.num_heads)
lit_model = SASRecLitModule(
    model,
    log_dir=args.notebook_persit_dp,
    accelerator=args.device,
    lr=args.lr,
    l2_emb=args.l2_emb,
    idm= idm
)

log_dir = f"{args.notebook_persit_dp}/logs/overfit"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    accelerator=args.device if args.device else "auto",
    max_epochs=args.num_epochs,
    overfit_batches=1,
    callbacks=[early_stopping],
)
trainer.fit(
    model=lit_model,
    train_dataloaders=train_loader,
    val_dataloaders=train_loader,
)
logger.info(f"Logs available at {trainer.log_dir}")

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(overfit_batches=1)` was configured so 1 batch will be used.

  | Name  | Type   | Params | Mode 
-----------------------------------------
0 | model | SASRec | 717 K  | train
-----------------------------------------
717 K     Trainable params
0         Non-trainable params
717 K     Total params
2.871     Total estimated model params size (MB)
20        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]


You requested to overfit but enabled val dataloader shuffling. We are turning off the val dataloader shuffling for you.


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


You requested to overfit but enabled train dataloader shuffling. We are turning off the train dataloader shuffling for you.


The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.



Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[32m2025-05-01 14:44:55.735[0m | [1mINFO    [0m | [36msrc.algo.gSASRec.trainer[0m:[36mon_fit_end[0m:[36m134[0m - [1mLogging ranking metrics...[0m


Recommendations: {'user_indice': [8071, 7935, 13705, 12730, 3735, 14832, 10069, 15786, 6742, 8061], 'recommendation': [[1204, 2004, 2, 3, 4, 3359, 2006, 3357, 2012, 9, 2013, 3350, 2014, 13, 2015, 2558, 2017, 17, 18, 19, 3344, 2018, 22, 23, 24, 25, 26, 27, 28, 2024, 30, 3340, 2027, 3339, 3338, 35, 36, 37, 3337, 2555, 40, 2029, 2030, 43, 2031, 2032, 2033, 47, 2554, 2552, 50, 3330, 2038, 3328, 3326, 55, 2551, 2040, 2550, 59, 60, 61, 62, 63, 3324, 2045, 66, 2046, 68, 69, 70, 71, 2047, 73, 2549, 2050, 2052, 2546, 78, 2057, 80, 81, 3322, 2544, 84, 85, 86, 2541, 2061, 3320, 2062, 91, 92, 2066, 3318, 95, 96, 2068, 98, 2537], [1204, 2004, 2, 3, 4, 3359, 2006, 3357, 2012, 9, 2013, 3350, 2014, 13, 2015, 2558, 2017, 17, 18, 19, 3344, 2018, 22, 23, 24, 25, 26, 27, 28, 2024, 30, 3340, 2027, 3339, 3338, 35, 36, 37, 3337, 2555, 40, 2029, 2030, 43, 2031, 2032, 2033, 47, 2554, 2552, 50, 3330, 2038, 3328, 3326, 55, 2551, 2040, 2550, 59, 60, 61, 62, 63, 3324, 2045, 66, 2046, 68, 69, 70, 71, 2047, 73, 2549


invalid value encountered in divide

[32m2025-05-01 14:45:27.266[0m | [1mINFO    [0m | [36msrc.algo.gSASRec.trainer[0m:[36mon_fit_end[0m:[36m137[0m - [1mEvidently metrics are available at: c:\Users\Trieu\OneDrive\Desktop\recsys\real_time_recsys\notebooks\data\first-attempt\018-sasrec[0m
[32m2025-05-01 14:45:27.270[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m30[0m - [1mLogs available at c:\Users\Trieu\OneDrive\Desktop\recsys\real_time_recsys\notebooks\data\first-attempt\018-sasrec\logs\overfit\lightning_logs\version_77[0m


{'metrics': [{'metric': 'NDCGKMetric', 'result': {'k': 10, 'current': 1     0.000000
2     0.000000
3     0.000000
4     0.000000
5     0.000008
6     0.000007
7     0.000006
8     0.000006
9     0.000006
10    0.000005
dtype: float64, 'current_value': 5.5420266117497916e-06, 'reference': None, 'reference_value': None}}, {'metric': 'RecallTopKMetric', 'result': {'k': 100, 'current': 0     0.000000
1     0.000000
2     0.000000
3     0.000000
4     0.000003
        ...   
95    0.000059
96    0.000059
97    0.000059
98    0.000059
99    0.000059
Length: 100, dtype: float64, 'current_value': 5.9366488329735725e-05, 'reference': None, 'reference_value': None}}, {'metric': 'PrecisionTopKMetric', 'result': {'k': 100, 'current': 0     0.000000
1     0.000000
2     0.000000
3     0.000000
4     0.000012
        ...   
95    0.000005
96    0.000005
97    0.000005
98    0.000005
99    0.000005
Length: 100, dtype: float64, 'current_value': 4.875967574815627e-06, 'reference': None, 'reference_val

In [None]:
early_stopping = EarlyStopping(
    monitor="val_loss", patience=args.early_stopping_patience, mode="min", verbose=False, min_delta=0.0025
)

checkpoint_callback = ModelCheckpoint(
    dirpath=f"{args.notebook_persit_dp}/checkpoints",
    filename="best-checkpoint",
    save_top_k=1,
    monitor="val_loss",
    mode="min",
)

model = init_model(n_users, n_items, args.dropout, args.hidden_units, args.num_blocks, args.num_heads)
lit_model = SASRecLitModule(
    model,
    log_dir=args.notebook_persit_dp,
    accelerator=args.device,
    lr=args.lr,
    l2_emb=args.l2_emb,
    idm= idm
)

log_dir = f"{args.notebook_persit_dp}/logs/run"

# train model
trainer = L.Trainer(
    default_root_dir=log_dir,
    accelerator=args.device if args.device else "auto",
    max_epochs=args.num_epochs,
    callbacks=[early_stopping, checkpoint_callback],
    logger=args._mlf_logger if args.log_to_mlflow else None,
)
trainer.fit(
    model=lit_model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)

# Change the library as a workaround for the issue in the latest Lightning release
#https://github.com/Lightning-AI/pytorch-lightning/pull/20669/commits/429f732a0528c558e701da7ec01e51c1e2e4f32e

AttributeError: 'Args' object has no attribute 'embedding_dim'

In [None]:
all_params = [args]

if args.log_to_mlflow:
    run_id = trainer.logger.run_id

    with mlflow.start_run(run_id=run_id):
        for params in all_params:
            params_dict = params.model_dump()
            params_ = dict()
            for k, v in params_dict.items():
                if k == "top_K":
                    k = "top_big_K"
                if k == "top_k":
                    k = "top_small_k"
                params_[f"{params.__repr_name__()}.{k}"] = v
            mlflow.log_params(params_)