## 1. Environment Setup

In [None]:
# If you are on Colab, uncomment this to ensure all dependencies are there.
# In most recent Colab environments, torch/torchvision are already installed.
# !pip install torch torchvision matplotlib tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils

import math
from tqdm.auto import tqdm
import matplotlib.pyplot as plt


## 2. Device and Random Seed

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

def set_seed(seed=42):
    import random, numpy as np
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)


## 3. Configuration

In [None]:
# Training hyperparameters
batch_size = 128
num_epochs = 50        # For a quick classroom demo
lr = 2e-4

# Diffusion hyperparameters
T = 250                # Number of diffusion steps (smaller for speed)
beta_start = 1e-4
beta_end = 0.02

# Data / image settings
image_size = 28
num_channels = 1       # MNIST is grayscale


## 4. MNIST Dataset and DataLoader

In [None]:
# Normalize images to [-1, 1] (common for diffusion models)
transform = transforms.Compose([
    transforms.ToTensor(),              # [0, 1]
    transforms.Normalize((0.5,), (0.5,))  # -> [-1, 1] for 1 channel
])

train_dataset = datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform,
)

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=True,
)

len(train_dataset), len(train_loader)


## 5. Diffusion Schedule & Forward (Noising) Process

In [None]:
# Linear beta schedule
betas = torch.linspace(beta_start, beta_end, T).to(device)   # (T,)
alphas = 1.0 - betas                                         # (T,)
alphas_cumprod = torch.cumprod(alphas, dim=0)                # alpha_bar_t
alphas_cumprod_prev = torch.cat(
    [torch.tensor([1.0], device=device), alphas_cumprod[:-1]],
    dim=0,
)

# Precomputed terms
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)


In [None]:
def get_index_from_list(vals, t, x_shape):
    # Helper to pick values for a batch of indices t and reshape to x_shape.
    # vals: (T,) tensor, t: (B,) tensor of timesteps.
    batch_size = t.shape[0]
    out = vals.gather(-1, t)              # (B,)
    return out.view(batch_size, *((1,) * (len(x_shape) - 1)))


In [None]:
def q_sample(x_start, t, noise=None):
    # Diffuse the data: q(x_t | x_0).
    # x_start: (B, C, H, W), t: (B,) timesteps.
    if noise is None:
        noise = torch.randn_like(x_start)

    sqrt_alphas_cumprod_t = get_index_from_list(
        sqrt_alphas_cumprod, t, x_start.shape
    )
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x_start.shape
    )

    # q(x_t | x_0) = N( sqrt(alpha_bar_t) * x_0, (1 - alpha_bar_t) I )
    return sqrt_alphas_cumprod_t * x_start + \
           sqrt_one_minus_alphas_cumprod_t * noise


### 5.1 Visualizing the Forward (Noising) Process

In [None]:
def show_image_grid(images, title='', nrow=8):
    # Denormalize from [-1,1] back to [0,1]
    images = (images + 1) * 0.5
    images = torch.clamp(images, 0.0, 1.0)
    grid = utils.make_grid(images, nrow=nrow)
    plt.figure(figsize=(6, 6))
    plt.title(title)
    plt.axis('off')
    # For single-channel images, repeat channels for visualization
    grid_np = grid.cpu().numpy()
    if grid_np.shape[0] == 1:
        grid_np = grid_np.repeat(3, axis=0)
    plt.imshow(grid_np.transpose(1, 2, 0))
    plt.show()

# Take a small batch and show different timesteps
examples = next(iter(train_loader))[0][:8].to(device)  # (8,1,28,28)

with torch.no_grad():
    t_steps = torch.tensor(
        [0, T//10, T//5, T//2, T-1], device=device
    )
    imgs = []
    for t_scalar in t_steps:
        t = torch.full(
            (examples.shape[0],),
            t_scalar,
            device=device,
            dtype=torch.long,
        )
        noisy = q_sample(examples, t)
        imgs.append(noisy)

imgs_to_show = torch.cat(imgs, dim=0)
show_image_grid(
    imgs_to_show,
    title='Forward noising at different timesteps (groups from t=0 to tâ‰ˆT)',
)


## 6. U-Net Noise Predictor

In [None]:
class SinusoidalPosEmb(nn.Module):
    # Sinusoidal timestep embedding (like in the original DDPM).
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        # t: (B,) int64 timesteps -> (B, dim) embedding
        device = t.device
        half_dim = self.dim // 2
        emb_factor = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb_factor)
        emb = t.float().unsqueeze(1) * emb.unsqueeze(0)
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
        return emb


In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
        self.time_mlp = nn.Linear(time_emb_dim, out_ch)
        self.act = nn.SiLU()
        self.res_conv = (
            nn.Conv2d(in_ch, out_ch, kernel_size=1)
            if in_ch != out_ch else nn.Identity()
        )

    def forward(self, x, t_emb):
        # x: (B, C, H, W), t_emb: (B, time_emb_dim)
        h = self.conv1(x)
        time_emb = self.time_mlp(t_emb)           # (B, out_ch)
        h = h + time_emb[..., None, None]         # broadcast
        h = self.act(h)
        h = self.conv2(h)
        return self.act(h + self.res_conv(x))


In [None]:
class UNet(nn.Module):
    def __init__(self,
                 in_channels=1,
                 base_channels=32,
                 channel_mults=(1, 2, 4),
                 time_emb_dim=128):
        super().__init__()

        # Time embedding MLP
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim * 4),
            nn.SiLU(),
            nn.Linear(time_emb_dim * 4, time_emb_dim),
        )

        # Downsampling path
        dims = [base_channels * m for m in channel_mults]
        self.conv_in = nn.Conv2d(in_channels, base_channels, kernel_size=3, padding=1)

        self.downs = nn.ModuleList()
        self.downs_pools = nn.ModuleList()
        in_ch = base_channels
        for dim in dims:
            self.downs.append(ResidualBlock(in_ch, dim, time_emb_dim))
            self.downs_pools.append(nn.MaxPool2d(2))
            in_ch = dim

        # Middle blocks
        self.mid_block1 = ResidualBlock(in_ch, in_ch, time_emb_dim)
        self.mid_block2 = ResidualBlock(in_ch, in_ch, time_emb_dim)

        # Upsampling path
        self.ups_transpose = nn.ModuleList()
        self.ups = nn.ModuleList()
        rev_dims = list(reversed(dims))
        for dim in rev_dims:
            self.ups_transpose.append(
                nn.ConvTranspose2d(in_ch, dim, kernel_size=2, stride=2)
            )
            self.ups.append(
                ResidualBlock(dim * 2, dim, time_emb_dim)  # concat skip
            )
            in_ch = dim

        self.conv_out = nn.Conv2d(base_channels, in_channels, kernel_size=3, padding=1)

    def forward(self, x, t):
        # t: (B,)
        t_emb = self.time_mlp(t)  # (B, time_emb_dim)

        # Down
        x = self.conv_in(x)
        skips = []
        for down, pool in zip(self.downs, self.downs_pools):
            x = down(x, t_emb)
            skips.append(x)
            x = pool(x)

        # Middle
        x = self.mid_block1(x, t_emb)
        x = self.mid_block2(x, t_emb)

        # Up
        for up_transpose, up, skip in zip(
            self.ups_transpose, self.ups, reversed(skips)
        ):
            x = up_transpose(x)
            # Handle any size mismatch (padding)
            if x.shape[-1] != skip.shape[-1]:
                diff = skip.shape[-1] - x.shape[-1]
                x = F.pad(x, (0, diff, 0, diff))
            x = torch.cat([x, skip], dim=1)
            x = up(x, t_emb)

        x = self.conv_out(x)
        return x

model = UNet(in_channels=num_channels).to(device)
print('Model parameters (M):', sum(p.numel() for p in model.parameters()) / 1e6)


## 7. Training Objective (Noise Prediction Loss)

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

def p_losses(model, x_start, t, noise=None):
    # Simplified DDPM loss: MSE between true noise and predicted noise.
    if noise is None:
        noise = torch.randn_like(x_start)

    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    predicted_noise = model(x_noisy, t)
    loss = F.mse_loss(predicted_noise, noise)
    return loss


## 8. Sampling (Reverse Diffusion)

In [None]:
@torch.no_grad()
def p_sample(model, x, t):
    # Sample x_{t-1} given x_t using the model's noise prediction.
    betas_t = get_index_from_list(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = get_index_from_list(
        sqrt_recip_alphas, t, x.shape
    )

    # Equation 11 from DDPM.
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t / sqrt_one_minus_alphas_cumprod_t * model(x, t)
    )

    # If t == 0, we skip adding noise.
    if (t == 0).all():
        return model_mean

    posterior_var_t = get_index_from_list(posterior_variance, t, x.shape)
    noise = torch.randn_like(x)
    return model_mean + torch.sqrt(posterior_var_t) * noise


In [None]:
@torch.no_grad()
def p_sample_loop(model, shape):
    # Iteratively sample from x_T ~ N(0, I) down to x_0.
    b = shape[0]
    img = torch.randn(shape, device=device)
    for i in tqdm(reversed(range(T)), total=T, desc='Sampling', leave=False):
        t = torch.full((b,), i, device=device, dtype=torch.long)
        img = p_sample(model, img, t)
    return img

@torch.no_grad()
def sample(model, n_samples=16):
    model.eval()
    imgs = p_sample_loop(
        model,
        shape=(n_samples, num_channels, image_size, image_size),
    )
    return imgs


## 9. Training Loop

In [None]:
def train(model, train_loader, num_epochs):
    global_step = 0
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0

        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
        for x, _ in pbar:
            x = x.to(device)  # (B,1,28,28)
            # Sample a timestep for each example
            t = torch.randint(0, T, (x.shape[0],), device=device).long()

            loss = p_losses(model, x, t)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            global_step += 1
            pbar.set_postfix(loss=loss.item())

        avg_loss = epoch_loss / len(train_loader)
        print(f'Epoch {epoch+1}: avg loss = {avg_loss:.4f}')

        # Generate sample images at the end of each epoch
        with torch.no_grad():
            samples = sample(model, n_samples=16)
            show_image_grid(samples, title=f'Samples after epoch {epoch+1}')

        # Optional: save checkpoint
        torch.save(model.state_dict(), f'ddpm_mnist_epoch{epoch+1}.pth')

# Uncomment this to start training in Colab.
train(model, train_loader, num_epochs=num_epochs)


## 10. Load a Trained Model and Generate Samples (Optional)

In [None]:
# Example usage if you already have a checkpoint:
# model.load_state_dict(torch.load('ddpm_mnist_epoch10.pth', map_location=device))
# model.eval()

# with torch.no_grad():
#     samples = sample(model, n_samples=16)
#     show_image_grid(samples, title='Generated samples from trained DDPM (MNIST)')
