In [1]:
%pylab inline
import torch
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch import nn, optim
import VAE

Populating the interactive namespace from numpy and matplotlib


In [2]:
# declare loss functions

def log_bernoulli_loss(x_hat, x):
    
    loss = - torch.sum(x * torch.log(x_hat) + (1-x) *(torch.log(1-x_hat)), 1)
    
    return torch.sum(loss)


def KL_loss(mu, logvar):
    
    D = torch.FloatTensor([mu.size(1)])
    log_D = torch.log(D)
    sum_logvar = torch.sum(logvar, 1)
    norm_var = torch.sum(torch.exp(logvar), 1)
    norm_mu = torch.sum(mu * mu, 1)
    loss = (log_D - sum_logvar + norm_var + norm_mu - D)/2
    
    return torch.sum(loss)


def loss_function(x_hat, x, mu, logvar):
    
    log_bernoulli = log_bernoulli_loss(x_hat, x)
    KL = KL_loss(mu, logvar)
    
    return log_bernoulli+KL

In [3]:
from torch import nn
from torch.nn import functional as F 

class VAE(nn.Module):
    def __init__(self, fc1_dims, fc21_dims, fc22_dims, fc3_dims, fc4_dims):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(*fc1_dims)
        self.fc21 = nn.Linear(*fc21_dims)
        self.fc22 = nn.Linear(*fc22_dims)
        self.fc3 = nn.Linear(*fc3_dims)
        self.fc4 = nn.Linear(*fc4_dims)

    def encode(self, x):
        embedding = F.relu(self.fc1(x))
        mu = F.sigmoid(self.fc21(embedding))
        logvar = F.tanh(self.fc22(embedding))
        return mu, logvar

    def reparameterize(self, mu, logvar):
        epsilon = torch.normal(torch.zeros(logvar.size()), torch.ones(logvar.size()))
        sigma = torch.sqrt(torch.exp(logvar))
        z = mu + epsilon * logvar
        return z

    def decode(self, z):
        x_hat =  F.sigmoid(self.fc4(self.fc3(z)))
        return x_hat

    def forward(self, x):
        x = x.view(-1, 784)
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_hat = self.decode(z)
        return x_hat, mu, logvar

In [4]:
def train(epoch, train_loader, model, optimizer):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = Variable(data)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data.view(-1, 784), mu, logvar)
        loss.backward()
        train_loss += loss.data
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.data / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))

In [5]:
# Hyperparameters
fc1_dims = (784,400)
fc21_dims = (400,20)
fc22_dims = (400,20)
fc3_dims = (20, 400)
fc4_dims = (400,784)
lr = 0.001
batch_size = 32
epochs = 10

In [6]:
# Load data
train_data = datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=batch_size, shuffle=True, **{})

# Init model
VAE_MNIST = VAE(fc1_dims=fc1_dims, fc21_dims=fc21_dims, fc22_dims=fc22_dims, fc3_dims=fc3_dims, fc4_dims=fc4_dims)

# Init optimizer
optimizer = optim.Adam(VAE_MNIST.parameters(), lr=lr)

# Train
for epoch in range(1, epochs + 1):
    train(epoch, train_loader, VAE_MNIST, optimizer)

====> Epoch: 1 Average loss: 137.8186
====> Epoch: 2 Average loss: 100.6542
====> Epoch: 3 Average loss: 91.7727
====> Epoch: 4 Average loss: 88.3313
====> Epoch: 5 Average loss: 86.8586
====> Epoch: 6 Average loss: 86.0654
====> Epoch: 7 Average loss: 85.5360
====> Epoch: 8 Average loss: 85.0949


====> Epoch: 9 Average loss: 84.8151
====> Epoch: 10 Average loss: 84.5317


In [None]:
### Let's check if the reconstructions make sense
# Set model to test mode
VAE_MNIST.eval()
    
# Reconstructed
train_data_plot = datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor())

train_loader_plot = torch.utils.data.DataLoader(train_data_plot,
                                           batch_size=1, shuffle=False, **{})

for batch_idx, (data, _) in enumerate(train_loader_plot):
    x_hat, mu, logvar = VAE_MNIST(data)
    plt.imshow(x_hat.view(1,28,28).squeeze().data.numpy(), cmap='gray')
    plt.title('%i' % train_data.train_labels[batch_idx])
    plt.show()
    if batch_idx == 3:
        break