# Variational Auto Encoder

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm.autonotebook import tqdm
from itertools import chain

import matplotlib.gridspec as gridspec    
import os

import numpy as np

In [None]:

batch_size = 256
image_dim = 784 


transform = transforms.Compose([
    transforms.ToTensor(), # convert to tensor
    transforms.Lambda(lambda x: x.view(image_dim)) # flatten into vector
    ])

train_set = torchvision.datasets.FashionMNIST(
    root='./data/FashionMNIST'
    ,train=True
    ,download=True
    ,transform=transform
)

train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=batch_size
)

In [None]:

class Encoder(nn.Module):
    '''
    simple encoder with a single hidden dense layer (ReLU activation)
    and linear projections to the diag-Gauss parameters
    '''
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.fc2_sigma = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        # print(x.shape)
        x = self.fc1(x)
        x = F.relu(x)
        mu = self.fc2(x)
        logsigma = self.fc2_sigma(x)
        return(mu, logsigma)
        
class Decoder(nn.Module):
    '''
    simple decoder: single dense hidden layer (ReLU activation) followed by 
    output layer with a sigmoid to squish values
    '''
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = torch.sigmoid(x)
        return(x)

In [None]:
def sample(mu, log_sigma2):

    eps = torch.randn(mu.shape[0], mu.shape[1])

    return mu + torch.exp(log_sigma2 / 2) * eps

In [None]:
embedding_dim = 2
enc_hidden_units = 512
dec_hidden_units = 512
nEpoch = 10

In [None]:

# construct the encoder, decoder and optimiser
enc = Encoder(image_dim, enc_hidden_units, embedding_dim)
dec = Decoder(embedding_dim, dec_hidden_units, image_dim)
optimizer = optim.Adam(chain(enc.parameters(), dec.parameters()), lr=1e-3)


In [None]:

loss_plot = []
# training loop
for epoch in range(nEpoch):
    losses = []
    trainloader = tqdm(train_loader)

    for i, data in enumerate(trainloader, 0):
        inputs, _ = data

        optimizer.zero_grad()

        mu, log_sigma2 = enc(inputs)
        #print(inputs.shape, mu.shape)
        z = sample(mu, log_sigma2)
        #print("sample shape: ", z.shape)
        outputs = dec(z)

        # E[log P(X|z)] - as images are binary it makes most sense to use binary cross entropy
        # we need to be a little careful - by default torch averages over every observation 
        # (e.g. each  pixel in each image of each batch), whereas we want the average over entire
        # images instead
        recon = F.binary_cross_entropy(outputs, inputs, reduction='sum') / inputs.shape[0]
        

        # kl = D_KL(Q(z|X) || P(z|X)) - calculate in closed form
        kl = torch.mean(-0.5 * torch.sum(1 + log_sigma2 - mu ** 2 - log_sigma2.exp(), dim = 1), dim = 0)
        loss = recon + kl

        loss.backward()
        optimizer.step()

        # keep track of the loss and update the stats
        losses.append(loss.item())
        trainloader.set_postfix(loss=np.mean(losses), epoch=epoch)
    
    loss_plot.append(np.average(losses))

In [None]:
outputs_d = outputs.detach().numpy()
outputs_d.shape = (96, 28, 28)

In [None]:
for i in range(4):
  plt.subplot(int(str(24)+str(i+1)))
  plt.imshow(train_set.train_data[i], cmap=plt.get_cmap('gray'))


  plt.subplot(int(str(24)+str(i + 5)))
  plt.imshow(outputs_d[i], cmap=plt.get_cmap('gray'))


In [None]:
mu.shape, log_sigma2.shape

In [None]:
mean = mu[0].detach().numpy()
sigma = log_sigma2[0].detach().numpy()

In [None]:
mean, sigma

In [None]:

x = np.linspace(mean[0] + sigma[0] *4, mean[0] - sigma[0]*4, 21)
y = np.linspace(mean[1] - sigma[1] *4, mean[1] + sigma[1]*4, 21)

In [None]:
x.shape

In [None]:

sample_img = np.zeros((28*21, 28*21))

size = 21

for row in range(size):
  for col in range(size):
    out = dec(torch.Tensor([x[row],  y[col]])).detach().numpy().reshape(28, 28)
    
    sample_img[col*28:col*28+28, row*28: row*28+28] = out
    # axs[row, col].axis('off')
    # axs[row, col].imshow(out, cmap=plt.get_cmap('gray'))

In [None]:
 plt.figure(figsize = (10, 10))
 plt.axis("off")
 plt.imshow(sample_img, cmap=plt.get_cmap('gray'))