# Course 1: Classical Generative Diffusion Models

## Intro to Deep Learning with Pytorch

### Manipulating tensors, transfer between CPU and GPU devices

In [None]:
import torch

def get_device():
    """
    Returns the available device ('cuda', 'mps', or 'cpu').
    """
    if torch.cuda.is_available():
        return 'cuda'
    elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return 'mps'
    else:
        return 'cpu'
    
device = get_device()
print('Device in use:', device)

**Simply creating tensors**

In [None]:
# create a torch.tensor from a list
x = torch.tensor([1, 2, 3])
print(x)

# create a torch.tensor from a numpy array is as straightforward
import numpy as np
x = np.array([1, 2, 3])
x = torch.tensor(x)
print(x)

# also works with a list of lists
x = [[1, 2], [3, 4]]
x = torch.tensor(x)
print(x)


**Tensors operations** Typically element wise operations

In [None]:
# element-wise operations
x = torch.tensor([1, 2, 3]).to(device)
y = torch.tensor([4, 5, 6]).to(device)
add = x + y
print('add', add)
mul = x * y
print('mul', mul)

**stacking tensors**

In [None]:
# you can stack multiple tensors together. See how the shape changes
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
print('x xhape:', x.shape)
print('y shape:', y.shape)
z = torch.stack([x, y])
print('z shape:', z.shape)

**Special tensors: zero and one tensors**

In [None]:
# create a tensor filled with zeros of dimension 2x3
x = torch.zeros(2, 3)
y = torch.ones(2, 3)
print('x', x)
print('y', y)

# IMPORTANT : you can create a tensor filled with zeros with the SAME SHAPE as another tensor AND ON THE SAME DEVICE
x = torch.tensor([1, 2, 3]).to(device)
z = torch.zeros_like(x)
print('z', z)


**Draw random variables**

In [None]:
# Sample a tensor from a uniform distribution
x = torch.rand(2, 3)
print('x', x)

# Sample a tensor from a normal distribution
y = torch.randn(2, 3)
print('y', y)

### Example CPU vs GPU: Mandelbrot

In [None]:
# create a grid of complex numbers
lim = 1.5
x = torch.linspace(-lim, lim, 300)
y = torch.linspace(-lim, lim, 300)
X, Y = torch.meshgrid(x, y)
C = X + 1j*Y
C = C.to(device) # comment out to test on CPU
print('C shape:', C.shape)

# A point is in the Mandelbrot set if: z_{n+1} = z_n^2 + c does not diverge
# We can use torch to compute the Mandelbrot set
def mandelbrot(c, max_iter):
    z = torch.zeros_like(c)
    for _ in range(max_iter):
        z = z*z + c
    
    # the point is in the Mandelbrot set if the absolute value of z is less than 2
    in_mandelbrot = z.cpu().abs() < 2
    return in_mandelbrot

In [None]:
# C = C.to(device)
Z = mandelbrot(C, 50)
%timeit Z = mandelbrot(C, 50)

In [None]:
# plot the Mandelbrot set
import matplotlib.pyplot as plt
plt.imshow(Z.numpy(), extent=(-lim, lim, -lim, lim))
plt.show()

### Preparing data: Dataset and Dataloader

In [None]:
import torch

# Create a dataset with a Gaussian mixture distribution:
def get_gaussian_mixture_datapoints(mean1, mean2, std, n_samples):
    print('Using Gaussian Mixture dataset, with parameters mean=[{}, {}], [{}, {}] and std={}. {} samples.'
          .format(mean1[0], mean1[1], mean2[0], mean2[1], std, n_samples))
    
    ...
    
    return samples

def get_default_gaussian_mixture_datapoints():
    mean1 = torch.tensor([-0.5, 0])
    mean2 = torch.tensor([0.5, 0])
    std = 0.1
    n_samples = 10000
    gaussian_datapoints = get_gaussian_mixture_datapoints(mean1, mean2, std, n_samples)
    # shuffle the datapoints
    gaussian_datapoints = gaussian_datapoints[torch.randperm(n_samples)]
    return gaussian_datapoints

gaussian_datapoints = get_default_gaussian_mixture_datapoints()

# plot the Gaussian mixture dataset
import matplotlib.pyplot as plt
plt.scatter(gaussian_datapoints[:, 0], gaussian_datapoints[:, 1], s=1)
plt.axis('equal')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Gaussian Mixture Dataset')
plt.show()


In [None]:
# Create a TensorDataset
def create_dataset(datapoints):
    labels = torch.zeros(datapoints.shape[0])
    dataset_obj = torch.utils.data.TensorDataset(datapoints, labels)
    return dataset_obj

def create_dataloader(dataset, batch_size):
    # Create a DataLoader
    dataloader = torch.utils.data.DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=True
        )
    return dataloader


gaussian_datapoints = get_default_gaussian_mixture_datapoints()
dataset_obj = create_dataset(gaussian_datapoints)
dataloader = create_dataloader(dataset_obj, batch_size=500)

### Evaluation metric: assess the distance between two empirical distribution

We will use the Wassertein-2 metric:
$$W_2(\mu, \nu) = \underset{\gamma \in \mathcal{M}(\mu, \nu)}{\inf} \int \| x - y \|^2 \gamma(dx, dy). $$

In [None]:
import pyemd

# Run emd_loss on two gaussian with different means and same std
mean1 = torch.tensor([-0.5, 0])
mean2 = torch.tensor([0.5, 0])
std1 = 0.1
std2 = 0.1
n_samples = 10000
gaussian_1 = torch.randn(n_samples, 2) * std1 + mean1
gaussian_2 = torch.randn(n_samples, 2) * std2 + mean2

# Compute the EMD between the two distributions
emd = pyemd.emd_samples(gaussian_1, gaussian_2)

print('Empirical EMD between the two distributions:', emd)

## Generative Diffusion Process

We need to define three functions:
* **The forward diffusion**, i.e., sample $x_t$ given $x_0, t$. Since $p_{t |0}$ is available in closed form, we do not need to simulate a forward SDE; this is the *the simulation-free* property.
* **The objective function**, which is the denoising squared $L_2$ loss.
* **The sampling algorithm**, i.e., a simulation of the backward SDE.

### Forward : sample $p_{t | 0}$

Implement two noise schedule:
* **linear noise schedule** $\beta_t$ scales linearly from $\beta_{min} = 0.1$ to $\beta_{max} = 20.0$.
* **cosine noise schedule** Directly parameterize $\bar \alpha_t$ as 
$$\bar \alpha_t = \cos(\frac{\bar t + s}{2(1 + s)} \pi)^2,$$ 
where $s = 0.008$ and $\bar  t = t / T$


In [None]:
import math

def match_last_dims(data, shape):
    """
    Repeat a 1D tensor so that its last dimensions [1:] match `size[1:]`.
    Useful for working with batched data.
    """
    assert len(data.shape) == 1, "Data must be 1-dimensional (one value per batch)"
    for _ in range(len(shape) - 1):
        data = data.unsqueeze(-1)
    return data.repeat(1, *(shape[1:]))

def compute_beta_t(t_norm, T, schedule = 'linear'):
    # Compute β(t) depending on schedule.
    
    if schedule == 'linear':
        ...
    elif schedule == 'cosine':
        s = 0.008
        beta_t = (torch.pi / (T * (1 + s))) * torch.tan(((t_norm + s) / (1 + s)) * (torch.pi / 2))
    else:
        raise ValueError('Unknown schedule')
    return beta_t

def compute_alpha_bar(t_norm, schedule = 'linear'):
    if schedule == 'linear': 
        ...
    elif schedule == 'cosine':
        alpha_bar = 0.5 * (1 - torch.cos(t_norm * torch.pi))
        s = 0.008
        alpha_bar = torch.cos((t_norm + s) / (1 + s) * (torch.pi / 2))**2
    else:
        raise ValueError('Unknown schedule')
    return alpha_bar

# must return x_t and the added noise, we will need it later to compute the loss
def forward(x_start, t, T, schedule = 'linear'):
    t_norm = t / T 
    alpha_bar = compute_alpha_bar(t_norm, schedule)
    # expand alpha_bar to the same shape as x_start, so that we can multiply them
    alpha_bar = match_last_dims(alpha_bar, x_start.shape)
    noise = ...
    x_t = ...
    return x_t, noise

In [None]:

# check the forward process by plotting the empirical marginals p_t, for some t and empirical samples
x_start = dataset_obj[:1000][0]

# create four subplots 
fig, axs = plt.subplots(1, 4, figsize=(20, 5))

# plot the empirical marginals p_t
def plot_ax_i(ax, x, y, title):
    ax.scatter(x, y, s=1)
    ax.axis('equal')
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_title(title)

T = 1

for i, t_norm in enumerate([0.0, 0.25, 0.5, 1.0]):
    t = t_norm * T * torch.ones_like(x_start)[:, 0]
    x_t, _ = forward(x_start, t, T)
    plot_ax_i(axs[i], x_t[:, 0], x_t[:, 1], 't = {}'.format(t[0]))

plt.show()

### Training and Setting up the Neural Network

**Neural network**

In [None]:
from model.SimpleModel import MLPModel

# define a simple MLP model. For the moment, take the default one I provide

simple_model = MLPModel(
    nfeatures = 2,
    time_emb_size= 8,
    nblocks = 2,
    nunits = 32,
    skip_connection = True,
    layer_norm = True,
    dropout_rate = 0.1,
    learn_variance = False,
)

simple_model = simple_model.to(device)

# setting up the optimizer
import torch.optim as optim

optimizer = optim.AdamW(
    simple_model.parameters(), 
    lr=2e-3, 
    betas=(0.9, 0.999))

# potentially set up a learning schedule too ...

**Objective function**

In [None]:
import torch.nn.functional as F

def training_losses(model, x_start, T):
    
    # Sample t uniformly from [0, T]
    t = ... 
    x_t, noise = forward(x_start, t, T)
    # The model takes x_t and t as input and predicts the noise. time t should be of shape (batch_size, 1)
    # we can pass normalized time to model as input
    t_norm = t / T
    t_norm = t_norm.view(-1, 1)
    predicted_noise = model(x_t, t_norm)
    loss = ...
    return loss

# training loop

import os

def train(
    num_epochs, 
    checkpoint_interval, 
    dataloader,
    model,
    optimizer,
    checkpoint_dir,
    device,
    T = 1
):
    print("Training on device:", device)

    # Set the model to training mode.
    model.train()
    epoch_losses = []
    for epoch in (range(1, num_epochs + 1)):
        running_loss = 0.0
        for batch_idx, (data, _) in (enumerate(dataloader)):
            data = data.to(device) 
            optimizer.zero_grad()

            # Compute the training loss.
            loss = training_losses(model, x_start=data, T=T)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_loss = running_loss / len(dataloader)
        epoch_losses.append(avg_loss)
        print(f"Epoch [{epoch}] Average Loss: {avg_loss:.4f}")

        # Save a checkpoint every checkpoint_interval epochs.
        if epoch % checkpoint_interval == 0:
            os.makedirs(checkpoint_dir, exist_ok=True)
            checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pth")
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'epoch_losses': epoch_losses,
            }, checkpoint_path)
            print("Saved checkpoint to", checkpoint_path)

    print("Training finished.")

In [None]:
# RUN!
train(
    num_epochs=50,
    checkpoint_interval=50,
    dataloader=dataloader,
    model=simple_model,
    optimizer=optimizer,
    checkpoint_dir='checkpoints',
    device=device
)

### Generation: simulate the backward SDE

In [None]:
from tqdm import tqdm

def score_fn(model, x, t_norm):
    """
    Given the noise-predicting model, returns the score (i.e. ∇_x log p_t(x))
    at actual time t. Note that the model expects a normalized time (t/T).
    For VP: score = - (predicted noise) / sqrt(1 - ᾱ(t))
    """
    alpha_bar = ...
    epsilon = ...
    score = ...
    return score

def sample(
    model,
    n_samples,
    reverse_steps,
    schedule = 'linear',
    T = 1):
    
    xt = torch.randn(...)
    model.eval()
    with torch.inference_mode():
        # Create a time discretization from T to 0
        t_seq = ... 
        for i in tqdm(range(reverse_steps)):
            t_current = t_seq[i]
            t_next = t_seq[i + 1]
            dt = t_next - t_current  # dt is negative (reverse time)
            
            # Create a batch of current time values for the update.
            t_batch = ...
            t_norm_batch = t_batch / T

            
            beta_t = compute_beta_t(t_norm_batch, T, schedule)
            
            f = ...
            g = ...
            
            
            # Get the score (using the noise-predicting network)
            score = score_fn(model, xt, t_batch)
            
            # Euler–Maruyama update:
            z = torch.randn_like(xt)
            xt += ...
            
    return xt

In [None]:
samples = sample(
    model=simple_model,
    n_samples=1000,
    reverse_steps=100,
    schedule='linear',
    T=1
)
samples = samples.cpu().detach().numpy()

# plot samples
tmp_samples = samples.clip(-1, 1)
plt.scatter(tmp_samples[:, 0], tmp_samples[:, 1], s=1)
plt.axis('equal')
plt.xlabel('x')
plt.ylabel('y')
plt.xlim(-1, 1)
plt.ylim(-1, 1)
plt.title('Generated Samples')
plt.show()

# Case Study 1: SDE vs ODE sampling

Modify the `sample` function to accept a `deterministic : bool` argument, according to which the sampling procedure will correspond to SDE or ODE sampling.

In [None]:
def sample(
    model,
    n_samples,
    reverse_steps,
    deterministic = False,
    schedule = 'linear',
    T = 1):
    pass
    

In [None]:
samples = sample(
    model=simple_model,
    n_samples=1000,
    reverse_steps=20,
    deterministic=True,
    schedule='linear',
    T=1
)
samples = samples.cpu().detach().numpy()

# plot samples
tmp_samples = samples.clip(-1, 1)
plt.scatter(tmp_samples[:, 0], tmp_samples[:, 1], s=1)
plt.axis('equal')
plt.xlabel('x')
plt.ylabel('y')
plt.xlim(-1, 1)
plt.ylim(-1, 1)
plt.title('Generated Samples')
plt.show()

Make a plot comparing the performance of SDE vs ODE sampling as a function of reverse steps, for example using the Wasserstein-2 metric.

In [None]:
# now compare ODE vs SDE performance

timesteps = [2, 5, 10, 20, 50, ]
n_samples = 5000

samples_sde = ...

samples_ode = ...

# retrieve true samples from dataset
true_samples = dataset_obj[:n_samples][0].cpu().detach().numpy()

In [None]:
# compute emd distance between samples
emd_sde = [pyemd.emd_samples(true_samples, samples.detach().cpu().numpy()) for samples in samples_sde]
emd_ode = [pyemd.emd_samples(true_samples, samples.detach().cpu().numpy()) for samples in samples_ode]

In [None]:
# now plot the EMD distance as a function of the number of reverse steps
plt.plot(timesteps, emd_sde, label='SDE')
plt.plot(timesteps, emd_ode, label='ODE')
plt.xlabel('Number of reverse steps')
plt.ylabel('EMD distance')
plt.legend()
plt.title('EMD distance as a function of the number of reverse steps')
plt.show()

# Case Study 2: Conditioning with Classifier-Free Guidance

In [None]:
import torch

def get_device():
    """
    Returns the available device ('cuda', 'mps', or 'cpu').
    """
    if torch.cuda.is_available():
        return 'cuda'
    elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return 'mps'
    else:
        return 'cpu'
    
device = get_device()
print('Device in use:', device)

**Modify the Dataset** 

Integrate class labels $y \in \{0, 1\}$ in Gaussian 2-mixture dataset

In [None]:
import torch

# Create a dataset with a Gaussian mixture distribution:
def get_gaussian_mixture_datapoints(mean1, mean2, std, n_samples):
    ...
    return samples, labels

def get_default_gaussian_mixture_datapoints():
    mean1 = torch.tensor([-0.5, 0])
    mean2 = torch.tensor([0.5, 0])
    std = 0.1
    n_samples = 10000
    gaussian_datapoints, labels = get_gaussian_mixture_datapoints(mean1, mean2, std, n_samples)
    # shuffle the data
    perm = torch.randperm(n_samples)
    gaussian_datapoints = gaussian_datapoints[perm]
    labels = labels[perm]
    return gaussian_datapoints, labels

# random shuffle
gaussian_datapoints, labels = get_default_gaussian_mixture_datapoints()

# plot the Gaussian mixture dataset
import matplotlib.pyplot as plt
plt.scatter(gaussian_datapoints[labels == 0, 0], gaussian_datapoints[labels == 0, 1], s=1, label='Class 0')
plt.scatter(gaussian_datapoints[labels == 1, 0], gaussian_datapoints[labels == 1, 1], s=1, label='Class 1')
plt.axis('equal')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.title('Gaussian Mixture Dataset')
plt.show()


# Create a TensorDataset
def create_dataset(datapoints, labels):
    dataset_obj = torch.utils.data.TensorDataset(datapoints, labels)
    return dataset_obj

def create_dataloader(dataset, batch_size):
    # Create a DataLoader
    dataloader = torch.utils.data.DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=True
        )
    return dataloader


gaussian_datapoints, labels = get_default_gaussian_mixture_datapoints()
dataset_obj = create_dataset(gaussian_datapoints, labels)
dataloader = create_dataloader(dataset_obj, batch_size=500)

next(iter(dataloader))

**Modify the neural network** 
* It should accept all possible class labels $y$, plus the null class label $\emptyset$, which will correspond to the *unconditional* label. It can be represented by any value you like.
* Try to use `nn.Embedding`; the model will learn to embed class labels in $\mathbb{R}^d$.

In [None]:
import model.SimpleModelConditioned as SimpleModelConditioned

simple_model_conditioned = SimpleModelConditioned.MLPModel(
    nfeatures = 2,
    time_emb_size=8,
    nblocks = 2,
    nunits = 32,
    skip_connection = True,
    layer_norm = True,
    dropout_rate = 0.1,
    num_classes = 2
)

simple_model_conditioned = simple_model_conditioned.to(device)

# setting up the optimizer
import torch.optim as optim

optimizer = optim.AdamW(
    simple_model_conditioned.parameters(), 
    lr=2e-3, 
    betas=(0.9, 0.999))

# potentially set up a learning schedule too ...

**Modify training**

In [None]:
import torch.nn.functional as F

def training_losses(model, x_start, T, y):
    ...
    return loss

# training loop

import os

def train(
    num_epochs, 
    checkpoint_interval, 
    dataloader,
    model,
    optimizer,
    checkpoint_dir,
    device,
    T = 1
):
    print("Training on device:", device)

    # Set the model to training mode.
    model.train()
    epoch_losses = []
    for epoch in (range(1, num_epochs + 1)):
        running_loss = 0.0
        for batch_idx, (data, y) in (enumerate(dataloader)):
            data = data.to(device) 
            y = y.to(device)
            # 10% of the time, set y to null label (for instance, null label = max_num_classes + 1)
            y[torch.rand_like(y) < 0.1] = 2
            optimizer.zero_grad()

            # Compute the training loss.
            loss = training_losses(model, x_start=data, T=T, y = y.int()) # y must be an integer tensor if using nn.Embedding
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            

        avg_loss = running_loss / len(dataloader)
        epoch_losses.append(avg_loss)
        print(f"Epoch [{epoch}] Average Loss: {avg_loss:.4f}")

        # Save a checkpoint every checkpoint_interval epochs.
        if epoch % checkpoint_interval == 0:
            os.makedirs(checkpoint_dir, exist_ok=True)
            checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pth")
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'epoch_losses': epoch_losses,
            }, checkpoint_path)
            print("Saved checkpoint to", checkpoint_path)

    print("Training finished.")

In [None]:
# RUN!
train(
    num_epochs=50,
    checkpoint_interval=50,
    dataloader=dataloader,
    model=simple_model_conditioned,
    optimizer=optimizer,
    checkpoint_dir='checkpoints_condtioned',
    device=device
)

**Modify the sampling algorithm** 

It should accept `guidance_scale` as argument

In [None]:
from tqdm import tqdm

def score_fn(model, x, t_norm, y):
    """
    Given the noise-predicting model, returns the score (i.e. ∇_x log p_t(x))
    at actual time t. Note that the model expects a normalized time (t/T).
    For VP: score = - (predicted noise) / sqrt(1 - ᾱ(t))
    """

    alpha_bar = ...
    epsilon = ...
    score = ...
    return score

def sample(
    model,
    n_samples,
    reverse_steps,
    class_label,
    deterministic = False,
    guidance_scale = 3.0,
    schedule = 'linear',
    T = 1):
    
    xt = ...
    model.eval()
    with torch.inference_mode():
        # Create a time discretization from T to 0
        t_seq = ...
        for i in tqdm(range(reverse_steps)):
            
            ...
            
            
            # Get the score (using the noise-predicting network)
            
            score_cond = score_fn(..., class_label)
            
            uncond_class_label = ...
            score_uncond = score_fn(..., uncond_class_label)
            
            score = ... # include guidance
            
            
            if deterministic:
                ...
            else:
                ...
            
    return xt

## Visualize Results

Visually observe what happens with increasing guidance scale. Quantify with Wasserstein metric

In [None]:
class_label = 0
guidance_scale = 1
samples = sample(
    model=simple_model_conditioned,
    n_samples=1000,
    reverse_steps=20,
    class_label=class_label,
    guidance_scale=guidance_scale,
    deterministic=False,
    schedule='linear',
    T=1
)
samples = samples.cpu().detach().numpy()

# plot samples
tmp_samples = samples.clip(-2, 2)
plt.scatter(tmp_samples[:, 0], tmp_samples[:, 1], s=1)
plt.axis('equal')
plt.xlabel('x')
plt.ylabel('y')
plt.xlim(-1, 1)
plt.ylim(-1, 1)
plt.title('Generated Samples with class {}, guidance {}'.format(class_label, guidance_scale))
plt.show()

# Go further: Elucidated Diffusion Model 

(Elucidating the Design Space of Diffusion Model)

Discuss the paper with your colleagues or with me. Implement the recommend design choices, in terms of: 
* Sampling
* Network and pre-conditioning
* Training

All these design choices will require choosing the right hyper-parameters to be chosen for the working dataset. 

Typically, one does not want to go through all these troubles... there is a reason why people have settled on a default choice:
* the VP process 
* epsilon-prediction