# CIFAR-10 Multi-Scale NerfEmbedder: Super-Resolution Inference

This notebook demonstrates the **Multi-Scale NerfEmbedder** architecture which decouples:
- **Encoder patch size** → controls global coherence (more tokens = better)
- **NerfEmbedder dense_samples** → controls super-resolution quality (more positions = smoother)

## Key Innovation

| Architecture | patch_size | Encoder Tokens | NF Positions | Global | Super-Res |
|--------------|------------|----------------|--------------|--------|----------|
| Original (ps=2) | 2 | 256 | 4 | ✅ Excellent | ❌ Poor |
| Original (ps=8) | 8 | 16 | 64 | ⚠️ Limited | ✅ Good |
| **Multi-Scale** | 2 | 256 | 256 | ✅ Excellent | ✅ Excellent |

The multi-scale architecture uses **dense position sampling** (e.g., 16×16=256) regardless of patch size,
plus **multi-octave Fourier features** for robust interpolation at arbitrary scales.

## Setup

- Requires GPU
- Assumes model was trained using `train_cifar10_multiscale.py`
- Update `CKPT_PATH` to point to your trained checkpoint

### Training command:
```bash
python train_cifar10_multiscale.py --patch_size 2 --dense_samples 16 --max_steps 100000
```

### Model config (must match training - ~10M params):
| Parameter | Value | Why |
|-----------|-------|-----|
| patch_size | 2 | 16×16=256 encoder tokens (excellent global coherence) |
| dense_samples | 16 | 16×16=256 NF positions (excellent super-resolution) |
| hidden_size | 256 | Reduced for ~10M params |
| decoder_hidden_size | 32 | Reduced decoder capacity |
| num_encoder_blocks | 6 | Fewer blocks |
| num_groups | 4 | Fewer attention heads |
| nerf_fusion | concat | Multi-scale feature fusion |

In [None]:
# Navigate to PixNerd folder where src/ is located
import os
import sys

NOTEBOOK_DIR = os.getcwd()
print(f"Starting directory: {NOTEBOOK_DIR}")

# Navigate to PixNerd folder (where src/ lives)
PIXNERD_DIR = os.path.join(NOTEBOOK_DIR, "PixNerd")
if os.path.exists(PIXNERD_DIR):
    os.chdir(PIXNERD_DIR)
    print(f"Changed to: {os.getcwd()}")
elif os.path.basename(NOTEBOOK_DIR) == "PixNerd":
    print(f"Already in PixNerd directory: {NOTEBOOK_DIR}")
else:
    parent = os.path.dirname(NOTEBOOK_DIR)
    pixnerd_in_parent = os.path.join(parent, "PixNerd")
    if os.path.exists(pixnerd_in_parent):
        os.chdir(pixnerd_in_parent)
        print(f"Changed to: {os.getcwd()}")
    else:
        print(f"WARNING: Could not find PixNerd folder. Current dir: {NOTEBOOK_DIR}")

if os.path.exists("src"):
    print("Found src/ directory")
else:
    print("ERROR: src/ directory not found!")

In [None]:
from pathlib import Path
import math
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image

# Paths
PIXNERD_ROOT = Path(os.getcwd())

# ============================================================
# CHECKPOINT PATH - UPDATE THIS TO YOUR TRAINED MODEL
# ============================================================
CKPT_PATH = PIXNERD_ROOT / "workdirs" / "exp_cifar10_multiscale_nerf" / "checkpoints" / "last.ckpt"
# ============================================================

OUTPUT_DIR = PIXNERD_ROOT / "outputs" / "cifar10_multiscale_superres"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
if DEVICE != "cuda":
    print("WARNING: Running on CPU will be very slow")

# CIFAR-10 class names
CIFAR10_CLASSES = [
    'airplane', 'automobile', 'bird', 'cat', 'deer',
    'dog', 'frog', 'horse', 'ship', 'truck'
]

# ============================================================
# MODEL CONFIG - Must match train_cifar10_multiscale.py (~10M params)
# ============================================================
NUM_CLASSES = 10
BASE_RES = 32  # CIFAR-10 native resolution

# KEY INNOVATION: These are INDEPENDENT!
PATCH_SIZE = 2           # Controls encoder tokens: 32/2 = 16x16 = 256 tokens
DENSE_SAMPLES = 16       # Controls NF positions: 16x16 = 256 positions

# Smaller model config (~10M parameters)
HIDDEN_SIZE = 256
DECODER_HIDDEN_SIZE = 32
NUM_ENCODER_BLOCKS = 6
NUM_DECODER_BLOCKS = 2
NUM_GROUPS = 4
NERF_FUSION = "concat"   # Options: "concat", "add", "attention"
# ============================================================

print(f"PixNerd root: {PIXNERD_ROOT}")
print(f"Checkpoint path: {CKPT_PATH}")
print(f"Checkpoint exists: {CKPT_PATH.exists()}")
if not CKPT_PATH.exists():
    print(f"\n⚠️  CHECKPOINT NOT FOUND!")
    print(f"   Please train a model first using:")
    print(f"   python train_cifar10_multiscale.py --patch_size 2 --dense_samples 16")
print(f"Output directory: {OUTPUT_DIR}")
print(f"Device: {DEVICE}")
print()
print("=" * 60)
print("MULTI-SCALE ARCHITECTURE (~10M params)")
print("patch_size and dense_samples are INDEPENDENT!")
print("=" * 60)
print(f"  Encoder patch_size: {PATCH_SIZE}")
print(f"    → Encoder tokens: {BASE_RES//PATCH_SIZE}x{BASE_RES//PATCH_SIZE} = {(BASE_RES//PATCH_SIZE)**2}")
print(f"    → Controls: Global coherence")
print()
print(f"  NerfEmbedder dense_samples: {DENSE_SAMPLES}")
print(f"    → Position samples: {DENSE_SAMPLES}x{DENSE_SAMPLES} = {DENSE_SAMPLES**2}")
print(f"    → Controls: Super-resolution quality")
print("=" * 60)

## Build Model

Using `PixNerDiTMultiScale` with the Multi-Scale NerfEmbedder.

In [None]:
# Import PixNerd components - NOTE: Using MULTISCALE model!
from src.models.autoencoder.pixel import PixelAE
from src.models.conditioner.class_label import LabelConditioner
from src.models.transformer.pixnerd_c2i_multiscale import PixNerDiTMultiScale  # Multi-scale version!
from src.diffusion.flow_matching.scheduling import LinearScheduler
from src.diffusion.flow_matching.sampling import EulerSampler, ode_step_fn
from src.diffusion.base.guidance import simple_guidance_fn
from src.diffusion.flow_matching.training import FlowMatchingTrainer
from src.callbacks.simple_ema import SimpleEMA
from src.lightning_model import LightningModel
from src.models.autoencoder.base import fp2uint8

print("Imports successful!")
print("Using: PixNerDiTMultiScale (Multi-Scale NerfEmbedder)")

In [None]:
print("Initializing model components...")

main_scheduler = LinearScheduler()

vae = PixelAE(scale=1.0)

conditioner = LabelConditioner(num_classes=NUM_CLASSES)

# Multi-Scale NerfEmbedder model - KEY DIFFERENCE!
denoiser = PixNerDiTMultiScale(
    in_channels=3,
    patch_size=PATCH_SIZE,           # Controls encoder tokens
    dense_samples=DENSE_SAMPLES,     # Controls NF positions (INDEPENDENT!)
    num_groups=NUM_GROUPS,
    hidden_size=HIDDEN_SIZE,
    decoder_hidden_size=DECODER_HIDDEN_SIZE,
    num_encoder_blocks=NUM_ENCODER_BLOCKS,
    num_decoder_blocks=NUM_DECODER_BLOCKS,
    num_classes=NUM_CLASSES,
    nerf_fusion=NERF_FUSION,
)

# Sampler with CFG
sampler = EulerSampler(
    num_steps=50,
    guidance=2.0,
    guidance_interval_min=0.0,
    guidance_interval_max=1.0,
    scheduler=main_scheduler,
    w_scheduler=LinearScheduler(),
    guidance_fn=simple_guidance_fn,
    step_fn=ode_step_fn,
)

# Trainer stub for checkpoint loading
trainer_stub = FlowMatchingTrainer(
    scheduler=main_scheduler,
    lognorm_t=True,
    timeshift=1.0,
)

ema_tracker = SimpleEMA(decay=0.9999)

model = LightningModel(
    vae=vae,
    conditioner=conditioner,
    denoiser=denoiser,
    diffusion_trainer=trainer_stub,
    diffusion_sampler=sampler,
    ema_tracker=ema_tracker,
    optimizer=None,
    lr_scheduler=None,
    eval_original_model=False,
)

model.eval()
model.to(DEVICE)
print(f"Model initialized and moved to {DEVICE}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print()
print(f"Architecture: PixNerDiTMultiScale")
print(f"  - patch_size={PATCH_SIZE} → {(BASE_RES//PATCH_SIZE)**2} encoder tokens")
print(f"  - dense_samples={DENSE_SAMPLES} → {DENSE_SAMPLES**2} NF positions")
print(f"  - nerf_fusion={NERF_FUSION}")

## Load Checkpoint

In [None]:
print(f"Loading checkpoint from: {CKPT_PATH}")
ckpt = torch.load(CKPT_PATH, map_location="cpu", weights_only=False)
missing, unexpected = model.load_state_dict(ckpt["state_dict"], strict=False)
print(f"Missing keys: {len(missing)} | Unexpected keys: {len(unexpected)}")
if missing:
    print(f"  Missing: {missing[:5]}..." if len(missing) > 5 else f"  Missing: {missing}")
if unexpected:
    print(f"  Unexpected: {unexpected[:5]}..." if len(unexpected) > 5 else f"  Unexpected: {unexpected}")
print("Checkpoint loaded successfully!")

## Helper Functions

In [None]:
def set_decoder_scale(scale: float):
    """Set NF decoder patch scaling for super-resolution."""
    for net in [model.denoiser, getattr(model, "ema_denoiser", None)]:
        if net is None:
            continue
        net.decoder_patch_scaling_h = scale
        net.decoder_patch_scaling_w = scale


@torch.no_grad()
def sample_class_conditional(
    class_labels: list,
    height: int = 32,
    width: int = 32,
    seed: int = 42,
    num_steps: int = 50,
    guidance: float = 2.0,
    base_res: int = BASE_RES,
):
    """
    Generate class-conditional images.
    
    Args:
        class_labels: List of class indices (0-9) or names
        height: Output height (32 for native, 128 for 4x super-res)
        width: Output width
        seed: Random seed
        num_steps: ODE solver steps
        guidance: CFG guidance scale
        base_res: Training resolution
    
    Returns:
        Generated images as uint8 tensor
    """
    torch.manual_seed(seed)
    
    # Convert class names to indices if needed
    labels = []
    for label in class_labels:
        if isinstance(label, str):
            label = CIFAR10_CLASSES.index(label.lower())
        labels.append(label)
    
    batch_size = len(labels)
    
    # Set decoder scale for super-resolution
    if height == base_res and width == base_res:
        set_decoder_scale(1.0)
        print(f"Generating at native {base_res}x{base_res}")
    else:
        scale_h = height / float(base_res)
        scale_w = width / float(base_res)
        assert scale_h == scale_w, "Only square scaling supported"
        set_decoder_scale(scale_h)
        print(f"Generating at {height}x{width} ({scale_h:.0f}x super-resolution)")
    
    # Configure sampler
    model.diffusion_sampler.guidance = guidance
    model.diffusion_sampler.num_steps = num_steps
    
    # Generate noise
    noise = torch.randn(batch_size, 3, height, width, device=DEVICE)
    
    # Get condition and uncondition
    condition, uncondition = model.conditioner(labels)
    condition = condition.to(DEVICE)
    uncondition = uncondition.to(DEVICE)
    
    # Sample
    samples = model.diffusion_sampler(
        model.ema_denoiser,
        noise,
        condition,
        uncondition,
    )
    
    # Decode
    images = model.vae.decode(samples)
    images = torch.clamp(images, -1.0, 1.0)
    images_uint8 = fp2uint8(images)
    
    return images_uint8.cpu()


def show_images(images_uint8, labels=None, title="", cols=None):
    """Display a batch of images with labels."""
    if isinstance(images_uint8, torch.Tensor):
        imgs_np = images_uint8.permute(0, 2, 3, 1).cpu().numpy()
    else:
        imgs_np = np.transpose(images_uint8, (0, 2, 3, 1))
    
    n = len(imgs_np)
    if cols is None:
        cols = min(n, 5)
    rows = math.ceil(n / cols)
    
    fig, axes = plt.subplots(rows, cols, figsize=(3 * cols, 3 * rows))
    if rows == 1 and cols == 1:
        axes = np.array([axes])
    axes = axes.flatten()
    
    for i, (ax, img) in enumerate(zip(axes, imgs_np)):
        ax.imshow(img)
        ax.axis('off')
        if labels is not None:
            label = labels[i]
            if isinstance(label, int):
                label = CIFAR10_CLASSES[label]
            ax.set_title(label)
    
    for ax in axes[n:]:
        ax.axis('off')
    
    if title:
        plt.suptitle(title)
    plt.tight_layout()
    plt.show()


def save_grid(images_uint8, filename, labels=None, cols=None):
    """Save images as a grid."""
    if isinstance(images_uint8, torch.Tensor):
        imgs_np = images_uint8.permute(0, 2, 3, 1).cpu().numpy()
    else:
        imgs_np = np.transpose(images_uint8, (0, 2, 3, 1))
    
    imgs = [Image.fromarray(img) for img in imgs_np]
    
    n = len(imgs)
    if cols is None:
        cols = min(n, 5)
    rows = math.ceil(n / cols)
    
    w, h = imgs[0].size
    grid = Image.new("RGB", (cols * w, rows * h))
    for idx, img in enumerate(imgs):
        r, c = divmod(idx, cols)
        grid.paste(img, (c * w, r * h))
    
    out_path = OUTPUT_DIR / filename
    grid.save(out_path)
    print(f"Saved: {out_path}")
    return out_path


print("Helper functions defined.")

## Generate at Native Resolution (32x32)

Generate one sample per class at the native CIFAR-10 resolution.

With `patch_size=2`, we have **256 encoder tokens** for excellent global coherence.

In [None]:
print("=" * 50)
print("Generating all 10 classes at 32x32")
print("=" * 50)

# Generate one image per class
all_classes = list(range(10))

images_32 = sample_class_conditional(
    class_labels=all_classes,
    height=32,
    width=32,
    seed=42,
    num_steps=50,
    guidance=2.0,
)

print(f"Output shape: {images_32.shape}")
show_images(images_32, labels=CIFAR10_CLASSES, title="Multi-Scale NerfEmbedder: CIFAR-10 at 32x32")
save_grid(images_32, "multiscale_cifar10_32x32_all_classes.png", cols=5)

## Generate at 4x Super-Resolution (128x128)

The Multi-Scale NerfEmbedder enables smooth super-resolution because:
1. **Dense position sampling** (16×16=256) regardless of patch size
2. **Multi-octave Fourier features** for robust interpolation
3. **Cross-scale communication** (if using attention fusion)

In [None]:
print("=" * 50)
print("Generating all 10 classes at 128x128 (4x super-resolution)")
print("=" * 50)

images_128 = sample_class_conditional(
    class_labels=all_classes,
    height=128,
    width=128,
    seed=42,  # Same seed for comparison
    num_steps=50,
    guidance=2.0,
)

print(f"Output shape: {images_128.shape}")
show_images(images_128, labels=CIFAR10_CLASSES, title="Multi-Scale NerfEmbedder: CIFAR-10 at 128x128 (4x Super-Res)")
save_grid(images_128, "multiscale_cifar10_128x128_4x_superres.png", cols=5)

## Side-by-Side Comparison: 32x32 vs 128x128

In [None]:
# Compare specific classes
comparison_classes = ['cat', 'dog', 'airplane', 'ship']
comparison_labels = [CIFAR10_CLASSES.index(c) for c in comparison_classes]

print("Comparing 32x32 vs Bilinear 128x128 vs Multi-Scale NF 128x128")

# Generate at 32x32
imgs_32 = sample_class_conditional(
    class_labels=comparison_labels,
    height=32, width=32,
    seed=123,
)

# Generate at 128x128 (4x)
imgs_128 = sample_class_conditional(
    class_labels=comparison_labels,
    height=128, width=128,
    seed=123,
)

# Bilinear upscale 32->128
imgs_32_upscaled = F.interpolate(
    imgs_32.float() / 255.0,
    size=(128, 128),
    mode='bilinear',
    align_corners=False,
)
imgs_32_upscaled = (imgs_32_upscaled * 255).to(torch.uint8)

# Plot comparison
fig, axes = plt.subplots(len(comparison_classes), 3, figsize=(12, 4 * len(comparison_classes)))

for i, class_name in enumerate(comparison_classes):
    # 32x32 native
    axes[i, 0].imshow(imgs_32[i].permute(1, 2, 0).numpy())
    axes[i, 0].set_title(f"{class_name} - 32x32 Native")
    axes[i, 0].axis('off')
    
    # Bilinear upscale
    axes[i, 1].imshow(imgs_32_upscaled[i].permute(1, 2, 0).numpy())
    axes[i, 1].set_title(f"{class_name} - Bilinear 128x128")
    axes[i, 1].axis('off')
    
    # Multi-scale NF super-res
    axes[i, 2].imshow(imgs_128[i].permute(1, 2, 0).numpy())
    axes[i, 2].set_title(f"{class_name} - Multi-Scale NF 128x128")
    axes[i, 2].axis('off')

plt.suptitle("32x32 Native vs Bilinear Upscale vs Multi-Scale NF Super-Resolution", fontsize=14)
plt.tight_layout()
plt.show()

## Try Different Super-Resolution Scales

The Multi-Scale NerfEmbedder should provide smooth interpolation at any scale.

In [None]:
# Compare different scales: 1x, 2x, 4x, 8x
print("=" * 50)
print("Comparing different super-resolution scales")
print("=" * 50)

scales = [
    (32, "1x (Native)"),
    (64, "2x"),
    (128, "4x"),
    (256, "8x"),
]

class_label = 'airplane'
seed = 42

scale_images = []
scale_titles = []

for resolution, scale_name in scales:
    print(f"Generating at {resolution}x{resolution}...")
    img = sample_class_conditional(
        class_labels=[class_label],
        height=resolution, width=resolution,
        seed=seed,
    )
    scale_images.append(img[0])
    scale_titles.append(f"{resolution}x{resolution} ({scale_name})")

# Display at same visual size
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
for i, (img, title) in enumerate(zip(scale_images, scale_titles)):
    axes[i].imshow(img.permute(1, 2, 0).numpy())
    axes[i].set_title(title)
    axes[i].axis('off')

plt.suptitle(f"{class_label.capitalize()} at Different Resolutions (Multi-Scale NerfEmbedder)", fontsize=14)
plt.tight_layout()
plt.show()

## Extreme Super-Resolution Test

Let's test the limits of the multi-scale architecture with even higher resolutions.

In [None]:
# Try extreme super-resolution: 16x (512x512)
print("=" * 50)
print("Extreme Super-Resolution Test: 16x (512x512)")
print("=" * 50)

try:
    img_512 = sample_class_conditional(
        class_labels=['cat'],
        height=512, width=512,
        seed=42,
        num_steps=50,
    )
    
    plt.figure(figsize=(10, 10))
    plt.imshow(img_512[0].permute(1, 2, 0).numpy())
    plt.title("Cat at 512x512 (16x Super-Resolution)")
    plt.axis('off')
    plt.show()
    
    save_grid(img_512, "multiscale_cat_512x512_16x_superres.png")
except Exception as e:
    print(f"512x512 generation failed: {e}")
    print("This might require more GPU memory.")

## Guidance Scale Comparison

In [None]:
# Compare different guidance scales
print("=" * 50)
print("Comparing different CFG guidance scales at 128x128")
print("=" * 50)

guidance_scales = [1.0, 1.5, 2.0, 3.0, 5.0]
class_label = 'horse'
seed = 42

guidance_images = []
for g in guidance_scales:
    print(f"Guidance = {g}...")
    img = sample_class_conditional(
        class_labels=[class_label],
        height=128, width=128,
        seed=seed,
        guidance=g,
    )
    guidance_images.append(img[0])

fig, axes = plt.subplots(1, 5, figsize=(15, 3))
for i, (img, g) in enumerate(zip(guidance_images, guidance_scales)):
    axes[i].imshow(img.permute(1, 2, 0).numpy())
    axes[i].set_title(f"guidance={g}")
    axes[i].axis('off')

plt.suptitle(f"{class_label.capitalize()} with Different CFG Scales (Multi-Scale NerfEmbedder)", fontsize=14)
plt.tight_layout()
plt.show()

## Generate Variety of Samples

In [None]:
# Generate 5 different samples for each class at 128x128
print("=" * 50)
print("Generating 5 samples per class at 128x128")
print("=" * 50)

selected_classes = ['cat', 'dog', 'airplane']
seeds = [0, 42, 123, 456, 789]

for class_name in selected_classes:
    print(f"\nGenerating {class_name}s...")
    class_images = []
    for seed in seeds:
        img = sample_class_conditional(
            class_labels=[class_name],
            height=128, width=128,
            seed=seed,
        )
        class_images.append(img)
    
    class_images = torch.cat(class_images, dim=0)
    show_images(class_images, labels=[class_name] * 5, title=f"5 {class_name.capitalize()}s at 128x128")
    save_grid(class_images, f"multiscale_{class_name}s_128x128.png", cols=5)

## Summary

This notebook demonstrated the **Multi-Scale NerfEmbedder** architecture:

### Key Innovation: Decoupled Parameters

| Parameter | Controls | Value | Effect |
|-----------|----------|-------|--------|
| `patch_size` | Encoder tokens | 2 | 256 tokens → excellent global coherence |
| `dense_samples` | NF positions | 16 | 256 positions → smooth super-resolution |

**These are now INDEPENDENT!**

### Multi-Scale Position Encoding
- **Global octave**: Low frequencies → coarse image structure
- **Region octave**: Mid frequencies → regional patterns
- **Local octave**: High frequencies → fine details
- **Scale fusion**: Combines information across scales

### Comparison with Original Architecture

| Architecture | Encoder Tokens | NF Positions | Global | Super-Res |
|--------------|----------------|--------------|--------|----------|
| Original (ps=2) | 256 | 4 | ✅ | ❌ |
| Original (ps=8) | 16 | 64 | ⚠️ | ✅ |
| **Multi-Scale** | **256** | **256** | **✅** | **✅** |

### Training Command
```bash
python train_cifar10_multiscale.py \
    --patch_size 2 \
    --dense_samples 16 \
    --nerf_fusion concat \
    --max_steps 100000
```

In [None]:
print("Done!")
print(f"\nAll outputs saved to: {OUTPUT_DIR}")