In [1]:
# Note: This is a hack to allow importing from the parent directory
import sys
from pathlib import Path

sys.path.append(str(Path().resolve().parent))

# Note: Ignore warnings, be brave (YoLo)
import warnings

warnings.filterwarnings("ignore")

In [2]:
import torch
import optuna
import torch.nn as nn
import torch.optim as optim
from models import DeepFlatAutoencoder
from data import CIFAR10GaussianSplatsDataset
from utils import train, transform_and_collate


input_dim = 23552
train_dataset = CIFAR10GaussianSplatsDataset(
    root="../data/CIFAR10GS",
    train=True,
    init_type="grid",
)
val_dataset = CIFAR10GaussianSplatsDataset(
    root="../data/CIFAR10GS",
    val=True,
    init_type="grid",
)

In [3]:
def objective(trial):
    # Define hyperparameter search space
    latent_dim = trial.suggest_categorical("latent_dim", [256, 512, 1024])
    lr = trial.suggest_loguniform("lr", 1e-5, 1e-2)
    weight_decay = trial.suggest_loguniform("weight_decay", 1e-6, 1e-3)
    loss_fn = trial.suggest_categorical("loss_fn", [nn.L1Loss, nn.MSELoss])
    epochs = trial.suggest_int("epochs", 10, 100, 50)
    grad_clip = trial.suggest_uniform("grad_clip", 0.5, 2.0)
    weight_init = trial.suggest_categorical("weight_init", [True, False])

    # Define train parameters
    model = DeepFlatAutoencoder(
        input_dim=input_dim, latent_dim=latent_dim, weight_init=weight_init
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=32,
        shuffle=True,
        num_workers=4,
        collate_fn=transform_and_collate,
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=4,
        collate_fn=transform_and_collate,
    )
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = loss_fn()
    epochs = epochs
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=5)
    grad_clip = grad_clip
    compile_model = True

    results = train(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        criterion=criterion,
        epochs=epochs,
        device=device,
        scheduler=scheduler,
        grad_clip=grad_clip,
        logger=print,
        compile_model=compile_model,
    )
    return results["val_loss"][-1]


# Run hyperparameter search
study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=100)
print(f"Best trial:{study.best_trial}")

[I 2025-01-16 22:35:07,875] A new study created in memory with name: no-name-cfdb33d6-3284-41fe-94b7-aed0aaf809d3
Epoch 1/1: 100%|██████████| 9/9 [01:11<00:00,  7.98s/batch]


Epoch 1/1 | Train Loss: 1.0799 | Val Loss: 0.7746


[I 2025-01-16 22:36:33,213] Trial 0 finished with value: 0.774617592493693 and parameters: {'latent_dim': 1024, 'lr': 0.01, 'weight_decay': 0.001, 'loss_fn': <class 'torch.nn.modules.loss.MSELoss'>, 'epochs': 1, 'grad_clip': 0.0, 'weight_init': True}. Best is trial 0 with value: 0.774617592493693.


Best trial:FrozenTrial(number=0, state=1, values=[0.774617592493693], datetime_start=datetime.datetime(2025, 1, 16, 22, 35, 7, 876694), datetime_complete=datetime.datetime(2025, 1, 16, 22, 36, 33, 209816), params={'latent_dim': 1024, 'lr': 0.01, 'weight_decay': 0.001, 'loss_fn': <class 'torch.nn.modules.loss.MSELoss'>, 'epochs': 1, 'grad_clip': 0.0, 'weight_init': True}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'latent_dim': CategoricalDistribution(choices=(1024,)), 'lr': FloatDistribution(high=0.01, log=True, low=0.01, step=None), 'weight_decay': FloatDistribution(high=0.001, log=True, low=0.001, step=None), 'loss_fn': CategoricalDistribution(choices=(<class 'torch.nn.modules.loss.MSELoss'>,)), 'epochs': IntDistribution(high=1, log=False, low=1, step=1), 'grad_clip': FloatDistribution(high=0.0, log=False, low=0.0, step=None), 'weight_init': CategoricalDistribution(choices=(True,))}, trial_id=0, value=None)
