In [None]:
!nvidia-smi

# DiT Training Notebook

This notebook trains the **Diffusion Transformer (DiT)** model on the MNIST dataset.

## Training Objective

The model learns to predict the noise that was added to an image during the forward diffusion process. This is the standard **DDPM** (Denoising Diffusion Probabilistic Models) approach.

## Training Loop

1. Sample a batch of clean images from the dataset
2. Sample random timesteps for each image
3. Add noise to images according to the forward diffusion process
4. Use the DiT model to predict the noise
5. Compute loss between predicted and actual noise
6. Backpropagate and update model weights

In [None]:
%load_ext autoreload
%autoreload 2
import os
import torch
from torch import nn
from torch.utils.data import DataLoader

from config import T
from dataset import MNIST
from diffusion import forward_add_noise
from dit import DiT

## 1. Configuration

Set up device and training hyperparameters.

In [None]:
# Device selection: use GPU if available, otherwise CPU
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")
print(f"Device name: {torch.cuda.get_device_name(DEVICE)}")
print(f"Number of GPUs: {torch.cuda.device_count()}")

# Training hyperparameters
EPOCH = 500  # Number of training epochs
BATCH_SIZE = 1000  # Number of images per batch
LEARNING_RATE = 1e-3  # Adam optimizer learning rate

print(f"Epochs: {EPOCH}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Learning rate: {LEARNING_RATE}")

## 2. Dataset

Load the MNIST dataset and create a DataLoader for batching.

In [None]:
# Load MNIST training dataset
# Each sample is a tuple of (image_tensor, label)
# Image shape: (1, 28, 28), pixel range: [0, 1]
dataset = MNIST()
print(f"Dataset size: {len(dataset)} images")

In [None]:
# Create DataLoader for batching and shuffling
# - shuffle=True: Randomize order each epoch for better training
# - num_workers=10: Use 10 parallel workers for data loading
# - persistent_workers=True: Keep workers alive between epochs (faster)
dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=10,
    persistent_workers=True,
)

print(f"Number of batches per epoch: {len(dataloader)}")

## 3. Model

Initialize the DiT model with MNIST-specific parameters.

In [None]:
# Initialize the DiT model:
# - img_size=28: MNIST images are 28x28 pixels
# - patch_size=4: Split into 4x4 patches (7x7 = 49 patches total)
# - channel=1: Grayscale images (1 channel)
# - emb_size=128: 128-dimensional token embeddings
# - label_num=10: 10 classes (digits 0-9)
# - dit_num=6: 6 DiT transformer blocks
# - head=8: 8 attention heads per block
model = DiT(
    img_size=28,
    patch_size=4,
    channel=1,
    emb_size=128,
    label_num=10,
    dit_num=6,
    head=8,
).to(DEVICE)

# Count parameters
num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params:,}")

In [None]:
# Try to load a previously saved model checkpoint
# This allows resuming training from where we left off
try:
    model.load_state_dict(torch.load("model.pth"))
    print("Loaded existing model checkpoint")
except FileNotFoundError:
    print("No checkpoint found, starting from scratch")
except Exception as e:
    print(f"Could not load checkpoint: {e}")

## 4. Optimizer and Loss Function

Set up the optimizer and loss function for training.

In [None]:
# AdamW optimizer with learning rate of 1e-3
# AdamW is Adam with decoupled weight decay regularization
# It's often preferred over Adam for training transformers
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

# L1 Loss (Mean Absolute Error)
# Measures average absolute difference between predicted and actual noise
# Alternative: nn.MSELoss() (Mean Squared Error) is also commonly used
loss_fn = nn.L1Loss()

print(f"Optimizer: AdamW")
print(f"Loss function: L1Loss (Mean Absolute Error)")

## 5. Training Loop

The main training loop that iterates over epochs and batches.

In [None]:
# Set model to training mode
# This enables dropout, batch norm training behavior, etc.
model.train()

# Iteration counter for logging and checkpointing
iter_count = 0

print(f"Starting training for {EPOCH} epochs...")
print(f"Batch size: {BATCH_SIZE}, Batches per epoch: {len(dataloader)}")

In [None]:
for epoch in range(EPOCH):
    for imgs, labels in dataloader:
        # ==== Step 1: Prepare the data ====
        
        # Normalize pixel values from [0, 1] to [-1, 1]
        # This matches the range of Gaussian noise (mean=0, std=1)
        x = imgs * 2 - 1  # Shape: (batch, 1, 28, 28)
        
        # Sample random timesteps for each image in the batch
        t = torch.randint(0, T, (imgs.size(0),))  # Shape: (batch,)
        
        # Class labels for conditional generation
        y = labels  # Shape: (batch,)
        
        # ==== Step 2: Forward diffusion (add noise) ====
        
        # Add noise to images according to the forward diffusion process
        x, noise = forward_add_noise(x, t)
        
        # ==== Step 3: Model prediction ====
        
        # The model predicts what noise was added to the image
        pred_noise = model(x.to(DEVICE), t.to(DEVICE), y.to(DEVICE))
        
        # ==== Step 4: Compute loss ====
        
        # Compare predicted noise with actual noise using L1 loss
        loss = loss_fn(pred_noise, noise.to(DEVICE))
        
        # ==== Step 5: Backpropagation and optimization ====
        
        optimizer.zero_grad()  # Clear gradients
        loss.backward()  # Compute gradients
        optimizer.step()  # Update weights
        
        # ==== Step 6: Logging and checkpointing ====
        
        if iter_count % 1000 == 0:
            print(f"Epoch: {epoch}, Iter: {iter_count}, Loss: {loss.item():.6f}")
            # Save checkpoint atomically
            torch.save(model.state_dict(), ".model.pth")
            os.replace(".model.pth", "model.pth")
        
        iter_count += 1

print("Training complete!")