# Hyperparameter Tuning with Optuna

This script was used for tuning hyperparameters with Optuna, using https://github.com/optuna/optuna-examples/blob/main/pytorch/pytorch_simple.py as a guide for the objective function. Important considerations are marked by `NOTE` in the comments.

In [None]:
import os

import optuna
from optuna.trial import TrialState
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from torchvision import datasets
from torchvision import transforms

In [None]:
DEVICE = torch.device("mps")
BATCHSIZE = 512
CLASSES = 10
DIR = os.getcwd() #./datafiles/
EPOCHS = 20
N_TRAIN_EXAMPLES = BATCHSIZE * 30
N_VALID_EXAMPLES = BATCHSIZE * 10

LOSS_FN = nn.CrossEntropyLoss()

torch.manual_seed(0)

In [None]:
def get_mnist():

    # Any data augmentation should be added to training
    train_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.RandomRotation(15), #NOTE: this was commented out for the "no data augmentation runs"
        transforms.RandomAffine(25), #NOTE: this was commented out for the "no data augmentation runs"
        transforms.ElasticTransform(alpha=70.0), #NOTE: this was commented out for the "no data augmentation runs"
        transforms.Normalize(mean=0.1307, std=0.3081),
    ])

    # Test data should have normalization applied, but no augmentation
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=0.1307, std=0.3081)
    ])

    # Load FashionMNIST dataset.
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(DIR, train=True, download=True, transform=transforms.ToTensor()),
        batch_size=BATCHSIZE,
        shuffle=True,
    )
    valid_loader = torch.utils.data.DataLoader(
        datasets.MNIST(DIR, train=False, download=True, transform=transforms.ToTensor()),
        batch_size=BATCHSIZE,
        shuffle=True,
    )

    return train_loader, valid_loader

In [None]:
import math
def get_output_size(input_size, padding, stride, kernel):   
    return math.floor((input_size + 2*padding - kernel)/stride) + 1

In [None]:
class MNIST_Model(nn.Module):
    def __init__(self, trial):
        super().__init__()

        img_size = 28

        # NOTE: optuna params: 1st layer output could be 8 16 32
        layer1_channels_exp = trial.suggest_int("layer1_channels_exp", 3, 5)
        # NOTE: optuna params: 2nd layer output could be 64 128 256
        layer2_channels_exp = trial.suggest_int("layer2_channels_exp", 6, 8)

        layer1_channels = 2 ** layer1_channels_exp
        layer2_channels = 2 ** layer2_channels_exp

        self.conv1 = nn.Conv2d(1, layer1_channels, kernel_size=(5, 5), padding='same')
        self.conv2 = nn.Conv2d(layer1_channels, layer2_channels, kernel_size=(3, 3), padding='same')
        self.mp = nn.MaxPool2d(kernel_size=(2, 2), stride=2, padding=1)
        
        # Batch nrom
        self.bn1 = nn.BatchNorm2d(layer1_channels)
        self.bn2 = nn.BatchNorm2d(layer2_channels)

        # get output size for linear layer since it is varied each trial
        out1 = get_output_size(input_size=img_size, padding=1, stride=2, kernel=2)
        out2 = get_output_size(input_size=out1, padding=1, stride=2, kernel=2)

        # Activation
        self.relu = nn.ReLU()

        p = trial.suggest_float("dropout_p:", 0, 0.1) #NOTE: optuna param
        self.dropout = nn.Dropout(p)
        self.output_layer = nn.Linear(layer2_channels*out2*out2, CLASSES)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.mp(x)
        x = self.bn1(x)

        x = self.conv2(x)
        x = self.relu(x)
        x = self.mp(x)
        x = self.bn2(x)

        x = self.dropout(x)
        x = x.view(x.size(0), -1) 
        x = self.output_layer(x)
        return x

def define_model(trial):
    model = MNIST_Model(trial)
    return model


In [None]:
def objective(trial):
    # NOTE: Credit to https://github.com/optuna/optuna-examples/blob/main/pytorch/pytorch_simple.py 
    # for providing this training loop example code, which we used with our own added metrics and
    # other small changes for the experimentation
    
    # Generate the model.
    model = define_model(trial).to(DEVICE)

    # Generate the optimizers.
    optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop"]) #NOTE: Optuna optim type param
    weight_decay = trial.suggest_float("weight_decay", 1e-4, 1e-2, log=True) #NOTE: Optuna weight decay param.
    lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True) #NOTE: Optuna learning rate param
    optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr, weight_decay=weight_decay)

    # Get the FashionMNIST dataset.
    train_loader, valid_loader = get_mnist()
    
    # Initialize variables for tracking the best accuracy and the number of epochs since improvement
    best_patience = 0
    epochs_since_improvement = 0
    best_accuracy = 0  # Initialize to 0 for accuracy
    best_epoch = -1

    # Training of the model.
    for epoch in range(EPOCHS):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            # Limiting training data for faster epochs.
            if batch_idx * BATCHSIZE >= N_TRAIN_EXAMPLES:
                break

            data, target = data.to(DEVICE), target.to(DEVICE)
            optimizer.zero_grad()
            output = model(data)
            loss = LOSS_FN(output, target)
            loss.backward()
            optimizer.step()

        # Validation of the model.
        model.eval()
        correct = 0
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(valid_loader):
                # Limiting validation data.
                if batch_idx * BATCHSIZE >= N_VALID_EXAMPLES:
                    break
                data, target = data.to(DEVICE), target.to(DEVICE)
                output = model(data)
                # Get the index of the max log-probability.
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()

        accuracy = correct / min(len(valid_loader.dataset), N_VALID_EXAMPLES)
        
        # NOTE: Added to track if early stopping is necessary
        if accuracy > best_accuracy:
            best_accuracy = accuracy 
            best_epoch = epoch
            if epochs_since_improvement > 0:
                if best_patience < epochs_since_improvement:
                    best_patience = epochs_since_improvement
                epochs_since_improvement = 0
        else:
            epochs_since_improvement += 1

        trial.report(accuracy, epoch)

        # Handle pruning based on the intermediate value.
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    # NOTE: these were added for tracking early stopping
    print(f"Best patience value: {best_patience}")
    trial.set_user_attr('best_epoch', best_epoch)
    trial.set_user_attr('best_accuracy', best_accuracy)
    return accuracy


In [None]:
study = optuna.create_study(direction="maximize")
# 200 trials were done with no data augmentation, then another 200 trials WITH data augmenation
study.optimize(objective, n_trials=200, timeout=600)

pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

print("Study statistics: ")
print("  Number of finished trials: ", len(study.trials))
print("  Number of pruned trials: ", len(pruned_trials))
print("  Number of complete trials: ", len(complete_trials))

print("Best trial:")
trial = study.best_trial

print("  Value: ", trial.value)

print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))
for key, value in trial.user_attrs.items():
    print("    {}: {}".format(key, value))