# DDPM (Denoising Diffusion Probabilistic Models) - TPU Production Training

Google Colab TPU環境用の本番環境notebookです。

**TPU最適化設定:**
- Dataset: CIFAR-10
- Total steps: 100,000
- Batch size: 512 (TPU最適化)
- Learning rate: 4e-4 (バッチサイズに合わせてスケール)
- Timesteps: 1000
- Runtime: TPU v2-8

**Reference:**
Ho et al. 2020 "Denoising Diffusion Probabilistic Models"

---

## Setup Instructions

1. Runtime > Change runtime type > TPU
2. Run all cells in order
3. Training will be faster than GPU due to TPU parallelism

## 1. Environment Setup

In [2]:
# TPU確認
import os
assert 'COLAB_TPU_ADDR' in os.environ or os.path.exists('/dev/accel0'), "TPU not available! Change runtime to TPU."
print("TPU is available!")

AssertionError: TPU not available! Change runtime to TPU.

In [None]:
# 出力ディレクトリを設定（ローカル）
import os
OUTPUT_DIR = './outputs'
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Output directory: {OUTPUT_DIR}")

In [None]:
# TPU用PyTorch/XLAのインストール
!pip install -q torch torchvision
!pip install -q cloud-tpu-client https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp310-cp310-linux_x86_64.whl
!pip install -q einops tqdm scipy pillow

print("Package installation complete!")

In [None]:
# インポート
import os
import sys
import random
import math
from pathlib import Path
from dataclasses import dataclass, field
from typing import Literal, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR10
from tqdm.auto import tqdm
from einops import rearrange
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
import json
from datetime import datetime

# PyTorch/XLA imports for TPU
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl

# Device設定 (TPU)
device = xm.xla_device()
print(f"Using device: {device}")
print(f"TPU cores available: 8")

## 2. Configuration (TPU Optimized)

In [None]:
@dataclass
class ModelConfig:
    """U-Net model configuration."""
    image_size: int = 32
    in_channels: int = 3
    out_channels: int = 3
    model_channels: int = 128
    channel_mult: tuple = (1, 2, 2, 2)
    num_res_blocks: int = 2
    attention_resolutions: tuple = (16,)
    dropout: float = 0.0  # TPU: Disable dropout for faster training
    num_heads: int = 4
    use_scale_shift_norm: bool = True

@dataclass
class DiffusionConfig:
    """Diffusion process configuration."""
    timesteps: int = 1000
    beta_schedule: Literal["linear", "cosine", "quadratic"] = "linear"
    beta_start: float = 1e-4
    beta_end: float = 0.02
    s: float = 0.008

@dataclass
class TrainingConfig:
    """Training configuration optimized for TPU."""
    batch_size: int = 512  # TPU: Large batch size for efficiency
    learning_rate: float = 4e-4  # TPU: Scaled for larger batch (linear scaling)
    total_steps: int = 100_000
    warmup_steps: int = 5000
    grad_clip: float = 1.0
    ema_decay: float = 0.9999
    save_every: int = 5000
    sample_every: int = 5000
    log_every: int = 100
    num_workers: int = 4  # TPU: More workers for data loading
    mixed_precision: bool = False  # TPU: bfloat16 handled by XLA automatically

@dataclass
class Config:
    """Main configuration."""
    model: ModelConfig = field(default_factory=ModelConfig)
    diffusion: DiffusionConfig = field(default_factory=DiffusionConfig)
    training: TrainingConfig = field(default_factory=TrainingConfig)
    output_dir: str = OUTPUT_DIR
    data_dir: str = "./data"
    device: str = "xla"  # TPU device
    seed: int = 42
    exp_name: str = "ddpm_tpu_production"

# Create config
config = Config()
print("TPU Configuration:")
print(f"  Total steps: {config.training.total_steps:,}")
print(f"  Batch size: {config.training.batch_size} (TPU optimized)")
print(f"  Learning rate: {config.training.learning_rate} (scaled for batch size)")
print(f"  Timesteps: {config.diffusion.timesteps}")
print(f"  Dropout: {config.model.dropout} (disabled for TPU speed)")

In [None]:
# Seed設定
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    xm.set_rng_state(seed)

set_seed(config.seed)
print(f"Random seed set to {config.seed}")

## 3. Model Definition

In [None]:
class SinusoidalPositionEmbeddings(nn.Module):
    """Sinusoidal position embeddings for timestep encoding."""
    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim

    def forward(self, time: torch.Tensor) -> torch.Tensor:
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings


class Swish(nn.Module):
    """Swish activation function."""
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x * torch.sigmoid(x)


class GroupNorm32(nn.GroupNorm):
    """GroupNorm with float32 computation for stability."""
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return super().forward(x.float()).type(x.dtype)


def normalization(channels: int, num_groups: int = 32) -> nn.Module:
    """Create a normalization layer."""
    return GroupNorm32(min(num_groups, channels), channels)


class ResidualBlock(nn.Module):
    """Residual block with timestep conditioning."""
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        time_emb_dim: int,
        dropout: float = 0.0,
        use_scale_shift_norm: bool = True,
    ):
        super().__init__()
        self.use_scale_shift_norm = use_scale_shift_norm

        self.norm1 = normalization(in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)

        self.norm2 = normalization(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)

        # Time embedding projection
        time_out_dim = out_channels * 2 if use_scale_shift_norm else out_channels
        self.time_mlp = nn.Sequential(
            Swish(),
            nn.Linear(time_emb_dim, time_out_dim),
        )

        self.dropout = nn.Dropout(dropout)

        # Skip connection
        if in_channels != out_channels:
            self.skip = nn.Conv2d(in_channels, out_channels, 1)
        else:
            self.skip = nn.Identity()

    def forward(self, x: torch.Tensor, time_emb: torch.Tensor) -> torch.Tensor:
        h = self.norm1(x)
        h = F.silu(h)
        h = self.conv1(h)

        # Add time embedding
        time_emb = self.time_mlp(time_emb)
        time_emb = rearrange(time_emb, "b c -> b c 1 1")

        if self.use_scale_shift_norm:
            scale, shift = time_emb.chunk(2, dim=1)
            h = self.norm2(h) * (1 + scale) + shift
        else:
            h = h + time_emb
            h = self.norm2(h)

        h = F.silu(h)
        h = self.dropout(h)
        h = self.conv2(h)

        return h + self.skip(x)


class AttentionBlock(nn.Module):
    """Multi-head self-attention block."""
    def __init__(self, channels: int, num_heads: int = 4):
        super().__init__()
        self.channels = channels
        self.num_heads = num_heads
        self.head_dim = channels // num_heads

        self.norm = normalization(channels)
        self.qkv = nn.Conv2d(channels, channels * 3, 1)
        self.proj = nn.Conv2d(channels, channels, 1)

        self.scale = self.head_dim**-0.5

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b, c, h, w = x.shape

        # Normalize
        x_norm = self.norm(x)

        # QKV projection
        qkv = self.qkv(x_norm)
        qkv = rearrange(qkv, "b (three heads d) h w -> three b heads (h w) d",
                        three=3, heads=self.num_heads)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Attention
        attn = torch.einsum("bhid,bhjd->bhij", q, k) * self.scale
        attn = attn.softmax(dim=-1)

        # Aggregate
        out = torch.einsum("bhij,bhjd->bhid", attn, v)
        out = rearrange(out, "b heads (h w) d -> b (heads d) h w", h=h, w=w)

        # Project
        out = self.proj(out)

        return x + out


class Downsample(nn.Module):
    """Downsample by factor of 2."""
    def __init__(self, channels: int):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.conv(x)


class Upsample(nn.Module):
    """Upsample by factor of 2."""
    def __init__(self, channels: int):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, 3, padding=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.interpolate(x, scale_factor=2, mode="nearest")
        return self.conv(x)


print("Model utilities defined.")

In [None]:
class UNet(nn.Module):
    """
    U-Net model for DDPM.
    """
    def __init__(
        self,
        image_size: int = 32,
        in_channels: int = 3,
        out_channels: int = 3,
        model_channels: int = 128,
        channel_mult: tuple = (1, 2, 2, 2),
        num_res_blocks: int = 2,
        attention_resolutions: tuple = (16,),
        dropout: float = 0.0,
        num_heads: int = 4,
        use_scale_shift_norm: bool = True,
    ):
        super().__init__()

        self.image_size = image_size
        self.in_channels = in_channels
        self.model_channels = model_channels

        # Time embedding
        time_emb_dim = model_channels * 4
        self.time_embed = nn.Sequential(
            SinusoidalPositionEmbeddings(model_channels),
            nn.Linear(model_channels, time_emb_dim),
            Swish(),
            nn.Linear(time_emb_dim, time_emb_dim),
        )

        # Initial projection
        self.init_conv = nn.Conv2d(in_channels, model_channels, 3, padding=1)

        # Build encoder
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()

        channels = [model_channels]
        ch = model_channels
        resolution = image_size

        for level, mult in enumerate(channel_mult):
            out_ch = model_channels * mult

            for _ in range(num_res_blocks):
                layers = [
                    ResidualBlock(
                        ch, out_ch, time_emb_dim, dropout, use_scale_shift_norm
                    )
                ]
                ch = out_ch

                if resolution in attention_resolutions:
                    layers.append(AttentionBlock(ch, num_heads))

                self.downs.append(nn.ModuleList(layers))
                channels.append(ch)

            # Downsample (except last level)
            if level != len(channel_mult) - 1:
                self.downs.append(nn.ModuleList([Downsample(ch)]))
                channels.append(ch)
                resolution //= 2

        # Middle block
        self.mid = nn.ModuleList([
            ResidualBlock(ch, ch, time_emb_dim, dropout, use_scale_shift_norm),
            AttentionBlock(ch, num_heads),
            ResidualBlock(ch, ch, time_emb_dim, dropout, use_scale_shift_norm),
        ])

        # Build decoder
        for level, mult in enumerate(reversed(channel_mult)):
            out_ch = model_channels * mult

            for i in range(num_res_blocks + 1):
                skip_ch = channels.pop()
                layers = [
                    ResidualBlock(
                        ch + skip_ch, out_ch, time_emb_dim, dropout, use_scale_shift_norm
                    )
                ]
                ch = out_ch

                if resolution in attention_resolutions:
                    layers.append(AttentionBlock(ch, num_heads))

                # Upsample (except last block of each level, and not on last level)
                if level != len(channel_mult) - 1 and i == num_res_blocks:
                    layers.append(Upsample(ch))
                    resolution *= 2

                self.ups.append(nn.ModuleList(layers))

        # Final projection
        self.final_norm = normalization(ch)
        self.final_conv = nn.Conv2d(ch, out_channels, 3, padding=1)

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        """Initialize weights with proper scaling."""
        for module in self.modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

        # Zero initialize the final conv for better training stability
        nn.init.zeros_(self.final_conv.weight)
        nn.init.zeros_(self.final_conv.bias)

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Forward pass.

        Args:
            x: Input images [B, C, H, W]
            t: Timesteps [B]

        Returns:
            Predicted noise [B, C, H, W]
        """
        # Time embedding
        t_emb = self.time_embed(t)

        # Initial conv
        h = self.init_conv(x)
        hs = [h]

        # Encoder
        for layers in self.downs:
            for layer in layers:
                if isinstance(layer, ResidualBlock):
                    h = layer(h, t_emb)
                elif isinstance(layer, AttentionBlock):
                    h = layer(h)
                else:  # Downsample
                    h = layer(h)
            hs.append(h)

        # Middle
        for layer in self.mid:
            if isinstance(layer, ResidualBlock):
                h = layer(h, t_emb)
            else:
                h = layer(h)

        # Decoder
        for layers in self.ups:
            h = torch.cat([h, hs.pop()], dim=1)
            for layer in layers:
                if isinstance(layer, ResidualBlock):
                    h = layer(h, t_emb)
                elif isinstance(layer, AttentionBlock):
                    h = layer(h)
                else:  # Upsample
                    h = layer(h)

        # Final
        h = self.final_norm(h)
        h = F.silu(h)
        h = self.final_conv(h)

        return h


def create_model(cfg: ModelConfig) -> UNet:
    """Create U-Net model from config."""
    return UNet(
        image_size=cfg.image_size,
        in_channels=cfg.in_channels,
        out_channels=cfg.out_channels,
        model_channels=cfg.model_channels,
        channel_mult=cfg.channel_mult,
        num_res_blocks=cfg.num_res_blocks,
        attention_resolutions=cfg.attention_resolutions,
        dropout=cfg.dropout,
        num_heads=cfg.num_heads,
        use_scale_shift_norm=cfg.use_scale_shift_norm,
    )


print("UNet model defined.")

## 4. Diffusion Process

In [None]:
def get_beta_schedule(schedule_type: str, timesteps: int, beta_start: float = 1e-4, 
                      beta_end: float = 0.02, s: float = 0.008):
    """
    Get beta schedule for diffusion process.
    """
    if schedule_type == "linear":
        return torch.linspace(beta_start, beta_end, timesteps)
    elif schedule_type == "quadratic":
        return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2
    elif schedule_type == "cosine":
        steps = timesteps + 1
        t = torch.linspace(0, timesteps, steps)
        alphas_cumprod = torch.cos(((t / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
        betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
        return torch.clip(betas, 0.0001, 0.9999)
    else:
        raise ValueError(f"Unknown schedule type: {schedule_type}")

class GaussianDiffusion(nn.Module):
    def __init__(self, model: nn.Module, config: DiffusionConfig):
        super().__init__()
        self.model = model
        self.config = config
        
        # Beta schedule
        betas = get_beta_schedule(
            config.beta_schedule, 
            config.timesteps, 
            config.beta_start, 
            config.beta_end,
            config.s
        )
        
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
        
        # Register buffers
        self.register_buffer('betas', betas)
        self.register_buffer('alphas', alphas)
        self.register_buffer('alphas_cumprod', alphas_cumprod)
        self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
        
        # Calculations for diffusion q(x_t | x_{t-1})
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1.0 - alphas_cumprod))
        
        # Calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        self.register_buffer('posterior_variance', posterior_variance)
        self.register_buffer('posterior_log_variance_clipped', 
                             torch.log(torch.clamp(posterior_variance, min=1e-20)))
        self.register_buffer('posterior_mean_coef1', 
                             betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod))
        self.register_buffer('posterior_mean_coef2', 
                             (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod))

    def q_sample(self, x_start, t, noise=None):
        """
        Diffuse the data (forward process).
        """
        if noise is None:
            noise = torch.randn_like(x_start)
        
        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t][:, None, None, None]
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t][:, None, None, None]
        
        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

    def p_mean_variance(self, x, t):
        """
        Predict mean and variance for reverse process.
        """
        # Predict noise
        pred_noise = self.model(x, t)
        
        # Get coefficients
        sqrt_recip_alphas_cumprod = 1.0 / self.sqrt_alphas_cumprod
        sqrt_recipm1_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod / self.sqrt_alphas_cumprod
        
        # Predict x_0
        sqrt_recip_alphas_cumprod_t = sqrt_recip_alphas_cumprod[t][:, None, None, None]
        sqrt_recipm1_alphas_cumprod_t = sqrt_recipm1_alphas_cumprod[t][:, None, None, None]
        
        pred_x0 = sqrt_recip_alphas_cumprod_t * x - sqrt_recipm1_alphas_cumprod_t * pred_noise
        pred_x0 = torch.clamp(pred_x0, -1, 1)
        
        # Get posterior mean and variance
        posterior_mean_coef1_t = self.posterior_mean_coef1[t][:, None, None, None]
        posterior_mean_coef2_t = self.posterior_mean_coef2[t][:, None, None, None]
        
        posterior_mean = posterior_mean_coef1_t * pred_x0 + posterior_mean_coef2_t * x
        posterior_variance = self.posterior_variance[t][:, None, None, None]
        
        return posterior_mean, posterior_variance

    def p_sample(self, x, t):
        """
        Sample from reverse process.
        """
        mean, variance = self.p_mean_variance(x, t)
        noise = torch.randn_like(x)
        # No noise when t == 0
        nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
        return mean + nonzero_mask * torch.sqrt(variance) * noise

    @torch.no_grad()
    def sample(self, batch_size: int, image_size: int, device, progress: bool = True):
        """
        Generate samples from the model.
        """
        # Start from pure noise
        x = torch.randn(batch_size, 3, image_size, image_size, device=device)
        
        # Denoise progressively
        timesteps = list(range(self.config.timesteps))[::-1]
        if progress:
            timesteps = tqdm(timesteps, desc="Sampling")
        
        for t in timesteps:
            t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long)
            x = self.p_sample(x, t_batch)
        
        return x

    def forward(self, x_start):
        """
        Training loss (simplified MSE loss).
        """
        batch_size = x_start.shape[0]
        device = x_start.device
        
        # Sample random timesteps
        t = torch.randint(0, self.config.timesteps, (batch_size,), device=device).long()
        
        # Sample noise
        noise = torch.randn_like(x_start)
        
        # Add noise to images
        x_noisy = self.q_sample(x_start, t, noise)
        
        # Predict noise
        pred_noise = self.model(x_noisy, t)
        
        # Calculate loss
        loss = F.mse_loss(pred_noise, noise)
        
        return loss

print("Diffusion process defined.")

## 5. EMA (Exponential Moving Average)

In [None]:
class EMA:
    """
    Exponential Moving Average of model parameters.
    """
    def __init__(self, model, decay=0.9999):
        self.decay = decay
        self.shadow = {}
        self.original = {}
        
        # Initialize shadow parameters
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self, model):
        """Update EMA parameters."""
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.shadow[name] -= (1 - self.decay) * (self.shadow[name] - param.data)

    def apply_shadow(self, model):
        """Apply EMA parameters to model."""
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.original[name] = param.data.clone()
                param.data = self.shadow[name]

    def restore(self, model):
        """Restore original parameters."""
        for name, param in model.named_parameters():
            if param.requires_grad:
                param.data = self.original[name]
        self.original = {}

    def state_dict(self):
        return {'shadow': self.shadow, 'decay': self.decay}

    def load_state_dict(self, state_dict):
        self.shadow = state_dict['shadow']
        self.decay = state_dict['decay']

print("EMA class defined.")

## 6. Data Loading (TPU Optimized)

In [None]:
# Data transforms
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Scale to [-1, 1]
])

# Load CIFAR-10
train_dataset = CIFAR10(
    root=config.data_dir,
    train=True,
    download=True,
    transform=transform
)

# TPU: Use regular DataLoader, ParallelLoader will wrap it
train_loader = DataLoader(
    train_dataset,
    batch_size=config.training.batch_size,
    shuffle=True,
    num_workers=config.training.num_workers,
    drop_last=True
)

print(f"Dataset: CIFAR-10")
print(f"Training samples: {len(train_dataset):,}")
print(f"Batch size: {config.training.batch_size} (TPU optimized)")
print(f"Batches per epoch: {len(train_loader)}")

## 7. Training (TPU Optimized)

In [None]:
# Create model and move to TPU
model = create_model(config.model)
diffusion = GaussianDiffusion(model, config.diffusion)

# Move to TPU
model = model.to(device)
diffusion = diffusion.to(device)

# Count parameters
num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params:,}")

# EMA
ema = EMA(model, decay=config.training.ema_decay)

# Optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config.training.learning_rate,
    betas=(0.9, 0.999),
    weight_decay=0.0
)

# Learning rate scheduler with warmup
def lr_lambda(step):
    if step < config.training.warmup_steps:
        return step / config.training.warmup_steps
    return 1.0

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

print("Model initialized on TPU and ready for training!")

In [None]:
# Visualization utilities
def unnormalize(x):
    """Unnormalize images from [-1, 1] to [0, 1]."""
    return (x + 1) / 2

def plot_samples(samples, title="Generated Samples"):
    """Plot a grid of samples."""
    samples = unnormalize(samples)
    samples = torch.clamp(samples, 0, 1)
    
    grid = torchvision.utils.make_grid(samples, nrow=8, padding=2)
    grid_np = grid.cpu().permute(1, 2, 0).numpy()
    
    plt.figure(figsize=(12, 12))
    plt.imshow(grid_np)
    plt.title(title)
    plt.axis('off')
    plt.tight_layout()
    plt.show()

def save_checkpoint(step, model, optimizer, ema, scheduler, metrics, path):
    """Save training checkpoint (TPU compatible)."""
    # Move model to CPU for saving
    cpu_model_state = {k: v.cpu() for k, v in model.state_dict().items()}
    cpu_ema_shadow = {k: v.cpu() for k, v in ema.shadow.items()}
    
    checkpoint = {
        'step': step,
        'model': cpu_model_state,
        'optimizer': optimizer.state_dict(),
        'ema': {'shadow': cpu_ema_shadow, 'decay': ema.decay},
        'scheduler': scheduler.state_dict(),
        'metrics': metrics,
    }
    xm.save(checkpoint, path)
    print(f"Checkpoint saved: {path}")

print("Utilities ready.")

In [None]:
# Training loop (TPU optimized)
print("Starting TPU training...")
print(f"Total steps: {config.training.total_steps:,}")
print(f"Batch size: {config.training.batch_size} (TPU optimized)")
print(f"Learning rate: {config.training.learning_rate}")
print("=" * 80)

# Setup directories
checkpoint_dir = Path(config.output_dir) / "checkpoints"
sample_dir = Path(config.output_dir) / "samples"
checkpoint_dir.mkdir(parents=True, exist_ok=True)
sample_dir.mkdir(parents=True, exist_ok=True)

# Training state
step = 0
running_loss = 0.0
losses = []
start_time = datetime.now()

# Progress bar
pbar = tqdm(total=config.training.total_steps, desc="Training (TPU)")

# Training loop
model.train()
epoch = 0

while step < config.training.total_steps:
    epoch += 1
    
    # TPU: Wrap DataLoader with ParallelLoader for efficient data transfer
    para_loader = pl.ParallelLoader(train_loader, [device])
    
    for batch in para_loader.per_device_loader(device):
        if step >= config.training.total_steps:
            break
        
        images = batch[0]  # Already on TPU device
        
        # Forward pass
        optimizer.zero_grad()
        loss = diffusion(images)
        loss.backward()
        
        # TPU: Clip gradients
        xm.reduce_gradients(optimizer)  # Sync gradients across TPU cores
        nn.utils.clip_grad_norm_(model.parameters(), config.training.grad_clip)
        
        # TPU: Optimizer step with barrier
        xm.optimizer_step(optimizer)
        
        scheduler.step()
        ema.update(model)
        
        step += 1
        running_loss += loss.item()
        
        # Logging
        if step % config.training.log_every == 0:
            avg_loss = running_loss / config.training.log_every
            losses.append(avg_loss)
            lr = scheduler.get_last_lr()[0]
            
            elapsed = (datetime.now() - start_time).total_seconds()
            steps_per_sec = step / elapsed
            remaining_steps = config.training.total_steps - step
            eta_seconds = remaining_steps / steps_per_sec
            eta_hours = eta_seconds / 3600
            
            pbar.set_postfix({
                'loss': f'{avg_loss:.4f}',
                'lr': f'{lr:.6f}',
                'ETA': f'{eta_hours:.1f}h',
                'step/s': f'{steps_per_sec:.2f}'
            })
            
            running_loss = 0.0
        
        # Generate samples
        if step % config.training.sample_every == 0 or step == config.training.total_steps:
            print(f"\n[Step {step}/{config.training.total_steps}] Generating samples...")
            
            # Use EMA model for sampling
            ema.apply_shadow(model)
            model.eval()
            
            with torch.no_grad():
                samples = diffusion.sample(
                    batch_size=64,
                    image_size=config.model.image_size,
                    device=device,
                    progress=False
                )
            
            # Move samples to CPU for saving
            samples_cpu = samples.cpu()
            
            # Save samples
            samples_unnorm = unnormalize(samples_cpu)
            samples_unnorm = torch.clamp(samples_unnorm, 0, 1)
            grid = torchvision.utils.make_grid(samples_unnorm, nrow=8, padding=2)
            save_path = sample_dir / f"samples_step_{step:08d}.png"
            torchvision.utils.save_image(grid, save_path)
            
            # Display samples
            clear_output(wait=True)
            plot_samples(samples_cpu, title=f"Samples at step {step}")
            
            ema.restore(model)
            model.train()
        
        # Save checkpoint
        if step % config.training.save_every == 0 or step == config.training.total_steps:
            checkpoint_path = checkpoint_dir / f"checkpoint_step_{step:08d}.pt"
            save_checkpoint(
                step=step,
                model=model,
                optimizer=optimizer,
                ema=ema,
                scheduler=scheduler,
                metrics={'loss': loss.item()},
                path=str(checkpoint_path)
            )
            
            # Also save as latest
            latest_path = checkpoint_dir / "checkpoint_latest.pt"
            save_checkpoint(
                step=step,
                model=model,
                optimizer=optimizer,
                ema=ema,
                scheduler=scheduler,
                metrics={'loss': loss.item()},
                path=str(latest_path)
            )
        
        pbar.update(1)

pbar.close()

# Training complete
total_time = (datetime.now() - start_time).total_seconds() / 3600
print("\n" + "=" * 80)
print("Training complete!")
print(f"Total time: {total_time:.2f} hours")
print(f"Final loss: {loss.item():.4f}")
print(f"Checkpoints saved to: {checkpoint_dir}")
print(f"Samples saved to: {sample_dir}")
print("=" * 80)

## 8. Results and Visualization

In [None]:
# Plot training loss
plt.figure(figsize=(12, 4))
plt.plot(losses)
plt.xlabel('Step (x100)')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True)
plt.tight_layout()
plt.savefig(Path(config.output_dir) / 'training_loss.png', dpi=150)
plt.show()

In [None]:
# Generate final samples
print("Generating final samples with EMA model...")

ema.apply_shadow(model)
model.eval()

with torch.no_grad():
    final_samples = diffusion.sample(
        batch_size=64,
        image_size=config.model.image_size,
        device=device,
        progress=True
    )

final_samples_cpu = final_samples.cpu()
plot_samples(final_samples_cpu, title="Final Generated Samples")

# Save final samples
final_samples_unnorm = unnormalize(final_samples_cpu)
final_samples_unnorm = torch.clamp(final_samples_unnorm, 0, 1)
grid = torchvision.utils.make_grid(final_samples_unnorm, nrow=8, padding=2)
torchvision.utils.save_image(grid, Path(config.output_dir) / 'final_samples.png')

ema.restore(model)

In [None]:
# Save training summary
summary = {
    'exp_name': config.exp_name,
    'total_steps': step,
    'total_hours': total_time,
    'timesteps': config.diffusion.timesteps,
    'beta_schedule': config.diffusion.beta_schedule,
    'batch_size': config.training.batch_size,
    'learning_rate': config.training.learning_rate,
    'final_loss': loss.item(),
    'num_parameters': num_params,
    'device': 'TPU',
}

summary_path = Path(config.output_dir) / 'training_summary.json'
with open(summary_path, 'w') as f:
    json.dump(summary, f, indent=2)

print("\nTraining Summary:")
print(json.dumps(summary, indent=2))
print(f"\nSummary saved to: {summary_path}")

## 9. Inference - Generate More Samples

In [None]:
# Generate additional samples
print("Generating additional samples...")

ema.apply_shadow(model)
model.eval()

num_batches = 4  # Generate 4 batches of 64 samples each

for i in range(num_batches):
    with torch.no_grad():
        samples = diffusion.sample(
            batch_size=64,
            image_size=config.model.image_size,
            device=device,
            progress=True
        )
    
    samples_cpu = samples.cpu()
    plot_samples(samples_cpu, title=f"Generated Samples Batch {i+1}")

ema.restore(model)

print(f"Generated {num_batches * 64} total samples!")

## Notes

### TPU Performance Expectations

- **Training speed**: TPUはGPU (T4)より大幅に高速
- **Batch size**: 512 (TPU最適化)
- **Learning rate**: 4e-4 (バッチサイズに合わせてスケール)

### TPU Optimization Tips

1. **Dropout**: 0に設定（TPUでは高速化のため）
2. **Batch size**: 大きいほど効率的（512推奨）
3. **ParallelLoader**: データ転送を最適化
4. **xm.optimizer_step()**: TPU用の勾配同期

### Resume Training

接続が切れた場合、以下のコードで再開:

```python
checkpoint = torch.load('./outputs/checkpoints/checkpoint_latest.pt')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
ema.load_state_dict(checkpoint['ema'])
scheduler.load_state_dict(checkpoint['scheduler'])
step = checkpoint['step']
# Move to TPU
model = model.to(device)
```

### References

- Paper: [Denoising Diffusion Probabilistic Models (Ho et al., 2020)](https://arxiv.org/abs/2006.11239)
- Original results: FID 3.17 on CIFAR-10 (800,000 steps)
