# DDPM 完整实现：训练与采样

本 Notebook 实现完整的 DDPM，包括：
1. 带时间嵌入的简化 U-Net
2. 训练算法（预测噪声）
3. 采样算法（逆向去噪）

In [None]:
import math

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 超参数
IMG_SIZE = 28
CHANNELS = 1
TIMESTEPS = 1000
BATCH_SIZE = 128
EPOCHS = 30
LR = 1e-3

## 1. 噪声调度器 (DDPMScheduler)

In [None]:
class DDPMScheduler:
    """DDPM 噪声调度器"""

    def __init__(self, timesteps=1000, beta_start=1e-4, beta_end=0.02):
        self.timesteps = timesteps
        self.betas = torch.linspace(beta_start, beta_end, timesteps)
        self.alphas = 1.0 - self.betas
        self.alpha_bars = torch.cumprod(self.alphas, dim=0)

        # 预计算系数
        self.sqrt_alpha_bars = torch.sqrt(self.alpha_bars)
        self.sqrt_one_minus_alpha_bars = torch.sqrt(1.0 - self.alpha_bars)

        # 采样时需要的系数
        self.sqrt_alphas = torch.sqrt(self.alphas)
        self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)

    def add_noise(self, x0, t, noise=None):
        """前向加噪: x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1-alpha_bar_t) * noise"""
        if noise is None:
            noise = torch.randn_like(x0)

        sqrt_alpha_bar = self.sqrt_alpha_bars[t].view(-1, 1, 1, 1).to(x0.device)
        sqrt_one_minus_alpha_bar = self.sqrt_one_minus_alpha_bars[t].view(-1, 1, 1, 1).to(x0.device)

        x_t = sqrt_alpha_bar * x0 + sqrt_one_minus_alpha_bar * noise
        return x_t, noise


scheduler = DDPMScheduler(timesteps=TIMESTEPS)

## 2. 时间嵌入 (Sinusoidal Position Embedding)

将标量时间步 $t$ 映射为高维向量，使用正弦位置编码：
$$PE(t, 2i) = \sin(t / 10000^{2i/d})$$
$$PE(t, 2i+1) = \cos(t / 10000^{2i/d})$$

In [None]:
class SinusoidalPositionEmbedding(nn.Module):
    """正弦位置编码，将时间步 t 映射为向量"""

    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
        return emb

## 3. 简化 U-Net

结构：Encoder (下采样) -> Bottleneck -> Decoder (上采样) + Skip Connections

In [None]:
class SimpleUNet(nn.Module):
    """简化版 U-Net，带时间嵌入"""

    def __init__(self, in_channels=1, time_dim=256):
        super().__init__()

        # 时间嵌入
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbedding(time_dim), nn.Linear(time_dim, time_dim), nn.ReLU()
        )

        # Encoder (下采样)
        self.conv1 = nn.Conv2d(in_channels, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, stride=2, padding=1)  # 28->14
        self.conv3 = nn.Conv2d(128, 256, 3, stride=2, padding=1)  # 14->7

        # 时间嵌入投影
        self.time_proj1 = nn.Linear(time_dim, 64)
        self.time_proj2 = nn.Linear(time_dim, 128)
        self.time_proj3 = nn.Linear(time_dim, 256)

        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(256, 256, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, 3, padding=1),
            nn.ReLU(),
        )

        # Decoder (上采样) + Skip Connections
        self.up1 = nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1)  # 7->14
        self.conv4 = nn.Conv2d(256, 128, 3, padding=1)  # 128+128 skip

        self.up2 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1)  # 14->28
        self.conv5 = nn.Conv2d(128, 64, 3, padding=1)  # 64+64 skip

        # 输出层
        self.out = nn.Conv2d(64, in_channels, 1)

    def forward(self, x, t):
        # 时间嵌入
        t_emb = self.time_mlp(t.float())

        # Encoder
        h1 = F.relu(self.conv1(x) + self.time_proj1(t_emb)[:, :, None, None])
        h2 = F.relu(self.conv2(h1) + self.time_proj2(t_emb)[:, :, None, None])
        h3 = F.relu(self.conv3(h2) + self.time_proj3(t_emb)[:, :, None, None])

        # Bottleneck
        h = self.bottleneck(h3)

        # Decoder with skip connections
        h = self.up1(h)
        h = torch.cat([h, h2], dim=1)
        h = F.relu(self.conv4(h))

        h = self.up2(h)
        h = torch.cat([h, h1], dim=1)
        h = F.relu(self.conv5(h))

        return self.out(h)


model = SimpleUNet(in_channels=CHANNELS).to(device)
print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")

## 4. 数据加载

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]  # 归一化到 [-1, 1]
)

train_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
print(f"训练集大小: {len(train_dataset)}")

## 5. 训练算法

**Algorithm 1 (Training)**:
1. 采样 $x_0 \sim q(x_0)$
2. 采样 $t \sim \text{Uniform}(1, T)$
3. 采样 $\epsilon \sim \mathcal{N}(0, I)$
4. 计算 $x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t} \epsilon$
5. 梯度下降: $\nabla_\theta \|\epsilon - \epsilon_\theta(x_t, t)\|^2$

In [None]:
def train_step(model, x0, scheduler, optimizer):
    """单步训练"""
    batch_size = x0.shape[0]

    # 1. 随机采样时间步 t
    t = torch.randint(0, scheduler.timesteps, (batch_size,), device=x0.device)

    # 2. 随机采样噪声
    noise = torch.randn_like(x0)

    # 3. 构造加噪图片 x_t
    x_t, _ = scheduler.add_noise(x0, t, noise)

    # 4. 模型预测噪声
    noise_pred = model(x_t, t)

    # 5. 计算 MSE Loss
    loss = F.mse_loss(noise_pred, noise)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item()

## 6. 采样算法

**Algorithm 2 (Sampling)**:
1. $x_T \sim \mathcal{N}(0, I)$
2. For $t = T, ..., 1$:
   - $z \sim \mathcal{N}(0, I)$ if $t > 1$, else $z = 0$
   - $x_{t-1} = \frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_\theta(x_t, t)\right) + \sigma_t z$

In [None]:
@torch.no_grad()
def sample(model, scheduler, n_samples=64, img_size=28, channels=1):
    """DDPM 采样算法"""
    model.eval()

    # 从纯噪声开始
    x = torch.randn(n_samples, channels, img_size, img_size).to(device)

    # 逆向去噪
    for t in reversed(range(scheduler.timesteps)):
        t_batch = torch.full((n_samples,), t, device=device, dtype=torch.long)

        # 预测噪声
        noise_pred = model(x, t_batch)

        # 获取系数
        alpha = scheduler.alphas[t]
        alpha_bar = scheduler.alpha_bars[t]
        beta = scheduler.betas[t]

        # 计算 x_{t-1}
        # x_{t-1} = 1/sqrt(alpha_t) * (x_t - (1-alpha_t)/sqrt(1-alpha_bar_t) * noise_pred) + sigma_t * z
        coef1 = 1 / torch.sqrt(alpha)
        coef2 = (1 - alpha) / torch.sqrt(1 - alpha_bar)

        x = coef1 * (x - coef2 * noise_pred)

        # 添加噪声 (t > 0 时)
        if t > 0:
            sigma = torch.sqrt(beta)
            z = torch.randn_like(x)
            x = x + sigma * z

    model.train()
    return x

In [None]:
def show_samples(samples, title="Generated Samples"):
    """展示生成的样本 (8x8 grid)"""
    samples = (samples + 1) / 2  # [-1,1] -> [0,1]
    samples = samples.clamp(0, 1)
    grid = make_grid(samples, nrow=8, padding=2)

    plt.figure(figsize=(10, 10))
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy(), cmap="gray")
    plt.title(title)
    plt.axis("off")
    plt.show()

## 7. 训练循环

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
losses = []

print("开始训练 DDPM...")
print("-" * 50)

for epoch in range(1, EPOCHS + 1):
    epoch_loss = 0
    for batch_idx, (x0, _) in enumerate(train_loader):
        x0 = x0.to(device)
        loss = train_step(model, x0, scheduler, optimizer)
        epoch_loss += loss

    avg_loss = epoch_loss / len(train_loader)
    losses.append(avg_loss)
    print(f"Epoch {epoch:2d}/{EPOCHS} | Loss: {avg_loss:.4f}")

    # 每 5 个 epoch 生成样本
    if epoch % 5 == 0:
        samples = sample(model, scheduler, n_samples=64)
        show_samples(samples, title=f"Epoch {epoch}")

print("-" * 50)
print("训练完成!")

In [None]:
# 绘制损失曲线
plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("DDPM Training Loss")
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
# 最终生成效果
print("生成最终样本...")
final_samples = sample(model, scheduler, n_samples=64)
show_samples(final_samples, title="Final Generated Samples")

## 8. 可视化去噪过程

In [None]:
@torch.no_grad()
def sample_with_trajectory(model, scheduler, n_samples=1, save_every=100):
    """采样并保存中间状态"""
    model.eval()
    trajectory = []

    x = torch.randn(n_samples, CHANNELS, IMG_SIZE, IMG_SIZE).to(device)
    trajectory.append(x.clone())

    for t in reversed(range(scheduler.timesteps)):
        t_batch = torch.full((n_samples,), t, device=device, dtype=torch.long)
        noise_pred = model(x, t_batch)

        alpha = scheduler.alphas[t]
        alpha_bar = scheduler.alpha_bars[t]
        beta = scheduler.betas[t]

        x = (1 / torch.sqrt(alpha)) * (x - (1 - alpha) / torch.sqrt(1 - alpha_bar) * noise_pred)

        if t > 0:
            x = x + torch.sqrt(beta) * torch.randn_like(x)

        if t % save_every == 0:
            trajectory.append(x.clone())

    model.train()
    return trajectory


# 可视化去噪轨迹
trajectory = sample_with_trajectory(model, scheduler, n_samples=1, save_every=100)

fig, axes = plt.subplots(1, len(trajectory), figsize=(20, 2))
for i, img in enumerate(trajectory):
    img = (img + 1) / 2
    axes[i].imshow(img[0, 0].cpu().numpy(), cmap="gray")
    axes[i].axis("off")
    t = 999 - i * 100 if i < len(trajectory) - 1 else 0
    axes[i].set_title(f"t={t}")

plt.suptitle("Denoising Trajectory: From Noise to Image", fontsize=14)
plt.tight_layout()
plt.show()