# Diffusion Models: A Comprehensive Notebook

This notebook demonstrates a simplified implementation of a diffusion model (a DDPM – Denoising Diffusion Probabilistic Model) using PyTorch. We will:

- Introduce the forward (noise-adding) and reverse (denoising) diffusion processes.
- Build a simple convolutional model that learns to predict the noise added to images.
- Train the model on MNIST (for demonstration purposes) while logging training metrics with wandb.
- Sample images from the learned model and visualize intermediate steps using Matplotlib and Seaborn.

**Note:** Training full diffusion models (e.g., on high-resolution datasets) requires significant compute. This notebook is meant for educational purposes.


In [1]:
# Imports and Initialization
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

import wandb

# Initialize wandb (make sure you are logged in to wandb)
wandb.init(project="diffusion-model-demo", mode="disabled")

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


## Data Preparation

We will use the MNIST dataset for demonstration. Images are scaled to [0,1] and converted to tensors.


In [3]:
# Data Preparation: MNIST Dataset
transform = transforms.Compose([
    transforms.ToTensor(),  # Converts image to tensor (C x H x W) in [0,1]
])

batch_size = 128

train_dataset = torchvision.datasets.MNIST(root='../data', train=True, download=True, transform=transform)
train_loader  = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)


## Model Definition

We define a simple CNN that takes as input an image along with a time-step embedding. This model will learn to predict the noise that was added during the forward diffusion process.

*Note: In practice, diffusion models use U-Net architectures with skip connections; here, we use a simplified architecture for clarity.*


In [None]:
# A simple time embedding module and CNN for noise prediction
class SimpleDiffusionModel(nn.Module):
    def __init__(self):
        super(SimpleDiffusionModel, self).__init__()
        # Time embedding: embed a scalar time into a vector and reshape for broadcasting
        self.time_embed = nn.Sequential(
            nn.Linear(1, 128),
            nn.ReLU(),
            nn.Linear(128, 28 * 28)
        )
        # A simple convolutional network to predict noise
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 1, 3, padding=1)
        )
        
    def forward(self, x, t):
        # x: (batch, 1, 28, 28), t: (batch,) time steps
        # Embed t and reshape to image size
        t = t.float().unsqueeze(1)  # shape (batch, 1)
        t_emb = self.time_embed(t)  # shape (batch, 28*28)
        t_emb = t_emb.view(-1, 1, 28, 28)
        # Combine image and time embedding (here, simply add them)
        x = x + t_emb
        return self.conv(x)

# Initialize model
model = SimpleDiffusionModel().to(device)


## Diffusion Process Setup

We now define the forward diffusion (adding noise) schedule. In a diffusion model, noise is added gradually over T timesteps.


In [None]:
# Hyperparameters for diffusion
T = 100  # total diffusion steps (for demonstration; real models use many more)
beta_start = 1e-4
beta_end = 0.02

# Linear beta schedule
betas = torch.linspace(beta_start, beta_end, T).to(device)  # shape (T,)
alphas = 1 - betas
alphas_hat = torch.cumprod(alphas, dim=0)  # cumulative product for each timestep

# Helper function: forward diffusion (q_sample)
def q_sample(x0, t, noise=None):
    """
    Diffuse the data (add noise) for a given timestep.
    x0: original image, shape (batch, 1, 28, 28)
    t: timestep tensor, shape (batch,)
    noise: optional noise tensor; if None, sampled from standard normal
    """
    if noise is None:
        noise = torch.randn_like(x0)
    # Get corresponding alphas_hat for each timestep t (need to index properly)
    # Expand dims so that we can multiply with x0
    sqrt_alphas_hat = torch.sqrt(alphas_hat[t]).view(-1, 1, 1, 1)
    sqrt_one_minus_alphas_hat = torch.sqrt(1 - alphas_hat[t]).view(-1, 1, 1, 1)
    return sqrt_alphas_hat * x0 + sqrt_one_minus_alphas_hat * noise


## Training Loop

For each training step:
- Sample a batch of images.
- Randomly choose a timestep \( t \) for each image.
- Compute the noised image using our forward diffusion function.
- Have the model predict the noise given the noised image and the timestep.
- Compute the loss (MSE) between the true noise and the model’s prediction.
- Log training loss to wandb.


In [None]:
# Training parameters
num_epochs = 5  # For demonstration, keep epochs small
learning_rate = 1e-3

optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
model.train()
for epoch in range(num_epochs):
    epoch_losses = []
    for batch_idx, (x0, _) in enumerate(train_loader):
        x0 = x0.to(device)  # shape (batch, 1, 28, 28)
        batch_size_current = x0.size(0)
        
        # Sample a random timestep for each image in the batch
        t = torch.randint(0, T, (batch_size_current,), device=device)
        
        # Sample noise
        noise = torch.randn_like(x0)
        # Generate noised image at timestep t
        x_noisy = q_sample(x0, t, noise)
        
        # Predict the noise using our model
        noise_pred = model(x_noisy, t)
        
        # Compute loss: Mean Squared Error between actual noise and predicted noise
        loss = F.mse_loss(noise_pred, noise)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_losses.append(loss.item())
        wandb.log({"loss": loss.item()})
        
        # For demonstration, log one batch every 500 iterations
        if batch_idx % 500 == 0:
            print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}")
    
    avg_loss = np.mean(epoch_losses)
    print(f"Epoch {epoch+1} Average Loss: {avg_loss:.4f}")


## Sampling: Reverse Diffusion Process

After training, we can sample new images by starting from pure noise and iteratively denoising. The reverse update (for a given timestep \( t \)) is approximated by:
  
\[
x_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}} \hat{\epsilon}_\theta(x_t, t) \right) + \sigma_t z
\]
  
For simplicity, our implementation uses a basic version without elaborate noise scheduling.


In [None]:
@torch.no_grad()
def sample_image(model, image_size=28):
    # Start from random noise
    x = torch.randn(1, 1, image_size, image_size).to(device)
    for t in reversed(range(T)):
        t_tensor = torch.full((1,), t, device=device, dtype=torch.long)
        beta = betas[t]
        alpha = alphas[t]
        alpha_hat_t = alphas_hat[t]
        
        # Predict noise at current step
        pred_noise = model(x, t_tensor)
        
        # Compute the reverse update (simplified version)
        x = (1 / torch.sqrt(alpha)) * (x - (beta / torch.sqrt(1 - alpha_hat_t)) * pred_noise)
        
        # For t > 0, add noise
        if t > 0:
            noise = torch.randn_like(x)
            x = x + torch.sqrt(beta) * noise
    # Clamp values to [0,1] for visualization
    return torch.clamp(x, 0., 1.)

# Generate a sample image
sampled_img = sample_image(model)


## Visualization

We use Matplotlib and Seaborn to visualize the generated image and plot training loss curves.


In [None]:
# Visualize a generated sample image
def show_image(img_tensor, title="Generated Image"):
    img = img_tensor.squeeze().cpu().numpy()
    plt.figure(figsize=(3,3))
    sns.heatmap(img, cmap="gray", cbar=False, xticklabels=False, yticklabels=False)
    plt.title(title)
    plt.axis("off")
    plt.show()

show_image(sampled_img, title="Sampled Image from Diffusion Model")


## Conclusion and Next Steps

In this notebook, we:

- Introduced the basics of diffusion models.
- Built a simplified diffusion model on MNIST.
- Implemented the forward diffusion (noise addition) and reverse (denoising/sampling) processes.
- Trained the model while tracking loss with wandb.
- Sampled and visualized new images from our model.

**Next Steps:**
- Experiment with more complex architectures (e.g., U-Net) and datasets (CIFAR-10, CelebA).
- Explore conditional diffusion models and latent diffusion methods.
- Fine-tune hyperparameters and training strategies for improved sample quality.
- Review state-of-the-art implementations using the Hugging Face diffusers library.

Happy experimenting!
