In [1]:
import os
import torch
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image
import pytorch_model_summary

if not os.path.exists('./VAE_img'):
    os.mkdir('./VAE_img')

In [2]:
def normalization(tensor, min_value, max_value):
    min_tensor = tensor.min()
    tensor = (tensor - min_tensor)
    max_tensor = tensor.max()
    tensor = tensor / max_tensor
    tensor = tensor * (max_value - min_value) + min_value
    return tensor

def value_round(tensor):
    return torch.round(tensor)

def to_img(x):
    x = x.view(x.size(0), 1, 28, 28)
    return x

img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda tensor:normalization(tensor, 0, 1)),
    transforms.Lambda(lambda tensor:value_round(tensor))
])
batch_size = 128

dataset = MNIST('./MNIST_dataset', transform=img_transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [3]:
class VariationalAutoencoder(nn.Module):
    def __init__(self):
        super(VariationalAutoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 400),
            nn.ReLU(True),
            nn.Linear(400, 40))
        self.decoder = nn.Sequential(
            nn.Linear(20, 400),
            nn.ReLU(True),
            nn.Linear(400, 28 * 28),
            nn.Sigmoid())

    def reparametrization(self, mu, logvar):
        var = logvar.exp()
        std = var.sqrt()
        eps = Variable(torch.cuda.FloatTensor(std.size()).normal_())
        return eps.mul(std).add(mu)

    def forward(self, x):
        h = self.encoder(x)
        mu = h[:, :20]
        logvar = h[:, 20:]
        z = self.reparametrization(mu, logvar)
        x_gen = self.decoder(z)
        return x_gen, mu, logvar

    def interpolation(self, x_1, x_2, alpha):
        traverse_1 = self.encoder(x_1)
        traverse_2 = self.encoder(x_2)
        mu_1, mu_2 = traverse_1[:, :20], traverse_2[:, :20]
        logvar_1, logvar_2 = traverse_1[:, 20:], traverse_2[:, 20:]
        traverse_m = (1 - alpha) * mu_1 + alpha * mu_2
        traverse_logvar = (1 - alpha) * logvar_1 + alpha * logvar_2
        z = self.reparametrization(traverse_m, traverse_logvar)
        generated_image = self.decoder(z)
        return generated_image

In [4]:
model = VariationalAutoencoder().cuda()
print(pytorch_model_summary.summary(model, torch.zeros(1,784).cuda(), show_input = True))

-----------------------------------------------------------------------
      Layer (type)         Input Shape         Param #     Tr. Param #
          Linear-1            [1, 784]         314,000         314,000
            ReLU-2            [1, 400]               0               0
          Linear-3            [1, 400]          16,040          16,040
          Linear-4             [1, 20]           8,400           8,400
            ReLU-5            [1, 400]               0               0
          Linear-6            [1, 400]         314,384         314,384
         Sigmoid-7            [1, 784]               0               0
Total params: 652,824
Trainable params: 652,824
Non-trainable params: 0
-----------------------------------------------------------------------


  eps = Variable(torch.cuda.FloatTensor(std.size()).normal_())


In [12]:
BCE = nn.BCELoss()
num_epochs, learning_rate = 50, 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(num_epochs):
    for data in dataloader:
        img, _ = data
        img = img.view(img.size(0), -1)
        img = Variable(img).cuda()
        x_gen, mu, logvar = model(img)
        NKLD = mu.pow(2).add(logvar.exp()).mul(-1).add(logvar.add(1))
        KLD = torch.sum(NKLD).mul(-0.5)
        KLD /= batch_size * 784
        loss = BCE(x_gen, img) + KLD
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    if epoch % 10 == 0 or (epoch+1) == num_epochs:
        print('epoch [{}/{}], loss:{:.4f}'
          .format(epoch + 1, num_epochs, loss.item()))
        x_gt = to_img(img.cpu().data)
        x_gen = to_img(x_gen.cpu().data)
        save_image(x_gt, './VAE_img/ground_truth_{}.png'.format(epoch))
        save_image(x_gen, './VAE_img/generated_x{}.png'.format(epoch))
        batch = next(iter(dataloader))
        batch = batch[0].clone().detach()
        batch = batch.view(batch.size(0), -1)
        batch = Variable(batch).cuda()
        x_1 = batch[0:1]
        x_2 = batch[1:2]
        generated_images = []
        for alpha in torch.arange(0.0, 1.0, 0.1):
            generated_images.append(model.interpolation(
                x_1, x_2, alpha))
        generated_images = torch.cat(generated_images, 0).cpu().data
        save_image(generated_images.view(-1, 1, 28, 28),
                   './VAE_img/interpolation_{}.png'.format(epoch),
                   nrow=1)
torch.save(model.state_dict(), './variational_autoencoder.pth')

epoch [1/50], loss:0.1465
epoch [11/50], loss:0.1034
epoch [21/50], loss:0.0949
epoch [31/50], loss:0.0933
epoch [41/50], loss:0.0999
epoch [50/50], loss:0.0914
