In [1]:
import os

import numpy as np
import torch
import torchvision.datasets as dset
import torch.nn as nn
import torchvision.transforms as transforms

import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

In [2]:
pyro.enable_validation(True)
pyro.distributions.enable_validation(False)
pyro.set_rng_seed(0)
# Enable smoke test - run the notebook cells on CI.
smoke_test = 'CI' in os.environ

In [3]:
# for loading and batching MNIST dataset
def setup_data_loaders(batch_size=128, use_cuda=False):
    root = './data'
    download = True
    trans = transforms.ToTensor()
    train_set = dset.MNIST(root=root, train=True, transform=trans,
                           download=download)
    test_set = dset.MNIST(root=root, train=False, transform=trans)

    kwargs = {'num_workers': 1, 'pin_memory': use_cuda}
    train_loader = torch.utils.data.DataLoader(dataset=train_set,
        batch_size=batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(dataset=test_set,
        batch_size=batch_size, shuffle=False, **kwargs)
    return train_loader, test_loader

In [4]:
class Decoder(nn.Module):
    def __init__(self, z_dim, hidden_dim):
        super(Decoder, self).__init__()
        # setup the two linear transformations used
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, 784)
        # setup the non-linearities
        self.softplus = nn.Softplus()
        self.sigmoid = nn.Sigmoid()

    def forward(self, z):
        # define the forward computation on the latent z
        # first compute the hidden units
        hidden = self.softplus(self.fc1(z))
        # return the parameter for the output Bernoulli
        # each is of size batch_size x 784
        loc_img = self.sigmoid(self.fc21(hidden))
        return loc_img

In [5]:
class Encoder(nn.Module):
    def __init__(self, z_dim, hidden_dim):
        super(Encoder, self).__init__()
        # setup the three linear transformations used
        self.fc1 = nn.Linear(784, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, z_dim)
        self.fc22 = nn.Linear(hidden_dim, z_dim)
        # setup the non-linearities
        self.softplus = nn.Softplus()

    def forward(self, x):
        # define the forward computation on the image x
        # first shape the mini-batch to have pixels in the rightmost dimension
        x = x.reshape(-1, 784)
        # then compute the hidden units
        hidden = self.softplus(self.fc1(x))
        # then return a mean vector and a (positive) square root covariance
        # each of size batch_size x z_dim
        z_loc = self.fc21(hidden)
        z_scale = torch.exp(self.fc22(hidden))
        return z_loc, z_scale

In [28]:
#Now define VAE class:
class VAE(nn.Module):
    def __init__(self, hidden_dim = 400, z_dim = 50, use_cuda = False):
        super(VAE, self).__init__()
        self.encoder = Encoder(z_dim, hidden_dim)
        self.decoder = Decoder(z_dim, hidden_dim)
        
        if use_cuda:
            self.cuda()
        self.z_dim = z_dim
        self.use_cuda = use_cuda

    #define the forward model: p(x|z)p(z)
    def model(self, x):
        pyro.module('decoder', self.decoder)
        with pyro.iarange('data', x.size(0)):
            z_loc = x.new_zeros(torch.Size((x.size(0), self.z_dim)))
            z_var = x.new_ones(torch.Size((x.size(0), self.z_dim)))
            #sample from prior(p(z))
            z = pyro.sample('latent', dist.Normal(z_loc,z_var).independent(1))
            #now pass it through the forward layer:
            loc_img = self.decoder.forward(z)
            #now observe the actual samples:
            pyro.sample('obs', dist.Bernoulli(loc_img).independent(1), obs=x.reshape(-1,784))
    #now define the guide model: p(z|x)
    def guide(self, x):
        pyro.module('encoder', self.encoder)
        with pyro.iarange('data', x.shape[0]):
            #use the encoder to get the parameters used to define q(z|x)
            z_loc, z_var = self.encoder.forward(x)
            #now sample from gaussian to generate z:
            pyro.sample('latent', dist.Normal(z_loc, z_var).independent(1))
            
    def reconstruct_img(self, x):
       # encode image x
        z_loc, z_scale = self.encoder(x)
        # sample in latent space
        z = dist.Normal(z_loc, z_scale).sample()
        # decode the image (note we don't sample in image space)
        loc_img = self.decoder(z)
        return loc_img
        

In [29]:
def train(svi, train_loader, use_cuda=False):
    # initialize loss accumulator
    epoch_loss = 0.
    # do a training epoch over each mini-batch x returned
    # by the data loader
    for _, (x, _) in enumerate(train_loader):
        # if on GPU put mini-batch into CUDA memory
        if use_cuda:
            x = x.cuda()
        # do ELBO gradient and accumulate loss
        epoch_loss += svi.step(x)

    # return epoch loss
    normalizer_train = len(train_loader.dataset)
    total_epoch_loss_train = epoch_loss / normalizer_train
    return total_epoch_loss_train

In [30]:
def evaluate(svi, test_loader, use_cuda=False):
    # initialize loss accumulator
    test_loss = 0.
    # compute the loss over the entire test set
    for i, (x, _) in enumerate(test_loader):
        # if on GPU put mini-batch into CUDA memory
        if use_cuda:
            x = x.cuda()
        # compute ELBO estimate and accumulate loss
        test_loss += svi.evaluate_loss(x)
    normalizer_test = len(test_loader.dataset)
    total_epoch_loss_test = test_loss / normalizer_test
    return total_epoch_loss_test

In [31]:
##Now let's do inference:
# Run options
LEARNING_RATE = 1.0e-3
USE_CUDA = False

# Run only for a single iteration for testing
NUM_EPOCHS = 1 if smoke_test else 100
TEST_FREQUENCY = 5

In [None]:
pyro.clear_param_store()
train_loader, test_loader = setup_data_loaders(batch_size=256, use_cuda=USE_CUDA)

# setup the VAE
vae = VAE(use_cuda=USE_CUDA)

# setup the optimizer
adam_args = {"lr": LEARNING_RATE}
optimizer = Adam(adam_args)

# setup the inference algorithm
svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())

train_elbo = []
test_elbo = []
# training loop
for epoch in range(NUM_EPOCHS):
    total_epoch_loss_train = train(svi, train_loader, use_cuda=USE_CUDA)
    train_elbo.append(-total_epoch_loss_train)
    print("[epoch %03d]  average training loss: %.4f" % (epoch, total_epoch_loss_train))

    if epoch % TEST_FREQUENCY == 0:
        # report test diagnostics
        total_epoch_loss_test = evaluate(svi, test_loader, use_cuda=USE_CUDA)
        test_elbo.append(-total_epoch_loss_test)
        print("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test))

[epoch 000]  average training loss: 192.1734
[epoch 000] average test loss: 157.7320
[epoch 001]  average training loss: 147.2542
[epoch 002]  average training loss: 133.9142
[epoch 003]  average training loss: 125.7661
[epoch 004]  average training loss: 120.0069
[epoch 005]  average training loss: 116.3472
[epoch 005] average test loss: 114.3448
[epoch 006]  average training loss: 113.8410
[epoch 007]  average training loss: 112.1232
[epoch 008]  average training loss: 110.8459
[epoch 009]  average training loss: 109.8398
[epoch 010]  average training loss: 109.0897
[epoch 010] average test loss: 108.2109
[epoch 011]  average training loss: 108.3587
[epoch 012]  average training loss: 107.7644
[epoch 013]  average training loss: 107.3083
[epoch 014]  average training loss: 106.9059
[epoch 015]  average training loss: 106.5298
[epoch 015] average test loss: 105.8550
[epoch 016]  average training loss: 106.1874
[epoch 017]  average training loss: 105.9015
[epoch 018]  average training 