# Convolutional Variational Autoencoders using PyTorch

Markus Enzweiler, markus.enzweiler@hs-esslingen.de

This is a demo used in a Computer Vision & Machine Learning lecture. Feel free to use and contribute.

**See `vae_torch_mnist.ipynb` for a more in-depth notebook on variational autoencoders.**

## Setup

Adapt `packagePath` to point to the directory containing this notebeook.

In [None]:
# Imports
import sys
import os
import subprocess
import numpy as np

In [None]:
# Package Path
package_path = "./" # local
print(f"Package path: {package_path}")


def check_for_colab():
  try:
      import google.colab
      return True
  except ImportError:
      return False

# Running on Colab?
on_colab = check_for_colab()


In [None]:
# Clone git repository

# Absolute path of the repository directory
repo_dir = os.path.join(package_path, "torch-vae")
repo_url = "https://github.com/menzHSE/torch-vae.git"

# Store the original working directory
original_cwd = os.getcwd()

# Check if the directory already exists using the absolute path
if os.path.exists(os.path.join(original_cwd, repo_dir)):
    print("Repository exists. Resetting to HEAD...")
    # Navigate into the repository directory
    os.chdir(repo_dir)
    # Fetch the latest changes from the remote
    subprocess.run(["git", "fetch", "origin"])
    # Reset the local branch to the latest commit from the remote
    subprocess.run(["git", "reset", "--hard", "origin/HEAD"])
    # Change back to the original working directory
    os.chdir(original_cwd)
else:
    print("Cloning repository...")
    # Clone the repository if it doesn't exist
    subprocess.run(["git", "clone", repo_url, repo_dir])


In [None]:
# Install requirements in the current Jupyter kernel
req_file = os.path.join(repo_dir, "requirements.txt")
if os.path.exists(req_file):
    !{sys.executable} -m pip install -r {req_file}
else:
    print(f"Requirements file not found: {req_file}")


# Additional requirements for this notebook
req_file = "requirements.txt"
if os.path.exists(req_file):
    !{sys.executable} -m pip install -r {req_file}
else:
    print(f"Requirements file not found: {req_file}")
    

# Reconstruct and sample CelebA faces

If the dataset cannot be automatically downloaded by PyTorch due to **daily quota exceeded** you can manually download it and put it in the ```data/celaba``` folder, see code cell below.

The following files are necessary:
- img_align_celeba.zip
- list_attr_celeba.txt
- list_bbox_celeba.txt
- list_eval_partition.txt
- list_landmarks_align_celeba.txt
- list_landmarks_celeba.txt

In [None]:
# Set this to True to download the dataset from an alternative source
use_alternative_data_source = False

if use_alternative_data_source:

    dest_dir = "./data/celeba"

    data_url = "https://graal.ift.ulaval.ca/public/celeba/"
    celeba_files = [       
        "list_attr_celeba.txt",
        "list_bbox_celeba.txt",
        "list_eval_partition.txt",
        "list_landmarks_align_celeba.txt",
        "list_landmarks_celeba.txt",
        "img_align_celeba.zip"
    ]

    # Download each file using wget
    for file in celeba_files:
        file_url = data_url + file
        print(f"Downloading {file_url}...")
        wget_command = f"wget {file_url} -P {dest_dir}"
        subprocess.run(wget_command, shell=True, check=True)

    print("Download complete.")


## Load model and data

In [None]:
import torch
import torchvision
import matplotlib.pyplot as plt

# random seed
seed = 0
torch.manual_seed(seed)
np.random.seed(seed)

# Add the directory containing models.py to the system path
sys.path.append(os.path.join(package_path, 'torch-vae'))


# Now we can import the model and dataset
import model
import dataset
import device

# parameters
dataset_id       = "celeb-a"
num_latent_dims  = 64
max_num_filters  = 128
img_size         = (64, 64)
batch_size       = 32
model_id         = f"vae_filters_{max_num_filters:04d}_dims_{num_latent_dims:04d}.pth"
vae_fname        = os.path.join(package_path, "torch-vae", "models", dataset_id, model_id)
device           = device.autoselectDevice()
    
# load dataset
celeba_train_loader, celeba_test_loader, _, celeba_num_img_channels = dataset.get_loaders(
    dataset_id, img_size=img_size, batch_size=batch_size)

# load the VAE model
vae_celeb = model.VAE(num_latent_dims, celeba_num_img_channels, max_num_filters, device)
vae_celeb.load_state_dict(torch.load(vae_fname, map_location=device));

if vae_celeb:
    print(f"Model {vae_fname} loaded successfully!")
    print(f"Device used: {device}")
    vae_celeb.to(device)
    vae_celeb.eval()

## Reconstruction of the CelebA test samples


In [None]:
def reconstructAndPlot(vae, num_latent_dims, data_loader):
    # Take the first batch from the data_loader
    data = next(iter(data_loader))
    with torch.no_grad():
        # Get the testing data and push the data to the device we are using       
        images = data[0].to(device)

        # Reconstruct (encode and decode) the images
        images_recon = vae(images)

        # Interleave original and reconstructed images
        images_comparison = torch.stack([images, images_recon], dim=1).view(-1, *images.size()[1:])

        # Display the images in a grid
        # nrow is set to 2 since we want each pair (original and reconstructed) to be side by side
        grid_img = torchvision.utils.make_grid(images_comparison.cpu(), nrow=batch_size//4)

    # Convert grid to numpy and transpose axes for plotting
    grid_np = grid_img.numpy()
    grid_np = np.transpose(grid_np, (1, 2, 0))

    # Plotting
    plt.figure(figsize=(15, 15))
    plt.imshow(grid_np)
    plt.axis('off')
    plt.title(f'Original and Reconstructed Images with {num_latent_dims} Latent Dimensions')
    plt.show()
    

In [None]:
reconstructAndPlot(vae_celeb, num_latent_dims, celeba_test_loader)

## Generate random CelebA-like samples from the VAE

In [None]:
def sampleAndPlot(vae, num_latent_dims, num_samples=batch_size):
    with torch.no_grad():
        for i in range(num_samples):         
            
            # generate a random latent vector
            
            # during training we have made sure that the distribution in latent
            # space remains close to a normal distribution

            z = torch.randn(num_latent_dims).to(device)
          
            # generate an image from the latent vector
            img = vae.decode(z)
        
            if i == 0:
                pics = img
            else:
                pics = torch.cat((pics, img), dim=0)      
       
        # Create a grid of images
        grid_img = torchvision.utils.make_grid(pics, nrow=batch_size//4)

        # Convert grid to numpy and transpose axes for plotting
        grid_np = grid_img.cpu().numpy()
        grid_np = np.transpose(grid_np, (1, 2, 0))

        # Plotting
        plt.figure(figsize=(15, 15))
        plt.imshow(grid_np)
        plt.axis('off')
        plt.title(f'Randomly Generated Images from the VAE with {num_latent_dims} Latent Dimensions')
        plt.show()

In [None]:
sampleAndPlot(vae_celeb, num_latent_dims)