In [1]:
# Note: This is a hack to allow importing from the parent directory
import sys
from pathlib import Path

sys.path.append(str(Path().resolve().parent))

# Note: Ignore warnings, be brave (YoLo)
import warnings

warnings.filterwarnings("ignore")

In [2]:
import torch
import optuna
import torch.nn as nn
import torch.optim as optim
from models import ConvAutoencoder
from data import CIFAR10GaussianSplatsDataset
from utils import train, transform_and_collate

results_path = Path("../logs/conv_autoencoder_test_1/")
results_path.mkdir(parents=True, exist_ok=True)

# Use one model for the whole splat
channels_dim = 23
join_mode = "concat"

train_dataset = CIFAR10GaussianSplatsDataset(
    root="../data/CIFAR10GS",
    train=True,
    init_type="grid",
)
val_dataset = CIFAR10GaussianSplatsDataset(
    root="../data/CIFAR10GS",
    val=True,
    init_type="grid",
)

In [3]:
def objective(trial):
    # Define hyperparameter search space
    lr = trial.suggest_loguniform("lr", 1e-5, 1e-2)
    weight_decay = trial.suggest_loguniform("weight_decay", 1e-6, 1e-3)
    loss_fn = trial.suggest_categorical("loss_fn", [nn.L1Loss, nn.MSELoss])
    epochs = trial.suggest_int("epochs", 10, 100, 50)
    grad_clip = trial.suggest_uniform("grad_clip", 0.5, 2.0)
    weight_init = trial.suggest_categorical("weight_init", [True, False])

    # Define train parameters
    model = ConvAutoencoder(channels_dim=channels_dim, weight_init=weight_init)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=32,
        shuffle=True,
        num_workers=4,
        collate_fn=lambda batch: transform_and_collate(batch, join_mode),
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=4,
        collate_fn=lambda batch: transform_and_collate(batch, join_mode),
    )
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = loss_fn()
    epochs = epochs
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=5)
    grad_clip = grad_clip
    compile_model = True

    results = train(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        criterion=criterion,
        epochs=epochs,
        device=device,
        scheduler=scheduler,
        grad_clip=grad_clip,
        logger=print,
        compile_model=compile_model,
        model_path=results_path / "model.pth",
    )
    return results["val_loss"][-1]


# Run hyperparameter search
study = optuna.create_study(direction="minimize", study_name="conv_autoencoder_test_1")
study.optimize(objective, n_trials=100)
print(f"Best trial:{study.best_trial}")
with open(results_path / "best_trial.txt", "w") as f:
    f.write(f"Best trial:{study.best_trial}")

[I 2025-01-19 17:54:18,111] A new study created in memory with name: conv_autoencoder_test_1
Epoch 1/1: 100%|██████████| 9/9 [00:57<00:00,  6.36s/batch]
[I 2025-01-19 17:55:39,453] Trial 0 finished with value: 0.30177072683970135 and parameters: {'lr': 0.1, 'weight_decay': 0.1, 'loss_fn': <class 'torch.nn.modules.loss.L1Loss'>, 'epochs': 1, 'grad_clip': 1.0, 'weight_init': True}. Best is trial 0 with value: 0.30177072683970135.


Epoch 1/1 | Train Loss: 0.3089 | Val Loss: 0.3018
Train Loss: 0.3089 | Val Loss: 0.3018 | Training time: 79.77s
Best trial:FrozenTrial(number=0, state=1, values=[0.30177072683970135], datetime_start=datetime.datetime(2025, 1, 19, 17, 54, 18, 112621), datetime_complete=datetime.datetime(2025, 1, 19, 17, 55, 39, 452573), params={'lr': 0.1, 'weight_decay': 0.1, 'loss_fn': <class 'torch.nn.modules.loss.L1Loss'>, 'epochs': 1, 'grad_clip': 1.0, 'weight_init': True}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'lr': FloatDistribution(high=0.1, log=True, low=0.1, step=None), 'weight_decay': FloatDistribution(high=0.1, log=True, low=0.1, step=None), 'loss_fn': CategoricalDistribution(choices=(<class 'torch.nn.modules.loss.L1Loss'>,)), 'epochs': IntDistribution(high=1, log=False, low=1, step=1), 'grad_clip': FloatDistribution(high=1.0, log=False, low=1.0, step=None), 'weight_init': CategoricalDistribution(choices=(True,))}, trial_id=0, value=None)
