# Lab 4 - Tune the hyperparameters of a CNN on MNIST

This tutorial walks through using Ax to tune two hyperparameters (learning rate and momentum) for a PyTorch CNN on the MNIST dataset trained using SGD with momentum. Adapted from https://ax.dev/tutorials/tune_cnn.html and a tutorial about the methods is available at https://ax.dev/docs/bayesopt.html

1. Run through the tutorial, then 
2. Write your own code for grid search and random search (remember the logarithmic transforms).
3. Create your own `net` structure to go into the `train_evaluate()` function and try optimising your own network. Consider also adapting the network structure. You could also try adapting some of the examples used in earlier labs to see whether you can improve on your results with hyperparameter search.


In [None]:
!pip3 install ax-platform 

In [None]:
import torch
import numpy as np

from ax.plot.contour import plot_contour
from ax.plot.trace import optimization_trace_single_method
from ax.service.managed_loop import optimize
from ax.utils.notebook.plotting import render, init_notebook_plotting
from ax.utils.tutorials.cnn_utils import load_mnist, train, evaluate, CNN

init_notebook_plotting(offline=True)

In [None]:
torch.manual_seed(12345)
dtype = torch.float
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## 1. Load MNIST data
First, we need to load the MNIST data and partition it into training, validation, and test sets.

Note: this will download the dataset if necessary.

In [None]:
BATCH_SIZE = 512
train_loader, valid_loader, test_loader = load_mnist(batch_size=BATCH_SIZE)

## 2. Define function to optimize
In this tutorial, we want to optimize classification accuracy on the validation set as a function of the learning rate and momentum. The function takes in a parameterization (set of parameter values), computes the classification accuracy, and returns a dictionary of metric name ('accuracy') to a tuple with the mean and standard error.

The CNN() function is a simple preconfigured ConvNet - see the code at https://ax.dev/api/_modules/ax/utils/tutorials/cnn_utils.html


In [None]:
print(CNN())

def train_evaluate(parameterization):
    net = CNN()
    net = train(net=net, train_loader=train_loader, parameters=parameterization, dtype=dtype, device=device)
    return evaluate(
        net=net,
        data_loader=valid_loader,
        dtype=dtype,
        device=device,
    )

## 3. Run the optimization loop
Here, we set the bounds on the learning rate and momentum and set the parameter space for the learning rate to be on a log scale. 

In [None]:
#ax.optimize(parameters, evaluation_function, experiment_name=None, objective_name=None, minimize=False, parameter_constraints=None, outcome_constraints=None, total_trials=20, arms_per_trial=1, random_seed=None, generation_strategy=None)

best_parameters, values, experiment, model = optimize(
    parameters=[
        {"name": "lr", "type": "range", "bounds": [1e-6, 0.4], "log_scale": True},
        {"name": "momentum", "type": "range", "bounds": [0.0, 1.0]},
    ],
    evaluation_function=train_evaluate,
    objective_name='accuracy',
)

We can introspect the optimal parameters and their outcomes:

In [None]:
best_parameters

In [None]:
means, covariances = values
means, covariances

## 4. Plot response surface

Contour plot showing classification accuracy as a function of the two hyperparameters.

The black squares show points that we have actually run, notice how they are clustered in the optimal region.

In [None]:
# some boilerplate to make things render on colab
import plotly.io as pio
pio.renderers.default = 'colab'

In [None]:
render(plot_contour(model=model, param_x='lr', param_y='momentum', metric_name='accuracy'))

## 5. Plot best objective as function of the iteration

Show the model accuracy improving as we identify better hyperparameters.

In [None]:
# `plot_single_method` expects a 2-d array of means, because it expects to average means from multiple 
# optimization runs, so we wrap out best objectives array in another array.
best_objectives = np.array([[trial.objective_mean*100 for trial in experiment.trials.values()]])
best_objective_plot = optimization_trace_single_method(
    y=np.maximum.accumulate(best_objectives, axis=1),
    title="Model performance vs. # of iterations",
    ylabel="Classification Accuracy, %",
)
render(best_objective_plot)

## 6. Train CNN with best hyperparameters and evaluate on test set
Note that the resulting accuracy on the test set might not be exactly the same as the maximum accuracy achieved on the evaluation set throughout optimization. 

In [None]:
data = experiment.fetch_data()
df = data.df
best_arm_name = df.arm_name[df['mean'] == df['mean'].max()].values[0]
best_arm = experiment.arms_by_name[best_arm_name]
best_arm

In [None]:
combined_train_valid_set = torch.utils.data.ConcatDataset([
    train_loader.dataset.dataset, 
    valid_loader.dataset.dataset,
])
combined_train_valid_loader = torch.utils.data.DataLoader(
    combined_train_valid_set, 
    batch_size=BATCH_SIZE, 
    shuffle=True,
)

In [None]:
net = train(
    net=CNN(),
    train_loader=combined_train_valid_loader, 
    parameters=best_arm.parameters,
    dtype=dtype,
    device=device,
)
test_accuracy = evaluate(
    net=net,
    data_loader=test_loader,
    dtype=dtype,
    device=device,
)

In [None]:
print(f"Classification Accuracy (test set): {round(test_accuracy*100, 2)}%")

# Task - apply this approach to a new problem
You can apply this approach to any of the lab tasks we have tried earlier, or you can optimise other parameters than in this example.

For example, you can try adding $\ell_1$ or $\ell_2$ regularisation to the model.  The $\ell_2$ case is straightforward -- just pass a parameter `"weight_decay"` for the $\ell_2$ parameter to the `parameters=[` code above.



You will also find these hyperparameter optimisation tools useful in your Assessed Exercise

If you want to go a bit further you could try experimenting with $\ell_1$ regularisation.  Testing $\ell_1$ regularisation is slightly more tricky. One approach would be to adapt the `train()` command in https://ax.dev/api/_modules/ax/utils/tutorials/cnn_utils.html then augmenting your training loop with the additional terms for the loss function associated with the $\ell_1$ cost.

l1_penalty = nn.L1Loss(size_average=False)

reg_loss = 0

for param in model.parameters():

>  reg_loss += l1_penalty(param)

>  factor = const_val #lambda
 
>  loss += factor * reg_loss 
