> This is a self-correcting activity generated by [nbgrader](https://nbgrader.readthedocs.io). Fill in any place that says `YOUR CODE HERE` or `YOUR ANSWER HERE`. Run subsequent cells to check your code.

---

# Generate handwritten digits with a VAE (PyTorch)

The goal here is to train a VAE to generate handwritten digits.

![VAE digits](images/vae_digits.png)

## Environment setup

In [None]:
import os
import math

import numpy as np
import matplotlib
import matplotlib.pyplot as plt

In [None]:
# Setup plots
%matplotlib inline
plt.rcParams['figure.figsize'] = 10, 8
%config InlineBackend.figure_format = 'retina'

In [None]:
import torch

print(f'PyTorch version: {torch.__version__}')
print("GPU found :)" if torch.cuda.is_available() else "No GPU :(")

import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Data loading

In [None]:
# Load MNIST dataset
trainset = torchvision.datasets.MNIST(
    root="./data", train=True, transform=transforms.ToTensor(), download=True
)
testset = torchvision.datasets.MNIST(
    root="./data", train=False, transform=transforms.ToTensor(), download=True
)

### Question

Create batch data loaders `trainloader` and `testloader` resp. for training and test datasets.

In [None]:
batch_size = 128

# YOUR CODE HERE

## Model definition

### Question

Complete the following class to create a variational autoencoder.

In [None]:
# VAE model
class VAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, latent_dim)
        self.fc3 = nn.Linear(hidden_dim, latent_dim)
        self.fc4 = nn.Linear(latent_dim, hidden_dim)
        self.fc5 = nn.Linear(hidden_dim, input_dim)
        
    def encode(self, x):
        """Encode input into its latent representation
        Returns mean and standard deviation"""
        
        h = F.relu(self.fc1(x))
        return self.fc2(h), self.fc3(h)
    
    def sample(self, mu, log_var):
        """Sample a random codings vector from a gaussian distribution
        Takes mean and log_var (gamma) as parameters"""
        
        std = torch.exp(log_var/2)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        """Decode codings"""
        
        h = F.relu(self.fc4(z))
        return torch.sigmoid(self.fc5(h))
    
    def forward(self, x):
        """Encode inputs to obtain mean and standard deviation
           Sample codings from gaussian distribution using mean and std
           Returns decoded codings, mean and standard deviation"""
        # YOUR CODE HERE

## Model training

### Question

Complete the following training loop to:
- instantiate the variational autoencoder on target device;
- instanciate the Adam optimizer;
- implement forward pass and gradient descent.

In [None]:
input_dim = 784
hidden_dim = 400
latent_dim = 20
num_epochs = 15
learning_rate = 1e-3
step_count = len(trainloader)
prints_per_epoch = 1  # Increase to see more feedback during training

# Instanciate VAE and optimizer
# YOUR CODE HERE

# Train model
for epoch in range(num_epochs):
    for i, (x, _) in enumerate(trainloader):
        # Forward pass
        # YOUR CODE HERE

        # Compute reconstruction loss and KL divergence
        reconst_loss = F.binary_cross_entropy(x_reconst, x, reduction="sum")
        kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        loss = reconst_loss + kl_div

        # Backprop and optimize
        # YOUR CODE HERE

        # Print losses at regular intervals
        print_threshold = math.ceil(step_count / prints_per_epoch)
        if (i + 1) % print_threshold == 0 or (i + 1) == step_count:
            print(
                f"Epoch [{epoch + 1}/{num_epochs}]"
                f", step [{i + 1}/{step_count}]"
                f", reconst loss: {reconst_loss.item():.4f}"
                f", KL div: {kl_div.item():.4f}"
            )

## Reconstructions visualization¶

In [None]:
def plot_image(image):
    # Convert PyTorch tensor to NumPy
    img_tensor = image.cpu().numpy() if torch.cuda.is_available() else image.numpy()
    plt.imshow(img_tensor.squeeze(), cmap="binary")
    plt.axis("off")

def show_reconstructions(model, images, n_images=8):
    """Show original and reconstructed images side-by-side"""
    
    inputs = images.reshape(-1, 28*28).to(device)
    reconstructions, _, _ = model(inputs)
    
    fig = plt.figure(figsize=(n_images * 1.5, 3))
    for image_index in range(n_images):
        plt.subplot(2, n_images, 1 + image_index)
        plot_image(images[image_index])
        plt.subplot(2, n_images, 1 + n_images + image_index)
        plot_image(reconstructions[image_index].view(1, 28, 28))

### Question

Show reconstructions for one batch of test data.

In [None]:
# YOUR CODE HERE

## Generating new images¶

In [None]:
def plot_multiple_images(images, n_cols=None):
    """Show a series of images"""

    n_cols = n_cols or len(images)
    n_rows = (len(images) - 1) // n_cols + 1
    if images.shape[-1] == 1:
        images = np.squeeze(images, axis=-1)
    plt.figure(figsize=(n_cols * 1.5, 3))
    for index, image in enumerate(images):
        plt.subplot(n_rows, n_cols, index + 1)
        plot_image(image)

### Question

Use the VAE to show several generated digits.

In [None]:
with torch.no_grad():
    z = torch.randn(16, latent_dim).to(device)
    # YOUR CODE HERE