In [1]:
import numpy as np
import pandas as pd
import os
import sys
import re
import random
import torch
from typing import List, Dict, Optional
import matplotlib.pyplot as plt
from tqdm import tqdm
import optuna
from optuna.pruners import MedianPruner
from optuna.exceptions import TrialPruned
import warnings

In [2]:
sys.path.append('..') 

from src.models.autoencoder import AutoEncoder
from src.models.autoencoder_trainer import *
from src.data.data_utils import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
data_dir = '../NETFLIX_DATA/partitions/train'
val_dir = '../NETFLIX_DATA/partitions/validation'
checkpoint_dir = '../model_checkpoints'

In [4]:
# retreive training data info
train_partition_files = get_data(data_dir)
print(f"Number of training partitions: {len(train_partition_files)}")
val_partition_files = get_data(val_dir)
print(f"Number of validation partitions: {len(val_partition_files)}")

seed = 42
random.seed(seed)

# testing
sample_train_partitions = random.sample(train_partition_files, 1)

sample_val_partitions = []
for partition in sample_train_partitions:
  val_partition = partition.copy()
  val_partition['path'] = partition['path'].replace('train', 'validation')
  sample_val_partitions.append(val_partition)


print(f"Train EX: {sample_train_partitions[0]}")
print(f"Val EX: {sample_val_partitions[0]}")

Number of training partitions: 34
Number of validation partitions: 34
Train EX: {'path': '../NETFLIX_DATA/partitions/train/part_1_7.parquet', 'part': 1, 'group': 7}
Val EX: {'path': '../NETFLIX_DATA/partitions/validation/part_1_7.parquet', 'part': 1, 'group': 7}


In [5]:
# build user_map and movie_map
user_map, movie_map = map_id(sample_train_partitions)

Mapping IDs: 100%|██████████| 1/1 [00:00<00:00,  4.16it/s]

Map successful for 317577 users, 354 movies





In [6]:
# preload user rating profiles

train_user_data = AutoEncoder.load_user_data(partitions=sample_train_partitions, 
                                             user_map=user_map)

validation_user_data = AutoEncoder.load_user_data(partitions=sample_val_partitions, 
                                                  user_map=user_map)

Loading user data: 100%|██████████| 1/1 [00:00<00:00,  9.53it/s]
Building user rating profiles: 100%|██████████| 317577/317577 [02:20<00:00, 2256.42it/s]
Loading user data: 100%|██████████| 1/1 [00:00<00:00, 28.56it/s]
Building user rating profiles: 100%|██████████| 71258/71258 [00:13<00:00, 5425.97it/s]


In [7]:
# optuna objective
def objective(trial):
    params = {
        "num_epochs": 5,
        "batch_size": 512,
        "learning_rate": trial.suggest_float("learning_rate", 0.0001, 0.001, log=True),
        "hidden_dims": trial.suggest_categorical("hidden_dims", 
                                                 [[512,256,128], [256,128], [512,128]]),
        "dropout": trial.suggest_float("dropout", 0.3, 0.7),
        "l2_reg": trial.suggest_float("l2_reg", 0.00001, 0.01, log=True),
        "checkpoint_interval": 5,
        "eval_interval": 1,
    }

                #[[512,256,128], [256,128], [512,128]]

    try:
        model, rmse = train_autoencoder(
            train_partitions=sample_train_partitions,
            user_map=user_map,
            movie_map=movie_map,
            validation_partitions=sample_val_partitions,
            checkpoint_dir=checkpoint_dir,
            trial=trial,
            user_data=train_user_data,
            validation_data=validation_user_data,
            **params
        )

        return float(rmse)
    except optuna.TrialPruned:
        raise # reraise prune error for tuning

In [8]:
# tuning
warnings.filterwarnings("ignore", module="optuna.*")

study = optuna.create_study(
    study_name="autoencoder_tuning", 
    direction='minimize',
    
    # prune after 2 trials, after 1 if really bad
    pruner=MedianPruner(n_startup_trials=2, n_warmup_steps=1),
    
    sampler=optuna.samplers.TPESampler(),
    storage="sqlite:///optuna_study.db",
    load_if_exists=True
)

study.optimize(objective, n_trials=10, timeout=2*3600)

[I 2025-06-02 09:17:46,554] A new study created in RDB with name: autoencoder_tuning


Device: cuda


Calculating global mean: 100%|██████████| 1/1 [00:00<00:00,  8.21it/s]

Global mean rating: 3.583



Epoch 1/3: 100%|██████████| 311/311 [00:21<00:00, 14.34it/s, loss=1.3305]


Epoch 1 - Average Loss: 3.7559
DEBUG: Final - num_batches: 92, total_predictions: 100561
DEBUG: loss 0.9830920715694842, rmse 0.987948818008201, mae 0.7829750955523882
Validation | Loss: 0.9831, RMSE: 0.9879


Epoch 2/3: 100%|██████████| 311/311 [00:20<00:00, 15.35it/s, loss=1.1669]


Epoch 2 - Average Loss: 1.2822
DEBUG: Final - num_batches: 92, total_predictions: 100561
DEBUG: loss 0.8947120360706163, rmse 0.9425775280063254, mae 0.7507508517359577
Validation | Loss: 0.8947, RMSE: 0.9426


Epoch 3/3: 100%|██████████| 311/311 [00:20<00:00, 15.06it/s, loss=1.1043]


Epoch 3 - Average Loss: 1.1798


[I 2025-06-02 09:19:06,963] Trial 0 finished with value: 0.944735352318519 and parameters: {'learning_rate': 0.00034504886479831817, 'hidden_dims': [128, 64], 'dropout': 0.4328524281461293, 'l2_reg': 3.0715465313408836e-05}. Best is trial 0 with value: 0.944735352318519.


DEBUG: Final - num_batches: 92, total_predictions: 100561
DEBUG: loss 0.8989250037981116, rmse 0.944735352318519, mae 0.7568164213516373
Validation | Loss: 0.8989, RMSE: 0.9447
Saved final model at ../model_checkpoints/final_model.pth.gz
Device: cuda


Calculating global mean: 100%|██████████| 1/1 [00:00<00:00,  7.66it/s]


Global mean rating: 3.583


Epoch 1/3: 100%|██████████| 311/311 [00:20<00:00, 15.15it/s, loss=1.3452]


Epoch 1 - Average Loss: 3.5568
DEBUG: Final - num_batches: 92, total_predictions: 100561
DEBUG: loss 1.0516647171715032, rmse 1.022334470418015, mae 0.8129651361721121
Validation | Loss: 1.0517, RMSE: 1.0223


Epoch 2/3: 100%|██████████| 311/311 [00:20<00:00, 14.97it/s, loss=1.3839]


Epoch 2 - Average Loss: 1.3858
DEBUG: Final - num_batches: 92, total_predictions: 100561
DEBUG: loss 0.9906695232443188, rmse 0.9919896551341109, mae 0.7982776722705925
Validation | Loss: 0.9907, RMSE: 0.9920


Epoch 3/3: 100%|██████████| 311/311 [00:19<00:00, 15.77it/s, loss=1.2061]


Epoch 3 - Average Loss: 1.2423
DEBUG: Final - num_batches: 92, total_predictions: 100561
DEBUG: loss 0.9775777847870536, rmse 0.9853540251685841, mae 0.794346536736603
Validation | Loss: 0.9776, RMSE: 0.9854
Saved final model at ../model_checkpoints/final_model.pth.gz


[I 2025-06-02 09:20:22,844] Trial 1 finished with value: 0.9853540251685841 and parameters: {'learning_rate': 0.0005152298545546065, 'hidden_dims': [128, 64], 'dropout': 0.5553667813028409, 'l2_reg': 2.6140669424286085e-05}. Best is trial 0 with value: 0.944735352318519.


Device: cuda


Calculating global mean: 100%|██████████| 1/1 [00:00<00:00,  5.13it/s]

Global mean rating: 3.583



Epoch 1/3: 100%|██████████| 311/311 [00:19<00:00, 15.93it/s, loss=1.5997]


Epoch 1 - Average Loss: 3.9476
DEBUG: Final - num_batches: 92, total_predictions: 100561
DEBUG: loss 1.1655144432316655, rmse 1.0784262460752958, mae 0.866568843474663
Validation | Loss: 1.1655, RMSE: 1.0784


Epoch 2/3: 100%|██████████| 311/311 [00:20<00:00, 15.23it/s, loss=1.3149]


Epoch 2 - Average Loss: 1.4913


[I 2025-06-02 09:21:12,573] Trial 2 pruned. 


DEBUG: Final - num_batches: 92, total_predictions: 100561
DEBUG: loss 1.0479064778141354, rmse 1.0222274705591397, mae 0.822929079020441
Device: cuda


Calculating global mean: 100%|██████████| 1/1 [00:00<00:00, 10.18it/s]


Global mean rating: 3.583


Epoch 1/3: 100%|██████████| 311/311 [00:19<00:00, 15.89it/s, loss=4.4163] 


Epoch 1 - Average Loss: 8.7924
DEBUG: Final - num_batches: 92, total_predictions: 100561
DEBUG: loss 6.294871454653532, rmse 2.334510769206626, mae 2.095570325244494
Validation | Loss: 6.2949, RMSE: 2.3345


Epoch 2/3: 100%|██████████| 311/311 [00:20<00:00, 15.00it/s, loss=2.9324]


Epoch 2 - Average Loss: 3.4288


[I 2025-06-02 09:22:02,458] Trial 3 pruned. 


DEBUG: Final - num_batches: 92, total_predictions: 100561
DEBUG: loss 2.578501185645228, rmse 1.3569629348620862, mae 1.0965442325641621
Device: cuda


Calculating global mean: 100%|██████████| 1/1 [00:00<00:00,  7.38it/s]


Global mean rating: 3.583


Epoch 1/3: 100%|██████████| 311/311 [00:20<00:00, 14.88it/s, loss=2.4760]


Epoch 1 - Average Loss: 6.4239
DEBUG: Final - num_batches: 92, total_predictions: 100561
DEBUG: loss 3.2059603618538897, rmse 1.789200579719088, mae 1.5509571426079138
Validation | Loss: 3.2060, RMSE: 1.7892


Epoch 2/3: 100%|██████████| 311/311 [00:20<00:00, 15.06it/s, loss=1.5671]


Epoch 2 - Average Loss: 1.9815


[I 2025-06-02 09:22:53,458] Trial 4 pruned. 


DEBUG: Final - num_batches: 92, total_predictions: 100561
DEBUG: loss 1.1967748520166979, rmse 1.092114716216632, mae 0.8722747022912323
Device: cuda


Calculating global mean: 100%|██████████| 1/1 [00:00<00:00,  7.88it/s]


Global mean rating: 3.583


Epoch 1/3: 100%|██████████| 311/311 [00:20<00:00, 15.47it/s, loss=2.7203]


Epoch 1 - Average Loss: 5.7880
DEBUG: Final - num_batches: 92, total_predictions: 100561
DEBUG: loss 2.1553056745425514, rmse 1.465342088744428, mae 1.2358801428236879
Validation | Loss: 2.1553, RMSE: 1.4653


Epoch 2/3: 100%|██████████| 311/311 [00:20<00:00, 15.08it/s, loss=1.8833]


Epoch 2 - Average Loss: 2.0756


[I 2025-06-02 09:23:43,849] Trial 5 pruned. 


DEBUG: Final - num_batches: 92, total_predictions: 100561
DEBUG: loss 1.2795620742051497, rmse 1.1278997278467995, mae 0.9054297533686181
Device: cuda


Calculating global mean: 100%|██████████| 1/1 [00:00<00:00,  7.45it/s]

Global mean rating: 3.583



Epoch 1/3: 100%|██████████| 311/311 [00:21<00:00, 14.57it/s, loss=1.5230]


Epoch 1 - Average Loss: 3.6705
DEBUG: Final - num_batches: 92, total_predictions: 100561
DEBUG: loss 1.271733265856038, rmse 1.0046217971902565, mae 0.786093086924457
Validation | Loss: 1.2717, RMSE: 1.0046


Epoch 2/3: 100%|██████████| 311/311 [00:19<00:00, 16.07it/s, loss=1.3546]


Epoch 2 - Average Loss: 1.4161


[I 2025-06-02 09:24:34,259] Trial 6 pruned. 


DEBUG: Final - num_batches: 92, total_predictions: 100561
DEBUG: loss 1.1787229050760684, rmse 0.9812964433961257, mae 0.7733497704737755
Device: cuda


Calculating global mean: 100%|██████████| 1/1 [00:00<00:00,  8.40it/s]

Global mean rating: 3.583



Epoch 1/3: 100%|██████████| 311/311 [00:21<00:00, 14.17it/s, loss=1.3201]


Epoch 1 - Average Loss: 2.6520
DEBUG: Final - num_batches: 92, total_predictions: 100561
DEBUG: loss 1.0153481286504995, rmse 0.9758140199787945, mae 0.7801277363336242
Validation | Loss: 1.0153, RMSE: 0.9758


Epoch 2/3: 100%|██████████| 311/311 [00:19<00:00, 16.17it/s, loss=1.1754]


Epoch 2 - Average Loss: 1.2353
DEBUG: Final - num_batches: 92, total_predictions: 100561
DEBUG: loss 0.9904710203409195, rmse 0.9625994599913547, mae 0.7669571674749329
Validation | Loss: 0.9905, RMSE: 0.9626


Epoch 3/3: 100%|██████████| 311/311 [00:20<00:00, 15.23it/s, loss=1.0924]


Epoch 3 - Average Loss: 1.1478


[I 2025-06-02 09:25:50,771] Trial 7 finished with value: 0.9446336096232201 and parameters: {'learning_rate': 0.0009560093645864017, 'hidden_dims': [128, 64], 'dropout': 0.4909998063083557, 'l2_reg': 0.0002766435018640519}. Best is trial 7 with value: 0.9446336096232201.


DEBUG: Final - num_batches: 92, total_predictions: 100561
DEBUG: loss 0.9581321400144825, rmse 0.9446336096232201, mae 0.7509628027172833
Validation | Loss: 0.9581, RMSE: 0.9446
Saved final model at ../model_checkpoints/final_model.pth.gz
Device: cuda


Calculating global mean: 100%|██████████| 1/1 [00:00<00:00,  6.91it/s]

Global mean rating: 3.583



Epoch 1/3: 100%|██████████| 311/311 [00:19<00:00, 15.84it/s, loss=1.7313]


Epoch 1 - Average Loss: 3.0399
DEBUG: Final - num_batches: 92, total_predictions: 100561
DEBUG: loss 1.3974514836850374, rmse 1.034490560351379, mae 0.8098144881085955
Validation | Loss: 1.3975, RMSE: 1.0345


Epoch 2/3: 100%|██████████| 311/311 [00:19<00:00, 15.67it/s, loss=1.3555]


Epoch 2 - Average Loss: 1.4026


[I 2025-06-02 09:26:40,054] Trial 8 pruned. 


DEBUG: Final - num_batches: 92, total_predictions: 100561
DEBUG: loss 1.3012506624926692, rmse 1.0285760007761926, mae 0.8072015794951295
Device: cuda


Calculating global mean: 100%|██████████| 1/1 [00:00<00:00,  7.71it/s]


Global mean rating: 3.583


Epoch 1/3: 100%|██████████| 311/311 [00:20<00:00, 15.05it/s, loss=2.3507]


Epoch 1 - Average Loss: 6.5927
DEBUG: Final - num_batches: 92, total_predictions: 100561
DEBUG: loss 2.596964957921401, rmse 1.59963049616769, mae 1.3640244320695716
Validation | Loss: 2.5970, RMSE: 1.5996


Epoch 2/3: 100%|██████████| 311/311 [00:19<00:00, 15.75it/s, loss=1.8408]


Epoch 2 - Average Loss: 2.0834


[I 2025-06-02 09:27:30,258] Trial 9 pruned. 


DEBUG: Final - num_batches: 92, total_predictions: 100561
DEBUG: loss 1.318318934544273, rmse 1.1317986374580662, mae 0.9118149209433396


In [9]:
print("Best trial")
print(f"RMSE: {study.best_value:.4f}")
print(f"Params: {study.best_params}")

Best trial
RMSE: 0.9446
Params: {'learning_rate': 0.0009560093645864017, 'hidden_dims': [128, 64], 'dropout': 0.4909998063083557, 'l2_reg': 0.0002766435018640519}
