# Autoencoders

VAEs or even just AEs are fairly complicated compared to the models we've built so far. This notebook will go through training a AE in torchbearer and introduce callbacks and their usefulness. 

## Setup 

We done the boring bit of setting up the data loading and transforming data. We've also created a validation set from the training data using the [dataset splitter](https://torchbearer.readthedocs.io/en/latest/code/main.html#torchbearer.cv_utils.DatasetValidationSplitter) provided in torchbearer. 

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt


import torchbearer
from torchbearer.cv_utils import DatasetValidationSplitter

BATCH_SIZE = 128

transform = transforms.Compose([transforms.ToTensor()])

# Define standard classification mnist dataset with random validation set

dataset = torchvision.datasets.MNIST('./data/mnist', train=True, download=True, transform=transform)
splitter = DatasetValidationSplitter(len(dataset), 0.1)
trainset = splitter.get_train_dataset(dataset)
valset = splitter.get_val_dataset(dataset)

traingen = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)

valgen = torch.utils.data.DataLoader(valset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)

# State keys
MU, LOGVAR = torchbearer.state_key('mu'), torchbearer.state_key('logvar')

## Data targets <a id='data_targets'></a>

For this problem, our targets are the input images. 
We have some options to make the dataset we have output images as targets:
- Re-write or wrap the dataset to return images
- Write a callback that replaces the target with the data
- Replace the target with data in the model forward pass

If you would like to create a callback to do this, there is a skeleton below, however any of these solutions can be implemented here. The list of preset state keys are defined [here](https://torchbearer.readthedocs.io/en/latest/_modules/torchbearer/state.html#State) at the bottom of the file until we write them up in the docs.

In [2]:
@torchbearer.callbacks.on_sample
@torchbearer.callbacks.on_sample_validation
def replace_targets(state):
    ## TODO: Implement? 
    pass

## Model

The pytorch model is a bit fiddly so its provided mostly finished below. It might be useful to have access to the mean and log-variance later on, can you add something to the model to do this? 
Note that we reserved state keys MU and LOGVAR for possible use at this state. 

In [3]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5*logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x, state):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        ## Perhaps add something here? 

        return self.decode(z)
    
model = VAE()
optimizer = optim.Adam(model.parameters(), lr=0.001)

## Visualising

For this example we make use of a callback to visualise results. We want this to output the true and reconstructed version of one validation image. 

Here we have the visualiser function decorated twice, one decorator ensures it is called after a validation step and the other ensures it is only called once per epoch. 

In [4]:
@torchbearer.callbacks.once_per_epoch # Calls the function once at the start of each epoch
@torchbearer.callbacks.on_step_validation # Calls only after a validation step
def visualiser(state):
    data = state[torchbearer.X]
    # Get predictions and format them back into square image
    recon_batch = state[torchbearer.Y_PRED].view(-1, 1, 28, 28)
        
    fig = plt.figure()
    plt.subplot(1,2,1)
    plt.title('True')
    # Get the first image in batch and format it so pyplot can show it
    plt.imshow(data[0].repeat(3,1,1).permute(1,2,0).cpu())
            
    plt.subplot(1,2,2)
    plt.title('Reconstruction')
    plt.imshow(recon_batch[0].repeat(3,1,1).permute(1,2,0).cpu())
    plt.show()


## VAE Loss

The vae loss is the sum of the BCE_loss (or mse_loss) and the KLD loss. The BCE is a reconstruction loss and so takes y_pred and y_true, the KLD however requires the mean and log-variance. 

Can you find out how to add the KLD to the loss? Perhaps take a look at some of the callback decorators ([here](https://torchbearer.readthedocs.io/en/latest/code/callbacks.html#module-torchbearer.callbacks.decorators)) which decorate a function of state and create a callback from it, there might be one that's useful here. 

In [5]:
def bce_loss(y_pred, y_true):
    BCE = F.binary_cross_entropy(y_pred, y_true.view(-1, 784), reduction='sum')
    return BCE


def kld_Loss(mu, logvar):
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return KLD

loss = bce_loss

## TODO: Add something here






## Training

Finally, lets start training and see how the model does. Note: If the [changing the target step](#data_targets) has not been done then this will show a pretty ugly error about target sizes. 

In [6]:
from torchbearer import Trial

torchbearer_trial = Trial(model, optimizer, loss, metrics=['loss'],
                          callbacks=[replace_targets, visualiser], pass_state=True).to('cuda')
torchbearer_trial.with_generators(train_generator=traingen, val_generator=valgen)
torchbearer_trial.run(epochs=10)

0/10(t):   0%|          | 0/422 [00:00<?, ?it/s]

RuntimeError: invalid argument 2: size '[-1 x 784]' is invalid for input with 128 elements at /pytorch/aten/src/TH/THStorage.cpp:80