# VAE guide

In this tutorial you'll learn how to make your very own variational autoencoder (VAE).
We'll be using the [Keras functional API](https://keras.io/getting-started/functional-api-guide/). The original paper describing the VAE, can be found [here](https://arxiv.org/abs/1312.6114).

### Autoencoder

First, we need to get an idea of whan an autoencoder is. The goal of an autoencoder is to compress data, yet keeping the information describing the data. The autoencoder will implicitly extract the most descriptive components of the data when compressing it, and it is therefore a commonly used component in creative architectures within DL. The idea behind an autoencoder is relatively simple - push raw data (X) through a shrinking pipeline (q) and make the autoencoder network learn what features to remove and what to keep. Push the shrunken data (z) through an increasing sized pipeline (p), and make the autoencoder network learn what features it needs to add to reproduce the raw input, as shown in the figure below.

<img src="images/autoencoder_Jwaaler.png" title="Auoencoder architecture" width="500"/>

#### Variational autoencoder

Let's think probability. The joint probability of the model above is:

$$p(X,z) = p (X | z)p(z)$$

According to Bayes: $$p(z | X) = \frac{p(X | z)p(z)}{p(X)}$$

Which can be calculated by marginalizing the latent varables:

$$p(X) = \int p(X | z)p(z)dz $$

Which is hard to calculate, as it requires exponentional computational time over all configurations of the latent variables.

This is why we approximate $p(z | X)$ with a family of distributions $q_{\lambda}(z | X)$, where the variational parameter $\lambda$ represents the family of distributions. To calulate the amount of lost information when approximating $p(z | X)$ using $q(z | X)$, [Kullback-Leibler divergence](https://towardsdatascience.com/demystifying-kl-divergence-7ebe4317ee68) is used (which is impossible to compute directly, hence the use of [Evidence Lower BOund](http://edwardlib.org/tutorials/klqp)).

I highly recommend reading [this](https://jaan.io/what-is-variational-autoencoder-vae-tutorial/) before you continue. Intuetively, the model is tasked to learn the distribution of the input data.

### Implementation

First, let's create a function that samples from the latent space, given mean ($\mu$) and variance ($\sigma^2$) as arguments. We sample from an isotropic Gaussian distribution so we can train the network. This process is called the reparameteriazation trick, which you can read about [here](https://stats.stackexchange.com/questions/199605/how-does-the-reparameterization-trick-for-vaes-work-and-why-is-it-important). Instead of sampling from $q(z | X)$, we sample $\epsilon = N(0,1)$, so that $z = \mu + \sqrt{\sigma^2}\epsilon$

In [1]:
from keras import backend as K


def sampling(input_tensor):
    z_mean, z_log_var = input_tensor
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    
    epsilon = K.random_normal(shape=(batch, dim))
    
    return z_mean + K.exp(0.5 * z_log_var) * epsilon

Using TensorFlow backend.


In [2]:
# Let's set some parameters for the model
params = {
    'filters': 8, # How many filters should the first convolutional layer contain?
    'hidden_layers': 4, # How many hidden layers should the encoder and the decoder contain?
    'latent_dim': 4, # How many distributions should the VAE learn?
    'kernel_size': 3, # How big should the kernels be?
    'input_shape': (128, 128, 1), # What is the input-shape?
    'last_activation': 'sigmoid' # What should be the last activation?
}

params['filters'] //= 2

In [3]:
from keras.layers import (Input,
                          Dense,
                          Conv2D,
                          Flatten,
                          Lambda)

from keras.models import Model


def encoder(input_shape,
            filters,
            hidden_layers,
            latent_dim,
            kernel_size,
            **kwargs) -> "Model":
    
    inputs = Input(shape=input_shape, name='encoder_input')
    x = inputs

    # The encoder is "shrinking" the input by stride = 2.
    # The number of filters is increasing by a factor of two for each hidden layer
    for i in range(hidden_layers):
        filters *= 2
        x = Conv2D(filters=filters,
                   kernel_size=kernel_size,
                   activation='relu',
                   strides=2,
                   padding='same',
                   name='encoder_conv_{}'.format(i))(x)
        
        

    intermediate_shape = K.int_shape(x)
    
    # The latent vector is created (q(z | X))
    x = Flatten(name='encoder_flatten')(x)
    x = Dense(latent_dim*8, activation='relu', name='encoder_intermediate_layer')(x)
    z_mean = Dense(latent_dim, name='encoder_z_mean')(x)
    z_log_var = Dense(latent_dim, name='encoder_z_log_var')(x)
    
    # Sample z
    z = Lambda(sampling, output_shape=(latent_dim,), name='encoder_z')([z_mean, z_log_var])
    
    encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')
    params['filters'] = filters
    return encoder, intermediate_shape

In [4]:
from keras.layers import (Conv2DTranspose,
                          Reshape)

def decoder(prev,
            intermediate_shape,
            hidden_layers,
            latent_dim,
            last_activation,
            kernel_size,
            **kwargs) -> "Model":
    filters = intermediate_shape[3]
    # Upsampling and reshaping to intermediate shape
    latent_inputs = Input(shape=(latent_dim,), name='decoder_sampled_input')
    x = Dense(intermediate_shape[1] * intermediate_shape[2] * intermediate_shape[3],
              activation='relu',
              name='decoder_intermediate')(latent_inputs)
    x = Reshape((intermediate_shape[1], intermediate_shape[2], intermediate_shape[3]),
                name='decoder_reshape')(x)
    
    for i in range(hidden_layers):
        
        x = Conv2DTranspose(filters=filters,
                            kernel_size=kernel_size,
                            activation='relu',
                            strides=2,
                            padding='same',
                            name='decoder_conv_{}'.format(i))(x)
        
        filters //= 2
        
    outputs = Conv2DTranspose(filters=1,
                              kernel_size=kernel_size,
                              activation=last_activation,
                              padding='same',
                              name='decoder_output')(x)

    # Make model object
    decoder = Model(latent_inputs, outputs, name='decoder')
    
    return decoder

In [5]:
# Instatiating the models - gluing it all together
encoder_model, intermediate_shape = encoder(**params)
decoder_model = decoder(encoder_model.outputs[2], intermediate_shape, **params)

outputs = decoder_model(encoder_model(encoder_model.inputs)[2])

vae = Model(encoder_model.inputs, outputs, name = 'vae')

print('ENCODER')
print(encoder_model.summary())

print('\nDECODER')
print(decoder_model.summary())

print('\nVAE')
print(vae.summary())

Instructions for updating:
Colocations handled automatically by placer.
ENCODER
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
encoder_input (InputLayer)      (None, 128, 128, 1)  0                                            
__________________________________________________________________________________________________
encoder_conv_0 (Conv2D)         (None, 64, 64, 8)    80          encoder_input[0][0]              
__________________________________________________________________________________________________
encoder_conv_1 (Conv2D)         (None, 32, 32, 16)   1168        encoder_conv_0[0][0]             
__________________________________________________________________________________________________
encoder_conv_2 (Conv2D)         (None, 16, 16, 32)   4640        encoder_conv_1[0][0]             
_____________________________

In [6]:
from keras.losses import mse

# Adding loss - both (scaled) reconstruction and KL

reconstruction_loss = mse(K.flatten(encoder_model.inputs), K.flatten(outputs))
reconstruction_loss *= params['input_shape'][0]*params['input_shape'][0]
kl_loss = 1. + encoder_model.outputs[1] - K.square(encoder_model.outputs[0]) - K.exp(encoder_model.outputs[1])
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5
vae_loss = K.mean(reconstruction_loss + kl_loss)
vae.add_loss(vae_loss)