In [None]:
%pylab inline
import torch
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch import nn, optim
from VAE import *

In [None]:
# declare loss functions

# def log_bernoulli_loss(x_hat, x):
#     x = torch.unsqueeze(x, 1)
#     x_hat = torch.transpose(torch.unsqueeze(x_hat, 1), 1, 2)
#     loss = torch.bmm(x, torch.log(x_hat))
#     loss += torch.bmm(1 - x, torch.log(1 - x_hat))
#     return -torch.sum(loss)

def log_bernoulli_loss(x_hat, x):
    loss = torch.sum(x * torch.log(x_hat) + (1-x) *(torch.log(1-x_hat)), 1)
    return - torch.sum(loss)

# def KL_loss(mu, logvar):
#     _, D = mu.size()
#     var = torch.exp(logvar)
#     trace = torch.sum(var, dim=1)
#     logsum = torch.sum(logvar, dim=1)
#     mu = torch.unsqueeze(mu, 1)
#     mu_hat = torch.transpose(mu, 1, 2)
#     loss = 0.5 * (trace + torch.bmm(mu, mu_hat).squeeze() - logsum - D)
#     return torch.sum(loss)

def KL_loss(mu, logvar):
    # Gaussian
    D = torch.FloatTensor([mu.size(1)])
    log_D = torch.log(D)
    sum_logvar = torch.sum(logvar, 1)
    norm_var = torch.sum(torch.exp(logvar), 1)
    norm_mu = torch.sum(mu * mu, 1)
    loss = (log_D - sum_logvar + norm_var + norm_mu - D)/2
    return torch.sum(loss)

def loss_function(x_hat, x, mu, logvar):
    return log_bernoulli_loss(x_hat, x) + KL_loss(mu, logvar)

In [None]:
def train(epoch, train_loader, model, optimizer):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = Variable(data)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data.view(-1, 784), mu, logvar)
        loss.backward()
        train_loss += loss.data
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.data / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))

In [None]:
# set learning rate, batch size and number of epochs

lr = 1e-3
batch_size = 20
epochs = 3

In [None]:
# Load data
train_data = datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=batch_size, shuffle=True, **{})

# Init model
VAE_MNIST = VAE(784)

# Init optimizer
optimizer = optim.Adam(VAE_MNIST.parameters(), lr=lr)

# Train
for epoch in range(1, epochs + 1):
    train(epoch, train_loader, VAE_MNIST, optimizer)

In [None]:
### Let's check if the reconstructions make sense
# Set model to test mode
VAE_MNIST.eval()
    
# Reconstructed
train_data_plot = datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor())

train_loader_plot = torch.utils.data.DataLoader(train_data_plot,
                                           batch_size=1, shuffle=False, **{})

for batch_idx, (data, _) in enumerate(train_loader_plot):
    x_hat, mu, logvar = VAE_MNIST(data)
    plt.imshow(x_hat.view(1,28,28).squeeze().data.numpy(), cmap='gray')
    plt.title('%i' % train_data.train_labels[batch_idx])
    plt.show()
    if batch_idx == 3:
        break