In [78]:
# imports
#import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.optim as optim
from torch.autograd import Variable
import torchvision
from torchvision import transforms
from torchvision.datasets import MNIST

In [90]:
# Static parameters
batch_size = 100
n_epochs = 1
n_latents = 50
lr = 1e-3
log_interval = 100

# Load data
train_loader   = torch.utils.data.DataLoader(
    MNIST('./data', train=True, download=True, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True)
N_mini_batches = len(train_loader)
test_loader    = torch.utils.data.DataLoader(
    MNIST('./data', train=False, download=True, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=False)

In [91]:
import torch.nn as nn
import torch.nn.functional as F

class VAE(nn.Module):
    
    def __init__(self, n_latents):
        super(VAE, self).__init__()
        self.encoder = Encoder(n_latents)
        self.decoder = Decoder(n_latents)
        self.n_latents = n_latents
        
        
    def reparameterize(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = Variable(std.data.new(std.size()).normal_())
            return eps.mul(std).add_(mu)
        else:
            return mu
        
    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z)
        return x_recon, mu, logvar
        
class Encoder(nn.Module):
    
    def __init__(self, n_latents):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(784, 200)
        self.fc2 = nn.Linear(200, 200)
        self.fc3 = nn.Linear(200, 2*n_latents)
        self.n_latents = n_latents
        
    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        out = self.fc3(h)
        return out[:, :self.n_latents], out[:, self.n_latents:]
    
class Decoder(nn.Module):
    
    def __init__(self, n_latents):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(n_latents, 200)
        self.fc2 = nn.Linear(200, 200)
        self.fc3 = nn.Linear(200, 784)
        self.n_latents = n_latents
        
    def forward(self, z):
        h = F.relu(self.fc1(z))
        h = F.relu(self.fc2(h))
        out = self.fc3(h)
        return out # No sigmoid here, see loss

In [92]:
def elbo_loss(recon_image, image, mu, logvar,
              lambda_image=1.0, annealing_factor=1):
    """Bimodal ELBO loss function. 
    
    @param recon_image: torch.Tensor
                        reconstructed image
    @param image: torch.Tensor
                  input image
    @param mu: torch.Tensor
               mean of latent distribution
    @param logvar: torch.Tensor
                   log-variance of latent distribution
    @param lambda_image: float [default: 1.0]
                         weight for image BCE
    @param annealing_factor: integer [default: 1]
                             multiplier for KL divergence term
    @return ELBO: torch.Tensor
                  evidence lower bound
    """
    image_bce, text_bce = 0, 0  # default params
    if recon_image is not None and image is not None:
        image_bce = torch.sum(binary_cross_entropy_with_logits(
            recon_image.view(-1, 1 * 28 * 28), 
            image.view(-1, 1 * 28 * 28)), dim=1)

    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
    ELBO = torch.mean(lambda_image * image_bce + annealing_factor * KLD)
    return ELBO

def binary_cross_entropy_with_logits(input, target):
    """Sigmoid Activation + Binary Cross Entropy
    @param input: torch.Tensor (size N)
    @param target: torch.Tensor (size N)
    @return loss: torch.Tensor (size N)
    """
    if not (target.size() == input.size()):
        raise ValueError("Target size ({}) must be the same as input size ({})".format(
            target.size(), input.size()))

    return (torch.clamp(input, 0) - input * target 
            + torch.log(1 + torch.exp(-torch.abs(input))))

In [93]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [94]:
model = VAE(n_latents)
optimizer = optim.Adam(model.parameters(), lr=lr)

def train(epoch):
    model.train()
    train_loss_meter = AverageMeter()

    # NOTE: is_paired is 1 if the example is paired
    for batch_idx, (image, label) in enumerate(train_loader):
        
        image      = Variable(image)
        label       = Variable(label)
        batch_size = len(image)
        
        # refresh the optimizer
        optimizer.zero_grad()

        # pass data through model
        recon_image, mu, logvar = model(image.view(-1, 1 * 28 * 28))

        # compute ELBO 
        train_loss = elbo_loss(recon_image, image, mu, logvar)
        train_loss_meter.update(train_loss.data, batch_size)
        
        # compute gradients and take step
        train_loss.backward()
        optimizer.step()

        #if batch_idx % log_interval == 0:
        #    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        #        epoch, batch_idx * len(image), len(train_loader.dataset),
        #         100. * batch_idx / len(train_loader), train_loss_meter.avg))

    print('====> Epoch: {}\tLoss: {:.4f}'.format(epoch, train_loss_meter.avg))


def test(epoch):
    model.eval()
    test_loss_meter = AverageMeter()

    for batch_idx, (image, label) in enumerate(test_loader):

        image = Variable(image)
        label  = Variable(label)
        batch_size = len(image)

        recon_image, mu, logvar = model(image.view(-1, 1 * 28 * 28))

        test_loss = elbo_loss(recon_image, image, mu, logvar)
        test_loss_meter.update(test_loss.data, batch_size)

    print('====> Test Loss: {:.4f}'.format(test_loss_meter.avg))
    return test_loss_meter.avg

In [95]:
for epoch in range(1, n_epochs + 1):
    train(epoch)
    test_loss = test(epoch)

====> Epoch: 1	Loss: 189.6260
====> Test Loss: 152.8628
