# Understanding Variational Autoencoders

An Abridged History of Neural Networks:
- Linear [Perceptron](https://en.wikipedia.org/wiki/Perceptron)
- Non-linear [SVM](https://en.wikipedia.org/wiki/Support_vector_machine)
- Perceptron w/ Activation [Neural Networks](https://wiki.pathmind.com/neural-network)
- Activation Function [Activation Functions](https://en.wikipedia.org/wiki/Activation_function)
- Stochastic Gradient Descent [Stochastic Gradient Descent](https://en.wikipedia.org/wiki/Stochastic_gradient_descent)
- Shallow [Neural Network](https://en.wikipedia.org/wiki/Neural_network_(machine_learning))
- Deep [Neural Network](https://en.wikipedia.org/wiki/AlexNet)

Describing the training process on a deterministic autoencoder: 
- [Autoencoder Diagram](https://www.ibm.com/think/topics/variational-autoencoder)

VAE diagram introducing the probabilistic genertive model:
- [VAE Wiki](https://en.wikipedia.org/wiki/Variational_autoencoder)

Transformers:
- [Transformers](https://en.wikipedia.org/wiki/Transformers)
- [Transformers](https://en.wikipedia.org/wiki/Transformer)
- [Transformers](https://en.wikipedia.org/wiki/Transformer_(deep_learning_architecture))

Diffusion:
- [Diffusion](https://en.wikipedia.org/wiki/Diffusion_model)

Academic source material:
- [Auto-Encoding Variational Bayes](https://arxiv.org/abs/1312.6114)
- [Variational Autoencoders Tutorial](https://arxiv.org/abs/1906.02691)
- [Beta-VAE](https://arxiv.org/abs/1804.03599)
  

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

from scipy.stats import norm
from torch.autograd import Variable
from torch.utils.data.dataloader import DataLoader
from torchvision.datasets import MNIST
from torchvision.utils import make_grid as make_image_grid
from tqdm import trange

torch.manual_seed(2017) # reproducability
sns.set_style('dark')
%matplotlib qt

In [None]:
# Select device (MPS for Apple Silicon, CUDA for NVIDIA, else CPU)
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Model
class VAE(nn.Module):
    def __init__(self,latent_dim=20,hidden_dim=500):
        super(VAE,self).__init__()
        self.fc_e = nn.Linear(784,hidden_dim)
        self.fc_mean = nn.Linear(hidden_dim,latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim,latent_dim)
        self.fc_d1 = nn.Linear(latent_dim,hidden_dim)
        self.fc_d2 = nn.Linear(hidden_dim,784)
            
    def encoder(self,x_in):
        x = F.relu(self.fc_e(x_in.view(-1,784)))
        mean = self.fc_mean(x)
        logvar = self.fc_logvar(x)
        return mean, logvar
    
    def decoder(self,z):
        z = F.relu(self.fc_d1(z))
        x_out = F.sigmoid(self.fc_d2(z))
        return x_out.view(-1,1,28,28)
    
    def sample_normal(self,mean,logvar):
        sd = torch.exp(logvar*0.5)
        e = torch.randn(sd.size(), device=device)  # Move to device
        z = e.mul(sd).add_(mean)
        return z
    
    def forward(self,x_in):
        z_mean, z_logvar = self.encoder(x_in)
        z = self.sample_normal(z_mean,z_logvar)
        x_out = self.decoder(z)
        return x_out, z_mean, z_logvar

model = VAE().to(device)

In [None]:
# Loss function
def criterion(x_out,x_in,z_mu,z_logvar):
    bce_loss = F.binary_cross_entropy(x_out,x_in,size_average=False)
    kld_loss = -0.5 * torch.sum(1 + z_logvar - (z_mu ** 2) - torch.exp(z_logvar))
    loss = (bce_loss + kld_loss) / x_out.size(0) # normalize by batch size
    return loss

In [None]:
# Optimizer
optimizer = torch.optim.Adam(model.parameters())

In [None]:
# Data loaders
trainloader = DataLoader(
    MNIST(root='./data',train=True,download=True,transform=transforms.ToTensor()),
    batch_size=128,shuffle=True)
testloader = DataLoader(
    MNIST(root='./data',train=False,download=True,transform=transforms.ToTensor()),
    batch_size=128,shuffle=True)

In [None]:
# Training
def train(model,optimizer,dataloader,epochs=15):
    losses = []
    for epoch in trange(epochs,desc='Epochs'):
        for images,_ in dataloader:
            # x_in = Variable(images)
            x_in = images.to(device)  # Move data to device
            optimizer.zero_grad()
            x_out, z_mu, z_logvar = model(x_in)
            loss = criterion(x_out,x_in,z_mu,z_logvar)
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
    return losses

train_losses = train(model,optimizer,trainloader)

In [None]:
plt.figure(figsize=(10,5))
plt.plot(train_losses)
plt.show()

In [None]:
# Visualize moving average of losses
def visualize_losses_moving_average(losses,window=50,boundary='valid',ylim=(95,125)):
    mav_losses = np.convolve(losses,np.ones(window)/window,boundary)
    corrected_mav_losses = np.append(np.full(window-1,np.nan),mav_losses)
    plt.figure(figsize=(10,5))
    plt.plot(losses)
    plt.plot(corrected_mav_losses)
    plt.ylim(ylim)
    plt.show()

visualize_losses_moving_average(train_losses)

In [None]:
# Testing
def test(model,dataloader):
    running_loss = 0.0
    for images, _ in dataloader:
        # x_in = Variable(images)
        x_in = images.to(device)  # Move data to device
        x_out, z_mu, z_logvar = model(x_in)
        loss = criterion(x_out,x_in,z_mu,z_logvar)
        running_loss = running_loss + (loss.item()*x_in.size(0))
    return running_loss/len(dataloader.dataset)

test_loss = test(model,testloader)
print(test_loss)

In [None]:
# Visualize VAE input and reconstruction
def visualize_mnist_vae(model,dataloader,num=16):
    def imshow(img):
        npimg = img.numpy()
        plt.imshow(np.transpose(npimg,(1,2,0)))
        plt.axis('off')
        plt.show()
        
    images,_ = next(iter(dataloader))
    images = images[0:num,:,:].to(device)
    x_in = Variable(images)
    x_out,_,_ = model(x_in)
    x_out = x_out.data
    imshow(make_image_grid(images.cpu()))
    imshow(make_image_grid(x_out.cpu()))

visualize_mnist_vae(model,testloader)

In [None]:
# Train, test and visualize reconstruction using a 2D latent space
model2 = VAE(latent_dim=2).to(device)
optimizer2 = torch.optim.Adam(model2.parameters())

In [None]:
train2_losses = train(model2,optimizer2,trainloader)
test2_loss = test(model2,testloader)

In [None]:
print(test2_loss)
visualize_mnist_vae(model2,testloader)

In [None]:
# Visualize test data encodings on the latent space
def visualize_encoder(model,dataloader):
    z_means_x, z_means_y, all_labels = [], [], []
    
    for images,labels in iter(dataloader):
        z_means,_ = model.encoder(Variable(images).to(device))
        z_means_x = np.append(z_means_x,z_means[:,0].cpu().data.numpy())
        z_means_y = np.append(z_means_y,z_means[:,1].cpu().data.numpy())
        all_labels = np.append(all_labels,labels.numpy())
        
    plt.figure(figsize=(6.5,5))
    plt.scatter(z_means_x,z_means_y,c=all_labels,cmap='inferno', s=1)
    plt.colorbar()
    plt.show()

visualize_encoder(model2,testloader)

In [None]:
# Train, test and visualize reconstruction using a 3D latent space
model3 = VAE(latent_dim=3).to(device)
optimizer3 = torch.optim.Adam(model2.parameters())

In [None]:
train3_losses = train(model3,optimizer3,trainloader)
test3_loss = test(model3,testloader)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import torch
from torch.autograd import Variable

def visualize_encoder_3d(model, dataloader):
    z_means_x, z_means_y, z_means_z, all_labels = [], [], [], []
    
    for images, labels in iter(dataloader):
        z_means, _ = model.encoder(Variable(images).to(device))
        
        z_means_x = np.append(z_means_x, z_means[:, 0].cpu().data.numpy())
        z_means_y = np.append(z_means_y, z_means[:, 1].cpu().data.numpy())
        z_means_z = np.append(z_means_z, z_means[:, 2].cpu().data.numpy())
        all_labels = np.append(all_labels, labels.numpy())
    
    plt.ion()  # Enable interactive mode
    
    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111, projection='3d')

    scatter = ax.scatter(z_means_x, z_means_y, z_means_z, c=all_labels, cmap='inferno', s=1)
    
    ax.set_xlabel('Latent Dimension 1')
    ax.set_ylabel('Latent Dimension 2')
    ax.set_zlabel('Latent Dimension 3')
    
    fig.colorbar(scatter, ax=ax, shrink=0.5, aspect=5)

    plt.show()

visualize_encoder_3d(model3, testloader)

In [None]:
# Visualize digits generated from latent space grid
def visualize_decoder(model,num=20,range_type='g'):
    image_grid = np.zeros([num*28,num*28])

    if range_type == 'l': # linear range
        # corresponds to output range of visualize_encoding()
        range_space = np.linspace(-4,4,num)
    elif range_type == 'g': # gaussian range
        range_space = norm.ppf(np.linspace(0.01,0.99,num))
    else:
        range_space = range_type

    for i, x in enumerate(range_space):
        for j, y in enumerate(reversed(range_space)):
            z = Variable(torch.FloatTensor([[x,y]]))
            image = model.decoder(z.to(device))
            image = image.cpu().data.numpy()
            image_grid[(j*28):((j+1)*28),(i*28):((i+1)*28)] = image

    plt.figure(figsize=(10, 10))
    plt.imshow(image_grid)
    plt.show
    return range_space

range_space = visualize_decoder(model2, range_type='l')