# Variational Autoencoders in PyTorch and Pyro

This notebook follows the Pyro introductory tutorial on building variational autoencoders.

#### Setup: library imports, smoke test, data loaders

In [1]:
import os

import numpy as np
import torch
from pyro.contrib.examples.util import MNIST
import torch.nn as nn
import torchvision.transforms as transforms

import pyro
import pyro.distributions as dist
import pyro.contrib.examples.util  # patches torchvision
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

In [2]:
assert pyro.__version__.startswith('1.9.1')
pyro.distributions.enable_validation(False)
pyro.set_rng_seed(0)
# Enable smoke test - run the notebook cells on CI.
smoke_test = 'CI' in os.environ

We will be experimenting on the MNIST dataset, which PyTorch has built-in dataloaders for.

In [3]:
# set up data loaders to load and batch MNIST data
def setup_data_loaders(batch_size=128, use_cuda=False):
    root = './data'
    download = True
    trans = transforms.ToTensor()
    train_set = MNIST(root=root, train=True, transform=trans,
                      download=download)
    test_set = MNIST(root=root, train=False, transform=trans)

    kwargs = {'num_workers': 1, 'pin_memory': use_cuda}
    train_loader = torch.utils.data.DataLoader(dataset=train_set,
        batch_size=batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(dataset=test_set,
        batch_size=batch_size, shuffle=False, **kwargs)
    return train_loader, test_loader

## Encoder and Decoder modules in PyTorch

The decoder module takes in the latent layer z and applies a neural network to decode it into the original space x. Here we will use a simple network with two linear layers.

The encoder module takes the input data and encodes it into the latent space. In a VAE, this is the variational family q(z|x), which will approximate the posterior p(z|x).

In [5]:
class Decoder(nn.Module):
    def __init__(self, z_dim, hidden_dim):
        super().__init__()
        # set up the two linear layers
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, 784)  # 784 = # of pixels in flattened MNIST images (28x28)
        # set up the non-linear activation functions
        self.softplus = nn.Softplus()
        self.sigmoid = nn.Sigmoid()

    def forward(self, z):
        # define the forward computation on the latent z to the output
        # first compute the hidden units
        hidden = self.softplus(self.fc1(z))
        # return parameter for the output Bernoulli
        # each is of size batch_size * 784
        loc_img = self.sigmoid(self.fc21(hidden))
        return loc_img


In [15]:
class Encoder(nn.Module):
    def __init__(self, z_dim, hidden_dim):
        super().__init__()
        # set up the three linear layers we will use to encode the input
        # remember that these are encoding our variational parameters, the
        #   mean and variance of the normal distribution (z_loc and z_scale)
        # fc21 will be used for the mean and fc22 for the variance
        self.fc1 = nn.Linear(784, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, z_dim)
        self.fc22 = nn.Linear(hidden_dim, z_dim)
        # set up the non-linear activation function
        self.softplus = nn.Softplus()
    
    def forward(self, x):
        # define the forward computation on the image x
        # first shape the mini-batch to flatten pixels
        x = x.reshape(-1, 784)
        # then compute the hidden units
        hidden = self.softplus(self.fc1(x))
        # then return a mean vector and a positive square root covariance,
        #   each of size batch_size * z_dim
        z_loc = self.fc21(hidden)
        z_scale = torch.exp(self.fc22(hidden))
        return(z_loc, z_scale)

## VAE Model and Guide

Now we define the probabilistic model p(x|z)p(z) for our VAE. We then define the guide (variational family). The model is what will be learned by the decoder (recovering z from x) and the guide will be learned by the encoder (i.e. the variational family q(z|x)). We define our VAE PyTorch module with these components. 

It also contains a helper function for reconstructing images, i.e. taking an input image and then running it through the encoder and then decoder to reconstruct the image. This will be used for training and testing.


In [12]:
# define our VAE class with the model and guide
class VAE(nn.Module):
    # by default our latent space is 50-dimensional and we use 400 hidden units
    def __init__(self, z_dim=50, hidden_dim=400, use_cuda=False):
        super().__init__()
        # create the encoder and decoder networks, using the classes defined above
        self.encoder = Encoder(z_dim, hidden_dim)
        self.decoder = Decoder(z_dim, hidden_dim)

        if use_cuda:
            # calling cuda here puts the encoder and decoder parameters into gpu memory
            self.cuda()
        self.use_cuda = use_cuda
        self.z_dim = z_dim

    # define the model p(x|z)p(z)
    def model(self, x):
        # register PyTorch module 'decoder' with Pyro
        pyro.module("decoder", self.decoder)  # the first arg is what we name the module
        with pyro.plate("data", x.shape[0]):  
            # pyro.plate designates independence among the samples in x
            # setup hyperparameters for prior p(z)
            # the new_zeros and new_ones ensure that the created tensors are on the same gpu device
            z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))
            z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim)))
            # sample from prior p(z) (value will be sampled by guide when computing the ELBO)
            # to_event(1) declares the dimensions of z as dependent, so sampled from an MVN
            #   with diagonal covarance. this is to avoid needing another pyro.plate call. 
            z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
            # decode the latent code z, using our decoder, i.e. p(x|z)
            loc_img = self.decoder(z)
            # score against the actual images (how close is our reconstruction to the original?)
            pyro.sample("obs", dist.Bernoulli(loc_img).to_event(1), obs=x.reshape(-1, 784))

    # define the guide (i.e. variationsl distribution) q(z|x)
    def guide(self, x):
        # register PyTorch module 'encoder' with Pyro (again, "encoder" is the name we assign)
        pyro.module("encoder", self.encoder)
        with pyro.plate("data", x.shape[0]):
            # use the encoder to get the parameters used to define q(z|x)
            z_loc, z_scale = self.encoder(x)
            # sample the latent code z
            pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
    
    # define a helper function for reconstructing images
    def reconstruct_img(self, x):
        # encode image x
        z_loc, z_scale = self.encoder(x)
        # sample in latent space
        z = dist.Normal(z_loc, z_scale).sample()
        # decode the image
        loc_img = self.decoder(z)
        return loc_img


## Inference

Now we create an instance of our VAE module, define an optimizer and our training and test functions, and run the model.

In [8]:
# run options
LEARNING_RATE = 1.0e-3
USE_CUDA = False
NUM_EPOCHS = 1 if smoke_test else 5
TEST_FREQUENCY = 5

In [9]:
def train(svi, train_loader, use_cuda=False):
    # initialize loss accumulator
    epoch_loss = 0.
    # do a training epoch over each mini-batch x returned
    # by the data loader
    for x, _ in train_loader:
        # if on GPU put mini-batch into CUDA memory
        if use_cuda:
            x = x.cuda()
        # do ELBO gradient and accumulate loss
        epoch_loss += svi.step(x)

    # return epoch loss
    normalizer_train = len(train_loader.dataset)
    total_epoch_loss_train = epoch_loss / normalizer_train
    return total_epoch_loss_train


In [10]:
# similar to our training loop, but calls svi.evaluate_loss instead of svi.step
#   because a gradient shoud not be computed
def evaluate(svi, test_loader, use_cuda=False):
    # initialize loss accumulator
    test_loss = 0.
    # compute the loss over the entire test set
    for x, _ in test_loader:
        # if on GPU put mini-batch into CUDA memory
        if use_cuda:
            x = x.cuda()
        # compute ELBO estimate and accumulate loss
        test_loss += svi.evaluate_loss(x)
    normalizer_test = len(test_loader.dataset)
    total_epoch_loss_test = test_loss / normalizer_test
    return total_epoch_loss_test

In [26]:
# initialize data loaders
train_loader, test_loader = setup_data_loaders(batch_size=256, use_cuda=USE_CUDA)

# clear param store
pyro.clear_param_store()

# setup the VAE
vae = VAE(use_cuda=USE_CUDA)

# setup to optimizer
adam_args = {"lr": LEARNING_RATE}
optimizer = Adam(adam_args)

# setup the inference algorithm
svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())

train_elbo = []
test_elbo = []
# training loop
for epoch in range(NUM_EPOCHS):
    total_epoch_loss_train = train(svi, train_loader, use_cuda=USE_CUDA)
    train_elbo.append(-total_epoch_loss_train)
    print("[epoch %03d]  average training loss: %.4f" % (epoch, total_epoch_loss_train))

    if epoch % TEST_FREQUENCY == 0:
        # report test diagnostics
        total_epoch_loss_test = evaluate(svi, test_loader, use_cuda=USE_CUDA)
        test_elbo.append(-total_epoch_loss_test)
        print("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test))


[epoch 000]  average training loss: 193.8537
[epoch 000] average test loss: 158.7788
[epoch 001]  average training loss: 148.1595
[epoch 002]  average training loss: 133.7032
[epoch 003]  average training loss: 125.0177
[epoch 004]  average training loss: 119.5332


## Visualization

Here we plot some reconstructed images from our model, and also plot a tSNE of our latent space, with colors corresponding to different classes. The latter helps us see if our embedded space is capturing differences in the classes, i.e. if it has captured latent structure in the data, as we had hoped.

We will use some code from: https://github.com/pyro-ppl/pyro/blob/daea9a65ac6aefabce110e0f3a79c483138c3d08/examples/vae/utils/vae_plots.py

In [27]:
# plot code from pyro github
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import torch
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
from sklearn.manifold import TSNE
from pathlib import Path
import pandas as pd
import seaborn as sns
import visdom


def plot_conditional_samples_ssvae(ssvae, visdom_session):
    """
    This is a method to do conditional sampling in visdom
    """
    vis = visdom_session
    ys = {}
    for i in range(10):
        ys[i] = torch.zeros(1, 10)
        ys[i][0, i] = 1
    xs = torch.zeros(1, 784)

    for i in range(10):
        images = []
        for rr in range(100):
            # get the loc from the model
            sample_loc_i = ssvae.model(xs, ys[i])
            img = sample_loc_i[0].view(1, 28, 28).cpu().data.numpy()
            images.append(img)
        vis.images(images, 10, 2)


def plot_llk(train_elbo, test_elbo):
    Path("vae_results").mkdir(parents=True, exist_ok=True)

    plt.figure(figsize=(30, 10))
    sns.set_style("whitegrid")
    df1 = pd.DataFrame(
        {
            "Epoch": train_elbo.keys(),
            "ELBO": [-val for val in train_elbo.values()],
            "dataset": "Train",
        }
    )
    df2 = pd.DataFrame(
        {
            "Epoch": test_elbo.keys(),
            "ELBO": [-val for val in test_elbo.values()],
            "dataset": "Test",
        }
    )
    df = pd.concat([df1, df2], axis=0)

    # Create the FacetGrid with scatter plot
    g = sns.FacetGrid(df, height=4, aspect=1.5, hue="dataset")
    g.map(sns.scatterplot, "Epoch", "ELBO")
    g.map(sns.lineplot, "Epoch", "ELBO", linestyle="--")
    g.ax.yaxis.get_major_locator().set_params(integer=True)
    g.add_legend()
    plt.savefig("./vae_results/test_elbo_vae.png")
    plt.close("all")


def plot_vae_samples(vae, visdom_session):
    vis = visdom_session
    x = torch.zeros([1, 784])
    for i in range(10):
        images = []
        for rr in range(100):
            # get loc from the model
            sample_loc_i = vae.model(x)
            img = sample_loc_i[0].view(1, 28, 28).cpu().data.numpy()
            images.append(img)
        vis.images(images, 10, 2)


def mnist_test_tsne(vae=None, test_loader=None):
    """
    This is used to generate a t-sne embedding of the vae
    """
    name = "VAE"
    data = test_loader.dataset.test_data.float()
    mnist_labels = test_loader.dataset.test_labels
    z_loc, z_scale = vae.encoder(data)
    plot_tsne(z_loc, mnist_labels, name)


def mnist_test_tsne_ssvae(name=None, ssvae=None, test_loader=None):
    """
    This is used to generate a t-sne embedding of the ss-vae
    """
    if name is None:
        name = "SS-VAE"
    data = test_loader.dataset.test_data.float()
    mnist_labels = test_loader.dataset.test_labels
    z_loc, z_scale = ssvae.encoder_z([data, mnist_labels])
    plot_tsne(z_loc, mnist_labels, name)


def plot_tsne(z_loc, classes, name):
    model_tsne = TSNE(n_components=2, random_state=0)
    z_states = z_loc.detach().cpu().numpy()
    z_embed = model_tsne.fit_transform(z_states)
    classes = classes.detach().cpu().numpy()
    fig = plt.figure()
    for ic in range(10):
        ind_vec = np.zeros_like(classes)
        ind_vec[:, ic] = 1
        ind_class = classes[:, ic] == 1
        color = plt.cm.Set1(ic)
        plt.scatter(z_embed[ind_class, 0], z_embed[ind_class, 1], s=10, color=color)
        plt.title("Latent Variable T-SNE per Class")
        fig.savefig("./vae_results/" + str(name) + "_embedding_" + str(ic) + ".png")
    fig.savefig("./vae_results/" + str(name) + "_embedding.png")

In [29]:
# plot some sample images
# vis = visdom.Visdom()
# plot_vae_samples(vae, vis)

In [30]:
# tSNE plot
# mnist_test_tsne(vae=vae, test_loader=test_loader)