# CS5670 Project 5B

In this problem set, you will train your own diffusion model on the MNIST dataset. You’ll build and train UNet-based diffusion models, gaining hands-on experience with both unconditional and class-conditioned generation.

Please refer to the [Project 5B instructions page](https://www.cs.cornell.edu/courses/cs5670/2025sp/projects/5_project/partB.html) for detailed descriptions of each task and submission instructions.

## Setup environment

In [None]:
# We recommend using these utils.
# https://google.github.io/mediapy/mediapy.html
# https://einops.rocks/
!pip install mediapy einops --quiet

In [None]:
# Import essential modules. Feel free to add whatever you need.
import matplotlib.pyplot as plt
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import dataclasses

# Part 1: Training a Single-step Denoising UNet


## Deliverables
In summary, your deliverables should include the following for this problem:

1.   A visualization of different noising process over $\sigma = [0.0, 0.25, 0.5, 0.75, 1.0]$ (see Figure 3 on the [instruction page](https://www.cs.cornell.edu/courses/cs5670/2025sp/projects/5_project/partB.html)).
2.   A training loss curve plot every few iterations during the whole training process (see Figure 4).
3.   Sample results on the test set after the first and the 5-th epoch (staff solution takes ~7 minutes for 5 epochs on a Colab T4 GPU). (see Figures 5 and 6)
4.   Sample results on the test set with out-of-distribution noise levels after the model is trained. Keep the same image and vary $\sigma = [0.0, 0.25, 0.5, 0.75, 1.0, 1.0]$ (see Figure 7).

### Hint

Since training can take a while, we strongly recommend that you checkpoint your model every epoch onto your personal Google Drive. This is because Colab notebooks aren't persistent such that if you are idle for a while, you will lose connection and your training progress. This consists of:

- Google Drive mounting.
- Epoch-wise model & optimizer checkpointing.
- Model & optimizer resuming from checkpoints.

## 1.1 Implementing the UNet

### Implementing Simple and Composed Ops

In [None]:
class Conv(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError()


class DownConv(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError()


class UpConv(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError()


class Flatten(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError()


class Unflatten(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError()


class ConvBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError()


class DownBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError()


class UpBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError()

### Implementing Unconditional UNet

In [None]:
class UnconditionalUNet(nn.Module):
    def __init__(
        self,
        in_channels: int,
        num_hiddens: int,
    ):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        assert x.shape[-2:] == (28, 28), "Expect input shape to be (28, 28)."
        raise NotImplementedError()

## 1.2 Using the UNet to Train a Denoiser
In this part, you will train your UNet to perform single-step denoising on MNIST images. You will define your configuration, prepare data, implement training, and visualize results.

### 1.2.1 Training

In [None]:
# @title Configuration and setup
# Define your configuration using dataclass
# Include: device (e.g., "cuda"), data_dir, batch_size, num_epochs, lr, num_hiddens
# === CODE TODO BEGIN ===
@dataclasses.dataclass
class SingleStepConfig(object):
    device: str = None  # TODO
    data_dir: str = None  # TODO
    batch_size: int = None  # TODO
    num_epochs: int = None  # TODO
    lr: float = None  # TODO
    num_hiddens: int = None  # TODO
# === CODE TODO END ===

# Set random seeds for reproducibility using random, numpy, and torch
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_everything(0)
cfg = SingleStepConfig()

In [None]:
#@title Load dataset and initialize model/optimizer
# === CODE TODO BEGIN ===
# Load the MNIST dataset (from torchvision.datasets)
# Prepare DataLoader with shuffling for training and no shuffling for testing
train_dataset = None  # TODO
test_dataset = None  # TODO

train_loader = None  # TODO
test_loader = None  # TODO

# Instantiate your UnconditionalUNet model (from 1.1) and move it to cfg.device
model = None  # TODO: Replace with model definition

# Define the Adam optimizer
optimizer = None  # TODO: Replace with optimizer definition
# === CODE TODO END ===

In [None]:
# @title Deliverable 1-1
# A visualization of different noising process over sigma = [0.0, 0.25, 0.5, 0.75, 1.0].
# Hint: you may need to reshape or arrange the image tensors,
# and use media.show_images to display results

# === CODE TODO BEGIN ===



# === CODE TODO END ===

In [None]:
#@title Training loop
# Assumes you have:
# - model: your UNet denoiser
# - optimizer: Adam optimizer
# - train_loader: batches of MNIST images
# - device: "cuda" or "cpu"
# - loss_fn: a loss function
# Impltement the training loop to train your model for several epochs
# Remember to use your configuration, e.g. cfg.batch_size and cfg.device
# === CODE TODO BEGIN ===



# === CODE TODO END ===

In [None]:
# @title Deliverable 1-2
# A training loss curve plot every few iterations during the whole training process.
# === CODE TODO BEGIN ===



# === CODE TODO END ===

In [None]:
# @title Deliverable 1-3
# Sample results on the test set after the first and the 5-th epoch
# Hint: You can either visualize results immediately during training (recommended for easier comparison),
#  or Store outputs and visualize later after training.
# === CODE TODO BEGIN ===



# === CODE TODO END ===

### 1.2.2 Out-of-Distribution Testing

In [None]:
# @title Deliverable 1-4
# Sample results on the test set with out-of-distribution noise levels after the model is trained.
# Keep the same image and vary \sigma = [0.0, 0.25, 0.5, 0.75, 1.0, 1.0].
# === CODE TODO BEGIN ===



# === CODE TODO END ===

# Part 2: Training a Diffusion Model

## Deliverables for Time-conditioned UNet
- A training loss curve plot for the time-conditioned UNet over the whole training process (figure 10).
- Sampling results for the time-conditioned UNet for 5 and 20 epochs.
Note: providing a gif is optional.

## 2.1 Implementing a Time-conditioned UNet

In [None]:
class FCBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError()


class TimeConditionalUNet(nn.Module):
    def __init__(
        self,
        in_channels: int,
        num_classes: int,
        num_hiddens: int,
    ):
        super().__init__()

    def forward(
        self,
        x: torch.Tensor,
        t: torch.Tensor,
    ) -> torch.Tensor:
        """
        Args:
            x: (N, C, H, W) input tensor.
            t: (N,) normalized time tensor.

        Returns:
            (N, C, H, W) output tensor.
        """
        assert x.shape[-2:] == (28, 28), "Expect input shape to be (28, 28)."
        raise NotImplementedError()

## 2.2 Training the UNet and 2.3 Sampling from the UNet

### Implementing DDPM Forward and Inverse Process for Time-conditioned Denoising

In [None]:
def ddpm_schedule(beta1: float, beta2: float, num_ts: int) -> dict:
    """Constants for DDPM training and sampling.

    Arguments:
        beta1: float, starting beta value.
        beta2: float, ending beta value.
        num_ts: int, number of timesteps.

    Returns:
        dict with keys:
            betas: linear schedule of betas from beta1 to beta2.
            alphas: 1 - betas.
            alpha_bars: cumulative product of alphas.
    """
    assert beta1 < beta2 < 1.0, "Expect beta1 < beta2 < 1.0."
    raise NotImplementedError()

In [None]:
def ddpm_forward(
    unet: TimeConditionalUNet,
    ddpm_schedule: dict,
    x_0: torch.Tensor,
    num_ts: int,
) -> torch.Tensor:
    """Algorithm 1 of the DDPM paper.

    Args:
        unet: TimeConditionalUNet
        ddpm_schedule: dict
        x_0: (N, C, H, W) input tensor.
        num_ts: int, number of timesteps.
    Returns:
        (,) diffusion loss.
    """
    unet.train()
    # YOUR CODE HERE.
    raise NotImplementedError()

In [None]:
@torch.inference_mode()
def ddpm_sample(
    unet: TimeConditionalUNet,
    ddpm_schedule: dict,
    img_wh: tuple[int, int],
    num_ts: int,
    seed: int = 0,
) -> torch.Tensor:
    """Algorithm 2 of the DDPM paper with classifier-free guidance.

    Args:
        unet: TimeConditionalUNet
        ddpm_schedule: dict
        img_wh: (H, W) output image width and height.
        num_ts: int, number of timesteps.
        seed: int, random seed.

    Returns:
        (N, C, H, W) final sample.
    """
    unet.eval()
    # YOUR CODE HERE.
    raise NotImplementedError()

In [None]:
class DDPM(nn.Module):
    def __init__(
        self,
        unet: TimeConditionalUNet,
        betas: tuple[float, float] = (1e-4, 0.02),
        num_ts: int = 300,
        p_uncond: float = 0.1,
    ):
        super().__init__()
        self.unet = unet
        self.betas = betas
        self.num_ts = num_ts
        self.p_uncond = p_uncond
        self.ddpm_schedule = ddpm_schedule(betas[0], betas[1], num_ts)

        for k, v in ddpm_schedule(betas[0], betas[1], num_ts).items():
            self.register_buffer(k, v, persistent=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (N, C, H, W) input tensor.

        Returns:
            (,) diffusion loss.
        """
        return ddpm_forward(
            self.unet, self.ddpm_schedule, x, self.num_ts
        )

    @torch.inference_mode()
    def sample(
        self,
        img_wh: tuple[int, int],
        seed: int = 0,
    ):
        return ddpm_sample(
            self.unet, self.ddpm_schedule, img_wh, self.num_ts, seed
        )

### Training

In [None]:
# @title Configuration and setup
# Define your configuration class for DDPM
# Include: device, data_dir, work_dir, num_hiddens, num_ts, p_uncond, betas, num_epochs, batch_size, lr, lr_decay

# === CODE TODO BEGIN ===
@dataclasses.dataclass
class DDPMConfig(object):
    device: str = None  # TODO
    data_dir: str = None  # TODO
    work_dir: str = None  # TODO
    num_hiddens: int = None  # TODO
    num_ts: int = None  # TODO  # number of diffusion steps
    p_uncond: float = None  # TODO  # classifier-free guidance dropout probability
    betas: tuple[float, float] = None  # TODO  # linear noise schedule
    num_epochs: int = None  # TODO
    batch_size: int = None  # TODO
    lr: float = None  # TODO
    lr_decay: float = None  # TODO
# === CODE TODO END ===

# Set random seed
seed_everything(0)
# Create config instance
cfg = DDPMConfig()

In [None]:
#@title Load dataset and initialize model/optimizer/schedule
# === CODE TODO BEGIN ===
# Load the MNIST dataset (from torchvision.datasets)
# Prepare DataLoader with shuffling for training and no shuffling for testing
train_dataset = None  # TODO
test_dataset = None  # TODO

train_loader = None  # TODO
test_loader = None  # TODO

# Create the beta schedule for DDPM (use ddpm_schedule helper)
schedule = None  # TODO

# Instantiate your TimeConditionalUNet model (from 2.1) and move it to cfg.device
model = None  # TODO

# Create optimizer (Adam) and learning rate scheduler (ExponentialLR)
optimizer = None  # TODO
scheduler = None  # TODO
# === CODE TODO END ===

In [None]:
#@title Training loop
# Assumes you have:
# - model: a Conditional UNet (time- and class-conditioned) on cfg.device
# - optimizer: Adam optimizer with learning rate cfg.lr
# - scheduler: ExponentialLR to decay learning rate over cfg.num_epochs
# - train_loader: batches of MNIST images and labels
# - schedule: output of ddpm_schedule
# - cfg: your DDPMConfig instance containing training parameters
# - device: "cuda" or "cpu"
# Impltement the training loop to train your model for several epochs
# === CODE TODO BEGIN ===




# === CODE TODO END ===

In [None]:
# @title Deliverable 2-1
# A training loss curve plot for the time-conditioned UNet over the whole training process.
# === CODE TODO BEGIN ===



# === CODE TODO END ===

In [None]:
# @title Deliverable 2-2
# Sampling results for the time-conditioned UNet for 5 and 20 epochs. Note: providing a gif is optional.
# Hint: You can either visualize results immediately during training (recommended for easier comparison),
#  or Store outputs and visualize later after training.
# === CODE TODO BEGIN ===



# === CODE TODO END ===

## Deliverables for Class-conditioned UNet
- A training loss curve plot for the class-conditioned UNet over the whole training process.
- Sampling results for the class-conditioned UNet for 5 and 20 epochs. Generate 4 instances of each digit as shown above.
Note: providing a gif is optional.

## 2.4 Implementing class-conditioned UNet

In [None]:
class ClassConditionalUNet(nn.Module):
    def __init__(
        self,
        in_channels: int,
        num_classes: int,
        num_hiddens: int,
    ):
        super().__init__()

    def forward(
        self,
        x: torch.Tensor,
        c: torch.Tensor,
        t: torch.Tensor,
        mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """
        Args:
            x: (N, C, H, W) input tensor.
            c: (N,) int64 condition tensor.
            t: (N,) normalized time tensor.
            mask: (N,) mask tensor. If not None, mask out condition when mask == 0.

        Returns:
            (N, C, H, W) output tensor.
        """
        assert x.shape[-2:] == (28, 28), "Expect input shape to be (28, 28)."
        raise NotImplementedError()

## 2.5 Training and Sampling from the Class-Conditioned UNet

In [None]:
def ddpm_forward(
    unet: ClassConditionalUNet,
    ddpm_schedule: dict,
    x_0: torch.Tensor,
    c: torch.Tensor,
    p_uncond: float,
    num_ts: int,
) -> torch.Tensor:
    """Algorithm 1 of the DDPM paper.

    Args:
        unet: ClassConditionalUNet
        ddpm_schedule: dict
        x_0: (N, C, H, W) input tensor.
        c: (N,) int64 condition tensor.
        p_uncond: float, probability of unconditioning the condition.
        num_ts: int, number of timesteps.

    Returns:
        (,) diffusion loss.
    """
    unet.train()
    # YOUR CODE HERE.
    raise NotImplementedError()

In [None]:
@torch.inference_mode()
def ddpm_sample(
    unet: ClassConditionalUNet,
    ddpm_schedule: dict,
    c: torch.Tensor,
    img_wh: tuple[int, int],
    num_ts: int,
    guidance_scale: float = 5.0,
    seed: int = 0,
) -> torch.Tensor:
    """Algorithm 2 of the DDPM paper with classifier-free guidance.

    Args:
        unet: ClassConditionalUNet
        ddpm_schedule: dict
        c: (N,) int64 condition tensor. Only for class-conditional
        img_wh: (H, W) output image width and height.
        num_ts: int, number of timesteps.
        guidance_scale: float, CFG scale.
        seed: int, random seed.

    Returns:
        (N, C, H, W) final sample.
        (N, T_animation, C, H, W) caches.
    """
    unet.eval()
    # YOUR CODE HERE.
    raise NotImplementedError()

In [None]:
class DDPM(nn.Module):
    def __init__(
        self,
        unet: ClassConditionalUNet,
        betas: tuple[float, float] = (1e-4, 0.02),
        num_ts: int = 300,
        p_uncond: float = 0.1,
    ):
        super().__init__()
        self.unet = unet
        self.betas = betas
        self.num_ts = num_ts
        self.p_uncond = p_uncond
        self.ddpm_schedule = ddpm_schedule(betas[0], betas[1], num_ts)

    def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (N, C, H, W) input tensor.
            c: (N,) int64 condition tensor.

        Returns:
            (,) diffusion loss.
        """
        return ddpm_forward(
            self.unet, self.ddpm_schedule, x, c, self.p_uncond, self.num_ts
        )

    @torch.inference_mode()
    def sample(
        self,
        c: torch.Tensor,
        img_wh: tuple[int, int],
        guidance_scale: float = 5.0,
        seed: int = 0,
    ):
        return ddpm_sample(
            self.unet, self.ddpm_schedule, c, img_wh, self.num_ts, guidance_scale, seed
        )

### Training

In [None]:
# @title Configuration and setup
# Define your configuration class for DDPM
# Include: device, data_dir, work_dir, num_hiddens, num_ts, p_uncond, betas, num_epochs, batch_size, lr, lr_decay

# === CODE TODO BEGIN ===
@dataclasses.dataclass
class DDPMConfig(object):
    device: str = None  # TODO
    data_dir: str = None  # TODO
    work_dir: str = None  # TODO
    num_hiddens: int = None  # TODO
    num_ts: int = None  # TODO  # number of diffusion steps
    p_uncond: float = None  # TODO  # classifier-free guidance dropout probability
    betas: tuple[float, float] = None  # TODO  # linear noise schedule
    num_epochs: int = None  # TODO
    batch_size: int = None  # TODO
    lr: float = None  # TODO
    lr_decay: float = None  # TODO
# === CODE TODO END ===

# Set random seed
seed_everything(0)
# Create config instance
cfg = DDPMConfig()

In [None]:
#@title Load dataset and initialize model/optimizer/schedule
# === CODE TODO BEGIN ===
# Load the MNIST dataset (from torchvision.datasets)
# Prepare DataLoader with shuffling for training and no shuffling for testing
train_dataset = None  # TODO
test_dataset = None  # TODO

train_loader = None  # TODO
test_loader = None  # TODO

# Create the beta schedule for DDPM (use ddpm_schedule helper)
schedule = None  # TODO

# Instantiate your ClassConditionalUNet model (from 2.1) and move it to cfg.device
model = None  # TODO

# Create optimizer (Adam) and learning rate scheduler (ExponentialLR)
optimizer = None  # TODO
scheduler = None  # TODO
# === CODE TODO END ===

In [None]:
#@title Training loop
# Assumes you have:
# - model: a Conditional UNet (time- and class-conditioned) on cfg.device
# - optimizer: Adam optimizer with learning rate cfg.lr
# - scheduler: ExponentialLR to decay learning rate over cfg.num_epochs
# - train_loader: batches of MNIST images and labels
# - schedule: output of ddpm_schedule
# - cfg: your DDPMConfig instance containing training parameters
# - device: "cuda" or "cpu"
# Impltement the training loop to train your model for several epochs
# === CODE TODO BEGIN ===




# === CODE TODO END ===

In [None]:
# @title Deliverable 2-3
# A training loss curve plot for the class-conditioned UNet over the whole training process.
# === CODE TODO BEGIN ===



# === CODE TODO END ===

In [None]:
# @title Deliverable 2-4
# Sampling results for the class-conditioned UNet for 5 and 20 epochs.
# Generate 4 instances of each digit. Note: providing a gif is optional.
# Hint: You can either visualize results immediately during training (recommended for easier comparison),
#  or Store outputs and visualize later after training.
# === CODE TODO BEGIN ===



# === CODE TODO END ===

# Extra Credit

## Improve the UNet Architecture for time-conditional generation
For ease of explanation and implementation, our UNet architecture above is pretty simple. Modify the UNet (e.g. with skip connections) such that it can fit better during training and sample even better results.

## Implement Rectified Flow
- Implement [rectified flow](https://arxiv.org/abs/2209.03003), which is the state of art diffusion model.
- You can reference any code on github, but your implementation needs to follow the same code structure as our DDPM implementation.
- In other words, the code change required should be minimal: only changing the forward and sample functions.

# Generating a PDF for CMSX

You can just use `File > Print` to get a pdf of this page. Please double check that no outputs are cutoff!