# Exercise 3: **Diffusion Models**

In this exercise, you will get some hands-on experience with **Deep Generative Modelling**. There are many different types of deep generative models (VAEs, GANs, Normalizing Flows, etc.), all with different strengths and weaknesses. Over the last few years, the paradigm of Diffusion Models (DMs) has started to dominate the generative modelling landscape, thanks to the impactful work on [DDPMs](https://arxiv.org/pdf/2006.11239) (Ho et al. 2020) that found a super simple parameterisation of DMs revealing an equivalence to another, so far separate, line of research on so-called denoising score-based generative models and showing that DMs can produce diverse samples of very high quality.

### ❗ Task 3.1: Read the DDPM paper
We will focus on DDPMs (Ho et al. 2020) in this exercise. You are supposed to read the [DDPM paper](https://arxiv.org/pdf/2006.11239). The following tasks will ask you to implement the placeholders in this notebook (indicated by 💻), which most of the time have a direct correspondence to equations or algorithms from the paper. After you went through the paper, you can continue with this notebook.

<div>
<img src="https://huggingface.co/blog/assets/78_annotated-diffusion/ddpm_paper.png" width="500"/>
</div>

## 🚀 Notebook Requirements
First, install and import all the basic packages and dependencies.

In [None]:
!pip3 install -q --upgrade pip
!pip3 install -q diffusers

In [None]:
import random
import imageio
import numpy as np
from argparse import ArgumentParser

from tqdm.auto import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

import torchvision
from torchvision.transforms import Compose, ToTensor, Lambda, Grayscale, Resize
from torchvision.datasets.mnist import MNIST, FashionMNIST

import matplotlib.pyplot as plt
from IPython.display import display

## 🪲 The Datasets
For the purpose of this exercise, we added a few different datasets that you can experiment with:
- **MNIST**: You all know it! A classical simple grayscale dataset of 28x28 images of handwritten digits (10 classes).
- **FashionMNIST**: Similarly to MNIST, this dataset contains 10 classes with 28x28 images, but this time of fashion assets, like shoes, dresses, etc.
- **CIFAR100**: Well-known dataset with 100 classes of 32x32 RGB images. The classes cover natural things such as different types of animals, but also objects like planes, cars, trucks, etc.
- **CIFAR100Gray**: Same as CIFAR100, just transformed into grayscale. Instead of 3 channels, we only have 1.

### Loading the Dataset

In [None]:
def denormalize_to_zero_to_one(img):
    img = img.clamp(-1, 1)
    return (img + 1.0) / 2.0


def load_dataset(dataset_name, n_classes):

  if dataset_name in ["MNIST", "FashionMNIST"]:
      assert n_classes <= 10, "Please choose n_classes <= 10 for the selected dataset."

  print(f"Loading dataset: {dataset_name} (reduced to first {n_classes} classes)")

  if dataset_name == "CIFAR100":
      img_transform = Compose([
          ToTensor(),
          Lambda(lambda x: (x - 0.5) * 2)],
      )
      dataset = torchvision.datasets.CIFAR100(root="data", download=True, transform=img_transform)

      def get_class_name(label):
          classes = ['apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm']
          return classes[label]

      total_n_classes = 100
      image_size = (32, 32)
      image_channels = 3

  elif dataset_name == "CIFAR100Gray":
      img_transform = Compose([
          Grayscale(),
          ToTensor(),
          Lambda(lambda x: (x - 0.5) * 2)],
      )
      dataset = torchvision.datasets.CIFAR100(root="data", download=True, transform=img_transform)

      def get_class_name(label):
          classes = ['apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm']
          return classes[label]

      total_n_classes = 100
      image_size = (32, 32)
      image_channels = 1

  elif dataset_name == "FashionMNIST":
      img_transform = Compose([
          ToTensor(),
          Lambda(lambda x: (x - 0.5) * 2)]
      )
      dataset = torchvision.datasets.FashionMNIST("./datasets", download=True, train=True, transform=img_transform)

      def get_class_name(label):
          classes = ["T-shirt", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
          return classes[label]

      total_n_classes = 10
      image_size = (28, 28)
      image_channels = 1

  elif dataset_name == "MNIST":
      img_transform = Compose([
          ToTensor(),
          Lambda(lambda x: (x - 0.5) * 2)]
      )
      dataset = torchvision.datasets.MNIST("./datasets", download=True, train=True, transform=img_transform)

      def get_class_name(label):
          return label

      total_n_classes = 10
      image_size = (32, 32)
      image_channels = 1

  else:
      raise NotImplementedError

  class_labels = list(range(n_classes))
  indices = [idx for idx, target in enumerate(dataset.targets) if target in class_labels]
  dataset = torch.utils.data.Subset(dataset, indices)

  return dataset, class_labels, get_class_name, image_size, image_channels

In [None]:
def show_images(images, title="", rows=None):
    """Shows the provided images as sub-pictures in a grid"""

    # Convert images to CPU numpy arrays if they are PyTorch tensors
    if isinstance(images, torch.Tensor):
        images = images.detach().cpu()
    else:
        images = torch.tensor(images)

    # Ensure images are in the shape (N, C, H, W)
    if images.ndim == 3:
        images = images.unsqueeze(1)  # Add channel dimension for grayscale images

    # Determine the number of rows and columns
    num_images = len(images)
    rows = int(num_images ** 0.5) if rows is None else rows
    cols = (num_images + rows - 1) // rows  # Ensure all images are included in the grid

    # Calculate total cells and the number of padding images needed
    total_cells = rows * cols
    padding_images = total_cells - num_images

    # Create padding images (white images)
    if padding_images > 0:
        white_image = torch.ones_like(images[0])  # Create a single white image
        white_images = white_image.unsqueeze(0).repeat(padding_images, 1, 1, 1)  # Repeat it
        images = torch.cat((images, white_images), dim=0)  # Append to the images tensor

    # Create a grid of images
    grid = torchvision.utils.make_grid(images, nrow=cols, padding=2, pad_value=1)  # pad_value=1 makes padding white

    # Convert the grid to a numpy array for displaying
    grid_np = grid.permute(1, 2, 0).numpy()

    # Display the grid
    plt.figure(figsize=(8, 8))
    plt.imshow(grid_np.clip(0, 1))
    plt.title(title)
    plt.axis('off')
    plt.gcf().set_dpi(150)
    plt.show()

In [None]:
def get_images_per_class(dataset, label, n_images: int = 16):
  images = []
  for img, img_label in dataset:
    if img_label == label:
      images.append(img)
    if len(images) == n_images:
      break
  return images

### ❗ Task 3.2: Visualize the Datasets

Before we get started with our implementation, let's take a look at some images of the different datasets. You can simply choose a dataset from the configuration below and also specify the number of classes that want to use from the dataset. Fewer classes obviously reduce the size of the datasets.

In [None]:
# @title Dataset Configuration { display-mode: "form", run: "auto" }
# @markdown ### Choose a dataset:
dataset_name = "FashionMNIST" # @param ["MNIST", "FashionMNIST", "CIFAR100", "CIFAR100Gray"] {type:"string"}
n_classes = 10 # @param {type:"slider", min:1, max:100, step:1}

dataset, class_labels, get_class_name, image_size, image_channels = load_dataset(dataset_name, n_classes)

print("Dataset size:", len(dataset))

Now you can run the cell below to take a look at the images.

In [None]:
# visualize the chosen dataset
for label in class_labels:
    class_images = get_images_per_class(dataset, label)
    class_row = denormalize_to_zero_to_one(torchvision.utils.make_grid(class_images)).permute(1, 2, 0).numpy()

    plt.imshow(class_row)
    plt.title(f'{get_class_name(label)}')
    plt.axis('off')
    plt.show()

    print(" ")

In [None]:
# double-check that you are using "cuda" as your device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

## ❗ Task 3.3: Unconditional Generation
As you know from reading the paper, a **DDPM** (Ho et al. 2020) is a Diffusion Model that discretizes the diffusion process into a finite number $T$ of steps and learns to reverse this process by training a conditional deep neural network to estimate the noise that has been added to a sample $x_t$ at time step $t$.

### Forward Diffusion Process
The **forward diffusion process** $q(x_t\mid x_{t-1})$, i.e. the process that diffuses the data distribution by adding noise, is fixed and given as a Markov chain whose transition kernel adds Gaussian noise according to a variance schedule $\beta_1,\beta_2,...,\beta_T$ (see Equation 2 in the DDPM paper). The cool thing about the Gaussian definition is that we can directly sample $x_t$ in closed form at any timestep $t$ without having to do it iteratively $x_0 \rightarrow x_1 \rightarrow ... \rightarrow x_t$. This means that given $x_0$, which is the raw sample from the data distribution, and a timestep $t$, we can sample $x_t$ by just sampling from a Gaussian centred around $x_0$ with a variance that depends on $t$. The higher $t$, the larger the variance essentially, but the schedule defines how much the variance changes over time.


<p align="center">
    <img src="https://user-images.githubusercontent.com/10695622/174349667-04e9e485-793b-429a-affe-096e8199ad5b.png" width="800"/>
    <br>
    <br>
    <em> Figure from DDPM paper (https://arxiv.org/abs/2006.11239). </em>
<p>

#### Variance Schedule
The variance schedule is fixed for our DDPM, so we can precompute at once. Here is a simple function that does that for you.

In [None]:
def get_variance_schedule(min_beta, max_beta, n_steps):
  # linear variance schedule
  betas = torch.linspace(min_beta, max_beta, n_steps).to(device)
  alphas = 1 - betas
  alpha_bars = torch.tensor([torch.prod(alphas[:i + 1]) for i in range(len(alphas))])
  return betas, alphas, alpha_bars

#### 💻 Forward Diffusion Process

Now it starts to get interesting. Please take a look at equation (4) and Algorithm (1) in the DDPM paper, which show the closed form of the forward process and the way it can be computed.**Text fett markieren**

Given $\overline\alpha_t$=`alpha_bars[t]`, please complete the function called `sample_xt` below.
- `x0` is the original image tensor, `t` is the time step, `alpha_bars` is a precomputed array of numbers, and `epsilon` is an optionally provided noise map
- if `epsilon` is not provided, you MUST generate a basic noise map yourself (remember to push it to the same device as `x0`)
- this function is supposed to sample $x_t$ from $q(x_t\mid x_0)$.

In [None]:
def sample_xt(x0, t, alpha_bars, epsilon=None):
    # --------------------- IMPLEMENTATION REQUIRED ------------------------

    # image dimensions
    n, c, h, w = x0.shape

    if epsilon is None:
      epsilon = torch.randn(n, c, h, w).to(x0.device)

    # variance schedule
    alpha_bar_t = alpha_bars[t].reshape(n, 1, 1, 1)

    mean = alpha_bar_t.sqrt() * x0
    var = (1 - alpha_bar_t).sqrt()

    xt = mean + epsilon * var

    # ----------------------------------------------------------------------
    return xt

The following function visualizes the forward process:

In [None]:
def show_forward(sample, alpha_bars, device):
    p_list = np.linspace(0, 1, 10)

    img, label = sample
    x0 = torch.cat([torch.tensor(img).unsqueeze(dim=0)] * len(p_list), dim=0).to(device)

    t = torch.tensor([max(0, int(p * n_steps) - 1) for p in p_list]).to(device)
    samples = sample_xt(x0, t, alpha_bars, None).cpu()

    fig, axes = plt.subplots(1, len(p_list), figsize=(16, 2), sharey=True)
    for tt, ax, xt in zip(t, axes, samples):

        xt = denormalize_to_zero_to_one(xt).numpy()

        ax.imshow(xt.transpose(1, 2, 0).clip(0, 1), cmap='gray' if image_channels == 1 else None)
        ax.set_title(f"$t={tt}$")
        ax.set_axis_off()

    plt.suptitle(f"{get_class_name(label)} class")
    plt.tight_layout()
    plt.show()

Here you can test your function on the dataset that you selected before:

In [None]:
# some hyperparameters for the visualization
min_beta = 10 ** -4
max_beta = 0.02
n_steps = 100

# precompute the variables of the variance schedule
_, _, alpha_bars = get_variance_schedule(min_beta, max_beta, n_steps)
alpha_bars = alpha_bars.to(device)

# show forward process for 3 first samples of the dataset
for x in list(iter(dataset))[:3]:
  show_forward(x, alpha_bars, device=device)

#### Reverse Diffusion Process
The reverse process is the direction parameterized by a neural network $\epsilon_\theta$. It is learned based on data!

For this $\epsilon_\theta$ model, the DDPM authors used a U-Net architecture. We will define the model a bit later, because its structure depends on the dataset you will be choosing.

#### DDPM Class
For convenience, we will create a DDPM wrapper class that unifies the forward and reverse process logic:

In [None]:
# DDPM class
class DDPM(nn.Module):

    def __init__(self, model, image_size, image_channels, n_steps=200, min_beta=10 ** -4, max_beta=0.02, device=None):
        super(DDPM, self).__init__()
        self.model = model.to(device)
        self.device = device

        self.image_size = image_size
        self.image_channels = image_channels
        self.n_steps = n_steps

        # precompute the variables of the variance schedule
        betas, alphas, alpha_bars = get_variance_schedule(min_beta, max_beta, n_steps)
        self.betas, self.alphas, self.alpha_bars = betas.to(device), alphas.to(device), alpha_bars.to(device)


    def forward(self, x0, t, epsilon=None):
        # Forward process (see Section 3.1 in paper)
        return sample_xt(x0, t, self.alpha_bars, epsilon)


    def reverse(self, x, t):
        # Reverse process (see section 3.2 in paper)
        return self.model(x, t, return_dict=False)[0]

#### 💻 Iterative **Reverse** Sampling
One of the original ideas coming from the score-based generative modeling research is to represent a data distribution $p(x)$ by modelling its gradient $\nabla p(x)$. For sampling, an iterative gradient-based procedure is then applied that follows the local gradients and over time optimizes the likelihood of $x_t$. This is separate from the gradient-based optimization of the parameters $\theta$ of $\epsilon_\theta$, as the sampling procedure is used after the model has been fully trained.

In [None]:
def sample(ddpm: DDPM, n_samples: int):
  h, w = ddpm.image_size
  c = ddpm.image_channels

  with torch.no_grad():

    # --------------------- IMPLEMENTATION REQUIRED ------------------------
    x = torch.randn(n_samples, c, h, w).to(device)
    # ----------------------------------------------------------------------

    for t in list(range(ddpm.n_steps))[::-1]:

      # --------------------- IMPLEMENTATION REQUIRED ------------------------
      time_tensor = (torch.ones(n_samples) * t).to(device).long()
      epsilon_theta = ddpm.reverse(x, time_tensor)

      alpha_t = ddpm.alphas[t]
      alpha_t_bar = ddpm.alpha_bars[t]

      # Partially denoising the image
      x = (1 / alpha_t.sqrt()) * (x - (1 - alpha_t) / (1 - alpha_t_bar).sqrt() * epsilon_theta)

      # Langevin dynamics
      if t > 0:
          beta_t = ddpm.betas[t]
          sigma_t = beta_t.sqrt()
          z = torch.randn(n_samples, c, h, w).to(device)
          x = x + sigma_t * z
      # ----------------------------------------------------------------------

  x = denormalize_to_zero_to_one(x)

  return x

### 💻 Training Logic
We have all the components that we need at this point except for one. We still need the logic to train our DDPMs. Take another look at Algorithm (1) of the paper and then complete the training loop below by implementing the two placeholders:
1. create the loss function
2. implement the Algorithm (1) + standard PyTorch training logic

In [None]:
def train(ddpm, loader, n_epochs, optim, device, display=False, store_path="ddpm.pt"):

    # --------------------- IMPLEMENTATION REQUIRED ------------------------
    mse = nn.MSELoss()
    # ----------------------------------------------------------------------

    best_loss = float("inf")
    n_steps = ddpm.n_steps

    batch_losses = []
    for epoch in tqdm(range(n_epochs), desc=f"Training progress"):
        epoch_loss = 0.0
        for step, batch in enumerate(tqdm(loader, leave=False, desc=f"Epoch {epoch + 1}/{n_epochs}")):

            # --------------------- IMPLEMENTATION REQUIRED ------------------------
            # Loading data
            x0 = batch[0].to(device)
            n = len(x0)

            # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
            epsilon = torch.randn_like(x0).to(device)
            t = torch.randint(0, n_steps, (n,)).to(device)

            # Computing the noisy image based on x0 and the time-step (forward process)
            noisy_imgs = ddpm(x0, t, epsilon)

            # Getting model estimation of noise based on the images and the time-step
            epsilon_theta = ddpm.reverse(noisy_imgs, t)

            # Optimizing the MSE between the noise plugged and the predicted noise
            loss = mse(epsilon_theta, epsilon)

            optim.zero_grad()
            loss.backward()
            optim.step()
            # ----------------------------------------------------------------------

            epoch_loss += loss.item() * len(x0) / len(loader.dataset)

        # Display images generated at this epoch
        if display:
            show_images(sample(ddpm, 8), f"Images generated at epoch {epoch + 1}")

        log_string = f"Loss at epoch {epoch + 1}: {epoch_loss:.3f}"

        # Storing the model
        if best_loss > epoch_loss:
            best_loss = epoch_loss
            torch.save(ddpm.state_dict(), store_path)
            log_string += " --> Best model ever (stored)"

        print(log_string)

Before we can test your training loop, choose a dataset as before!

In [None]:
# @title Dataset Configuration { display-mode: "form", run: "auto" }
# @markdown ### Choose a dataset:
dataset_name = "FashionMNIST" # @param ["MNIST", "FashionMNIST", "CIFAR100", "CIFAR100Gray"] {type:"string"}
n_classes = 10 # @param {type:"slider", min:1, max:100, step:1}

dataset, class_labels, get_class_name, image_size, image_channels = load_dataset(dataset_name, n_classes)

print("Dataset size:", len(dataset))

#### The U-Net Architecture
Here is a nice U-Net implementation from the 🧨 diffusers library. Feel free to play around with it.


In [None]:
from diffusers import UNet2DModel

# instantiate our U-Net model
model = UNet2DModel(
  sample_size=image_size, # the target image resolution
  in_channels=image_channels, # the number of input channels, 3 for RGB images
  out_channels=image_channels, # the number of output channels
  layers_per_block=2, # how many ResNet layers to use per UNet block
  block_out_channels=(
    64,
    128,
    128,
  ), # the number of output channels for each UNet block
  down_block_types=(
    "DownBlock2D", # a regular ResNet downsampling block
    "DownBlock2D", # a ResNet downsampling block with spatial␣
    "DownBlock2D", # a regular ResNet downsampling block
  ),
  up_block_types=(
    "UpBlock2D", # a regular ResNet upsampling block
    "UpBlock2D", # a ResNet upsampling block with spatial self-attention
    "UpBlock2D", # a regular ResNet upsampling block
  ))

# count the number of parameters to see the model size
n_params = sum([p.numel() for p in model.parameters()])
print(f"Created model with {n_params} parameters!")


You now also have choose the hyperparameters for the training:

In [None]:
# @title Training Configuration { display-mode: "form", run: "auto" }
# @markdown ### Enter a name for your training run:
run_name = "ddpm_fashion" # @param {type:"string"}

# @markdown ### Training Hyperparameters:
batch_size = 128 # @param {type:"slider", min:1, max:512, step:1}
n_epochs = 3 # @param {type:"slider", min:1, max:50, step:1}
lr = 0.7585 # @param {type:"slider", min:0.0, max:1.0, step:0.0001}

# @markdown ### DDPM Hyperparameters:
n_steps = 1001 # @param {type:"slider", min:1, max:2000, step:10}
min_beta = 0.0001 # @param {type:"slider", min:0.0001, max:1.0, step:0.0001}
max_beta = 0.02 # @param {type:"slider", min:0.0001, max:1.0, step:0.0001}

After you made these choices, you can start the training! Let's see if it works:

In [None]:
# instantiate our DDPM, wrapped around our model
ddpm = DDPM(model, image_size=image_size, image_channels=image_channels, n_steps=n_steps, min_beta=min_beta, max_beta=max_beta, device=device)

# create our dataloader
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# start the training
train(ddpm, train_dataloader, n_epochs, optim=torch.optim.AdamW(ddpm.parameters(), lr), display=True, device=device, store_path=f'{run_name}.pt')

### Inference

In [None]:
# specify the model name for inference
inference_run_name = "ddpm_fashion"

ddpm.load_state_dict(torch.load(f'{inference_run_name}.pt', map_location=device))
ddpm.eval()
print("Model loaded:", inference_run_name)

# number of samples to generate
N = 64

# run the iterative sampling procedure
samples = sample(ddpm, N)

# show the samples
show_images(samples, "Unconditional Samples")

# ❗ Task 3.4: Conditional Generation
The denoising process has so far been totally unconditional, which gave a lot of freedom to the model during training but also zero control to us during inference. Now we want to try to change that and use the classes to condition the reverse process. Our model will then be class-conditional. There are different ways to achieve that, but for this exercise we will just rely on the 🧨 diffusers library to create learnable class embeddings for us that will be incorporated in the forward pass of the network. Check out the documentation of the [UNet2DModel](https://huggingface.co/docs/diffusers/en/api/models/unet2d) if you want to play around more with.

Still, we cannot enforce it to generate images of the given class during inference, but the hope is that the model learns to rely on the additional information, because it helps it to better estimate the noise that was added to a particular sample $x_t$ at time step $t$.

First, we will have to upgrade our DDPM class by making sure that the additional class label information $y$ will be propagated to our model in the reverse process.

In [None]:
# DDPM class
class ClassConditionalDDPM(DDPM):

    def reverse(self, x, t, y):
        return self.model(x, t, y, return_dict=False)[0]

### 💻 Conditional Sampling
Then we need to extend our sampling function as well. It will now also expect an additional argument called `n_classes`.

In [None]:
def sample(ddpm: DDPM, n_samples: int, classes: list[int] = None):
  h, w = ddpm.image_size
  c = ddpm.image_channels

  with torch.no_grad():

    # --------------------- IMPLEMENTATION REQUIRED ------------------------
    if classes:
      n_samples = n_samples * len(classes)

    x = torch.randn(n_samples, c, h, w).to(device)
    # ----------------------------------------------------------------------

    for t in list(range(ddpm.n_steps))[::-1]:

      # --------------------- IMPLEMENTATION REQUIRED ------------------------
      time_tensor = (torch.ones(n_samples) * t).to(device).long()

      if classes:
        y = torch.cat([torch.ones(n_samples // len(classes)) * label for label in classes]).long().to(device)
        epsilon_theta = ddpm.reverse(x, time_tensor, y)
      else:
        epsilon_theta = ddpm.reverse(x, time_tensor)

      alpha_t = ddpm.alphas[t]
      alpha_t_bar = ddpm.alpha_bars[t]

      # Partially denoising the image
      x = (1 / alpha_t.sqrt()) * (x - (1 - alpha_t) / (1 - alpha_t_bar).sqrt() * epsilon_theta)

      # Langevin dynamics
      if t > 0:
          beta_t = ddpm.betas[t]
          sigma_t = beta_t.sqrt()
          z = torch.randn(n_samples, c, h, w).to(device)
          x = x + sigma_t * z
      # ----------------------------------------------------------------------

  x = samples = denormalize_to_zero_to_one(x)

  return x

### 💻 Conditional DDPM Training
Same applies to the training loop. The overall logic is the exact same as before. You just have to get the class information from the batch (hint: `y=batch[1]`) and make sure it is actually used.

In [None]:
def train_classconditional(ddpm, loader, n_epochs, optim, device, display=False, store_path="ddpm_model.pt"):

    # --------------------- IMPLEMENTATION REQUIRED ------------------------
    mse = nn.MSELoss()
    # ----------------------------------------------------------------------

    best_loss = float("inf")
    n_steps = ddpm.n_steps

    batch_losses = []
    for epoch in tqdm(range(n_epochs), desc=f"Training progress"):
        epoch_loss = 0.0
        for step, batch in enumerate(tqdm(loader, leave=False, desc=f"Epoch {epoch + 1}/{n_epochs}")):

            # --------------------- IMPLEMENTATION REQUIRED ------------------------
            # Loading data
            x0 = batch[0].to(device)
            y = batch[1].to(device)
            n = len(x0)

            # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
            epsilon = torch.randn_like(x0).to(device)
            t = torch.randint(0, n_steps, (n,)).to(device)

            # Computing the noisy image based on x0 and the time-step (forward process)
            noisy_imgs = ddpm(x0, t, epsilon)

            # Getting model estimation of noise based on the images and the time-step
            epsilon_theta = ddpm.reverse(noisy_imgs, t, y)

            # Optimizing the MSE between the noise plugged and the predicted noise
            loss = mse(epsilon_theta, epsilon)

            optim.zero_grad()
            loss.backward()
            optim.step()
            # ----------------------------------------------------------------------

            epoch_loss += loss.item() * len(x0) / len(loader.dataset)

        # Display images generated at this epoch
        if display:
            is_conditional = ddpm.model.class_embedding is not None
            show_images(sample(ddpm, 8, n_classes if is_conditional else None), f"Images generated at epoch {epoch + 1}")

        log_string = f"Loss at epoch {epoch + 1}: {epoch_loss:.3f}"

        # Storing the model
        if best_loss > epoch_loss:
            best_loss = epoch_loss
            torch.save(ddpm.state_dict(), store_path)
            log_string += " --> Best model ever (stored)"

        print(log_string)

In [None]:
from diffusers import UNet2DModel

# instantiate our U-Net model
cond_model = UNet2DModel(

  num_class_embeds=n_classes, # ONLY THIS HERE CHANGED!

  sample_size=image_size, # the target image resolution
  in_channels=image_channels, # the number of input channels, 3 for RGB images
  out_channels=image_channels, # the number of output channels
  layers_per_block=2, # how many ResNet layers to use per UNet block
  block_out_channels=(
    64,
    128,
    128,
  ), # the number of output channels for each UNet block
  down_block_types=(
    "DownBlock2D", # a regular ResNet downsampling block
    "DownBlock2D", # a ResNet downsampling block with spatial␣
    "DownBlock2D", # a regular ResNet downsampling block
  ),
  up_block_types=(
    "UpBlock2D", # a regular ResNet upsampling block
    "UpBlock2D", # a ResNet upsampling block with spatial self-attention
    "UpBlock2D", # a regular ResNet upsampling block
  ))

# count the number of parameters to see the model size
n_params = sum([p.numel() for p in model.parameters()])
print(f"Created model with {n_params} parameters!")


In [None]:
# @title Conditional Training Configuration { display-mode: "form", run: "auto" }

# @markdown ### Enter a name for your training run:
cond_run_name = "conditional_ddpm_fashion" # @param {type:"string"}

# @markdown ### Training Hyperparameters:
batch_size = 128 # @param {type:"slider", min:1, max:512, step:1}
n_epochs = 1 # @param {type:"slider", min:1, max:50, step:1}
lr = 0.001 # @param {type:"slider", min:0.0, max:1.0, step:0.0001}

# @markdown ### DDPM Hyperparameters:
n_steps = 500 # @param {type:"slider", min:1, max:2000, step:10}
min_beta = 0.0001 # @param {type:"slider", min:0.0001, max:1.0, step:0.0001}
max_beta = 0.02 # @param {type:"slider", min:0.0001, max:1.0, step:0.0001}

In [None]:
# instantiate our DDPM, wrapped around our model
cond_ddpm = ClassConditionalDDPM(cond_model, image_size=image_size, image_channels=image_channels, n_steps=n_steps, min_beta=min_beta, max_beta=max_beta, device=device)

# create our dataloader
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# start the training
train_classconditional(cond_ddpm, train_dataloader, n_epochs, optim=torch.optim.AdamW(cond_ddpm.parameters(), lr), display=False, device=device, store_path=f'{cond_run_name}.pt')

### Conditional Inference

In [None]:
# Loading the trained model
cond_ddpm.load_state_dict(torch.load(f'{cond_run_name}.pt', map_location=device))
cond_ddpm.eval()
print("Model loaded:", cond_run_name)

In [None]:
samples = sample(cond_ddpm, 4, class_labels)
show_images(samples, "Conditional Samples", rows=len(class_labels))

In [None]:
# in case of too many classes use this approach to avoid memory issues
for label in range(n_classes):
  print(get_class_name(label))
  samples = sample(cond_ddpm, 4, [label])
  show_images(samples, "Conditional Samples", rows=1)

# ❗ Task 3.5: Explore!
Go back in the notebook and play around with different hyperparameters and modelling choices. Can you make it work with the more challenging CIFAR100 dataset? You can also try to change the architecture of the $\epsilon_\theta$ model, i.e. change the parameters of the [UNet2D](https://huggingface.co/docs/diffusers/en/api/models/unet2d) class. Feel also free to share some of your favorite samples in the forum on Moodle.

# Further Readings

For more insights on the connections to score-based generative modelling, you can take a look at these great posts:

- [Score-based Perspective](https://yang-song.net/blog/2021/score/).
- [Diffusion Model Perspective](https://calvinyluo.com/2022/08/26/diffusion-tutorial.html)