# PixNerd checkpoint inference (ImageNet class-conditional)

This notebook loads a trained PixNerd XL checkpoint, runs class-conditional sampling, and visualizes both 256×256 and higher-resolution (e.g., 512×512) outputs without relying on the Lightning CLI. Configure the checkpoint path and a few options at the top, then run the cells sequentially on a GPU runtime.


## Environment & paths
- Assumes the repository is already installed (dependencies from `requirements.txt`).
- Provide the checkpoint at `checkpoints/PixNerd-XL-P16-C2I/epoch=319-step=1600000.ckpt` relative to the repo root (adjustable below).
- Sampling expects a CUDA GPU; the sampler uses bfloat16 autocast on CUDA.


In [None]:
import os
from pathlib import Path
import math
import torch
import matplotlib.pyplot as plt
from PIL import Image

# Optional: change to your repo root if running from elsewhere
REPO_ROOT = Path(os.getcwd())

# Checkpoint location (adjust if yours lives elsewhere)
CKPT_PATH = REPO_ROOT / "checkpoints/PixNerd-XL-P16-C2I/epoch=319-step=1600000.ckpt"

# Where to save generated images
OUTPUT_DIR = REPO_ROOT / "notebooks" / "pixnerd_inference_outputs"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Repo root: {REPO_ROOT}")
print(f"Checkpoint exists: {CKPT_PATH.exists()}")
print(f"Using device: {DEVICE}")


## Build the PixNerd XL model components
This mirrors the training configuration used for ImageNet 256×256:
- Pixel-space autoencoder (`PixelAE`)
- Label conditioner (1000 ImageNet classes)
- PixNerDiT XL denoiser (patch size 16, 30 blocks)
- Flow-matching Euler sampler with classifier-free guidance
- REPATrainer stub (only needed to load projection weights from the checkpoint)
- SimpleEMA tracker to hold EMA weights from training


In [None]:
from src.models.autoencoder.pixel import PixelAE
from src.models.conditioner.class_label import LabelConditioner
from src.models.transformer.pixnerd_c2i import PixNerDiT
from src.diffusion.flow_matching.scheduling import LinearScheduler
from src.diffusion.flow_matching.sampling import EulerSampler, ode_step_fn
from src.diffusion.base.guidance import simple_guidance_fn
from src.diffusion.flow_matching.training_repa import REPATrainer
from src.callbacks.simple_ema import SimpleEMA
from src.lightning_model import LightningModel
from src.models.encoder import IndentityMapping
from src.models.autoencoder.base import fp2uint8

# Model hyperparameters copied from configs_c2i/pix256std1_repa_pixnerd_xl.yaml
HIDDEN_SIZE = 1152
NUM_CLASSES = 1000
PATCH_SIZE = 16

# Schedulers
main_scheduler = LinearScheduler()
guide_scheduler = LinearScheduler()

# Core modules
vae = PixelAE(scale=1.0)
conditioner = LabelConditioner(num_classes=NUM_CLASSES)
denoiser = PixNerDiT(
    in_channels=3,
    patch_size=PATCH_SIZE,
    num_groups=16,
    hidden_size=HIDDEN_SIZE,
    hidden_size_x=64,
    num_blocks=30,
    num_cond_blocks=26,
    nerf_mlpratio=2,
    num_classes=NUM_CLASSES,
)

# Sampler mirrors the ImageNet setup
sampler = EulerSampler(
    num_steps=100,
    guidance=3.5,
    guidance_interval_min=0.1,
    guidance_interval_max=1.0,
    scheduler=main_scheduler,
    w_scheduler=guide_scheduler,
    guidance_fn=simple_guidance_fn,
    step_fn=ode_step_fn,
)

# REPATrainer is only instantiated so its projection MLP weights can be loaded from the checkpoint
trainer_stub = REPATrainer(
    scheduler=main_scheduler,
    lognorm_t=True,
    encoder=IndentityMapping(),  # placeholder; encoder is unused for inference
    align_layer=8,
    proj_denoiser_dim=HIDDEN_SIZE,
    proj_hidden_dim=HIDDEN_SIZE,
    proj_encoder_dim=768,
)

ema_tracker = SimpleEMA(decay=0.9999)

# Wrap everything in the LightningModel for easy checkpoint loading
model = LightningModel(
    vae=vae,
    conditioner=conditioner,
    denoiser=denoiser,
    diffusion_trainer=trainer_stub,
    diffusion_sampler=sampler,
    ema_tracker=ema_tracker,
    optimizer=None,
    lr_scheduler=None,
    eval_original_model=False,
)
model.eval()
model.to(DEVICE)
print("Model initialized.")


## Load the checkpoint (EMA weights)
`strict=False` lets us ignore any optimizer or trainer state that may be absent. EMA weights are stored under `ema_denoiser.*` and are used for sampling.


In [None]:
ckpt = torch.load(CKPT_PATH, map_location="cpu")
missing, unexpected = model.load_state_dict(ckpt["state_dict"], strict=False)
print(f"Missing keys: {len(missing)} | Unexpected keys: {len(unexpected)}")
print("Ready for sampling with EMA denoiser.")


## Sampling helpers
- `sample_imagenet_classes` draws Gaussian noise, runs the Euler sampler, decodes to pixel space, and returns uint8 tensors.
- `save_grid` writes a simple grid PNG for quick inspection.


In [None]:
@torch.no_grad()
def sample_imagenet_classes(class_ids, height=256, width=256, seed=0, guidance=3.5, num_steps=100):
    torch.manual_seed(seed)
    batch = len(class_ids)
    # Update guidance or steps on the fly if desired
    model.diffusion_sampler.guidance = guidance
    model.diffusion_sampler.num_steps = num_steps

    noise = torch.randn(batch, 3, height, width, device=DEVICE)
    labels = torch.tensor(class_ids, device=DEVICE)

    condition, uncondition = model.conditioner(labels)
    samples = model.diffusion_sampler(
        model.ema_denoiser,
        noise,
        condition,
        uncondition,
    )
    images = model.vae.decode(samples)
    images = torch.clamp(images, -1.0, 1.0)
    images_uint8 = fp2uint8(images)
    return images_uint8.cpu()


def save_grid(images_uint8, filename, cols=None):
    imgs = [Image.fromarray(img) for img in images_uint8]
    n = len(imgs)
    if cols is None:
        cols = math.ceil(math.sqrt(n))
    rows = math.ceil(n / cols)
    w, h = imgs[0].size
    grid = Image.new("RGB", (cols * w, rows * h))
    for idx, img in enumerate(imgs):
        r, c = divmod(idx, cols)
        grid.paste(img, (c * w, r * h))
    out_path = OUTPUT_DIR / filename
    grid.save(out_path)
    return out_path


## Generate 256×256 ImageNet samples
Choose a few class IDs (0–999). The example below uses three diverse classes. Adjust `SEED`, `GUIDANCE`, or `NUM_STEPS` as needed.


In [None]:
CLASS_IDS = [207, 130, 340]  # golden retriever, flamingo, zebra (example ImageNet IDs)
SEED = 42
GUIDANCE = 3.5
NUM_STEPS = 100

images_256 = sample_imagenet_classes(
    class_ids=CLASS_IDS,
    height=256,
    width=256,
    seed=SEED,
    guidance=GUIDANCE,
    num_steps=NUM_STEPS,
)

print(f"Generated batch shape: {images_256.shape}")
_ = plt.figure(figsize=(12, 4))
plt.imshow(Image.fromarray(images_256[0]))
plt.axis('off')
plt.show()

save_path_256 = save_grid(images_256, "samples_256.png")
print(f"Saved grid to {save_path_256}")


## Super-resolution sampling (e.g., 512×512)
Because the model works in pixel space with coordinate-aware embeddings, simply draw higher-resolution noise and run the same sampler. You can tweak guidance or steps separately for super-res runs.


In [None]:
HIGH_RES = 512  # set to any square size supported by your GPU
SUPERRES_SEED = 7
SUPERRES_GUIDANCE = 4.0
SUPERRES_STEPS = 120

images_sr = sample_imagenet_classes(
    class_ids=CLASS_IDS,
    height=HIGH_RES,
    width=HIGH_RES,
    seed=SUPERRES_SEED,
    guidance=SUPERRES_GUIDANCE,
    num_steps=SUPERRES_STEPS,
)

print(f"Super-res batch shape: {images_sr.shape}")
_ = plt.figure(figsize=(12, 4))
plt.imshow(Image.fromarray(images_sr[0]))
plt.axis('off')
plt.show()

save_path_sr = save_grid(images_sr, f"samples_{HIGH_RES}.png")
print(f"Saved grid to {save_path_sr}")


## Notes
- If you see autocast errors on CPU, switch to a GPU runtime; the sampler decorates its forward pass with CUDA autocast.
- To reproduce training-time sampler behavior, keep `guidance_interval_min=0.1`, `guidance_interval_max=1.0`, and `num_steps=100`.
- For deterministic reruns, fix `seed` and the class list; GPU algorithms may still introduce minor nondeterminism.
