In [None]:
import numpy as np
import torch
from pyro.contrib.examples.util import MNIST
import torch.nn as nn
import torchvision.transforms as transforms

import pyro
import pyro.distributions as dist
import pyro.contrib.examples.util
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

from matplotlib import pyplot as plt

In [None]:
# We set enable_valudation to False so that the Bernoulli distribution
# can accept the images even though the values are in [0, 1] instead of
# only values that are 0 or 1.
pyro.distributions.enable_validation(False)

In [None]:
class Encoder(nn.Module):
    def __init__(self, z_dim):
        super().__init__()
        self.fc1 = nn.Linear(784, 400)
        self.fc2 = nn.Linear(400, 100)
        self.fc31 = nn.Linear(100, z_dim)
        self.fc32 = nn.Linear(100, z_dim)

    def forward(self, x):
        x = x.reshape(-1, 784)
        hidden = torch.tanh(self.fc1(x))
        hidden = torch.tanh(self.fc2(hidden))
        z_loc = self.fc31(hidden)
        z_scale = torch.exp(0.5 * self.fc32(hidden))
        return z_loc, z_scale

In [None]:
class BVAE(nn.Module):
    def __init__(self, z_dim = 2):
        super().__init__()
        self.z_dim = z_dim
        self.encoder = Encoder(self.z_dim)

    def model(self, x):
        # Place N(0, 1) priors on the linear layers
        w1 = pyro.sample("w1", dist.Normal(0, 1).expand([2, 100]).to_event(2))
        b1 = pyro.sample("b1", dist.Normal(0, 1).expand([100]).to_event(1))
        w2 = pyro.sample("w2", dist.Normal(0, 1).expand([100, 400]).to_event(2))
        b2 = pyro.sample("b2", dist.Normal(0, 1).expand([400]).to_event(1))
        w3 = pyro.sample("w3", dist.Normal(0, 1).expand([400, 784]).to_event(2))
        b3 = pyro.sample("b3", dist.Normal(0, 1).expand([784]).to_event(1))
        with pyro.plate("data", x.shape[0]):
            # N(0, 1) prior on the latent variable z
            z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))
            z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim)))
            z = pyro.sample("z", dist.Normal(z_loc, z_scale).to_event(1))
            hidden = torch.tanh((z @ w1) + b1)
            hidden = torch.tanh((hidden @ w2) + b2)
            loc_img = torch.sigmoid((hidden @ w3) + b3)
            pyro.sample("obs", dist.Bernoulli(loc_img).to_event(1), obs=x.reshape(-1, 784))

    def guide(self, x):
        # w and b
        # We initialize the mean parameters to 0 and the sd parameters to 1.
        w1_mu = pyro.param("w1_mu", torch.zeros([2, 100]))
        w1_sd = pyro.param("w1_sd", torch.ones([2, 100]),
                         constraint = torch.distributions.constraints.positive)
        pyro.sample("w1", dist.Normal(w1_mu, w1_sd).to_event(2))
        b1_mu = pyro.param("b1_mu", torch.zeros([100]))
        b1_sd = pyro.param("b1_sd", torch.ones([100]),
                         constraint = torch.distributions.constraints.positive)
        pyro.sample("b1", dist.Normal(b1_mu, b1_sd).to_event(1))
        w2_mu = pyro.param("w2_mu", torch.zeros([100, 400]))
        w2_sd = pyro.param("w2_sd", torch.ones([100, 400]),
                         constraint = torch.distributions.constraints.positive)
        pyro.sample("w2", dist.Normal(w2_mu, w2_sd).to_event(2))
        b2_mu = pyro.param("b2_mu", torch.zeros([400]))
        b2_sd = pyro.param("b2_sd", torch.ones([400]),
                         constraint = torch.distributions.constraints.positive)
        pyro.sample("b2", dist.Normal(b2_mu, b2_sd).to_event(1))
        w3_mu = pyro.param("w3_mu", torch.zeros([400, 784]))
        w3_sd = pyro.param("w3_sd", torch.ones([400, 784]),
                         constraint = torch.distributions.constraints.positive)
        pyro.sample("w3", dist.Normal(w3_mu, w3_sd).to_event(2))
        b3_mu = pyro.param("b3_mu", torch.zeros([784]))
        b3_sd = pyro.param("b3_sd", torch.ones([784]),
                         constraint = torch.distributions.constraints.positive)
        pyro.sample("b3", dist.Normal(b3_mu, b3_sd).to_event(1))
        
        # z
        pyro.module("encoder", self.encoder)
        with pyro.plate("data", x.shape[0]):
            z_loc, z_scale = self.encoder(x)
            pyro.sample("z", dist.Normal(z_loc, z_scale).to_event(1))

In [None]:
def train(svi, train_loader):
    epoch_loss = 0.
    for x, _ in train_loader:
        epoch_loss += svi.step(x)

    normalizer_train = len(train_loader.dataset)
    total_epoch_loss_train = epoch_loss / normalizer_train
    return total_epoch_loss_train

def evaluate(svi, test_loader):
    test_loss = 0.
    for x, _ in test_loader:
        test_loss += svi.evaluate_loss(x)
    normalizer_test = len(test_loader.dataset)
    total_epoch_loss_test = test_loss / normalizer_test
    return total_epoch_loss_test

In [None]:
def setup_data_loaders(batch_size=128):
    root = './data'
    download = True
    trans = transforms.ToTensor()
    train_set = MNIST(root=root, train=True, transform=trans,
                      download=download)
    test_set = MNIST(root=root, train=False, transform=trans)

    kwargs = {'num_workers': 1, 'pin_memory': False}
    train_loader = torch.utils.data.DataLoader(dataset=train_set,
        batch_size=batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(dataset=test_set,
        batch_size=batch_size, shuffle=False, **kwargs)
    return train_loader, test_loader

In [None]:
LEARNING_RATE = 1.0e-3

NUM_EPOCHS = 1500
TEST_FREQUENCY = 5

train_loader, test_loader = setup_data_loaders(batch_size=128)

pyro.clear_param_store()

bvae = BVAE(z_dim = 2)

adam_args = {"lr": LEARNING_RATE}
optimizer = Adam(adam_args)

svi = SVI(bvae.model, bvae.guide, optimizer, loss=Trace_ELBO())

In [None]:
train_elbo = []
test_elbo = []
for epoch in range(NUM_EPOCHS):
    total_epoch_loss_train = train(svi, train_loader)
    train_elbo.append(-total_epoch_loss_train)
    print("[epoch %03d]  average training loss: %.4f" % (epoch, total_epoch_loss_train))

    if epoch % TEST_FREQUENCY == 0:
        total_epoch_loss_test = evaluate(svi, test_loader)
        test_elbo.append(-total_epoch_loss_test)
        print("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test))

In [None]:
n_im = 16
batch0 = next(iter(test_loader))[0]
plt.imshow(batch0[n_im, :].reshape([28, 28]), cmap = "gray")

In [None]:
predictive = pyro.infer.Predictive(bvae.model, guide = bvae.guide, num_samples = 1)
pred = predictive(batch0[n_im])

# Use parameters sampled from the variational distributions to reconstruct image

hidden = torch.tanh((pred["z"] @ pred["w1"]) + pred["b1"])
hidden = torch.tanh((hidden @ pred["w2"]) + pred["b2"])
reconstructed = torch.sigmoid((hidden @ pred["w3"]) + pred["b3"])

In [None]:
fig = plt.figure()
plt.imshow(reconstructed.detach().numpy().reshape([28, 28]), cmap = "gray")
plt.axis("off")
plt.title("Bayesian VAE: MNIST 9")
plt.show()

fig.savefig('bvae.png', format = 'png', dpi = 600, bbox_inches = 'tight')

In [None]:
torch.save(bvae, "bvae.pt")
pyro.get_param_store().save("bvae_params.pyro")