# PixNerd CIFAR-10 Training (Heavy Decoder for Super-Resolution)This notebook trains the PixNerd class-conditional model on CIFAR-10 using the **text-to-image heavy decoder** architecture as a neural-field super-resolution head. Diffusion is kept on the 32×32 training grid, while the decoder can be run at larger resolutions via coordinate scaling.

## Environment setup- Requires CUDA for meaningful speed.- Installs are omitted; the notebook imports from the repository source tree directly.

In [None]:
import math, os, random, timefrom dataclasses import dataclassimport matplotlib.pyplot as pltimport numpy as npimport torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch.utils.data import DataLoaderimport torchvisionfrom torchvision import transforms# Reproducibilitytorch.manual_seed(42)np.random.seed(42)random.seed(42)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print('Device:', device)

## Data: CIFAR-10 loadersImages are normalized to [-1, 1] for diffusion training.

In [None]:
BATCH_SIZE = 128NUM_WORKERS = 4IMAGE_SIZE = 32NUM_CLASSES = 10transform = transforms.Compose([    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),    transforms.ToTensor(),    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),])train_dataset = torchvision.datasets.CIFAR10(    root='./data', train=True, download=True, transform=transform)train_loader = DataLoader(    train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)test_dataset = torchvision.datasets.CIFAR10(    root='./data', train=False, download=True, transform=transform)test_loader = DataLoader(    test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)len(train_dataset), len(test_dataset)

## Noise schedule helpers (DDPM)

In [None]:
def make_beta_schedule(T=1000, beta_start=0.0001, beta_end=0.02):    return torch.linspace(beta_start, beta_end, T)@dataclassclass DiffusionSchedule:    betas: torch.Tensor    alphas: torch.Tensor    alpha_bars: torch.Tensor    @classmethod    def create(cls, T=1000, beta_start=0.0001, beta_end=0.02, device='cpu'):        betas = make_beta_schedule(T, beta_start, beta_end).to(device)        alphas = 1.0 - betas        alpha_bars = torch.cumprod(alphas, dim=0)        return cls(betas, alphas, alpha_bars)    def sample_xt(self, x0, t, noise=None):        if noise is None:            noise = torch.randn_like(x0)        a_bar = self.alpha_bars[t].view(-1, 1, 1, 1)        return (a_bar.sqrt() * x0) + ((1 - a_bar).sqrt() * noise), noise

## Heavy decoder wrapper for class conditioningWe reuse the text-to-image heavy decoder and feed it with learned class embeddings (length 1 sequence). Coordinate interpolation for super-resolution is controlled via `decoder_patch_scaling_*`.

In [None]:
from src.models.transformer.pixnerd_t2i_heavydecoder import PixNerDiTclass ClassConditionalPixNerDiT(nn.Module):    def __init__(self, num_classes=10, in_channels=3, patch_size=2, hidden_size=1152, decoder_hidden_size=64,                 num_encoder_blocks=18, num_decoder_blocks=4, num_text_blocks=4, txt_embed_dim=1024, txt_max_length=1):        super().__init__()        self.txt_max_length = txt_max_length        self.class_embed = nn.Embedding(num_classes, txt_embed_dim)        self.dit = PixNerDiT(            in_channels=in_channels,            num_groups=12,            hidden_size=hidden_size,            decoder_hidden_size=decoder_hidden_size,            num_encoder_blocks=num_encoder_blocks,            num_decoder_blocks=num_decoder_blocks,            num_text_blocks=num_text_blocks,            patch_size=patch_size,            txt_embed_dim=txt_embed_dim,            txt_max_length=txt_max_length,        )    def forward(self, x, t, labels):        # labels: (B,) ints        cls_tokens = self.class_embed(labels).unsqueeze(1)  # (B, 1, txt_embed_dim)        return self.dit(x, t, cls_tokens)

## Diffusion model tying everything together

In [None]:
class PixNerdDDPM(nn.Module):    def __init__(self, model: ClassConditionalPixNerDiT, schedule: DiffusionSchedule):        super().__init__()        self.model = model        self.schedule = schedule    def loss(self, x0, labels):        b = x0.size(0)        t = torch.randint(0, self.schedule.betas.size(0), (b,), device=x0.device)        xt, noise = self.schedule.sample_xt(x0, t)        pred = self.model(xt, t, labels)        return F.mse_loss(pred, noise)    @torch.no_grad()    def p_sample(self, xt, t, labels):        beta = self.schedule.betas[t]        alpha = self.schedule.alphas[t]        alpha_bar = self.schedule.alpha_bars[t]        noise_pred = self.model(xt, torch.full((xt.size(0),), t, device=xt.device, dtype=torch.long), labels)        coef1 = 1 / alpha.sqrt()        coef2 = beta / ((1 - alpha_bar).sqrt())        mean = coef1 * (xt - coef2 * noise_pred)        if t == 0:            return mean        noise = torch.randn_like(xt)        return mean + (beta.sqrt() * noise)    @torch.no_grad()    def sample(self, batch_size, labels, timesteps=1000, img_size=32):        xt = torch.randn(batch_size, 3, img_size, img_size, device=labels.device)        for t in reversed(range(timesteps)):            xt = self.p_sample(xt, t, labels)        return xt    @torch.no_grad()    def superres(self, base_imgs, labels, target_size):        # base_imgs: (B,3,H,W) typically H=32        _, _, h, w = base_imgs.shape        scale_h = target_size[0] / h        scale_w = target_size[1] / w        self.model.dit.decoder_patch_scaling_h = scale_h        self.model.dit.decoder_patch_scaling_w = scale_w        upsampled = F.interpolate(base_imgs, size=target_size, mode='bilinear', align_corners=False)        t = torch.zeros(base_imgs.size(0), device=base_imgs.device, dtype=torch.long)        return self.model(upsampled, t, labels)

## Instantiate model, optimizer, EMA

In [None]:
schedule = DiffusionSchedule.create(T=1000, device=device)model = ClassConditionalPixNerDiT(num_classes=NUM_CLASSES).to(device)diffusion = PixNerdDDPM(model, schedule).to(device)optimizer = torch.optim.AdamW(diffusion.parameters(), lr=2e-4, weight_decay=1e-4)# Simple EMA for stabilityema_decay = 0.999ema_model = ClassConditionalPixNerDiT(num_classes=NUM_CLASSES).to(device)ema_model.load_state_dict(model.state_dict())ema_diffusion = PixNerdDDPM(ema_model, schedule)def update_ema(model_src, model_dst, decay):    with torch.no_grad():        for p_src, p_dst in zip(model_src.parameters(), model_dst.parameters()):            p_dst.data.mul_(decay).add_(p_src.data, alpha=1 - decay)

## Training loopThis is a minimal loop; adjust epochs/steps and add checkpointing as needed.

In [None]:
EPOCHS = 1  # increase for real traininglog_interval = 100for epoch in range(EPOCHS):    diffusion.train()    for step, (imgs, labels) in enumerate(train_loader):        imgs = imgs.to(device)        labels = labels.to(device)        loss = diffusion.loss(imgs, labels)        optimizer.zero_grad()        loss.backward()        optimizer.step()        update_ema(model, ema_model, ema_decay)        if step % log_interval == 0:            print(f"epoch {epoch} step {step}: loss {loss.item():.4f}")    # quick eval batch    diffusion.eval()    with torch.no_grad():        sample_labels = torch.arange(0, NUM_CLASSES, device=device)[:16]        samples = ema_diffusion.sample(batch_size=sample_labels.size(0), labels=sample_labels, timesteps=50, img_size=IMAGE_SIZE)        grid = torchvision.utils.make_grid(samples, nrow=4, normalize=True, value_range=(-1, 1))        plt.figure(figsize=(6,6))        plt.axis('off')        plt.imshow(grid.permute(1,2,0).cpu().numpy())        plt.show()

## Super-resolution samplingRun the decoder at a higher resolution (e.g., 64×64) while keeping the diffusion grid at 32×32.

In [None]:
@torch.no_grad()def sample_superres(num_imgs=4, target_size=(64,64)):    diffusion.eval(); ema_diffusion.eval()    labels = torch.randint(0, NUM_CLASSES, (num_imgs,), device=device)    base = ema_diffusion.sample(batch_size=num_imgs, labels=labels, timesteps=50, img_size=IMAGE_SIZE)    sr = ema_diffusion.superres(base, labels, target_size=target_size)    return base, srbase_imgs, sr_imgs = sample_superres(num_imgs=4, target_size=(64,64))base_grid = torchvision.utils.make_grid(base_imgs, nrow=4, normalize=True, value_range=(-1,1))sr_grid = torchvision.utils.make_grid(sr_imgs, nrow=4, normalize=True, value_range=(-1,1))fig, axes = plt.subplots(1,2, figsize=(10,5))axes[0].imshow(base_grid.permute(1,2,0).cpu().numpy()); axes[0].set_title('32×32 diffusion'); axes[0].axis('off')axes[1].imshow(sr_grid.permute(1,2,0).cpu().numpy()); axes[1].set_title('64×64 super-res'); axes[1].axis('off')plt.show()