# Score-Based Diffusion Models
Exercise by [Jes Frellsen](https://frellsen.org) (Technical University of Denmark), June 2025 (version 1.0).

In this programming exercise, you will work with Score-Based Diffusion Models as described by [Song et al., 2021](https://arxiv.org/abs/2011.13456). We consider the MNIST dataset, with pixel values transformed the interval $[-1,1]$.

The provided code is a modular and simple implementation of most of the functionality of a variance preserving diffusion model. Your task is to implement the learning and sampling functionality, as described in **the exercise falling the code**.

We have provided you with one file:
* `unet.py` contains the code for a U-Net predicting $\epsilon$ of reverse process on MNIST. The architecture and the implementation of the U-Net is adapted from an implementation by
[Muhammad Firmansyah Kasim](https://github.com/mfkasim1/score-based-tutorial/blob/main/03-SGM-with-SDE-MNIST.ipynb).

You can download the files using the following commands:

In [None]:
! curl -O https://raw.githubusercontent.com/frellsen/ProbAI-2025/refs/heads/main/unet.py

# Implementing the diffusion model
**Implementation:** Below we provide an implementation of a variance preserving diffusion model. The code is missing the implementation of the loss and sample function. Your task will be to complete them (see Exercise 1 below).

In [None]:
import torch
import torch.nn as nn
import torch.distributions as td
import math

class Diffusion(nn.Module):
    def __init__(self, network, beta_min=0.1, beta_max=20., tau=1e-3):
        """
        Initialize a variance preserving score-based diffusion model.

        Parameters:
        network: [nn.Module]
            The network to use for the diffusion process.
        beta_min: [float]
            The minimal beta for linear beta noise schedule.
        beta_max: [float]
            The maximal beta for linear beta noise schedule.
        tau: [float]
            The time interval goes from [tau, 1] for the diffusion process.
        """
        super(Diffusion, self).__init__()
        self.network = network
        self.beta_min = beta_min
        self.beta_max = beta_max
        self.tau = tau

    def beta(self, t):
        """
        Compute the beta noise schedule.

        Parameters:
        t: [torch.Tensor]
            The time steps for which to compute the beta values.
        """
        return self.beta_min + t * (self.beta_max - self.beta_min)

    def mu_sigma(self, t, x_0=None):
        """
        Compute the mean and standard deviation for the variance preserving transistion kernel.

        Parameters:
        t: [torch.Tensor]
            The time steps for which to compute the mean and standard deviation.
        x_0: [torch.Tensor, optional]
            The original data (x_0) to compute the mean. If None, only the standard deviation is computed.
        """

        c = -0.5 * t**2 * (self.beta_max-self.beta_min) - t*self.beta_min
        sigma = torch.sqrt(1-torch.exp(c))

        if x_0 is None:
            return sigma
        else:
            mu = torch.exp(0.5*c) * x_0
            return mu, sigma

    def loss(self, x_0):
        """
        Evaluate the denoising score matching loss.

        Parameters:
        x_0: [torch.Tensor]
            A batch of data (x) of dimension `(batch_size, *)`.
        Returns:
        [torch.Tensor]
            The computed loss value.
        """

        # Sample the time steps
        t = td.uniform.Uniform(self.tau, 1).sample((x_0.shape[0],) + (x_0.dim()-1)*(1,)).to(x_0.device)

        # Sample the noise
        epsilon = torch.randn_like(x_0)

        # Sample x_t from the noising process
        mu, sigma = self.mu_sigma(t, x_0)
        x_t = mu + sigma * epsilon

        ### You code here
        loss = 0
        ###

        return loss.mean()


    def sample(self, shape, T):
        """
        Sample from the model using Euler Maruyama method for SDEs.

        Parameters:
        shape: [tuple]
            The shape of the samples to generate.
        T: [int]
            The number of time steps to sample.
        Returns:
        [torch.Tensor]
            The generated samples.
        """

        device = next(self.network.parameters()).device

        delta_t = (1-self.tau)/T

        # Sample x_t for i=T (i.e., Gaussian noise)
        x_t = torch.randn(shape).to(device)

        # Sample x_t given x_{t+1}the time steps
        for i in range(T, 1, -1):
            t = self.tau + i*delta_t
            t = torch.full((x_t.shape[0],1), t).to(device)

            beta = self.beta(t)
            sigma = self.mu_sigma(t)

            delta_w = math.sqrt(delta_t)*torch.randn_like(x_t).to(device)

            ### You code here
            x_t = 0
            ###

        return x_t

**Training loop**: We have also implemented a generic training loop for learning the diffusion model.

In [None]:
from tqdm.notebook import tqdm

def train(model, optimizer, data_loader, epochs, device):
    """
    Train a Diffuion model.

    Parameters:
    model: [Diffuion]
       The model to train.
    optimizer: [torch.optim.Optimizer]
         The optimizer to use for training.
    data_loader: [torch.utils.data.DataLoader]
            The data loader to use for training.
    epochs: [int]
        Number of epochs to train for.
    device: [torch.device]
        The device to use for training.
    """
    model.train()

    total_steps = len(data_loader)*epochs
    progress_bar = tqdm(range(total_steps), desc="Training")

    for epoch in range(epochs):
        data_iter = iter(data_loader)
        for x in data_iter:
            if isinstance(x, (list, tuple)):
                x = x[0]
            x = x.to(device)
            optimizer.zero_grad()
            loss = model.loss(x)
            loss.backward()
            optimizer.step()

            # Update progress bar
            progress_bar.set_postfix(loss=f"⠀{loss.item():12.4f}", epoch=f"{epoch+1}/{epochs}")
            progress_bar.update()

**Training data**: Next, we load the MNIST traning set and transform the pixels to the interval $[-1,1]$. **For faster training, we initially only load the first 64 images from the traininset.**

In [None]:
from torchvision import datasets, transforms

batch_size = 64

# Define the transform to use for the data
transform=transforms.Compose([transforms.ToTensor(),
                                transforms.Lambda(lambda x: x + torch.rand(x.shape)/255),
                                transforms.Lambda(lambda x: (x-0.5)*2.0),
                                transforms.Lambda(lambda x: x.flatten())])

# Load the data
train_data = datasets.MNIST('data/', train=True, download=True, transform=transform)

# Load only the first 64 samples for training
train_data = torch.utils.data.Subset(train_data, range(batch_size))

# Create the data loaders
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

# Get the dimension of the dataset
D = next(iter(train_loader))[0].shape[1]

**Initialize the model and run the training loop**: Finally we initializes the model and run the training loop. Remember that this will not work before you have completed the assignment below.

In [None]:
# Define the network
from unet import Unet
network = Unet()

# Define model
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu")
model = Diffusion(network).to(device)

# Define optimizer
lr = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# Train model
epochs = 5000
train(model, optimizer, train_loader, epochs, device)

**Sampling**: The following code samples from a trained model and plots the samples

In [None]:
from torchvision.transforms.functional import to_pil_image
from torchvision.utils import make_grid
from IPython.display import display

# Generate samples
model.eval()
with torch.no_grad():
    samples = (model.sample((64,D), T=1000)).cpu()
    samples = (samples /2 + 0.5).clamp(0, 1)

image_pil = to_pil_image(make_grid(samples.view(64, 1, 28, 28)))
display(image_pil)

# Exercise 1
Complete the Diffusion implementation above, by implementing the following parts:
* `Diffusion.loss(...)` should implement the denoising score matching objective. Use $\lambda(t) = \sigma^2_{VP}$ such that the loss just becomes $||\tilde{s}_{\theta}(\mathbf{x}(t))-\epsilon||$.
* `Diffusion.sample(...)` should implement the SDE Euler-Maruyama solver for sampling.