# Conditional Variational Autoencoders (CVAEs)

This notebook follows the Pyro tutorial on CVAEs (https://pyro.ai/examples/cvae.html).

CVAEs extend VAEs by conditioning on an output label y. This allows the use of VAE machinery in structured prediction tasks. The CVAE is composed of several multilayer perceptrons: the conditional prior network p(z|x), the recognition network q(z|x,y), and the generation network p(y|x,z).

The CVAE can also be conceptualized in other ways, such as enforcing constraints c rather than generating predictions y. In this case, we may have the MLPs p(z|c), p(x|z,c), and q(z|x,c).

In this notebook, we will follow the first formulation. Our goal will be to predict an MNIST digit given only the bottom-left quadrant of the image. We will see that the CVAE offers an advantage over traditional neural networks, which can only make a single prediction. In some cases, the true digit will be genuinely uncertain given only the bottom-left quadrant. A traditional NN would give a blurred combination of the possible digits. The CVAE, on the other hand, will reflect the uncertainty, because we draw images from its posterior distribution, so that we will get a mix of clear digit images, with the mix proportions reflecting their posterior probabilities.

### Preliminaries

First we will load the MNIST data. Then we will implement a simple baseline MLP to compare to.

In [11]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, functional

import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.contrib.examples.util import MNIST

In [6]:
# classes and functions for loading MNIST data and preparing for CVAE

# MNIST dataset loader to load data for CVAE prediction task
class CVAEMNIST(Dataset):
    def __init__(self, root, train=True, transform=None, download=False):
        self.original = MNIST(root, train=train, download=download)
        self.transform = transform

    def __len__(self):
        return len(self.original)

    def __getitem__(self, item):
        image, digit = self.original[item]
        sample = {'original': image, 'digit': digit}
        if self.transform:
            sample = self.transform(sample)

        return sample

# helper function to convert the data to pytorch tensors 
class ToTensor:
    def __call__(self, sample):
        sample['original'] = functional.to_tensor(sample['original'])
        sample['digit'] = torch.as_tensor(np.asarray(sample['digit']),
                                          dtype=torch.int64)
        return sample

# class to mask quadrants of MNIST images for our prediction task
class MaskImages:
    """This torchvision image transformation prepares the MNIST digits to be
    used in the tutorial. Depending on the number of quadrants to be used as
    inputs (1, 2, or 3), the transformation masks the remaining (3, 2, 1)
    quadrant(s) setting their pixels with -1. Additionally, the transformation
    adds the target output in the sample dict as the complementary of the input
    """
    def __init__(self, num_quadrant_inputs, mask_with=-1):
        if num_quadrant_inputs <= 0 or num_quadrant_inputs >= 4:
            raise ValueError('Number of quadrants as inputs must be 1, 2 or 3')
        self.num = num_quadrant_inputs
        self.mask_with = mask_with

    def __call__(self, sample):
        tensor = sample['original'].squeeze()
        out = tensor.detach().clone()
        h, w = tensor.shape

        # removes the bottom left quadrant from the target output
        out[h // 2:, :w // 2] = self.mask_with
        # if num of quadrants to be used as input is 2,
        # also removes the top left quadrant from the target output
        if self.num == 2:
            out[:, :w // 2] = self.mask_with
        # if num of quadrants to be used as input is 3,
        # also removes the top right quadrant from the target output
        if self.num == 3:
            out[:h // 2, :] = self.mask_with

        # now, sets the input as complementary
        inp = tensor.clone()
        inp[out != -1] = self.mask_with

        sample['input'] = inp
        sample['output'] = out
        return sample


# function to load, prepare, and mask data using the classes above
def get_data(num_quadrant_inputs, batch_size):
    transforms = Compose([
        ToTensor(),
        MaskImages(num_quadrant_inputs=num_quadrant_inputs)
    ])
    datasets, dataloaders, dataset_sizes = {}, {}, {}
    for mode in ['train', 'val']:
        datasets[mode] = CVAEMNIST(
            '../data',
            download=True,
            transform=transforms,
            train=mode == 'train'
        )
        dataloaders[mode] = DataLoader(
            datasets[mode],
            batch_size=batch_size,
            shuffle=mode == 'train',
            num_workers=0
        )
        dataset_sizes[mode] = len(datasets[mode])

    return datasets, dataloaders, dataset_sizes

In [5]:
# baseline network to compare to -- simple feedforward MLP
class BaselineNet(nn.Module):
    def __init__(self, hidden_1, hidden_2):
        super().__init__()
        self.fc1 = nn.Linear(784, hidden_1)
        self.fc2 = nn.Linear(hidden_1, hidden_2)
        self.fc3 = nn.Linear(hidden_2, 784)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(-1, 784)
        hidden = self.relu(self.fc1(x))
        hidden = self.relu(self.fc2(hidden))
        y = torch.sigmoid(self.fc3(hidden))
        return y

In [12]:
# define a modified binary cross entropy loss, which only computes the loss for non-masked pixels
class MaskedBCELoss(nn.Module):
    def __init__(self, masked_with=-1):
        super().__init__()
        self.masked_with = masked_with

    def forward(self, input, target):
        target = target.view(input.shape)
        loss = F.binary_cross_entropy(input, target, reduction='none')
        loss[target == self.masked_with] = 0
        return loss.sum()


### CVAE Implementation

The CVAE doesn't look too different from the VAE (see vae.ipynb) in Pyro. The conditional prior network p(z|x) and the recognition network q(z|x,y) are both like VAE Encoders, and the generation network p(y|x,z) is like a VAE decoder.

We first define the Encoder and Decoder, and then the CVAE with those modules.

In [8]:
# Encoder network
class Encoder(nn.Module):
    def __init__(self, z_dim, hidden_1, hidden_2):
        super().__init__()
        self.fc1 = nn.Linear(784, hidden_1)
        self.fc2 = nn.Linear(hidden_1, hidden_2)
        self.fc31 = nn.Linear(hidden_2, z_dim)  # for learning mean
        self.fc32 = nn.Linear(hidden_2, z_dim)  # for learning covariance
        self.relu = nn.ReLU()
    
    def forward(self, x, y):
        # put x and y together in the same image for simplification
        xc = x.clone()
        xc[x == -1] = y[x == -1]
        xc = xc.view(-1, 784)
        # then compute the hidden units
        hidden = self.relu(self.fc1(xc))
        hidden = self.relu(self.fc2(hidden))
        # then return a mean vector and a positive square root covariance
        # each of size batch_size * z_dim
        z_loc = self.fc31(hidden)
        z_scale = torch.exp(self.fc32(hidden))
        return z_loc, z_scale

In [9]:
# Decoder network
class Decoder(nn.Module):
    def __init__(self, z_dim, hidden_1, hidden_2):
        super().__init__()
        self.fc1 = nn.Linear(z_dim, hidden_1)
        self.fc2 = nn.Linear(hidden_1, hidden_2)
        self.fc3 = nn.Linear(hidden_2, 784)
        self.relu = nn.ReLU()

    def forward(self, z):
        y = self.relu(self.fc1(z))
        y = self.relu(self.fc2(y))
        y = torch.sigmoid(self.fc3(y))
        return y

In [10]:
# now we define the CVAE class
# it has the prior, generation, and recognition modules, 
#   which are of class Encoder, Decoder, and Encoder, respectively
# it also includes a pre-trained baseline network for comparison and as a prior

class CVAE(nn.Module):
    def __init__(self, z_dim, hidden_1, hidden_2, pre_trained_baseline_net):
        super().__init__()
        # The CVAE is composed of multiple MLPs, such as recognition network
        # qφ(z|x, y), (conditional) prior network pθ(z|x), and generation
        # network pθ(y|x, z). Also, CVAE is built on top of the NN: not only
        # the direct input x, but also the initial guess y_hat made by the NN
        # are fed into the prior network.
        self.baseline_net = pre_trained_baseline_net
        self.prior_net = Encoder(z_dim, hidden_1, hidden_2)
        self.generation_net = Decoder(z_dim, hidden_1, hidden_2)
        self.recognition_net = Encoder(z_dim, hidden_1, hidden_2)

    def model(self, xs, ys=None):
        # register this pytorch module and all of its submodules with pyro
        pyro.module("generation_net", self)
        batch_size = xs.shape[0]

        with pyro.plate("data"):
            # Prior network uses the baseline predictions as an initial guess.
            # This is the generative process with recurrent connection.
            with torch.no_grad():
                # this ensures that the training process does not change the baseline network
                # view(ns) method returns a tensor with a new shape "ns"
                y_hat = self.baseline_net(xs).view(xs.shape)
            
            # sample handwriting style z from the prior distribution, which is modulated by the input xs (x)
            prior_loc, prior_scale = self.prior_net(xs, y_hat)
            # recall that to_event will force the z's to be learned as one MVN rather than separate normals
            zs = pyro.sample('z', dist.Normal(prior_loc, prior_scale).to_event(1))

            # the output y is generated from the distribution pθ(y|x, z)
            loc = self.generation_net(zs)

            if ys is not None:
                # In training, we will only sample in the masked image
                mask_loc = loc[(xs == -1).view(-1, 784)].view(batch_size, -1)
                mask_ys = ys[xs == -1].view(batch_size, -1)
                pyro.sample('y', dist.Bernoulli(mask_loc).to_event(1), obs=mask_ys)
            else:
                # In testing, no need to sample: the output is already a
                # probability in [0, 1] range, which better represent pixel
                # values considering grayscale. If we sample, we will force
                # each pixel to be  either 0 or 1, killing the grayscale
                pyro.deterministic('y', loc.detach())
        
        # return the loc so we can visualize it later
        return loc

    # now we define the guide. 
    # at training time this is the recognition_net q(z|y,x).
    # at testing time we do not have ys, so we use the prior network
    def guide(self, xs, ys=None):
        with pyro.plate("data"):
            if ys is None:
                # at inference time, ys is not provided. In that case,
                # the model uses the prior network
                y_hat = self.baseline_net(xs).view(xs.shape)
                loc, scale = self.prior_net(xs, y_hat)
            else:
                # at training time, uses the variational distribution
                # q(z|x,y) = normal(loc(x,y),scale(x,y))
                loc, scale = self.recognition_net(xs, ys)

            pyro.sample("z", dist.Normal(loc, scale).to_event(1))
