# Lyapunov-Adaptive Predictive Coding Diffusion

Unifies diffusion sampling with predictive coding and a Lyapunov-guided scheduler.

## Notebook Roadmap

- Review configuration and data pipeline tuned for CIFAR-10.
- Define the EDM-style predictive coding UNet backbone.
- Formulate Lyapunov energy and the adaptive scheduler with optional policy head.
- Implement the predictive-coding Heun sampler and training loop with AMP + EMA.
- Provide hooks for sampling, checkpointing, and CleanFID evaluation.

In [1]:
import math
import os
from dataclasses import dataclass
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch_ema import ExponentialMovingAverage
from tqdm.auto import tqdm
from PIL import Image

try:
    from cleanfid import fid as cleanfid
except ImportError:
    cleanfid = None

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on {device}")

Running on cuda


In [2]:
@dataclass
class TrainingConfig:
    data_root: str = "./data"
    batch_size: int = 256
    num_workers: int = 8
    epochs: int = 400
    lr: float = 2e-4
    ema_decay: float = 0.9999
    sigma_data: float = 0.5
    sigma_min: float = 0.002
    sigma_max: float = 80.0
    steps: int = 80
    rho: float = 7.0
    grad_clip: float = 1.0
    amp: bool = True
    log_interval: int = 50
    sample_interval: int = 5000
    out_dir: str = "runs/lyapunov_pc"

os.makedirs("runs", exist_ok=True)
config = TrainingConfig()
os.makedirs(config.out_dir, exist_ok=True)
print(config)

TrainingConfig(data_root='./data', batch_size=256, num_workers=8, epochs=400, lr=0.0002, ema_decay=0.9999, sigma_data=0.5, sigma_min=0.002, sigma_max=80.0, steps=80, rho=7.0, grad_clip=1.0, amp=True, log_interval=50, sample_interval=5000, out_dir='runs/lyapunov_pc')


In [3]:
def build_dataloader(cfg: TrainingConfig) -> DataLoader:
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataset = datasets.CIFAR10(root=cfg.data_root, train=True, download=True, transform=transform)
    loader = DataLoader(dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, pin_memory=True, drop_last=True)
    return loader

train_loader = build_dataloader(config)
print(f"Training batches: {len(train_loader)}")

Training batches: 195




In [4]:
class FourierFeatures(nn.Module):
    def __init__(self, embedding_dim: int, scale: float = 1.0):
        super().__init__()
        half = embedding_dim // 2
        self.register_buffer("frequencies", torch.exp(torch.linspace(math.log(1.0), math.log(1000.0), half)))
        self.scale = scale

    def forward(self, sigma: torch.Tensor) -> torch.Tensor:
        sigma = sigma.view(-1, 1)
        angles = sigma * self.frequencies[None] * self.scale
        return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)

def zero_module(module: nn.Module) -> nn.Module:
    for param in module.parameters():
        nn.init.zeros_(param)
    return module

class ResidualBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, emb_channels: int, dropout: float = 0.0, skip_rescale: bool = True):
        super().__init__()
        self.skip_rescale = skip_rescale
        self.norm1 = nn.GroupNorm(num_groups=min(32, in_channels), num_channels=in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.emb_proj = nn.Linear(emb_channels, out_channels)
        self.norm2 = nn.GroupNorm(num_groups=min(32, out_channels), num_channels=out_channels)
        self.dropout = nn.Dropout(dropout)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        if in_channels != out_channels:
            self.skip = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        else:
            self.skip = nn.Identity()

    def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
        h = self.conv1(F.silu(self.norm1(x)))
        h = h + self.emb_proj(F.silu(temb))[:, :, None, None]
        h = self.conv2(self.dropout(F.silu(self.norm2(h))))
        res = self.skip(x)
        if self.skip_rescale:
            return (res + h) / math.sqrt(2.0)
        return res + h

class AttentionBlock(nn.Module):
    def __init__(self, channels: int, num_heads: int = 4):
        super().__init__()
        self.num_heads = num_heads
        self.norm = nn.GroupNorm(num_groups=min(32, channels), num_channels=channels)
        self.qkv = nn.Conv2d(channels, channels * 3, kernel_size=1)
        self.proj = nn.Conv2d(channels, channels, kernel_size=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b, c, h, w = x.shape
        head_dim = c // self.num_heads
        qkv = self.qkv(self.norm(x)).reshape(b, 3, self.num_heads, head_dim, h * w)
        q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]
        attn = torch.einsum("bhci,bhcj->bhij", q, k) * (head_dim ** -0.5)
        attn = attn.softmax(dim=-1)
        out = torch.einsum("bhij,bhcj->bhci", attn, v).reshape(b, c, h, w)
        return x + self.proj(out)

class Downsample(nn.Module):
    def __init__(self, channels: int):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, kernel_size=3, stride=2, padding=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.conv(x)

class Upsample(nn.Module):
    def __init__(self, channels: int):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.interpolate(x, scale_factor=2, mode="nearest")
        return self.conv(x)

In [5]:
class EDMUNet(nn.Module):
    def __init__(self, img_resolution: int = 32, in_channels: int = 3, out_channels: int = 3, model_channels: int = 192,
                 channel_mult: Tuple[int, ...] = (1, 2, 2), num_res_blocks: int = 3, attn_resolutions: Tuple[int, ...] = (16,),
                 dropout: float = 0.0, sigma_data: float = 0.5):
        super().__init__()
        self.img_resolution = img_resolution
        self.sigma_data = sigma_data
        self.in_conv = nn.Conv2d(in_channels, model_channels, kernel_size=3, padding=1)
        embed_dim = model_channels
        self.fourier = FourierFeatures(embed_dim)
        emb_channels = model_channels * 4
        self.time_embed = nn.Sequential(
            nn.Linear(embed_dim, emb_channels),
            nn.SiLU(),
            nn.Linear(emb_channels, emb_channels)
        )
        self.down_blocks = nn.ModuleList()
        self.downsamplers = nn.ModuleList()
        self.skip_channels: List[int] = []
        ch = model_channels
        resolution = img_resolution
        for level, mult in enumerate(channel_mult):
            out_ch = model_channels * mult
            block = nn.ModuleList()
            for _ in range(num_res_blocks):
                block.append(ResidualBlock(ch, out_ch, emb_channels, dropout))
                ch = out_ch
                if resolution in attn_resolutions:
                    block.append(AttentionBlock(ch))
            self.down_blocks.append(block)
            self.skip_channels.append(ch)
            if level < len(channel_mult) - 1:
                self.downsamplers.append(Downsample(ch))
                resolution //= 2
            else:
                self.downsamplers.append(nn.Identity())
        self.mid_block = nn.ModuleList([
            ResidualBlock(ch, ch, emb_channels, dropout),
            AttentionBlock(ch),
            ResidualBlock(ch, ch, emb_channels, dropout)
        ])
        self.up_blocks = nn.ModuleList()
        self.upsamplers = nn.ModuleList()
        for level, mult in reversed(list(enumerate(channel_mult))):
            out_ch = model_channels * mult
            block = nn.ModuleList()
            for idx in range(num_res_blocks + 1):
                skip_ch = self.skip_channels[level] if idx == 0 else 0
                block.append(ResidualBlock(ch + skip_ch, out_ch, emb_channels, dropout))
                ch = out_ch
                if (img_resolution // (2 ** level)) in attn_resolutions:
                    block.append(AttentionBlock(ch))
            self.up_blocks.append(block)
            if level > 0:
                self.upsamplers.append(Upsample(ch))
            else:
                self.upsamplers.append(nn.Identity())
        self.out_norm = nn.GroupNorm(num_groups=min(32, ch), num_channels=ch)
        self.out_conv = zero_module(nn.Conv2d(ch, out_channels, kernel_size=3, padding=1))

    def forward(self, x: torch.Tensor, c_noise: torch.Tensor) -> torch.Tensor:
        sigma = torch.exp(4.0 * c_noise)
        temb = self.time_embed(self.fourier(sigma))
        h = self.in_conv(x)
        residuals: List[torch.Tensor] = []
        for block, downsample in zip(self.down_blocks, self.downsamplers):
            for layer in block:
                if isinstance(layer, ResidualBlock):
                    h = layer(h, temb)
                else:
                    h = layer(h)
            residuals.append(h)
            h = downsample(h)
        for layer in self.mid_block:
            if isinstance(layer, ResidualBlock):
                h = layer(h, temb)
            else:
                h = layer(h)
        for block, upsample in zip(self.up_blocks, self.upsamplers):
            skip = residuals.pop()
            h = torch.cat([h, skip], dim=1)
            for layer in block:
                if isinstance(layer, ResidualBlock):
                    h = layer(h, temb)
                else:
                    h = layer(h)
            h = upsample(h)
        h = self.out_norm(h)
        h = F.silu(h)
        return self.out_conv(h)

model = EDMUNet().to(device)
ema = ExponentialMovingAverage(model.parameters(), decay=config.ema_decay)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

Parameters: 66,148,227


In [6]:
def get_karras_schedule(steps: int, sigma_min: float, sigma_max: float, rho: float, device: torch.device) -> torch.Tensor:
    ramp = torch.linspace(0, 1, steps, device=device)
    sigmas = (sigma_max ** (1 / rho) + ramp * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
    return torch.cat([sigmas, sigmas.new_tensor([0.0])])

def karras_sigma_distribution(batch_size: int, device: torch.device, p_mean: float = -1.2, p_std: float = 1.2) -> torch.Tensor:
    return torch.exp(torch.randn(batch_size, device=device) * p_std + p_mean)

def edm_preconditioning(sigma: torch.Tensor, sigma_data: float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    sigma = sigma.view(-1, 1, 1, 1)
    sigma2 = sigma.square()
    sigma_data2 = sigma_data ** 2
    c_skip = sigma_data2 / (sigma2 + sigma_data2)
    c_out = sigma * sigma_data / torch.sqrt(sigma2 + sigma_data2)
    c_in = 1.0 / torch.sqrt(sigma2 + sigma_data2)
    c_noise = torch.log(sigma.squeeze(-1).squeeze(-1)) / 4.0
    return c_in, c_out, c_skip, c_noise

def predictive_coding_corrector(model: nn.Module, x_sigma: torch.Tensor, sigma: torch.Tensor, sigma_data: float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    c_in, c_out, c_skip, c_noise = edm_preconditioning(sigma.to(x_sigma.device, x_sigma.dtype), sigma_data)
    prediction = model(c_in * x_sigma, c_noise)
    corrected = c_skip.view(-1, 1, 1, 1) * x_sigma + c_out.view(-1, 1, 1, 1) * prediction
    residual = prediction - x_sigma
    return corrected, residual, prediction

@torch.no_grad()
def edm_wrapper(model: nn.Module, x: torch.Tensor, sigma: torch.Tensor, sigma_data: float) -> torch.Tensor:
    corrected, _, _ = predictive_coding_corrector(model, x, sigma, sigma_data)
    return corrected

sigma_schedule = get_karras_schedule(config.steps, config.sigma_min, config.sigma_max, config.rho, device)
print(f"Schedule head: {sigma_schedule[:5]}")

Schedule head: tensor([80.0000, 74.6325, 69.5766, 64.8172, 60.3397], device='cuda:0')


In [7]:
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, betas=(0.9, 0.999), weight_decay=0.0)
scaler = torch.cuda.amp.GradScaler(enabled=config.amp and device.type == "cuda")

def predictive_coding_flow_loss(model: nn.Module, x0: torch.Tensor, cfg: TrainingConfig) -> Tuple[torch.Tensor, dict]:
    sigma = karras_sigma_distribution(x0.shape[0], x0.device)
    noise = torch.randn_like(x0)
    x_sigma = x0 + sigma.view(-1, 1, 1, 1) * noise
    corrected, residual, prediction = predictive_coding_corrector(model, x_sigma, sigma, cfg.sigma_data)
    loss = F.mse_loss(corrected, x0)
    metrics = {
        "sigma_mean": float(sigma.mean().item()),
        "residual_rms": float(residual.pow(2).mean().sqrt().item())
    }
    return loss, metrics

  scaler = torch.cuda.amp.GradScaler(enabled=config.amp and device.type == "cuda")


In [8]:
class StepController(nn.Module):
    def __init__(self, hidden_dim: int = 64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(3, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, features: torch.Tensor) -> torch.Tensor:
        return torch.tanh(self.net(features))

class LyapunovAdaptiveScheduler:
    def __init__(self, base_sigmas: torch.Tensor, sigma_data: float, alpha: float = 0.3,
                 min_scale: float = 0.5, max_scale: float = 1.6, controller: Optional[StepController] = None):
        self.base_sigmas = base_sigmas
        self.sigma_data = sigma_data
        self.alpha = alpha
        self.min_scale = min_scale
        self.max_scale = max_scale
        self.device = base_sigmas.device
        self.dtype = base_sigmas.dtype
        self.controller = controller.to(self.device) if controller is not None else None
        self.reset()

    def reset(self) -> None:
        self.index = 0
        self.prev_energy: Optional[torch.Tensor] = None

    def energy(self, x: torch.Tensor, sigma_value: torch.Tensor) -> torch.Tensor:
        sigma_batch = sigma_value.view(1).to(x.device, x.dtype).expand(x.shape[0])
        recon_term = (x / self.sigma_data).pow(2).mean(dim=(1, 2, 3))
        return 0.5 * recon_term + torch.log(sigma_batch + 1e-8)

    def scale_factor(self, drop: torch.Tensor, features: Optional[torch.Tensor]) -> float:
        if self.controller is not None and features is not None:
            features = features.to(self.device, dtype=self.dtype)
            raw = self.controller(features).mean()
            return float(torch.clamp(1.0 + raw, self.min_scale, self.max_scale))
        return float(torch.clamp(1.0 + self.alpha * drop.mean(), self.min_scale, self.max_scale))

    def current_sigma(self) -> torch.Tensor:
        return self.base_sigmas[self.index]

    def step(self, denoised: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        sigma_curr = self.base_sigmas[self.index]
        if self.index >= len(self.base_sigmas) - 1:
            return sigma_curr, self.base_sigmas[-1]
        base_next = self.base_sigmas[self.index + 1]
        energy_curr = self.energy(denoised, sigma_curr)
        if self.prev_energy is None:
            drop = torch.zeros_like(energy_curr)
        else:
            drop = self.prev_energy - energy_curr
        features = None
        if self.controller is not None:
            features = torch.stack([
                energy_curr.mean(),
                drop.mean(),
                torch.log(sigma_curr.clamp_min(1e-8))
            ], dim=0).view(1, -1)
        scale = self.scale_factor(drop, features)
        sigma_next = sigma_curr + (base_next - sigma_curr) * scale
        sigma_next = torch.clamp(sigma_next, min=float(self.base_sigmas[-1].item()), max=float(self.base_sigmas[0].item()))
        sigma_next = torch.minimum(sigma_next, sigma_curr)
        self.prev_energy = energy_curr.detach()
        self.index += 1
        return sigma_curr, sigma_next

controller = StepController()
scheduler = LyapunovAdaptiveScheduler(sigma_schedule, config.sigma_data, controller=controller)


In [9]:
@torch.no_grad()
def predictive_coding_heun(model: nn.Module, scheduler: LyapunovAdaptiveScheduler, shape: Tuple[int, int, int, int],
                           sigma_data: float, disable_tqdm: bool = False) -> torch.Tensor:
    scheduler.reset()
    sigmas = scheduler.base_sigmas
    x = torch.randn(shape, device=device) * sigmas[0]
    iterator = range(len(sigmas) - 1)
    if not disable_tqdm:
        iterator = tqdm(iterator, desc="Lyapunov PC Sampling", leave=False)
    for _ in iterator:
        sigma_curr_tensor = scheduler.current_sigma()
        sigma_batch = torch.full((shape[0],), float(sigma_curr_tensor.item()), device=device, dtype=x.dtype)
        corrected, residual, _ = predictive_coding_corrector(model, x, sigma_batch, sigma_data)
        sigma_curr_tensor, sigma_next_tensor = scheduler.step(corrected)
        sigma_curr = float(sigma_curr_tensor.item())
        sigma_next = float(sigma_next_tensor.item())
        sigma_curr_batch = torch.full((shape[0],), sigma_curr, device=device, dtype=x.dtype)
        corrected, residual, _ = predictive_coding_corrector(model, x, sigma_curr_batch, sigma_data)
        d = (x - corrected) / sigma_curr_batch.view(-1, 1, 1, 1)
        delta = sigma_next - sigma_curr
        x_euler = x + delta * d
        if abs(sigma_next) < 1e-12:
            x = x_euler
            break
        sigma_next_batch = torch.full((shape[0],), sigma_next, device=device, dtype=x.dtype)
        corrected_next, residual_next, _ = predictive_coding_corrector(model, x_euler, sigma_next_batch, sigma_data)
        d_next = (x_euler - corrected_next) / sigma_next_batch.view(-1, 1, 1, 1)
        x = x + 0.5 * delta * (d + d_next)
    return x

@torch.no_grad()
def decode_images(samples: torch.Tensor) -> torch.Tensor:
    return torch.clamp((samples + 1) / 2, 0.0, 1.0)

In [10]:
global_step = 0
loss_history: List[float] = []

def train(model: nn.Module, loader: DataLoader, cfg: TrainingConfig) -> None:
    global global_step
    model.train()
    for epoch in range(cfg.epochs):
        pbar = tqdm(loader, desc=f"Epoch {epoch + 1}/{cfg.epochs}")
        for batch_idx, (images, _) in enumerate(pbar):
            images = images.to(device)
            with torch.cuda.amp.autocast(enabled=cfg.amp and device.type == "cuda"):
                loss, metrics = predictive_coding_flow_loss(model, images, cfg)
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            ema.update()
            loss_history.append(loss.item())
            global_step += 1
            if global_step % cfg.log_interval == 0:
                pbar.set_postfix(loss=f"{loss.item():.4f}", rms=f"{metrics['residual_rms']:.3f}")
            if global_step % cfg.sample_interval == 0:
                with ema.average_parameters():
                    samples = predictive_coding_heun(model, scheduler, (64, 3, 32, 32), cfg.sigma_data, disable_tqdm=True)
                decoded = decode_images(samples)
                save_path = os.path.join(cfg.out_dir, f"samples_step_{global_step}.png")
                save_image(decoded, save_path, nrow=8)
        torch.save(model.state_dict(), os.path.join(cfg.out_dir, f"model_epoch_{epoch + 1}.pth"))
        with ema.average_parameters():
            torch.save(model.state_dict(), os.path.join(cfg.out_dir, f"ema_epoch_{epoch + 1}.pth"))

### Training Entry Point
- Run the next cell to launch predictive-coding diffusion training with the Lyapunov adaptive scheduler.
- The loop uses the forward predictive correction (flow matching) and updates EMA + sampling hooks automatically.

In [None]:
def run_training_and_evaluation(cfg: TrainingConfig, loader: DataLoader, epochs_override: Optional[int] = None,
                                  fid_total: int = 50000, fid_batch: int = 256) -> Optional[float]:
    original_epochs = cfg.epochs
    if epochs_override is not None:
        cfg.epochs = epochs_override
    try:
        train(model, loader, cfg)
    finally:
        cfg.epochs = original_epochs
    eval_scheduler = LyapunovAdaptiveScheduler(sigma_schedule, cfg.sigma_data, controller=controller)
    with ema.average_parameters():
        fid_score = evaluate_fid(model, eval_scheduler, cfg, total=fid_total, batch=fid_batch)
    if fid_score is not None:
        print(f"Post-training FID (EMA): {fid_score:.3f}")
    return fid_score

# Example usage:
# fid = run_training_and_evaluation(config, train_loader, epochs_override=1, fid_total=1000, fid_batch=128)

In [12]:
@torch.no_grad()
def sample_grid(model: nn.Module, scheduler: LyapunovAdaptiveScheduler, cfg: TrainingConfig, filename: str = "preview.png", batch: int = 64) -> str:
    model.eval()
    samples = predictive_coding_heun(model, scheduler, (batch, 3, 32, 32), cfg.sigma_data)
    decoded = decode_images(samples)
    out_path = os.path.join(cfg.out_dir, filename)
    save_image(decoded, out_path, nrow=8)
    return out_path

preview_path = sample_grid(model, scheduler, config)
print(f"Preview saved to {preview_path}")

Lyapunov PC Sampling:   0%|          | 0/80 [00:00<?, ?it/s]

Preview saved to runs/lyapunov_pc/preview.png


In [13]:
def compute_clean_fid(samples_dir: str, dataset: str = "cifar10", dataset_res: int = 32) -> Optional[float]:
    if cleanfid is None:
        print("CleanFID not installed. Install via `pip install clean-fid` to enable.")
        return None
    return cleanfid.compute_fid(samples_dir, dataset_name=dataset, mode="clean", dataset_res=dataset_res, dataset_split="train")

def evaluate_fid(model: nn.Module, scheduler: LyapunovAdaptiveScheduler, cfg: TrainingConfig, total: int = 50000, batch: int = 256) -> Optional[float]:
    model.eval()
    gen_dir = os.path.join(cfg.out_dir, "generated")
    os.makedirs(gen_dir, exist_ok=True)
    produced = 0
    with torch.no_grad():
        while produced < total:
            current_batch = min(batch, total - produced)
            samples = predictive_coding_heun(model, scheduler, (current_batch, 3, 32, 32), cfg.sigma_data, disable_tqdm=True)
            decoded = (decode_images(samples) * 255).to(torch.uint8).cpu()
            for idx in range(current_batch):
                img = decoded[idx].permute(1, 2, 0).numpy()
                Image.fromarray(img).save(os.path.join(gen_dir, f"img_{produced + idx:05d}.png"))
            produced += current_batch
            if produced % 1000 == 0:
                print(f"Generated {produced}/{total} images")
    fid_score = compute_clean_fid(gen_dir)
    if fid_score is not None:
        print(f"FID: {fid_score:.3f}")
    return fid_score

## Next Steps

- Train the controller with REINFORCE or GFlowNet-style objectives keyed on FID or Lyapunov drop.
- Tune `alpha`, `steps`, and channel multipliers to align with the <2 FID target.
- Integrate distributed training and mixed precision logging to reproduce EDM reference runs.