# Variational Autoencoder (VAE)

The VAE has two parts: an encoder and a decoder. Instead of mapping each image to a single point in z-space, the encoder outputs the means and covariance matrices of a multivariate normal distribution where all of the dimensions are independent

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.distributions.normal import Normal
from torch.utils.tensorboard import SummaryWriter

In [2]:
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")

print(device) 

cuda:0


## Encoder and Decoder
For your encoder and decoder, you can use similar architectures that you've seen before, with some tweaks. For example, for the decoder, you can use the DCGAN generator architecture. For the encoder, you can use a classifier that you used before, and instead of having it produce 1 classification output of whether something is a cat or not, for example, you can have it produce 2 different outputs, one for mean and one for standard deviation. Each of those outputs will have dimensionality z to model the z dimensions in the multivariate normal distributions.

In [3]:
class Encoder(nn.Module):
    def __init__(self, im_chan=1, output_chan=32, hidden_dim=16):
        super(Encoder, self).__init__()
        self.z_dim = output_chan
        self.disc = nn.Sequential(
            self._block(im_chan, hidden_dim),
            self._block(hidden_dim, hidden_dim * 2),
            self._block(hidden_dim * 2, output_chan * 2, final_layer=True),
        )

    def _block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True),
            )
        else:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
            )

    def forward(self, images):
        disc_pred = self.disc(images)
        encoding = disc_pred.view(len(disc_pred), -1)
        # The stddev output is treated as the log of the variance of the normal 
        # distribution by convention and for numerical stability
        return encoding[:, :self.z_dim], encoding[:, self.z_dim:].exp()

In [4]:
im_chan = 1
z_dim = 32

# Create the Discriminator
encode = Encoder(im_chan, z_dim).to(device)

# Print the model
print(encode)

Encoder(
  (disc): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 16, kernel_size=(4, 4), stride=(2, 2))
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (1): Sequential(
      (0): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (2): Sequential(
      (0): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
    )
  )
)


In [7]:
class Decoder(nn.Module):
    def __init__(self, z_dim=32, im_chan=1, hidden_dim=64):
        super(Decoder, self).__init__()
        self.z_dim = z_dim
        self.gen = nn.Sequential(
            self._block(z_dim, hidden_dim * 4),
            self._block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),
            self._block(hidden_dim * 2, hidden_dim),
            self._block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
        )

    def _block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.Sigmoid(),
            )

    def forward(self, noise):
        x = noise.view(len(noise), self.z_dim, 1, 1)
        return self.gen(x)

In [8]:
decoder = Decoder(z_dim, im_chan)

# Print the model
print(decoder)

Decoder(
  (gen): Sequential(
    (0): Sequential(
      (0): ConvTranspose2d(32, 256, kernel_size=(3, 3), stride=(2, 2))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (1): Sequential(
      (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (2): Sequential(
      (0): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (3): Sequential(
      (0): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2))
      (1): Sigmoid()
    )
  )
)


## Reparameterization Trick and VAE

You can't backpropagate through a random sample. In `PyTorch`, you can do this by sampling using the `rsample` method, such as in `Normal(mean, stddev).rsample()`. This is equivalent to `torch.randn(z_dim) * stddev + mean`

In [9]:
class VAE(nn.Module):
    def __init__(self, z_dim=32, im_chan=1, hidden_dim=64):
        super(VAE, self).__init__()
        self.z_dim = z_dim
        self.encode = Encoder(im_chan, z_dim)
        self.decode = Decoder(z_dim, im_chan)

    def forward(self, images):
        q_mean, q_stddev = self.encode(images)
        q_dist = Normal(q_mean, q_stddev)
        z_sample = q_dist.rsample() # Sample once from each distribution, using the `rsample` notation
        decoding = self.decode(z_sample)
        return decoding, q_dist

In [10]:
vae = VAE(z_dim=32, im_chan=1).to(device)

print(vae)

VAE(
  (encode): Encoder(
    (disc): Sequential(
      (0): Sequential(
        (0): Conv2d(1, 16, kernel_size=(4, 4), stride=(2, 2))
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (1): Sequential(
        (0): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2))
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (2): Sequential(
        (0): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
      )
    )
  )
  (decode): Decoder(
    (gen): Sequential(
      (0): Sequential(
        (0): ConvTranspose2d(32, 256, kernel_size=(3, 3), stride=(2, 2))
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (1): Sequential(
        (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), str

In [11]:
#!pip install torchsummary

from torchsummary import summary



In [12]:
summary(vae,(1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 13, 13]             272
       BatchNorm2d-2           [-1, 16, 13, 13]              32
         LeakyReLU-3           [-1, 16, 13, 13]               0
            Conv2d-4             [-1, 32, 5, 5]           8,224
       BatchNorm2d-5             [-1, 32, 5, 5]              64
         LeakyReLU-6             [-1, 32, 5, 5]               0
            Conv2d-7             [-1, 64, 1, 1]          32,832
           Encoder-8       [[-1, 32], [-1, 32]]               0
   ConvTranspose2d-9            [-1, 256, 3, 3]          73,984
      BatchNorm2d-10            [-1, 256, 3, 3]             512
             ReLU-11            [-1, 256, 3, 3]               0
  ConvTranspose2d-12            [-1, 128, 6, 6]         524,416
      BatchNorm2d-13            [-1, 128, 6, 6]             256
             ReLU-14            [-1, 12

## Reconstruction Loss
Reconstruction loss refers to the distance between the real input image (that you put into the encoder) and the generated image (that comes out of the decoder). Explicitly, the reconstruction loss term is E(logp(x|z)), the log probability of the true image given the latent value.

For MNIST, you can treat each grayscale prediction as a binary random variable (also known as a Bernoulli distribution) with the value between 0 and 1 of a pixel corresponding to the output brightness, so you can use the binary cross entropy loss between the real input image and the generated image in order to represent the reconstruction loss term. 

In [13]:
reconstruction_loss = nn.BCELoss(reduction='sum')

## KL Divergence

KL divergence, mentioned in a video (on Inception Score) last week, allows you to evaluate how different one probability distribution is from another. You care about two distributions and finding how different they are: (1) the learned latent space $q(z|x)$ that your encoder is trying to model and (2) your prior on the latent space $p(z)$, which you want your learned latent space to be as close as possible to. If both of your distributions are normal distributions, you can calculate the KL divergence, or $DKL(q(z|x)∥p(z))$, based on a simple formula. 

The latent prior $p(z)$ is a simple distribution for the latent space with a mean of zero and a standard deviation of one in each dimension.

In [14]:
from torch.distributions.kl import kl_divergence

def kl_divergence_loss(q_dist):
    return kl_divergence(
        q_dist, Normal(torch.zeros_like(q_dist.mean), torch.ones_like(q_dist.stddev))
    ).sum(-1)

## Optimizer

In [15]:
lr = 0.002
# Setup Adam optimizers for vae:
vae_opt  = optim.Adam(vae.parameters(), lr=lr)

## Dataset

In [16]:
batch_size =  1024
im_chan    = 1 

transform = transforms.Compose(
    [  
        transforms.ToTensor()
    ]
)

# https://pytorch.org/vision/stable/generated/torchvision.datasets.MNIST.html#torchvision.datasets.MNIST
dataset = datasets.MNIST(root="dataset/", train=True, transform=transform,
                       download=True)


dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

## Weight Initialization

In [17]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [18]:
## Weight Initialization
weights_init(vae)

## Preparation of Training 

In [19]:
writer_real = SummaryWriter(f"logs_vae/real_images")
writer_fake = SummaryWriter(f"logs_vae/recon_images")
step = 0

vae.train()

VAE(
  (encode): Encoder(
    (disc): Sequential(
      (0): Sequential(
        (0): Conv2d(1, 16, kernel_size=(4, 4), stride=(2, 2))
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (1): Sequential(
        (0): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2))
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (2): Sequential(
        (0): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
      )
    )
  )
  (decode): Decoder(
    (gen): Sequential(
      (0): Sequential(
        (0): ConvTranspose2d(32, 256, kernel_size=(3, 3), stride=(2, 2))
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (1): Sequential(
        (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), str

In [21]:
fixed_latent = torch.randn(batch_size, z_dim, 1, 1).to(device)
print(fixed_latent.shape)

torch.Size([1024, 32, 1, 1])


## Training

In [22]:
print("Starting Training Loop...")

NUM_EPOCHS = 10

for epoch in range(NUM_EPOCHS):
    # Target labels not needed! <3 unsupervised
    for batch_idx, (real_images, _) in enumerate(dataloader):
        real_images = real_images.to(device)
        recon_images, encoding = vae(real_images)

        ############################
        #  Update VAE: minimize reconstruction_loss + kl_divergence_loss
        ###########################
        loss = reconstruction_loss(recon_images, real_images) + kl_divergence_loss(encoding).sum()
        vae_opt.zero_grad() # Clear out the gradients
        loss.backward()
        vae_opt.step()

        ############################
        
        # Print losses occasionally and print to tensorboard
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
                  Loss VAE: {loss:.4f}"
            )

            with torch.no_grad():
                recon_images = vae.decode(fixed_latent)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(
                    real_images[:32], normalize=True
                )
                img_grid_fake = torchvision.utils.make_grid(
                    recon_images[:32], normalize=True
                )

                writer_real.add_image("real_images", img_grid_real, global_step=step)
                writer_fake.add_image("recon_images", img_grid_fake, global_step=step)
            
            writer_real.add_scalar("Loss_VAE",loss.item(), global_step=step)
            step += 1

Starting Training Loop...
Epoch [0/10] Batch 0/59                   Loss VAE: 792486.0000
Epoch [1/10] Batch 0/59                   Loss VAE: 161875.0469
Epoch [2/10] Batch 0/59                   Loss VAE: 128755.1953
Epoch [3/10] Batch 0/59                   Loss VAE: 118238.9766
Epoch [4/10] Batch 0/59                   Loss VAE: 113607.9297
Epoch [5/10] Batch 0/59                   Loss VAE: 110856.3203
Epoch [6/10] Batch 0/59                   Loss VAE: 109380.6875
Epoch [7/10] Batch 0/59                   Loss VAE: 108001.7188
Epoch [8/10] Batch 0/59                   Loss VAE: 108881.4453
Epoch [9/10] Batch 0/59                   Loss VAE: 107004.8438


## References:
1. https://deepgenerativemodels.github.io/notes/vae/
2. http://www.cs.toronto.edu/~rgrosse/courses/csc421_2019/slides/lec17.pdf 
3. Build Better Generative Adversarial Networks Cousera.org
 