> Generative modeling

In this notebook, we’ll use a [Variational Autoencoder](https://en.wikipedia.org/wiki/Variational_autoencoder) (VAE) as a kind of “artist’s shorthand” for images of faces. The encoder learns to compress each image into just a few *latent* variables (a rough sketch that keeps the main traits but not every tiny detail) while the decoder learns to turn that sketch back into a full image. We also gently force this "compressed space" to follow a simple, standard [Gaussian](https://en.wikipedia.org/wiki/Normal_distribution) (a nice, round “cloud” of points), so we know what kind of latent variables are valid. That’s the fun part: because this space is smooth and well-behaved, nearby points correspond to similar faces, and we can safely sample random points from this Gaussian cloud to generate new faces the model has never seen before. The picture below tries to convey the main idea (notice that, on the left we have a 3D space, whereas the one on the right is 2D).

<div style="text-align: center;">
<img src="https://raw.githubusercontent.com/manuvazquez/uc3m_computation_and_intelligence/master/labs/notebooks/figures/vae.svg" alt="Description" width="1200">
</div>

# Setup

In principle, you could run the notebook either in *Colab* or locally. Is the notebook running in *Colab*?

In [None]:
try:
    import google.colab
    running_in_colab = True
except ImportError:
    running_in_colab = False

running_in_colab

If running in *Colab* we must install a couple of Python libraries. If not, you might want to choose a GPU if several are available.

In [None]:
import os

if running_in_colab:
    !pip install equinox numpyro
else:
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"

The remaining required `import`s come here

In [None]:
import pathlib
import random

import numpy as np
import torch
from torch import nn
import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torch import optim

from PIL import Image

import jax
import jax.numpy as jnp
import jax.random as jr
import equinox as eqx
import optax
from jaxtyping import Array, Float, Int
import numpyro.distributions as dist

import matplotlib.pyplot as plt

import kagglehub

For the sake of reproducibility, we set random *seeds*.

In [None]:
torch.manual_seed(42)
np.random.seed(42)

Device to be used...you should see *[CudaDevice(id=0)]* or similar (the *Cuda* prefix being the important bit) if you intend to (you should) use the available GPU.

In [None]:
jax.devices()

In this notebook we will make use of [JAX](https://docs.jax.dev/en/latest/) and [Equinox](https://docs.kidger.site/equinox/) libraries, which adopt a more [functional](https://en.wikipedia.org/wiki/Functional_programming) approach to computing. *PyTorch* is only used for data mangling.

# Dataset

Images from the [CelebFaces Dataset](https://www.kaggle.com/datasets/arnrob/celeba-small-images-dataset) are downloaded

In [None]:
imgs_dir = pathlib.Path(kagglehub.dataset_download("arnrob/celeba-small-images-dataset"))
print("Path to dataset files:", imgs_dir)

Some code to get the images ready for training. Essentially, we need to build a *PyTorch* `DataLoader` out of the images in the above directory.

In [None]:
class CustomImageDataset(Dataset):
    
    def __init__(self, img_dir, transform=None, n_samples=None):
        
        self.img_dir = img_dir
        self.transform = transform
        self.image_files = []

        # the names of *all* image files
        print(f"Scanning directory: {img_dir}")
        all_files = [os.path.join(img_dir, f) for f in os.listdir(img_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp'))]

        # optionally limit the number of samples and shuffle for randomness
        if n_samples is not None and n_samples < len(all_files):
            self.image_files = random.sample(all_files, n_samples)
        else:
            self.image_files = all_files
            random.shuffle(self.image_files) # Shuffle if using all files

        # we don't have *actual* labels, but in the usual `Dataset` one is expected; it is set to 0 (dummy label) for all images
        self.labels = [0] * len(self.image_files)

    def __len__(self):

        return len(self.image_files)

    def __getitem__(self, idx):

        img_path = self.image_files[idx]
        image = Image.open(img_path).convert('RGB')
        
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)
        
        return image, label


def get_celeba(
        batch_size: int,
        dataset_directory: str | pathlib.Path,
        n: int | None = None,
        data_subset: str = "training", # either 'training' or 'validation'
    ) -> torch.utils.data.DataLoader:

    # size of the images after resizing
    img_size: tuple[int, int] = (64, 64)

    train_transformation = transforms.Compose([
        transforms.Resize(img_size), # *images* are resized,...
        transforms.ToTensor(), # ...converted to *tensors*,...
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # ...and normalized to have values in [-1, 1]
    ])

    # the path to the specific data subset (e.g., 'training')
    actual_image_directory = pathlib.Path(dataset_directory) / data_subset
    
    if not actual_image_directory.is_dir():
        raise ValueError(f"Specified data_subset '{data_subset}' not found in '{dataset_directory}'.")

    train_dataset = CustomImageDataset(actual_image_directory, train_transformation, n_samples=n)

    # a `DataLoader` is returned
    return torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

Let us make a `DataLoader` looping through $10,000$ images in batches of size $32$

In [None]:
trainloader = get_celeba(32, imgs_dir, n=10_000)
trainloader

A *PyTorch* `DataLoader` is ultimately an iterator...

<font color='red'>TO-DO</font>: Get the first element from it. What is it?

In [None]:
# image = 

<font color='red'>TO-DO</font>: Extract the first image from the above `DataLoader`. What is its type?

$d_x$ in the picture above would here be the overall number of pixels in an image, i.e., *width* $\times$ *height* $\times$ 3 channels (RGB)

Let us write a convenience function to plot a *PyTorch* `Tensor` as an image. Notice that above's `transforms.Normalize` function is doing $\frac{x - 0.5}{0.5} = 2(x-0.5)$...which must be undone before plotting in order to get an image ready for "human consumption".

In [None]:
def show_image(img):
    
    # image is "unnormalized"
    img = img / 2 + 0.5
    
    # pytorch expects the channel dimension first whereas matplotlib expects it last
    plt.imshow(np.transpose(img, (1, 2, 0)))
    
    plt.show()

<font color='red'>TO-DO</font>: Plot the image you extracted above.

In [None]:
# show_image(image)

<font color='red'>TO-DO</font>: What happens if you skip the *unnormalization* part in `show_image`?

*PyTorch* provices a convenience function to stick together a bunch of images together. We can use it to plot a whole batch.

In [None]:
show_image(torchvision.utils.make_grid(next(iter(trainloader))[0]))

# Model

The model encompasses two components, the *encoder/compressor* (implementing function `f_enc` in the above picture) and the *decoder/decompressor* (implementing function `f_dec`). In the middle of them we have the "compressed space", known as the *latent* space. We can choose its dimension (size).

In [None]:
d_z = 3

## Encoder

A `class` defining the architecture of the *encoder* (i.e., the compressor). This is assuming $64 \times 64$ images. If that's not the case, tweaks are required.

In [None]:
class Encoder(eqx.Module):

    layers: list

    def __init__(self, d_z: int, input_channels: int = 3, rng_key: jr.PRNGKey = jr.PRNGKey(42)):

        key1, key2, key3, key4, key5, key6 = jr.split(rng_key, 6)

        self.layers = [
            eqx.nn.Conv2d(in_channels=input_channels, out_channels=32, kernel_size=4, stride=2, padding=1, key=key1),
            jax.nn.relu,
            eqx.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1, key=key2),
            jax.nn.relu,
            eqx.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1, key=key3),
            jax.nn.relu,
            eqx.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1, key=key4),
            jax.nn.relu,
            eqx.nn.Conv2d(in_channels=64, out_channels=256, kernel_size=4, stride=1, padding=0, key=key5),
            jnp.ravel,
            eqx.nn.Linear(256, 2*d_z, key=key6),
            lambda x: x.at[d_z:].set(jax.nn.softplus(x[d_z:]))
        ]
    
    def __call__(self, x):

        for layer in self.layers:

            x = layer(x)
        
        return x

Let us instantiate it

In [None]:
encoder = Encoder(d_z=d_z)

Object `encoder` ultimately behaves as a function accepting an image (in the form of an array) as input, and returning a vector of size $2 \times d_z$ yielding the corresponding mean and standard deviation (stacked together) of a Gaussian distribution in the latent space. Indeed, the encoder not only gives you some $z$ in the compressed space, but also a measure of its uncertainty.

Let us call it on the first image from the first batch.

In [None]:
z_mean_std = encoder(next(iter(trainloader))[0][0].numpy())
z_mean, z_std = jnp.split(z_mean_std, 2)
z_mean, z_std

Notice that, as required, the standar deviations are non-negative.

## Decoder

The architecture for the *decoder*, i.e., the decompressor.

In [None]:
class Decoder(eqx.Module):

    layers: list

    def __init__(self, d_z: int, input_channels: int = 3, rng_key: jr.PRNGKey = jr.PRNGKey(42)):

        key1, key2, key3, key4, key5, key6 = jr.split(rng_key, 6)

        self.layers = [
            eqx.nn.Linear(d_z, 256, key=key1),
            lambda x: jnp.reshape(x, (256, 1, 1)),
            jax.nn.relu,
            eqx.nn.ConvTranspose2d(in_channels=256, out_channels=64, kernel_size=4, stride=1, padding=0, key=key2),
            jax.nn.relu,
            eqx.nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1, key=key3),
            jax.nn.relu,
            eqx.nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=4, stride=2, padding=1, key=key4),
            jax.nn.relu,
            eqx.nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1, key=key5),
            jax.nn.relu,
            eqx.nn.ConvTranspose2d(in_channels=32, out_channels=input_channels, kernel_size=4, stride=2, padding=1, key=key6),
            jax.nn.tanh
        ]
    
    def __call__(self, x):

        for layer in self.layers:

            x = layer(x)
        
        return x

`decoder` behaves as a function acting on vectors in the latent space (of dimension $d_z$).

In [None]:
decoder = Decoder(d_z=d_z)

Let us draw a sample from the above mean and standard deviation (in this programming *functional* paradigm, we must pass a pseudo-random numbers generator key/seed, here `jr.PRNGKey(42)` every time a random number is to be produced)...

In [None]:
z = dist.Normal(loc=z_mean, scale=z_std).sample(jr.PRNGKey(42))

...and *decode* it to (maybe?) get back the original image.

In [None]:
x_est = decoder(z)
x_est.shape

<font color='red'>TO-DO</font>: Visualize the image. What is wrong?

# Training

Some hyperparameters that can be tweaked

In [None]:
learning_rate = 1e-3
# n_epochs = 40
n_epochs = 10
d_z = 75

Neural networks for the decoder and encoder are instantiated *and* initialized

In [None]:
encoder = Encoder(d_z=d_z)
decoder = Decoder(d_z=d_z)

For the sake of convenience, we will gather together both things in a Python `tuple`

In [None]:
model = encoder, decoder

We define the loss function (the one to be minimized). Prior to that, and for the sake of clarity, we also define [Kullback-Leibler divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence), which provides a way to tell how different two distributions are. Just **skip** this code for now (this course) since this has to do with technical details about a VAE.

In [None]:
def kl_loss(mean: Float[Array, 'feature'], sd: Float[Array, 'feature']) -> Float[Array, '']:

    return -0.5 * jnp.sum(1 + 2*jnp.log(sd) - mean**2 - sd**2)

def loss(model, x: Float[Array, 'batch channel width height'], rng_key) -> Float[Array, '']:

    # never mind for now...but this is variance assumed for the decoded `x`
    x_var = 0.1

    encoder, decoder = model
    z_mean_std = jax.vmap(encoder)(x)

    z = dist.Normal(loc=z_mean_std[:, :d_z], scale=z_mean_std[:, d_z:]).sample(rng_key)
    
    x_pred = jax.vmap(decoder)(z)

    log_likelihood = dist.Normal(loc=x_pred, scale=jnp.sqrt(x_var)).log_prob(x).sum()

    kl_divergence = jax.vmap(kl_loss)(z_mean_std[:, :d_z], z_mean_std[:, d_z:]).sum()
    
    return -log_likelihood + kl_divergence

The loss function is just, well...a function, that you can actually call like any other function. Let us get a batch of images from the `DataLoader` above

In [None]:
images, _ = next(iter(trainloader))
images.shape

<font color='red'>TO-DO</font>: Explain the size of the above `Tensor`.

<font color='red'>TO-DO</font>: Call the loss function (with the above `model`) on the images. Looking at the above definition, `loss` expects:

- the model,

- either a *numpy* or a *jax* array, so you must convert `images`, which is a *PyTorch* tensor (you can use the method `numpy()` on `images`)

- a pseudo-random numbers generator key (you can use again `jr.PRNGKey(42)`...or something else) to produce the required random numbers.

A nice thing about *JAX*/*Equinox* is that if you have a (Python) function, you can easily get the [gradient](https://en.wikipedia.org/wiki/Gradient) of that function by using `jax.grad`. In this case, since we are using *Equinox*, we call the equivalent *wrapper* `eqx.filter_grad`.

In [None]:
grad_loss = eqx.filter_grad(loss)

Now, `grad_loss` is a function taking the same arguments as `loss`, so you can...

<font color='red'>TO-DO</font>: Call the function just like you call `loss` above. What do you get? Keep in mind we are computing the *gradient* of the loss function!!

Now we make a function that packs the actions that must be carried out, during training, on every batch.

In [None]:
@eqx.filter_jit
def take_step(model, opt_state, x: jax.Array, rng_key):

    loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, rng_key)

    updates, opt_state = optim.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)
    
    return model, opt_state, loss_value

## Loop

Let's go for the training loop. You'll notice that the evolution of the loss is somehow *bumpy*. That's OK. Also, the first iteration might take a while (since the data are being read into memory).

In [None]:
optim = optax.adam(learning_rate)
opt_state = optim.init(eqx.filter(model, eqx.is_array))

key = jr.PRNGKey(42)

for epoch in range(n_epochs):

    key, subkey = jr.split(key)

    for x, _ in trainloader:

        model, opt_state, loss_value = take_step(model, opt_state, x.numpy(), subkey)

    print(epoch, loss_value)

# Results

Let us look a the first picture in the last batch processed in the training loop (still in `x`)

In [None]:
show_image(x[0])

Let us encode it into the latent space and decode it back. Formally, the decoder returns the mean and standard deviation (stacked together) in the data space but, for the sake of simplicity, we can take the former as if it were a sample.

In [None]:
encoder, decoder = model
show_image(decoder(encoder(x[0].numpy())[:d_z]))

# Experiments

<font color='red'>TO-DO</font>: Generate a couple of new images by drawing samples in the *latent space* and calling the decoder on them.

<font color='red'>TO-DO</font>: Train for a few more epochs to try and improve the quality of the *reconstructions*. Does it get better?

<font color='red'>TO-DO</font>: Train with very few images, say 10. The number of images you use for training (out of the overall number in the dataset) is controlled by the `n` parameter to the function `get_celeba` above. Use a much larger number of epochs, say 500, or the model won't have seen enough examples to learn anything. Then, generate a few images and compare them. What do you observe?

<font color='red'>TO-DO</font>: Play around with the dimension of the latent space. Can you get good results with a small dimension, say $d_z=10$?

# Sample questions

## What is a typical effect of using a *smaller* latent dimension (for example, going down to 10)?
- [ ] The model always trains faster and becomes perfect
- [ ] The model has less capacity to capture details, so reconstructions may become blurrier or lose information
- [ ] The model stops using the decoder network
- [ ] The model cannot be trained at all

## Why do we set random seeds (for example, for numpy or PyTorch) at the beginning?
- [ ] To make the training run only once
- [ ] To avoid using the GPU by mistake
- [ ] To make the results more reproducible when we run the code again
- [ ] To prevent the code from using any random numbers