# HW 4: Variational Autoencoders

In this homework you will build a deep generative model of binary images (MNIST) using variational autoencoders and generative adversarial networks.
The original VAE paper can be found [here](https://arxiv.org/abs/1312.6114) and GANs [here](https://arxiv.org/abs/1406.2661), and there are many excellent tutorials
online, e.g. [here](https://arxiv.org/abs/1606.05908) and [here](https://jaan.io/what-is-variational-autoencoder-vae-tutorial/)

**For this homework there will not be a Kaggle submission**

## Goals


1. Build a discrete deep generative model of binary digits (MNIST) using variational autoencoders
2. Examine the learned latent space with visualizations 
3. Build a continuous deep generative model using generative adversarial networks.
4. Additionally extend the above in any way, for example by :
    - using better encoder/decoders (e.g. CNN as the encoder, PixelCNN as the decoder. Description of PixelCNN 
    can be found [here](https://arxiv.org/abs/1601.06759))
    - using different variational families, e.g. with [normalizing flows](https://arxiv.org/abs/1505.05770), 
    [inverse autoregressive flows](https://arxiv.org/pdf/1606.04934.pdf), 
    [hierarchical models](https://arxiv.org/pdf/1602.02282.pdf)
    - comparing with stochastic variational inference (i.e. where your variational parameters are randomly initialized and
    then updated with gradient ascent on the ELBO
    - or your own extension.

For your encoder/decoder, we suggest starting off with simple models (e.g. 2-layer MLP with ReLUs).

Consult the papers provided for hyperparameters, and the course notes for formal definitions.


## Setup

This notebook provides a working definition of the setup of the problem itself. Feel free to construct your models inline, or use an external setup (preferred) to build your system.

First, as always, let's download the data.

In [1]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms

# The output of torchvision datasets are PILImage images of range [0, 1].
# We transform them to Tensors of normalized range [-1, 1]
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
            mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

train_dataset = datasets.MNIST(root='./data/',
                            train=True, 
                            transform=transform,
                            download=True)
test_dataset = datasets.MNIST(root='./data/',
                           train=False, 
                           transform=transform)

In [2]:
print(len(train_dataset))
print(len(test_dataset))

60000
10000


By default MNIST gives grayscale values between [0,1]. Since we are modeling binary images, we have to turn these
into binary values, i.e. $\{0,1\}^{784}$). A standard way to do this is to interpret the grayscale values as 
probabilities and sample Bernoulli random vectors based on these probabilities. (Note you should not do this for GANs)


In [3]:
torch.manual_seed(3435)
train_img = torch.stack([d[0] for d in train_dataset])
train_label = torch.LongTensor([d[1] for d in train_dataset])
test_img = torch.stack([d[0] for d in test_dataset])
test_label = torch.LongTensor([d[1] for d in test_dataset])
print(train_img.size(), train_label.size(), test_img.size(), test_label.size())

torch.Size([60000, 1, 28, 28]) torch.Size([60000]) torch.Size([10000, 1, 28, 28]) torch.Size([10000])


MNIST does not have an official train dataset. So we will use the last 10000 training points as your validation set.

In [4]:
val_img = train_img[-10000:].clone()
val_label = train_label[-10000:].clone()
train_img = train_img[:-10000]
train_label = train_label[:-10000]

Now we use the dataloader to split into batches.

In [5]:
train = torch.utils.data.TensorDataset(train_img, train_label)
val = torch.utils.data.TensorDataset(val_img, val_label)
test = torch.utils.data.TensorDataset(test_img, test_label)

BATCH_SIZE = 100
train_loader = torch.utils.data.DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)
val_loader = torch.utils.data.DataLoader(val, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test, batch_size=BATCH_SIZE, shuffle=True)

In [6]:
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch import optim
from torchvision.utils import save_image

USE_CUDA = torch.cuda.is_available()

In [7]:
def to_var(x):
    if USE_CUDA:
        x = x.cuda()
    return Variable(x)

def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

In [37]:
LATENT_SIZE = 64
H_DIM = 256
OUT_DIM = 784

# Discriminator
_D = nn.Sequential(
    nn.Linear(OUT_DIM, H_DIM),
    nn.LeakyReLU(0.2),
    nn.Linear(H_DIM, H_DIM),
    nn.LeakyReLU(0.2),
)

Dz = nn.Sequential(
    _D,
    nn.Linear(H_DIM, 1),
    nn.Sigmoid()
)

Dy = nn.Sequential(
    _D,
    nn.Linear(H_DIM, 10),
)

# Generator 
G = nn.Sequential(
    nn.Linear(LATENT_SIZE, H_DIM),
    nn.LeakyReLU(0.2),
    nn.Linear(H_DIM, H_DIM),
    nn.LeakyReLU(0.2),
    nn.Linear(H_DIM, OUT_DIM),
    nn.Tanh()
)

if USE_CUDA:
    D.cuda()
    G.cuda()

In [38]:
criterion_z = nn.BCELoss() # binary classification
criterion_y = nn.CrossEntropyLoss() # digit classification
d_optimizer = torch.optim.Adam(
    set(Dz.parameters()).union(set(Dy.parameters())), 
    lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)

In [41]:
# Start training
N_EPOCHS = 200

for epoch in range(N_EPOCHS):
    for i, (images, labels) in enumerate(train_loader):        
        batch_size = images.size(0)
        images = to_var(images.view(batch_size, -1))
        
        labels = to_var(labels)
        real_labels = to_var(torch.ones(batch_size))
        fake_labels = to_var(torch.zeros(batch_size))

        #============= discriminator =============#
        # k=1 (least expensive to train)     
        
        # real images
        outputs_z, outputs_y = Dz(images).squeeze(), Dy(images).squeeze()
        class_loss = criterion_y(outputs_y, labels) 
        d_loss_real = criterion_z(outputs_z, real_labels) + class_loss
        real_score = outputs_z
        
        # fake images
        z = to_var(torch.randn(batch_size, LATENT_SIZE))
        fake_images = G(z)
        outputs_z = Dz(fake_images).squeeze()
        d_loss_fake = criterion_z(outputs_z, fake_labels)
        fake_score = outputs_z
        
        d_loss = d_loss_real + d_loss_fake
        Dz.zero_grad()
        Dy.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        #=============== generator ===============#
        # fake images
        z = to_var(torch.randn(batch_size, LATENT_SIZE))
        fake_images = G(z)
        outputs = Dz(fake_images).squeeze()
        
        g_loss = criterion(outputs, real_labels)
        
        Dz.zero_grad()
        Dy.zero_grad()
        G.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if (i+1) % 100 == 0:
            print('Epoch [%d/%d], Step[%d/%d], class_loss: %.4f, d_loss: %.4f, '
                  'g_loss: %.4f, D(x): %.2f, D(G(z)): %.2f' 
                  %(epoch, 200, i+1, 600, class_loss.data, d_loss.data[0], g_loss.data[0],
                    real_score.data.mean(), fake_score.data.mean()))
        
    fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images.data), './data/fake_images-%d.png' %(epoch+1))

Epoch [0/200], Step[100/600], class_loss: 0.0010, d_loss: 0.7989, g_loss: 1.8391, D(x): 0.71, D(G(z)): 0.23
Epoch [0/200], Step[200/600], class_loss: 0.0175, d_loss: 0.7396, g_loss: 2.0045, D(x): 0.78, D(G(z)): 0.26
Epoch [0/200], Step[300/600], class_loss: 0.0010, d_loss: 0.8119, g_loss: 1.9115, D(x): 0.77, D(G(z)): 0.29
Epoch [0/200], Step[400/600], class_loss: 0.0016, d_loss: 0.7518, g_loss: 1.6935, D(x): 0.72, D(G(z)): 0.20
Epoch [0/200], Step[500/600], class_loss: 0.0299, d_loss: 0.9119, g_loss: 1.5215, D(x): 0.84, D(G(z)): 0.35
Epoch [1/200], Step[100/600], class_loss: 0.0042, d_loss: 0.6506, g_loss: 2.0551, D(x): 0.77, D(G(z)): 0.21
Epoch [1/200], Step[200/600], class_loss: 0.0008, d_loss: 0.7460, g_loss: 2.0849, D(x): 0.77, D(G(z)): 0.28
Epoch [1/200], Step[300/600], class_loss: 0.0012, d_loss: 0.8240, g_loss: 1.8011, D(x): 0.75, D(G(z)): 0.27
Epoch [1/200], Step[400/600], class_loss: 0.0123, d_loss: 0.8404, g_loss: 1.9831, D(x): 0.77, D(G(z)): 0.30
Epoch [1/200], Step[500/600]

KeyboardInterrupt: 

In [47]:
import numpy as np

batch_size = 1
N_INTERP_IMGS = 10

for i in range(N_INTERP_IMGS):
    z1 = to_var(torch.randn(batch_size, LATENT_SIZE))
    z2 = to_var(torch.randn(batch_size, LATENT_SIZE))
    fake_images1, fake_images2 = G(z1), G(z2)

    alphas = [0, 0.2, 0.4, 0.6, 0.8, 1]
    z = torch.stack([alpha * z1 + (1 - alpha) * z2 for alpha in alphas])
    fake_images = G(z)

    fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images.data), './data2/interp_fake_images-%d.png' % (i + 1))

Great now we are ready to begin modeling. Performance-wise, you want tune your hyperparameters based on the **evidence lower bound (ELBO)**. Recall that the ELBO is given by:

$$ELBO = \mathbb{E}_{q(\mathbf{z} ; \lambda)} [\log p(\mathbf{x} \,|\,\mathbf{z} ; \theta)] - \mathbb{KL}[q(\mathbf{z};\lambda) \, \Vert \, p(\mathbf{z})]$$

The variational parameters are given by running the encoder over the input, i..e. $\lambda = encoder(\mathbf{x};\phi)$. The generative model (i.e. decoder) is parameterized by $\theta$. Since we are working with binarized digits, $\log p(x \, | \, \mathbf{z} ; \theta)$ is given by:

$$ \log p(x \, | \, \mathbf{z} ; \theta) = \sum_{i=1}^{784} \log \sigma(\mathbf{h})_{i} $$

where $\mathbf{h}$ is the final layer of the generative model (i.e. 28*28 = 784 dimensionval vector), and $\sigma(\cdot)$ is the sigmoid function. 

For the baseline model in this assignment you will be using a spherical normal prior, i.e. $p(\mathbf{z}) = \mathcal{N}(\mathbf{0}, \mathbf{I})$. The variational family will also be normal, i.e. $q(\mathbf{z} ; \lambda) = \mathcal{N}(\boldsymbol{\mu}, \log \boldsymbol \sigma^2)$ (here we will work with normal families with diagonal covariance). The KL-divergence between the variational posterior $q(\mathbf{z})$ and the prior $p(\mathbf{z})$ has a closed-form analytic solution, which is available in the original VAE paper referenced above. (If you are using the torch distributions package they will automatically calculate it for you, however you will need to use pytorch 0.4).

For GANs you should use the same data in its continuous form. Here use the same prior, but use a multi-layer network to map to a continous 28x28 output space. Then use a multilayer discriminator to classify. 

For both models you may also consider trying a deconvolutional network (as in DCGAN) to produce output from the latent variable.

## Visualizations

In addition to quantitative metrics (i.e. ELBO), we are also going to ask you to do some qualitative analysis via visualizations. Please include the following in your report:

1. Generate a bunch of digits from your generative model (sample $\mathbf{z} \sim p(\mathbf{z})$, then $\mathbf{x} \sim p (\mathbf{x} \, | \, \mathbf{z} ; \theta$))
2. Sample two random latent vectors $\mathbf{z}_1, \mathbf{z}_2 \sim p(\mathbf{z})$, then sample from their interpolated values, i.e. $\mathbf{z} \sim p (\mathbf{x} \, | \, \alpha\mathbf{z}_1 + (1-\alpha)\mathbf{z}_2; \theta$) for $\alpha = \{0, 0.2, 0.4, 0.6, 0.8 ,1.0 \}$.
3. Train a VAE with 2 latent dimensions. Make a scatter plot of the variational means, $\mu_1, \mu_2$, where the color
corresponds to the digit.
4. With the same model as in (3), pick a 2d grid around the origin (0,0), e.g. with
`np.meshgrid(np.linspace(-2, 2, 10), np.linspace(-2, 2, 10)`. For each point in the grid $(z_1, z_2)$, generate
$\mathbf{x}$ and show the corresponding digit in the 2d plot. For an example see [here](http://fastforwardlabs.github.io/blog-images/miriam/tableau.1493x693.png) (the right image)
