In [None]:
%load_ext autoreload
%autoreload 2

from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"

try:
    import jupyter_black

    jupyter_black.load()
except:
    print("black not installed")

# Basic Image Generation

## Goals

- Train and understand different basic generative models for image generation
- Train a Generative Adversarial Network (GAN)
- Train a Variational Autoencoder (VAE) and understand its latent space.
- Train a Diffusion Model (DM) using `diffusers` components

## Setup

Let's define paths, install & load the necessary Python packages.

**Optionally: Save the notebook to your personal google drive to persist changes.**

Mount your google drive to store data and results (if running the code in Google Colab).

In [None]:
try:
    import google.colab

    IN_COLAB = True
except:
    IN_COLAB = False

print(f"In colab: {IN_COLAB}")

In [None]:
if IN_COLAB:
    from google.colab import drive

    drive.mount("/content/drive")

**Modify the following paths if necessary.**

That is where your data will be stored.

In [None]:
from pathlib import Path

if IN_COLAB:
    DATA_PATH = Path("/content/drive/MyDrive/cas-dl-module-genai-part2")
else:
    DATA_PATH = Path("../../data")

Install `dl_genai_lectures`

In [None]:
try:
    import dl_genai_lectures

    print("dl_genau_lectures installed, all good")
except ImportError as e:
    import os

    if Path("/workspace/code/src").exists():
        print("Installing from local repo")
        os.system("cd /workspace/code  && pip install .")
    else:
        print("Installing from git repo")
        os.system("pip install git+https://github.com/marco-willi/cas-dl-genai-exercises-fs2025")

Load all packages

In [None]:
import io
from typing import Callable

import lightning as L
import numpy as np
import requests
import seaborn as sns
import torch
from lightning.pytorch.loggers import TensorBoardLogger
from matplotlib import pyplot as plt
from PIL import Image
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchinfo import summary
from torchvision import transforms
from torchvision.transforms.v2 import functional as TF
from torchvision.utils import make_grid

from dl_genai_lectures import visualize

Define a default device for your computations.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using: {device}")

## 1) Prepare Dataset

We use the famous MNIST dataset because of its small size and illustrative power. Feel free to user a different dataset (which would need some adapations later for the models).

In [None]:
from torchvision.datasets import MNIST

ds_mnist_train = MNIST(root=DATA_PATH.joinpath("mnist"), train=True, download=True)

We inspect the dataset.

In [None]:
ds_mnist_train[0]

Each element consists of a [PIL.Image](https://pillow.readthedocs.io/en/stable/reference/Image.html), a commonly used class to store and process images in Python, and a label.

In [None]:
images = [ds_mnist_train[i][0] for i in range(0, 16)]
labels = [f"Label: {ds_mnist_train[i][1]}" for i in range(0, 16)]

In [None]:
fix, ax = visualize.plot_collage(images, captions=labels)

Now we select specific numbers to inspect their variations:

In [None]:
num_to_show = 16
digit_to_select = 1

ones = [i for i in range(len(ds_mnist_train)) if ds_mnist_train[i][1] == digit_to_select]

ones = ones[0:num_to_show]
images = [ds_mnist_train[i][0] for i in ones]
labels = [f"Label: {ds_mnist_train[i][1]}" for i in ones]

In [None]:
fix, ax = visualize.plot_collage(images, captions=labels)

**Task**: Inspect a few more digits and observe how they vary. 

We create a [lightning data module](https://lightning.ai/docs/pytorch/stable/data/datamodule.html) to handle our dataset.

In [None]:
class MNISTDataModule(L.LightningDataModule):
    def __init__(
        self,
        data_dir: str,
        batch_size: int = 32,
        num_workers: int = 2,
        transform_fn: Callable | None = None,
    ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        if transform_fn is not None:
            self.transform = transform_fn
        else:
            self.transform = transforms.Compose([transforms.ToTensor()])

    def prepare_data(self):
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            self.ds_train = MNIST(self.data_dir, train=True, transform=self.transform)

        if stage == "test" or stage is None:
            self.ds_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(
            self.ds_train,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )

    def test_dataloader(self):
        return DataLoader(self.ds_test, batch_size=self.batch_size, num_workers=self.num_workers)

Let's test the data module.

In [None]:
data_module = MNISTDataModule(data_dir=DATA_PATH.joinpath("mnist"))
data_module.setup("fit")

In [None]:
image, label = data_module.ds_train[0]
image.shape

In [None]:
dl = data_module.train_dataloader()
image_batch, label_batch = next(iter(dl))
image_batch.shape, label_batch.shape

## 2) Generative Adversarial Networks

In the following we will implement a Generative Adversarial Network (GAN) to generate samples.


We begin by implementing a Generator network, denoted by $\mathcal{G}$, which takes as input a latent vector $\mathbf{z} \in \mathbb{R}^d$, where the dimensionality $d$ is a configurable hyperparameter.

Given that the task involves image generation, $\mathcal{G}$ is implemented as a convolutional neural network (CNN).

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim=100):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim

        self.model = nn.Sequential(
            # Map the latent vector to a flat feature map (6 channels of 7x7)
            nn.Linear(self.latent_dim, 6 * 49),
            # Reshape the flat vector into a 3D tensor (channels, height, width)
            nn.Unflatten(1, (6, 7, 7)),
            # First transposed convolution: upsample from 7x7 to 14x14
            nn.ConvTranspose2d(6, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.LeakyReLU(),  # Activation for non-linearity
            nn.BatchNorm2d(32),  # Normalize activations to stabilize training
            # Second transposed convolution: upsample from 14x14 to 28x28
            nn.ConvTranspose2d(32, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.LeakyReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(),
            # Final convolution to reduce channels from 32 to 1 (grayscale output)
            nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
            # TanH activation to ensure output pixel values are in [-1, 1]
            nn.Tanh(),
        )

    def forward(self, z):
        # Forward pass through the generator network
        img = self.model(z)
        return img

Let's verify the model's input and output shapes are correct.

In [None]:
from torchinfo import summary

generator = Generator(latent_dim=5)
summary(generator, input_size=(32, 5))

Next, we implement a Discriminator network, denoted by $\mathcal{D}$, which takes an image $\mathbf{x}$ as input and outputs a single logit representing the unnormalized probability that the image is real (from the dataset) rather than generated (from $\mathcal{G}(\mathbf{z}))$.

Since the input consists of images, a convolutional neural network (CNN) is an appropriate architecture for $\mathcal{D}$.

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            # First convolution: downsample the input from 28x28 to 14x14
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(),  # Non-linear activation
            # nn.BatchNorm2d(32),  # Normalize activations to stabilize training
            # Second convolution: downsample from 14x14 to 7x7
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(),
            # nn.BatchNorm2d(64),
            # Flatten the 3D feature map into a 1D vector
            nn.Flatten(),
            # Fully connected layer to project features into a lower-dimensional representation
            nn.Linear(64 * 7 * 7, 1),
            # nn.LeakyReLU(),
            # # Final linear layer outputs a single logit (real vs. fake score)
            # nn.Linear(64, 1),
        )

    def forward(self, img):
        # Forward pass through the discriminator network
        logits = self.model(img)
        return logits

Now we can build the whole GAN model. We use a [L.LightningModule](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html) to help with managing the training loop.

In [None]:
class GAN(L.LightningModule):
    def __init__(self, latent_dim=20):
        super().__init__()
        self.save_hyperparameters()
        self.generator = Generator(latent_dim=latent_dim)
        self.discriminator = Discriminator()
        self.latent_dim = latent_dim
        # manual control over optimization
        self.automatic_optimization = False

    def forward(self, z):
        return self.generator(z)

    def training_step(self, batch, batch_idx):
        x_real, _ = batch
        optimizer_g, optimizer_d = self.optimizers()

        # sample latent vectors
        z = torch.randn((x_real.shape[0], self.latent_dim), device=self.device)
        # z = z.type_as(x_real)

        # ==== Train the Generator =====
        self.toggle_optimizer(optimizer_g)
        x_fake = self.generator(z)
        logits_fake = self.discriminator(x_fake)
        y_true = torch.ones_like(logits_fake)
        # y_true = y_true.type_as(x_real)
        g_loss = F.binary_cross_entropy_with_logits(logits_fake, y_true)
        self.log("g_loss", g_loss, prog_bar=True)
        self.manual_backward(g_loss)
        optimizer_g.step()
        optimizer_g.zero_grad()
        self.untoggle_optimizer(optimizer_g)

        # ==== Train the Discriminator =====
        self.toggle_optimizer(optimizer_d)

        # Real images (label = 1 + noise for better convergence)
        logits_real = self.discriminator(x_real)
        y_real = 1.0 - torch.rand_like(logits_real) * 0.05
        d_real_loss = F.binary_cross_entropy_with_logits(logits_real, y_real)

        # Fake images (label = 0 + noise)
        logits_fake = self.discriminator(x_fake.detach())
        y_fake = torch.rand_like(logits_fake) * 0.05
        d_fake_loss = F.binary_cross_entropy_with_logits(logits_fake, y_fake)

        # discriminator loss is the average of these
        d_loss = 0.5 * (d_real_loss + d_fake_loss)
        self.log("d_loss", d_loss, prog_bar=True)
        self.manual_backward(d_loss)
        optimizer_d.step()
        optimizer_d.zero_grad()
        self.untoggle_optimizer(optimizer_d)

    def configure_optimizers(self):
        optimizer_generator = torch.optim.Adam(self.generator.parameters(), lr=1e-3)
        optimizer_discriminator = torch.optim.Adam(self.discriminator.parameters(), lr=1e-3)
        return optimizer_generator, optimizer_discriminator

**Question**: What loss value to you expect in the limit? (if both models are equally well in the minimax game)

To assess the quality of generative models it is important to monitor sample quality. A simple way to do this is by visual inspection. We implement a callback which regularily saves samples.

In [None]:
class SampleMonitor(L.Callback):
    def __init__(self, latent_dim, num_samples=16):
        super().__init__()
        self.num_samples = num_samples
        self.test_z = torch.randn(16, latent_dim)

    def on_train_epoch_end(self, trainer, pl_module):
        test_images = pl_module.forward(self.test_z.to(pl_module.device))
        grid = make_grid(test_images)
        pl_module.logger.experiment.add_image(
            "train/generated_images", grid, trainer.current_epoch
        )

We can start a tensorboard server to observe training progress. 

You can open your browser at localhost:6006

In [None]:
%reload_ext tensorboard
%tensorboard --logdir={DATA_PATH.joinpath("lightning_logs")} --host 0.0.0.0 --port=6006

Now we train the model.

In [None]:
from lightning.pytorch.loggers import TensorBoardLogger

logger = TensorBoardLogger(DATA_PATH.joinpath("lightning_logs"), name="gan/")

L.seed_everything(123)

data_module = MNISTDataModule(
    data_dir=DATA_PATH.joinpath("mnist"),
    transform_fn=transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.5, 0.5)]),
)

LATENT_DIM = 20

gan_model = GAN(latent_dim=LATENT_DIM)
trainer = L.Trainer(
    accelerator="auto",
    devices=1,
    max_epochs=15,
    logger=logger,
    default_root_dir=DATA_PATH.joinpath("lightning_logs"),
    callbacks=[SampleMonitor(latent_dim=LATENT_DIM, num_samples=16)],
)
trainer.fit(gan_model, data_module)

**Task**: Feel free to change the architecture in order to improve the result.

**Task**: Sample from the model and display the result. You can look at the `SampleMonitor` class for inspiration.

## 3) Variational Autoencoders

In the following we will implement a Variational Autoencoder (VAE) to generate samples.

A VAE consists of an Encoder $\mathcal{E}$ which takes as input images $\mathbf{x}$ and outputs the parameters of a latent distribution $\mathbf{z}$. 

We define $\mathbf{z}$ to be normally distributed with a diagonal covariance matrix.

The dimensionality of $\mathbf{z}$ should be configurable.

Since we are dealing with images, we use a convolutional neural network.

In [None]:
class Sampling(nn.Module):
    """Sample from the latent space using the mean (z_mean) and log variance (z_log_var)"""

    def forward(self, z_mean, z_log_var):
        std = torch.exp(0.5 * z_log_var)
        eps = torch.randn_like(std)
        return z_mean + eps * std


class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(7 * 7 * 64, 16),
            nn.ReLU(),
        )
        self.z_mean = nn.Linear(16, latent_dim)
        self.z_log_var = nn.Linear(16, latent_dim)
        self.sampling = Sampling()

    def forward(self, x):
        # [N, C, H, W] -> [N, 16]
        pre_z = self.network(x)

        # get the mean and log variance of the latent space
        # for the data point

        # [N, 16] -> [N, latent_dim]
        z_mean = self.z_mean(pre_z)
        # [N, 16] -> [N, latent_dim]
        z_log_var = self.z_log_var(pre_z)

        # sample from the latent space using the mean and log variance

        # [N, latent_dim], [N, latent_dim] -> [N, latent_dim]
        z = self.sampling(z_mean, z_log_var)

        return z_mean, z_log_var, z

Let's check if the encoder layers are correctly parameterized.

In [None]:
from torchinfo import summary

encoder = Encoder(latent_dim=2)

summary(encoder, input_size=(16, 1, 28, 28))

In [None]:
class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.fc = nn.Sequential(nn.Linear(latent_dim, 7 * 7 * 64), nn.ReLU())

        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 1, kernel_size=3, padding=1),
            nn.Sigmoid(),
        )

    def forward(self, z):
        x = self.fc(z)
        x = x.view(-1, 64, 7, 7)
        x = self.deconv(x)
        return x

In [None]:
from torchinfo import summary

decoder = Decoder(latent_dim=2)

summary(decoder, input_size=(16, 2))

In [None]:
class VAE(L.LightningModule):
    def __init__(self, latent_dim=20):
        super().__init__()
        self.save_hyperparameters()
        self.encoder = Encoder(latent_dim=latent_dim)
        self.decoder = Decoder(latent_dim=latent_dim)

    def forward(self, z):
        return self.decoder(z)

    def training_step(self, batch, batch_idx):
        x, _ = batch
        z_mean, z_log_var, z = self.encoder(x)
        x_rec = self.decoder(z)
        rec_loss = F.binary_cross_entropy(x_rec, x, reduction="none").sum(dim=(1, 2, 3)).mean()
        kl_loss = -0.5 * torch.sum(1 + z_log_var - z_mean**2 - torch.exp(z_log_var))
        kl_loss = kl_loss / x.shape[0]
        loss = rec_loss + kl_loss
        self.log("rec_loss", rec_loss, prog_bar=True)
        self.log("kl_loss", kl_loss, prog_bar=True)
        self.log("loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

To assess the quality of the generated samples we monitor sample quality.

In [None]:
class VAESampleMonitor(L.Callback):
    def __init__(self, latent_dim, num_samples=16):
        super().__init__()
        self.num_samples = num_samples
        self.test_z = torch.randn(16, latent_dim)

    def on_train_epoch_end(self, trainer, pl_module):
        test_images = pl_module.forward(self.test_z.to(pl_module.device))
        grid = make_grid(test_images)
        pl_module.logger.experiment.add_image(
            "train/generated_images", grid, trainer.current_epoch
        )

Additionally we can also monitor the reconstruction quality.

In [None]:
class VAEReconstructionQualityMonitor(L.Callback):
    def __init__(self, image_batch):
        super().__init__()
        self.image_batch = image_batch

    def on_train_epoch_end(self, trainer, pl_module):

        z_mean, z_log_var, z = pl_module.encoder(self.image_batch.to(pl_module.device))
        image_batch_reconstructed = pl_module.decoder(z_mean).detach().cpu()

        all_images = torch.cat([self.image_batch, image_batch_reconstructed])
        grid = make_grid(all_images, nrow=self.image_batch.shape[0])
        pl_module.logger.experiment.add_image(
            "train/reconstructed_images", grid, trainer.current_epoch
        )

We start the tensorboard server, if not already running:

In [None]:
%reload_ext tensorboard
%tensorboard --logdir={DATA_PATH.joinpath("lightning_logs")} --host 0.0.0.0 --port=6006

And we train the model.

In [None]:
logger = TensorBoardLogger(DATA_PATH.joinpath("lightning_logs"), name="vae/")

L.seed_everything(123)

LATENT_DIM = 2
MAX_EPOCHS = 15

data_module = MNISTDataModule(data_dir=DATA_PATH.joinpath("mnist"))
data_module.setup("fit")

vae_model = VAE(latent_dim=LATENT_DIM)

reconstruction_test_batch = torch.stack([data_module.ds_train[i][0] for i in range(16)], axis=0)
callbacks_list = [
    VAESampleMonitor(latent_dim=LATENT_DIM, num_samples=16),
    VAEReconstructionQualityMonitor(image_batch=reconstruction_test_batch),
]

trainer = L.Trainer(
    accelerator="auto",
    devices=1,
    max_epochs=MAX_EPOCHS,
    logger=logger,
    default_root_dir=DATA_PATH.joinpath("lightning_logs"),
    callbacks=callbacks_list,
)
trainer.fit(vae_model, data_module)

By training a VAE to map samples from a probabilistic latent space to images, we can now investigate the latent space itself.

This works for visualizing 2 dimensional latent spaces.

In [None]:
def get_latent_space_samples(
    generator: nn.Module,
    num_points_per_dim=30,
    latent_range=1.0,
    figsize=15,
):
    from torchvision.transforms.v2 import functional as TF

    # define latent space grid
    latent_grid_x = np.linspace(-latent_range, latent_range, num_points_per_dim)
    latent_grid_y = np.linspace(-latent_range, latent_range, num_points_per_dim)[::-1]

    # sample from grid
    generator.eval()
    samples = list()
    with torch.no_grad():
        for x in latent_grid_x:
            for y in latent_grid_y:
                z_sample = torch.tensor([[x, y]], dtype=torch.float32).to(generator.device)
                x_decoded = TF.to_pil_image(generator(z_sample).cpu().squeeze(0))
                samples.append(x_decoded)
    return samples

We decode points in the latent space by sampling along a regular 2-D grid.

In [None]:
samples = get_latent_space_samples(vae_model, num_points_per_dim=30, latent_range=3.0)

Now we display the samples.

In [None]:
fig, ax = visualize.plot_collage(samples, axes_iteration_order="C")

**Question**: Explain the observed structure of the latent space.

We can also exploit the fact that we have labels to inspect the latent space. We can simply encode original data points to the latent space and then color the points according to their class label.

In [None]:
# Plot label clusters
def plot_label_clusters(vae, dl, num_smples=1000):
    vae.eval()
    latents = []
    labels = []
    num_samples_processed = 0

    with torch.no_grad():
        for x, y in dl:
            labels += y.numpy().tolist()
            x = x.to(vae.device)
            z_mean, _, _ = vae.encoder(x)
            latents.append(z_mean.cpu().numpy())
            num_samples_processed += x.shape[0]
            if num_samples_processed >= num_smples:
                break

    z_mean = np.concatenate(latents, axis=0)

    fig, ax = plt.subplots(figsize=(8, 6))

    sns.scatterplot(
        x=z_mean[:, 0],
        y=z_mean[:, 1],
        hue=labels,
        palette="tab10",
        s=12,
        ax=ax,
    )
    ax.set_xlabel("z[0]")
    ax.set_ylabel("z[1]")
    ax.set_title("Latent space clusters")
    plt.show()

In [None]:
NUM_SAMPLES = 20000
images = [ds_mnist_train[i][0] for i in range(0, NUM_SAMPLES)]
labels = [ds_mnist_train[i][1] for i in range(0, NUM_SAMPLES)]

data_module = MNISTDataModule(data_dir=DATA_PATH.joinpath("mnist"))
data_module.setup("fit")
dl = data_module.train_dataloader()


plot_label_clusters(vae_model, dl)

**Question**: Without using any labels in the training process, the latent space separates the different digits quite nicely. How could this be useful for other tasks?

**Task**: Try to increase sample quality. What happens with a larger latent space?

## 3) Diffusion Models

Diffusion models (DMs) consist of two processes:

- a _forward|noise|diffusion process_ which gradually adds random Gaussian noise to the input $\mathbf{x}$. It can also be considered an encoder.
- a _reverse|denoising process_ which learns how to (gradually) remove noise as added by the _forward process_. It can also be considered a decoder.

DMs are latent variable models where the latent variable $\mathbf{z}$ has the same dimensionality as the input data $\mathbf{x}$.

We follow the original implementation [Denoising Diffusion Probabilistic Models](http://arxiv.org/abs/2006.11239). The notation of the variables follows the paper.

We start by implementing the forward process.

In [None]:
class ForwardProcess(nn.Module):
    def __init__(self, T=1000, beta_0=1e-4, beta_T=0.02):
        super().__init__()
        self.beta_T = beta_T
        self.beta_0 = beta_0
        self.T = T

        alpha_bars, alphas, betas = self._precompute_variances()
        self.register_buffer("alpha_bar_ts", torch.tensor(alpha_bars))
        self.register_buffer("alpha_ts", 1 - torch.tensor(alphas))
        self.register_buffer("beta_ts", 1 - torch.tensor(betas))

    def get_beta(self, t):
        return (((self.beta_T - self.beta_0) / self.T) * t) + self.beta_0

    def _precompute_variances(self):
        betas = list()
        beta = self.get_beta(1)
        betas.append(beta)

        alpha = list()
        alpha.append(1 - beta)

        alpha_bar = list()
        alpha_bar.append(1 - self.get_beta(1))

        for t in range(2, self.T + 1):
            beta = self.get_beta(t)
            alpha_t = 1 - beta
            alpha_bar.append(alpha_bar[-1] * alpha_t)
            alpha.append(alpha_t)
            betas.append(beta)
        return alpha_bar, alpha, betas

    def forward(self, x_0, t):

        N, C, H, W = x_0.shape

        device = x_0.device

        eps = torch.randn_like(x_0).to(device)

        alpha_bar = self.alpha_bar_ts[t].view(N, 1, 1, 1).to(device)

        mean = torch.sqrt(alpha_bar) * x_0
        std = torch.sqrt(1.0 - alpha_bar)
        x_t = mean + std * eps

        return x_t, eps

In [None]:
fw = ForwardProcess()

Now we test the forward process. We use $T=1000$ for the number of steps where we add noise.

In [None]:
T = 1000
betas = np.array([fw.get_beta(t) for t in range(1, T + 1)])
_ = sns.lineplot(betas).set(xlabel="Timestep [t]", ylabel=r"$\beta_t$", title="Noise Schedule")

plt.show()

The plot above shows how the noise variance increases over time. Here a linear schedule is used.

In [None]:
alphas = fw.alpha_bar_ts.numpy()
_ = sns.lineplot(alphas).set(
    xlabel="Timestep [t]", ylabel=r"$\bar{\alpha}_t$", title="Signal to Noise Ratio"
)

The plot above shows how the ratio between signal and noise. We see that, in the limit, no signal is left and the the result of the diffusion process is identical to that of a Gaussian distribution.

Lets take an image.

In [None]:
url = "https://github.com/pytorch/vision/blob/main/gallery/assets/dog2.jpg?raw=true"
r = requests.get(url, allow_redirects=True)
image = Image.open(io.BytesIO(r.content))
image

We need to convert the image to a tensor and scale it into [-1, 1].

In [None]:
import torchvision.transforms.v2.functional as TF

x = TF.to_image(image).to(torch.float32)

from torchvision.transforms import Compose, Normalize, ToPILImage, ToTensor

transf = Compose([ToTensor(), Normalize(0.5, 0.5)])
inverse_transf = Compose([Normalize(0.0, 1.0 / 0.5), Normalize(-0.5, 1.0), ToPILImage()])


x = transf(image)
x_batch = x.unsqueeze(0)
x_rec = inverse_transf(x)

np.testing.assert_allclose(
    np.array(list(image.getdata())), np.array(list(x_rec.getdata())), atol=1
)
x.shape
x.max()
x.min()

In [None]:
output, noise_added = fw(x_batch, 0)
output.shape
image_rec = inverse_transf(output.squeeze(0))
image_rec

Now lets take a look at the image degradation over multiple steps.


In [None]:
steps = np.linspace(0, 1000 - 1, 10).astype(int)
steps

noised_images = list()
for t in steps:
    noised_images.append(inverse_transf(fw(x_batch, t)[0].squeeze(0)))

In [None]:
fig, ax = visualize.plot_collage(
    noised_images, captions=[f"t={s + 1}" for s in steps], nrows=1, ncols=len(steps)
)

We simplify things going forward and use a forward process from the diffusers library. This also includes tricks like  clipping the latent representations to a specified range.

In [None]:
from diffusers import DDPMScheduler

noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

Let's quickly test it,

In [None]:
noise = torch.randn_like(x_batch)
steps = np.linspace(0, 1000 - 1, 10).astype(int)

noised_images = list()
for t in steps:
    output = noise_scheduler.add_noise(
        x_batch, noise=noise, timesteps=torch.tensor([t], device=x.device)
    )
    noised_images.append(inverse_transf(output.squeeze(0)))

In [None]:
fig, ax = visualize.plot_collage(
    noised_images, captions=[f"t={s + 1}" for s in steps], nrows=1, ncols=len(steps)
)

That looks very similar!

The forward process seems to work. Now we need to implement the backward or _Denoising_ process.

This is a neural network that estimates the noise added to an image at a given time step.

We use a Unet with convolutional layers. More details can be found here: [Keras Example](https://keras.io/examples/generative/ddpm/)

We use the HuggingFace Library to configure our network.

In [None]:
from diffusers import UNet2DModel

rp = UNet2DModel(
    sample_size=32,
    in_channels=1,
    out_channels=1,
    layers_per_block=2,
    block_out_channels=(32, 32, 64, 128),
    down_block_types=(
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",
        "AttnUpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)

In [None]:
summary(rp, sample=torch.randn((16, 1, 32, 32)), timestep=1)

x_out = rp(sample=torch.randn((1, 1, 32, 32)).to(rp.device), timestep=0)
x_out["sample"].shape

Now we build our model class.

In [None]:
class DM(L.LightningModule):
    def __init__(self, T=1000, beta_0=1e-4, beta_T=0.02):
        super().__init__()
        self.save_hyperparameters()
        self.noise_scheduler = DDPMScheduler(
            num_train_timesteps=T,
            beta_start=beta_0,
            beta_end=beta_T,
        )
        self.denoising = UNet2DModel(
            sample_size=32,
            in_channels=1,
            out_channels=1,
            layers_per_block=2,
            block_out_channels=(32, 32, 64, 128),
            down_block_types=(
                "DownBlock2D",
                "DownBlock2D",
                "AttnDownBlock2D",
                "DownBlock2D",
            ),
            up_block_types=(
                "UpBlock2D",
                "AttnUpBlock2D",
                "UpBlock2D",
                "UpBlock2D",
            ),
        )

    def forward(self, x_T):

        N, C, H, W = x_T.shape
        device = x_T.device
        x_t = x_T

        for t in reversed(range(0, self.noise.T)):
            t_tensor = torch.full((N,), t, device=device, dtype=torch.int)

            # Predict noise
            eps_hat = self.denoising(x_t, t_tensor)["sample"]

            beta_t = self.noise.beta_ts[t].view(N, 1, 1, 1)
            alpha_t = self.noise.alpha_ts[t].view(N, 1, 1, 1)
            alpha_bar_t = self.noise.alpha_bar_ts[t].view(N, 1, 1, 1)

            # Compute the mean of the reverse process
            mean = (1 / torch.sqrt(alpha_t)) * (
                x_t - ((1 - alpha_t) / torch.sqrt(1 - alpha_bar_t)) * eps_hat
            )

            # Add noise if t > 1
            if t > 0:
                noise = torch.randn_like(x_t)
                sigma_t = torch.sqrt(beta_t)
                x_t = mean + sigma_t * noise
            else:
                x_t = mean  # final denoised sample

        return x_t

    def training_step(self, batch, batch_idx):
        x_zero, _ = batch

        batch_size = x_zero.shape[0]

        # sample timestemps
        timesteps = torch.randint(
            0,
            self.noise_scheduler.config.num_train_timesteps,
            (batch_size,),
            device=x_zero.device,
        )

        # noise input
        eps = torch.randn(x_zero.shape).to(x_zero.device)
        x_t = self.noise_scheduler.add_noise(x_zero, eps, timesteps)

        # estimate noise
        noise_estimate = self.denoising(x_t, timesteps)["sample"]

        # compute loss
        loss = F.mse_loss(noise_estimate, eps)

        self.log("loss", loss, prog_bar=True)

        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

Monitor the quality of the generated samples.

In [None]:
from diffusers import DDPMPipeline


class DMSampleMonitor(L.Callback):
    def __init__(self, latent_dim, num_samples=16):
        super().__init__()
        self.num_samples = num_samples
        self.test_z = torch.randn((num_samples,) + latent_dim)

    def on_fit_start(self, trainer, pl_module):
        self.pipe = DDPMPipeline(unet=pl_module.denoising, scheduler=pl_module.noise_scheduler)

    def on_train_epoch_end(self, trainer, pl_module):
        test_images = self.pipe(batch_size=self.num_samples, output_type="pil").images
        test_images = [TF.to_tensor(image) for image in test_images]
        grid = make_grid(test_images)
        pl_module.logger.experiment.add_image(
            "train/generated_images", grid, trainer.current_epoch
        )

Let's monitor our progress:

In [None]:
%reload_ext tensorboard
%tensorboard --logdir={DATA_PATH.joinpath("lightning_logs")} --host 0.0.0.0 --port=6006

We train the model.

In [None]:
logger = TensorBoardLogger(DATA_PATH.joinpath("lightning_logs"), name="dm/")

L.seed_everything(123)

LATENT_DIM = (1, 32, 32)
MAX_EPOCHS = 15
NUM_STEPS = 1000

transf = transforms.Compose(
    [
        transforms.ToTensor(),  # Converts to [0, 1] float tensor
        transforms.Normalize(0.5, 0.5),  # Scale to [-1, 1]
        transforms.Pad(padding=2, fill=0),  # Zero-pad 2 pixels on all sides
    ]
)

inverse_transf = transforms.Compose(
    [
        transforms.CenterCrop((28, 28)),
        transforms.Normalize(0.0, 1.0 / 0.5),
        transforms.Normalize(-0.5, 1.0),
        transforms.ToPILImage(),
    ]
)

data_module = MNISTDataModule(data_dir=DATA_PATH.joinpath("mnist"), transform_fn=transf)

data_module.setup("fit")

dm_model = DM(T=NUM_STEPS, beta_0=1e-4, beta_T=0.02)


callbacks_list = [
    DMSampleMonitor(latent_dim=LATENT_DIM, num_samples=16),
]


trainer = L.Trainer(
    accelerator="auto",
    devices=1,
    max_epochs=MAX_EPOCHS,
    logger=logger,
    default_root_dir=DATA_PATH.joinpath("lightning_logs"),
    callbacks=callbacks_list,
)
trainer.fit(dm_model, data_module)

We can sample from the model:

In [None]:
from diffusers import DDPMPipeline

pipe = DDPMPipeline(unet=dm_model.denoising, scheduler=dm_model.noise_scheduler).to(device)
images = pipe(output_type="pil", batch_size=32).images

In [None]:
fig, ax = visualize.plot_collage(images, nrows=6, ncols=6)