# DiT Inference Notebook

This notebook performs **inference (image generation)** using a trained DiT model.

## Reverse Diffusion Process

The reverse process works by:
1. Starting with pure random noise ($x_T$)
2. Iteratively denoising from timestep $T-1$ down to $0$
3. At each step, the model predicts the noise component
4. We remove part of the predicted noise to get a cleaner image

## DDPM Reverse Formula

$$\mu_\theta(x_t, t) = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta(x_t, t) \right)$$

$$x_{t-1} = \mu_\theta(x_t, t) + \sigma_t \cdot z, \quad z \sim \mathcal{N}(0, \mathbf{I})$$

In [None]:
%load_ext autoreload
%autoreload 2
import torch
import matplotlib.pyplot as plt

from config import T
from diffusion import alphas, alphas_cumprod, variance
from dit import DiT

## 1. Configuration

In [None]:
# Device selection: use GPU if available, otherwise CPU
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")
print(f"Total diffusion timesteps: {T}")

## 2. Define the Reverse Diffusion Function

This function iteratively denoises an image from pure noise to a clean generated image.

In [None]:
def backward_denoise(model, x, y):
    """
    Perform the reverse diffusion process to generate images from noise.
    
    Args:
        model: Trained DiT model
        x: Initial noise tensor, shape (batch, channel, height, width)
        y: Class labels for conditional generation, shape (batch,)
    
    Returns:
        steps: List of tensors showing the denoising progression
    """
    # Store intermediate results for visualization
    steps = [x.clone()]
    
    # Move diffusion parameters to the correct device
    print("Device of x before move:", x.device)
    alphas_device = alphas.to(DEVICE)
    alphas_cumprod_device = alphas_cumprod.to(DEVICE)
    variance_device = variance.to(DEVICE)
    
    # Move input tensors to device
    x = x.to(DEVICE)
    y = y.to(DEVICE)
    print("Device of x after move:", x.device)
    # Set model to evaluation mode
    model.eval()
    
    # Disable gradient computation for inference
    with torch.no_grad():
        # Iterate backwards from T-1 to 0
        for time in range(T - 1, -1, -1):
            # Current timestep for the batch
            t = torch.full((x.size(0),), time).to(DEVICE)
            
            # Step 1: Predict the noise at timestep t
            noise = model(x, t, y)
            
            # Step 2: Compute the mean of x_{t-1}
            shape = (x.size(0), 1, 1, 1)
            alpha_t = alphas_device[t].view(*shape)
            alpha_cumprod_t = alphas_cumprod_device[t].view(*shape)
            variance_t = variance_device[t].view(*shape)
            
            # DDPM formula for posterior mean
            # The key formula for DDPM reverse process.
            mean = (1 / torch.sqrt(alpha_t)) * (
                x - (1 - alpha_t) / torch.sqrt(1 - alpha_cumprod_t) * noise
            )
            # We are not like subtracting the noise from x_t,
            # but calculating the mean and sample from the Gaussian distribution.
            
            # Step 3: Sample x_{t-1}
            if time != 0:
                # Add noise for t > 0
                x = mean + torch.randn_like(x) * torch.sqrt(variance_t)
            else:
                # No noise for final step
                x = mean
            
            # Clamp to prevent numerical instability
            x = torch.clamp(x, -1.0, 1.0).detach()
            steps.append(x.clone())
    
    return steps

## 3. Load Trained Model

In [None]:
# Initialize the DiT model with the same architecture used during training
model = DiT(
    img_size=28,
    patch_size=4,
    channel=1,
    emb_size=64,
    label_num=10,
    dit_num=3,
    head=4,
).to(DEVICE)

print(f"Model architecture: DiT with {sum(p.numel() for p in model.parameters()):,} parameters")

In [None]:
# Load the trained weights
model.load_state_dict(torch.load("model.pth"))
print("Model weights loaded successfully!")

## 4. Generate Images

Start with pure noise and denoise to generate images of each digit (0-9).

In [None]:
# Number of images to generate (one for each digit 0-9)
batch_size = 10

# Start with pure random noise
x = torch.randn(size=(batch_size, 1, 28, 28))

# Class labels: generate one of each digit (0, 1, 2, ..., 9)
y = torch.arange(start=0, end=10, dtype=torch.long)

print(f"Generating {batch_size} images...")
print(f"Class labels: {y.tolist()}")
print(f"Initial noise shape: {x.shape}")

In [None]:
# Run the reverse diffusion process
steps = backward_denoise(model, x, y)

print(f"Denoising complete!")
print(f"Generated {len(steps)} intermediate steps (including initial noise)")

## 5. Visualize Final Results

Show the generated images for each digit.

In [None]:
# Display the final generated images
plt.figure(figsize=(15, 2))

for i in range(batch_size):
    # Get final image (last step)
    final_img = (steps[-1][i].to("cpu") + 1) / 2  # Convert [-1,1] to [0,1]
    final_img = final_img.permute(1, 2, 0)  # (C,H,W) -> (H,W,C)
    
    plt.subplot(1, batch_size, i + 1)
    plt.imshow(final_img, cmap="gray")
    plt.title(f"Digit {i}")
    plt.axis("off")

plt.suptitle("Generated Digits (0-9)", fontsize=14)
plt.tight_layout()
plt.show()

## 6. Visualize Denoising Process

Show how each image evolves from noise to the final digit.

In [None]:
# Number of intermediate steps to show
num_imgs = 20

# Create a grid: rows = digits, columns = denoising steps
plt.figure(figsize=(15, 15))

for b in range(batch_size):
    for i in range(num_imgs):
        # Calculate which step to show (evenly spaced)
        idx = int(T / num_imgs) * (i + 1)
        
        # Get image and convert to displayable format
        img = (steps[idx][b].to("cpu") + 1) / 2
        img = img.permute(1, 2, 0)
        
        plt.subplot(batch_size, num_imgs, b * num_imgs + i + 1)
        plt.imshow(img, cmap="gray")
        plt.axis("off")

plt.suptitle("Denoising Process: Noise â†’ Generated Digit\n(Each row is a different digit)", fontsize=14)
plt.tight_layout()
plt.show()

## 7. Single Digit Deep Dive

Let's look more closely at the denoising process for a single digit.

In [None]:
# Choose a digit to examine
digit_to_show = 9

# Show more steps for detailed view
steps_to_show = [0, 50, 100, 200, 300, 400, 500, 600, 700, 800, 900, T]

plt.figure(figsize=(15, 2))

for i, step_idx in enumerate(steps_to_show):
    # Clamp step_idx to valid range
    step_idx = min(step_idx, len(steps) - 1)
    
    img = (steps[step_idx][digit_to_show].to("cpu") + 1) / 2
    img = img.permute(1, 2, 0)
    
    plt.subplot(1, len(steps_to_show), i + 1)
    plt.imshow(img, cmap="gray")
    plt.title(f"t={T - step_idx}")
    plt.axis("off")

plt.suptitle(f"Denoising Process for Digit {digit_to_show}", fontsize=14)
plt.tight_layout()
plt.show()

## 8. Save Results

In [None]:
# Save the final generated images
plt.figure(figsize=(15, 15))

for b in range(batch_size):
    for i in range(num_imgs):
        idx = int(T / num_imgs) * (i + 1)
        img = (steps[idx][b].to("cpu") + 1) / 2
        img = img.permute(1, 2, 0)
        
        plt.subplot(batch_size, num_imgs, b * num_imgs + i + 1)
        plt.imshow(img, cmap="gray")
        plt.axis("off")

plt.savefig("inference.png", dpi=150, bbox_inches="tight")
print("Saved visualization to inference.png")