### Importing Training Libraries

In [None]:
import torch
from torch import nn, optim
import torch.nn.functional as F
from torchvision.datasets import MNIST # MNIST
from tqdm import tqdm
from torchvision import transforms # For Image Augmentation
from torchvision.utils import save_image
from torch.utils.data import DataLoader # Easier data management by creating mini batches
import matplotlib.pyplot as plt

### Configuration

In [None]:
# Model Hyperparameters

INPUT_DIM = 784
H_DIM = 200
Z_DIM = 20
NUM_EPOCHS = 6
BATCH_SIZE = 64
LR_RATE = 3e-4 # Karpathy Constant

In [None]:
# Input img -> Hidden dim -> mean, std -> Parametrization trick -> -> Output img [One Hidden Layer]

class VAE(nn.Module):
    def __init__(self, input_dim, h_dim = 200, z_dim = 20):
      super().__init__()
      # Encoder
      self.img_2hid = nn.Linear(input_dim, h_dim) # Linear layer

      # Pushes the layers towards Gaussian, ensuring latent space is Gaussian
      self.hid_2mu = nn.Linear(h_dim, z_dim)
      self.hid_2sigma = nn.Linear(h_dim, z_dim)

      # Decoder
      self.z_2hid = nn.Linear(z_dim, h_dim)
      self.hid_2img = nn.Linear(h_dim, input_dim)

      self.relu = nn.ReLU()
      self.training = True


    def encode(self, x):
      h = self.relu(self.img_2hid(x))
      mu, sigma = self.hid_2mu(h), self.hid_2sigma(h)
      return mu, sigma

    def decode(self, z):
      h = self.relu(self.z_2hid(z))
      return torch.sigmoid(self.hid_2img(h)) # To ensure pixel vals are binary

    def reparameterization(self, mu, sigma):
      # Sampling epsilon for latent space with distribution from N(1,2)
      # epsilon = torch.randn_like(sigma) * torch.sqrt(torch.tensor(2.0)) + 1

      # Sampling epsilon for latent space with distribution from Gamma(3,2)
      gamma_distribution = torch.distributions.Gamma(3.0, 2.0)
      epsilon = gamma_distribution.sample(sigma.shape).to(sigma.device) # Ensures same device (GPU) as sigma
      z = mu + sigma*epsilon                          # Reparameterization trick, Element wise product
      return z

    def forward(self, x):
      mu, sigma = self.encode(x)
      epsilon = torch.randn_like(sigma)
      z_reparametrized = self.reparameterization(mu, torch.exp(0.5 * sigma))
      x_reconstructed = self.decode(z_reparametrized)
      return x_reconstructed, mu, sigma

### Test Case

In [None]:
if __name__ == "__main__":
  x = torch.randn(4,28*28)
  vae = VAE(input_dim = 784)
  x_recon, mu, sigma = vae(x)
  print(x_recon.shape)
  print(mu.shape)
  print(sigma.shape)

torch.Size([4, 784])
torch.Size([4, 20])
torch.Size([4, 20])


### Loading The Dataset

In [None]:
mnist_transform = transforms.Compose([
        transforms.ToTensor(),
])
dataset = MNIST(root="dataset/",
                         train = True,
                         transform=mnist_transform,
                         download=True)
train_loader = DataLoader(dataset = dataset,
                          batch_size= BATCH_SIZE,
                          shuffle = True)
model = VAE(INPUT_DIM, H_DIM, Z_DIM)
optimizer = optim.Adam(model.parameters(), lr = LR_RATE) # Adam Optimizer

In [None]:
# Computer loss
def loss_fn(x,x_hat, mean, var):
  recon_loss = nn.functional.binary_cross_entropy(x_hat, x, reduction='sum') # Binary Cross Entropy Loss (Since only 2 values 0 or 1 here)
  kl_div = -0.5*torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2)) # Minimize the Loss hence negative sign
  return recon_loss + kl_div

### Training The Dataset

In [None]:
for epoch in range(NUM_EPOCHS):
  overall_loss = 0
  loop = tqdm(enumerate(train_loader)) # To get the progress bar
  for i, (x, _) in loop:
    # Forward pass
    x = x.view(x.shape[0], INPUT_DIM)
    x_recon, mu, sigma = model(x)
    x_recon = torch.clamp(x_recon, 0, 1)

    # Backprop
    loss = loss_fn(x,x_recon, mu, sigma)
    overall_loss += loss.item()

    optimizer.zero_grad() # No accumulated gradients from before
    loss.backward() # Compute grads
    optimizer.step()
    loop.set_postfix(loss = loss.item())
  print("\tEpoch Num:", epoch + 1, "\tAverage Loss: ", overall_loss/ (i*BATCH_SIZE))

938it [00:19, 47.86it/s, loss=5.19e+3]


	Epoch Num: 1 	Average Loss:  196.0456906281006


938it [00:18, 50.67it/s, loss=4.81e+3]


	Epoch Num: 2 	Average Loss:  152.7770244472182


938it [00:17, 52.22it/s, loss=4.17e+3]


	Epoch Num: 3 	Average Loss:  142.08048244219833


938it [00:18, 50.14it/s, loss=4.49e+3]


	Epoch Num: 4 	Average Loss:  135.57431303856848


938it [00:18, 50.99it/s, loss=4e+3]


	Epoch Num: 5 	Average Loss:  132.6663132609527


938it [00:20, 45.26it/s, loss=4.17e+3]

	Epoch Num: 6 	Average Loss:  129.42804249017445





### Running The Model On The MNIST Dataset

In [None]:
def inference(digit, num_examples=1):
  images = []
  idx = 0
  for x, y in dataset:
    if y == idx:
      images.append(x)
      idx += 1
    if idx == 10:
      break

  # Encoding of Digits
  encodings_digit = []
  for d in range(10):
    with torch.no_grad():
      mu, sigma = model.encode(images[d].view(1,784))
    encodings_digit.append((mu,sigma))

  mu, sigma = encodings_digit[digit]
  # Decoding the Digits from the Encodings
  for example in range(num_examples):
    epsilon = torch.randn_like(sigma)
    z = mu + sigma*epsilon
    out = model.decode(z)
    out = out.view(-1, 1, 28, 28)
    save_image(out, f"generated_{digit}_ex{example}.png")

In [None]:
for idx in range(10):
  inference(idx, num_examples=1)