In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
import time
import os

In [None]:
mnist_train_data = datasets.MNIST('/home/jovyan/MNIST/', train=True, download=True, transform=transforms.ToTensor())
mnist_val_data = datasets.MNIST('/home/jovyan/MNIST/', train=False, download=True, transform=transforms.ToTensor())

In [None]:
# define some hypers
BATCH_SIZE = 32
LEARNING_RATE = 0.001
NUM_EPOCHS = 50


In [None]:
# data loader
train_dataloader = torch.utils.data.DataLoader(mnist_train_data,
                                              batch_size = BATCH_SIZE,
                                              shuffle = True
                                              )

val_dataloader = torch.utils.data.DataLoader(mnist_val_data,
                                              batch_size = BATCH_SIZE,
                                              shuffle = False
                                              )

In [None]:
class VAE(nn.Module):
    def __init__(self,):
        super(VAE, self).__init__()
        
        # encoder part
        self.layer1 = nn.Linear(784, 512)
        self.layer2 = nn.Linear(512, 256)
        self.encoded_mean = nn.Linear(256, 2)
        self.encoded_logvar = nn.Linear(256, 2)
        # decoder part
        self.decoder1 = nn.Linear(2, 256)
        self.decoder2 = nn.Linear(256, 512)
        self.decoder3 = nn.Linear(512, 784)
        
    def encoder(self, x):
        h = F.relu(self.layer1(x))
        h = F.relu(self.layer2(h))
        return self.encoded_mean(h), self.encoded_logvar(h) # mu, log_var
    
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample
        
        
    def decoder(self, z):
        h = F.relu(self.decoder1(z))
        h = F.relu(self.decoder2(h))
        return F.sigmoid(self.decoder3(h)) 
    
    def forward(self, x):
        mu, log_var = self.encoder(x.view(-1, 784))
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var
    
    



In [None]:
# build model
vae = VAE()

# are we using gpu?
if torch.cuda.is_available():
    vae.cuda()

In [None]:
optimizer = optim.Adam(vae.parameters(), lr=LEARNING_RATE)


In [None]:
# return reconstruction error + KL divergence losses
def loss_function(recon_x, x, mu, log_var):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD

In [None]:
def train(vae, optimizer):
    vae.train()
    for batch_idx, (data, _) in enumerate(train_dataloader):
        optimizer.zero_grad()
        
        recon_batch, mu, log_var = vae(data)
        loss = loss_function(recon_batch, data, mu, log_var)
        
        loss.backward()
        optimizer.step()
        


In [None]:
def plot_synthetic_data(epoch, vae):
    with torch.no_grad():
        z = torch.randn(64, 2)
        sample = vae.decoder(z)

        save_image(sample.view(64, 1, 28, 28), './figures/sample_epoch_' + str(epoch) + '.png')

In [None]:
plot_synthetic_data(-1, vae)
for epoch in range(NUM_EPOCHS):
    print("Epoch: ", epoch)
    train(vae, optimizer)
    plot_synthetic_data(epoch, vae)
    
    if epoch % 10 == 0:
        state_dict = {'weights': vae.state_dict(),
                     'epoch': epoch
                     }
        torch.save(state_dict, './figures/model_weights.pth')

In [None]:
def plot_latent(vae, dataloader, num_batches=100):
    for i, (x, y) in enumerate(dataloader):
        mu, logvar = vae.encoder(x.view(-1, 784))
        mu = mu.detach().numpy()
        plt.scatter(mu[:, 0], mu[:, 1], c=y, cmap='tab10', alpha=0.5, s=2)
        if i > num_batches:
            plt.colorbar(alpha=1)
            break

In [None]:
plot_latent(vae, val_dataloader, num_batches=1000)

In [None]:
# pick an (x,y) to reconstruct
z = torch.tensor([-2., -1.])
sample = vae.decoder(z)

save_image(sample.view(1, 28, 28), './figures/sample_epocasdfdfh_' + str(epoch) + '.png')

In [None]:
plt.imshow(sample.view(28,28).detach().numpy())

In [None]:
def plot_reconstructed(vae, y0=(-3.5, 3.5), x0=(-4, 2), n=12):
    w = 28
    img = np.zeros((n*w, n*w))
    for i, y in enumerate(np.linspace(*y0, n)):
        for j, x in enumerate(np.linspace(*x0, n)):
            z = torch.Tensor([[x, y]])
            x_hat = vae.decoder(z)
            x_hat = x_hat.reshape(28, 28).detach().numpy()
            img[(n-1-i)*w:(n-1-i+1)*w, j*w:(j+1)*w] = x_hat
    plt.imshow(img, extent=[*x0, *y0])

In [None]:
plot_reconstructed(vae, n=12)

In [None]:
def plot_interpolation(vae, p1, p2, n=20):
    
    z = torch.stack([p1 + (p2 - p1)*t for t in np.linspace(0, 1, n)])
    interpolate_list = vae.decoder(z)
    interpolate_list = interpolate_list.detach().numpy()

    w = 28
    img = np.zeros((w, n*w))
    for i, x_hat in enumerate(interpolate_list):
        img[:, i*w:(i+1)*w] = x_hat.reshape(28, 28)
    plt.imshow(img)
    plt.xticks([])
    plt.yticks([])

In [None]:
p1 = torch.tensor([])
p2 = torch.tensor([])
plot_interpolation(vae, p1, p2, n=20)