# PixNerd ImageNet-256 Training (Heavy Decoder, Class-Conditional)

This self-contained notebook trains PixNerd on ImageNet using the **text-to-image heavy decoder** as a neural-field head while keeping diffusion on the 256×256 training grid. It loads images from `/pscratch/sd/k/kevinval/datasets/imagenet256` (class-folder layout) and uses class embeddings instead of text.


## Environment setup
- Requires CUDA for meaningful speed.
- Imports modules directly from the repository source tree (no Lightning CLI).
- Adjust batch size/epochs to fit your GPUs.


In [None]:
import math, os, random, time
from dataclasses import dataclass
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms

# Reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)


## Data: ImageNet-256 loaders
Images are normalized to [-1, 1] for diffusion training. The dataset path should contain class subfolders with JPEGs.


In [None]:
# Paths and hyperparameters
IMAGENET_ROOT = "/pscratch/sd/k/kevinval/datasets/imagenet256"
IMAGE_SIZE = 256
BATCH_SIZE = 64
NUM_WORKERS = 8
NUM_CLASSES = 1000

transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])

train_dataset = torchvision.datasets.ImageFolder(
    root=IMAGENET_ROOT,
    transform=transform,
)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

print('Train images:', len(train_dataset))


## Noise schedule helpers (DDPM)


In [None]:
def make_beta_schedule(T=1000, beta_start=0.0001, beta_end=0.02):
    return torch.linspace(beta_start, beta_end, T)

@dataclass
class DiffusionSchedule:
    betas: torch.Tensor
    alphas: torch.Tensor
    alpha_bars: torch.Tensor

    @classmethod
    def create(cls, T=1000, beta_start=0.0001, beta_end=0.02, device='cpu'):
        betas = make_beta_schedule(T, beta_start, beta_end).to(device)
        alphas = 1.0 - betas
        alpha_bars = torch.cumprod(alphas, dim=0)
        return cls(betas, alphas, alpha_bars)


## Heavy decoder wrapper for class conditioning
We reuse the text-to-image heavy decoder and feed it with learned class embeddings (length 1 sequence). Coordinate interpolation for super-resolution is controlled via `decoder_patch_scaling_*`.


In [None]:
from src.models.transformer.pixnerd_t2i_heavydecoder import PixNerDiT

class ClassConditionalPixNerDiT(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES, in_channels=3, patch_size=2, hidden_size=1152,
                 decoder_hidden_size=64, num_encoder_blocks=18, num_decoder_blocks=4, num_text_blocks=4,
                 txt_embed_dim=1024, txt_max_length=1):
        super().__init__()
        self.model = PixNerDiT(
            in_channels=in_channels,
            hidden_size=hidden_size,
            decoder_hidden_size=decoder_hidden_size,
            num_encoder_blocks=num_encoder_blocks,
            num_decoder_blocks=num_decoder_blocks,
            num_text_blocks=num_text_blocks,
            patch_size=patch_size,
            txt_embed_dim=txt_embed_dim,
            txt_max_length=txt_max_length,
        )
        self.class_embed = nn.Embedding(num_classes, txt_embed_dim)

    def forward(self, x, t, labels):
        y = self.class_embed(labels).unsqueeze(1)
        return self.model(x, t, y)


## Diffusion model tying everything together


In [None]:
class PixNerdDDPM(nn.Module):
    def __init__(self, model: ClassConditionalPixNerDiT, schedule: DiffusionSchedule):
        super().__init__()
        self.model = model
        self.schedule = schedule

    def loss(self, x0, labels):
        b = x0.size(0)
        t = torch.randint(0, self.schedule.betas.size(0), (b,), device=x0.device)
        noise = torch.randn_like(x0)
        alpha_bar_t = self.schedule.alpha_bars[t].view(-1, 1, 1, 1)
        xt = torch.sqrt(alpha_bar_t) * x0 + torch.sqrt(1 - alpha_bar_t) * noise
        pred = self.model(xt, t, labels)
        return F.mse_loss(pred, noise)

    @torch.no_grad()
    def sample(self, batch_size, labels, timesteps=50, img_size=IMAGE_SIZE):
        x = torch.randn(batch_size, 3, img_size, img_size, device=labels.device)
        T = self.schedule.betas.size(0)
        skip = T // timesteps
        for i in range(T - 1, -1, -skip):
            t = torch.full((batch_size,), i, device=labels.device, dtype=torch.long)
            beta_t = self.schedule.betas[t].view(-1, 1, 1, 1)
            alpha_t = self.schedule.alphas[t].view(-1, 1, 1, 1)
            alpha_bar_t = self.schedule.alpha_bars[t].view(-1, 1, 1, 1)
            eps = self.model(x, t, labels)
            x0_pred = (x - torch.sqrt(1 - alpha_bar_t) * eps) / torch.sqrt(alpha_bar_t)
            if i > 0:
                noise = torch.randn_like(x)
            else:
                noise = 0
            x = torch.sqrt(alpha_t) * x0_pred + torch.sqrt(beta_t) * noise
        return x.clamp(-1, 1)


## Instantiate model, optimizer, EMA


In [None]:
schedule = DiffusionSchedule.create(T=1000, device=device)
model = ClassConditionalPixNerDiT(num_classes=NUM_CLASSES).to(device)
diffusion = PixNerdDDPM(model, schedule).to(device)

optimizer = torch.optim.AdamW(diffusion.parameters(), lr=2e-4, weight_decay=1e-4)

# Simple EMA for stability
ema_decay = 0.999
ema_diffusion = PixNerdDDPM(ClassConditionalPixNerDiT(num_classes=NUM_CLASSES).to(device), schedule)
ema_diffusion.load_state_dict(diffusion.state_dict())

def update_ema(target, source, decay):
    with torch.no_grad():
        for tgt, src in zip(target.parameters(), source.parameters()):
            tgt.data.mul_(decay).add_(src.data, alpha=1 - decay)


## Training loop
Minimal loop; increase epochs/steps for real training. Add checkpointing/logging as needed.


In [None]:
EPOCHS = 1   # increase to train properly
log_interval = 50
save_dir = Path("./pixnerd_imagenet_heavy_ckpts")
save_dir.mkdir(parents=True, exist_ok=True)

step = 0
for epoch in range(EPOCHS):
    diffusion.train()
    for imgs, labels in train_loader:
        imgs = imgs.to(device)
        labels = labels.to(device)
        loss = diffusion.loss(imgs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        update_ema(ema_diffusion, diffusion, ema_decay)

        if step % log_interval == 0:
            print(f"epoch {epoch} step {step} loss {loss.item():.4f}")
        step += 1

    torch.save({
        'diffusion': diffusion.state_dict(),
        'ema_diffusion': ema_diffusion.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
        'step': step,
    }, save_dir / f"epoch{epoch:03d}.pt")


## Super-resolution sampling
Run the heavy decoder at a higher resolution (e.g., 512×512) while keeping the diffusion grid at 256×256.


In [None]:
@torch.no_grad()
def sample_superres(num_imgs=4, target_size=(512, 512), timesteps=50):
    ema_diffusion.eval()
    h, w = target_size
    labels = torch.randint(0, NUM_CLASSES, (num_imgs,), device=device)

    # Base sampling on the training grid
    base = ema_diffusion.sample(batch_size=num_imgs, labels=labels, timesteps=timesteps, img_size=IMAGE_SIZE)

    # Scale decoder patches for super-resolution
    scale_h = h / IMAGE_SIZE
    scale_w = w / IMAGE_SIZE
    model.model.decoder_patch_scaling_h = scale_h
    model.model.decoder_patch_scaling_w = scale_w
    ema_diffusion.model.model.decoder_patch_scaling_h = scale_h
    ema_diffusion.model.model.decoder_patch_scaling_w = scale_w

    up = torch.nn.functional.interpolate(base, size=(h, w), mode='bilinear', align_corners=False)
    refined = ema_diffusion.model(up, torch.zeros(num_imgs, device=device, dtype=torch.long), labels)

    # Restore default scaling
    model.model.decoder_patch_scaling_h = 1.0
    model.model.decoder_patch_scaling_w = 1.0
    ema_diffusion.model.model.decoder_patch_scaling_h = 1.0
    ema_diffusion.model.model.decoder_patch_scaling_w = 1.0

    imgs = (refined.clamp(-1, 1) + 1) * 0.5
    grid = torchvision.utils.make_grid(imgs, nrow=2)
    plt.figure(figsize=(8, 8))
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.axis('off')
    plt.show()

# Example (commented to avoid accidental long runs)
# sample_superres(num_imgs=4, target_size=(512, 512), timesteps=50)
