In [1]:
# Cell 1: imports and utils
import math, os, time, json, random
from pathlib import Path
from dataclasses import dataclass
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torchvision import datasets, transforms, utils as tvu

try:
    import torch_fidelity
    HAS_TORCH_FIDELITY = True
except Exception:
    HAS_TORCH_FIDELITY = False

SEED = 1234
torch.manual_seed(SEED); np.random.seed(SEED); random.seed(SEED)
torch.backends.cudnn.benchmark = True

def count_params(m: nn.Module):
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

def to_device(x, device):
    return x.to(device, non_blocking=True)

def set_requires_grad(m: nn.Module, flag: bool):
    for p in m.parameters():
        p.requires_grad = flag


In [2]:
# Cell 2: CIFAR-10 dataloader
def get_cifar10(batch_size=256, num_workers=4, root="./data"):
    tfm = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x*2.0 - 1.0),  # [-1, 1]
    ])
    train = datasets.CIFAR10(root=root, train=True, transform=tfm, download=True)
    test  = datasets.CIFAR10(root=root, train=False, transform=tfm, download=True)
    train_loader = data.DataLoader(train, batch_size=batch_size, shuffle=True,
                                   num_workers=num_workers, pin_memory=True, drop_last=True)
    test_loader  = data.DataLoader(test,  batch_size=batch_size, shuffle=False,
                                   num_workers=num_workers, pin_memory=True, drop_last=False)
    return train_loader, test_loader


In [3]:
# Cell 3: embeddings and predictive-coding core (bug-fixed)

class GaussianFourierProjection(nn.Module):
    """Gaussian Fourier features of log(sigma)."""
    def __init__(self, emb_dim=128, scale=16.0):
        super().__init__()
        self.W = nn.Parameter(torch.randn(emb_dim//2) * scale, requires_grad=False)

    def forward(self, sigma):  # sigma: (B,)
        # Use log-sigma for better conditioning
        x = sigma.clamp(min=1e-8).log().unsqueeze(-1)  # (B,1)
        f = x * self.W.view(1, -1)                    # (B, emb_dim//2)
        return torch.cat([f.sin(), f.cos()], dim=-1)  # (B, emb_dim)

class MLPFiLM(nn.Module):
    """Embed -> scale, shift for FiLM-like conditioning."""
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, out_dim*2),
            nn.SiLU(),
            nn.Linear(out_dim*2, out_dim*2),
        )
    def forward(self, emb):  # (B, E) -> (B, 2*C)
        return self.net(emb)

class PCBlock(nn.Module):
    """
    Predictive-coding update:
      state s (C,H,W)
      target features f (C,H,W)
      error e = f - s in the SAME channel space
      delta = g([s, e]) with FiLM from embedding, then s <- s + delta
    """
    def __init__(self, C, emb_dim, groups=16):
        super().__init__()
        self.norm1 = nn.GroupNorm(groups, 2*C)
        self.conv1 = nn.Conv2d(2*C, C, 3, padding=1)
        self.norm2 = nn.GroupNorm(groups, C)
        self.conv2 = nn.Conv2d(C, C, 3, padding=1)
        self.emb_proj = MLPFiLM(emb_dim, C)

    def forward(self, s, f, emb):
        # s, f: (B,C,H,W); emb: (B,E)
        e = f - s
        x = torch.cat([s, e], dim=1)
        x = self.conv1(self.norm1(x))
        x = F.silu(x)
        # FiLM
        gamma, beta = self.emb_proj(emb).chunk(2, dim=1)  # (B,C), (B,C)
        gamma = gamma.unsqueeze(-1).unsqueeze(-1)
        beta  = beta.unsqueeze(-1).unsqueeze(-1)
        x = self.norm2(x) * (1 + gamma) + beta
        x = F.silu(x)
        delta = self.conv2(x)
        return s + delta

class PredictiveCodingCore(nn.Module):
    """
    Core: project RGB to C, run K PC blocks with shared target features f,
    then project back to RGB.
    """
    def __init__(self, in_ch=3, C=128, K=8, emb_dim=128, groups=16):
        super().__init__()
        self.in_proj  = nn.Conv2d(in_ch, C, 3, padding=1)
        self.blocks   = nn.ModuleList([PCBlock(C, emb_dim, groups=groups) for _ in range(K)])
        self.out_norm = nn.GroupNorm(groups, C)
        self.out_proj = nn.Conv2d(C, in_ch, 3, padding=1)

    def forward(self, x_cond, emb):
        # x_cond: preconditioned noisy input in RGB; emb: (B,E)
        f = self.in_proj(x_cond)        # target features
        s = f.clone()                   # predictive state starts equal to target projection
        for blk in self.blocks:
            s = blk(s, f, emb)
        y = self.out_proj(F.silu(self.out_norm(s)))
        return y


In [4]:
import math, torch, torch.nn as nn
import torch.nn.functional as F

# --- EDM config (CIFAR-10) ---
class EDMConfig:
    sigma_data = 0.5
    sigma_min  = 0.002
    sigma_max  = 80.0
    rho        = 7.0
    P_mean     = -1.2
    P_std      =  1.2

cfg = EDMConfig()

# --- sinusoidal embedding for log-sigma ---
class TimeEmbedding(nn.Module):
    def __init__(self, in_dim=1, hidden=256, out_dim=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.SiLU(),
            nn.Linear(hidden, out_dim), nn.SiLU()
        )
    def forward(self, log_sigma):  # log_sigma shape: (B, 1)
        return self.mlp(log_sigma)

# --- EDM coeffs ---
def edm_coeffs(sigma, sigma_data=0.5):
    # sigma: (B,1,1,1)
    s2 = sigma * sigma
    sd2 = sigma_data * sigma_data
    denom = torch.sqrt(s2 + sd2)
    c_in   = 1.0 / denom
    c_out  = (sigma * sigma_data) / denom
    c_skip = sd2 / (s2 + sd2)
    return c_skip, c_out, c_in

# --- Your predictive-coding core is used as the raw network F_theta ---
# It must map (x_in, x_hat0_init, emb) -> residual in data space.
# It should return a CxHxW tensor per example.
# (Keep your existing PredictiveCodingCore / PCBlock definitions.)
# Example: core = PredictiveCodingCore(C=3, width=128, ...)

class DenoiserEDM(nn.Module):
    def __init__(self, core, emb_dim=256):
        super().__init__()
        self.core = core
        self.t_embed = TimeEmbedding(in_dim=1, hidden=512, out_dim=emb_dim)

    def forward(self, x, sigma_scalar):
        """
        x: (B,C,H,W), sigma_scalar: (B,)   -> returns x_hat0 (B,C,H,W)
        """
        B = x.shape[0]
        sigma_4d = sigma_scalar.view(B, 1, 1, 1)
        log_sigma = 0.25 * torch.log(sigma_scalar).view(B, 1)  # cnoise(σ)=¼ log σ
        emb = self.t_embed(log_sigma)  # (B,emb_dim)

        c_skip, c_out, c_in = edm_coeffs(sigma_4d, cfg.sigma_data)  # (B,1,1,1)
        x_in = c_in * x

        # Your core takes (x_in, x_hat0_init, emb). Keep zeros as x_hat0_init.
        h = self.core(x_in, torch.zeros_like(x_in), emb)  # raw prediction in data space

        # Preconditioning wrapper: x_hat0 = c_skip * x + c_out * h
        x_hat0 = c_skip * x + c_out * h
        return x_hat0


In [5]:
# Cell 5: training utilities

@dataclass
class TrainCfg:
    lr: float = 2e-4
    wd: float = 0.0
    epochs: int = 300
    batch_size: int = 256
    num_workers: int = 4
    sigma_min: float = 0.002
    sigma_max: float = 80.0
    rho: float = 7.0
    sigma_data: float = 0.5
    ema_decay: float = 0.999
    amp: bool = True
    ckpt_dir: str = "./pc_edm_ckpts"

class EMA:
    def __init__(self, m: nn.Module, decay=0.999):
        self.decay = decay
        self.shadow = {k: v.detach().clone() for k, v in m.state_dict().items()}
    @torch.no_grad()
    def update(self, m: nn.Module):
        for k, v in m.state_dict().items():
            self.shadow[k].mul_(self.decay).add_(v.detach(), alpha=1.0 - self.decay)
    @torch.no_grad()
    def copy_to(self, m: nn.Module):
        m.load_state_dict(self.shadow, strict=True)

def sample_log_uniform_sigma(bsz, sigma_min, sigma_max, device):
    u = torch.rand(bsz, device=device)
    return (sigma_min**2 * (sigma_max**2 / sigma_min**2) ** u).sqrt()

def mse_loss(x, y, reduction='mean'):
    return F.mse_loss(x, y, reduction=reduction)

def train_one_epoch(model, opt, ema, loader, cfg: TrainCfg, device):
    model.train()
    scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp)
    loss_meter = 0.0
    n = 0
    for x, _ in loader:
        x = to_device(x, device)
        b = x.size(0)
        sigma = sample_log_uniform_sigma(b, cfg.sigma_min, cfg.sigma_max, x.device)
        noise = torch.randn_like(x)
        x_noisy = x + sigma.view(b,1,1,1) * noise

        opt.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=cfg.amp):
            x_hat = model(x_noisy, sigma)        # predict x0
            # Plain L2 in x-space (EDM often uses simple L2 with log-uniform sigma sampling)
            loss = mse_loss(x_hat, x)

        scaler.scale(loss).backward()
        scaler.unscale_(opt)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(opt)
        scaler.update()
        ema.update(model)

        loss_meter += loss.item() * b
        n += b
    return loss_meter / max(1, n)

def save_ckpt(model, ema, opt, epoch, cfg: TrainCfg):
    Path(cfg.ckpt_dir).mkdir(parents=True, exist_ok=True)
    ckpt = {
        'model': model.state_dict(),
        'ema': ema.shadow,
        'opt': opt.state_dict(),
        'epoch': epoch,
        'cfg': cfg.__dict__,
    }
    torch.save(ckpt, Path(cfg.ckpt_dir)/f"epoch_{epoch:04d}.pt")


In [6]:
# Cell 6: EDM Heun sampler (Karras schedule)

@torch.no_grad()
def karras_sigmas(n, sigma_min, sigma_max, rho, device):
    # Decreasing schedule with Heun; append final 0 as last node
    i = torch.arange(n, device=device, dtype=torch.float32)
    ramp = i / (n - 1)
    min_inv_rho = sigma_min ** (1.0 / rho)
    max_inv_rho = sigma_max ** (1.0 / rho)
    sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
    sigmas = torch.cat([sigmas, torch.zeros_like(sigmas[:1])], dim=0)  # add zero
    return sigmas  # shape: (n+1,)


@torch.no_grad()
def heun_sampler(denoiser, batch_size, channels, height, width,
                 steps=40, sigma_min=None, sigma_max=None, rho=None, device='cuda'):
    sigma_min = cfg.sigma_min if sigma_min is None else sigma_min
    sigma_max = cfg.sigma_max if sigma_max is None else sigma_max
    rho       = cfg.rho       if rho is None else rho

    sigmas = karras_sigmas(steps, sigma_min, sigma_max, rho, device)  # (steps+1,)
    # Start from N(0, sigma_max^2 I)
    x = torch.randn(batch_size, channels, height, width, device=device) * sigmas[0]

    for i in range(steps):
        s_i = sigmas[i].expand(batch_size)      # (B,)
        s_j = sigmas[i+1].expand(batch_size)    # (B,)

        # First slope di = (x - x_hat0)/σi
        x_hat0_i = denoiser(x, s_i)             # (B,C,H,W)
        di = (x - x_hat0_i) / s_i.view(-1,1,1,1)

        # Euler step
        dt = (s_j - s_i).view(-1,1,1,1)         # (B,1,1,1), negative
        x_euler = x + dt * di

        # Second slope dj at provisional state
        x_hat0_j = denoiser(x_euler, s_j)
        dj = (x_euler - x_hat0_j) / s_j.view(-1,1,1,1)

        # Heun update
        x = x + dt * 0.5 * (di + dj)

    return x.clamp(-1, 1)



In [7]:
# Cell 7: FID evaluation (requires torch-fidelity)

@torch.no_grad()
def generate_cifar10_samples(denoiser, n_samples=50000, batch=250, device='cuda',
                             steps=40, sigma_min=0.002, sigma_max=80.0, rho=7.0):
    denoiser.eval()
    imgs = []
    n_batches = (n_samples + batch - 1) // batch
    for _ in range(n_batches):
        b = min(batch, n_samples - len(imgs))
        x = heun_sampler(denoiser, b, 3, 32, 32, steps=steps,
                         sigma_min=sigma_min, sigma_max=sigma_max, rho=rho, device=device)
        # to [0,1]
        x = (x + 1.0) * 0.5
        imgs.append(x.cpu())
    return torch.cat(imgs, dim=0)[:n_samples]

def save_images_tensor(imgs, out_dir):
    out = Path(out_dir)
    out.mkdir(parents=True, exist_ok=True)
    for i in range(imgs.size(0)):
        tvu.save_image(imgs[i], out/f"{i:06d}.png")

def compute_fid_with_torch_fidelity(gen_dir, ref_dataset='cifar10-train'):
    assert HAS_TORCH_FIDELITY, "torch-fidelity not available."
    metrics = torch_fidelity.calculate_metrics(
        input1=str(gen_dir),
        input2=ref_dataset,   # 'cifar10-train' or 'cifar10-test'
        fid=True, verbose=False)
    return float(metrics['frechet_inception_distance'])


In [8]:
# Cell 8: build model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cfg = TrainCfg()

core = PredictiveCodingCore(in_ch=3, C=192, K=10, emb_dim=128, groups=32)
denoiser = DenoiserEDM(core=core, sigma_data=cfg.sigma_data, emb_dim=128).to(device)
print("Params (M):", count_params(denoiser)/1e6)

opt = torch.optim.AdamW(denoiser.parameters(), lr=cfg.lr, weight_decay=cfg.wd, betas=(0.9, 0.999))
ema = EMA(denoiser, decay=cfg.ema_decay)

train_loader, test_loader = get_cifar10(batch_size=cfg.batch_size, num_workers=cfg.num_workers)


TypeError: DenoiserEDM.__init__() got an unexpected keyword argument 'sigma_data'

In [None]:
def edm_loss(denoiser, x0):
    """
    x0 in [-1,1]. Draw sigma ~ LogNormal(P_mean, P_std), add noise,
    weight MSE by lambda(sigma) = (σ²+σd²)/(σ² σd²).
    """
    B = x0.shape[0]
    device = x0.device
    # sample log-sigma
    rnd = torch.randn(B, device=device) * cfg.P_std + cfg.P_mean
    sigma = rnd.exp()  # (B,)

    n = torch.randn_like(x0)
    x_sigma = x0 + sigma.view(B,1,1,1) * n

    # weighted MSE to x0
    x_hat0 = denoiser(x_sigma, sigma)
    lam = (sigma*sigma + cfg.sigma_data*cfg.sigma_data) / ((sigma * cfg.sigma_data)**2)  # (B,)
    w = lam.view(B,1,1,1)
    return (w * (x_hat0 - x0)**2).mean()


In [None]:
import torchvision as tv, torchvision.transforms as T
from torch.utils.data import DataLoader

# --- data to [-1,1] ---
transform = T.Compose([T.ToTensor(), T.Lambda(lambda z: z*2.0 - 1.0)])
trainset = tv.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)

# --- model, opt, ema ---
device = 'cuda' if torch.cuda.is_available() else 'cpu'
core = PredictiveCodingCore(...)  # keep your PC architecture
denoiser = DenoiserEDM(core).to(device)

opt = torch.optim.AdamW(denoiser.parameters(), lr=2e-4, betas=(0.9, 0.99), weight_decay=0.0)
# EMA
class EMA:
    def __init__(self, model, decay=0.9999):
        self.decay = decay
        self.shadow = [p.detach().clone() for p in model.parameters() if p.requires_grad]
        self.params = [p for p in model.parameters() if p.requires_grad]
    @torch.no_grad()
    def update(self):
        for s, p in zip(self.shadow, self.params):
            s.mul_(self.decay).add_(p.detach(), alpha=(1.0 - self.decay))
    @torch.no_grad()
    def copy_to(self, model):
        for s, p in zip(self.shadow, self.params):
            p.data.copy_(s.data)

ema = EMA(denoiser, decay=0.9999)

# --- training ---
def train_edm_pc(denoiser, loader, epochs=50):
    denoiser.train()
    for ep in range(1, epochs+1):
        tot = 0.0
        for x,_ in loader:
            x = x.to(device, non_blocking=True)
            loss = edm_loss(denoiser, x)
            opt.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(denoiser.parameters(), 1.0)
            opt.step()
            ema.update()
            tot += loss.item()
        print(f"epoch {ep:04d} | loss {tot/len(loader):.4f}")


In [None]:
# Cell 10: quick graph/sanity check (8 samples)
ema.copy_to(denoiser)
with torch.no_grad():
    imgs = heun_sampler(denoiser, batch_size=8, channels=3, height=32, width=32,
                        steps=40, sigma_min=cfg.sigma_min, sigma_max=cfg.sigma_max, rho=cfg.rho, device=device)
grid = tvu.make_grid((imgs+1)*0.5, nrow=4)
display(transforms.ToPILImage()(grid.cpu()))


In [None]:
# after some training:
ema.copy_to(denoiser)  # swap in EMA weights
denoiser.eval()
with torch.no_grad():
    imgs = heun_sampler(denoiser, batch_size=8, channels=3, height=32, width=32,
                        steps=40, sigma_min=cfg.sigma_min, sigma_max=cfg.sigma_max,
                        rho=cfg.rho, device=device)
# de-normalize for display
grid = tv.utils.make_grid((imgs + 1)/2, nrow=4, padding=2)
T.ToPILImage()(grid.cpu()).show()


In [None]:
# pip install torchmetrics torchvision
from torchmetrics.image.fid import FrechetInceptionDistance

def fid_from_sampler(denoiser, n_gen=5000, batch=100, steps=40):
    fid = FrechetInceptionDistance(feature=2048, normalize=True).to(device)
    # real stats
    for x,_ in DataLoader(trainset, batch_size=batch, shuffle=True):
        x = x.to(device)
        fid.update(((x+1)/2).clamp(0,1), real=True)
        if fid.real_features_num >= n_gen: break
    # fake stats
    ema.copy_to(denoiser); denoiser.eval()
    done = 0
    while done < n_gen:
        bs = min(batch, n_gen - done)
        imgs = heun_sampler(denoiser, batch_size=bs, channels=3, height=32, width=32, steps=steps, device=device)
        fid.update(((imgs+1)/2).clamp(0,1), real=False)
        done += bs
    return float(fid.compute().cpu())

print("Approx FID:", fid_from_sampler(denoiser, n_gen=10000, batch=100, steps=40))


In [None]:
# Cell 11: FID (needs time and disk)
ema.copy_to(denoiser)
gen_dir = "./gen_cifar10"
samples = generate_cifar10_samples(denoiser, n_samples=50000, batch=250, device=device,
                                   steps=40, sigma_min=cfg.sigma_min, sigma_max=cfg.sigma_max, rho=cfg.rho)
save_images_tensor(samples, gen_dir)
if HAS_TORCH_FIDELITY:
    fid = compute_fid_with_torch_fidelity(gen_dir, ref_dataset='cifar10-train')
    print("FID (vs CIFAR-10 train):", fid)
else:
    print("Install torch-fidelity to compute FID, or use pytorch-fid.")
