# DDPM (Denoising Diffusion Probabilistic Models) - Production Training

Google Colab無料版（12時間制限）で完了する本番環境用のnotebookです。

**設定:**
- Dataset: CIFAR-10
- Total steps: 100,000 (約10-11時間)
- Batch size: 128
- Timesteps: 1000
- GPU: T4 (Colab free tier)

**Reference:**
Ho et al. 2020 "Denoising Diffusion Probabilistic Models"

---

## Setup Instructions

1. Runtime > Change runtime type > GPU (T4)
2. Run all cells in order
3. Training will take approximately 10-11 hours
4. Checkpoints are saved to Google Drive (optional)

## 1. Environment Setup

In [None]:
# GPU確認
!nvidia-smi

In [None]:
# Google Driveをマウント（チェックポイント保存用）
from google.colab import drive
drive.mount('/content/drive')

# 出力ディレクトリをDriveに設定
import os
OUTPUT_DIR = '/content/drive/MyDrive/DDPM_outputs'
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Output directory: {OUTPUT_DIR}")

In [None]:
# 必要なパッケージのインストール
!pip install -q torch torchvision torchaudio
!pip install -q einops tqdm scipy pillow

print("Package installation complete!")

In [None]:
# インポート
import os
import sys
import random
import math
from pathlib import Path
from dataclasses import dataclass, field
from typing import Literal, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR10
from tqdm.auto import tqdm
from einops import rearrange
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
import json
from datetime import datetime

# Device設定
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

## 2. Configuration

In [None]:
@dataclass
class ModelConfig:
    """U-Net model configuration."""
    image_size: int = 32
    in_channels: int = 3
    out_channels: int = 3
    model_channels: int = 128
    channel_mult: tuple = (1, 2, 2, 2)
    num_res_blocks: int = 2
    attention_resolutions: tuple = (16,)
    dropout: float = 0.1
    num_heads: int = 4
    use_scale_shift_norm: bool = True

@dataclass
class DiffusionConfig:
    """Diffusion process configuration."""
    timesteps: int = 1000
    beta_schedule: Literal["linear", "cosine", "quadratic"] = "linear"
    beta_start: float = 1e-4
    beta_end: float = 0.02
    s: float = 0.008

@dataclass
class TrainingConfig:
    """Training configuration for Colab free tier (12 hours)."""
    batch_size: int = 128
    learning_rate: float = 2e-4
    total_steps: int = 100_000  # ~10-11 hours on T4
    warmup_steps: int = 5000
    grad_clip: float = 1.0
    ema_decay: float = 0.9999
    save_every: int = 5000
    sample_every: int = 5000
    log_every: int = 100
    num_workers: int = 2
    mixed_precision: bool = True

@dataclass
class Config:
    """Main configuration."""
    model: ModelConfig = field(default_factory=ModelConfig)
    diffusion: DiffusionConfig = field(default_factory=DiffusionConfig)
    training: TrainingConfig = field(default_factory=TrainingConfig)
    output_dir: str = OUTPUT_DIR
    data_dir: str = "./data"
    device: str = "cuda"
    seed: int = 42
    exp_name: str = "ddpm_colab_production"

# Create config
config = Config()
print("Configuration:")
print(f"  Total steps: {config.training.total_steps:,}")
print(f"  Batch size: {config.training.batch_size}")
print(f"  Timesteps: {config.diffusion.timesteps}")
print(f"  Estimated time: ~10-11 hours on T4 GPU")

In [None]:
# Seed設定
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(config.seed)
print(f"Random seed set to {config.seed}")

## 3. Model Definition

In [None]:
# Utility functions for sinusoidal position embeddings
def get_timestep_embedding(timesteps, embedding_dim):
    """
    Sinusoidal timestep embeddings.
    """
    half_dim = embedding_dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
    emb = timesteps.float()[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if embedding_dim % 2 == 1:
        emb = F.pad(emb, (0, 1, 0, 0))
    return emb

class TimestepEmbedding(nn.Module):
    def __init__(self, dim, out_dim):
        super().__init__()
        self.linear1 = nn.Linear(dim, out_dim)
        self.act = nn.SiLU()
        self.linear2 = nn.Linear(out_dim, out_dim)

    def forward(self, x):
        x = self.linear1(x)
        x = self.act(x)
        x = self.linear2(x)
        return x

def normalization(channels):
    return nn.GroupNorm(32, channels)

class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_channels, dropout, use_scale_shift_norm=True):
        super().__init__()
        self.use_scale_shift_norm = use_scale_shift_norm
        
        self.norm1 = normalization(in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        
        self.time_emb = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_channels, out_channels * 2 if use_scale_shift_norm else out_channels),
        )
        
        self.norm2 = normalization(out_channels)
        self.dropout = nn.Dropout(dropout)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, 1)
        else:
            self.shortcut = nn.Identity()

    def forward(self, x, t):
        h = self.norm1(x)
        h = F.silu(h)
        h = self.conv1(h)
        
        t_emb = self.time_emb(t)[:, :, None, None]
        
        if self.use_scale_shift_norm:
            scale, shift = t_emb.chunk(2, dim=1)
            h = self.norm2(h) * (1 + scale) + shift
        else:
            h = h + t_emb
            h = self.norm2(h)
        
        h = F.silu(h)
        h = self.dropout(h)
        h = self.conv2(h)
        
        return h + self.shortcut(x)

class AttentionBlock(nn.Module):
    def __init__(self, channels, num_heads=4):
        super().__init__()
        self.num_heads = num_heads
        self.norm = normalization(channels)
        self.qkv = nn.Conv2d(channels, channels * 3, 1)
        self.proj = nn.Conv2d(channels, channels, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.qkv(self.norm(x))
        qkv = rearrange(qkv, 'b (three heads c) h w -> three b heads c (h w)', three=3, heads=self.num_heads)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = torch.einsum('bhci,bhcj->bhij', q, k) * (c ** -0.5)
        attn = F.softmax(attn, dim=-1)
        
        out = torch.einsum('bhij,bhcj->bhci', attn, v)
        out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', h=h, w=w)
        out = self.proj(out)
        
        return x + out

class Downsample(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=1)

    def forward(self, x):
        return self.conv(x)

class Upsample(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, 3, padding=1)

    def forward(self, x):
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        return self.conv(x)

print("Model utilities defined.")

In [None]:
class UNet(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        
        # Time embedding
        time_embed_dim = config.model_channels * 4
        self.time_embed = TimestepEmbedding(config.model_channels, time_embed_dim)
        
        # Initial convolution
        self.conv_in = nn.Conv2d(config.in_channels, config.model_channels, 3, padding=1)
        
        # Downsampling
        self.down_blocks = nn.ModuleList()
        channels = [config.model_channels]
        ch = config.model_channels
        
        for level, mult in enumerate(config.channel_mult):
            out_ch = config.model_channels * mult
            for _ in range(config.num_res_blocks):
                layers = [ResBlock(ch, out_ch, time_embed_dim, config.dropout, config.use_scale_shift_norm)]
                ch = out_ch
                if config.image_size // (2 ** level) in config.attention_resolutions:
                    layers.append(AttentionBlock(ch, config.num_heads))
                self.down_blocks.append(nn.ModuleList(layers))
                channels.append(ch)
            
            if level != len(config.channel_mult) - 1:
                self.down_blocks.append(nn.ModuleList([Downsample(ch)]))
                channels.append(ch)
        
        # Middle
        self.middle = nn.ModuleList([
            ResBlock(ch, ch, time_embed_dim, config.dropout, config.use_scale_shift_norm),
            AttentionBlock(ch, config.num_heads),
            ResBlock(ch, ch, time_embed_dim, config.dropout, config.use_scale_shift_norm),
        ])
        
        # Upsampling
        self.up_blocks = nn.ModuleList()
        for level, mult in enumerate(reversed(config.channel_mult)):
            for i in range(config.num_res_blocks + 1):
                skip_ch = channels.pop()
                layers = [ResBlock(ch + skip_ch, config.model_channels * mult, time_embed_dim, 
                                   config.dropout, config.use_scale_shift_norm)]
                ch = config.model_channels * mult
                if config.image_size // (2 ** (len(config.channel_mult) - 1 - level)) in config.attention_resolutions:
                    layers.append(AttentionBlock(ch, config.num_heads))
                if level < len(config.channel_mult) - 1 and i == config.num_res_blocks:
                    layers.append(Upsample(ch))
                self.up_blocks.append(nn.ModuleList(layers))
        
        # Output
        self.norm_out = normalization(ch)
        self.conv_out = nn.Conv2d(ch, config.out_channels, 3, padding=1)

    def forward(self, x, t):
        # Time embedding
        t_emb = get_timestep_embedding(t, self.config.model_channels)
        t_emb = self.time_embed(t_emb)
        
        # Initial conv
        h = self.conv_in(x)
        hs = [h]
        
        # Downsampling
        for blocks in self.down_blocks:
            for block in blocks:
                if isinstance(block, ResBlock):
                    h = block(h, t_emb)
                elif isinstance(block, AttentionBlock):
                    h = block(h)
                else:  # Downsample
                    h = block(h)
                hs.append(h)
        
        # Middle
        for block in self.middle:
            if isinstance(block, ResBlock):
                h = block(h, t_emb)
            else:
                h = block(h)
        
        # Upsampling
        for blocks in self.up_blocks:
            skip = hs.pop()
            h = torch.cat([h, skip], dim=1)
            for block in blocks:
                if isinstance(block, ResBlock):
                    h = block(h, t_emb)
                elif isinstance(block, AttentionBlock):
                    h = block(h)
                else:  # Upsample
                    h = block(h)
        
        # Output
        h = self.norm_out(h)
        h = F.silu(h)
        h = self.conv_out(h)
        
        return h

print("UNet model defined.")

## 4. Diffusion Process

In [None]:
def get_beta_schedule(schedule_type: str, timesteps: int, beta_start: float = 1e-4, 
                      beta_end: float = 0.02, s: float = 0.008):
    """
    Get beta schedule for diffusion process.
    """
    if schedule_type == "linear":
        return torch.linspace(beta_start, beta_end, timesteps)
    elif schedule_type == "quadratic":
        return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2
    elif schedule_type == "cosine":
        steps = timesteps + 1
        t = torch.linspace(0, timesteps, steps)
        alphas_cumprod = torch.cos(((t / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
        betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
        return torch.clip(betas, 0.0001, 0.9999)
    else:
        raise ValueError(f"Unknown schedule type: {schedule_type}")

class GaussianDiffusion(nn.Module):
    def __init__(self, model: nn.Module, config: DiffusionConfig):
        super().__init__()
        self.model = model
        self.config = config
        
        # Beta schedule
        betas = get_beta_schedule(
            config.beta_schedule, 
            config.timesteps, 
            config.beta_start, 
            config.beta_end,
            config.s
        )
        
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
        
        # Register buffers
        self.register_buffer('betas', betas)
        self.register_buffer('alphas', alphas)
        self.register_buffer('alphas_cumprod', alphas_cumprod)
        self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
        
        # Calculations for diffusion q(x_t | x_{t-1})
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1.0 - alphas_cumprod))
        
        # Calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        self.register_buffer('posterior_variance', posterior_variance)
        self.register_buffer('posterior_log_variance_clipped', 
                             torch.log(torch.clamp(posterior_variance, min=1e-20)))
        self.register_buffer('posterior_mean_coef1', 
                             betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod))
        self.register_buffer('posterior_mean_coef2', 
                             (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod))

    def q_sample(self, x_start, t, noise=None):
        """
        Diffuse the data (forward process).
        """
        if noise is None:
            noise = torch.randn_like(x_start)
        
        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t][:, None, None, None]
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t][:, None, None, None]
        
        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

    def p_mean_variance(self, x, t):
        """
        Predict mean and variance for reverse process.
        """
        # Predict noise
        pred_noise = self.model(x, t)
        
        # Get coefficients
        sqrt_recip_alphas_cumprod = 1.0 / self.sqrt_alphas_cumprod
        sqrt_recipm1_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod / self.sqrt_alphas_cumprod
        
        # Predict x_0
        sqrt_recip_alphas_cumprod_t = sqrt_recip_alphas_cumprod[t][:, None, None, None]
        sqrt_recipm1_alphas_cumprod_t = sqrt_recipm1_alphas_cumprod[t][:, None, None, None]
        
        pred_x0 = sqrt_recip_alphas_cumprod_t * x - sqrt_recipm1_alphas_cumprod_t * pred_noise
        pred_x0 = torch.clamp(pred_x0, -1, 1)
        
        # Get posterior mean and variance
        posterior_mean_coef1_t = self.posterior_mean_coef1[t][:, None, None, None]
        posterior_mean_coef2_t = self.posterior_mean_coef2[t][:, None, None, None]
        
        posterior_mean = posterior_mean_coef1_t * pred_x0 + posterior_mean_coef2_t * x
        posterior_variance = self.posterior_variance[t][:, None, None, None]
        
        return posterior_mean, posterior_variance

    def p_sample(self, x, t):
        """
        Sample from reverse process.
        """
        mean, variance = self.p_mean_variance(x, t)
        noise = torch.randn_like(x)
        # No noise when t == 0
        nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
        return mean + nonzero_mask * torch.sqrt(variance) * noise

    @torch.no_grad()
    def sample(self, batch_size: int, image_size: int, progress: bool = True):
        """
        Generate samples from the model.
        """
        device = next(self.model.parameters()).device
        
        # Start from pure noise
        x = torch.randn(batch_size, 3, image_size, image_size, device=device)
        
        # Denoise progressively
        timesteps = list(range(self.config.timesteps))[::-1]
        if progress:
            timesteps = tqdm(timesteps, desc="Sampling")
        
        for t in timesteps:
            t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long)
            x = self.p_sample(x, t_batch)
        
        return x

    def forward(self, x_start):
        """
        Training loss (simplified MSE loss).
        """
        batch_size = x_start.shape[0]
        device = x_start.device
        
        # Sample random timesteps
        t = torch.randint(0, self.config.timesteps, (batch_size,), device=device).long()
        
        # Sample noise
        noise = torch.randn_like(x_start)
        
        # Add noise to images
        x_noisy = self.q_sample(x_start, t, noise)
        
        # Predict noise
        pred_noise = self.model(x_noisy, t)
        
        # Calculate loss
        loss = F.mse_loss(pred_noise, noise)
        
        return loss

print("Diffusion process defined.")

## 5. EMA (Exponential Moving Average)

In [None]:
class EMA:
    """
    Exponential Moving Average of model parameters.
    """
    def __init__(self, model, decay=0.9999):
        self.decay = decay
        self.shadow = {}
        self.original = {}
        
        # Initialize shadow parameters
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self, model):
        """Update EMA parameters."""
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.shadow[name] -= (1 - self.decay) * (self.shadow[name] - param.data)

    def apply_shadow(self, model):
        """Apply EMA parameters to model."""
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.original[name] = param.data.clone()
                param.data = self.shadow[name]

    def restore(self, model):
        """Restore original parameters."""
        for name, param in model.named_parameters():
            if param.requires_grad:
                param.data = self.original[name]
        self.original = {}

    def state_dict(self):
        return {'shadow': self.shadow, 'decay': self.decay}

    def load_state_dict(self, state_dict):
        self.shadow = state_dict['shadow']
        self.decay = state_dict['decay']

print("EMA class defined.")

## 6. Data Loading

In [None]:
# Data transforms
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Scale to [-1, 1]
])

# Load CIFAR-10
train_dataset = CIFAR10(
    root=config.data_dir,
    train=True,
    download=True,
    transform=transform
)

train_loader = DataLoader(
    train_dataset,
    batch_size=config.training.batch_size,
    shuffle=True,
    num_workers=config.training.num_workers,
    pin_memory=True,
    drop_last=True
)

print(f"Dataset: CIFAR-10")
print(f"Training samples: {len(train_dataset):,}")
print(f"Batch size: {config.training.batch_size}")
print(f"Batches per epoch: {len(train_loader)}")

## 7. Training

In [None]:
# Create model
model = UNet(config.model).to(device)
diffusion = GaussianDiffusion(model, config.diffusion).to(device)

# Count parameters
num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params:,}")

# EMA
ema = EMA(model, decay=config.training.ema_decay)

# Optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config.training.learning_rate,
    betas=(0.9, 0.999),
    weight_decay=0.0
)

# Learning rate scheduler with warmup
def lr_lambda(step):
    if step < config.training.warmup_steps:
        return step / config.training.warmup_steps
    return 1.0

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

# Mixed precision scaler
scaler = GradScaler() if config.training.mixed_precision else None

print("Model initialized and ready for training!")

In [None]:
# Visualization utilities
def unnormalize(x):
    """Unnormalize images from [-1, 1] to [0, 1]."""
    return (x + 1) / 2

def plot_samples(samples, title="Generated Samples"):
    """Plot a grid of samples."""
    samples = unnormalize(samples)
    samples = torch.clamp(samples, 0, 1)
    
    grid = torchvision.utils.make_grid(samples, nrow=8, padding=2)
    grid_np = grid.cpu().permute(1, 2, 0).numpy()
    
    plt.figure(figsize=(12, 12))
    plt.imshow(grid_np)
    plt.title(title)
    plt.axis('off')
    plt.tight_layout()
    plt.show()

def save_checkpoint(step, model, optimizer, ema, scheduler, metrics, path):
    """Save training checkpoint."""
    checkpoint = {
        'step': step,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'ema': ema.state_dict(),
        'scheduler': scheduler.state_dict(),
        'metrics': metrics,
        'config': config,
    }
    torch.save(checkpoint, path)
    print(f"Checkpoint saved: {path}")

print("Utilities ready.")

In [None]:
# Training loop
print("Starting training...")
print(f"Total steps: {config.training.total_steps:,}")
print(f"Estimated time: ~10-11 hours on T4 GPU")
print("=" * 80)

# Setup directories
checkpoint_dir = Path(config.output_dir) / "checkpoints"
sample_dir = Path(config.output_dir) / "samples"
checkpoint_dir.mkdir(parents=True, exist_ok=True)
sample_dir.mkdir(parents=True, exist_ok=True)

# Training state
step = 0
running_loss = 0.0
losses = []
start_time = datetime.now()

# Progress bar
pbar = tqdm(total=config.training.total_steps, desc="Training")

# Training loop
model.train()
epoch = 0

while step < config.training.total_steps:
    epoch += 1
    
    for batch in train_loader:
        if step >= config.training.total_steps:
            break
        
        images = batch[0].to(device)
        
        # Forward pass
        optimizer.zero_grad()
        
        if scaler is not None:
            with autocast():
                loss = diffusion(images)
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            nn.utils.clip_grad_norm_(model.parameters(), config.training.grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss = diffusion(images)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), config.training.grad_clip)
            optimizer.step()
        
        scheduler.step()
        ema.update(model)
        
        step += 1
        running_loss += loss.item()
        
        # Logging
        if step % config.training.log_every == 0:
            avg_loss = running_loss / config.training.log_every
            losses.append(avg_loss)
            lr = scheduler.get_last_lr()[0]
            
            elapsed = (datetime.now() - start_time).total_seconds()
            steps_per_sec = step / elapsed
            remaining_steps = config.training.total_steps - step
            eta_seconds = remaining_steps / steps_per_sec
            eta_hours = eta_seconds / 3600
            
            pbar.set_postfix({
                'loss': f'{avg_loss:.4f}',
                'lr': f'{lr:.6f}',
                'ETA': f'{eta_hours:.1f}h'
            })
            
            running_loss = 0.0
        
        # Generate samples
        if step % config.training.sample_every == 0 or step == config.training.total_steps:
            print(f"\n[Step {step}/{config.training.total_steps}] Generating samples...")
            
            # Use EMA model for sampling
            ema.apply_shadow(model)
            model.eval()
            
            with torch.no_grad():
                samples = diffusion.sample(
                    batch_size=64,
                    image_size=config.model.image_size,
                    progress=False
                )
            
            # Save samples
            samples_unnorm = unnormalize(samples)
            samples_unnorm = torch.clamp(samples_unnorm, 0, 1)
            grid = torchvision.utils.make_grid(samples_unnorm, nrow=8, padding=2)
            save_path = sample_dir / f"samples_step_{step:08d}.png"
            torchvision.utils.save_image(grid, save_path)
            
            # Display samples
            clear_output(wait=True)
            plot_samples(samples, title=f"Samples at step {step}")
            
            ema.restore(model)
            model.train()
        
        # Save checkpoint
        if step % config.training.save_every == 0 or step == config.training.total_steps:
            checkpoint_path = checkpoint_dir / f"checkpoint_step_{step:08d}.pt"
            save_checkpoint(
                step=step,
                model=model,
                optimizer=optimizer,
                ema=ema,
                scheduler=scheduler,
                metrics={'loss': loss.item()},
                path=checkpoint_path
            )
            
            # Also save as latest
            latest_path = checkpoint_dir / "checkpoint_latest.pt"
            save_checkpoint(
                step=step,
                model=model,
                optimizer=optimizer,
                ema=ema,
                scheduler=scheduler,
                metrics={'loss': loss.item()},
                path=latest_path
            )
        
        pbar.update(1)

pbar.close()

# Training complete
total_time = (datetime.now() - start_time).total_seconds() / 3600
print("\n" + "=" * 80)
print("Training complete!")
print(f"Total time: {total_time:.2f} hours")
print(f"Final loss: {loss.item():.4f}")
print(f"Checkpoints saved to: {checkpoint_dir}")
print(f"Samples saved to: {sample_dir}")
print("=" * 80)

## 8. Results and Visualization

In [None]:
# Plot training loss
plt.figure(figsize=(12, 4))
plt.plot(losses)
plt.xlabel('Step (x100)')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True)
plt.tight_layout()
plt.savefig(Path(config.output_dir) / 'training_loss.png', dpi=150)
plt.show()

In [None]:
# Generate final samples
print("Generating final samples with EMA model...")

ema.apply_shadow(model)
model.eval()

with torch.no_grad():
    final_samples = diffusion.sample(
        batch_size=64,
        image_size=config.model.image_size,
        progress=True
    )

plot_samples(final_samples, title="Final Generated Samples")

# Save final samples
final_samples_unnorm = unnormalize(final_samples)
final_samples_unnorm = torch.clamp(final_samples_unnorm, 0, 1)
grid = torchvision.utils.make_grid(final_samples_unnorm, nrow=8, padding=2)
torchvision.utils.save_image(grid, Path(config.output_dir) / 'final_samples.png')

ema.restore(model)

In [None]:
# Save training summary
summary = {
    'exp_name': config.exp_name,
    'total_steps': step,
    'total_hours': total_time,
    'timesteps': config.diffusion.timesteps,
    'beta_schedule': config.diffusion.beta_schedule,
    'batch_size': config.training.batch_size,
    'learning_rate': config.training.learning_rate,
    'final_loss': loss.item(),
    'num_parameters': num_params,
}

summary_path = Path(config.output_dir) / 'training_summary.json'
with open(summary_path, 'w') as f:
    json.dump(summary, f, indent=2)

print("\nTraining Summary:")
print(json.dumps(summary, indent=2))
print(f"\nSummary saved to: {summary_path}")

## 9. Inference - Generate More Samples

In [None]:
# Generate additional samples
print("Generating additional samples...")

ema.apply_shadow(model)
model.eval()

num_batches = 4  # Generate 4 batches of 64 samples each

for i in range(num_batches):
    with torch.no_grad():
        samples = diffusion.sample(
            batch_size=64,
            image_size=config.model.image_size,
            progress=True
        )
    
    plot_samples(samples, title=f"Generated Samples Batch {i+1}")

ema.restore(model)

print(f"Generated {num_batches * 64} total samples!")

## Notes

### Performance Expectations

- **Training time**: 約10-11時間 (T4 GPU)
- **Steps**: 100,000 steps
- **Final FID**: 30-50程度を期待（フルトレーニングの800,000 stepsでFID ~3.17）

### Tips

1. **チェックポイント**: `/content/drive/MyDrive/DDPM_outputs/checkpoints/`に保存
2. **サンプル画像**: `/content/drive/MyDrive/DDPM_outputs/samples/`に保存
3. **接続が切れた場合**: 最新のcheckpoint (`checkpoint_latest.pt`) から再開可能

### Resume Training

接続が切れた場合、以下のコードで再開:

```python
checkpoint = torch.load('/content/drive/MyDrive/DDPM_outputs/checkpoints/checkpoint_latest.pt')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
ema.load_state_dict(checkpoint['ema'])
scheduler.load_state_dict(checkpoint['scheduler'])
step = checkpoint['step']
```

### References

- Paper: [Denoising Diffusion Probabilistic Models (Ho et al., 2020)](https://arxiv.org/abs/2006.11239)
- Original results: FID 3.17 on CIFAR-10 (800,000 steps)
