# Variational Autoencoder

Following several tutorials as listed below
https://jaan.io/what-is-variational-autoencoder-vae-tutorial/
https://towardsdatascience.com/intuitively-understanding-variational-autoencoders-1bfe67eb5daf
https://medium.com/dataseries/variational-autoencoder-with-pytorch-2d359cbf027b

Something I didn't realise initially was the difference between Variational Autoencoders and straight Autoencoders. As it so happens there is quite a large and destinct difference which relates fundamentally to whether the models are generative (variational) or discriminative (straight). Generative models appear to be more powerful than straight descriminative models but with that power comes additional complexities.

In this first foray into generative models below I will hopefully learn to understand the difference and the benefits of such models.


## Pytorch Implementation

Using pytorch and the mnist dataset we will code this up.

Import everything:

In [13]:
import torch
from torch import nn
from torch.utils.data import DataLoader,Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
import torchvision.models as models
import time
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

Lets use weights and bisases. You will need an account.
If running this on a cluster with jupyter nbconvert then get rid of all wandb code as you are unable to sign in as below.
There is probably a way around this that I haven't researched.

In [14]:
import wandb

Edit the below cell to setup project global variables

In [15]:
#Variables
batch_size = 64
learning_rate = 1e-4
epochs = 50
input_size = 32*32*3
latent_size = 512
hidden_size = 1000

wandb.init(project="VariationalAutoencoder",
           config={
               "batch_size": batch_size,
               "learning_rate": learning_rate,
               "dataset": "MNIST",
           })


# Get cpu or gpu device for training.
# device = "cpu"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

[34m[1mwandb[0m: wandb version 0.12.9 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Using cuda device


Now, lets set up our MNIST dataset. This is a simple setup and uses built in functions that can be found almost line-for-line in the quickstart guide

In [16]:
training_data = datasets.CIFAR10(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.CIFAR10(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

train_dataloader = DataLoader(training_data, batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size)

#Printing data
for X, y in test_dataloader:
    print("Shape of X [N, C, H, W]: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break

Files already downloaded and verified
Files already downloaded and verified
Shape of X [N, C, H, W]:  torch.Size([64, 3, 32, 32])
Shape of y:  torch.Size([64]) torch.int64


Now to create a class for the network

In [17]:
class EncoderNeuralNetwork(nn.Module):
    def __init__(self,input_size, hidden_size):
        super().__init__()
        self.network = models.resnet18()
        
    def forward(self, x):
        return self.network(x)
    
class DecoderNeuralNetwork(nn.Module):
    def __init__(self,input_size, hidden_size, output_size):
        super().__init__()
        self.network = models.resnet18()
        
    def forward(self, x):
        return self.network(x)

Now to instantiate our class to create a variational autoencoder.
Up until here it is pretty easy to follow. This next bit gets a bit complicated and it really comesdown to "latent_size*2". Why?
Well because at this point we have both mu and sigma to be used in our z-space (latent space) and this will be shrunk to a singular value that represents them both.


In [18]:
class VAE(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size):
        super().__init__()
        self.encoder = EncoderNeuralNetwork(input_size, hidden_size)
        self.decoder = DecoderNeuralNetwork(latent_size, hidden_size, input_size)
        
        # Parameters for decoding the output of the encoder
        self.fc_mu = nn.Linear(hidden_size, latent_size) #mu is the mean
        self.fc_var = nn.Linear(hidden_size, latent_size) #var is the variance
        
        # for the gaussian likelihood
        self.log_scale = nn.Parameter(torch.Tensor([0.0]))

    def gaussian_likelihood(self, x_hat, logscale, x):
        scale = torch.exp(logscale)
        mean = x_hat
        dist = torch.distributions.Normal(mean, scale)

        # measure prob of seeing image under p(x|z)
        log_pxz = dist.log_prob(x)
        return log_pxz.sum(-1)

    def kl_divergence(self, z, mu, std):
        # --------------------------
        # Monte carlo KL divergence
        # --------------------------
        # 1. define the first two probabilities (in this case Normal for both)
        p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
        q = torch.distributions.Normal(mu, std)

        # 2. get the probabilities from the equation
        log_qzx = q.log_prob(z)
        log_pz = p.log_prob(z)

        # kl
        kl = (log_qzx - log_pz)
        kl = kl.sum(-1)
        return kl 
        
    def forward(self, x):
        #Encoder otherwise known as q. Pass x through it.
        encoded_x = self.encoder(x)
        
        #Use the encoding to find the mean and variance
        mu, log_var = self.fc_mu(encoded_x), self.fc_var(encoded_x)
        
        # sample z from q
        std = torch.exp(log_var / 2)
        q = torch.distributions.Normal(mu, std)
        z = q.rsample()
        
        # decode the latent space, otherwise known as the function p.
        # x_hat because this is the new x
        x_hat = self.decoder(z)
        
        # Calculate the ELBO loss. Remember the two parts....1 Reconstruction Loss and 2 KL divergence
        # 1 reconstruction loss
        recon_loss = self.gaussian_likelihood(x_hat, self.log_scale, x)
        
        # 2 KL divergence
        kl = self.kl_divergence(z, mu, std)
        
        #Evidence lower bound
        elbo = -recon_loss + kl
        elbo = elbo.mean()
        
        return elbo, recon_loss.mean(), kl.mean(),  x_hat
    


Now that we have this class we can go about training the network.

In [19]:
model = VAE(input_size = input_size, hidden_size = hidden_size, latent_size = latent_size).to(device)
# print(model)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

def train(dataloader, model, optimizer):
    size = len(dataloader.dataset)
    model.train()
    model.zero_grad()
    for (image, _) in dataloader:
    # for batch, (X, y) in enumerate(dataloader):
    
        # Compute prediction and loss
        image = image.to(device)
        # X, y = X.to(device), y.to(device)
        optimizer.zero_grad()

        elbo, recon_loss, kl,  x_hat = model(image)
        loss = elbo

        # Backpropagation
        loss.backward()
        optimizer.step()
        
def test(dataloader, model):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss= 0
    kl_loss= 0
    recon_loss= 0

    with torch.no_grad():
        for X, y in dataloader:
            image = image.to(device)
            elbo, recon, kl, x_hat = model(image)
            test_loss += elbo
            kl_loss += kl
            recon_loss += recon
            
    
    test_loss_avg = test_loss/num_batches
    kl_loss /= num_batches
    recon_loss /= num_batches
    wandb.log({"loss": test_loss_avg, 
               "elbo": test_loss, 
               "kl": kl_loss,
               "reconstruction": recon_loss
              })
    wandb.watch(model)
    return test_loss_avg
    

Training and testing datasets

In [20]:
%%wandb

#Timing the hardware difference
tic = time.perf_counter()

for t in range(epochs):
    train(train_dataloader, model, optimizer)
    test_loss = test(test_dataloader, model)
    if t%20==0:
        print(f"Epoch {t+1}\n-------------------------------")    
        print(f"Avg loss: {test_loss:>8f} \n")
print("Done!")

toc = time.perf_counter()
if {device} == "cpu":
    print(f"CPU time {toc - tic:0.4f} seconds")
else:
    print(f"GPU time {toc - tic:0.4f} seconds")

RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 3, 7, 7], but got 2-dimensional input of size [64, 512] instead

In [21]:
figure = plt.figure(figsize=(16, 8))
cols, rows = 6, 2
for i in range(1, cols +1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    image = img.reshape(-1, 32*32*3)
    image = image.to(device)
    elbo, recon, kl, x_hat = model(image)
    pred = torch.reshape(x_hat,[32,32,3])
    pred = pred.to("cpu")
    pred = pred.detach().numpy()
    figure.add_subplot(rows, cols, i)
    plt.title("Real")
    plt.axis("off")
    plt.imshow(img.reshape(3,32,32).permute(1, 2, 0))
    figure.add_subplot(rows, cols, i+cols)
    plt.title("Decoded")
    plt.axis("off")
    plt.imshow(pred.reshape(3,32,32).transpose(1, 2, 0))
plt.show()

RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 3, 7, 7], but got 2-dimensional input of size [1, 3072] instead

<Figure size 1152x576 with 0 Axes>

In [None]:

torch.save(model, 'model.pth')
torch.save(model.state_dict(), 'model_weights.pth')
