In [1]:
# Note: This is a hack to allow importing from the parent directory
import sys
from pathlib import Path

sys.path.append(str(Path().resolve().parent))

# Note: Ignore warnings, be brave (YoLo)
import warnings

warnings.filterwarnings("ignore")

In [2]:
import torch
import optuna
import torch.nn as nn
import torch.optim as optim
from models import ResNetAutoencoder
from data import CIFAR10GaussianSplatsDataset
from utils import train, transform_and_collate

results_path = Path("../logs/resnet_autoencoder_test_1/")
results_path.mkdir(parents=True, exist_ok=True)

# Use one model for the whole splat (all at once)
channels_dim = 23
join_mode = "concat"

train_dataset = CIFAR10GaussianSplatsDataset(
    root="../data/CIFAR10GS",
    train=True,
    init_type="grid",
)
val_dataset = CIFAR10GaussianSplatsDataset(
    root="../data/CIFAR10GS",
    val=True,
    init_type="grid",
)

In [3]:
def objective(trial):
    # Define hyperparameter search space
    lr = trial.suggest_loguniform("lr", 1e-5, 1e-2)
    weight_decay = trial.suggest_loguniform("weight_decay", 1e-6, 1e-3)
    epochs = trial.suggest_categorical("epochs", [10, 25, 50])
    grad_clip = trial.suggest_uniform("grad_clip", 0.5, 2.0)
    weight_init = trial.suggest_categorical("weight_init", [True, False])

    # Define train parameters
    model = ResNetAutoencoder(channels_dim=channels_dim, weight_init=weight_init)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=32,
        shuffle=True,
        num_workers=4,
        collate_fn=lambda batch: transform_and_collate(batch, join_mode),
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=4,
        collate_fn=lambda batch: transform_and_collate(batch, join_mode),
    )
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = nn.MSELoss()
    epochs = epochs
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=5)
    grad_clip = grad_clip
    compile_model = True

    results = train(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        criterion=criterion,
        epochs=epochs,
        device=device,
        scheduler=scheduler,
        grad_clip=grad_clip,
        logger=print,
        compile_model=compile_model,
        model_path=results_path / "model.pt",
    )
    return results["val_loss"][-1]


# Run hyperparameter search
study = optuna.create_study(
    direction="minimize",
    study_name="resnet_autoencoder_test_1",
    pruner=optuna.pruners.MedianPruner(n_warmup_steps=5),
)
study.optimize(objective, n_trials=100, n_jobs=4)
print(f"Best trial:{study.best_trial}")
with open(results_path / "best_trial.txt", "w") as f:
    f.write(f"Best trial:{study.best_trial}")
optuna.visualization.plot_optimization_history(study).write_image(
    str(results_path / "opt_history.png")
)
optuna.visualization.plot_param_importances(study).write_image(
    str(results_path / "param_importances.png")
)

[I 2025-01-20 18:51:43,337] A new study created in memory with name: resnet_autoencoder_test_1


Epoch 1/1: 100%|██████████| 9/9 [03:29<00:00, 23.32s/batch]
[I 2025-01-20 18:56:24,215] Trial 0 finished with value: 0.326013445854187 and parameters: {'lr': 0.1, 'weight_decay': 0.1, 'loss_fn': <class 'torch.nn.modules.loss.L1Loss'>, 'epochs': 1, 'grad_clip': 1.0, 'weight_init': True}. Best is trial 0 with value: 0.326013445854187.


Epoch 1/1 | Train Loss: 0.4190 | Val Loss: 0.3260
Train Loss: 0.4190 | Val Loss: 0.3260 | Training time: 278.91s
Best trial:FrozenTrial(number=0, state=1, values=[0.326013445854187], datetime_start=datetime.datetime(2025, 1, 20, 18, 51, 43, 339111), datetime_complete=datetime.datetime(2025, 1, 20, 18, 56, 24, 215417), params={'lr': 0.1, 'weight_decay': 0.1, 'loss_fn': <class 'torch.nn.modules.loss.L1Loss'>, 'epochs': 1, 'grad_clip': 1.0, 'weight_init': True}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'lr': FloatDistribution(high=0.1, log=True, low=0.1, step=None), 'weight_decay': FloatDistribution(high=0.1, log=True, low=0.1, step=None), 'loss_fn': CategoricalDistribution(choices=(<class 'torch.nn.modules.loss.L1Loss'>,)), 'epochs': IntDistribution(high=1, log=False, low=1, step=1), 'grad_clip': FloatDistribution(high=1.0, log=False, low=1.0, step=None), 'weight_init': CategoricalDistribution(choices=(True,))}, trial_id=0, value=None)
