# Variational Autoencoder
_(Requires Python 3, PyTorch 1.0.1, TorchVision 0.2.2)_

**Reference**: _D.Kingma and M.Welling,_ [Auto-Encoding Variational Bayes](https://arxiv.org/abs/1312.6114)

### Libraries
Import torch

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as func
import torch.utils.data as Data

we'll also need numpy and matplotlib

In [None]:
import numpy as np
import matplotlib.pyplot as plt

## Model
The variational autoencoder has a conventional structure of encoder plus decoder. 

In a VAE, the encoder does not directly generate the latent representations. Instead, it generates the parameters for a multidimensional random variable whose distribution is to be defined. 

> We will use independent multivariate Normal distributions, so our encoder will generate the mean and the standard deviation for each latent dimension.

### Encoder

In [None]:
class Encoder(nn.Module):
    
    def __init__(self, input_dim, hidden_dims, lat_dim, activations=func.relu):
        super(Encoder, self).__init__()
        
        # Parse input arguments
        if type(hidden_dims) == int:
            hidden_dims = [hidden_dims,]

        if type(activations) == list:
            if len(activations) != len(hidden_dims):
                raise ValueError('activations and hidden_dims must have the same dimensions')
        else:
            activations = [activations]*len(hidden_dims)

        # Store arguments
        self.input_dim   = input_dim
        self.lat_dim     = lat_dim
        self.hidden_dims = np.array(hidden_dims)
        self.activations = activations
        
        # Create layers
        self.layers = nn.ModuleList()
        prev_d = input_dim
        for d in hidden_dims:
            self.layers.append(nn.Linear(prev_d, d))
            prev_d = d
            
        # Latent layer parameters
        self.mu = nn.Linear(prev_d, lat_dim)
        self.logvar = nn.Linear(prev_d, lat_dim)
        
    def forward(self, x):
        for layer, activation in zip(self.layers, self.activations):
            x = activation(layer(x))
            
        mu = self.mu(x)
        logvar = self.logvar(x)
        
        z = mu + torch.exp(0.5*logvar)*torch.randn(mu.shape)
        
        return z, mu, logvar

### Decoder

In [None]:
class Decoder(nn.Module):
    
    def __init__(self, lat_dim, hidden_dims, output_dim, activations=func.relu):
        super(Decoder, self).__init__()
        
        # Parse input arguments
        if type(hidden_dims) == int:
            hidden_dims = [hidden_dims,]
        
        if type(activations) == list:
            if len(activations) != len(hidden_dims):
                raise ValueError('activations and hidden_dims must have the same dimensions')
        else:
            activations = [activations]*len(hidden_dims)
            
        # Store arguments
        self.input_dim   = input_dim
        self.lat_dim     = lat_dim
        self.hidden_dims = np.array(hidden_dims)
        self.activations = activations
        activations.append(torch.sigmoid)
        
        # Create layers
        self.layers = nn.ModuleList()
        prev_d = lat_dim
        for d in self.hidden_dims:
            self.layers.append(nn.Linear(prev_d, d))
            prev_d = d
            
        # Output layer
        self.layers.append(nn.Linear(prev_d, output_dim))

            
    def forward(self, x):
        for layer, activation in zip(self.layers, self.activations):
            x = activation(layer(x))
        
        return x

### Variational Autoencoder

In [None]:
class VariationalAutoencoder(nn.Module):
    
    def __init__(self, intput_dim, hidden_dims, lat_dim, activations=func.relu):
        super(VariationalAutoencoder, self).__init__()
        
        # Parse input arguments
        if type(hidden_dims) == int:
            hidden_dims = [hidden_dims,]
        
        if type(activations) == list:
            if len(activations) != len(hidden_dims):
                raise ValueError('activations and hidden_dims must have the same dimensions')
        else:
            activations = [activations]*len(hidden_dims)
            
        # Store arguments
        self.input_dim   = input_dim
        self.lat_dim     = lat_dim
        self.hidden_dims = np.array(hidden_dims)
        self.activations = activations
        
        # Create encoder-decoder
        self.encoder = Encoder(input_dim, hidden_dims, lat_dim)
        self.decoder = Decoder(lat_dim, np.flip(hidden_dims, 0), input_dim)
        
    def forward(self, x):
        z, mu, logvar = self.encoder(x)
        out = self.decoder(z)
        
        return out, mu, logvar, z

## Training example

### Training Data
Let's train the denoising autoencoder on the MNIST dataset.

In [None]:
import torchvision
train = torchvision.datasets.MNIST('./', train=True, download=True, transform=torchvision.transforms.ToTensor())
test = torchvision.datasets.MNIST('./', train=False, download=True, transform=torchvision.transforms.ToTensor())

### Instantiate Autoencoder

In [None]:
input_dim  = 28*28
layers     = [200,100,30]
latent_dim = 2

vae = VariationalAutoencoder(input_dim, layers, latent_dim)

In [None]:
vae

### Optimizer

In [None]:
optimizer = torch.optim.Adam(vae.parameters(), lr=0.001)

### Training loop

In [None]:
%matplotlib notebook

n_test_img = 6
epochs     = 15
batch_size = 100

# Reshape data
train_samples = train.data.view(-1, 28*28).type(torch.float32)/255.0
test_samples  = test.data.view(-1,28*28).type(torch.float32)/255.0

# Lists to store training losses
train_loss = []
test_loss  = []

# Set model to training mode
vae.train()

# Plot test input images
test_imgs = test.data[0:n_test_img,:].type(torch.float32).view(-1,28*28)/255.0
f, a = plt.subplots(2, n_test_img, figsize=(8, 3))
for i in range(n_test_img):
    a[0][i].imshow(255-np.reshape(test_imgs.data.numpy()[i], (28,28)), cmap='gray')
    a[0][i].set_xticks(())
    a[0][i].set_yticks(())
    
loss_text = f.text(0, 0, "epoch: 0, loss: 0")

# Data iterator
train_batches = Data.DataLoader(dataset=train, batch_size=batch_size, shuffle=True)

for e in np.arange(epochs):
    batch_loss = 0
    for batch_no, (batch, batch_labels) in enumerate(train_batches):
        # Input and target data (flatten)
        b_in = batch.view(-1, 28*28)
        target = batch.view(-1, 28*28)
        # Forward pass of the data through the network
        out, mu, logvar, z = vae(b_in)
        # Compute the VAE Losses
        ae_loss = func.binary_cross_entropy(out, target, reduction='sum')
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        # Total loss 
        loss = ae_loss + kl_loss
        batch_loss += float(loss)
        # Reset the gradients
        optimizer.zero_grad()
        # Compute gradients
        loss.backward()
        # Update parameters
        optimizer.step()

        # Test images
        if batch_no % 50 == 0:
            test_out, _, _, _ = vae(test_imgs)
            for i in range(n_test_img):
                a[1][i].imshow(1.0-np.reshape(test_out.data.numpy()[i], (28,28)), cmap='gray')
                a[1][i].set_xticks(())
                a[1][i].set_yticks(())
            loss_text.set_text("epoch: {}, loss: {:.3f}".format(e+1, loss))
            f.canvas.draw()

    # End of epoch, compute train & test loss
    train_loss.append(batch_loss/batch_no)
    out, mu, logvar, z = vae(test_samples)
    test_loss.append(func.binary_cross_entropy(out, test_samples) - 0.5*torch.sum(1+logvar - mu.pow(2) - logvar.exp()))


### Training plots

In [None]:
plt.figure()
plt.plot(train_loss)
plt.plot(test_loss)
plt.legend(['train', 'test'])

## Results
### Loss on test dataset
We can easily compute the Loss by evaluating the loss function on the output data:

#### Train Loss

In [None]:
print("Train loss: {}".format(train_loss[-1]))

#### Test Loss

In [None]:
print("Test loss: {}".format(test_loss[-1]))

### Latent space
The projections of the test images in the latent space after training look like this:

In [None]:
test_z, test_mu, test_logsigma = vae.encoder(test_samples)
plt.figure(figsize=(10,7))
plt.scatter(test_mu.data.numpy()[:,0], test_mu.data.numpy()[:,1], c=test.targets.numpy(), s=1.7)

In [None]:
test_z, test_mu, test_logsigma = vae.encoder(train_samples)
plt.figure(figsize=(10,7))
plt.scatter(test_mu.data.numpy()[:,0], test_mu.data.numpy()[:,1], c=train.targets.numpy(), s=1.7)

### Image generation
Generate images sampling the latent space:

In [None]:
x_samples = np.arange(2,-2.1,-0.2) # 21 samples from -2 to 2
y_samples = np.arange(-2,2.1,0.2) # 21 samples from -2 to 2

imgs = np.zeros((28*21, 28*21))

for y_idx, y in enumerate(y_samples):
    for x_idx, x in enumerate(x_samples):
        y_img = 1-vae.decoder(torch.Tensor([x,y])).view(-1,28,28)
        imgs[(x_idx*28):((x_idx+1)*28), (y_idx*28):((y_idx+1)*28)] = y_img.detach().numpy()
    
plt.figure(figsize=(10,10))
plt.imshow(imgs, cmap=plt.cm.gray)