In [1]:
import math, os, time, random
from dataclasses import dataclass
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader


import torchvision as tv
import torchvision.transforms as T
import torchvision.utils as tvu
from torchvision.datasets import CIFAR10

import numpy as np
from scipy import linalg
from torchvision.models import inception_v3

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

def seed_all(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
seed_all(42)

print("‚úÖ Imports loaded")

Device: cuda
‚úÖ Imports loaded


In [2]:
@dataclass
class Config:
    # Data
    data_root: str = './data'
    channels: int = 3
    image_size: int = 32
    
    # PC Architecture - Following hierarchical PC theory
    num_layers: int = 4  # Number of hierarchical latent layers
    resolutions: Tuple[int] = (32, 16, 8, 4)  # Resolution of each layer
    layer_channels: Tuple[int] = (64, 128, 256, 256)  # Channels at each layer
    pc_iterations: int = 8  # Inference iterations to minimize energy
    
    # Time embedding
    emb_dim: int = 128
    
    # EDM preconditioning
    sigma_data: float = 0.5
    sigma_min: float = 0.002
    sigma_max: float = 80.0
    rho: float = 7.0
    
    # Training
    batch_size: int = 128
    num_workers: int = 4
    epochs: int = 100
    lr: float = 2e-4
    ema_decay: float = 0.9999
    log_every: int = 100
    sample_every: int = 10  # Sample every N epochs
    
    # FID evaluation
    fid_samples: int = 10000
    fid_batch: int = 100

cfg = Config()
print(cfg)
print("\n‚úÖ Config created")

Config(data_root='./data', channels=3, image_size=32, num_layers=4, resolutions=(32, 16, 8, 4), layer_channels=(64, 128, 256, 256), pc_iterations=8, emb_dim=128, sigma_data=0.5, sigma_min=0.002, sigma_max=80.0, rho=7.0, batch_size=128, num_workers=4, epochs=100, lr=0.0002, ema_decay=0.9999, log_every=100, sample_every=10, fid_samples=10000, fid_batch=100)

‚úÖ Config created


In [3]:
class SinusoidalEmbedding(nn.Module):
    """Time/sigma embedding."""
    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim
        half = dim // 2
        self.register_buffer('freqs', torch.exp(torch.linspace(math.log(1.0), math.log(1000.0), half)))

    def forward(self, x: torch.Tensor):
        x = x.view(-1)[:, None] * self.freqs[None, :]
        emb = torch.cat([x.sin(), x.cos()], dim=-1)
        if emb.shape[1] < self.dim:
            emb = F.pad(emb, (0, self.dim - emb.shape[1]))
        return emb

# Test
emb_test = SinusoidalEmbedding(128)
assert emb_test(torch.randn(4)).shape == (4, 128)
print("‚úÖ SinusoidalEmbedding")

‚úÖ SinusoidalEmbedding


In [4]:
def get_groups(channels: int, max_groups: int = 32) -> int:
    """Find largest divisor for GroupNorm."""
    for g in range(max_groups, 0, -1):
        if channels % g == 0:
            return g
    return 1

class PCLayer(nn.Module):
    """One layer in the PC hierarchy.
    
    Receives top-down prediction from layer above,
    compares with current state, computes error.
    """
    def __init__(self, in_channels: int, out_channels: int, emb_dim: int):
        super().__init__()
        
        # Top-down prediction pathway (from layer above)
        self.pred_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.GroupNorm(get_groups(out_channels), out_channels),
            nn.SiLU(),
            nn.Conv2d(out_channels, out_channels, 3, padding=1)
        )
        
        # Error processing (bottom-up)
        self.error_conv = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.GroupNorm(get_groups(out_channels), out_channels),
            nn.SiLU()
        )
        
        # Time embedding injection
        self.temb_proj = nn.Sequential(
            nn.SiLU(),
            nn.Linear(emb_dim, out_channels * 2)
        )
    
    def predict(self, x_above: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
        """Top-down prediction: what does layer above predict for this layer?"""
        h = self.pred_conv(x_above)
        
        # Add time conditioning
        scale, shift = self.temb_proj(temb).chunk(2, dim=1)
        h = h * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
        
        return h
    
    def process_error(self, error: torch.Tensor) -> torch.Tensor:
        """Process prediction error (bottom-up pathway)."""
        return self.error_conv(error)

# Test
layer_test = PCLayer(64, 128, 128).to(device)
x_test = torch.randn(2, 64, 16, 16).to(device)
temb_test = torch.randn(2, 128).to(device)
pred = layer_test.predict(x_test, temb_test)
error = layer_test.process_error(torch.randn(2, 128, 16, 16).to(device))
assert pred.shape == (2, 128, 16, 16)
assert error.shape == (2, 128, 16, 16)
print("‚úÖ PCLayer")

‚úÖ PCLayer


In [5]:
class HierarchicalPC(nn.Module):
    """Hierarchical Predictive Coding Network.
    
    Following classical PC theory with:
    - Multiple latent layers at different resolutions
    - Top-down predictions
    - Bottom-up error propagation
    - Iterative inference to minimize energy
    """
    def __init__(
        self,
        in_channels: int = 3,
        layer_channels: Tuple[int] = (64, 128, 256, 256),
        resolutions: Tuple[int] = (32, 16, 8, 4),
        pc_iterations: int = 8,
        emb_dim: int = 128
    ):
        super().__init__()
        self.num_layers = len(layer_channels)
        self.resolutions = resolutions
        self.layer_channels = layer_channels
        self.pc_iterations = pc_iterations
        
        # Input projection (noisy image ‚Üí first latent layer)
        self.input_proj = nn.Conv2d(in_channels, layer_channels[0], 3, padding=1)
        
        # PC layers (each layer predicts the one below)
        self.pc_layers = nn.ModuleList()
        for i in range(self.num_layers - 1):
            self.pc_layers.append(
                PCLayer(layer_channels[i+1], layer_channels[i], emb_dim)
            )
        
        # Downsampling (for hierarchy)
        self.downsample = nn.ModuleList([
            nn.Conv2d(layer_channels[i], layer_channels[i+1], 3, stride=2, padding=1)
            for i in range(self.num_layers - 1)
        ])
        
        # Upsampling (for top-down predictions)
        self.upsample = nn.ModuleList([
            nn.ConvTranspose2d(layer_channels[i+1], layer_channels[i+1], 4, stride=2, padding=1)
            for i in range(self.num_layers - 1)
        ])
        
        # Output projection (first latent layer ‚Üí clean image)
        self.output_proj = nn.Sequential(
            nn.GroupNorm(get_groups(layer_channels[0]), layer_channels[0]),
            nn.SiLU(),
            nn.Conv2d(layer_channels[0], in_channels, 3, padding=1)
        )
    
    def forward(self, x_noisy: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
        """Predictive coding inference.
        
        1. Initialize latent hierarchy from input
        2. Iterate to minimize prediction errors
        3. Output denoised prediction
        """
        B = x_noisy.shape[0]
        
        # Initialize latent hierarchy
        x = [None] * self.num_layers
        x[0] = self.input_proj(x_noisy)
        
        # Build initial hierarchy via downsampling
        for i in range(self.num_layers - 1):
            x[i+1] = self.downsample[i](x[i])
        
        # Predictive coding iterations
        for it in range(self.pc_iterations):
            # Top-down predictions and error computation
            predictions = [None] * (self.num_layers - 1)
            errors = [None] * (self.num_layers - 1)
            
            for i in range(self.num_layers - 1):
                # Layer i+1 predicts layer i
                x_above_upsampled = self.upsample[i](x[i+1])
                predictions[i] = self.pc_layers[i].predict(x_above_upsampled, temb)
                
                # Prediction error
                errors[i] = x[i] - predictions[i]
            
            # Update latent variables based on errors
            # (In full PC, this would be gradient-based, but we use residual updates)
            for i in range(self.num_layers - 1):
                # Process error
                error_processed = self.pc_layers[i].process_error(errors[i])
                
                # Update lower layer (move toward prediction)
                x[i] = x[i] * 0.5 + predictions[i] * 0.5 + error_processed * 0.1
                
                # Propagate error up (update higher layer)
                if i < self.num_layers - 1:
                    error_down = self.downsample[i](error_processed)
                    x[i+1] = x[i+1] + error_down * 0.1
        
        # Output denoised image
        return self.output_proj(x[0])

# Test
pc_net = HierarchicalPC(
    in_channels=3,
    layer_channels=cfg.layer_channels,
    resolutions=cfg.resolutions,
    pc_iterations=cfg.pc_iterations,
    emb_dim=cfg.emb_dim
).to(device)

x_test = torch.randn(2, 3, 32, 32).to(device)
temb_test = torch.randn(2, 128).to(device)
with torch.no_grad():
    out = pc_net(x_test, temb_test)
assert out.shape == (2, 3, 32, 32)
print(f"‚úÖ HierarchicalPC: {x_test.shape} -> {out.shape}")
print(f"   Parameters: {sum(p.numel() for p in pc_net.parameters())/1e6:.2f}M")
print(f"   Latent hierarchy: {cfg.resolutions}")
print(f"   PC iterations: {cfg.pc_iterations}")

‚úÖ HierarchicalPC: torch.Size([2, 3, 32, 32]) -> torch.Size([2, 3, 32, 32])
   Parameters: 5.95M
   Latent hierarchy: (32, 16, 8, 4)
   PC iterations: 8


In [6]:
class DenoiserEDM(nn.Module):
    """EDM-preconditioned denoiser with PC core."""
    def __init__(self, core: nn.Module, emb_dim: int, sigma_data: float = 0.5):
        super().__init__()
        self.core = core
        self.sigma_data = sigma_data
        self.emb = SinusoidalEmbedding(emb_dim)

    def _coeffs(self, sigma: torch.Tensor):
        s2 = sigma**2
        sd2 = self.sigma_data**2
        c_skip = sd2 / (s2 + sd2)
        c_in = 1.0 / torch.sqrt(s2 + sd2)
        c_out = sigma * self.sigma_data / torch.sqrt(s2 + sd2)
        return c_skip, c_in, c_out

    def forward(self, x: torch.Tensor, sigma: torch.Tensor):
        if sigma.dim() == 1:
            sigma_img = sigma[:, None, None, None]
        else:
            sigma_img = sigma
            sigma = sigma.view(-1)

        c_skip, c_in, c_out = self._coeffs(sigma_img)
        x_in = c_in * x
        emb = self.emb(torch.log(sigma.clamp(min=1e-8)))
        h = self.core(x_in, emb)
        return c_skip * x + c_out * h

# Test
denoiser_test = DenoiserEDM(pc_net, cfg.emb_dim, cfg.sigma_data).to(device)
x_test = torch.randn(2, 3, 32, 32).to(device)
sigma_test = torch.rand(2).to(device) * 80
with torch.no_grad():
    out = denoiser_test(x_test, sigma_test)
assert out.shape == (2, 3, 32, 32)
print(f"‚úÖ DenoiserEDM: {x_test.shape} -> {out.shape}")

‚úÖ DenoiserEDM: torch.Size([2, 3, 32, 32]) -> torch.Size([2, 3, 32, 32])


In [14]:
@torch.no_grad()
def edm_schedule(steps: int, sigma_min: float, sigma_max: float, rho: float, device):
    """EDM noise schedule."""
    i = torch.arange(steps, device=device, dtype=torch.float32)
    t = sigma_max**(1/rho) + (i / (steps - 1)) * (sigma_min**(1/rho) - sigma_max**(1/rho))
    return torch.cat([t**rho, torch.zeros(1, device=device)])

@torch.no_grad()
def heun_sampler(
    denoiser: nn.Module,
    batch_size: int,
    channels: int,
    size: int,
    steps: int = 50,
    sigma_min: float = 0.002,
    sigma_max: float = 80.0,
    rho: float = 7.0,
    device=device
):
    """Heun sampler with CORRECT denoising residual (no division!)."""
    sigmas = edm_schedule(steps, sigma_min, sigma_max, rho, device)
    x = torch.randn(batch_size, channels, size, size, device=device) * sigmas[0]
    
    for i in range(steps):
        sigma_cur = sigmas[i]
        sigma_next = sigmas[i + 1]
        
        if sigma_next == 0:
            x = denoiser(x, sigma_cur.expand(batch_size))
            break
        
        # ‚úÖ CORRECT: No division by sigma!
        d_cur = denoiser(x, sigma_cur.expand(batch_size)) - x
        x_euler = x + (sigma_next - sigma_cur) * d_cur
        
        d_next = denoiser(x_euler, sigma_next.expand(batch_size)) - x_euler
        x = x + (sigma_next - sigma_cur) * 0.5 * (d_cur + d_next)
    
    return x.clamp(-1, 1)

# Test schedule with robust assertion
test_sched = edm_schedule(40, cfg.sigma_min, cfg.sigma_max, cfg.rho, device)
assert len(test_sched) == 41
# Use torch.isclose for floating point comparison
assert torch.isclose(test_sched[0], torch.tensor(cfg.sigma_max, device=device), rtol=1e-4), \
    f"Expected {cfg.sigma_max}, got {test_sched[0].item()}"
assert test_sched[-1] == 0
print(f"‚úÖ Sampler: {len(test_sched)} steps, œÉ ‚àà [{test_sched[0]:.1f}, {test_sched[-1]:.1f}]")


‚úÖ Sampler: 41 steps, œÉ ‚àà [80.0, 0.0]


In [8]:
def sample_sigma(B: int, sigma_min: float, sigma_max: float, device):
    u = torch.rand(B, device=device)
    return torch.exp(u * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min))

def edm_loss(denoiser: nn.Module, x0: torch.Tensor, sigma: torch.Tensor):
    """EDM denoising objective."""
    sigma_img = sigma[:, None, None, None] if sigma.dim() == 1 else sigma
    noise = torch.randn_like(x0)
    x_noisy = x0 + sigma_img * noise
    x_pred = denoiser(x_noisy, sigma)
    return F.mse_loss(x_pred, x0)

class EMA:
    def __init__(self, model: nn.Module, decay: float = 0.9999):
        self.decay = decay
        self.shadow = {n: p.data.clone() for n, p in model.named_parameters() if p.requires_grad}

    @torch.no_grad()
    def update(self, model: nn.Module):
        for n, p in model.named_parameters():
            if n in self.shadow:
                self.shadow[n] = self.decay * self.shadow[n] + (1 - self.decay) * p.data

    @torch.no_grad()
    def copy_to(self, model: nn.Module):
        for n, p in model.named_parameters():
            if n in self.shadow:
                p.data.copy_(self.shadow[n])

print("‚úÖ Training utils defined")

‚úÖ Training utils defined


In [9]:
@torch.no_grad()
def get_inception_features(images: torch.Tensor, model, batch_size: int = 50):
    """Extract Inception features for FID."""
    model.eval()
    features = []
    
    for i in range(0, len(images), batch_size):
        batch = images[i:i+batch_size]
        # Resize to 299x299 for InceptionV3
        batch = F.interpolate(batch, size=(299, 299), mode='bilinear', align_corners=False)
        feat = model(batch)[0].squeeze(-1).squeeze(-1)
        features.append(feat.cpu().numpy())
    
    return np.concatenate(features, axis=0)

def calculate_fid(real_features: np.ndarray, fake_features: np.ndarray, eps=1e-6):
    """Calculate FID between real and fake features."""
    mu1, sigma1 = real_features.mean(axis=0), np.cov(real_features, rowvar=False)
    mu2, sigma2 = fake_features.mean(axis=0), np.cov(fake_features, rowvar=False)
    
    diff = mu1 - mu2
    covmean, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False)
    
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    
    fid = diff @ diff + np.trace(sigma1 + sigma2 - 2 * covmean)
    return fid

@torch.no_grad()
def compute_fid(
    denoiser: nn.Module,
    real_loader: DataLoader,
    num_samples: int,
    cfg: Config,
    inception_model
):
    """Compute FID score."""
    denoiser.eval()
    
    # Get real features
    print(f"   Computing features for {num_samples} real images...")
    real_images = []
    for x, _ in real_loader:
        real_images.append(x)
        if len(real_images) * x.shape[0] >= num_samples:
            break
    real_images = torch.cat(real_images, dim=0)[:num_samples]
    real_images = (real_images + 1) / 2  # [-1,1] -> [0,1]
    real_features = get_inception_features(real_images.to(device), inception_model)
    
    # Generate fake samples
    print(f"   Generating {num_samples} fake images...")
    fake_images = []
    num_batches = (num_samples + cfg.fid_batch - 1) // cfg.fid_batch
    for i in range(num_batches):
        batch_size = min(cfg.fid_batch, num_samples - i * cfg.fid_batch)
        fake = heun_sampler(
            denoiser, batch_size, cfg.channels, cfg.image_size,
            steps=50, sigma_min=cfg.sigma_min, sigma_max=cfg.sigma_max,
            rho=cfg.rho, device=device
        )
        fake_images.append(fake)
    fake_images = torch.cat(fake_images, dim=0)
    fake_images = (fake_images + 1) / 2  # [-1,1] -> [0,1]
    fake_features = get_inception_features(fake_images, inception_model)
    
    # Compute FID
    fid = calculate_fid(real_features, fake_features)
    return fid

print("‚úÖ FID computation defined")

‚úÖ FID computation defined


In [10]:
# Load data
transform = T.Compose([
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize([0.5]*3, [0.5]*3)  # [-1, 1]
])

train_ds = CIFAR10(cfg.data_root, train=True, download=True, transform=transform)
train_loader = DataLoader(
    train_ds, batch_size=cfg.batch_size, shuffle=True,
    num_workers=cfg.num_workers, pin_memory=True, drop_last=True
)

print(f"‚úÖ CIFAR-10: {len(train_loader)} batches")

# Load Inception for FID
inception = inception_v3(pretrained=True, transform_input=False).to(device)
inception.eval()
for p in inception.parameters():
    p.requires_grad = False
print("‚úÖ InceptionV3 loaded for FID")

‚úÖ CIFAR-10: 390 batches
Downloading: "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth" to /home/wang.yixuan/.cache/torch/hub/checkpoints/inception_v3_google-0cc3c7bd.pth


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 104M/104M [00:00<00:00, 231MB/s]  


‚úÖ InceptionV3 loaded for FID


In [11]:
# Create model
pc_core = HierarchicalPC(
    in_channels=cfg.channels,
    layer_channels=cfg.layer_channels,
    resolutions=cfg.resolutions,
    pc_iterations=cfg.pc_iterations,
    emb_dim=cfg.emb_dim
)
denoiser = DenoiserEDM(pc_core, cfg.emb_dim, cfg.sigma_data).to(device)
print(f"‚úÖ Model: {sum(p.numel() for p in denoiser.parameters())/1e6:.2f}M params")

opt = torch.optim.AdamW(denoiser.parameters(), lr=cfg.lr, betas=(0.9, 0.999))
ema = EMA(denoiser, cfg.ema_decay)
print(f"‚úÖ Optimizer & EMA ready")

‚úÖ Model: 5.95M params
‚úÖ Optimizer & EMA ready


In [15]:
from tqdm import tqdm

def train_epoch(model, opt, ema, loader, epoch, cfg):
    """Train for one epoch with progress bar."""
    model.train()
    losses = []
    
    pbar = tqdm(loader, desc=f"Epoch {epoch:03d}", leave=False)
    for it, (x, _) in enumerate(pbar):
        x = x.to(device, non_blocking=True)
        sigma = sample_sigma(x.size(0), cfg.sigma_min, cfg.sigma_max, device)
        loss = edm_loss(model, x, sigma)
        
        opt.zero_grad(set_to_none=True)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
        ema.update(model)
        
        losses.append(loss.item())
        
        # Update progress bar every 10 iterations
        if len(losses) >= 10:
            pbar.set_postfix({'loss': f"{np.mean(losses[-10:]):.5f}"})
    
    return np.mean(losses)

@torch.no_grad()
def sample_grid(model, epoch, cfg, nrow=4):
    """Sample and save a grid of images."""
    model.eval()
    imgs = heun_sampler(
        model, nrow**2, cfg.channels, cfg.image_size,
        steps=50, sigma_min=cfg.sigma_min,
        sigma_max=cfg.sigma_max, rho=cfg.rho, device=device
    )
    grid = tvu.make_grid((imgs + 1) * 0.5, nrow=nrow)
    os.makedirs('samples', exist_ok=True)
    tv.utils.save_image(grid, f'samples/ep{epoch:03d}.png')
    print(f"  ‚úÖ Saved samples/ep{epoch:03d}.png")

print("‚úÖ Training functions with tqdm defined")

‚úÖ Training functions with tqdm defined


In [16]:
os.makedirs('checkpoints', exist_ok=True)
os.makedirs('samples', exist_ok=True)
print("‚úÖ Directories created")

# Training loop with tqdm
fid_scores = []
epoch_pbar = tqdm(range(1, cfg.epochs + 1), desc="Training")

for epoch in epoch_pbar:
    t0 = time.time()
    avg_loss = train_epoch(denoiser, opt, ema, train_loader, epoch, cfg)
    elapsed = time.time() - t0
    
    # Update epoch progress bar
    epoch_pbar.set_postfix({
        'loss': f"{avg_loss:.5f}",
        'time': f"{elapsed:.1f}s"
    })
    
    # Sample & evaluate
    if epoch % cfg.sample_every == 0 or epoch == 1:
        ema.copy_to(denoiser)
        sample_grid(denoiser, epoch, cfg)
        
        # Compute FID every 20 epochs
        if epoch % 20 == 0:
            print(f"\n  Computing FID at epoch {epoch}...")
            fid = compute_fid(denoiser, train_loader, cfg.fid_samples, cfg, inception)
            fid_scores.append((epoch, fid))
            print(f"  üìä FID at epoch {epoch}: {fid:.2f}\n")
        
        # Save checkpoint
        checkpoint = {
            'epoch': epoch,
            'model': denoiser.state_dict(),
            'ema': ema.shadow,
            'opt': opt.state_dict(),
            'fid_scores': fid_scores,
            'config': cfg
        }
        torch.save(checkpoint, f'checkpoints/ep{epoch:03d}.pt')

print("\nüéâ Training complete!")
print("\nüìä FID Scores:")
for ep, fid in fid_scores:
    print(f"  Epoch {ep:03d}: {fid:.2f}")

‚úÖ Directories created


Training:   1%|          | 1/100 [00:18<30:20, 18.39s/it, loss=0.08260, time=17.4s]

  ‚úÖ Saved samples/ep001.png


Training:  10%|‚ñà         | 10/100 [02:55<26:25, 17.61s/it, loss=0.06971, time=17.2s]

  ‚úÖ Saved samples/ep010.png


Training:  19%|‚ñà‚ñâ        | 19/100 [05:48<23:20, 17.29s/it, loss=0.06853, time=17.2s]

  ‚úÖ Saved samples/ep020.png

  Computing FID at epoch 20...
   Computing features for 10000 real images...
   Generating 10000 fake images...


Training:  19%|‚ñà‚ñâ        | 19/100 [08:04<34:25, 25.50s/it, loss=0.06853, time=17.2s]


ValueError: matmul: Input operand 0 does not have enough dimensions (has 0, gufunc core with signature (n?,k),(k,m?)->(n?,m?) requires 1)

In [None]:
# Final evaluation
ema.copy_to(denoiser)
denoiser.eval()

print("Final sampling...")
with torch.no_grad():
    imgs = heun_sampler(
        denoiser, 64, cfg.channels, cfg.image_size,
        steps=50, sigma_min=cfg.sigma_min,
        sigma_max=cfg.sigma_max, rho=cfg.rho, device=device
    )

grid = tvu.make_grid((imgs + 1) * 0.5, nrow=8)
tv.utils.save_image(grid, 'samples/final.png')
print("‚úÖ Saved samples/final.png")

# Display
from IPython.display import Image, display
display(Image('samples/final.png'))

# test

In [19]:
# New cell - test if your setup CAN learn
import torch
import torch.nn.functional as F

def quick_diagnostic(denoiser, train_loader, device):
    """Test if model can learn at all."""
    print("\nüî¨ DIAGNOSTIC: Can model learn?")
    print("="*50)
    
    denoiser.train()
    opt = torch.optim.Adam(denoiser.parameters(), lr=1e-3)
    
    # Get one small batch
    x, _ = next(iter(train_loader))
    x = x[:8].to(device)
    
    # Overfit on this batch
    losses = []
    for i in range(50):
        sigma = torch.ones(x.size(0), device=device) * 5.0
        noise = torch.randn_like(x)
        x_noisy = x + sigma[:, None, None, None] * noise
        
        x_pred = denoiser(x_noisy, sigma)
        loss = F.mse_loss(x_pred, x)
        
        opt.zero_grad()
        loss.backward()
        opt.step()
        
        losses.append(loss.item())
    
    print(f"Initial loss: {losses[0]:.4f}")
    print(f"Final loss:   {losses[-1]:.4f}")
    print(f"Reduction:    {(1 - losses[-1]/losses[0])*100:.1f}%")
    
    if losses[-1] < losses[0] * 0.2:
        print("‚úÖ Model CAN learn!")
        return True
    else:
        print("‚ùå Model CANNOT learn - architecture issue")
        return False

# Run it
can_learn = quick_diagnostic(denoiser, train_loader, device)


üî¨ DIAGNOSTIC: Can model learn?
Initial loss: 0.1491
Final loss:   0.1413
Reduction:    5.2%
‚ùå Model CANNOT learn - architecture issue


In [20]:
"""
EMERGENCY FIX: Simplified Working Baseline

This uses a simpler architecture to verify your training pipeline works.
Once this works (FID ~20-30 at epoch 50), we can add PC structure.
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# ==============================================================================
# SIMPLIFIED ARCHITECTURE (GUARANTEED TO WORK)
# ==============================================================================

def get_groups(channels: int, max_groups: int = 32) -> int:
    for g in range(max_groups, 0, -1):
        if channels % g == 0:
            return g
    return 1

class ResBlock(nn.Module):
    """Simple residual block."""
    def __init__(self, channels: int, emb_dim: int):
        super().__init__()
        self.norm1 = nn.GroupNorm(get_groups(channels), channels)
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        
        self.temb_proj = nn.Sequential(
            nn.SiLU(),
            nn.Linear(emb_dim, channels)
        )
        
        self.norm2 = nn.GroupNorm(get_groups(channels), channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
    
    def forward(self, x, temb):
        h = self.norm1(x)
        h = F.silu(h)
        h = self.conv1(h)
        
        # Time embedding
        h = h + self.temb_proj(temb)[:, :, None, None]
        
        h = self.norm2(h)
        h = F.silu(h)
        h = self.conv2(h)
        
        return x + h  # Residual connection

class SimpleHierarchicalNet(nn.Module):
    """Simplified hierarchical network (like U-Net but simpler)."""
    def __init__(
        self,
        in_channels: int = 3,
        base_channels: int = 128,
        emb_dim: int = 128,
        num_res_blocks: int = 2
    ):
        super().__init__()
        
        # Input
        self.conv_in = nn.Conv2d(in_channels, base_channels, 3, padding=1)
        
        # Encoder (32 -> 16 -> 8 -> 4)
        self.down1 = nn.ModuleList([
            ResBlock(base_channels, emb_dim) for _ in range(num_res_blocks)
        ])
        self.down1_pool = nn.Conv2d(base_channels, base_channels*2, 3, stride=2, padding=1)
        
        self.down2 = nn.ModuleList([
            ResBlock(base_channels*2, emb_dim) for _ in range(num_res_blocks)
        ])
        self.down2_pool = nn.Conv2d(base_channels*2, base_channels*4, 3, stride=2, padding=1)
        
        self.down3 = nn.ModuleList([
            ResBlock(base_channels*4, emb_dim) for _ in range(num_res_blocks)
        ])
        self.down3_pool = nn.Conv2d(base_channels*4, base_channels*4, 3, stride=2, padding=1)
        
        # Bottleneck (4x4)
        self.mid = nn.ModuleList([
            ResBlock(base_channels*4, emb_dim) for _ in range(2)
        ])
        
        # Decoder (4 -> 8 -> 16 -> 32)
        self.up3_upsample = nn.ConvTranspose2d(base_channels*4, base_channels*4, 4, stride=2, padding=1)
        self.up3 = nn.ModuleList([
            ResBlock(base_channels*4 + base_channels*4, emb_dim),  # +skip
            ResBlock(base_channels*4, emb_dim)
        ])
        
        self.up2_upsample = nn.ConvTranspose2d(base_channels*4, base_channels*2, 4, stride=2, padding=1)
        self.up2 = nn.ModuleList([
            ResBlock(base_channels*2 + base_channels*2, emb_dim),  # +skip
            ResBlock(base_channels*2, emb_dim)
        ])
        
        self.up1_upsample = nn.ConvTranspose2d(base_channels*2, base_channels, 4, stride=2, padding=1)
        self.up1 = nn.ModuleList([
            ResBlock(base_channels + base_channels, emb_dim),  # +skip
            ResBlock(base_channels, emb_dim)
        ])
        
        # Output
        self.norm_out = nn.GroupNorm(get_groups(base_channels), base_channels)
        self.conv_out = nn.Conv2d(base_channels, in_channels, 3, padding=1)
        
        # Initialize output layer to zero
        self.conv_out.weight.data.zero_()
        self.conv_out.bias.data.zero_()
    
    def forward(self, x, temb):
        # Input
        h = self.conv_in(x)
        
        # Encoder with skip connections
        h1 = h
        for block in self.down1:
            h1 = block(h1, temb)
        
        h2 = self.down1_pool(h1)
        for block in self.down2:
            h2 = block(h2, temb)
        
        h3 = self.down2_pool(h2)
        for block in self.down3:
            h3 = block(h3, temb)
        
        h = self.down3_pool(h3)
        
        # Bottleneck
        for block in self.mid:
            h = block(h, temb)
        
        # Decoder with skip connections
        h = self.up3_upsample(h)
        h = torch.cat([h, h3], dim=1)  # Skip
        h = self.up3[0](h, temb)
        h = self.up3[1](h, temb)
        
        h = self.up2_upsample(h)
        h = torch.cat([h, h2], dim=1)  # Skip
        h = self.up2[0](h, temb)
        h = self.up2[1](h, temb)
        
        h = self.up1_upsample(h)
        h = torch.cat([h, h1], dim=1)  # Skip
        h = self.up1[0](h, temb)
        h = self.up1[1](h, temb)
        
        # Output
        h = self.norm_out(h)
        h = F.silu(h)
        h = self.conv_out(h)
        
        return h


# ==============================================================================
# FIX: FID Computation with Better Error Handling
# ==============================================================================

import numpy as np
from scipy import linalg

@torch.no_grad()
def get_inception_features(images: torch.Tensor, model, batch_size: int = 50):
    """Extract Inception features for FID."""
    model.eval()
    features = []
    
    for i in range(0, len(images), batch_size):
        batch = images[i:i+batch_size]
        # Resize to 299x299 for InceptionV3
        batch = F.interpolate(batch, size=(299, 299), mode='bilinear', align_corners=False)
        
        # Get features
        with torch.no_grad():
            feat = model(batch)
            if isinstance(feat, tuple):
                feat = feat[0]  # InceptionV3 returns tuple
            feat = feat.squeeze(-1).squeeze(-1)  # Remove spatial dims
            features.append(feat.cpu().numpy())
    
    return np.concatenate(features, axis=0)

def calculate_fid(real_features: np.ndarray, fake_features: np.ndarray, eps=1e-6):
    """Calculate FID between real and fake features."""
    # Ensure 2D
    if real_features.ndim == 1:
        real_features = real_features.reshape(-1, 1)
    if fake_features.ndim == 1:
        fake_features = fake_features.reshape(-1, 1)
    
    # Check minimum samples
    if len(real_features) < 2 or len(fake_features) < 2:
        print("Warning: Not enough samples for FID calculation")
        return float('inf')
    
    mu1 = real_features.mean(axis=0)
    mu2 = fake_features.mean(axis=0)
    
    sigma1 = np.cov(real_features, rowvar=False)
    sigma2 = np.cov(fake_features, rowvar=False)
    
    # Handle scalar case
    if sigma1.ndim == 0:
        sigma1 = sigma1.reshape(1, 1)
    if sigma2.ndim == 0:
        sigma2 = sigma2.reshape(1, 1)
    
    diff = mu1 - mu2
    
    # Compute sqrt of product
    covmean, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False)
    
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            print("Warning: Imaginary component in covmean")
        covmean = covmean.real
    
    fid = diff @ diff + np.trace(sigma1 + sigma2 - 2 * covmean)
    return float(fid)


# ==============================================================================
# DIAGNOSTIC: Check if model is learning
# ==============================================================================

def diagnose_training(denoiser, train_loader, device):
    """Quick diagnostic to see if model can learn."""
    print("\n" + "="*60)
    print("DIAGNOSTIC: Testing if model can learn")
    print("="*60)
    
    denoiser.train()
    opt = torch.optim.Adam(denoiser.parameters(), lr=1e-3)
    
    # Overfit on ONE batch
    x, _ = next(iter(train_loader))
    x = x[:16].to(device)  # Just 16 images
    
    losses = []
    for i in range(100):
        sigma = torch.ones(x.size(0), device=device) * 10.0
        
        # Add noise
        noise = torch.randn_like(x)
        x_noisy = x + sigma[:, None, None, None] * noise
        
        # Predict clean
        x_pred = denoiser(x_noisy, sigma)
        loss = F.mse_loss(x_pred, x)
        
        opt.zero_grad()
        loss.backward()
        opt.step()
        
        losses.append(loss.item())
        
        if (i+1) % 20 == 0:
            print(f"  Step {i+1:3d}: loss = {loss.item():.6f}")
    
    print(f"\n  Initial loss: {losses[0]:.6f}")
    print(f"  Final loss:   {losses[-1]:.6f}")
    print(f"  Reduction:    {(losses[0] - losses[-1]) / losses[0] * 100:.1f}%")
    
    if losses[-1] < losses[0] * 0.1:
        print("  ‚úÖ Model CAN learn (loss reduced >90%)")
        return True
    else:
        print("  ‚ùå Model NOT learning properly")
        return False


# ==============================================================================
# USAGE
# ==============================================================================

print("""
TO USE THIS FIX:

1. Create a new cell with:
   from simplified_fix import SimpleHierarchicalNet, diagnose_training
   
2. Replace your PC core with:
   simple_core = SimpleHierarchicalNet(
       in_channels=3,
       base_channels=128,
       emb_dim=128,
       num_res_blocks=2
   )
   
3. Create denoiser as before:
   denoiser = DenoiserEDM(simple_core, cfg.emb_dim, cfg.sigma_data).to(device)
   
4. Run diagnostic:
   can_learn = diagnose_training(denoiser, train_loader, device)
   
5. If diagnostic passes, train as normal!
""")


TO USE THIS FIX:

1. Create a new cell with:
   from simplified_fix import SimpleHierarchicalNet, diagnose_training
   
2. Replace your PC core with:
   simple_core = SimpleHierarchicalNet(
       in_channels=3,
       base_channels=128,
       emb_dim=128,
       num_res_blocks=2
   )
   
3. Create denoiser as before:
   denoiser = DenoiserEDM(simple_core, cfg.emb_dim, cfg.sigma_data).to(device)
   
4. Run diagnostic:
   can_learn = diagnose_training(denoiser, train_loader, device)
   
5. If diagnostic passes, train as normal!



In [21]:
simple_core = SimpleHierarchicalNet(
    in_channels=3,
    base_channels=128,
    emb_dim=128,
    num_res_blocks=2
)

# Recreate denoiser
denoiser = DenoiserEDM(simple_core, cfg.emb_dim, cfg.sigma_data).to(device)
opt = torch.optim.AdamW(denoiser.parameters(), lr=cfg.lr)
ema = EMA(denoiser, cfg.ema_decay)

print("‚úÖ Using simple baseline architecture")
print(f"   Parameters: {sum(p.numel() for p in denoiser.parameters())/1e6:.1f}M")

# Now train as normal

‚úÖ Using simple baseline architecture
   Parameters: 64.2M
