This notebook is adapted from: https://github.com/pytorch/examples/tree/master/vae which uses ReLUs and the adam optimizer, instead of sigmoids and adagrad as in the original paper. This gives faster convergence.

The notebook is set up to receive images that are 28x28 pixels in size and encode them into 2 latent variables. It uses the MNIST dataset by default. If you don't have it already the notebook will automatically download it for you.

You need to create a directory called "results" in the same directory that you run the notebook from. 

The outputs are: 

* **reconstruction_NN.png** - the top row is a random selection of 8 samples from the test dataset and the lower row is the reconstruction using the model based on their auto-encoded latent variables.

* **sample_NN.png** - this is a selection of 64 randomly created samples based on a random selection of latent variable values from the model.

* **manifold.png** - this is the equivalent of Fig4(b) in Kingma & Welling.

---

In [1]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.distributions import Normal
from torchsummary import summary
from VAE import VAE

First set the parameters of the training:

In [2]:
batch_size=128
epochs=10
no_cuda=True
seed=1
log_interval=100

Decide whether you can use a GPU:

In [3]:
cuda = no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}

Set a random seed:

In [4]:
torch.manual_seed(seed)

<torch._C.Generator at 0x10a34d710>

Load the training data:

In [5]:
train_loader = torch.utils.data.DataLoader(
                            datasets.MNIST('../data', train=True, download=True,
                            transform=transforms.ToTensor()),
                            batch_size=batch_size, shuffle=True, **kwargs)

Load the test data:

In [6]:
test_loader = torch.utils.data.DataLoader(
                            datasets.MNIST('../data', train=False, 
                            transform=transforms.ToTensor()),
                            batch_size=batch_size, shuffle=True, **kwargs)

Call the network:

In [7]:
model = VAE().to(device)

Have a look at the different layers:

In [8]:
summary(model,(1,28,28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                  [-1, 400]         314,000
            Linear-2                    [-1, 2]             802
            Linear-3                    [-1, 2]             802
            Linear-4                  [-1, 400]           1,200
            Linear-5                  [-1, 784]         314,384
Total params: 631,188
Trainable params: 631,188
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 2.41
Estimated Total Size (MB): 2.42
----------------------------------------------------------------


Define the optimizer and the learning rate:

In [9]:
optimizer = optim.Adam(model.parameters(), lr=1e-3)

Implement the loss function from the Kingma & Welling paper (Appendix B):

In [10]:
def loss_function(recon_x, x, mu, logvar):
    
    BCE = F.binary_cross_entropy(recon_x, x.view(-1,784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD

Define a function for training:

In [11]:
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))

Define a function for testing:

In [12]:
def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(batch_size, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         'results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

Loop over epochs, updating the training each time and testing the updated model against the test dataset:

In [13]:
for epoch in range(1, epochs + 1):
        train(epoch)
        test(epoch)
        with torch.no_grad():
            sample = torch.randn(64, 2).to(device)
            sample = model.decode(sample).cpu()
            save_image(sample.view(64, 1, 28, 28),
                       'results/sample_' + str(epoch) + '.png')

====> Epoch: 1 Average loss: 191.6132
====> Test set loss: 172.1348
====> Epoch: 2 Average loss: 168.1606
====> Test set loss: 164.8987
====> Epoch: 3 Average loss: 163.3929
====> Test set loss: 161.6210
====> Epoch: 4 Average loss: 160.7454
====> Test set loss: 159.5269
====> Epoch: 5 Average loss: 158.9231
====> Test set loss: 158.1781
====> Epoch: 6 Average loss: 157.7054
====> Test set loss: 157.0081
====> Epoch: 7 Average loss: 156.6586
====> Test set loss: 156.1518
====> Epoch: 8 Average loss: 155.7604
====> Test set loss: 155.4686
====> Epoch: 9 Average loss: 155.0102
====> Test set loss: 154.8375
====> Epoch: 10 Average loss: 154.3146
====> Test set loss: 154.1702


Recreate Figure 4(b) from Kingma & Welling:

In [14]:
nside = 20
x, y = torch.meshgrid([torch.linspace(0.,1.,nside), torch.linspace(0.,1.,nside)])
m = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
x = m.icdf(x).view(-1,nside**2)
y = m.icdf(y).view(-1,nside**2)
z = torch.cat((x,y),0).t()
with torch.no_grad():
    sample = z.to(device)
    sample = model.decode(sample).cpu()
    save_image(sample.view(nside**2, 1, 28, 28),
                   'results/manifold.png', nrow=nside)