<a href="https://colab.research.google.com/github/francisco-perez-sorrosal/miniai/blob/main/VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import math
import random, os
import numpy as np

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import matplotlib.pyplot as plt

In [None]:
def seeder(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seeder(42)

In [None]:
batch_size_train = 64
batch_size_test = 1000

train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

In [None]:
examples = enumerate(train_loader)
batch_idx, (example_data, example_targets) = next(examples)

In [None]:
def display_digits_grid(images, grid_size, example_targets=None):
  n_grid_rows, n_grid_cols = grid_size

  # Select images to present
  num_images_to_present = n_grid_rows * n_grid_cols
  images = images[:num_images_to_present]

  # Create a figure with subplots
  fig, axs = plt.subplots(n_grid_rows, n_grid_cols, figsize=(n_grid_cols, n_grid_rows),
                          subplot_kw={'xticks': [], 'yticks': []})

  axs = axs.flat  # Flatten the array of axes

  # Plot each image in its respective subplot
  for i, ax in enumerate(axs):
    ax.imshow(images[i][0], cmap='gray', interpolation="none")
    if example_targets is not None:
      ax.text(0, 0.8, f"GT: {str(example_targets[i].item())}", transform=ax.transAxes, color="yellow")

  plt.show()

display_digits_grid(example_data, grid_size=(4, 10), example_targets=example_targets)

In [None]:
class DenseLayer(nn.Module):
    def __init__(self, input_size, output_size, activation_f=nn.functional.elu):
      super(DenseLayer, self).__init__()
      self.fc = nn.Linear(input_size, output_size)
      self.activation_f = activation_f
      # # Initialize the weights and biases randomly
      # nn.init.uniform_(self.fc.weight)
      # nn.init.uniform_(self.fc.bias)

    def forward(self, x):
      x = self.fc(x)
      if self.activation_f:
        x = self.activation_f(x)
      return x

class VAE(nn.Module):
    def __init__(self, input_size, n_outputs):
        super(VAE, self).__init__()
        # print(input_size, n_outputs)
        self.hidden1 = DenseLayer(input_size, 500)      
        self.hidden2 = DenseLayer(500, 500)
        self.hidden3_mean = DenseLayer(500, 20, None)
        self.hidden3_gamma = DenseLayer(500, 20, None)
        self.hidden4 = DenseLayer(20, 500)
        self.hidden5 = DenseLayer(500, 500)
        self.output = DenseLayer(500, n_outputs, None)

    def encode(self, x):
        x = self.hidden1(x)
        x = self.hidden2(x)
        mu = self.hidden3_mean(x)
        gamma = self.hidden3_gamma(x)
        return mu, gamma

    def reparametrize(self, mu, gamma):
        # noise = torch.randn(gamma.shape, dtype=torch.float32)
        # x = mu + torch.exp(0.5 * gamma) * noise
        std = torch.exp(0.5*gamma)
        noise = torch.randn_like(std)
        return mu + std * noise

    def decode(self, x):
        x = self.hidden4(x)
        x = self.hidden5(x)
        logits = self.output(x)
        # print(logits.shape)
        return logits, F.sigmoid(logits)

    def forward(self, x):
        # print(x.shape)
        x_reshaped = x.view(x.shape[0], -1)
        mu, gamma = self.encode(x_reshaped)
        z = self.reparametrize(mu, gamma)
        # print(z.shape, z)
        # sys.exit(1)
        logits, sigmoid = self.decode(z)
        return logits, sigmoid, mu, gamma





In [None]:
model = VAE(28 * 28, 28 * 28)
optimizer = optim.Adam(model.parameters(), lr=0.001)

train_losses = []
train_counter = []
log_interval = 100

def vae_loss_function(recon_x, x, mu, gamma):
  # xentropy = torch.nn.functional.binary_cross_entropy_with_logits(input=logits, target=data.view(data.shape[0],-1))
  # reconstruction_loss = torch.sum(xentropy)
  # latent_loss = 0.5 * torch.sum(torch.exp(hidden_gamma) + torch.square(hidden_mean) - 1 - hidden_gamma)
  # loss = reconstruction_loss + latent_loss
  # print(recon_x.shape, x.shape)
  BCE_reconstruction_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')
  KLD_latent_loss = -0.5 * torch.sum(1 + gamma - mu.pow(2) - gamma.exp())

  return BCE_reconstruction_loss + KLD_latent_loss


def train(epoch):
  model.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    logits, sigmoid_outputs, hidden_mean, hidden_gamma = model(data)
    # print("here", logits.shape, sigmoid_outputs.shape)

    loss = vae_loss_function(sigmoid_outputs, data.view(data.shape[0],-1), hidden_gamma, hidden_gamma)
    loss.backward()
    optimizer.step()

    if batch_idx % log_interval == 0:
      print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({(100. * batch_idx / len(train_loader)):.0f}%)]\tLoss: {loss.item():.6f}')
      train_counter.append((batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))



In [None]:
epochs=10
for epoch in range(1, epochs + 1):
  train(epoch)

In [None]:
n_digits = 20
codings_rnd = torch.randn((n_digits,20), dtype=torch.float32)
gen_numbers, gen_numbers_sig = model.decode(codings_rnd)
gen_numbers = gen_numbers_sig.reshape(n_digits, 1, 28, 28)
display_digits_grid(gen_numbers.detach().numpy(), grid_size=(2, 10), example_targets=None)

In [None]:
codings_rnd = torch.randn((n_digits,20), dtype=torch.float32)
gen_numbers, gen_numbers_sig = model.decode(codings_rnd)
gen_numbers = gen_numbers_sig.reshape(n_digits, 1, 28, 28)
display_digits_grid(gen_numbers.detach().numpy(), grid_size=(2, 10), example_targets=None)

In [None]:
codings_rnd = torch.randn((n_digits,20), dtype=torch.float32)
gen_numbers, gen_numbers_sig = model.decode(codings_rnd)
gen_numbers = gen_numbers_sig.reshape(n_digits, 1, 28, 28)
display_digits_grid(gen_numbers.detach().numpy(), grid_size=(2, 10), example_targets=None)