In [1]:
import numpy as np
import pandas as pd
from pathlib import Path
import pickle
import torch

import SASRec_class as sasrec

np.random.seed(42)
if np.random.choice(np.arange(1000)) != 102:
    raise ValueError("Random seed is not set correctly.")

## 1. Load Processed Data

In [2]:
DATASET = 'steam'
base_artifacts = Path.cwd().resolve().parents[1] / 'CausalI2I_artifacts'
data_sasrec = pd.read_csv(
    base_artifacts / 'Datasets' / 'Processed' / DATASET / 'data_sasrec.csv'
)

In [3]:
parameters_dict = {
    'ml-1m': {
        'L': 200, 
        'dropout': 0.2, 
        'batch_size': 128,
        'num_epochs': 10},
    'steam': {
        'L': 50, 
        'dropout': 0.5, 
        'batch_size': 2**11,
        'num_epochs': 20},
    'goodreads': {
        'L': 50, 
        'dropout': 0.5, 
        'batch_size': 2**11,
        'num_epochs': 20}
}

L = parameters_dict[DATASET]['L']
dropout = parameters_dict[DATASET]['dropout']
batch_size = parameters_dict[DATASET]['batch_size']
num_epochs = parameters_dict[DATASET]['num_epochs']

In [4]:
unique_users = data_sasrec['user_id'].unique()
n_users = len(unique_users)

train_users = np.random.choice(
    unique_users, 
    size=int(0.8 * n_users), 
    replace=False)
test_users = np.setdiff1d(unique_users, train_users)

num_items = data_sasrec['item_id'].nunique()

In [5]:
users_dict = data_sasrec.groupby('user_id')['item_id'].apply(list).to_dict()
lens = [len(users_dict[user]) for user in users_dict]
np.mean(lens)

np.float64(162.60485413677665)

In [None]:
padding_idx = num_items

train_dataset = []
test_dataset = []
for user_id in users_dict:
    users_dict[user_id] = [padding_idx] * (L - 2) + users_dict[user_id]
    i = 0
    while i + L <= len(users_dict[user_id]):
        if user_id in train_users:
            train_dataset.append(users_dict[user_id][i:i+L])
        else:
            test_dataset.append(users_dict[user_id][i:i+L])
        i += 1
train_dataset = np.array(train_dataset)
test_dataset = np.array(test_dataset)

# 2. Train the model

In [None]:
model = sasrec.SASRecTorch(
    num_items=num_items,
    max_seq_len=L,
    d_model=50,
    n_heads=1,
    n_layers=2,
    dropout=dropout,
    device="cuda",
)
model.fit(
    train_dataset=train_dataset,
    valid_dataset=test_dataset,
    batch_size=batch_size,
    lr=1e-3,
    weight_decay=0.0, 
    num_epochs=num_epochs,
)



Epoch | T-Loss | V-Loss | Pctl  | HR10  | NDCG  | CosÎ¸  | Elapsed Time
    1 |  1.062 |  0.889 | 0.856 | 0.564 | 0.415 | None  |     00:24.3
    2 |  0.853 |  0.796 | 0.887 | 0.642 | 0.455 | 0.515 |     00:48.3
    3 |  0.779 |  0.741 | 0.903 | 0.689 | 0.482 | 0.555 |     01:11.9
    4 |  0.722 |  0.699 | 0.915 | 0.727 | 0.509 | 0.554 |     01:35.7
    5 |  0.682 |  0.677 | 0.921 | 0.748 | 0.525 | 0.578 |     01:59.5
    6 |  0.661 |  0.668 | 0.924 | 0.758 | 0.533 | 0.609 |     02:23.5
    7 |  0.647 |  0.661 | 0.926 | 0.765 | 0.539 | 0.613 |     02:47.4
    8 |  0.635 |  0.655 | 0.928 | 0.770 | 0.544 | 0.575 |     03:11.5
    9 |  0.624 |  0.648 | 0.930 | 0.776 | 0.551 | 0.548 |     03:35.7
   10 |  0.612 |  0.643 | 0.931 | 0.781 | 0.555 | 0.558 |     03:59.7
   11 |  0.604 |  0.639 | 0.932 | 0.785 | 0.558 | 0.523 |     04:23.7
   12 |  0.596 |  0.636 | 0.933 | 0.788 | 0.561 | 0.498 |     04:47.8
   13 |  0.589 |  0.633 | 0.934 | 0.790 | 0.563 | 0.473 |     05:11.7
   14 |  0.584 |  

In [None]:
folder_path = base_artifacts / 'SASRec_Models'
model.save(path=folder_path / f'sasrec_{DATASET}.pt')

init_dict = {
    "num_items": num_items,
    "max_seq_len": L,
    "d_model": model.d_model,
    "n_heads": model.n_heads,
    "n_layers": model.n_layers,
    "dropout": model.dropout,
    "device": model.device
}

with open(folder_path / f'sasrec_{DATASET}_init_dict.pkl', 'wb') as f:
    pickle.dump(init_dict, f)


# 3. Load a Model

In [4]:
folder_path = base_artifacts / 'SASRec_Models'
with open(folder_path / f'sasrec_{DATASET}_init_dict.pkl', 'rb') as f:
    init_dict_loaded = pickle.load(f)

loaded_model = sasrec.SASRecTorch(**init_dict_loaded)
loaded_model.load(folder_path / f'sasrec_{DATASET}.pt')



Model loaded from /home/gouni/CausalI2I_artifacts/SASRec_Models/sasrec_steam.pt.
num_items:     12434
max_seq_len:   50
device:        cuda
batch_size:    2048
lr:            0.001
weight_decay:  0.0
num_epochs:    20
saved_at:      2026-01-03 12:00:39
note:          None
