In [19]:
# !pip install torch torchvision torch_fidelity lpips --quiet
import math, random, os, functools, itertools, time
from contextlib import nullcontext
from dataclasses import dataclass
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils as vutils
try:
    from torch.amp import autocast, GradScaler
except ImportError:
    from torch.cuda.amp import autocast, GradScaler
from tqdm.auto import tqdm

# Optional: LPIPS for RF pre-metric (installed above)
try:
    import lpips
    LPIPS_AVAILABLE = True
except Exception:
    LPIPS_AVAILABLE = False

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.backends.cudnn.benchmark = True


In [20]:
@dataclass
class Config:
    # Data
    dataset: str = 'CIFAR10'  # 'CIFAR10' or your dataset name
    data_root: str = './data'
    image_size: int = 32
    channels: int = 3

    # Training
    batch_size: int = 128  # per-device batch size
    scale_batch_by_gpus: bool = True
    auto_scale_lr: bool = True
    reference_batch_size: int = 128
    reference_num_gpus: int = 1
    num_steps: int = 300_000
    lr: float = 2e-4
    ema_decay: float = 0.999  # EMA for sampling stability
    grad_clip: float = 1.0
    use_scheduler: bool = True
    scheduler_min_lr: float = 2e-5
    mixed_precision: bool = True

    # Objective toggle
    objective: str = 'edm'  # 'edm' or 'flow' ('flow' = flow/rectified-flow matching)

    # EDM noise schedule (Table 1, EDM)
    sigma_min: float = 0.002
    sigma_max: float = 80.0
    sigma_data: float = 0.5
    rho: float = 7.0
    P_mean: float = -1.2  # log-sigma mean
    P_std: float = 1.2    # log-sigma std

    # Sampler
    NFE: int = 35  # number of function evaluations (Heun steps)
    sample_bs: int = 64
    save_dir: str = './samples'

cfg = Config()
os.makedirs(cfg.save_dir, exist_ok=True)


In [21]:
def make_dataloaders(cfg: Config):
    device_count = torch.cuda.device_count() if torch.cuda.is_available() else 1
    batch_size = cfg.batch_size * device_count if cfg.scale_batch_by_gpus else cfg.batch_size
    if device_count > 1 and cfg.scale_batch_by_gpus:
        print(f'Using effective batch size {batch_size} ({cfg.batch_size} per GPU × {device_count} GPUs)')
    tfm = transforms.Compose([
        transforms.Resize(cfg.image_size),
        transforms.CenterCrop(cfg.image_size),
        transforms.ToTensor(),  # [0,1]
        # map to [-1,1] for EDM / standard diffusion
        transforms.Lambda(lambda x: x * 2.0 - 1.0),
    ])
    if cfg.dataset.upper() == 'CIFAR10':
        trainset = datasets.CIFAR10(root=cfg.data_root, train=True, download=True, transform=tfm)
    else:
        raise NotImplementedError("Add your dataset here.")
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True,
                             drop_last=True, num_workers=8, pin_memory=True)
    return trainloader

trainloader = make_dataloaders(cfg)


Using effective batch size 512 (128 per GPU × 4 GPUs)


In [4]:
class SinusoidalTimeEmbedding(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim

    def forward(self, t: torch.Tensor) -> torch.Tensor:
        """
        t is (B,) of log-sigma (EDM) or time t in [0,1] (Flow). We still feed a scalar with sin-cos.
        """
        half = self.dim // 2
        freqs = torch.exp(torch.arange(half, device=t.device) * (-math.log(10_000) / max(half - 1, 1)))
        args = t[:, None] * freqs[None, :]
        emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
        if self.dim % 2 == 1:  # pad if odd
            emb = F.pad(emb, (0,1))
        return emb


In [5]:
class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, tdim, groups=8):
        super().__init__()
        self.norm1 = nn.GroupNorm(groups, in_ch)
        self.act = nn.SiLU()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.emb = nn.Sequential(
            nn.SiLU(),
            nn.Linear(tdim, out_ch)
        )
        self.norm2 = nn.GroupNorm(groups, out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.skip = (nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity())

    def forward(self, x, temb):
        h = self.conv1(self.act(self.norm1(x)))
        h = h + self.emb(temb)[:, :, None, None]
        h = self.conv2(self.act(self.norm2(h)))
        return h + self.skip(x)

class AttentionBlock(nn.Module):
    def __init__(self, ch, heads=4):
        super().__init__()
        self.norm = nn.GroupNorm(8, ch)
        self.qkv = nn.Conv2d(ch, ch*3, 1)
        self.proj = nn.Conv2d(ch, ch, 1)
        self.heads = heads

    def forward(self, x):
        b,c,h,w = x.shape
        qkv = self.qkv(self.norm(x))
        q,k,v = qkv.chunk(3, dim=1)
        q = q.view(b, self.heads, c//self.heads, h*w)
        k = k.view(b, self.heads, c//self.heads, h*w)
        v = v.view(b, self.heads, c//self.heads, h*w)
        attn = torch.softmax((q.transpose(-2,-1) @ k) / math.sqrt(c//self.heads), dim=-1)  # (B,H,HW,HW)
        out = (attn @ v.transpose(-2,-1)).transpose(-2,-1).contiguous()
        out = out.view(b, c, h, w)
        return x + self.proj(out)

class EDMUNet(nn.Module):
    """
    UNet backbone. Outputs raw network F_theta; EDM preconditioning is applied outside.
    Deep supervision: intermediate x0 heads to realize 'predictive coding' residual learning.
    """
    def __init__(self, cfg: Config, ch=128, ch_mult=(1,2,2,2), num_res=2, attn_res=(16,)):
        super().__init__()
        self.cfg = cfg
        self.attn_resolutions = set(attn_res)
        self.in_conv = nn.Conv2d(cfg.channels, ch, 3, padding=1)
        tdim = ch * 4
        self.time_mlp = nn.Sequential(
            SinusoidalTimeEmbedding(tdim),
            nn.Linear(tdim, tdim),
            nn.SiLU(),
            nn.Linear(tdim, tdim),
        )
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()

        chs = []
        curr_ch = ch
        in_size = cfg.image_size
        for level, mult in enumerate(ch_mult):
            out_ch = ch * mult
            for _ in range(num_res):
                block = ResBlock(curr_ch, out_ch, tdim)
                self.downs.append(block)
                curr_ch = out_ch
                chs.append(curr_ch)  # track skip connection channels
                if in_size in self.attn_resolutions:
                    self.downs.append(AttentionBlock(curr_ch))
            if level != len(ch_mult) - 1:
                self.downs.append(nn.Conv2d(curr_ch, curr_ch, 3, stride=2, padding=1))
                in_size //= 2

        self.mid_block1 = ResBlock(curr_ch, curr_ch, tdim)
        self.mid_attn = AttentionBlock(curr_ch)
        self.mid_block2 = ResBlock(curr_ch, curr_ch, tdim)

        aux_channels = [ch * mult for level, mult in reversed(list(enumerate(ch_mult))) if level != 0]
        self.aux_heads = nn.ModuleList([
            nn.Conv2d(c, cfg.channels, 1) for c in aux_channels
        ])

        for level, mult in reversed(list(enumerate(ch_mult))):
            out_ch = ch * mult
            for _ in range(num_res):
                skip_ch = chs.pop()
                self.ups.append(ResBlock(curr_ch + skip_ch, out_ch, tdim))
                curr_ch = out_ch
                if in_size in self.attn_resolutions:
                    self.ups.append(AttentionBlock(curr_ch))
            if level != 0:
                self.ups.append(nn.Upsample(scale_factor=2, mode='nearest'))
                in_size *= 2

        self.out_norm = nn.GroupNorm(8, curr_ch)
        self.out = nn.Conv2d(curr_ch, cfg.channels, 3, padding=1)

    def forward(self, x, t_scalar):
        temb = self.time_mlp(t_scalar)
        hs = []
        h = self.in_conv(x)
        aux_preds = []
        for m in self.downs:
            if isinstance(m, ResBlock):
                h = m(h, temb)
                hs.append(h)
            elif isinstance(m, AttentionBlock):
                h = m(h)
            else:
                h = m(h)
        h = self.mid_block1(h, temb)
        h = self.mid_attn(h)
        h = self.mid_block2(h, temb)
        up_idx = 0
        for m in self.ups:
            if isinstance(m, ResBlock):
                skip = hs.pop()
                h = torch.cat([h, skip], dim=1)
                h = m(h, temb)
            elif isinstance(m, AttentionBlock):
                h = m(h)
            else:
                h = m(h)
                if up_idx < len(self.aux_heads):
                    aux_preds.append(self.aux_heads[up_idx](h))
                    up_idx += 1
        h = self.out(self.out_norm(h))
        return h, aux_preds  # raw network output F_theta (no preconditioning)


In [6]:
def rand_log_normal_sigma(bs, P_mean, P_std, device):
    # Sample log sigma ~ N(P_mean, P_std^2)
    return torch.exp(P_mean + P_std * torch.randn(bs, device=device))

def karras_sigma_schedule(N, sigma_min, sigma_max, rho=7.0, device='cpu'):
    # Monotone decreasing σ sequence used for sampling; see EDM Table 1 & Alg 1
    ramp = torch.linspace(0, 1, N, device=device)
    min_inv_rho = sigma_min**(1/rho)
    max_inv_rho = sigma_max**(1/rho)
    sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**rho
    return sigmas

class EDMPrecondWrapper(nn.Module):
    def __init__(self, net: EDMUNet, sigma_data: float):
        super().__init__()
        self.net = net
        self.sigma_data = sigma_data

    def forward(self, x, sigma):
        """
        x: (B,C,H,W) in [-1,1]
        sigma: (B,) positive noise level
        returns: denoised x0 estimate (EDM preconditioned output)
        """
        sigma = sigma.view(-1, 1, 1, 1)
        c_skip = (self.sigma_data**2) / (sigma**2 + self.sigma_data**2)
        c_out  = (sigma * self.sigma_data) / torch.sqrt(sigma**2 + self.sigma_data**2)
        c_in   = 1.0 / torch.sqrt(sigma**2 + self.sigma_data**2)
        # c_noise is log-sigma / 4 per EDM
        c_noise = 0.25 * torch.log(sigma.squeeze(-1).squeeze(-1).squeeze(-1))

        f_raw, aux_preds = self.net(c_in * x, c_noise)
        D = c_skip * x + c_out * f_raw
        return D, aux_preds


In [7]:
def sample_xt_rf(x0, eps, t):
    # xt = (1-t) x0 + t eps ; eps ~ N(0,I), t in [0,1]
    return (1.0 - t)[:, None, None, None] * x0 + t[:, None, None, None] * eps

def flow_matching_targets(x0, eps):
    # v = d/dt xt = eps - x0  (constant in t for linear path)
    return eps - x0

def u_shaped_t(bs, device):
    # U-shaped distribution (more mass near 0 and 1), improves few-step performance
    u = torch.rand(bs, device=device)
    t = 0.5 - 0.5 * torch.cos(math.pi * u)
    return t.clamp(1e-5, 1-1e-5)


In [22]:
def make_model(cfg: Config):
    net = EDMUNet(cfg)
    model = EDMPrecondWrapper(net, sigma_data=cfg.sigma_data).to(DEVICE)
    ema = EDMPrecondWrapper(EDMUNet(cfg), sigma_data=cfg.sigma_data).to(DEVICE)
    ema.load_state_dict(model.state_dict())
    for p in ema.parameters():
        p.requires_grad = False
    device_count = torch.cuda.device_count() if torch.cuda.is_available() else 1
    effective_batch = cfg.batch_size * device_count if cfg.scale_batch_by_gpus else cfg.batch_size
    cfg.effective_batch_size = effective_batch
    reference_batch = max(1, cfg.reference_batch_size * max(1, cfg.reference_num_gpus))
    scaled_lr = cfg.lr
    if cfg.auto_scale_lr:
        scaled_lr = cfg.lr * (effective_batch / reference_batch)
    cfg.effective_lr = scaled_lr
    if cfg.auto_scale_lr:
        print(f"Effective batch size: {effective_batch}. Scaled LR -> {scaled_lr:.2e}")
    opt = AdamW(model.parameters(), lr=scaled_lr, weight_decay=0.0, betas=(0.9, 0.999))
    return model, ema, opt

model, ema, opt = make_model(cfg)
global_step = 0


Effective batch size: 512. Scaled LR -> 8.00e-04


In [23]:
def setup_data_parallel(model, ema):
    if DEVICE != 'cuda':
        print('CUDA not available; keeping single-device model.')
        return model, ema
    num_devices = torch.cuda.device_count()
    if num_devices <= 1:
        print('Only one CUDA device detected; DataParallel not applied.')
        return model, ema
    print(f'Enabling DataParallel across {num_devices} GPUs.')
    model = nn.DataParallel(model)
    ema = ema.to(DEVICE)
    return model, ema

model, ema = setup_data_parallel(model, ema)

def _normalize_state_dict_keys(state_dict):
    if not any(key.startswith('module.') for key in state_dict.keys()):
        return state_dict
    return {key.replace('module.', '', 1): value for key, value in state_dict.items()}

@torch.no_grad()
def update_ema(ema, model, decay):
    source_state = _normalize_state_dict_keys(model.state_dict())
    target_state = ema.state_dict()
    for key, param in target_state.items():
        if param.dtype.is_floating_point:
            param.copy_(decay * param + (1.0 - decay) * source_state[key])

USE_BF16 = DEVICE == 'cuda' and torch.cuda.is_bf16_supported()
AMP_DTYPE = torch.bfloat16 if USE_BF16 else torch.float16
USE_MIXED_PRECISION = cfg.mixed_precision and DEVICE == 'cuda'
scaler = GradScaler(enabled=USE_MIXED_PRECISION and not USE_BF16) if USE_MIXED_PRECISION else None
if USE_MIXED_PRECISION:
    prec = 'bf16' if USE_BF16 else 'fp16'
    print(f'Using mixed precision ({prec}); GradScaler enabled={scaler is not None and scaler.is_enabled()}')



Enabling DataParallel across 4 GPUs.
Using mixed precision (bf16); GradScaler enabled=False


In [24]:
lpips_fn = lpips.LPIPS(net='vgg').to(DEVICE).eval() if LPIPS_AVAILABLE else None

def loss_lpips_huber(x_hat, x0, delta=0.01):
    # LPIPS + Huber (optional, RF low-NFE improvements)
    loss = 0.0
    if LPIPS_AVAILABLE:
        with torch.no_grad():
            pass
        p = lpips_fn(torch.clamp((x_hat+1)/2,0,1), torch.clamp((x0+1)/2,0,1))
        loss = p.mean()
    # small Huber in pixel space
    diff = x_hat - x0
    huber = torch.where(diff.abs() < delta, 0.5*diff**2, delta*(diff.abs()-0.5*delta)).mean()
    return loss + huber

def train_step(batch, cfg: Config, model, ema, opt, objective='edm', scaler: Optional[GradScaler] = None):
    global global_step
    x0 = batch[0].to(DEVICE, non_blocking=True)  # in [-1,1]
    B = x0.shape[0]
    use_amp = USE_MIXED_PRECISION
    if use_amp:
        amp_context = autocast
        amp_kwargs = {'device_type': 'cuda', 'dtype': AMP_DTYPE}
        if scaler is None and AMP_DTYPE == torch.float16:
            scaler = GradScaler(enabled=True)
    else:
        amp_context = nullcontext
        amp_kwargs = {}

    def edm_forward():
        sigma = rand_log_normal_sigma(B, cfg.P_mean, cfg.P_std, DEVICE).clamp(cfg.sigma_min, cfg.sigma_max)
        n = torch.randn_like(x0)
        x_noisy = x0 + sigma[:, None, None, None] * n
        xhat, aux_preds = model(x_noisy, sigma)
        lam = (sigma**2 + cfg.sigma_data**2) / ((sigma * cfg.sigma_data)**2)
        pix_loss = (lam[:, None, None, None] * (xhat - x0)**2).mean()
        aux_loss = 0.0
        for aux in aux_preds:
            aux_loss += F.mse_loss(aux, F.interpolate(x0, aux.shape[-2:], mode='bilinear', align_corners=False))
        aux_loss = 0.25 * aux_loss / max(1, len(aux_preds))
        return pix_loss + aux_loss

    def flow_forward():
        t = u_shaped_t(B, DEVICE)
        eps = torch.randn_like(x0)
        xt = sample_xt_rf(x0, eps, t)
        base_net = getattr(model, 'module', model).net
        f_raw, aux_preds = base_net(xt, t)
        v_target = flow_matching_targets(x0, eps)
        loss_main = F.mse_loss(f_raw, v_target)
        if LPIPS_AVAILABLE:
            xhat = xt + f_raw * (1.0 - t)[:, None, None, None]
            loss_main = loss_main + 0.05 * loss_lpips_huber(xhat, x0)
        aux_loss = 0.0
        for aux in aux_preds:
            v_down = F.interpolate(v_target, aux.shape[-2:], mode='bilinear', align_corners=False)
            aux_loss += F.mse_loss(aux, v_down)
        aux_loss = 0.25 * aux_loss / max(1, len(aux_preds))
        return loss_main + aux_loss

    with amp_context(**amp_kwargs):
        if objective == 'edm':
            loss = edm_forward()
        elif objective == 'flow':
            loss = flow_forward()
        else:
            raise ValueError("objective must be 'edm' or 'flow'")

    loss_value = float(loss.detach().cpu().item())
    opt.zero_grad(set_to_none=True)
    if scaler is not None and scaler.is_enabled():
        scaler.scale(loss).backward()
        if cfg.grad_clip is not None:
            scaler.unscale_(opt)
            nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
        scaler.step(opt)
        scaler.update()
    else:
        loss.backward()
        if cfg.grad_clip is not None:
            nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
        opt.step()
    update_ema(ema, model, cfg.ema_decay)
    global_step += 1
    return loss_value


Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]
Loading model from: /blue/wdixon/wang.yixuan/.conda/envs/ly/lib/python3.10/site-packages/lpips/weights/v0.1/vgg.pth
Loading model from: /blue/wdixon/wang.yixuan/.conda/envs/ly/lib/python3.10/site-packages/lpips/weights/v0.1/vgg.pth


In [11]:
@torch.no_grad()
def heun_sampler_ema(ema, cfg: Config, num: int = 64):
    ema.eval()
    sigmas = karras_sigma_schedule(cfg.NFE, cfg.sigma_min, cfg.sigma_max, cfg.rho, device=DEVICE)
    x = torch.randn(num, cfg.channels, cfg.image_size, cfg.image_size, device=DEVICE) * sigmas[0]
    for i in range(len(sigmas)-1):
        sigma_i = sigmas[i]
        sigma_j = sigmas[i+1]
        # Derivative of x w.r.t. t (prob flow ODE form): dx/dt = -(σ̇/σ)[ x - D(x;σ) ]
        # In discrete form with σ as the "time" parameterizing, EDM uses Heun (improved Euler):
        #   d_i = ( (sigma_dot/sigma) * x - (sigma_dot) * D(x; sigma)/sigma )  -> implemented via EDM Alg 1
        # We use the simpler practical Euler predictor + trapezoid corrector based on D(x;σ).
        # (Exact EDM code uses the same structure; see Alg. 1 lines 4–8; guard at σ->0) :contentReference[oaicite:11]{index=11}

        # Compute derivative at i
        D_i, _ = ema(x, torch.full((num,), sigma_i, device=DEVICE))
        d_i = (D_i - x) / (sigma_i**2) * (- (sigma_j - sigma_i)) * (sigma_i)  # scaled step; equivalent form
        # Euler step
        x_euler = x + d_i

        if sigma_j > 0:
            # 2nd-order correction (Heun/trapezoid)
            D_j, _ = ema(x_euler, torch.full((num,), sigma_j, device=DEVICE))
            d_j = (D_j - x_euler) / (sigma_j**2) * (- (sigma_j - sigma_i)) * (sigma_j)
            x = x + 0.5 * (d_i + d_j)
        else:
            # At σ=0, use Euler (avoid division by zero)
            x = x_euler

    x = x.clamp(-1, 1)
    return x


In [None]:
def train(cfg: Config, model, ema, opt, trainloader, scaler: Optional[GradScaler] = None):
    model.train()
    losses = []
    data_iter = itertools.cycle(trainloader)
    scheduler = None
    if cfg.use_scheduler:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            opt, T_max=cfg.num_steps, eta_min=cfg.scheduler_min_lr
        )
    loss_ema = None
    with tqdm(range(cfg.num_steps), desc="Training", unit="step") as progress:
        for step in progress:
            batch = next(data_iter)
            loss = train_step(batch, cfg, model, ema, opt, objective=cfg.objective, scaler=scaler)
            losses.append(loss)
            if scheduler is not None:
                scheduler.step()
            loss_ema = loss if loss_ema is None else 0.9 * loss_ema + 0.1 * loss
            postfix = {"loss": f"{loss_ema:.4f}"}
            if scheduler is not None:
                postfix["lr"] = f"{scheduler.get_last_lr()[0]:.2e}"
            progress.set_postfix(postfix)
            if (step + 1) % 5000 == 0:
                with torch.no_grad():
                    samples = heun_sampler_ema(ema, cfg, num=cfg.sample_bs)
                grid = vutils.make_grid((samples+1)/2, nrow=int(math.sqrt(cfg.sample_bs)))
                vutils.save_image(grid, os.path.join(cfg.save_dir, f"sample_{cfg.objective}_{step+1}.png"))
    return losses

train(cfg, model, ema, opt, trainloader, scaler)


Training:   0%|          | 0/300000 [00:00<?, ?step/s]

In [13]:
# After training (or loading a checkpoint), sample:
with torch.no_grad():
    samples = heun_sampler_ema(ema, cfg, num=cfg.sample_bs)
grid = vutils.make_grid((samples+1)/2, nrow=int(math.sqrt(cfg.sample_bs)))
vutils.save_image(grid, os.path.join(cfg.save_dir, f"final_{cfg.objective}.png"))
grid


tensor([[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.2449,  ..., 0.3200, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0000, 1.0000,  ..., 0.1597, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.3612,  ..., 0.1081, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0000, 1.0000,  ..., 0.1802, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.3442,  ..., 0.0201, 0.0000, 0.