# Variational Autoencoders (VAE)

VAE is made up of two sub-networks, the encoder and decoder. The encoder instead of compressing the input data to a fixed latent vector in the latent space, turns the input image into two parameters of statistical distributions, the mean $\mu$ which represent the most likely position of the image in the latent space and the standard deviation $\sigma$ which represent the size of the circular area around the position where the image could be. A random latent vector $z$ is then sampled from a latent normal distribution $P_{model}(z)$ from the encoder in each forward pass. The decoder then pick the sampled $z$ (a multivariate normal distribution) and tries to reconstructe $z$ as close as possible to the original input data. 


Let’s use $x$ as the vector representing the set of all observed variables. We assumed to be sample from an unknown underlying process, whose true (probability) distribution $P^{*}(x)$ is unknown. We chose a model $P(x)$ to approximate the unknow underlying process and the true distribution of the data $P^{*}(x)$ 
$$x \sim P(x) \approx P^{*}(x)$$


Approximating the conditional density of the latent variable given observed variable $P(z|x)$ is intractable. So we pick a $Q(z|x)$  to estimate $P(z|x)$  with the intention of making $Q(z|x)$ as close as possible to the true posterior $P(z|x)$. In this case $Q(z|x)$ is taking to be a Gaussian distribution

$P(x)$ the marginal likelihood

$Q(z|x)=encoder$

$P(x|z)=decoder$


**Reparameterizing the Sampling layer**

Sampling $z$ directly from $z \sim \mathcal{N}(0,1)$ creates a bottleneck since gradients cannot flow back through a random sampling operation we cannot directly backpropagate gradients through the random variable z. In order to allow gradients to flow during backpropagation instead of directly sampling $z$ from the encoder output, we use a **Reparameterization trick** in which $z$ is approximated by


$$z=\mu +\sigma \odot \epsilon  $$ where $$\epsilon \sim \mathcal{N}(0,1)$$

Instead of using the variance directly we the **log-var vector**

$$log(\sigma ^2) =2 log(\sigma)$$
$$ \sigma=exp^{\frac{log(\sigma^2)}{2}}$$
so $z$ becomes 

$$z=\mu +exp^{\frac{log(\sigma^2)}{2}} \odot \epsilon  $$ 

Because $\epsilon$ (random noise) is random process it ensures that points in the neighborhood where you encode the input image decode something similar to the imput image, thus ensuring meanful continuous distribution. Points in the neighborhood will decode to highly similar images 

![](../images/gan1.JPG)

# LOSS
 Using a loss function we want to know the amount of information lost when we go from $x$ to $z$ then to $\acute{x}$.VAE is trained via two losses, the **rconstruction loss** which measures how effective the decoder has learned to
reconstruct $x$ given the latent representation $z$ and **regularization term**

Loss= rconstruction loss  +   regularization term

$$P(𝑥, 𝑧) = P(𝑥 |𝑧) P (𝑧) = P (𝑧|𝑥) P(𝑥)$$

Applying bayes' Rule

$$P(x)=\frac{P(x|z)P(z)}{P(z|x)}$$

where $P(x|z)$ is computed using the decoder, $P(z)$ is the prior  latent vector distribution assumed to be a gaussian.
Computing the true posterior $P(z|x)$ is intractable so we train another network called encoder $Q(z|x)$ to approximates the true but intractable posterior $P(z|x)$

$$ Q(z|x) \approx P(z|x)$$
$$P(x)=\frac{P(x|z)P(z)}{P(z|x)}$$

Multiplying both  top and bottom by $Q(z|x)$ we get
$$ P(x)=log \ \frac{P(x|z)P(z)}{P(z|x)} = \frac{P(x|z)P(z) Q(z|x)}{P(z|x)Q(z|x)}$$

$$log \ P(x)=log \ \frac{P(x|z)P(z)}{P(z|x)} = log \ \frac{P(x|z)P(z) Q(z|x)}{P(z|x)Q(z|x)}$$



$$log \ P(x) =log P(x|z) - \ log \frac{Q(z|x)}{P(z)} +  \ log \frac{Q(z|x)}{P(z|x)} $$
**NOTE:**  $ log \ P(x)=E_{z \sim Q(z|x)}log \ P(x)$
$$log \ P(x) =E_{z \sim Q(z|x)} \left[ log P(x|z) \right] - E_{z \sim Q(z|x)} \left[ \ log \frac{Q(z|x)}{P(z)}\right] +  E_{z \sim Q(z|x)} \left[\ log \frac{Q(z|x)}{P(z|x)} \right]$$


$$log \ P(x) =E_{z \sim Q(z|x)} \left[ log P(x|z) \right] - D_{KL}(Q(z|x)||P(z)+ \ D_{KL}(Q(z|x)||P(z|x)  $$


$E_{z \sim Q(z|x)} \left[ log P(x|z) \right] $ is the data reconstruction error, $D_{KL}(Q(z|x)||P(z)$ is the KL divergence between the encoder network (the latent vector $z$ condition on the input $x$) and the prior and samples nwhich we assumed to be gaussian$z$  .


$ D_{KL}(Q(z|x)||P(z|x) $ is the KL divergence between the encoder and posterior of decoder. Because $P(z|x)$ is intractable so we drop 

$ D_{KL}(Q(z|x)||P(z|x)$ and since $KL \ge 0$, dropping $ D_{KL}(Q(z|x)||P(z|x) $  gives a lower bound on the data likelihood.

$$log \ P(x) \ge E_{z \sim Q(z|x)} \left[ log P(x|z) \right] - D_{KL}(Q(z|x)||P(z) $$


From [Kullback–Leibler divergence](https://en.wikipedia.org/wiki/Kullback–Leibler_divergence)
A special case, and a common quantity in variational inference, is the relative entropy between a diagonal multivariate normal, and a standard normal distribution (with zero mean and unit variance):

$$D_{kL}(N(\mu_i,\sigma_i)||N(0,\mathbb{1})=\frac{1}{2}\sum_i^{n}(1+\sigma^2-\mu^{2}-In(\sigma^{2}))$$

so in this tutorial 

**Loss=Binary cross entropy +$\frac{1}{2}\sum_i^{n}(1+\sigma^2-\mu^{2}-In(\sigma^{2}))$**


In [3]:
import matplotlib.pyplot as plt
import numpy as np
from torchvision import datasets
from torchvision.transforms import ToTensor
import torch
from torchsummary import summary
from torch.utils.data import DataLoader
import torch.nn as nn
plt.rcParams['image.cmap']='gray'

In [4]:

train_data = datasets.FashionMNIST( root = '../data',train = True,transform = ToTensor(), download = True)
test_data = datasets.FashionMNIST(root = '../data', train = False, transform = ToTensor())

In [5]:
len(train_data),len(test_data)

(60000, 10000)

In [6]:
train_data_loader=DataLoader(train_data,batch_size=100,shuffle=True)
test_data_loader=DataLoader(test_data,batch_size=100,shuffle=True)

In [7]:
for images, labels in train_data_loader:  
    print('Image batch dimensions:', images.shape)
    print('Image label dimensions:', labels.shape)
    break

Image batch dimensions: torch.Size([100, 1, 28, 28])
Image label dimensions: torch.Size([100])


In [8]:
latent_dim = 64

In [9]:
class Reshape(nn.Module):
    def __init__(self,*args):
        super().__init__()
        self.shape=args
    def forward(self, x):
        return x.view(self.shape)

In [14]:


class VAE(nn.Module):
    def __init__(self,latent_dim):
        super().__init__()
        self.latent_dim=latent_dim
        
        self.encoder = nn.Sequential(
                nn.Conv2d(1, 32, stride=(1, 1), kernel_size=(3, 3), padding=1),
                nn.LeakyReLU(0.01),
                nn.Conv2d(32, 64, stride=(2, 2), kernel_size=(3, 3), padding=1),
                nn.LeakyReLU(0.01),
                nn.Conv2d(64, 72, stride=(2, 2), kernel_size=(3, 3), padding=1),
                nn.LeakyReLU(0.01),
                nn.Conv2d(72, 64, stride=(1, 1), kernel_size=(3, 3), padding=1),
                nn.Flatten(),
        ) 
        
        self.z_mean=nn.Linear(3136,self.latent_dim)
        self.z_log_var=nn.Linear(3136,self.latent_dim)
        
        self.decoder = nn.Sequential(
                torch.nn.Linear(self.latent_dim, 3136),
                Reshape(-1, 64, 7, 7),
                nn.ConvTranspose2d(64, 64, kernel_size=3,stride=1, padding=1),
                nn.LeakyReLU(0.01),
                nn.ConvTranspose2d(64, 72, kernel_size=3,stride=2, padding=1),                
                nn.LeakyReLU(0.01),
                nn.ConvTranspose2d(72, 32, kernel_size=3, stride=2,padding=0),                
                nn.LeakyReLU(0.01),
                nn.ConvTranspose2d(32, 1, kernel_size=4,stride=1, padding=1), 
                nn.Sigmoid()
                )


        
    def Reparametrization(self,x):
        z_mean=self.z_mean(x)
        z_log_var=self.z_log_var(x)
        batch=z_mean.shape[0]
        dim=z_mean.shape[1]
        eps=torch.randn(batch,dim)
        z=z_mean+torch.exp(0.5*z_log_var)*eps
        return z_mean,z_log_var,z

    def forward(self, x):
        x = self.encoder(x)
        z_mean,z_log_var,z=self.Reparametrization(x)
        output = self.decoder(z)
        return z_mean,z_log_var,output
vae=VAE(latent_dim=latent_dim)

In [16]:
x=torch.randn((1,1,28,28))
summary(vae,x)

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 3136]                --
|    └─Conv2d: 2-1                       [-1, 32, 28, 28]          320
|    └─LeakyReLU: 2-2                    [-1, 32, 28, 28]          --
|    └─Conv2d: 2-3                       [-1, 64, 14, 14]          18,496
|    └─LeakyReLU: 2-4                    [-1, 64, 14, 14]          --
|    └─Conv2d: 2-5                       [-1, 72, 7, 7]            41,544
|    └─LeakyReLU: 2-6                    [-1, 72, 7, 7]            --
|    └─Conv2d: 2-7                       [-1, 64, 7, 7]            41,536
|    └─Flatten: 2-8                      [-1, 3136]                --
├─Linear: 1-2                            [-1, 64]                  200,768
├─Linear: 1-3                            [-1, 64]                  200,768
├─Sequential: 1-4                        [-1, 1, 28, 28]           --
|    └─Linear: 2-9                       [-1, 3136]           

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 3136]                --
|    └─Conv2d: 2-1                       [-1, 32, 28, 28]          320
|    └─LeakyReLU: 2-2                    [-1, 32, 28, 28]          --
|    └─Conv2d: 2-3                       [-1, 64, 14, 14]          18,496
|    └─LeakyReLU: 2-4                    [-1, 64, 14, 14]          --
|    └─Conv2d: 2-5                       [-1, 72, 7, 7]            41,544
|    └─LeakyReLU: 2-6                    [-1, 72, 7, 7]            --
|    └─Conv2d: 2-7                       [-1, 64, 7, 7]            41,536
|    └─Flatten: 2-8                      [-1, 3136]                --
├─Linear: 1-2                            [-1, 64]                  200,768
├─Linear: 1-3                            [-1, 64]                  200,768
├─Sequential: 1-4                        [-1, 1, 28, 28]           --
|    └─Linear: 2-9                       [-1, 3136]           

for name ,params in vae.named_parameters():
    print(name, '\t\t' ,params.shape)

In [32]:
optimizer=torch.optim.Adam(vae.parameters())

In [None]:
for epoch in range(30):
    vae.train()
    for x,_ in train_data_loader:
        z_mean,z_log_var, output=vae(x)
        kl_loss=-0.5*torch.sum(1+z_log_var-z_mean**2-torch.exp(z_log_var))
        bce=nn.functional.binary_cross_entropy(output,x.reshape(output.shape))
        train_loss=bce+kl_loss.mean()
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
    print("Epoch %d: train loss %.3f" % (epoch+1, train_loss.detach().numpy().mean()))

In [None]:
vae.eval()
for x_test,_ in test_data_loader:
    z_mean_t,z_log_var_t,pred_con=vae(x_test)
    break

In [None]:
n = 10
image_width = 28

plt.figure(figsize=(20, 4))
orig_imgs = x_test[:n]
decoded_imgs = pred_con[:n]

for i in range(n):

    # display original + noise
    ax = plt.subplot(2, n, i + 1)
    plt.title("original")
    plt.imshow(orig_imgs[i].detach().view((image_width, image_width)) )
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display reconstruction
    bx = plt.subplot(2, n, i + n + 1)
    plt.title("reconstructed")
    plt.imshow(decoded_imgs[i].detach().view((image_width, image_width))   )
    plt.gray()
    bx.get_xaxis().set_visible(False)
    bx.get_yaxis().set_visible(False)
plt.show()

<b>References </b>

-  [Evidence lower bound](https://en.wikipedia.org/wiki/Evidence_lower_bound)
-  [Kullback–Leibler divergence](https://en.wikipedia.org/wiki/Kullback–Leibler_divergence)
- [ Variational Bayesian methods](https://en.wikipedia.org/wiki/Variational_Bayesian_methods)
- [Tutorial on Variational Autoencoders](https://arxiv.org/abs/1606.05908)
- [An Introduction to Variational Autoencoders](https://arxiv.org/abs/1906.02691)
- [Variational AutoEncoder (keras.io)](https://keras.io/examples/generative/vae/)

