In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import torch.distributions as distribution
import math

from torchvision.datasets import utils
import torch.utils.data as data_utils
import torch
import os
import numpy as np
from torch import nn
from torch.nn.modules import upsampling
from torch.functional import F
from torch.optim import Adam

In [0]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [0]:
batch_size = 32
learning_rate = 3e-4
epochs = 20
latent_variable_dim = 100

In [0]:
#Code provided to load the data

def get_data_loader(dataset_location, batch_size):
    URL = "http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/"
    
    # start processing
    def lines_to_np_array(lines):
        return np.array([[int(i) for i in line.split()] for line in lines])
    splitdata = []
    
    for splitname in ["train", "valid", "test"]:
        filename = "binarized_mnist_%s.amat" % splitname
        filepath = os.path.join(dataset_location, filename)
        utils.download_url(URL + filename, dataset_location)
        with open(filepath) as f:
            lines = f.readlines()
        x = lines_to_np_array(lines).astype('float32')
        x = x.reshape(x.shape[0], 1, 28, 28)
        
        # pytorch data loader
        dataset = data_utils.TensorDataset(torch.from_numpy(x))
        dataset_loader = data_utils.DataLoader(x, batch_size = batch_size, shuffle = splitname == "train")
        splitdata.append(dataset_loader)
        splitdata.append(dataset)
    return splitdata

In [5]:
trainloader, train, validloader, valid, testloader, test = get_data_loader("binarized_mnist", batch_size)

Using downloaded and verified file: binarized_mnist/binarized_mnist_train.amat
Using downloaded and verified file: binarized_mnist/binarized_mnist_valid.amat
Using downloaded and verified file: binarized_mnist/binarized_mnist_test.amat


We use the tutorial provided in the following websites to help us implement our VAE :
http://hameddaily.blogspot.com/2018/12/yet-another-tutorial-on-variational.html
and https://github.com/pytorch/examples/blob/master/vae/main.py

**Train a VAE**

In [0]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        # For encoder
        self.fcE1   = nn.Linear(in_features=256, out_features=2*latent_variable_dim)
        self.convE1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1)
        self.convE2 = nn.Conv2d(in_channels=32 , out_channels=64 , kernel_size=3)
        self.convE3 = nn.Conv2d(in_channels=64 , out_channels=256, kernel_size=5)
                
        # For decoder
        self.fcD1   = nn.Linear(in_features=latent_variable_dim, out_features=256)
        self.convD1 = nn.Conv2d(in_channels=256, out_channels=64, kernel_size=5, padding=4)
        self.convD2 = nn.Conv2d(in_channels=64 , out_channels=32, kernel_size=3, padding=2)
        self.convD3 = nn.Conv2d(in_channels=32 , out_channels=16, kernel_size=3, padding=2)        
        self.convD4 = nn.Conv2d(in_channels=16 , out_channels=1 , kernel_size=3, padding=2)
        self.upsampling = nn.UpsamplingBilinear2d(scale_factor=2)

    def encode(self, x):
        h1 = F.elu(self.convE1(x))
        h2 = F.avg_pool2d(h1, kernel_size=2, stride=2)
        h3 = F.elu(self.convE2(h2))
        h4 = F.avg_pool2d(h3, kernel_size=2, stride=2)
        h5 = F.elu(self.convE3(h4))
        h5 = h5.reshape(-1,256)
        h6 = self.fcE1(h5) 
        mu = h6[:,:latent_variable_dim]
        logvar = h6[:,latent_variable_dim:]
        return  mu , logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h7 = F.elu(self.fcD1(z))
        h7 = h7.reshape(z.shape[0], 256, 1, 1)
        h8 = F.elu(self.convD1(h7))
        h9 = self.upsampling(h8)
        h10 = F.elu(self.convD2(h9))
        h11 = self.upsampling(h10)
        h12 = F.elu(self.convD3(h11))
        return torch.sigmoid(self.convD4(h12))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar
      
    def get_sample(self, mu, logvar, x):
        epsilon = torch.randn(x.shape[0], latent_variable_dim).to(device)
        return epsilon * (logvar/2).exp() + mu 


In [0]:
model = VAE()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = Adam(model.parameters(), lr=learning_rate)

In [0]:
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x.view(-1, 784), x.view(-1, 784), reduction='sum')
    KLD = 0.5 * torch.sum(-1 - logvar + mu**2 + logvar.exp())

    return BCE + KLD

In [9]:
log_interval = batch_size

for epoch in range(epochs):
    model.train()
    train_loss = 0
    for i, data in enumerate(trainloader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
      
    model.eval()
    valid_loss = 0
    with torch.no_grad():
        for i, data in enumerate(validloader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            valid_loss += loss_function(recon_batch, data, mu, logvar).item()
            
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, data in enumerate(testloader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
    
    print('Epoch: {} --- ELBO Train: {:.2f} ---  ELBO Validation: {:.2f} ---  ELBO Test: {:.2f}'.format(epoch + 1, -train_loss / len(trainloader.dataset), 
                                                                                                        -valid_loss / len(validloader.dataset), -test_loss / len(testloader.dataset) ))
    



Epoch: 1 --- ELBO Train: -157.37 ---  ELBO Validation: -121.96 ---  ELBO Test: -120.69
Epoch: 2 --- ELBO Train: -114.38 ---  ELBO Validation: -109.88 ---  ELBO Test: -108.72
Epoch: 3 --- ELBO Train: -106.48 ---  ELBO Validation: -104.90 ---  ELBO Test: -103.84
Epoch: 4 --- ELBO Train: -102.78 ---  ELBO Validation: -101.42 ---  ELBO Test: -100.45
Epoch: 5 --- ELBO Train: -100.57 ---  ELBO Validation: -100.49 ---  ELBO Test: -99.48
Epoch: 6 --- ELBO Train: -99.15 ---  ELBO Validation: -98.85 ---  ELBO Test: -97.99
Epoch: 7 --- ELBO Train: -98.10 ---  ELBO Validation: -98.06 ---  ELBO Test: -97.31
Epoch: 8 --- ELBO Train: -97.25 ---  ELBO Validation: -97.21 ---  ELBO Test: -96.44
Epoch: 9 --- ELBO Train: -96.63 ---  ELBO Validation: -96.76 ---  ELBO Test: -96.05
Epoch: 10 --- ELBO Train: -96.01 ---  ELBO Validation: -96.18 ---  ELBO Test: -95.38
Epoch: 11 --- ELBO Train: -95.56 ---  ELBO Validation: -95.79 ---  ELBO Test: -94.89
Epoch: 12 --- ELBO Train: -95.15 ---  ELBO Validation: -95.2

**Evaluating log-likelihood with Variational Autoencoders **

In [0]:
def log_likelihood(model, data, M, K=200, D=784, L=100):
    with torch.no_grad():
        data = torch.utils.data.DataLoader(data, batch_size=M, shuffle=False)
        data = list(data)[0][0].to(device)
        mu, logvar = model.encode(data)
        
        normal_distribution = distribution.Normal(mu, (logvar/2).exp())
        std_normal_distribution = distribution.Normal(torch.zeros(L).to(device), torch.ones(L).to(device))

        sum_p = []
        for i in range(K):
            z = model.get_sample(mu, logvar, data)
            recon_data = model.decode(z)
            log_p_z = torch.sum(std_normal_distribution.log_prob(z), 1)
            log_p_zx = torch.sum(normal_distribution.log_prob(z), 1)
            log_p_xz = -F.binary_cross_entropy(recon_data, data, reduction="none").view(M, -1)
            log_p_xz = torch.sum(log_p_xz, 1)
            sum_p.append(log_p_xz + log_p_z - log_p_zx - math.log(K))
            
        log_p_x = torch.logsumexp(torch.stack(sum_p).to(device),0).cpu().numpy()
    return log_p_x

In [11]:
val_log_likelihood = np.mean(log_likelihood(model, valid, M=batch_size))
print("Log-likelihood  estimate on Validation: {:.2f}".format(val_log_likelihood))



Log-likelihood  estimate on Validation: -83.33


In [12]:
test_log_likelihood = np.mean(log_likelihood(model, test, M=batch_size))
print("Log-likelihood  estimate on Test: {:.2f}".format(test_log_likelihood))



Log-likelihood  estimate on Test: -95.09
