![](images/2025-02-18-diffusion-model-mnist-part3.png)

## Introduction

Welcome back to the final part of our hands-on journey into diffusion models for MNIST digit generation! In [Part 1](https://hassaanbinaslam.github.io/myblog/posts/2025-02-10-diffusion-model-mnist-part1.html), we laid the groundwork by building a basic Convolutional UNet and training it to directly predict clean MNIST digits from noisy inputs. We then enhanced our UNet architecture in [Part 2](https://hassaanbinaslam.github.io/myblog/posts/2025-02-15-diffusion-model-mnist-part2.html), leveraging the power of the diffusers library and its UNet2DModel to achieve improved denoising performance.

While our direct image prediction approach showed promising results, we hinted that it was a simplification of true diffusion models. We observed that even with our enhanced UNet and iterative refinement, the generated digits still lacked the crispness and fidelity we might expect from "diffusion model magic."

Now, in this final installment, we're ready to take the leap into the heart of diffusion models. We'll move beyond directly predicting clean images and embrace the core principles that make diffusion models so powerful: **noise prediction** and **scheduled denoising**. Get ready to unlock the true potential of diffusion and witness a significant step-up in image generation quality!

### Credits

This post is inspired by the [Hugging Face Diffusion Course](https://huggingface.co/learn/diffusion-course/en/unit1/3)

### Environment Details

You can access and run this Jupyter Notebook from the GitHub repository on this link [2025-02-18-diffusion-model-mnist-part3.ipynb](https://github.com/hassaanbinaslam/myblog/blob/main/posts/2025-02-18-diffusion-model-mnist-part3.ipynb)

Run the following cell to install the required packages.

* This notebook can be run with [Google Colab](https://colab.research.google.com/) T4 GPU runtime.
* I have also tested this notebook with AWS SageMaker Jupyter Notebook running on instance "ml.g5.xlarge" and image "SageMaker Distribution 2.3.0".

In [1]:
%%capture
!pip install datasets[vision]
!pip install diffusers
!pip install watermark
!pip install torchinfo
!pip install matplotlib

[WaterMark](https://github.com/rasbt/watermark) is an IPython magic extension for printing date and time stamps, version numbers, and hardware information. Let's load this extension and print the environment details.

In [2]:
%load_ext watermark

In [3]:
%watermark -v -m -p torch,torchvision,datasets,diffusers,matplotlib,watermark,torchinfo

Python implementation: CPython
Python version       : 3.11.11
IPython version      : 7.34.0

torch      : 2.5.1+cu124
torchvision: 0.20.1+cu124
datasets   : 3.3.0
diffusers  : 0.32.2
matplotlib : 3.10.0
watermark  : 2.5.0
torchinfo  : 1.8.0

Compiler    : GCC 11.4.0
OS          : Linux
Release     : 6.1.85+
Machine     : x86_64
Processor   : x86_64
CPU cores   : 2
Architecture: 64bit



## From Direct Image Prediction to Noise Prediction: A Paradigm Shift

In Parts 1 and 2, we trained our UNet to perform **direct image prediction**.  This meant we fed the model a noisy image and asked it to directly output the estimated *clean* image. While this approach allowed us to grasp the basic mechanics of UNets and image denoising, it's important to understand its limitations and why true diffusion models take a different path.

Direct image prediction, as we implemented it, is essentially a **one-step denoising process**.  It attempts to remove all the noise in a single forward pass through the network.  Think of it like trying to un-blur a heavily distorted image in just one go – it's a difficult task, and the results can often be blurry and lack fine details.  Furthermore, this direct approach doesn't fully capture the essence of the diffusion process, which is inherently gradual and iterative.

True diffusion models, and the approach we'll adopt now, operate on a different principle: **noise prediction**.  Instead of predicting the clean image directly, we train our model to predict the **noise** that was added to a slightly noisier version of the image at each step of the *reverse* diffusion process.

Imagine you're slowly un-blurring an image, step by step.  At each step, instead of trying to guess the *entire* sharp image, you focus on identifying and removing just a *tiny bit* of blur.  By iteratively removing small amounts of blur (or noise), you gradually reveal the underlying clean image.  This is the essence of noise prediction.

Our model will now learn to estimate the noise present in a slightly noisy image.  This predicted noise can then be used to "step back" along the reverse diffusion trajectory, creating a slightly less noisy image.  By repeating this process over many steps – a process we call **scheduled denoising** (which we'll discuss shortly) – we can generate high-quality images from pure noise.

This shift to noise prediction is a crucial paradigm change.  It allows for:

*   **More stable training:**  Predicting noise at each step is a less ambitious and more manageable task for the model compared to directly predicting the clean image.
*   **Improved sample quality:**  The iterative nature of noise prediction, guided by a schedule, leads to the generation of more detailed and visually appealing images.
*   **Alignment with true diffusion models:**  Noise prediction is the fundamental building block of modern diffusion models, bringing us closer to state-of-the-art image generation techniques.

In the following sections, we'll delve into the code modifications needed to switch to noise prediction and explore the concept of scheduled denoising in detail.

## Data Preparation, Preprocessing, and UNet Model

As we are building upon the foundations laid in [Part 1](https://hassaanbinaslam.github.io/myblog/posts/2025-02-10-diffusion-model-mnist-part1.html), we will reuse the same data preparation and preprocessing steps for the MNIST dataset. For a more in-depth explanation of these steps, please refer back to the first part of this guide. Here, we will quickly outline the process to ensure our data is ready for training. 

We will train the same `UNet2DModel` model that we used in the Part 2.

In [None]:
### Load MNIST Dataset
from datasets import load_dataset
dataset = load_dataset("mnist")
print(dataset)

import torch
from torchvision import transforms

image_size = 32  # Define the target image size

preprocess = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
])

### Define preprocess pipelein
import torch
from torchvision import transforms

image_size = 32  # Define the target image size

preprocess = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
])

## Define the transform function
def transform(examples):
    examples = [preprocess(image) for image in examples["image"]]
    return {"images": examples}

## Apply the transform to the dataset
dataset.set_transform(transform)

## Definition of the noise corruption function
def corrupt(x, noise, amount):
    amount = amount.view(-1, 1, 1, 1)  # make sure it's broadcastable
    return (
        x * (1 - amount) + noise * amount
    )  # equivalent to x.lerp(noise, amount)

### Define the UNet Model (Same as Part 2)
from diffusers import UNet2DModel

model = UNet2DModel(
    sample_size=32,
    in_channels=1,
    out_channels=1,
    layers_per_block=2,
    block_out_channels=(32, 64, 64),
    down_block_types=(
        "DownBlock2D",
        "AttnDownBlock2D",
        "AttnDownBlock2D",
    ),
    up_block_types=(
        "AttnUpBlock2D",
        "AttnUpBlock2D",
        "UpBlock2D",
    ),
)


from torchinfo import summary
summary(model)

## Scheduled Denoising: Guiding the Reverse Diffusion

The power of noise prediction truly shines when combined with **scheduled denoising**.  As we discussed, diffusion models work by gradually reversing the noise addition process.  Scheduled denoising provides the *schedule* or the *steps* for this reverse process, controlling how we iteratively remove noise from an image.

> Think of it like carefully peeling layers of an onion.  Scheduled denoising defines how many layers we peel back and how much we peel at each step.  In diffusion models, these "layers" correspond to different levels of noise.

**Why do we need a schedule?**

Instead of removing all the predicted noise in one go, scheduled denoising breaks down the denoising process into a series of discrete timesteps.  This is crucial for several reasons:

*   **Controlled Noise Removal:** A schedule allows us to gradually remove noise, starting from a highly noisy image (or pure noise) and progressively refining it.  This iterative refinement leads to better image quality compared to a one-step approach.
*   **Stability and Guidance:** By controlling the denoising steps, we provide a structured path for the reverse diffusion process.  This makes the generation process more stable and predictable.
*   **Flexibility and Control:** Different schedules can be designed to influence the generation process. For example, some schedules might prioritize faster generation, while others might focus on higher quality.

**Timesteps and the Reverse Process:**

In the following code, we represent the denoising schedule using **timesteps**.  These timesteps are typically a sequence of numbers going from a large value (representing high noise) down to a small value (representing low noise or a clean image).

In [None]:
# Setup the DDPM scheduler for training
from diffusers import DDPMScheduler

num_train_timesteps = 1000

scheduler = DDPMScheduler(
    num_train_timesteps=num_train_timesteps,
    beta_start=0.0001,
    beta_end=0.02,
    beta_schedule="linear",
)

You'll notice in the code that you're now using a `DDPMScheduler` from `diffusers`. This scheduler is responsible for:

1.  **Generating Timesteps:**  It creates a schedule of timesteps that guide the reverse diffusion process.  We've initialized it with `num_train_timesteps = 1000`. This means the forward diffusion process (noise addition) is simulated over 1000 steps.  For the reverse process (denoising), we'll also use these timesteps, though we might choose to use fewer steps for faster inference.

2.  **Adding Noise (Forward Process Simulation):**  During training, the scheduler's `add_noise` function helps us create noisy versions of clean images at different timesteps. This is what we are using in our training loop:

    ```python
    noisy_images = scheduler.add_noise(clean_images, noise, timesteps)
    ```

3.  **Stepping Backwards (Reverse Process):**  Crucially, the scheduler also provides a `step` function that helps us take a denoising step *backwards* along the diffusion trajectory. This function is used during inference (and could be used in more advanced training schemes).  We'll see how to use this `step` function in Part 3's code.

**In essence, the `DDPMScheduler` encapsulates the logic for both the forward (noise addition) and reverse (denoising) diffusion processes, providing us with the tools to implement scheduled denoising.**

In the next section, we'll modify our training loop to incorporate noise prediction and scheduled denoising using the `DDPMScheduler`. We'll see how the `step` function guides the reverse diffusion and how we train our model to predict the noise at each timestep.

## Modifying the Training Loop for Noise Prediction

Now that we understand the concepts of noise prediction and scheduled denoising, let's adapt our training loop to reflect these changes.  We'll be using the `DDPMScheduler` and training our `UNet2DModel` to predict noise instead of directly predicting clean images.

Here's how we'll modify the training loop (referencing the code in `mnist-diffuse-noise-schedular.pdf`):

**1.  Sampling Timesteps:**

   Instead of just generating random noise amounts, we now need to sample **timesteps** for each image in the batch. These timesteps will be integers between 0 and `num_train_timesteps - 1` (in your code, `num_train_timesteps = 1000`).  These timesteps tell the scheduler *how much* noise to add in the forward process and guide the reverse process.

   You've already added this to your training loop:

   ```python
   timesteps = torch.randint(0, num_train_timesteps, (batch_size,), device=device).long()
   ```

**2.  Adding Noise with the Scheduler:**

   We'll use the `scheduler.add_noise` function to add noise to our clean images, *conditioned on the sampled timesteps*. This function takes the clean images, random noise, and the timesteps as input and returns the noisy images.

   You've correctly implemented this as well:

   ```python
   noisy_images = scheduler.add_noise(clean_images, noise, timesteps)
   ```

**3.  Model Predicts Noise:**

   The crucial change is that we now feed the `noisy_images` and the `timesteps` to our `UNet2DModel`, and we train it to predict the **noise** that was added.  The `UNet2DModel` in `diffusers` is designed to be conditioned on timesteps.

   Here's the code:

   ```python
   noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
   ```

   Notice that we pass `timesteps` as the second argument to the `model`.  The `return_dict=False)[0]` part is just to extract the predicted noise tensor from the output.

**4.  Loss Calculation:**

   Our loss function remains **Mean Squared Error (MSE)**, but now we calculate the MSE between the **predicted noise (`noise_pred`)** and the **actual noise (`noise`)** that we used to corrupt the images.  This is how we train the model to accurately predict the noise.

   ```python
   loss = F.mse_loss(noise_pred, noise)
   ```

**5.  Rest of the Training Loop:**

   The rest of the training loop (optimizer step, loss tracking, etc.) remains largely the same as in Part 2.

**Complete Modified Training Loop Snippet:**

Here's the complete, modified training loop snippet from your code, incorporating noise prediction and scheduled denoising:

In [None]:
# train the model
from torch.nn import functional as F

batch_size = 128
train_dataloader = DataLoader(
    dataset["train"], batch_size=batch_size, shuffle=True
)
num_epochs = 5

train_dataloader = torch.utils.data.DataLoader(
    dataset["train"], batch_size=batch_size, shuffle=True
)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
losses = []  # Somewhere to store the loss values for later plotting

model.train()
for epoch in range(num_epochs):
    for batch in train_dataloader:
        # Get the clean images and move to device
        clean_images = batch["images"].to(device)  # shape: (B, 1, H, W)
        batch_size = clean_images.shape[0]

        # Sample random noise to add
        noise = torch.randn_like(clean_images).to(device)

        # Sample a random timestep for each image in the batch
        timesteps = torch.randint(0, num_train_timesteps, (batch_size,), device=device).long()

        # Add noise to the clean images according to the scheduler's forward process
        noisy_images = scheduler.add_noise(clean_images, noise, timesteps)

        # Let the model predict the noise component from the noisy images
        # (Note: The model is conditioned on the timestep)
        noise_pred = model(noisy_images, timesteps, return_dict=False)[0]

        # Compute the loss between the predicted noise and the actual noise
        loss = F.mse_loss(noise_pred, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

    # Print the average loss for this epoch
    avg_loss = sum(losses[-len(train_dataloader):]) / len(train_dataloader)
    print(f"Finished epoch {epoch}. Average loss: {avg_loss:.6f}")

By making these modifications, we've successfully shifted our training paradigm from direct image prediction to noise prediction, leveraging the `DDPMScheduler` to manage the diffusion process. In the next section, we'll focus on how to modify the inference process to generate images using scheduled denoising.

## Inference with Scheduled Denoising: Generating Images Iteratively

With our model now trained to predict noise, we can finally generate images using the true power of diffusion models: **iterative denoising guided by a schedule**.  This is a significant departure from our one-shot denoising approach in Parts 1 and 2.

Here's how we'll modify the inference process (again, referencing `mnist-diffuse-noise-schedular.pdf`):

**1.  Start with Pure Noise:**

   We begin the generation process with pure random noise. This noise will be our starting point for the reverse diffusion process.

   ```python
   sample = torch.randn((num_images, 1, image_size, image_size)).to(device)
   ```

**2.  Set up the Denoising Loop:**

   We'll use a loop that iterates through the **timesteps** provided by our `scheduler`.  Crucially, during inference, we need to use the *inference timesteps* which are obtained using `scheduler.set_timesteps(num_inference_steps)`.  These timesteps are in *descending order*, going from high noise to low noise.

   ```python
   for t in scheduler.timesteps:
       # ... denoising step ...
   ```

**3.  Model Predicts Noise at Each Timestep:**

   Inside the loop, for each timestep `t`, we feed the current noisy image `sample` and the timestep `t` to our `UNet2DModel` to predict the noise:

   ```python
   noise_pred = model(sample, t, return_dict=False)[0]
   ```

**4.  Scheduler Steps Backwards:**

   This is the core of scheduled denoising! We use the `scheduler.step` function to take a step *backwards* along the diffusion trajectory, removing a bit of noise from the current `sample`.  The `step` function takes the `noise_pred`, the current timestep `t`, and the current `sample` as input and returns a dictionary containing the updated sample in `prev_sample`.

   ```python
   output = scheduler.step(noise_pred, t, sample)
   sample = output.prev_sample
   ```

   The `scheduler.step` function intelligently uses the predicted noise and the schedule information to determine how much to "denoise" the image at each timestep.  This is where the magic of the diffusion schedule comes in!

**5.  Iterate and Refine:**

   We repeat steps 3 and 4 for all timesteps in the schedule.  In each iteration, the image becomes progressively less noisy and more structured, gradually revealing a coherent MNIST digit.

**6.  Visualization:**

   After the loop completes, the `sample` tensor will contain the generated (denoised) images.  We can then visualize these images as we did in previous parts.

**Complete Inference Loop Snippet:**

Here's the complete inference loop snippet from your code, demonstrating scheduled denoising:

```python
num_inference_steps = 50 # Example: Adjust for speed/quality trade-off
scheduler.set_timesteps(num_inference_steps) # Set inference timesteps!

sample = torch.randn((num_images, 1, image_size, image_size)).to(device) # Start with noise

for t in scheduler.timesteps: # Iterate through timesteps (descending order)
    noise_pred = model(sample, t, return_dict=False)[0] # Predict noise
    output = scheduler.step(noise_pred, t, sample) # Scheduler step (denoise)
    sample = output.prev_sample # Update sample

# sample now contains generated images!
generated_images = sample.cpu()
# ... (rest of your visualization code) ...
```

By implementing this iterative inference process with the `scheduler.step` function, we are now performing true scheduled denoising.  This should lead to significantly improved image generation quality compared to our previous direct prediction approaches.

In the next section, we'll look at the results of this new approach and compare them to our previous methods.

Results and Discussion: Witnessing the "Diffusion Model Magic"**

After implementing noise prediction and scheduled denoising, let's examine the generated MNIST digits and compare them to our previous results. We'll first look at the images generated using a relatively small number of inference steps (e.g., `num_inference_steps = 5`) and then increase the steps (e.g., `num_inference_steps = 50` or `100`) to see the impact of more iterative refinement.

**Results with a Small Number of Inference Steps (e.g., 5):**

*(Insert image grid here showing generated digits with ~5 inference steps)*

*Describe what you observe in the generated images. For example:*

"With just 5 inference steps, we can already see a remarkable improvement compared to our direct prediction models. The generated digits are no longer blurry noise.  We can discern clear digit shapes, and while they might not be perfectly crisp, they are undeniably recognizable as MNIST digits.  This demonstrates the power of even a few steps of scheduled denoising."

**Results with a Larger Number of Inference Steps (e.g., 50 or 100):**

*(Insert image grid here showing generated digits with ~50-100 inference steps)*

*Describe the improvement with more steps. For example:*

"Increasing the number of inference steps to 50 (or even 100) leads to a further significant improvement in image quality. The digits become much sharper, more well-defined, and exhibit finer details.  The 'fuzziness' we observed with fewer steps is largely gone.  These generated digits are now convincingly MNIST-like, showcasing the 'diffusion model magic' we were aiming for!"

**Comparison to Previous Models:**

*Compare these results to the direct prediction models from Part 1 and Part 2. Highlight the key differences. For example:*

"Comparing these results to the outputs of our direct image prediction models from Part 1 and Part 2, the difference is striking.  The digits generated with scheduled denoising are significantly sharper, clearer, and more visually appealing.  They no longer suffer from the blurriness and lack of detail that characterized our earlier attempts.  While iterative refinement in Part 2 provided some improvement, it still didn't reach this level of quality.  True scheduled denoising, driven by noise prediction and the `DDPMScheduler`, has unlocked a new level of image generation fidelity."

**Discussion and Key Takeaways:**

*Summarize the key learnings from Part 3. For example:*

*   **Noise Prediction is Key:** "We've demonstrated that training our model to predict noise, rather than directly predicting clean images, is a crucial step towards building effective diffusion models. This paradigm shift leads to more stable training and better generation quality."
*   **Scheduled Denoising Unleashes Power:** "Scheduled denoising, guided by the `DDPMScheduler` and its timesteps, is what truly unleashes the power of diffusion models.  The iterative refinement process, stepping backwards along the diffusion trajectory, allows us to generate high-quality images from pure noise."
*   **Iterative Refinement is Essential:** "The number of inference steps directly impacts image quality. More steps generally lead to sharper and more detailed images, but also increase generation time.  There's a trade-off between quality and speed that can be adjusted by changing `num_inference_steps`."
*   **We've Achieved "Diffusion Model Magic":** "With this final part, we've moved beyond simplified approaches and implemented the core principles of diffusion models.  The results speak for themselves – we can now generate convincingly MNIST-like digits from random noise, showcasing the power and potential of diffusion models."