# [VAE in Pyro](http://pyro.ai/examples/vae.html#VAE-in-Pyro)

Let’s see how we implement a VAE in Pyro. The dataset we’re going to model is MNIST, a collection of images of handwritten digits. Since this is a popular benchmark dataset, we can make use of PyTorch’s convenient data loader functionalities to reduce the amount of boilerplate code we need to write:

In [62]:
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import Dataset, DataLoader

In [63]:
import os

In [64]:
pixel_path = os.path.join(os.environ['HOME']+'/mnist/')

In [65]:
if not os.path.exists(pixel_path):
    os.mkdir(pixel_path)

In [66]:
train_set = MNIST(root=pixel_path, train=True, transform=transforms.ToTensor(), download=False)
test_set = MNIST(root=pixel_path, train=False, transform=transforms.ToTensor())

The main thing to draw attention to here is that we use transforms.ToTensor() to normalize the pixel intensities to the range [0, 1].

In [67]:
use_cuda = False
kwargs = {'num_workers': 1, 'pin_memory': use_cuda}

In [68]:
train_dl = DataLoader(dataset=train_set, batch_size=128, shuffle=True)
test_dl = DataLoader(dataset=test_set, batch_size=128, shuffle=True, **kwargs)

___

## VAE

In [69]:
import torch.nn as nn

In [70]:
class Decoder(nn.Module):
    def __init__(self, z_dim, hidden_dim, out_dim) -> None:
        super().__init__()
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, out_dim)
        self.softplus = nn.Softplus()
        self.sigmoid = nn.Sigmoid()

    def forward(self, z):
        hidden = self.softplus(self.fc1(z))
        loc_img = self.sigmoid(self.fc2(hidden))
        return loc_img

In [71]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, z_dim) -> None:
        super().__init__()
        self.input_dim = input_dim
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, z_dim)
        self.fc22 = nn.Linear(hidden_dim, z_dim)

        self.softplus = nn.Softplus()

    def forward(self, x):
        x = x.reshape(-1, self.input_dim)
        hidden = self.softplus(self.fc1(x))
        z_loc = self.fc21(hidden)
        z_scale = torch.exp(self.fc22(hidden))
        return z_loc, z_scale

In [86]:
encoder = Encoder(784, 400, 50)
decoder = Decoder(50, 400, 784)

In [87]:
for x, _ in train_dl:
    decoder(encoder(x))

TypeError: linear(): argument 'input' (position 1) must be Tensor, not tuple

Given an image  the forward call of Encoder returns a mean and covariance that together parameterize a (diagonal) Gaussian distribution in latent space.

In [74]:
import pyro
import pyro.distributions as dist
import torch

In [75]:
class VAE(nn.Module):
    def __init__(self, input_dim = 784, hidden_dim = 400, z_dim = 50, out_dim = 784, use_cuda = False) -> None:
        super().__init__()
        self.encoder = Encoder(input_dim, hidden_dim, z_dim)
        self.decoder = Decoder(z_dim, hidden_dim, out_dim)

        if use_cuda:
            self.cuda()
        self.use_cuda = use_cuda
        self.z_dim = z_dim
        self.input_dim = input_dim

    def model(self, x):
        pyro.module("decoder", self.decoder)
        with pyro.plate("data", x.shape[0]):
            #p(z)
            z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))
            #we use new_zeros and new_ones to ensure that newly created tensors are on the same GPU device.
            z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim)))
            # sample code from prior
            z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
            # use of .to_event(1) when sampling from the latent z - this ensures that instead of 
            # treating our sample as being generated from a univariate normal with batch_size = z_dim, 
            # we treat them as being generated from a multivariate normal distribution with diagonal covariance. 
            # As such, the log probabilities along each dimension is summed out when we evaluate .log_prob for 
            # a “latent” sample.
            loc_img = self.decoder(z)
            pyro.sample("obs", dist.Bernoulli(loc_img).to_event(1), obs=x.reshape(-1, self.input_dim)) 

    
    def guide(self, x):
        pyro.module("encoder", self.encoder)
        with pyro.plate("data", x.shape[0]):
            z_loc, z_scale = self.encoder(x)
            pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))



    def reconstruct(self, x):
        z_loc, z_scale = self.encoder(x)
        z = dist.Normal(z_loc, z_scale).sample()
        loc_img = self.decoder(z)
        return loc_img


The point we’d like to make here is that the two Modules encoder and decoder are attributes of VAE (which itself inherits from nn.Module). This has the consequence they are both automatically registered as belonging to the VAE module. So, for example, when we call parameters() on an instance of VAE, PyTorch will know to return all the relevant parameters. It also means that if we’re running on a GPU, the call to cuda() will move all the parameters of all the (sub)modules into GPU memory.

___

In [76]:
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO

In [77]:
LR = 1.0e-3
USE_CUDA = False
NUM_EPOCKS = 1
TEST_FREQUENCY = 5

In [78]:
pyro.clear_param_store()

In [79]:
vae = VAE(use_cuda=USE_CUDA)

In [80]:
adam_args = {"lr": LR}
optimizer = Adam(adam_args)

In [81]:
svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())

In [82]:
for x, _ in train_dl:
    svi.step(x)
    

ValueError: Error while computing log_prob at site 'obs':
Expected value argument (Tensor of shape (128, 784)) to be within the support (Boolean()) of the distribution Bernoulli(probs: torch.Size([128, 784])), but found invalid values:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
       Trace Shapes:            
        Param Sites:            
decoder$$$fc1.weight 400  50    
  decoder$$$fc1.bias     400    
decoder$$$fc2.weight 784 400    
  decoder$$$fc2.bias     784    
       Sample Sites:            
         latent dist 128   |  50
               value 128   |  50
            log_prob 128   |    
            obs dist 128   | 784
               value 128   | 784