# Training an Autoencoder with FishLeg versus ADAM

FishLeg is a second order optimizer for training neural networks. This notebook demonstrates using FishLeg to train an autoencoder, working with the MNIST dataset of handwritten digits, and compares the training performance with ADAM, a common neural network optimizer.

In [None]:
# Imports

import torch            # import pytorch
import torch.nn as nn
import numpy as np
from tqdm import tqdm   # tqdm ('te quiero demasiado') creates progress bars for loops
import time
import os               # for file manipulation and checking the file structure is correct
import sys
import matplotlib.pyplot as plt
import torch.optim as optim
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter   # tensorboard creates a local web server to analyse training runs (it stores temp files in `examples/runs`)
from data_utils import read_data_sets   # this allows for reading in the MNIST data from a web source

In [None]:
#Set the precision of torch tensors to 32 bit floats:
torch.set_default_dtype(torch.float32)

sys.path.append("../src")

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

#Specify the seed for generating random numbers, for reproducibility of results:
seed = 13
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

#Select GPU if it is available:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## GPU Acceleration
Depending on your system, PyTorch can use one of several hardware accelerations for training.

In [None]:
if torch.cuda.is_available(): # i.e. for NVIDIA GPUs
    device_type = "cuda" 
elif torch.backends.mps.is_available(): # i.e. for Apple Silicon GPUs
    device_type = "mps"
else:
    device_type = "cpu"

device = torch.device(device_type)
print(f'Running on device: {device}')

## Import the MNIST Dataset

MNIST is a famous dataset which contains examples of handwritten digits. More can be found [here](https://www.tensorflow.org/datasets/catalog/mnist).

The dataset is prepared for training by dividing it into testing and training groups, and also initialising dataloaders, which are iterator objects with automatic batching support that allow for passing the data into PyTorch.

In [None]:
dataset = read_data_sets("MNIST", "../data/", if_autoencoder=True)

# the dataset is already split into test and train groups
train_dataset = dataset.train
test_dataset = dataset.test

# using batches of 100 
batch_size = 100

#this is the main loader for the loop, it splits the data into batches of 100 randomly shuffled data entries.
# MNIST has 60,000 train data entries so this is 600 batches
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True
)

# the aux dataloader is used for the fischer to learn on the same train data. This is shuffled differently
# to the train loader, so that the batches are different. 
aux_loader = torch.utils.data.DataLoader(
    train_dataset, shuffle=True, batch_size=batch_size
)

# the test loader does not need to be shuffled as the order of testing doesnt impact the validation. There 
# are 10,000 test data entries so with a batch of 1000 this is 10 test batches
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=1000, shuffle=False
)


## Create autoencoder model

Here we create our autoencoder, `model`. We then initialise two copies of the model, one to be trained with the FishLeg optimizer and the other to be trained with ADAM, for comparison.

| Stretch/Untouched | ProbDistribution |
| --- | --- |
| Stretched | Gaussian |

The architecture of the simple autoencoder model is as follows:

| `nn.Linear(784, 1000, dtype=torch.float32)`             | Input layer: Takes in a flattened 28x28 pixel (784 values) image as input and outputs a 1000-dimensional vector |

| `nn.ReLU()`                                             | Activation function: Applies the ReLU (Rectified Linear Unit) function to introduce non-linearity |

`nn.Linear(1000, 500, dtype=torch.float32)`             Hidden layer: Takes the 1000-dimensional vector and outputs a 500-dimensional vector

`nn.ReLU()`

`nn.Linear(500, 250, dtype=torch.float32)`              Hidden layer: Takes the 500-dimensional vector and reduces the output down to a 250-dimensional vector

`nn.ReLU()` 

`nn.Linear(250, 30, dtype=torch.float32)`               Waist layer: Takes the 250-dimensional vector and outputs a 30-dimensional vector. This is the smallest compressed representation of the data, every possible feature should be able to be described with no less than 30 degrees of freedom.

`nn.Linear(30, 250, dtype=torch.float32)`               Start of the decoder part of the network: Takes the 30-dimensional vector and upscales to a 250-dimensional vector

`nn.ReLU()` 

`nn.Linear(250, 500, dtype=torch.float32)`              Hidden layer: Takes the 250-dimensional vector and upscales again to a 500-dimensional vector

`nn.ReLU()` 

`nn.Linear(500, 1000, dtype=torch.float32)`             Hidden layer: Takes the 500-dimensional vector and outputs a 1000-dimensional vector

`nn.ReLU()` 

`nn.Linear(1000, 784, dtype=torch.float32)`             Output layer: Takes the 1000-dimensional vector and outputs the 784-dimensional vector, which is the same size as the input. This is the reconstructed image.


In [None]:
model = nn.Sequential(
    nn.Linear(784, 1000, dtype=torch.float32),
    nn.ReLU(),
    nn.Linear(1000, 500, dtype=torch.float32),
    nn.ReLU(),
    nn.Linear(500, 250, dtype=torch.float32),
    nn.ReLU(),
    nn.Linear(250, 30, dtype=torch.float32),
    nn.Linear(30, 250, dtype=torch.float32),
    nn.ReLU(),
    nn.Linear(250, 500, dtype=torch.float32),
    nn.ReLU(),
    nn.Linear(500, 1000, dtype=torch.float32),
    nn.ReLU(),
    nn.Linear(1000, 784, dtype=torch.float32),
)

from optim.FishLeg import FishLeg, FISH_LIKELIHOODS, initialise_FishModel

# For compatibility with FishLeg, our model layers need to be initialised with additional parameters. This is completed by initialise_FishModel():
scale_factor = 1
damping = 0.1

model_FishLeg = initialise_FishModel(
    model, module_names="__ALL__", fish_scale=scale_factor / damping
)

# Specify the hardware device that will be used to train the model
model_FishLeg = model_FishLeg.to(device)

model_ADAM = model.to(device)


## Initialising FishLeg optimizer

In [None]:
# Setting FishLeg optimizer parameters:
eta_adam = 1e-4
lr = 0.005
beta = 0.9
weight_decay = 1e-5
aux_lr = 1e-4
aux_eps = 1e-8
update_aux_every = 10
initialization = "normal"
normalization = True
likelihood = FISH_LIKELIHOODS["bernoulli"](device=device)
writer = SummaryWriter(
    log_dir=f"runs/MNIST_fishleg/lr={lr}_auxlr={aux_lr}/{datetime.now().strftime('%Y%m%d-%H%M%S')}",
)

# Initialising FishLeg:
opt = FishLeg(
    model,
    aux_loader,
    likelihood,
    lr=lr,
    beta=beta,
    weight_decay=weight_decay,
    aux_lr=aux_lr,
    aux_betas=(0.9, 0.999),
    aux_eps=aux_eps,
    damping=damping,
    update_aux_every=update_aux_every,
    writer=writer,
    method="antithetic",
    method_kwargs={"eps": 1e-4},
    precondition_aux=True,
)

## Training with FishLeg:

In [None]:
epochs = 100

st = time.time()
eval_time = 0

for epoch in range(1, epochs + 1):
    with tqdm(train_loader, unit="batch") as tepoch:
        running_loss = 0
        for n, (batch_data, batch_labels) in enumerate(tepoch, start=1):
            tepoch.set_description(f"Epoch {epoch}")

            batch_data, batch_labels = batch_data.to(device), batch_labels.to(device)

            opt.zero_grad()
            output = model(batch_data)

            loss = likelihood(output, batch_labels)

            running_loss += loss.item()

            loss.backward()
            opt.step()

            et = time.time()
            if n % 50 == 0:
                model.eval()

                running_test_loss = 0

                for m, (test_batch_data, test_batch_labels) in enumerate(test_loader):
                    test_batch_data, test_batch_labels = test_batch_data.to(
                        device
                    ), test_batch_labels.to(device)

                    test_output = model(test_batch_data)

                    test_loss = likelihood(test_output, test_batch_labels)

                    running_test_loss += test_loss.item()

                running_test_loss /= m

                tepoch.set_postfix(loss=loss.item(), test_loss=running_test_loss)
                model.train()
                eval_time += time.time() - et

        epoch_time = time.time() - st - eval_time

        tepoch.set_postfix(
            loss=running_loss / n, test_loss=running_test_loss, epoch_time=epoch_time
        )
        # Write out the losses per epoch
        writer.add_scalar("Loss/train", running_loss / n, epoch)
        writer.add_scalar("Loss/test", running_test_loss, epoch)

        # Write out the losses per wall clock time
        writer.add_scalar("Loss/train/time", running_loss / n, epoch_time)
        writer.add_scalar("Loss/test/time", running_test_loss, epoch_time)

## Initialising ADAM optimizer:

In [None]:
lr = 0.001
# betas = (0.7, 0.9)
weight_decay = 1e-5
# eps = 1e-8
likelihood = FISH_LIKELIHOODS["bernoulli"](device=device)

opt = optim.Adam(
    model.parameters(),
    lr=lr,
    # betas=betas,
    weight_decay=weight_decay,
    # eps=eps,
)

writer = SummaryWriter(
    log_dir=f"runs/MNIST_adam/lr={lr}_lambda={weight_decay}/{datetime.now().strftime('%Y%m%d-%H%M%S')}",
)

## Training with ADAM:

In [None]:
epochs = 100

st = time.time()
eval_time = 0

for epoch in range(1, epochs + 1):
    with tqdm(train_loader, unit="batch") as tepoch:
        running_loss = 0
        for n, (batch_data, batch_labels) in enumerate(tepoch, start=1):
            tepoch.set_description(f"Epoch {epoch}")

            batch_data, batch_labels = batch_data.to(device), batch_labels.to(device)

            opt.zero_grad()
            output = model(batch_data)

            loss = likelihood(output, batch_labels)

            running_loss += loss.item()

            loss.backward()
            opt.step()

            et = time.time()
            if n % 50 == 0:
                model.eval()

                running_test_loss = 0

                for m, (test_batch_data, test_batch_labels) in enumerate(test_loader):
                    test_batch_data, test_batch_labels = test_batch_data.to(
                        device
                    ), test_batch_labels.to(device)

                    test_output = model(test_batch_data)

                    test_loss = likelihood(test_output, test_batch_labels)

                    running_test_loss += test_loss.item()

                running_test_loss /= m

                tepoch.set_postfix(loss=loss.item(), test_loss=running_test_loss)
                model.train()
                eval_time += time.time() - et

        epoch_time = time.time() - st - eval_time

        tepoch.set_postfix(
            loss=running_loss / n, test_loss=running_test_loss, epoch_time=epoch_time
        )
        # Write out the losses per epoch
        writer.add_scalar("Loss/train", running_loss / n, epoch)
        writer.add_scalar("Loss/test", running_test_loss, epoch)

        # Write out the losses per wall clock time
        writer.add_scalar("Loss/train/time", running_loss / n, epoch_time)
        writer.add_scalar("Loss/test/time", running_test_loss, epoch_time)