# CIFAR-10 Extended NerfEmbedder: Super-Resolution Inference

This notebook demonstrates the **Extended NerfEmbedder** architecture which improves NF quality through:

## Key Innovations

### 1. Extended Patch Boundaries (`margin=0.25`)
- Standard: positions in `[0, 1]` → only see interior points
- Extended: positions in `[-0.25, 1.25]` → 50% overlap at edges
- Reduces seam artifacts at patch boundaries

### 2. Position Jittering (`jitter_std=0.01`)
- Training: `coord + N(0, 0.01)` noise added to positions
- Forces model to learn continuous NF representations
- Improves interpolation at arbitrary coordinates

| Approach | Boundary Range | Position Jitter | Seam Artifacts |
|----------|---------------|-----------------|----------------|
| Original | [0, 1] | None | ❌ Visible |
| **Extended** | [-0.25, 1.25] | ✅ Yes | ✅ Reduced |

## Setup

- Requires GPU
- Assumes model was trained using `train_cifar10_extended.py`
- Update `CKPT_PATH` to point to your trained checkpoint

### Training command:
```bash
python train_cifar10_extended.py --margin 0.25 --jitter_std 0.01 --max_steps 100000
```

### Model config (must match training - ~10M params):
| Parameter | Value | Why |
|-----------|-------|-----|
| patch_size | 2 | 16×16=256 encoder tokens |
| margin | 0.25 | Predict [-0.25, 1.25] for overlap |
| jitter_std | 0.01 | Position noise during training |
| hidden_size | 256 | ~10M params |
| decoder_hidden_size | 32 | Decoder capacity |
| num_encoder_blocks | 6 | Encoder depth |

In [None]:
# Navigate to PixNerd folder where src/ is located
import os
import sys

NOTEBOOK_DIR = os.getcwd()
print(f"Starting directory: {NOTEBOOK_DIR}")

# Navigate to PixNerd folder (where src/ lives)
PIXNERD_DIR = os.path.join(NOTEBOOK_DIR, "PixNerd")
if os.path.exists(PIXNERD_DIR):
    os.chdir(PIXNERD_DIR)
    print(f"Changed to: {os.getcwd()}")
elif os.path.basename(NOTEBOOK_DIR) == "PixNerd":
    print(f"Already in PixNerd directory: {NOTEBOOK_DIR}")
else:
    parent = os.path.dirname(NOTEBOOK_DIR)
    pixnerd_in_parent = os.path.join(parent, "PixNerd")
    if os.path.exists(pixnerd_in_parent):
        os.chdir(pixnerd_in_parent)
        print(f"Changed to: {os.getcwd()}")
    else:
        print(f"WARNING: Could not find PixNerd folder. Current dir: {NOTEBOOK_DIR}")

if os.path.exists("src"):
    print("Found src/ directory")
else:
    print("ERROR: src/ directory not found!")

In [None]:
from pathlib import Path
import math
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image

# Paths
PIXNERD_ROOT = Path(os.getcwd())

# ============================================================
# CHECKPOINT PATH - UPDATE THIS TO YOUR TRAINED MODEL
# ============================================================
CKPT_PATH = PIXNERD_ROOT / "workdirs" / "exp_cifar10_extended_nerf" / "checkpoints" / "last.ckpt"
# ============================================================

OUTPUT_DIR = PIXNERD_ROOT / "outputs" / "cifar10_extended_superres"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
if DEVICE != "cuda":
    print("WARNING: Running on CPU will be very slow")

# CIFAR-10 class names
CIFAR10_CLASSES = [
    'airplane', 'automobile', 'bird', 'cat', 'deer',
    'dog', 'frog', 'horse', 'ship', 'truck'
]

# ============================================================
# MODEL CONFIG - Must match train_cifar10_extended.py (~10M params)
# ============================================================
NUM_CLASSES = 10
BASE_RES = 32  # CIFAR-10 native resolution

PATCH_SIZE = 2           # Controls encoder tokens: 32/2 = 16x16 = 256 tokens

# EXTENDED NERFEMBEDDER CONFIG - KEY INNOVATIONS!
MARGIN = 0.25            # Predict [-0.25, 1.25] for overlapping patches
JITTER_STD = 0.01        # Position jittering during training

# Model dimensions (~10M parameters)
HIDDEN_SIZE = 256
DECODER_HIDDEN_SIZE = 32
NUM_ENCODER_BLOCKS = 6
NUM_DECODER_BLOCKS = 2
NUM_GROUPS = 4
# ============================================================

print(f"PixNerd root: {PIXNERD_ROOT}")
print(f"Checkpoint path: {CKPT_PATH}")
print(f"Checkpoint exists: {CKPT_PATH.exists()}")
if not CKPT_PATH.exists():
    print(f"\n⚠️  CHECKPOINT NOT FOUND!")
    print(f"   Please train a model first using:")
    print(f"   python train_cifar10_extended.py --margin {MARGIN} --jitter_std {JITTER_STD}")
print(f"Output directory: {OUTPUT_DIR}")
print(f"Device: {DEVICE}")
print()
print("=" * 60)
print("EXTENDED NERFEMBEDDER ARCHITECTURE")
print("=" * 60)
print(f"  Extended boundaries: margin={MARGIN}")
print(f"    → Positions span [{-MARGIN:.2f}, {1+MARGIN:.2f}] instead of [0, 1]")
print(f"    → Reduces seam artifacts at patch boundaries")
print()
print(f"  Position jittering: std={JITTER_STD}")
print(f"    → Gaussian noise during training")
print(f"    → Forces continuous NF learning")
print("=" * 60)

## Build Model

Using `PixNerDiTExtended` with extended boundaries and jittering.

In [None]:
# Import PixNerd components - NOTE: Using EXTENDED model!
from src.models.autoencoder.pixel import PixelAE
from src.models.conditioner.class_label import LabelConditioner
from src.models.transformer.pixnerd_c2i_extended import PixNerDiTExtended  # Extended version!
from src.diffusion.flow_matching.scheduling import LinearScheduler
from src.diffusion.flow_matching.sampling import EulerSampler, ode_step_fn
from src.diffusion.base.guidance import simple_guidance_fn
from src.diffusion.flow_matching.training import FlowMatchingTrainer
from src.callbacks.simple_ema import SimpleEMA
from src.lightning_model import LightningModel
from src.models.autoencoder.base import fp2uint8

print("Imports successful!")
print("Using: PixNerDiTExtended (Extended Boundaries + Position Jittering)")

In [None]:
print("Initializing model components...")

main_scheduler = LinearScheduler()

vae = PixelAE(scale=1.0)

conditioner = LabelConditioner(num_classes=NUM_CLASSES)

# Extended NerfEmbedder model - KEY DIFFERENCE!
denoiser = PixNerDiTExtended(
    in_channels=3,
    patch_size=PATCH_SIZE,
    num_groups=NUM_GROUPS,
    hidden_size=HIDDEN_SIZE,
    decoder_hidden_size=DECODER_HIDDEN_SIZE,
    num_encoder_blocks=NUM_ENCODER_BLOCKS,
    num_decoder_blocks=NUM_DECODER_BLOCKS,
    num_classes=NUM_CLASSES,
    margin=MARGIN,              # Extended boundaries!
    jitter_std=JITTER_STD,      # Position jittering (only during training)
)

# Sampler with CFG
sampler = EulerSampler(
    num_steps=50,
    guidance=2.0,
    guidance_interval_min=0.0,
    guidance_interval_max=1.0,
    scheduler=main_scheduler,
    w_scheduler=LinearScheduler(),
    guidance_fn=simple_guidance_fn,
    step_fn=ode_step_fn,
)

# Trainer stub for checkpoint loading
trainer_stub = FlowMatchingTrainer(
    scheduler=main_scheduler,
    lognorm_t=True,
    timeshift=1.0,
)

ema_tracker = SimpleEMA(decay=0.9999)

model = LightningModel(
    vae=vae,
    conditioner=conditioner,
    denoiser=denoiser,
    diffusion_trainer=trainer_stub,
    diffusion_sampler=sampler,
    ema_tracker=ema_tracker,
    optimizer=None,
    lr_scheduler=None,
    eval_original_model=False,
)

model.eval()
model.to(DEVICE)
print(f"Model initialized and moved to {DEVICE}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print()
print(f"Architecture: PixNerDiTExtended")
print(f"  - margin={MARGIN} → positions in [{-MARGIN:.2f}, {1+MARGIN:.2f}]")
print(f"  - jitter_std={JITTER_STD} (only during training)")

## Load Checkpoint

In [None]:
print(f"Loading checkpoint from: {CKPT_PATH}")
ckpt = torch.load(CKPT_PATH, map_location="cpu", weights_only=False)
missing, unexpected = model.load_state_dict(ckpt["state_dict"], strict=False)
print(f"Missing keys: {len(missing)} | Unexpected keys: {len(unexpected)}")
if missing:
    print(f"  Missing: {missing[:5]}..." if len(missing) > 5 else f"  Missing: {missing}")
if unexpected:
    print(f"  Unexpected: {unexpected[:5]}..." if len(unexpected) > 5 else f"  Unexpected: {unexpected}")
print("Checkpoint loaded successfully!")

## Helper Functions

In [None]:
def set_decoder_scale(scale: float):
    """Set NF decoder patch scaling for super-resolution."""
    for net in [model.denoiser, getattr(model, "ema_denoiser", None)]:
        if net is None:
            continue
        net.decoder_patch_scaling_h = scale
        net.decoder_patch_scaling_w = scale


@torch.no_grad()
def sample_class_conditional(
    class_labels: list,
    height: int = 32,
    width: int = 32,
    seed: int = 42,
    num_steps: int = 50,
    guidance: float = 2.0,
    base_res: int = BASE_RES,
):
    """
    Generate class-conditional images.
    
    Args:
        class_labels: List of class indices (0-9) or names
        height: Output height (32 for native, 128 for 4x super-res)
        width: Output width
        seed: Random seed
        num_steps: ODE solver steps
        guidance: CFG guidance scale
        base_res: Training resolution
    
    Returns:
        Generated images as uint8 tensor
    """
    torch.manual_seed(seed)
    
    # Convert class names to indices if needed
    labels = []
    for label in class_labels:
        if isinstance(label, str):
            label = CIFAR10_CLASSES.index(label.lower())
        labels.append(label)
    
    batch_size = len(labels)
    
    # Set decoder scale for super-resolution
    if height == base_res and width == base_res:
        set_decoder_scale(1.0)
        print(f"Generating at native {base_res}x{base_res}")
    else:
        scale_h = height / float(base_res)
        scale_w = width / float(base_res)
        assert scale_h == scale_w, "Only square scaling supported"
        set_decoder_scale(scale_h)
        print(f"Generating at {height}x{width} ({scale_h:.0f}x super-resolution)")
    
    # Configure sampler
    model.diffusion_sampler.guidance = guidance
    model.diffusion_sampler.num_steps = num_steps
    
    # Generate noise
    noise = torch.randn(batch_size, 3, height, width, device=DEVICE)
    
    # Get condition and uncondition
    condition, uncondition = model.conditioner(labels)
    condition = condition.to(DEVICE)
    uncondition = uncondition.to(DEVICE)
    
    # Sample
    samples = model.diffusion_sampler(
        model.ema_denoiser,
        noise,
        condition,
        uncondition,
    )
    
    # Decode
    images = model.vae.decode(samples)
    images = torch.clamp(images, -1.0, 1.0)
    images_uint8 = fp2uint8(images)
    
    return images_uint8.cpu()


def show_images(images_uint8, labels=None, title="", cols=None):
    """Display a batch of images with labels."""
    if isinstance(images_uint8, torch.Tensor):
        imgs_np = images_uint8.permute(0, 2, 3, 1).cpu().numpy()
    else:
        imgs_np = np.transpose(images_uint8, (0, 2, 3, 1))
    
    n = len(imgs_np)
    if cols is None:
        cols = min(n, 5)
    rows = math.ceil(n / cols)
    
    fig, axes = plt.subplots(rows, cols, figsize=(3 * cols, 3 * rows))
    if rows == 1 and cols == 1:
        axes = np.array([axes])
    axes = axes.flatten()
    
    for i, (ax, img) in enumerate(zip(axes, imgs_np)):
        ax.imshow(img)
        ax.axis('off')
        if labels is not None:
            label = labels[i]
            if isinstance(label, int):
                label = CIFAR10_CLASSES[label]
            ax.set_title(label)
    
    for ax in axes[n:]:
        ax.axis('off')
    
    if title:
        plt.suptitle(title)
    plt.tight_layout()
    plt.show()


def save_grid(images_uint8, filename, labels=None, cols=None):
    """Save images as a grid."""
    if isinstance(images_uint8, torch.Tensor):
        imgs_np = images_uint8.permute(0, 2, 3, 1).cpu().numpy()
    else:
        imgs_np = np.transpose(images_uint8, (0, 2, 3, 1))
    
    imgs = [Image.fromarray(img) for img in imgs_np]
    
    n = len(imgs)
    if cols is None:
        cols = min(n, 5)
    rows = math.ceil(n / cols)
    
    w, h = imgs[0].size
    grid = Image.new("RGB", (cols * w, rows * h))
    for idx, img in enumerate(imgs):
        r, c = divmod(idx, cols)
        grid.paste(img, (c * w, r * h))
    
    out_path = OUTPUT_DIR / filename
    grid.save(out_path)
    print(f"Saved: {out_path}")
    return out_path


print("Helper functions defined.")

## Generate at Native Resolution (32x32)

Generate one sample per class at the native CIFAR-10 resolution.

In [None]:
print("=" * 50)
print("Generating all 10 classes at 32x32")
print("=" * 50)

# Generate one image per class
all_classes = list(range(10))

images_32 = sample_class_conditional(
    class_labels=all_classes,
    height=32,
    width=32,
    seed=42,
    num_steps=50,
    guidance=2.0,
)

print(f"Output shape: {images_32.shape}")
show_images(images_32, labels=CIFAR10_CLASSES, title="Extended NerfEmbedder: CIFAR-10 at 32x32")
save_grid(images_32, "extended_cifar10_32x32_all_classes.png", cols=5)

## Generate at 4x Super-Resolution (128x128)

The Extended NerfEmbedder should produce smoother results due to:
1. **Extended boundaries** - patches overlap, reducing seam artifacts
2. **Position jittering during training** - learned continuous representations

In [None]:
print("=" * 50)
print("Generating all 10 classes at 128x128 (4x super-resolution)")
print("=" * 50)

images_128 = sample_class_conditional(
    class_labels=all_classes,
    height=128,
    width=128,
    seed=42,  # Same seed for comparison
    num_steps=50,
    guidance=2.0,
)

print(f"Output shape: {images_128.shape}")
show_images(images_128, labels=CIFAR10_CLASSES, title="Extended NerfEmbedder: CIFAR-10 at 128x128 (4x Super-Res)")
save_grid(images_128, "extended_cifar10_128x128_4x_superres.png", cols=5)

## Side-by-Side Comparison: 32x32 vs 128x128

In [None]:
# Compare specific classes
comparison_classes = ['cat', 'dog', 'airplane', 'ship']
comparison_labels = [CIFAR10_CLASSES.index(c) for c in comparison_classes]

print("Comparing 32x32 vs Bilinear 128x128 vs Extended NF 128x128")

# Generate at 32x32
imgs_32 = sample_class_conditional(
    class_labels=comparison_labels,
    height=32, width=32,
    seed=123,
)

# Generate at 128x128 (4x)
imgs_128 = sample_class_conditional(
    class_labels=comparison_labels,
    height=128, width=128,
    seed=123,
)

# Bilinear upscale 32->128
imgs_32_upscaled = F.interpolate(
    imgs_32.float() / 255.0,
    size=(128, 128),
    mode='bilinear',
    align_corners=False,
)
imgs_32_upscaled = (imgs_32_upscaled * 255).to(torch.uint8)

# Plot comparison
fig, axes = plt.subplots(len(comparison_classes), 3, figsize=(12, 4 * len(comparison_classes)))

for i, class_name in enumerate(comparison_classes):
    # 32x32 native
    axes[i, 0].imshow(imgs_32[i].permute(1, 2, 0).numpy())
    axes[i, 0].set_title(f"{class_name} - 32x32 Native")
    axes[i, 0].axis('off')
    
    # Bilinear upscale
    axes[i, 1].imshow(imgs_32_upscaled[i].permute(1, 2, 0).numpy())
    axes[i, 1].set_title(f"{class_name} - Bilinear 128x128")
    axes[i, 1].axis('off')
    
    # Extended NF super-res
    axes[i, 2].imshow(imgs_128[i].permute(1, 2, 0).numpy())
    axes[i, 2].set_title(f"{class_name} - Extended NF 128x128")
    axes[i, 2].axis('off')

plt.suptitle("32x32 Native vs Bilinear Upscale vs Extended NF Super-Resolution", fontsize=14)
plt.tight_layout()
plt.show()

## Try Different Super-Resolution Scales

Test the extended boundaries at various scales.

In [None]:
# Compare different scales: 1x, 2x, 4x, 8x
print("=" * 50)
print("Comparing different super-resolution scales")
print("=" * 50)

scales = [
    (32, "1x (Native)"),
    (64, "2x"),
    (128, "4x"),
    (256, "8x"),
]

class_label = 'airplane'
seed = 42

scale_images = []
scale_titles = []

for resolution, scale_name in scales:
    print(f"Generating at {resolution}x{resolution}...")
    img = sample_class_conditional(
        class_labels=[class_label],
        height=resolution, width=resolution,
        seed=seed,
    )
    scale_images.append(img[0])
    scale_titles.append(f"{resolution}x{resolution} ({scale_name})")

# Display at same visual size
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
for i, (img, title) in enumerate(zip(scale_images, scale_titles)):
    axes[i].imshow(img.permute(1, 2, 0).numpy())
    axes[i].set_title(title)
    axes[i].axis('off')

plt.suptitle(f"{class_label.capitalize()} at Different Resolutions (Extended NerfEmbedder)", fontsize=14)
plt.tight_layout()
plt.show()

## Boundary Artifact Analysis

Let's look closely at potential boundary artifacts.
The extended boundaries should reduce visible seams between patches.

In [None]:
# Generate a high-res image and zoom in on potential boundary regions
print("=" * 50)
print("Boundary artifact analysis at 256x256")
print("=" * 50)

img_256 = sample_class_conditional(
    class_labels=['cat'],
    height=256, width=256,
    seed=42,
)

img_np = img_256[0].permute(1, 2, 0).numpy()

# Show full image and zoomed regions
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Full image
axes[0, 0].imshow(img_np)
axes[0, 0].set_title("Full 256x256 Image")
axes[0, 0].axis('off')

# Add grid lines showing patch boundaries (every 16 pixels at 8x scale)
axes[0, 1].imshow(img_np)
patch_size_scaled = PATCH_SIZE * 8  # 2 * 8 = 16 at 256x256
for i in range(0, 257, patch_size_scaled):
    axes[0, 1].axhline(y=i, color='red', linewidth=0.5, alpha=0.5)
    axes[0, 1].axvline(x=i, color='red', linewidth=0.5, alpha=0.5)
axes[0, 1].set_title("With Patch Boundaries (red)")
axes[0, 1].axis('off')

# Zoom regions
zoom_regions = [
    (64, 64, 128, 128, "Center region"),
    (0, 0, 64, 64, "Top-left corner"),
    (192, 192, 256, 256, "Bottom-right corner"),
    (112, 112, 144, 144, "Patch boundary zone"),
]

for idx, (y1, x1, y2, x2, title) in enumerate(zoom_regions[:4]):
    ax = axes.flatten()[idx + 2]
    ax.imshow(img_np[y1:y2, x1:x2])
    ax.set_title(f"Zoom: {title}")
    ax.axis('off')

plt.suptitle("Boundary Artifact Analysis - Extended NerfEmbedder", fontsize=14)
plt.tight_layout()
plt.show()

save_grid(img_256, "extended_cat_256x256_8x_superres.png")

## Guidance Scale Comparison

In [None]:
# Compare different guidance scales
print("=" * 50)
print("Comparing different CFG guidance scales at 128x128")
print("=" * 50)

guidance_scales = [1.0, 1.5, 2.0, 3.0, 5.0]
class_label = 'horse'
seed = 42

guidance_images = []
for g in guidance_scales:
    print(f"Guidance = {g}...")
    img = sample_class_conditional(
        class_labels=[class_label],
        height=128, width=128,
        seed=seed,
        guidance=g,
    )
    guidance_images.append(img[0])

fig, axes = plt.subplots(1, 5, figsize=(15, 3))
for i, (img, g) in enumerate(zip(guidance_images, guidance_scales)):
    axes[i].imshow(img.permute(1, 2, 0).numpy())
    axes[i].set_title(f"guidance={g}")
    axes[i].axis('off')

plt.suptitle(f"{class_label.capitalize()} with Different CFG Scales (Extended NerfEmbedder)", fontsize=14)
plt.tight_layout()
plt.show()

## Generate Variety of Samples

In [None]:
# Generate 5 different samples for selected classes at 128x128
print("=" * 50)
print("Generating 5 samples per class at 128x128")
print("=" * 50)

selected_classes = ['cat', 'dog', 'airplane']
seeds = [0, 42, 123, 456, 789]

for class_name in selected_classes:
    print(f"\nGenerating {class_name}s...")
    class_images = []
    for seed in seeds:
        img = sample_class_conditional(
            class_labels=[class_name],
            height=128, width=128,
            seed=seed,
        )
        class_images.append(img)
    
    class_images = torch.cat(class_images, dim=0)
    show_images(class_images, labels=[class_name] * 5, title=f"5 {class_name.capitalize()}s at 128x128 (Extended NF)")
    save_grid(class_images, f"extended_{class_name}s_128x128.png", cols=5)

## Summary

This notebook demonstrated the **Extended NerfEmbedder** architecture:

### Key Innovations

| Feature | Description | Benefit |
|---------|-------------|--------|
| **Extended Boundaries** | Positions in [-0.25, 1.25] | Overlapping patches reduce seams |
| **Position Jittering** | N(0, 0.01) noise during training | Smoother NF interpolation |

### How It Works

**Training:**
- NerfEmbedder predicts positions beyond [0,1] → sees "outside" each patch
- Random jitter forces continuous representation learning

**Inference:**
- Same extended positions, no jittering
- Overlapping regions can be blended for seamless output

### Comparison with Other Approaches

| Approach | Boundary | Jitter | Seam Artifacts | Interpolation |
|----------|----------|--------|----------------|---------------|
| Original | [0,1] | None | ❌ Visible | ⚠️ Grid-locked |
| Multi-scale | [0,1] + dense | None | ⚠️ Some | ✅ Better |
| **Extended** | [-0.25,1.25] | ✅ Yes | ✅ Reduced | ✅ Smooth |

### Training Command
```bash
python train_cifar10_extended.py \
    --margin 0.25 \
    --jitter_std 0.01 \
    --max_steps 100000
```

In [None]:
print("Done!")
print(f"\nAll outputs saved to: {OUTPUT_DIR}")