In [1]:
import math, os, time, random, functools
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision as tv
import torchvision.transforms as T
import torchvision.utils as tvu

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True

def seed_all(seed: int = 42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
seed_all(42)

def count_params(m: nn.Module) -> int:
    return sum(p.numel() for p in m.parameters() if p.requires_grad)


In [2]:
@dataclass
class TrainCfg:
    # data
    dataset: str = 'CIFAR10'
    data_root: str = './data'
    channels: int = 3
    H: int = 32
    W: int = 32
    # model
    emb_dim: int = 128
    C_hidden: int = 192        # divisible by 32 â†’ safe for GroupNorm
    K: int = 10                # predictive-coding iterations
    # edm preconditioning
    sigma_data: float = 0.5
    sigma_min: float = 0.002
    sigma_max: float = 80.0
    rho: float = 7.0           # EDM schedule exponent
    # train
    batch_size: int = 256
    num_workers: int = 4
    epochs: int = 50
    lr: float = 2e-4
    wd: float = 0.0
    ema_decay: float = 0.999
    log_every: int = 100

cfg = TrainCfg()
print(cfg)


TrainCfg(dataset='CIFAR10', data_root='./data', channels=3, H=32, W=32, emb_dim=128, C_hidden=192, K=10, sigma_data=0.5, sigma_min=0.002, sigma_max=80.0, rho=7.0, batch_size=256, num_workers=4, epochs=50, lr=0.0002, wd=0.0, ema_decay=0.999, log_every=100)


In [3]:
class FourierEmbedding(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim
        half = dim // 2
        # fixed frequencies
        self.register_buffer('freqs', torch.exp(torch.linspace(math.log(1.0), math.log(10000.0), half)))

    def forward(self, sigma: torch.Tensor):  # sigma: (B,) or (B,1,1,1)
        sigma = sigma.view(-1)  # (B,)
        x = sigma[:, None] * self.freqs[None, :]
        emb = torch.cat([x.sin(), x.cos()], dim=-1)
        if emb.shape[1] < self.dim:
            emb = F.pad(emb, (0, self.dim - emb.shape[1]))
        return emb  # (B, dim)


In [4]:
def _groups_for_channels(C: int, max_groups: int = 32) -> int:
    # largest divisor of C not exceeding max_groups
    for g in reversed(range(1, max_groups + 1)):
        if C % g == 0:
            return g
    return 1

class FiLM(nn.Module):
    def __init__(self, emb_dim: int, hidden: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.SiLU(),
            nn.Linear(emb_dim, 2 * hidden)
        )

    def forward(self, emb: torch.Tensor):
        # returns gamma, beta of shape (B, hidden)
        gb = self.net(emb)
        g, b = gb.chunk(2, dim=1)
        return g, b

class PCBlock(nn.Module):
    """
    One predictive-coding update:
      e = x_sigma - x_hat
      h = Conv(cat[x_sigma, e]) -> GN -> SiLU -> FiLM -> Conv
      delta = Conv(h) -> add to x_hat
    All GroupNorm on hidden channels only.
    """
    def __init__(self, in_ch: int, hidden: int, emb_dim: int):
        super().__init__()
        self.conv_in = nn.Conv2d(in_ch * 2, hidden, 3, padding=1)
        g = _groups_for_channels(hidden, 32)
        self.norm_h = nn.GroupNorm(g, hidden)
        self.film = FiLM(emb_dim, hidden)
        self.act = nn.SiLU()
        self.conv_mid = nn.Conv2d(hidden, hidden, 3, padding=1)
        self.norm_mid = nn.GroupNorm(g, hidden)
        self.conv_out = nn.Conv2d(hidden, in_ch, 3, padding=1)

    def forward(self, x_sigma: torch.Tensor, x_hat: torch.Tensor, emb: torch.Tensor):
        e = x_sigma - x_hat
        h = self.conv_in(torch.cat([x_sigma, e], dim=1))
        h = self.norm_h(h)
        g, b = self.film(emb)  # (B, hidden)
        g = g[:, :, None, None]
        b = b[:, :, None, None]
        h = h * (1 + g) + b
        h = self.act(h)
        h = self.conv_mid(h)
        h = self.norm_mid(h)
        h = self.act(h)
        delta = self.conv_out(h)
        return delta

class PredictiveCodingCore(nn.Module):
    def __init__(self, in_ch: int, hidden: int, K: int, emb_dim: int):
        super().__init__()
        self.blocks = nn.ModuleList([PCBlock(in_ch, hidden, emb_dim) for _ in range(K)])
        self.final = nn.Conv2d(in_ch, in_ch, 1)

    def forward(self, x_sigma: torch.Tensor, emb: torch.Tensor):
        # start from zero estimate; update K times
        x_hat = torch.zeros_like(x_sigma)
        for blk in self.blocks:
            x_hat = x_hat + blk(x_sigma, x_hat, emb)
        return self.final(x_hat)


In [5]:
class DenoiserEDM(nn.Module):
    """
    Preconditioned network (Karras et al. 2022 EDM style):
      y = cskip * x + cout * F(cin * x, emb)
    The core F returns a residual in image space.
    """
    def __init__(self, core: nn.Module, emb_dim: int, sigma_data: float = 0.5):
        super().__init__()
        self.core = core
        self.sigma_data = sigma_data
        self.emb = FourierEmbedding(emb_dim)

    def _coeffs(self, sigma: torch.Tensor):
        # sigma: (B,1,1,1)
        s2 = sigma**2
        sd2 = self.sigma_data**2
        cskip = sd2 / (s2 + sd2)
        cin   = 1.0 / torch.sqrt(s2 + sd2)
        cout  = sigma * self.sigma_data / torch.sqrt(s2 + sd2)
        return cskip, cout, cin

    def forward(self, x: torch.Tensor, sigma: torch.Tensor):
        """
        x: (B,C,H,W)
        sigma: (B,) or (B,1,1,1)
        returns denoised prediction at noise level sigma
        """
        if sigma.dim() == 1:
            sigma_img = sigma[:, None, None, None]
        else:
            sigma_img = sigma
            sigma = sigma.view(-1)

        cskip, cout, cin = self._coeffs(sigma_img)
        x_in = cin * x
        emb = self.emb(torch.log(sigma + 1e-8))  # use log sigma
        h = self.core(x_in, emb)
        y = cskip * x + cout * h
        return y


In [6]:
@torch.no_grad()
def edm_sigma_schedule(steps: int, sigma_min: float, sigma_max: float, rho: float, device):
    i = torch.arange(steps, device=device, dtype=torch.float32)
    t = sigma_max**(1/rho) + (i / (steps - 1)) * (sigma_min**(1/rho) - sigma_max**(1/rho))
    sigmas = t**rho
    return sigmas  # decreasing

@torch.no_grad()
def heun_sampler(denoiser: nn.Module, batch_size: int, channels: int, height: int, width: int,
                 steps: int = 40, sigma_min: float = 0.002, sigma_max: float = 80.0, rho: float = 7.0,
                 device=device):
    sigmas = edm_sigma_schedule(steps, sigma_min, sigma_max, rho, device)
    # start from N(0, sigma_max^2 I)
    x = torch.randn(batch_size, channels, height, width, device=device) * sigmas[0]
    for i in range(steps - 1):
        s_i = sigmas[i]
        s_j = sigmas[i + 1]
        s_i_img = s_i.view(1, 1, 1, 1)
        s_j_img = s_j.view(1, 1, 1, 1)
        # velocity at s_i
        d_i = (denoiser(x, s_i.expand(batch_size)) - x) / s_i_img
        x_euler = x + (s_j - s_i) * d_i
        # velocity at s_j
        d_j = (denoiser(x_euler, s_j.expand(batch_size)) - x_euler) / s_j_img
        # Heun (trapezoid)
        x = x + (s_j - s_i) * 0.5 * (d_i + d_j)
    # final clamp to [-1,1]
    return x.clamp(-1, 1)


In [7]:
def sample_sigma(batch_size: int, sigma_min: float, sigma_max: float, device):
    u = torch.rand(batch_size, device=device)
    log_min, log_max = math.log(sigma_min), math.log(sigma_max)
    sigma = torch.exp(u * (log_max - log_min) + log_min)
    return sigma  # (B,)

def edm_loss(denoiser: nn.Module, x0: torch.Tensor, sigma: torch.Tensor):
    """
    Train target is the clean image x0; denoiser(x_noisy, sigma) should predict clean x.
    """
    if sigma.dim() == 1:
        sigma_img = sigma[:, None, None, None]
    else:
        sigma_img = sigma
    noise = torch.randn_like(x0)
    x_noisy = x0 + sigma_img * noise
    x_pred = denoiser(x_noisy, sigma)
    loss = F.mse_loss(x_pred, x0, reduction='mean')
    return loss


In [8]:
class EMA:
    def __init__(self, model: nn.Module, decay: float = 0.999):
        self.decay = decay
        self.shadow = {}
        self.collected = {}
        for name, p in model.named_parameters():
            if p.requires_grad:
                self.shadow[name] = p.data.clone()

    @torch.no_grad()
    def update(self, model: nn.Module):
        for name, p in model.named_parameters():
            if not p.requires_grad: 
                continue
            assert name in self.shadow
            new_avg = self.decay * self.shadow[name] + (1.0 - self.decay) * p.data
            self.shadow[name] = new_avg.clone()

    @torch.no_grad()
    def copy_to(self, model: nn.Module):
        for name, p in model.named_parameters():
            if not p.requires_grad: 
                continue
            p.data.copy_(self.shadow[name])


In [9]:
def get_cifar10(root: str, batch_size: int, num_workers: int):
    tfm = T.Compose([
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])  # to [-1,1]
    ])
    ds = tv.datasets.CIFAR10(root=root, train=True, download=True, transform=tfm)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True)
    return dl

train_loader = get_cifar10(cfg.data_root, cfg.batch_size, cfg.num_workers)


In [10]:
core = PredictiveCodingCore(in_ch=cfg.channels, hidden=cfg.C_hidden, K=cfg.K, emb_dim=cfg.emb_dim)
denoiser = DenoiserEDM(core=core, emb_dim=cfg.emb_dim, sigma_data=cfg.sigma_data).to(device)
print('Params (M):', count_params(denoiser)/1e6)

opt = torch.optim.AdamW(denoiser.parameters(), lr=cfg.lr, weight_decay=cfg.wd, betas=(0.9, 0.999))
ema = EMA(denoiser, decay=cfg.ema_decay)


Params (M): 3.980202


In [11]:
def train(model: nn.Module, opt, ema: EMA, loader, epochs: int, cfg: TrainCfg):
    model.train()
    global_step = 0
    for ep in range(1, epochs+1):
        t0 = time.time()
        run_loss = 0.0
        for it, (x, _) in enumerate(loader):
            x = x.to(device, non_blocking=True)
            sigma = sample_sigma(x.size(0), cfg.sigma_min, cfg.sigma_max, device)
            loss = edm_loss(model, x, sigma)

            opt.zero_grad(set_to_none=True)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            ema.update(model)

            run_loss += loss.item()
            global_step += 1
            if global_step % cfg.log_every == 0:
                print(f'ep {ep:03d} it {it:05d} | loss {run_loss/cfg.log_every:.5f}')
                run_loss = 0.0
        print(f'epoch {ep:03d} | time {time.time()-t0:.1f}s')

# Example: train for a few epochs to sanity-check (increase to cfg.epochs for full run)
train(denoiser, opt, ema, train_loader, epochs=50, cfg=cfg)


ep 001 it 00099 | loss 0.09555
epoch 001 | time 12.5s
ep 002 it 00004 | loss 0.00382
ep 002 it 00104 | loss 0.07575
epoch 002 | time 11.9s
ep 003 it 00009 | loss 0.00736
ep 003 it 00109 | loss 0.07374
epoch 003 | time 11.8s
ep 004 it 00014 | loss 0.01064
ep 004 it 00114 | loss 0.07252
epoch 004 | time 11.8s
ep 005 it 00019 | loss 0.01455
ep 005 it 00119 | loss 0.07025
epoch 005 | time 12.0s
ep 006 it 00024 | loss 0.01792
ep 006 it 00124 | loss 0.07039
epoch 006 | time 12.1s
ep 007 it 00029 | loss 0.02105
ep 007 it 00129 | loss 0.07088
epoch 007 | time 11.9s
ep 008 it 00034 | loss 0.02415
ep 008 it 00134 | loss 0.06960
epoch 008 | time 11.9s
ep 009 it 00039 | loss 0.02821
ep 009 it 00139 | loss 0.07117
epoch 009 | time 11.8s
ep 010 it 00044 | loss 0.03098
ep 010 it 00144 | loss 0.06960
epoch 010 | time 11.8s
ep 011 it 00049 | loss 0.03531
ep 011 it 00149 | loss 0.06883
epoch 011 | time 11.9s
ep 012 it 00054 | loss 0.03765
ep 012 it 00154 | loss 0.07001
epoch 012 | time 11.8s
ep 013 it 0

In [12]:
# Copy EMA weights into the model for sampling
ema.copy_to(denoiser)
denoiser.eval()
with torch.no_grad():
    imgs = heun_sampler(
        denoiser,
        batch_size=16,
        channels=cfg.channels,
        height=cfg.H,
        width=cfg.W,
        steps=40,
        sigma_min=cfg.sigma_min,
        sigma_max=cfg.sigma_max,
        rho=cfg.rho,
        device=device
    )
grid = tvu.make_grid((imgs + 1) * 0.5, nrow=4)  # back to [0,1]
os.makedirs('samples', exist_ok=True)
tv.utils.save_image(grid, 'samples/edm_pc_cifar10.png')
print('Saved to samples/edm_pc_cifar10.png')


Saved to samples/edm_pc_cifar10.png
