In [2]:
import numpy as np
import pandas as pd
import torch
import pickle
from pathlib import Path

import MF_class as MF

np.random.seed(42)
if np.random.choice(np.arange(1000)) != 102:
    raise ValueError("Random seed is not set correctly.")

```
                                USERS                             
         ┌───────────────────────────────────────────────────────┐
         │                          │                            │
         │                          │                            │
         │                          │                            │
         │                          │                            │
ITEMS    │                          │                            │
         │                          │                            │
         │                          │                            │
         ├──────────────────────────┼────────────────────────────┤
         │                          │████████████████████████████│
         │                          │████████████████████████████│
         └───────────────────────────────────────────────────────┘
```

# 1 Choose Dataset

In [3]:
DATASET = 'goodreads'

base_artifacts = Path.cwd().resolve().parents[1] / 'CausalI2I_artifacts'
data_path = base_artifacts / 'Datasets' / 'Processed' / DATASET

In [4]:
parameters_dict = {
    'ml-1m': {
        'n_factors': 20,
        'lr': 5e-3,
        'batch_size': 2**15,
        'n_epochs': 20},
    'steam': {
        'n_factors': 50,
        'lr': 2e-3,
        'batch_size': 2**16,
        'n_epochs': 30},
    'goodreads': {
        'n_factors': 50,
        'lr': 1e-3,
        'batch_size': 2**15,
        'n_epochs': 20}
}

n_factors  = parameters_dict[DATASET]['n_factors']
lr         = parameters_dict[DATASET]['lr']
batch_size = parameters_dict[DATASET]['batch_size']
n_epochs   = parameters_dict[DATASET]['n_epochs']

# 1. Load Data

In [5]:
train = pd.read_csv(data_path / 'train.csv')
test = pd.read_csv(data_path / 'test.csv')
with open(data_path / 'item_dict.pkl', 'rb') as f:
    item_dict = pickle.load(f)

n_users = train['user_id'].nunique()
n_items = train['item_id'].nunique()
print(f'Number of users: {n_users}, Number of items: {n_items}')

Number of users: 7801, Number of items: 6384


# 2. Train Model

In [None]:
model = MF.MatrixFactorizationTorch(n_users, n_items, n_factors=n_factors)
model.fit(
    train_data=train.values,
    val_data=test.values,
    lr=lr, 
    wd=1e-7,
    pos_weight=1,
    batch_size=batch_size,
    n_epochs=n_epochs,
    device=torch.device('cuda:0'), 
    use_amp=True)

Epoch  ||- - - - - - - - Train - - - - - - - -||- - - - - - Validation - - - - - - - || Epoch's | COS θ | Time     
Number || BCE    | BCE-POS | BCE-NEG | MPR    || BCE    | BCE-POS | BCE-NEG | MPR    || Change  |       | Elapsed  
   1   || 0.0766 |  3.3557 |  0.0182 | 0.7898 || 0.0767 |  3.3303 |  0.0204 | 0.7760 || 225.83  | None  | 00:08.09
   2   || 0.0750 |  3.2878 |  0.0178 | 0.8018 || 0.0753 |  3.3473 |  0.0187 | 0.7844 ||  31.85  | 0.188 | 00:15.71
   3   || 0.0689 |  2.9577 |  0.0175 | 0.8510 || 0.0706 |  3.1220 |  0.0178 | 0.8256 ||  64.21  | 0.798 | 00:23.61
   4   || 0.0647 |  2.7349 |  0.0171 | 0.8854 || 0.0669 |  2.9201 |  0.0175 | 0.8617 ||  49.36  | 0.790 | 00:31.65
   5   || 0.0619 |  2.5879 |  0.0168 | 0.9033 || 0.0644 |  2.7899 |  0.0173 | 0.8820 ||  37.96  | 0.838 | 00:39.69
   6   || 0.0597 |  2.4839 |  0.0165 | 0.9142 || 0.0627 |  2.7082 |  0.0169 | 0.8938 ||  31.70  | 0.837 | 00:47.73
   7   || 0.0579 |  2.3992 |  0.0162 | 0.9220 || 0.0613 |  2.6489 |  0.0165 | 

### Save Model

In [None]:
name = f'MF{n_factors}_{DATASET}'

model_path = base_artifacts / 'Propensity_Models'
model.save(path=model_path / (name + '.pt'), note=None)

### Load Model

In [6]:
name = f'MF{n_factors}_{DATASET}'

loaded_model = MF.MatrixFactorizationTorch(n_users, n_items, n_factors=n_factors)
model_path = base_artifacts / 'Propensity_Models'
loaded_model.load(path=model_path / (name + '.pt'))

Loaded model summary:
Model:                      MatrixFactorizationTorch
Number of users:            7801
Number of items:            6384
Number of factors:          50
Learning rate:              0.001
Weight decay:               1e-07
Positive weight:            1
Batch size:                 32768
Number of epochs:           20
Device:                     cuda:0
Use AMP:                    True
Timestamp:                  2026-01-02 15:50:38
