# Convolutional Variational Autoencoders using PyTorch

Markus Enzweiler, markus.enzweiler@hs-esslingen.de

This is a demo used in a Computer Vision & Machine Learning lecture. Feel free to use and contribute.

We analyze convolutional variational autoencoders (VAEs) on datasets such as MNIST, Fashion MNIST and CelebA. We use the Python code and pretrained models from https://github.com/menzHSE/torch-vae. This notebook does not show how to train VAEs. Plese refer to https://github.com/menzHSE/torch-vae for that. 

Good overviews of variational autoencoders are provided in [arXiv:1906.02691](https://arxiv.org/abs/1906.02691) and [arXiv:1312.6114](https://arxiv.org/abs/1312.6114).

In our implementation, the input image is not directly mapped to a single latent vector. Instead, it's transformed into a probability distribution within the latent space, from which we sample a latent vector for reconstruction. The process involves:

1. **Encoding to Probability Distribution**: 
   - The input image is linearly mapped to two vectors: 
     - A **mean vector**.
     - A **standard deviation vector**.
   - These vectors define a normal distribution in the latent space.

2. **Auxiliary Loss for Distribution Shape**: 
   - We ensure the latent space distribution resembles a zero-mean unit-variance Gaussian distribution (standard normal distribution).
   - An auxiliary loss, the Kullback-Leibler (KL) divergence between the mapped distribution and the standard normal distribution, is used in addition to the standard reconstruction loss
   - This loss guides the training to shape the latent distribution accordingly.
   - It ensures a well-structured and generalizable latent space for generating new images.

3. **Sampling and Decoding**: 
   - The variational approach allows for sampling from the defined distribution in the latent space.
   - These samples are then used by the decoder to generate new images.

4. **Reparametrization Trick**:
   - This trick enables backpropagation through random sampling, a crucial step in VAEs. Normally, backpropagating through a random sampling process from a distribution with mean ```mu``` and standard deviation ```sigma``` is challenging due to its nondeterministic nature.
   - The solution involves initially sampling random values from a standard normal distribution (mean 0, standard deviation 1). These values are then linearly transformed by multiplying with ```sigma``` and adding ```mu```. This process essentially samples from our target distribution with mean ```mu``` and standard deviation ```sigma```.
   - The key benefit of this approach is that the randomness (initial standard normal sampling) is separated from the learnable parameters (```mu``` and ```sigma```). ```Mu``` and ```sigma``` are deterministic and differentiable, allowing gradients with respect to them to be calculated during backpropagation. This enables the model to effectively learn from the data.

## Setup

Adapt `packagePath` to point to the directory containing this notebeook.

In [None]:
# Notebook id
nb_id = "vae/torch"

# Imports
import sys
import os

In [None]:
# Package Path (folder of this notebook)

#####################
# Local environment #
#####################

package_path = "./"


#########
# Colab #
#########


def check_for_colab():
    try:
        import google.colab

        return True
    except ImportError:
        return False


# running on Colab?
on_colab = check_for_colab()

if on_colab:
    # assume this notebook is run from Google Drive and the whole
    # cv-ml-lecture-notebooks repo has been setup via setupOnColab.ipynb

    # Google Drive mount point
    gdrive_mnt = "/content/drive"

    ##########################################################################
    # Ensure that this is the same as gdrive_repo_root in setupOnColab.ipynb #
    ##########################################################################
    # Path on Google Drive to the cv-ml-lecture-notebooks repo
    gdrive_repo_root = f"{gdrive_mnt}/MyDrive/cv-ml-lecture-notebooks"

    # mount drive
    from google.colab import drive

    drive.mount(gdrive_mnt, force_remount=True)

    # set package path
    package_path = f"{gdrive_repo_root}/{nb_id}"

# check whether package path exists
if not os.path.isdir(package_path):
    raise FileNotFoundError(f"Package path does not exist: {package_path}")

print(f"Package path: {package_path}")

In [None]:
# Additional imports

# Repository Root
repo_root = os.path.abspath(os.path.join(package_path, "..", ".."))
# Add the repository root to the system path
if repo_root not in sys.path:
    sys.path.append(repo_root)

# Package Imports
from nbutils import requirements as nb_reqs
from nbutils import colab as nb_colab
from nbutils import git as nb_git
from nbutils import exec as nb_exec
from nbutils import data as nb_data

In [None]:
# Clone git repository

# Absolute path of the repository directory
repo_dir = os.path.join(package_path, "torch-vae")
repo_url = "https://github.com/menzHSE/torch-vae.git"

nb_git.clone(repo_url, repo_dir, on_colab)

In [None]:
# Install requirements in the current Jupyter kernel
req_file = os.path.join(repo_dir, "requirements.txt")
nb_reqs.pip_install_reqs(req_file, on_colab)

# Additional requirements for this notebook
req_file = os.path.join(package_path, "requirements.txt")
nb_reqs.pip_install_reqs(req_file, on_colab)

# Latent Space Analysis using Fashion MNIST

To analyze the concept of the latent space, we use a  VAE with 2 latent dimensions pretrained on Fashion MNIST from https://github.com/menzHSE/torch-vae. This makes it easy to visualize. 

## Load the MNIST VAE model

In [None]:
import torch
import torchvision
import numpy as np

# random seed
seed = 0
torch.manual_seed(seed)
np.random.seed(seed)

# Add the directory containing models.py to the system path
sys.path.append(os.path.join(package_path, "torch-vae"))


# Now we can import the model and dataset
import model
import dataset
import device

# parameters
dataset_id = "fashion-mnist"
num_latent_dims = 2
max_num_filters = 128
img_size = (64, 64)
batch_size = 32
model_id = f"vae_filters_{max_num_filters:04d}_dims_{num_latent_dims:04d}.pth"
vae_fname = os.path.join(package_path, "torch-vae", "models", dataset_id, model_id)
device = device.autoselectDevice()

# load dataset
(
    mnist_train_loader,
    mnist_test_loader,
    mnist_classes_list,
    mnist_num_img_channels,
) = dataset.get_loaders(dataset_id, img_size=img_size, batch_size=batch_size)

# load the VAE model
vae_mnist_2 = model.VAE(
    num_latent_dims, mnist_num_img_channels, max_num_filters, device
)
vae_mnist_2.load_state_dict(torch.load(vae_fname, map_location=device))

if vae_mnist_2:
    print(f"Model {vae_fname} loaded successfully!")
    print(f"Device used: {device}")
    vae_mnist_2.to(device)
    vae_mnist_2.eval()

## Show some training images

In [None]:
import matplotlib.pyplot as plt

# get a batch of images from the training set and display them
# we use the torchvision.utils.make_grid function to create a grid of images
images, labels = next(iter(mnist_train_loader))
grid_img = torchvision.utils.make_grid(images, nrow=batch_size // 4)
plt.figure(figsize=(10, 10))
plt.imshow(np.transpose(grid_img, (1, 2, 0)))
plt.axis("off")
plt.title("Batch from the Training Set")
plt.show()

## Visualize the data distribution in latent space

Classes are clustered quite well with similar classes being close to each other in latent space. The whole distribution resembles a 2D zero-mean unit-variance Gaussian. 

In [None]:
# Encode all training images and plot the two-dimensional latent space
# representation colored by the class label

# Initialize lists to collect latent vectors and labels
latent_vectors = []
all_labels = []

# Loop through the dataset
for i, data in enumerate(mnist_train_loader):
    with torch.no_grad():
        print(f"Encoded batch {i+1}/{len(mnist_train_loader)}", end="\r")
        images, labels = data[0].to(device), data[1].to(device)

        # Encode image(s) to latent vector(s)
        z = vae_mnist_2.encode(images).cpu().numpy()
        latent_vectors.append(z)
        all_labels.append(labels.cpu().numpy())

# Concatenate all collected vectors and labels
latent_vectors = np.concatenate(latent_vectors, axis=0)
all_labels = np.concatenate(all_labels, axis=0)

In [None]:
# Create a figure
plt.figure(figsize=(10, 8))

# Scatter plot of latent vectors
# Adjust size (s) and color (c) as needed
plt.scatter(
    latent_vectors[:, 0],
    latent_vectors[:, 1],
    alpha=0.7,
    c=all_labels,
    cmap="tab10",
    s=10,
)

# Colorbar and labels
plt.colorbar()
plt.axis("equal")
plt.xlabel("Latent Dimension 1")
plt.ylabel("Latent Dimension 2")
plt.title("Latent Space Representation of Fashion MNIST Training Set")

# Adjust plot limits if needed
xlim = [-4, 4]
ylim = [-4, 4]
plt.xlim(xlim)
plt.ylim(ylim)

# Show plot
plt.show()

# mean and covariance of the latent space
mu = latent_vectors.mean(axis=0)
cov = np.cov(latent_vectors.T)
print(f"Mean and covariance of the latent space:\nmu={mu}\ncov={cov}")

## Visualize reconstructions in the latent space

In [None]:
# Generate a grid of images by uniformly sampling the latent space
# and decode the latent vectors to images

# Number of images per row and column
n = 20

# Size of each image (assuming square images)
image_size = 64

# Limits of the latent space
xlim = [-4, 4]
ylim = [-4, 4]

# Number of ticks on each axis
num_ticks = 9

# Create a grid of latent vectors
x = np.linspace(xlim[0], xlim[1], n)
y = np.linspace(ylim[1], ylim[0], n)
xx, yy = np.meshgrid(x, y)

# Create an empty array for the large image
large_image = np.zeros((n * image_size, n * image_size))

# Loop through the grid
for i in range(n):
    for j in range(n):
        # Get the latent vector
        z = np.array([[xx[i, j], yy[i, j]]])

        # Decode the latent vector to an image
        with torch.no_grad():
            x_decoded = (
                vae_mnist_2.decode(torch.from_numpy(z).float().to(device)).cpu().numpy()
            )

        # Place the decoded image in the large array
        large_image[
            i * image_size : (i + 1) * image_size, j * image_size : (j + 1) * image_size
        ] = x_decoded[0, 0]

# Create a figure
plt.figure(figsize=(10, 10))

# Display the large image
plt.imshow(large_image, cmap="gray")

# Set the ticks to correspond to the latent space values
tick_positions_x = np.linspace(0, n * image_size, num_ticks)
tick_labels_x = np.linspace(xlim[0], xlim[1], num_ticks)
plt.xticks(ticks=tick_positions_x, labels=[f"{val:.1f}" for val in tick_labels_x])

tick_positions_y = np.linspace(0, n * image_size, num_ticks)
tick_labels_y = np.linspace(ylim[0], ylim[1], num_ticks)
plt.yticks(
    ticks=tick_positions_y, labels=[f"{val:.1f}" for val in reversed(tick_labels_y)]
)  # Reversed y-labels

# Labels and title
plt.xlabel("Latent Dimension 1")
plt.ylabel("Latent Dimension 2")
plt.title("Grid of Images Sampled from Latent Space")

# Show the plot
plt.show()

# Reconstruction of the Fashion MNIST test samples

Here, we reconstruct the (unknown) Fashion MNIST test samples by encoding and decoding them. We additionally use a model with more latent dimensions here. Two latent dimensions is good for visually analyzing the latent space but typically too few dimensions for good reconstruction.

In [None]:
def reconstructAndPlot(vae, num_latent_dims, data_loader):
    # Take the first batch from the data_loader
    data = next(iter(data_loader))
    with torch.no_grad():
        # Get the testing data and push the data to the device we are using
        images = data[0].to(device)

        # Reconstruct (encode and decode) the images
        images_recon = vae(images)

        # Interleave original and reconstructed images
        images_comparison = torch.stack([images, images_recon], dim=1).view(
            -1, *images.size()[1:]
        )

        # Display the images in a grid
        # nrow is set to 2 since we want each pair (original and reconstructed) to be side by side
        grid_img = torchvision.utils.make_grid(
            images_comparison.cpu(), nrow=batch_size // 4
        )

    # Convert grid to numpy and transpose axes for plotting
    grid_np = grid_img.numpy()
    grid_np = np.transpose(grid_np, (1, 2, 0))

    # Plotting
    plt.figure(figsize=(15, 15))
    plt.imshow(grid_np)
    plt.axis("off")
    plt.title(
        f"Original and Reconstructed Images with {num_latent_dims} Latent Dimensions"
    )
    plt.show()

Visualization using the VAE with two latent dimensions.

In [None]:
# load the VAE model
num_latent_dims = 2
model_id = f"vae_filters_{max_num_filters:04d}_dims_{num_latent_dims:04d}.pth"
vae_fname = os.path.join(package_path, "torch-vae", "models", "fashion-mnist", model_id)

vae_mnist2 = model.VAE(num_latent_dims, mnist_num_img_channels, max_num_filters, device)
vae_mnist2.load_state_dict(torch.load(vae_fname, map_location=device))
vae_mnist2.to(device)
vae_mnist2.eval()

reconstructAndPlot(vae_mnist2, num_latent_dims, mnist_test_loader)

Visualization using the VAE with 8 latent dimensions

In [None]:
# load the VAE model
num_latent_dims = 8
model_id = f"vae_filters_{max_num_filters:04d}_dims_{num_latent_dims:04d}.pth"
vae_fname = os.path.join(package_path, "torch-vae", "models", "fashion-mnist", model_id)

vae_mnist_8 = model.VAE(
    num_latent_dims, mnist_num_img_channels, max_num_filters, device
)
vae_mnist_8.load_state_dict(torch.load(vae_fname, map_location=device))
vae_mnist_8.to(device)
vae_mnist_8.eval()

reconstructAndPlot(vae_mnist_8, num_latent_dims, mnist_test_loader)

# Generate random Fashion MNIST-like samples from the VAE

The variational autoencoders are trained in a way that the distribution in latent space resembles a normal distribution (see above). To generate samples from the variational autoencoder, we can sample a random normally distributed latent vector and have the decoder generate an image from that.

In [None]:
def sampleAndPlot(vae, num_latent_dims, num_samples=batch_size):
    with torch.no_grad():
        for i in range(num_samples):
            # generate a random latent vector

            # during training we have made sure that the distribution in latent
            # space remains close to a normal distribution

            z = torch.randn(num_latent_dims).to(device)

            # generate an image from the latent vector
            img = vae.decode(z)

            if i == 0:
                pics = img
            else:
                pics = torch.cat((pics, img), dim=0)

        # Create a grid of images
        grid_img = torchvision.utils.make_grid(pics, nrow=batch_size // 4)

        # Convert grid to numpy and transpose axes for plotting
        grid_np = grid_img.cpu().numpy()
        grid_np = np.transpose(grid_np, (1, 2, 0))

        # Plotting
        plt.figure(figsize=(10, 10))
        plt.imshow(grid_np)
        plt.axis("off")
        plt.title(
            f"Randomly Generated Images from the VAE with {num_latent_dims} Latent Dimensions"
        )
        plt.show()

In [None]:
sampleAndPlot(vae_mnist_8, num_latent_dims=8)