# [VAE in Pyro](http://pyro.ai/examples/vae.html#VAE-in-Pyro)

Let’s see how we implement a VAE in Pyro. The dataset we’re going to model is MNIST, a collection of images of handwritten digits. Since this is a popular benchmark dataset, we can make use of PyTorch’s convenient data loader functionalities to reduce the amount of boilerplate code we need to write:

In [13]:
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import Dataset, DataLoader

In [14]:
import os

In [15]:
pixel_path = os.path.join(os.environ['HOME']+'/mnist/')

In [16]:
if not os.path.exists(pixel_path):
    os.mkdir(pixel_path)

In [17]:
train_set = MNIST(root=pixel_path, train=True, transform=transforms.ToTensor(), download=False)
test_set = MNIST(root=pixel_path, train=False, transform=transforms.ToTensor())

The main thing to draw attention to here is that we use transforms.ToTensor() to normalize the pixel intensities to the range [0, 1].

In [18]:
use_cuda = False
kwargs = {'num_workers': 1, 'pin_memory': use_cuda}

In [19]:
train_dl = DataLoader(dataset=train_set, batch_size=128, shuffle=True)
test_dl = DataLoader(dataset=test_set, batch_size=128, shuffle=True, **kwargs)

___

## VAE

In [20]:
import torch.nn as nn

In [21]:
class Decoder(nn.Module):
    def __init__(self, z_dim, hidden_dim, out_dim) -> None:
        super().__init__()
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, out_dim)
        self.softplus = nn.Softplus()
        self.sigmoid = nn.Sigmoid()

    def forward(self, z):
        hidden = self.softplus(self.fc1(z))
        loc_img = self.sigmoid(self.fc2(hidden))
        return loc_img