# Programming for Data Science and Artificial Intelligence

## 9.8 Deep Learning -  PyTorch - Variational Autoencoders

- Variational Autoencoders - https://arxiv.org/abs/1312.6114

### Variational Autoencoders

Variational autoencoders (VAEs) are a group of generative models in the field of deep learning.

#### Standard Autoencoders

Before we move to variational autoencoders, let's review what is a standard autoencoders.  A standard autoencoder consists of an encoder and a decoder. Let the input data be X. The encoder produces the latent space vector z from X. Then the decoder tries to reconstruct the input data X from the latent vector z.  Hence, the network serves the purpose of dimensionality reduction.

The problem is that they can only reconstruct those types of images on which they are trained. In other words, they compress the data while producing the latent vector and try to replicate the output to the input.  Another limitation is that the latent space vectors are not continuous. This means that we can only replicate the output images to input images. But we cannot generate new images from the latent space vector. This is where variational autoencoders work much better than standard autoencoders.

#### Variational Autoencoders

The concept of variational autoencoders was introduced by Diederik P Kingma and Max Welling in their paper Auto-Encoding Variational Bayes.

Variational autoencoders or VAEs are really good at generating new images from the latent vector. Although, they also reconstruct images similar to the data they are trained on, but they can generate many variations of the images.

Moreover, the latent vector space of variational autoencoders is continous which helps them in generating new images.

In architecture, VAEs resemble a standard autoencoder. VAEs also consist of an encoder and a decoder. The major difference – the latent vector generated by VAEs is continuous (that is, it generates a distribution, not a latent vector) which makes them a part of the generative model family.

#### Types of Variational Autoencoders

VAEs also allow us to control or condition the outputs of the decoder to some extent. This conditioning of the decoder’s actions leads to the concept of **Conditional Variational Autoencoders (CVAEs)**.

We can also have variational autoencoders that learn from latent vectors which have more disentanglement. As such, disentanglement can lead to learning a broader set of features from the input data to the latent vectors. This, we can control through a parameter called beta ($\beta$). Such VAEs are called $\beta$-VAEs.

Here, we will take a look at the simple VAE.

#### Problem

In the case of an autoencoder, we have $z$ as the latent vector, $x$ as the input, $P(x)$ be the probability distribution of input data, $P(z)$ be the probability distribution of the latent variable, $Q(z|x)$ as the encoder, and $P(x|z)$ be the distribution of generating data given latent variable, or the decoder.  

The loss function of variational autoencoders consist of two parts: (a) the KL divergence, and (b) the reconstruction loss:

$$\mathcal{L}(\theta, \phi;x^{(i)}) = -D_{KL}(q_{\phi}(z|x^{(i)}) || p_{\theta}(z)) + \mathbb{E}_{z{\tilde{}}q}[logp_{\theta}(x|z)]$$

For the first term KL divergence, it measures the similarity between two distributions where $D_{KL}$ outputs a big number when two distibution is close.   What we want is to make sure the learned distribution ($q_{\phi}(z|x^{(i)}$) is not far from actual distribution ($p_{\theta}(z)$ assumed to be normal for easy calculation) where $\mu=0$ and $\sigma=1$.  The reason is because if it is not normal, it will be like standard autoencoders ($\mu=$scalar and $\sigma=0$) which does not output a distribution but instead a non-continuous latent vector.

The loss function of this KL divergence that we need to minimize can be written as:

$$-D_{KL}(q_{\phi}(z|x^{(i)}) || p_{\theta}(z)) = -\frac{1}{2}\sum_{j=1}^{N}{(1+log(\sigma_j)^2-(\mu_j)^2-(\sigma_j)^2)}$$

Here, $\sigma_j$ is the standard deviation and $\mu_j$ is the mean.  It's minimized when $m_j = 0$ and $\sigma_j = 1$.

For the second term, it is simply maximize the expectation of the reconstruction of real data points from the latent vector.  In another words, we simply compare the real data points with generated data points, and minimize the reconstruction loss.  Commonly, the second term can be denoted $\mathcal{L}_R$ and commonly calculated using the Binary Cross-Entrophy Loss (BCELoss).

So the final VAE loss that we need to optimize is:

$$\mathcal{L}_{VAE} = \mathcal{L}_R + \mathcal{L}_{KL}$$

Finally, we need to sample from the input space using the following formula.

$$Sample = \mu + \epsilon\sigma$$

Here, $\epsilon\sigma$ is element-wise multiplication. And the above formula is called the **reparameterization trick** in VAE. This perhaps is the most important part of a variational autoencoder. This makes it look like as if the sampling is coming from the input space instead of the latent vector space.

In [3]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create a directory if not exists
sample_dir = 'samples'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

# Hyper-parameters
image_size = 784
h_dim = 400
z_dim = 20
num_epochs = 15
batch_size = 128
learning_rate = 1e-3

# MNIST dataset
dataset = torchvision.datasets.MNIST(root='../../data',
                                     train=True,
                                     transform=transforms.ToTensor(),
                                     download=True)

# Data loader
data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                          batch_size=batch_size, 
                                          shuffle=True)


# VAE model
class VAE(nn.Module):
    def __init__(self, image_size=784, h_dim=400, z_dim=20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(image_size, h_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(h_dim, z_dim)
        self.fc4 = nn.Linear(z_dim, h_dim)
        self.fc5 = nn.Linear(h_dim, image_size)
        
    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2(h), self.fc3(h)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var/2)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc4(z))
        return F.sigmoid(self.fc5(h))
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_reconst = self.decode(z)
        return x_reconst, mu, log_var

model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Start training
for epoch in range(num_epochs):
    for i, (x, _) in enumerate(data_loader):
        # Forward pass
        x = x.to(device).view(-1, image_size)
        x_reconst, mu, log_var = model(x)
        
        # Compute reconstruction loss and kl divergence
        reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)
        kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        
        # Backprop and optimize
        loss = reconst_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 10 == 0:
            print ("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}" 
                   .format(epoch+1, num_epochs, i+1, len(data_loader), reconst_loss.item(), kl_div.item()))
    
    with torch.no_grad():
        # Save the sampled images
        z = torch.randn(batch_size, z_dim).to(device)
        out = model.decode(z).view(-1, 1, 28, 28)
        save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch+1)))

        # Save the reconstructed images
        out, _, _ = model(x)
        x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)
        save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch+1)))

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../../data/MNIST/raw/train-images-idx3-ubyte.gz


100.1%

Extracting ../../data/MNIST/raw/train-images-idx3-ubyte.gz to ../../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../../data/MNIST/raw/train-labels-idx1-ubyte.gz


113.5%

Extracting ../../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../../data/MNIST/raw/t10k-images-idx3-ubyte.gz


100.4%

Extracting ../../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Extracting ../../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../../data/MNIST/raw
Processing...
Done!
Epoch[1/15], Step [10/469], Reconst Loss: 35997.9648, KL Div: 3528.9819
Epoch[1/15], Step [20/469], Reconst Loss: 29627.0117, KL Div: 1011.9289
Epoch[1/15], Step [30/469], Reconst Loss: 26666.1836, KL Div: 1153.9070
Epoch[1/15], Step [40/469], Reconst Loss: 25800.6367, KL Div: 715.3208
Epoch[1/15], Step [50/469], Reconst Loss: 25302.5469, KL Div: 741.5033
Epoch[1/15], Step [60/469], Reconst Loss: 25220.3906, KL Div: 931.3972
Epoch[1/15], Step [70/469], Reconst Loss: 24315.3418, KL Div: 1034.2733
Epoch[1/15], Step [80/469], Reconst Loss: 23148.3418, KL Div: 1108.7080
Epoch[1/15], Step [90/469], Reconst Loss: 23523.0215, KL Div: 1347.5452
Epoch[1/15], Step [100/469], Reconst Loss: 21146.2910, KL Div: 1451.1224
Epoch[1/15], Step [110/469], Reconst Loss: 21475.1055, KL Div: 1522.2439
Epoch[1/15], Step [120/469], Reconst Loss: 21172.3984, KL Div: 1563.0320
Epoch[1/15], Step [130/469], Recon

Epoch[3/15], Step [210/469], Reconst Loss: 11635.3604, KL Div: 3144.9187
Epoch[3/15], Step [220/469], Reconst Loss: 11702.9619, KL Div: 3065.3311
Epoch[3/15], Step [230/469], Reconst Loss: 11753.8125, KL Div: 3040.0229
Epoch[3/15], Step [240/469], Reconst Loss: 11721.6045, KL Div: 3043.5205
Epoch[3/15], Step [250/469], Reconst Loss: 12098.2383, KL Div: 3077.1523
Epoch[3/15], Step [260/469], Reconst Loss: 11703.6641, KL Div: 3160.4749
Epoch[3/15], Step [270/469], Reconst Loss: 11087.0938, KL Div: 2960.7227
Epoch[3/15], Step [280/469], Reconst Loss: 11588.9531, KL Div: 3056.7163
Epoch[3/15], Step [290/469], Reconst Loss: 11843.6113, KL Div: 3065.3877
Epoch[3/15], Step [300/469], Reconst Loss: 11823.7510, KL Div: 3017.7693
Epoch[3/15], Step [310/469], Reconst Loss: 11397.3057, KL Div: 3147.2952
Epoch[3/15], Step [320/469], Reconst Loss: 11551.8701, KL Div: 3165.4055
Epoch[3/15], Step [330/469], Reconst Loss: 11111.6357, KL Div: 3018.5281
Epoch[3/15], Step [340/469], Reconst Loss: 11756.52

Epoch[5/15], Step [420/469], Reconst Loss: 10776.7852, KL Div: 3163.8567
Epoch[5/15], Step [430/469], Reconst Loss: 10950.8770, KL Div: 3286.9265
Epoch[5/15], Step [440/469], Reconst Loss: 10385.1270, KL Div: 3066.8127
Epoch[5/15], Step [450/469], Reconst Loss: 11020.1416, KL Div: 3126.3542
Epoch[5/15], Step [460/469], Reconst Loss: 10962.5557, KL Div: 3231.2507
Epoch[6/15], Step [10/469], Reconst Loss: 10895.3213, KL Div: 3172.8396
Epoch[6/15], Step [20/469], Reconst Loss: 10927.1631, KL Div: 3287.9685
Epoch[6/15], Step [30/469], Reconst Loss: 10296.2803, KL Div: 3256.9756
Epoch[6/15], Step [40/469], Reconst Loss: 10378.1797, KL Div: 3157.3025
Epoch[6/15], Step [50/469], Reconst Loss: 10792.6729, KL Div: 3095.2041
Epoch[6/15], Step [60/469], Reconst Loss: 10180.3223, KL Div: 3139.9712
Epoch[6/15], Step [70/469], Reconst Loss: 11175.7939, KL Div: 3292.1858
Epoch[6/15], Step [80/469], Reconst Loss: 10681.0254, KL Div: 3243.1489
Epoch[6/15], Step [90/469], Reconst Loss: 10498.7510, KL Di

Epoch[8/15], Step [170/469], Reconst Loss: 10385.6758, KL Div: 3283.0322
Epoch[8/15], Step [180/469], Reconst Loss: 10606.3740, KL Div: 3188.4648
Epoch[8/15], Step [190/469], Reconst Loss: 10237.3799, KL Div: 3314.6216
Epoch[8/15], Step [200/469], Reconst Loss: 10470.8691, KL Div: 2997.6401
Epoch[8/15], Step [210/469], Reconst Loss: 10357.5771, KL Div: 3343.6084
Epoch[8/15], Step [220/469], Reconst Loss: 10124.4609, KL Div: 3174.0579
Epoch[8/15], Step [230/469], Reconst Loss: 10637.8008, KL Div: 3159.7700
Epoch[8/15], Step [240/469], Reconst Loss: 10922.1846, KL Div: 3227.6096
Epoch[8/15], Step [250/469], Reconst Loss: 10509.6768, KL Div: 3170.9534
Epoch[8/15], Step [260/469], Reconst Loss: 10759.7139, KL Div: 3322.8594
Epoch[8/15], Step [270/469], Reconst Loss: 10450.9434, KL Div: 3214.1401
Epoch[8/15], Step [280/469], Reconst Loss: 10713.8047, KL Div: 3272.4685
Epoch[8/15], Step [290/469], Reconst Loss: 10453.1494, KL Div: 3220.3774
Epoch[8/15], Step [300/469], Reconst Loss: 10825.65

Epoch[10/15], Step [380/469], Reconst Loss: 10744.4756, KL Div: 3303.2786
Epoch[10/15], Step [390/469], Reconst Loss: 10468.0859, KL Div: 3328.0149
Epoch[10/15], Step [400/469], Reconst Loss: 10767.8936, KL Div: 3247.0728
Epoch[10/15], Step [410/469], Reconst Loss: 10477.1436, KL Div: 3241.9668
Epoch[10/15], Step [420/469], Reconst Loss: 10174.9580, KL Div: 3249.3704
Epoch[10/15], Step [430/469], Reconst Loss: 10257.8779, KL Div: 3239.3967
Epoch[10/15], Step [440/469], Reconst Loss: 10237.5127, KL Div: 3261.8223
Epoch[10/15], Step [450/469], Reconst Loss: 10135.8848, KL Div: 3231.4771
Epoch[10/15], Step [460/469], Reconst Loss: 10284.6719, KL Div: 3190.0610
Epoch[11/15], Step [10/469], Reconst Loss: 10654.4561, KL Div: 3249.2561
Epoch[11/15], Step [20/469], Reconst Loss: 10707.9600, KL Div: 3328.6704
Epoch[11/15], Step [30/469], Reconst Loss: 10711.5723, KL Div: 3340.4468
Epoch[11/15], Step [40/469], Reconst Loss: 10160.0801, KL Div: 3383.0254
Epoch[11/15], Step [50/469], Reconst Loss:

Epoch[13/15], Step [120/469], Reconst Loss: 10556.6611, KL Div: 3233.7148
Epoch[13/15], Step [130/469], Reconst Loss: 10408.9248, KL Div: 3235.0742
Epoch[13/15], Step [140/469], Reconst Loss: 10075.7695, KL Div: 3299.1680
Epoch[13/15], Step [150/469], Reconst Loss: 9913.7002, KL Div: 3230.2495
Epoch[13/15], Step [160/469], Reconst Loss: 9940.9863, KL Div: 3219.9619
Epoch[13/15], Step [170/469], Reconst Loss: 10607.9502, KL Div: 3299.2991
Epoch[13/15], Step [180/469], Reconst Loss: 10189.4980, KL Div: 3187.9370
Epoch[13/15], Step [190/469], Reconst Loss: 10267.3662, KL Div: 3245.7290
Epoch[13/15], Step [200/469], Reconst Loss: 10150.7041, KL Div: 3249.1333
Epoch[13/15], Step [210/469], Reconst Loss: 10131.6240, KL Div: 3219.9338
Epoch[13/15], Step [220/469], Reconst Loss: 10338.5781, KL Div: 3309.5422
Epoch[13/15], Step [230/469], Reconst Loss: 9976.4453, KL Div: 3207.8540
Epoch[13/15], Step [240/469], Reconst Loss: 9879.0342, KL Div: 3226.7844
Epoch[13/15], Step [250/469], Reconst Loss

Epoch[15/15], Step [320/469], Reconst Loss: 10417.8223, KL Div: 3177.1838
Epoch[15/15], Step [330/469], Reconst Loss: 9803.8467, KL Div: 3215.8281
Epoch[15/15], Step [340/469], Reconst Loss: 9983.5078, KL Div: 3223.2058
Epoch[15/15], Step [350/469], Reconst Loss: 10500.2100, KL Div: 3342.8684
Epoch[15/15], Step [360/469], Reconst Loss: 9925.4199, KL Div: 3187.9319
Epoch[15/15], Step [370/469], Reconst Loss: 9745.0762, KL Div: 3246.3953
Epoch[15/15], Step [380/469], Reconst Loss: 10011.7715, KL Div: 3136.9636
Epoch[15/15], Step [390/469], Reconst Loss: 10092.8359, KL Div: 3155.0017
Epoch[15/15], Step [400/469], Reconst Loss: 10070.7119, KL Div: 3341.8015
Epoch[15/15], Step [410/469], Reconst Loss: 10214.5088, KL Div: 3241.5994
Epoch[15/15], Step [420/469], Reconst Loss: 10122.7783, KL Div: 3281.0703
Epoch[15/15], Step [430/469], Reconst Loss: 9878.8193, KL Div: 3196.2974
Epoch[15/15], Step [440/469], Reconst Loss: 10156.1475, KL Div: 3249.3672
Epoch[15/15], Step [450/469], Reconst Loss: