# Diffusion Models
Diffusion models are a type of generative model that model the data generation process as a Markov chain of diffusion steps. They gradually add noise to the data and learn to reverse this process to generate new samples. Diffusion models have several advantages over other generative models (VAE, GAN, ...), including improved sample quality (compared to VAE) and better mode coverage (compared to GAN). They are defined as a sequence of conditional Gaussian distributions, where each step adds or removes noise from the data. The key challenge in training diffusion models is to learn the reverse process, which requires optimizing a complex objective function. Despite this challenge, diffusion models have been extensively implemented in major models such as DALL-E, Imagen, and Stable Diffusion and are at the core of more recent video generative models such as Sora.

We will cover the [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) *(Ho et al.)* approach in this notebook

Some parts are adapted from this excellent [YouTube video](https://www.youtube.com/watch?v=a4Yfz2FxXiY)

U-Net backbone architecture adapated from [this Github](https://github.com/milesial/Pytorch-UNet)

In [None]:
import time
import math
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
from torchvision.utils import make_grid

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Hyperparameters
num_epochs = 100
batch_size = 8
lr = 0.001

train_ratio = 0.002   # ratio of training data to use
test_ratio = 0.1
img_size = 32

T = 300               # diffusion steps

In [None]:
# Some utils functions for later
def plot_grid(noisy_imgs) :
    noisy_grid = make_grid(noisy_imgs, nrow=8, padding=2, normalize=True)

    # Convert the grid tensor to a numpy array
    noisy_grid_np = noisy_grid.permute(1, 2, 0).detach().cpu().numpy()

    # Display the grid of noisy images
    plt.figure(figsize=(10, 10))
    plt.imshow(noisy_grid_np)
    plt.axis('off')
    plt.show()

def show_images(image):
    reverse_transforms = transforms.Compose([
        transforms.Lambda(lambda t: (t + 1) / 2),
        transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
        transforms.Lambda(lambda t: t * 255.),
        transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
        transforms.ToPILImage(),
    ])

    # Take first image of batch
    if len(image.shape) == 4:
        image = image[0, :, :, :]
    plt.imshow(reverse_transforms(image))

# Forward Pass
## Build the noise scheduler

We want to define a forward diffusion process by introducing a kernel $q$ that describes the transition probability from the data at the previous timestep ${x_{t-1}}$ to the data at the current timestep ${x_t}$. This kernel is defined as a conditional Gaussian distribution:

$$q(x_t| x_{t-1}) = \mathcal{N}(x_t;\sqrt{1- \beta_t} x_{t-1},\beta_t I )$$

Here, ${\mathcal{N}(x;\mu,\sigma I )}$ represents a Gaussian distribution over $x$ with mean ${\mu}$ and covariance matrix ${\sigma I}$, where $I$ is the identity matrix. The notation ${y \sim \mathcal{N}(x;\mu,\sigma I )}$ means that $y$ is a sample drawn from this distribution given $x$. In other words, the distribution of $y$ given $x$ is ${\mathcal{N}(\mu,\sigma I )}$.

Let's first define the variance schedule ${\beta_1, \beta_2, ..., \beta_T}$ that controls how much noise is added to the data at each `timestep`. The choice of the variance schedule can have a significant impact on the performance of the model. In this implementation, we use a linear variance schedule, where the variance increases linearly from a small value `start` to a larger value `end` over a fixed number of `timesteps`.

In [None]:
# Linear variance schedule
def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    """ return timesteps evenly spaced values from start to stop """

    ## TO DO

    return beta_schedule

We define $\alpha_t := 1-\beta_t$ and $\overline{\alpha_t} = \prod_{i=1}^{t} \alpha_i$ so that we can sample forward directly at timestep t from the original data $\mathbf{x}_0$ :
$$q(\mathbf{x}_{1:T}|\mathbf{x}_0) = \prod_{t=1}^{T}q(\mathbf{x}_t|\mathbf{x}_{t-1})$$
$$ q(\mathbf{x}_t| \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_t;\sqrt{\overline{\alpha_t}} \mathbf{x}_0, (1-\overline{\alpha_t})I)$$

In [None]:
def create_alphas(betas) :
  """ return the alphas as defined above """

  ## TO DO

  return alphas


def create_alphas_cumprod(alphas) :
  """ return the cumulative product of the alphas (alpha bar) """

  ## TO DO

  return alphas_cumprod

# NB : We break down the construction of the alpha_cumprod in 2 steps as we will
# need the intermediate alphas values for the backward process

In [None]:
## Define linear variance schedule
betas = linear_beta_schedule(timesteps=T)

## Pre-calculate different terms for closed form

alphas = create_alphas(betas)
alphas_cumprod = create_alphas_cumprod(alphas)

# cumulative product of the alpha values up to the previous timestep
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)

sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

# Variance of the noise that was added at the current timestep, given the noisy data at the current and previous timestep
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

In [None]:
## Plot the cumulative product of the alphas (alpha bar)
## TO DO

### Q1) What do $\overline{\alpha_t}$ represent for the forward process at step `t` ?

A1)

In [None]:
## Define the forward process function

def get_index_from_list(vals, t, x_shape):
    """
    Returns a specific index t of a passed list of values vals
    while considering the batch dimension.
    """
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu()) # alpha bar for values t for instance
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(device)

def forward_diffusion_sample(x_0, t):
    """
    Takes an image and a timestep as input and
    returns the noisy version of it
    """
    noise = torch.randn_like(x_0)
    sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x_0.shape
    )
    # mean + variance
    return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \
    + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)

In [None]:
## Prepare the MNIST dataset

def get_subset(dataset, ratio):
    """
    Returns a subset of the dataset containing a specified ratio of examples.
    """
    indices = list(range(len(dataset)))
    subset_size = max(1, int(len(dataset) * ratio))
    subset_indices = indices[:subset_size]
    return torch.utils.data.Subset(dataset, subset_indices)

# Define transforms
transform = transforms.Compose([
    transforms.Pad((2, 2, 2, 2), fill=0, padding_mode='constant'),  # Add padding to obtain 32x32 images
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Download and load the MNIST dataset
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainset = get_subset(trainset, ratio=train_ratio)  # Use only train_ratio of the training set for faster training during debugging
print(f'Length of training set: {len(trainset)}')
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, drop_last=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
print(f'Length of test set: {len(testset)}')
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, drop_last=True)

In [None]:
# Noise an image
for batch_nbr, (imgs, labels) in enumerate(trainloader):
    t = torch.randint(0, T, (imgs.shape[0],))
    print("Noising steps list :\n", t)
    noisy_imgs, noise = forward_diffusion_sample(imgs, t)
    plot_grid(noisy_imgs)
    break

You should notice that the bigger the noising timestep is, the noisier the image will be

## U-Net architecture


In [None]:
## Let's now define the UNet architecture
## U-Net model taken from the very clean pytorch implementation of milesial
## https://github.com/milesial/Pytorch-UNet

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, time_emb_dim, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True)
        )
        self.time_mlp = nn.Sequential(nn.Linear(time_emb_dim, mid_channels), nn.ReLU())
        self.conv2 = nn.Sequential(
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x, t):
        x = self.conv1(x)
        t = self.time_mlp(t)
        t = t[(..., ) + (None, ) * 2]
        x = x + t # We add the diffusion step embedding to the image embedding
        x = self.conv2(x)
        return x


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels, time_emb_dim):
        super().__init__()
        self.maxpool = nn.MaxPool2d(2)
        self.conv = DoubleConv(in_channels, out_channels, time_emb_dim)

    def forward(self, x, t):
        return self.maxpool(self.conv(x, t))


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, time_emb_dim, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, time_emb_dim, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels, time_emb_dim)

    def forward(self, x1, x2, t):
        x1 = self.up(x1)

        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x, t)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

In [None]:
## Full assembly of the parts to form the complete network

class UNet(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=False):
        super(UNet, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.bilinear = bilinear

        # Time embedding
        self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(dim=32),
                nn.Linear(32, 32),
                nn.ReLU()
            )

        self.inc = (DoubleConv(self.in_channels, 64, time_emb_dim=32))
        self.down1 = (Down(64, 128, time_emb_dim=32))
        self.down2 = (Down(128, 256, time_emb_dim=32))
        self.down3 = (Down(256, 512, time_emb_dim=32))
        factor = 2 if bilinear else 1
        self.down4 = (Down(512, 1024 // factor, time_emb_dim=32))
        self.up1 = (Up(1024, 512 // factor, time_emb_dim=32, bilinear=bilinear))
        self.up2 = (Up(512, 256 // factor, time_emb_dim=32, bilinear=bilinear))
        self.up3 = (Up(256, 128 // factor, time_emb_dim=32, bilinear=bilinear))
        self.up4 = (Up(128, 64, time_emb_dim=32, bilinear=bilinear))
        self.outc = (OutConv(64, self.out_channels))

    def forward(self, x, timestep):
        t = self.time_mlp(timestep)
        x1 = self.inc(x, t)
        x2 = self.down1(x1, t)
        x3 = self.down2(x2, t)
        x4 = self.down3(x3, t)
        x5 = self.down4(x4, t)
        x = self.up1(x5, x4, t)
        x = self.up2(x, x3, t)
        x = self.up3(x, x2, t)
        x = self.up4(x, x1, t)
        logits = self.outc(x)
        return logits

In [None]:
## Sampling function for generation

@torch.no_grad()
def sample_plot_reconstruction(nbr_imgs):
    # Sample noise
    img = torch.randn((10, 1, img_size, img_size), device=device)
    plt.figure(figsize=(8,8))
    plt.axis('off')
    stepsize = int(T/nbr_imgs)

    for i in range(0,T)[::-1]:
        t = torch.full((img.shape[0],), i, device=device, dtype=torch.long)

        betas_t = get_index_from_list(betas, t, img.shape)
        sqrt_one_minus_alphas_cumprod_t = get_index_from_list(sqrt_one_minus_alphas_cumprod, t, img.shape)
        sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, img.shape)

        # Call model (current image - noise prediction)
        model_mean = sqrt_recip_alphas_t * (img - betas_t * model(img, t) / sqrt_one_minus_alphas_cumprod_t)
        posterior_variance_t = get_index_from_list(posterior_variance, t, img.shape)

        if t[0] == 0:
            img = model_mean
        else:
            noise = torch.randn_like(img)
            img = model_mean + torch.sqrt(posterior_variance_t) * noise

        # To maintain the natural range of the distribution
        img = torch.clamp(img, -1.0, 1.0)
        if i % stepsize == 0: # at diffusion step i
          for j in range(img.shape[0]) : # for each img of batch
            plt.subplot(10, nbr_imgs, int((T-i)/stepsize) + j*nbr_imgs)
            show_images(img[j].detach().cpu())
    plt.show()

In [None]:
## Initialisation and visualisation of the model

model = UNet(1, 1) # 1 input channel and 1 output channel for MNIST data (greyscale images)
model.to(device)
print(model)

## Q2) How many parameters does this model have ?

A2)

In [None]:
## TO DO

## Train the model and visualise generated samples and training loss

In [None]:
## Define Adam optimizer and MSE loss

## TO DO

train_losses = []

In [None]:
start = time.time()
for epoch in range(num_epochs):
    for step, (imgs_batch, _) in enumerate(trainloader): # imgs_batch is for the
    # MNIST images and _ is for the labels which are the corresponding number
    # on the images but we don't need them here (it's useful for classification
    # tasks for instance)
      optimizer.zero_grad()

      # Sample a random timestep between 0 and T-1 of shape batch_size
      # It has to be of type "long integer"
      # Don't forget to bring it to GPU (.to(device))
      t = ## TO DO

      # Compute the noised images and the corresponding added noise
      # (deterministic process)
      x_noisy, noise = ## TO DO

      # Predict the added noise with the model
      noise_pred = ## TO DO

      # Compute the loss
      train_loss = ## TO DO

      # Backward pass and update weights
      train_loss.backward()
      optimizer.step()

    train_losses.append(train_loss.item())

    if epoch % 10 == 0:
      print("Epoch", epoch, "Loss:", train_loss.item(), "Time elapsed :", np.round(time.time() - start, 2), "secs")
      sample_plot_reconstruction(10)

In [None]:
plt.plot(train_losses)

Once your training loop properly works you can try to train the model with more training data to have better quality generation (change `train_ratio` in the hyperparameters)

## Conditional Generation
We will now add the embeddings of MNIST labels (just like we already do with time embeddings for diffusion) to transform our model into a conditional model. These embeddings are learned throughout training. Another approach would be to add fixed embeddings (one-hot encoding for instance). <br><br>
In the following, your task is to edit the architecture of the model to take into account these MNIST label embeddings.

In [None]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, time_emb_dim, label_emb_dim, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True)
        )
        self.time_mlp = nn.Sequential(nn.Linear(time_emb_dim, mid_channels), nn.ReLU())
        self.label_mlp = nn.Sequential(nn.Linear(label_emb_dim, mid_channels), nn.ReLU())
        self.conv2 = nn.Sequential(
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x, t, l):

      ## TO DO
      # Add the diffusion step embedding and the label embedding to the image embedding

        return x


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels, time_emb_dim, label_emb_dim):
        super().__init__()
        self.maxpool = nn.MaxPool2d(2)
        self.conv = DoubleConv(in_channels, out_channels, time_emb_dim, label_emb_dim)

    def forward(self, x, t, l):
        return       ## TO DO (take into account label embedding)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, time_emb_dim, label_emb_dim, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, time_emb_dim, label_emb_dim, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels, time_emb_dim, label_emb_dim)

    def forward(self, x1, x2, t, l):
        x1 = self.up(x1)

        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        x = torch.cat([x2, x1], dim=1)
        return ## TO DO (take into account label embedding)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

In [None]:
## We specify the label embedding dimension and add the label embedding
## For the DoubleConv layers

class Conditional_UNet(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=False):
        super(Conditional_UNet, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.bilinear = bilinear

        # Time embedding
        self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(dim=32),
                nn.Linear(32, 32),
                nn.ReLU()
            )

        # Label embedding for conditional generation
        self.label_embeddings = nn.Embedding(num_embeddings=10, embedding_dim=32)

        self.inc = (DoubleConv(self.in_channels, 64, time_emb_dim=32, label_emb_dim=32))
        self.down1 = (Down(64, 128, time_emb_dim=32, label_emb_dim=32))
        self.down2 = (Down(128, 256, time_emb_dim=32, label_emb_dim=32))
        self.down3 = (Down(256, 512, time_emb_dim=32, label_emb_dim=32))
        factor = 2 if bilinear else 1
        self.down4 = (Down(512, 1024 // factor, time_emb_dim=32, label_emb_dim=32))
        self.up1 = (Up(1024, 512 // factor, time_emb_dim=32, label_emb_dim=32, bilinear=bilinear))
        self.up2 = (Up(512, 256 // factor, time_emb_dim=32, label_emb_dim=32, bilinear=bilinear))
        self.up3 = (Up(256, 128 // factor, time_emb_dim=32, label_emb_dim=32, bilinear=bilinear))
        self.up4 = (Up(128, 64, time_emb_dim=32, label_emb_dim=32, bilinear=bilinear))
        self.outc = (OutConv(64, self.out_channels))

    def forward(self, x, timestep, labels):
        ## TO DO (take into account label embedding)
        return logits

In [None]:
## New sampling function for generation

@torch.no_grad()
def sample_plot_reconstruction_conditional(labels_list):
    # Sample noise
    img = torch.randn((len(labels_list), 1, img_size, img_size), device=device)
    plt.figure(figsize=(8,len(labels_list)))
    plt.axis('off')
    stepsize = int(T/10)
    #labels = torch.tensor([i for i in range(img.shape[0])]).to(device)
    labels = torch.tensor(labels_list).to(device)

    for i in range(0,T)[::-1]:
        t = torch.full((img.shape[0],), i, device=device, dtype=torch.long)

        betas_t = get_index_from_list(betas, t, img.shape)
        sqrt_one_minus_alphas_cumprod_t = get_index_from_list(sqrt_one_minus_alphas_cumprod, t, img.shape)
        sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, img.shape)

        # Call model (current image - noise prediction)
        model_mean = sqrt_recip_alphas_t * (img - betas_t * model(img, t, labels) / sqrt_one_minus_alphas_cumprod_t)
        posterior_variance_t = get_index_from_list(posterior_variance, t, img.shape)

        if t[0] == 0:
            img = model_mean
        else:
            noise = torch.randn_like(img)
            img = model_mean + torch.sqrt(posterior_variance_t) * noise

        # To maintain the natural range of the distribution
        img = torch.clamp(img, -1.0, 1.0)
        if i % stepsize == 0: # at diffusion step i
          for j in range(img.shape[0]) : # for each img of batch
              plt.subplot(len(labels_list), 10, int((T-i)/stepsize) + j*10)
              show_images(img[j].detach().cpu())
    plt.show()

In [None]:
## Initialisation and visualisation of the model

model = Conditional_UNet(1, 1) # 1 input channel and 1 output channel for MNIST data (greyscale images)
model.to(device)
print(model)

In [None]:
## Define Adam optimizer and MSE loss

## TO DO

train_losses = []

In [None]:
start = time.time()
for epoch in range(num_epochs):
    for step, (imgs_batch, labels_batch) in enumerate(trainloader): # imgs_batch is for the
      # MNIST images and labels_batch is for the labels which are the corresponding number
      # on the images
      labels_batch = labels_batch.to(device)
      optimizer.zero_grad()

      # Sample a random timestep between 0 and T-1 of shape batch_size
      # It has to be of type "long integer"
      # Don't forget to bring it to GPU (.to(device))
      t = ## TO DO

      # Compute the noised images and the corresponding added noise
      # (deterministic process)
      x_noisy, noise = ## TO DO

      # Predict the added noise with the model
      # --> Don't forget to add the new label argument for conditional generation
      noise_pred = ## TO DO

      # Compute the loss
      train_loss = ## TO DO

      # Backward pass and update weights
      train_loss.backward()
      optimizer.step()

    train_losses.append(train_loss.item())

    if epoch % 10 == 0:
      print("Epoch", epoch, "Loss:", train_loss.item(), "Time elapsed :", np.round(time.time() - start, 2), "secs")
      labels_list = [1,2,3] # Choose what numbers you would like to generate (you can increase the length of the list as you wish)
      sample_plot_reconstruction_conditional(labels_list)

In [None]:
plt.plot(train_losses)
plt.show()

Note : We  trained this model from scratch again to make the diffusion process conditional using a classifier-free guidance approach. Another approach would be to perform classifier guidance which boils down to train a separate classifier and edit the sampling step during generation by adding the gradient of the log probability with respect to the label ($x_{t−1}​=f(x_{t​},t)+α∇x_{t​}​logP(y∣x_{t}​)$). <br> <br> Pros of this methods are :

1.   no need to retrain the diffusion model from scratch
2.   you can leverage an existing pretrained classifier for conditionning
3.   you usually get better quality results with this approach

You can try to implement that yourself if you wish :) (Bonus)