![image source: https://www.artbreeder.com/image/6b4df6c697078f0e2cda42348ec6](images/2025-02-15-diffusion-model-mnist-part2.jpeg)

## Introduction

Welcome back to Part 2 of our journey into diffusion models! In the [first part](https://hassaanbinaslam.github.io/myblog/posts/2025-02-10-diffusion-model-mnist-part1.html), we successfully built a basic Convolutional UNet from scratch and trained it to directly predict denoised MNIST digits.  We saw that it could indeed remove some noise, but the results were still a bit blurry, and it wasn't quite the "diffusion model magic" we were hoping for.

One of the key limitations we hinted at was the simplicity of our `BasicUNet` architecture.  For this second part, we're going to address that and we'll be upgrading our UNet architecture to something more powerful and feature-rich.

To do this, we'll be leveraging the fantastic `diffusers` library from [Hugging Face](https://huggingface.co/).  [`diffusers`](https://huggingface.co/docs/diffusers/en/index) is a widely adopted toolkit in the world of diffusion models, providing pre-built and optimized components that can significantly simplify our development process and boost performance.

In this part, we'll replace our `BasicUNet` with a `UNet2DModel` from `diffusers`.  We'll keep the core task the same – direct image prediction – but with a more advanced UNet under the hood. This will allow us to see firsthand how architectural improvements can impact the quality of our denoising results, setting the stage for even more exciting explorations in future parts! Let's dive in!

### Credits

This post is inspired by the [Hugging Face Diffusion Course](https://huggingface.co/learn/diffusion-course/en/unit1/3)

### Environment Details

You can access and run this Jupyter Notebook from the GitHub repository on this link [2025-02-15-diffusion-model-mnist-part2.ipynb](https://github.com/hassaanbinaslam/myblog/blob/main/posts/2025-02-15-diffusion-model-mnist-part2.ipynb)

Run the following cell to install the required packages.

* This notebook can be run with [Google Colab](https://colab.research.google.com/) T4 GPU runtime.
* I have also tested this notebook with AWS SageMaker Jupyter Notebook running on instance "ml.g5.xlarge" and image "SageMaker Distribution 2.3.0".

In [1]:
%%capture
!pip install datasets[vision]
!pip install diffusers
!pip install watermark
!pip install torchinfo
!pip install matplotlib

[WaterMark](https://github.com/rasbt/watermark) is an IPython magic extension for printing date and time stamps, version numbers, and hardware information. Let's load this extension and print the environment details.

In [2]:
%load_ext watermark

In [3]:
%watermark -v -m -p torch,torchvision,datasets,diffusers,matplotlib,watermark,torchinfo

Python implementation: CPython
Python version       : 3.11.11
IPython version      : 7.34.0

torch      : 2.5.1+cu124
torchvision: 0.20.1+cu124
datasets   : 3.2.0
diffusers  : 0.32.2
matplotlib : 3.10.0
watermark  : 2.5.0
torchinfo  : 1.8.0

Compiler    : GCC 11.4.0
OS          : Linux
Release     : 6.1.85+
Machine     : x86_64
Processor   : x86_64
CPU cores   : 2
Architecture: 64bit



## Diving into `diffusers` and `UNet2DModel`

So, what exactly *is* this [`diffusers`](https://huggingface.co/docs/diffusers/en/index) library we're so excited about?  Think of `diffusers` as a comprehensive, community-driven library in [PyTorch](https://pytorch.org/) specifically designed for working with diffusion models. It's maintained by Hugging Face, the same team behind the popular [Transformers](https://huggingface.co/docs/transformers/en/index) library, so you know it's built with quality and ease of use in mind.

Why are we using `diffusers` now?  Several reasons!  First, it provides well-tested and optimized implementations of various diffusion model components, saving us from writing everything from scratch.  Second, it's a vibrant ecosystem, constantly evolving with the latest research and techniques in diffusion models.  By using `diffusers`, we're standing on the shoulders of giants!

For Part 2, the star of the show is the [`UNet2DModel`](https://huggingface.co/docs/diffusers/main/en/api/models/unet2d) class from `diffusers`. This is a more sophisticated UNet architecture compared to our `BasicUNet`.  It's like upgrading from a standard bicycle to a mountain bike – both are bikes, but the mountain bike is built for more challenging terrain and better performance.

What makes `UNet2DModel` more advanced? Let's look at some key architectural improvements under the hood:

*   **Configurable Block Types:** `UNet2DModel` is designed to be flexible. Instead of being fixed to a single type of block, it allows you to choose different block types for its downsampling and upsampling paths using parameters like `down_block_types` and `up_block_types`.  While the *default* `DownBlock2D` and `UpBlock2D` blocks are primarily convolutional,  `diffusers` also provides option for selecting "resnet" layers.

*   **Attention Mechanisms:**  `UNet2DModel` incorporates attention mechanisms, specifically "Attention Blocks," in its architecture.  Attention is a powerful concept in deep learning that allows the model to focus on the most relevant parts of the input when processing information.  In image generation, attention can help the model selectively focus on different regions of the image, potentially leading to finer details and more coherent structures.

*   **Group Normalization:**  Instead of Batch Normalization, `UNet2DModel` uses Group Normalization. Group Normalization is often favored in generative models, especially when working with smaller batch sizes, as it tends to be more stable and perform better in those scenarios.

*   **Timestep Embedding:**  Even though we are still doing direct image prediction in this part,  `UNet2DModel` is designed with diffusion models in mind.  It includes a `TimestepEmbedding` layer, which is a standard component in diffusion models to handle the timestep information (which we'll explore in later parts!).  For now, we'll just be passing in a timestep of 0, but this layer is there, ready for when we move to true diffusion.

These architectural enhancements in `UNet2DModel` give it a greater capacity to learn and potentially denoise images more effectively than our `BasicUNet`. Let's see if it lives up to the hype!

## Data Preparation and Preprocessing for MNIST

As we are building upon the foundations laid in Part 1, we will reuse the same data preparation and preprocessing steps for the MNIST dataset. For a more in-depth explanation of these steps, please refer back to the first part of this guide. Here, we will quickly outline the process to ensure our data is ready for training our enhanced UNet.

First, we load the MNIST dataset using the `datasets` library:

In [None]:
from datasets import load_dataset
dataset = load_dataset("mnist")
print(dataset)

This code snippet downloads and loads the MNIST dataset.  As we saw in Part 1, this dataset is provided as a `DatasetDict` with 'train' and 'test' splits, each containing 'image' and 'label' features.

Next, we define our preprocessing pipeline using `torchvision.transforms`:

In [None]:
import torch
from torchvision import transforms

image_size = 32  # Define the target image size

preprocess = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
])

This `preprocess` pipeline consists of two transformations:

*   `transforms.Resize((image_size, image_size))`: Resizes each image to a fixed size of 32x32 pixels. This ensures consistent input dimensions for our UNet model.
*   `transforms.ToTensor()`: Converts the images to PyTorch tensors and scales the pixel values to the range \[0, 1]. This normalization is crucial for training deep learning models effectively.

To apply this preprocessing to the dataset efficiently, we define a `transform` function and set it for our dataset:

In [None]:
# Define the transform function
def transform(examples):
    examples = [preprocess(image) for image in examples["image"]]
    return {"images": examples}

# Apply the transform to the dataset
dataset.set_transform(transform)

This `transform` function applies our `preprocess` pipeline to each image in the dataset on-the-fly, meaning preprocessing happens only when the data is accessed, saving memory and keeping our dataset efficient.

With the MNIST dataset loaded and preprocessed, we are now ready to introduce our enhanced UNet architecture from the `diffusers` library! Let's move on to explore the `UNet2DModel`.

## Implementing `UNet2DModel`

Now, let's see how to put the `diffusers` `UNet2DModel` into action for our MNIST digit denoising task.  Here's the code snippet we'll use to instantiate the model:

In [None]:
from diffusers import UNet2DModel

model = UNet2DModel(
    sample_size=32,
    in_channels=1,
    out_channels=1,
    layers_per_block=2,
    block_out_channels=(32, 64, 64),
    down_block_types=(
        "DownBlock2D",
        "AttnDownBlock2D",
        "AttnDownBlock2D",
    ),
    up_block_types=(
        "AttnUpBlock2D",
        "AttnUpBlock2D",
        "UpBlock2D",
    ),
)

print(model) # Or summary(model) for a more detailed view

Let's break down the parameters we've used when creating our `UNet2DModel` instance:

*   `sample_size=32`:  This specifies the size of the input images. We're still working with 32x32 MNIST images after preprocessing, so we set this to 32.

*   `in_channels=1`: MNIST images are grayscale, meaning they have a single color channel.  Therefore, `in_channels` is set to 1.

*   `out_channels=1`:  We want our UNet to output denoised grayscale images, so `out_channels` is also 1.

*   `layers_per_block=2`: This parameter controls the number of ResNet layers within each UNet block (both downsampling and upsampling blocks). We've chosen 2, meaning each block will have two ResNet layers. Increasing this would make the model deeper and potentially more powerful, but also increase training time.

*   `block_out_channels=(32, 64, 64)`: This is a crucial parameter that defines the number of output channels for each block in the downsampling path.
    *   The first value, `32`, corresponds to the output channels of the initial downsampling block.
    *   The second value, `64`, is for the next downsampling block, and so on.
    *   We've chosen `(32, 64, 64)`, which is "roughly matching our basic unet example" as we noted in the code comments. This is a deliberate choice to keep the model size somewhat comparable to our `BasicUNet` while still benefiting from the architectural improvements of `UNet2DModel`.

*   `down_block_types=("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`: This list specifies the type of downsampling blocks to use in the encoder path.
    *   `"DownBlock2D"`:  A standard ResNet downsampling block.
    *   `"AttnDownBlock2D"`: A ResNet downsampling block with added attention mechanisms.
    *   We're using a mix of standard and attention-based downsampling blocks to leverage the benefits of attention in capturing important image features.

*   `up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`:  Similarly, this list defines the types of upsampling blocks in the decoder path, mirroring the downsampling path and also incorporating attention blocks in the upsampling process.

By carefully configuring these parameters, we've created a `UNet2DModel` tailored for our MNIST denoising task, leveraging the power of `diffusers` and incorporating more advanced architectural components compared to our `BasicUNet`.  The `print(model)` output (or `summary(model)`) will show the detailed architecture and confirm the parameter settings we've defined.  You'll likely notice a significantly larger number of parameters compared to `BasicUNet`, hinting at the increased capacity of this enhanced model.

## Training the Enhanced UNet

With our `UNet2DModel` defined, the next step is to train it!  The training process for this enhanced UNet will be remarkably similar to what we did in Part 1 with our `BasicUNet`.  This is intentional! By keeping the training process consistent, we can isolate the impact of the architectural changes we've made by switching to `UNet2DModel`.

We will still be using:

*   **Direct Image Prediction:**  Our model will still be trained to directly predict the denoised version of a noisy MNIST image in a single forward pass.
*   **Mean Squared Error (MSE) Loss:** We'll continue to use MSE loss (`F.mse_loss`) to measure the difference between the predicted denoised image and the clean target image.
*   **Adam Optimizer:**  We'll stick with the Adam optimizer (`torch.optim.Adam`) to update the model's weights during training.

Here's a snippet of the training loop code. You'll notice it's almost identical to the training loop from Part 1:

In [None]:
import torch.nn.functional as F
from torch.optim import Adam
import matplotlib.pyplot as plt

# --- Setup (Device, Model, Optimizer, Loss History, Hyperparameters) ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device) # Our UNet2DModel from diffusers
optimizer = Adam(model.parameters(), lr=1e-3) # Same learning rate as Part 1
losses = []
num_epochs = 5 # Same number of epochs as Part 1
batch_size = 128 # Same batch size as Part 1

train_dataloader = torch.utils.data.DataLoader(
    dataset["train"], batch_size=batch_size, shuffle=True
)

# --- Training Loop ---
for epoch in range(num_epochs):
    for batch in train_dataloader:
        clean_images = batch["images"].to(device)
        noise = torch.randn_like(clean_images).to(device)
        noise_amount = torch.randn(clean_images.shape[0]).to(device)
        noisy_images = corrupt(clean_images, noise, noise_amount) # Same corrupt function

        predicted_images = model(noisy_images, 0).sample # Still passing timestep 0

        loss = F.mse_loss(predicted_images, clean_images)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())

    avg_loss = sum(losses[-len(train_dataloader):]) / len(train_dataloader)
    print(f"Finished epoch {epoch+1}. Average loss: {avg_loss:.4f}")

# --- Plotting Loss Curve ---
plt.figure(figsize=(8, 4))
plt.plot(losses, label="Training Loss")
plt.title("Training Loss Curve (UNet2DModel - Direct Prediction)") # Updated title
plt.xlabel("Iteration")
plt.ylabel("MSE Loss")
plt.legend()
plt.grid(True)
plt.show()

As you can see, the core training logic remains the same. We load batches, generate noise, corrupt images, feed them to the `UNet2DModel` (still with a timestep of 0), calculate MSE loss, and update the weights using Adam.  We've also kept the hyperparameters (learning rate, batch size, number of epochs) consistent with Part 1 for a direct comparison.

After running this training code, we obtain the following loss curve:

[**Insert Loss Curve Plot Image Here -  *This is where you would insert the actual plot image generated by your code***]

This loss curve shows the training progress of our `UNet2DModel`.  [**Add a brief observation about the loss curve here - e.g., Does it converge lower than in Part 1? Does it converge faster? Is it more stable?  Even a qualitative observation is good.**]

Now that our enhanced UNet is trained, let's see how it performs in denoising MNIST digits!