# Predictive-Coding Style Diffusion on CIFAR-10

This notebook implements a **predictive-coding-inspired** diffusion model:

- Each 'layer' corresponds to a diffusion step.
- The model predicts noise `ε` and reconstructs an estimate of the clean image `x₀`.
- We add a predictive-coding auxiliary loss that penalizes the mismatch between the predicted noisy state `x̂_t` and the true `x_t`.
- Training samples multiple timesteps per batch to emulate a stack of predictive layers.
- Sampling uses a deterministic DDIM-like sampler.

**Health checks** are embedded to confirm imports, CUDA, dataset wiring, and optional FID dependencies.


In [None]:
import os, math, random, time
from dataclasses import dataclass
from typing import Tuple, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

# torchvision
try:
    import torchvision
    from torchvision import transforms
    from torchvision.utils import make_grid, save_image
    TV_OK = True
except Exception as e:
    TV_OK = False
    print('torchvision import failed:', e)

# tqdm
try:
    from tqdm.auto import tqdm
    TQDM_OK = True
except Exception as e:
    TQDM_OK = False
    print('tqdm import failed:', e)

# Optional FID
try:
    from torchmetrics.image.fid import FrechetInceptionDistance
    TM_FID_OK = True
except Exception as e:
    TM_FID_OK = False
    print('torchmetrics FID not available:', e)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)
print('torch version:', torch.__version__)
print('torchvision available:', TV_OK)
print('tqdm available:', TQDM_OK)
print('torchmetrics FID available:', TM_FID_OK)
assert torch.__version__ >= '1.10', 'Please use torch>=1.10'


In [None]:
from dataclasses import dataclass
@dataclass
class Config:
    data_root: str = './data'
    image_size: int = 32
    num_channels: int = 3
    use_fake_data: bool = True
    fake_data_len: int = 512

    batch_size: int = 64
    epochs: int = 1
    lr: float = 2e-4
    grad_clip: float = 1.0
    num_workers: int = 2
    pc_layers_per_batch: int = 4
    lambda_pc: float = 0.1
    ema_decay: float = 0.999

    timesteps: int = 1000
    sampling_steps: int = 50
    ddim_eta: float = 0.0

    out_dir: str = './runs/pcn_diffusion_cifar'
    save_every: int = 1
    sample_grids: int = 64

cfg = Config()
os.makedirs(cfg.out_dir, exist_ok=True)
print(cfg)


In [None]:
class FakeCIFAR10(Dataset):
    def __init__(self, n=512, img_size=32):
        self.n = n
        self.img_size = img_size
    def __len__(self):
        return self.n
    def __getitem__(self, idx):
        x = torch.rand(3, self.img_size, self.img_size)*2-1
        return x, 0

def get_dataloaders(cfg: Config):
    if cfg.use_fake_data or not TV_OK:
        train_ds = FakeCIFAR10(cfg.fake_data_len, cfg.image_size)
        test_ds  = FakeCIFAR10(cfg.fake_data_len//4, cfg.image_size)
        print('[DATA] Using FakeCIFAR10 for smoke test.')
    else:
        transform = transforms.Compose([
            transforms.Resize(cfg.image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5)),
        ])
        train_ds = torchvision.datasets.CIFAR10(root=cfg.data_root, train=True, download=True, transform=transform)
        test_ds  = torchvision.datasets.CIFAR10(root=cfg.data_root, train=False, download=True, transform=transform)
        print('[DATA] Using real CIFAR-10.')
    train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, pin_memory=True)
    test_loader  = DataLoader(test_ds,  batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True)
    return train_loader, test_loader

train_loader, test_loader = get_dataloaders(cfg)
xb, yb = next(iter(train_loader))
print('Batch:', xb.shape, xb.min().item(), xb.max().item())
assert xb.shape[2] == cfg.image_size and xb.shape[3] == cfg.image_size


In [None]:
def cosine_beta_schedule(T, s=0.008):
    t = torch.linspace(0, T, T+1, dtype=torch.float32)
    f = torch.cos(((t/T + s) / (1+s)) * math.pi / 2)**2
    alphas_cumprod = f / f[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return betas.clamp(1e-8, 0.999)

class Diffusion:
    def __init__(self, timesteps: int):
        betas = cosine_beta_schedule(timesteps)
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), alphas_cumprod[:-1]], dim=0)
        self.timesteps = timesteps
        self.betas = betas
        self.alphas = alphas
        self.alphas_cumprod = alphas_cumprod
        self.alphas_cumprod_prev = alphas_cumprod_prev
        self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod)
        self.sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
        self.one_over_sqrt_alphas = 1.0 / torch.sqrt(alphas)
        self.posterior_mean_coef1 = (1 - alphas_cumprod_prev) / (1 - alphas_cumprod) * torch.sqrt(alphas)
        self.posterior_mean_coef2 = (torch.sqrt(alphas_cumprod_prev) * betas) / (1 - alphas_cumprod)
        self.posterior_variance = betas * (1 - alphas_cumprod_prev) / (1 - alphas_cumprod)

    def q_sample(self, x0, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x0)
        sqrt_ac = self.sqrt_alphas_cumprod.to(x0.device)[t]
        sqrt_om = self.sqrt_one_minus_alphas_cumprod.to(x0.device)[t]
        while sqrt_ac.ndim < x0.ndim:
            sqrt_ac = sqrt_ac[..., None]
            sqrt_om = sqrt_om[..., None]
        return sqrt_ac * x0 + sqrt_om * noise

    def predict_x0_from_eps(self, x_t, t, eps):
        sqrt_ac = self.sqrt_alphas_cumprod.to(x_t.device)[t]
        sqrt_om = self.sqrt_one_minus_alphas_cumprod.to(x_t.device)[t]
        while sqrt_ac.ndim < x_t.ndim:
            sqrt_ac = sqrt_ac[..., None]
            sqrt_om = sqrt_om[..., None]
        return (x_t - sqrt_om * eps) / (sqrt_ac + 1e-8)

    def predict_x_t_from_x0(self, x0, t, eps):
        sqrt_ac = self.sqrt_alphas_cumprod.to(x0.device)[t]
        sqrt_om = self.sqrt_one_minus_alphas_cumprod.to(x0.device)[t]
        while sqrt_ac.ndim < x0.ndim:
            sqrt_ac = sqrt_ac[..., None]
            sqrt_om = sqrt_om[..., None]
        return sqrt_ac * x0 + sqrt_om * eps

diff = Diffusion(cfg.timesteps)
print('Timesteps:', diff.timesteps, '| betas shape:', diff.betas.shape)


In [None]:
class SinusoidalTimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    def forward(self, t):
        half = self.dim // 2
        # fixed frequency range
        freqs = torch.exp(torch.linspace(math.log(1.0), math.log(10000.0), steps=half, device=t.device))
        t = t.float()[:, None]
        emb = t * freqs[None, :]
        return torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)

class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, t_dim):
        super().__init__()
        self.norm1 = nn.GroupNorm(8, in_ch)
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.norm2 = nn.GroupNorm(8, out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.time_mlp = nn.Sequential(nn.SiLU(), nn.Linear(t_dim, out_ch))
        self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
    def forward(self, x, t_emb):
        h = self.conv1(F.silu(self.norm1(x)))
        t = self.time_mlp(t_emb)[:, :, None, None]
        h = h + t
        h = self.conv2(F.silu(self.norm2(h)))
        return h + self.skip(x)

class Down(nn.Module):
    def __init__(self, in_ch, out_ch, t_dim):
        super().__init__()
        self.block1 = ResBlock(in_ch, out_ch, t_dim)
        self.block2 = ResBlock(out_ch, out_ch, t_dim)
        self.down = nn.Conv2d(out_ch, out_ch, 4, stride=2, padding=1)
    def forward(self, x, t_emb):
        x = self.block1(x, t_emb)
        x = self.block2(x, t_emb)
        skip = x
        x = self.down(x)
        return x, skip

class Up(nn.Module):
    def __init__(self, in_ch, out_ch, t_dim):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
        self.block1 = ResBlock(out_ch*2, out_ch, t_dim)
        self.block2 = ResBlock(out_ch, out_ch, t_dim)
    def forward(self, x, skip, t_emb):
        x = self.up(x)
        x = torch.cat([x, skip], dim=1)
        x = self.block1(x, t_emb)
        x = self.block2(x, t_emb)
        return x

class UNet(nn.Module):
    def __init__(self, in_ch=3, base=64, time_dim=256):
        super().__init__()
        self.time_dim = time_dim
        self.time_emb = SinusoidalTimeEmbedding(time_dim)
        self.time_mlp = nn.Sequential(
            nn.Linear(time_dim, time_dim*4), nn.SiLU(), nn.Linear(time_dim*4, time_dim)
        )
        self.in_conv = nn.Conv2d(in_ch, base, 3, padding=1)
        self.down1 = Down(base, base*2, time_dim)
        self.down2 = Down(base*2, base*4, time_dim)
        self.mid1 = ResBlock(base*4, base*4, time_dim)
        self.mid2 = ResBlock(base*4, base*4, time_dim)
        self.up2 = Up(base*4, base*2, time_dim)
        self.up1 = Up(base*2, base, time_dim)
        self.out_norm = nn.GroupNorm(8, base)
        self.out_conv = nn.Conv2d(base, in_ch, 3, padding=1)
    def forward(self, x, t):
        t_emb = self.time_mlp(self.time_emb(t))
        x = self.in_conv(x)
        x, s1 = self.down1(x, t_emb)
        x, s2 = self.down2(x, t_emb)
        x = self.mid1(x, t_emb)
        x = self.mid2(x, t_emb)
        x = self.up2(x, s2, t_emb)
        x = self.up1(x, s1, t_emb)
        x = self.out_conv(F.silu(self.out_norm(x)))
        return x

net = UNet(in_ch=cfg.num_channels).to(device)
print('UNet params (M):', sum(p.numel() for p in net.parameters())/1e6)
xb_small = torch.randn(2, cfg.num_channels, cfg.image_size, cfg.image_size, device=device)
tb = torch.randint(0, cfg.timesteps, (2,), device=device)
with torch.no_grad():
    out = net(xb_small, tb)
print('UNet out:', out.shape)
assert out.shape == xb_small.shape


In [None]:
class EMA:
    def __init__(self, model, decay=0.999):
        self.decay = decay
        self.shadow = {}
        for name, p in model.named_parameters():
            if p.requires_grad:
                self.shadow[name] = p.data.clone()
    def update(self, model):
        for name, p in model.named_parameters():
            if p.requires_grad:
                new_avg = (1.0 - self.decay) * p.data + self.decay * self.shadow[name]
                self.shadow[name] = new_avg.clone()
    def apply_to(self, model):
        for name, p in model.named_parameters():
            if p.requires_grad:
                p.data.copy_(self.shadow[name])

def eps_mse_loss(pred_eps, eps):
    return F.mse_loss(pred_eps, eps)

def predictive_coding_aux_loss(x_t_hat, x_t):
    return F.mse_loss(x_t_hat, x_t)


In [None]:
def train_one_epoch(model, ema, opt, loader, diff: Diffusion, cfg: Config, epoch: int):
    model.train()
    iterable = tqdm(loader, desc=f'Epoch {epoch}', leave=False) if TQDM_OK else loader
    total_loss = 0.0
    for batch in iterable:
        x0, _ = batch
        x0 = x0.to(device)
        B = x0.size(0)
        tset = torch.randint(0, diff.timesteps, (B, cfg.pc_layers_per_batch), device=device)
        loss_eps_sum = 0.0
        loss_pc_sum  = 0.0
        for k in range(cfg.pc_layers_per_batch):
            t = tset[:, k]
            noise = torch.randn_like(x0)
            x_t = diff.q_sample(x0, t, noise=noise)
            pred_eps = model(x_t, t)
            loss_eps = eps_mse_loss(pred_eps, noise)
            x0_hat = diff.predict_x0_from_eps(x_t, t, pred_eps).clamp(-1, 1)
            x_t_hat = diff.predict_x_t_from_x0(x0_hat, t, pred_eps.detach())
            loss_pc = predictive_coding_aux_loss(x_t_hat, x_t)
            loss_eps_sum = loss_eps_sum + loss_eps
            loss_pc_sum  = loss_pc_sum  + loss_pc
        loss = loss_eps_sum / cfg.pc_layers_per_batch + cfg.lambda_pc * (loss_pc_sum / cfg.pc_layers_per_batch)
        opt.zero_grad(set_to_none=True)
        loss.backward()
        if cfg.grad_clip is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
        opt.step()
        ema.update(model)
        total_loss += loss.item()
        if TQDM_OK:
            iterable.set_postfix(loss=f'{loss.item():.4f}')
    return total_loss / max(1, len(loader))

def save_samples_ddim(model, diff: Diffusion, cfg: Config, n: int = 64, fname: str = 'samples.png'):
    model.eval()
    with torch.no_grad():
        x = torch.randn(n, cfg.num_channels, cfg.image_size, cfg.image_size, device=device)
        t_seq = torch.linspace(diff.timesteps-1, 0, cfg.sampling_steps, device=device).long()
        for i, t in enumerate(t_seq):
            t_b = torch.full((n,), int(t.item()), device=device, dtype=torch.long)
            eps = model(x, t_b)
            a_bar_t = diff.sqrt_alphas_cumprod.to(device)[t_b]
            while a_bar_t.ndim < x.ndim:
                a_bar_t = a_bar_t[..., None]
            if i == len(t_seq) - 1:
                x = (x - (1 - a_bar_t**2).sqrt() * eps) / (a_bar_t + 1e-8)
            else:
                t_next = t_seq[i+1]
                a_bar_next = diff.sqrt_alphas_cumprod.to(device)[t_next]
                while a_bar_next.ndim < x.ndim:
                    a_bar_next = a_bar_next[..., None]
                x0_hat = (x - (1 - a_bar_t**2).sqrt() * eps) / (a_bar_t + 1e-8)
                dir_xt = (1 - a_bar_next**2).sqrt() * eps
                x = a_bar_next * x0_hat + dir_xt
        x = x.clamp(-1, 1)
        grid = make_grid((x+1)/2, nrow=int(math.sqrt(n)))
        save_path = os.path.join(cfg.out_dir, fname)
        save_image(grid, save_path)
        return save_path


In [None]:
opt = torch.optim.AdamW(net.parameters(), lr=cfg.lr)
ema = EMA(net, decay=cfg.ema_decay)
steps_limit = 2 if cfg.use_fake_data else None
net.train()
epoch_loss = 0.0
batches = 0
it_ = tqdm(train_loader, desc='SmokeTrain', leave=False) if TQDM_OK else train_loader
for i, batch in enumerate(it_):
    if steps_limit is not None and i >= steps_limit:
        break
    loss = train_one_epoch(net, ema, opt, [batch], diff, cfg, epoch=0)
    epoch_loss += loss
    batches += 1
    if TQDM_OK:
        it_.set_postfix(avg_loss=f'{(epoch_loss/max(1,batches)):.4f}')
print('Smoke training avg loss:', epoch_loss/max(1,batches))
ema.apply_to(net)
sample_path = save_samples_ddim(net, diff, cfg, n=min(cfg.sample_grids, 64), fname='samples_smoke.png')
print('Saved sample grid to:', sample_path)


In [None]:
def compute_fid_samples_vs_test(model, diff: Diffusion, cfg: Config, n_gen: int = 5000):
    if not TM_FID_OK:
        print('torchmetrics FID not available. Install torchmetrics to enable.')
        return None
    model.eval()
    fid = FrechetInceptionDistance(feature=2048, reset_real_features=True).to(device)
    count_real = 0
    for xb, _ in test_loader:
        xb = xb.to(device)
        xb_clamped = ((xb + 1)/2).clamp(0,1)
        fid.update((xb_clamped*255).byte(), real=True)
        count_real += xb.size(0)
        if count_real >= n_gen:
            break
    gen_batch = cfg.batch_size
    made = 0
    while made < n_gen:
        with torch.no_grad():
            x = torch.randn(gen_batch, cfg.num_channels, cfg.image_size, cfg.image_size, device=device)
            t_seq = torch.linspace(diff.timesteps-1, 0, cfg.sampling_steps, device=device).long()
            for i, t in enumerate(t_seq):
                t_b = torch.full((gen_batch,), int(t.item()), device=device, dtype=torch.long)
                eps = model(x, t_b)
                a_bar_t = diff.sqrt_alphas_cumprod.to(device)[t_b]
                while a_bar_t.ndim < x.ndim:
                    a_bar_t = a_bar_t[..., None]
                if i == len(t_seq) - 1:
                    x = (x - (1 - a_bar_t**2).sqrt() * eps) / (a_bar_t + 1e-8)
                else:
                    t_next = t_seq[i+1]
                    a_bar_next = diff.sqrt_alphas_cumprod.to(device)[t_next]
                    while a_bar_next.ndim < x.ndim:
                        a_bar_next = a_bar_next[..., None]
                    x0_hat = (x - (1 - a_bar_t**2).sqrt() * eps) / (a_bar_t + 1e-8)
                    dir_xt = (1 - a_bar_next**2).sqrt() * eps
                    x = a_bar_next * x0_hat + dir_xt
            x = ((x.clamp(-1,1) + 1)/2).clamp(0,1)
            fid.update((x*255).byte(), real=False)
            made += gen_batch
    score = float(fid.compute().cpu().item())
    print(f'FID (approx, n_gen={n_gen}):', score)
    return score
print('FID function ready (may require internet to download Inception weights).')


### Artifacts

- A small sample grid was generated during the smoke test and saved in the run directory.
- For full training on real CIFAR-10, set `use_fake_data=False` in the **Config** cell and run end-to-end.


In [None]:
print('Notebook setup complete.')
