Skip to content

Latest commit

 

History

History
146 lines (93 loc) · 4.93 KB

vae.rst

File metadata and controls

146 lines (93 loc) · 4.93 KB

Training a Variational Auto-Encoder

This guide will give a quick guide on training a variational auto-encoder (VAE) in torchbearer. We will use the VAE example from the pytorch examples here:

Defining the Model

We shall first copy the VAE example model.

/_static/examples/vae_standard.py

Defining the Data

We get the MNIST dataset from torchvision, split it into a train and validation set and transform them to torch tensors.

/_static/examples/vae_standard.py

The output label from this dataset is the classification label, since we are doing a auto-encoding problem, we wish the label to be the original image. To fix this we create a wrapper class which replaces the classification label with the image.

/_static/examples/vae_standard.py

We then wrap the original datasets and create training and testing data generators in the standard pytorch way.

/_static/examples/vae_standard.py

Defining the Loss

Now we have the model and data, we will need a loss function to optimize. VAEs typically take the sum of a reconstruction loss and a KL-divergence loss to form the final loss value.

/_static/examples/vae.py

/_static/examples/vae.py

There are two ways this can be done in torchbearer - one is very similar to the PyTorch example method and the other utilises the torchbearer state.

PyTorch method

The loss function slightly modified from the PyTorch example is:

/_static/examples/vae_standard.py

This requires the packing of the reconstruction, mean and log-variance into the model output and unpacking it for the loss function to use.

/_static/examples/vae_standard.py

Using Torchbearer State

Instead of having to pack and unpack the mean and variance in the forward pass, in torchbearer there is a persistent state dictionary which can be used to conveniently hold such intermediate tensors. We can (and should) generate unique state keys for interacting with state:

/_static/examples/vae.py

By default the model forward pass does not have access to the state dictionary, but setting the pass_state flag to true when initialising Trial gives the model access to state on forward.

/_static/examples/vae.py

We can then modify the model forward pass to store the mean and log-variance under suitable keys.

/_static/examples/vae.py

The reconstruction loss is a standard loss taking network output and the true label

/_static/examples/vae.py

Since loss functions cannot access state, we utilise a simple callback to combine the kld loss which does not act on network output or true label.

/_static/examples/vae.py

Visualising Results

For auto-encoding problems it is often useful to visualise the reconstructions. We can do this in torchbearer by using another simple callback. We stack the first 8 images from the first validation batch and pass them to torchvisions save_image function which saves out visualisations.

/_static/examples/vae.py

Training the Model

We train the model by creating a torchmodel and a torchbearertrialand calling run.

/_static/examples/vae.py

The visualised results after ten epochs then look like this:

Source Code

The source code for the example are given below:

Standard:

Download Python source code: vae_standard.py </_static/examples/vae_standard.py>

Using state:

Download Python source code: vae.py </_static/examples/vae.py>