# Trainable MNIST with AHSA


ASHA - Asynchronous Hyprerband is a schedulign algorithm that can be used with Random and Grid searches that monitor training performance and stop models that are not performing.

This frees up resources to launch other runs with potentially better parameters.

[AHS Paper](https://arxiv.org/pdf/1810.05934.pdf)


In [None]:
%load_ext autoreload
%autoreload 2

from dependencies import *
from mnist_pytorch import get_data_loaders
from mnist_training import *

In [None]:
ray.shutdown()
ray.init(num_cpus=6, num_gpus=1, include_webui=True)

### Adding the scheduler

Adding ASHA is as simple as configuring the scheduler object replacing the `stop` argument

In [None]:
config={
    "lr": tune.uniform(0.001, 0.1),
    "momentum": tune.uniform(0.1, 0.9),
}

asha = tune.schedulers.AsyncHyperBandScheduler(
    time_attr='training_iteration',
    metric='mean_accuracy',
    mode='max',
    max_t=100,
    grace_period=10,
    reduction_factor=3,
    brackets=3)

analysis = tune.run(
    TrainMNIST,
    local_dir="~/ray_results/torch_mnist_asha",
    resources_per_trial={
        "cpu": 1,
        "gpu": 0
    },
    num_samples=15,
    checkpoint_at_end=True,
    checkpoint_freq=10,
    keep_checkpoints_num=3,
    scheduler=asha,
#     stop={
#         "mean_accuracy": 0.95,
#         "training_iteration": 100,
#     },
    config=config)

In [None]:
print("Best config is:", analysis.get_best_config(metric="mean_accuracy"))

### Check Tensorboard

In [None]:
%load_ext tensorboard
from tensorboard import notebook 
%tensorboard --logdir "~/ray_results/torch_mnist_asha"

In [None]:
ray.shutdown()

In [None]:
# Exercises
# - change out the optimiser for adam
# - add network hyperparameters