# CIFAR-10 Extended Boundaries Inference with Overlap Supervision

This notebook demonstrates inference with `PixNerDiTExtended` trained using **proper overlap supervision**.

## Key Innovation: Overlap Supervision

Unlike simple coordinate rescaling, this model:
1. **Predicts extended regions**: Each patch predicts beyond its boundary into neighbor territory
2. **Uses real pixels**: Extended predictions supervised against actual neighbor pixel values
3. **Enforces consistency**: Overlapping regions must match → adjacent patches predict same values
4. **Blends during inference**: Overlapping predictions averaged for smooth output

### Example: patch_size=2, margin=0.5
- `margin_pixels = 1`, `extended_size = 4`
- Each 2×2 core patch predicts a 4×4 region
- The extra pixels come from neighboring patches
- During inference, overlapping 4×4 predictions are blended

In [None]:
import os
import sys
from pathlib import Path

# Setup paths
NOTEBOOK_DIR = Path.cwd()
PIXNERD_DIR = NOTEBOOK_DIR / "PixNerd"

if PIXNERD_DIR.exists():
    os.chdir(PIXNERD_DIR)
    sys.path.insert(0, str(PIXNERD_DIR))
    print(f"Working directory: {os.getcwd()}")
else:
    print(f"ERROR: PixNerd directory not found at {PIXNERD_DIR}")

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# Disable dynamo for compatibility
torch._dynamo.config.disable = True

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Model Configuration

Must match training config from `train_cifar10_extended.py`

In [None]:
# =============================================================================
# MODEL CONFIGURATION - Must match training!
# =============================================================================

# Checkpoint path (v2 = fresh training with overlap supervision)
CHECKPOINT_PATH = str(NOTEBOOK_DIR / "workdirs/exp_cifar10_extended_v2/checkpoints/last.ckpt")

# Model architecture (~10M params)
PATCH_SIZE = 2
HIDDEN_SIZE = 256
DECODER_HIDDEN_SIZE = 32
NUM_ENCODER_BLOCKS = 6
NUM_DECODER_BLOCKS = 2
NUM_GROUPS = 4
NUM_CLASSES = 10

# Extended boundary config - KEY PARAMETERS
MARGIN = 0.5  # 0.5 = 1 pixel margin for patch_size=2

# Compute extended sizes
MARGIN_PIXELS = max(1, int(round(PATCH_SIZE * MARGIN)))
EXTENDED_SIZE = PATCH_SIZE + 2 * MARGIN_PIXELS
EFFECTIVE_MARGIN = MARGIN_PIXELS / PATCH_SIZE

print("Extended Boundary Configuration:")
print(f"  patch_size: {PATCH_SIZE}")
print(f"  margin: {MARGIN} → margin_pixels: {MARGIN_PIXELS}")
print(f"  extended_size: {EXTENDED_SIZE} (core {PATCH_SIZE} + {MARGIN_PIXELS} on each side)")
print(f"  effective_margin: {EFFECTIVE_MARGIN:.2f}")
print()
print("How overlap works:")
print(f"  • Each {PATCH_SIZE}×{PATCH_SIZE} patch predicts {EXTENDED_SIZE}×{EXTENDED_SIZE} pixels")
print(f"  • Overlapping predictions are blended during inference")
print(f"  • Consistency enforced during training → smooth boundaries")

# Sampling config
GUIDANCE_SCALE = 2.0
NUM_SAMPLING_STEPS = 50

# Device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

## Load Model

In [None]:
from src.models.autoencoder.pixel import PixelAE
from src.models.conditioner.class_label import LabelConditioner
from src.models.transformer.pixnerd_c2i_extended import PixNerDiTExtended
from src.diffusion.flow_matching.scheduling import LinearScheduler
from src.diffusion.flow_matching.sampling import EulerSampler, ode_step_fn
from src.diffusion.base.guidance import simple_guidance_fn
from src.lightning_model import LightningModel

def build_model():
    """Build model with extended boundary architecture."""
    
    scheduler = LinearScheduler()
    vae = PixelAE(scale=1.0)
    conditioner = LabelConditioner(num_classes=NUM_CLASSES)
    
    # Extended NerfEmbedder model - no jittering (overlap must be consistent)
    denoiser = PixNerDiTExtended(
        in_channels=3,
        patch_size=PATCH_SIZE,
        num_groups=NUM_GROUPS,
        hidden_size=HIDDEN_SIZE,
        decoder_hidden_size=DECODER_HIDDEN_SIZE,
        num_encoder_blocks=NUM_ENCODER_BLOCKS,
        num_decoder_blocks=NUM_DECODER_BLOCKS,
        num_classes=NUM_CLASSES,
        margin=MARGIN,  # Extended boundaries with overlap supervision
    )
    
    sampler = EulerSampler(
        num_steps=NUM_SAMPLING_STEPS,
        guidance=GUIDANCE_SCALE,
        guidance_interval_min=0.0,
        guidance_interval_max=1.0,
        scheduler=scheduler,
        w_scheduler=LinearScheduler(),
        guidance_fn=simple_guidance_fn,
        step_fn=ode_step_fn,
    )
    
    model = LightningModel(
        vae=vae,
        conditioner=conditioner,
        denoiser=denoiser,
        diffusion_trainer=None,
        diffusion_sampler=sampler,
        ema_tracker=None,
        optimizer=None,
        lr_scheduler=None,
        eval_original_model=False,
    )
    
    return model

# Build and load
print("Building model...")
model = build_model()

print(f"\nLoading checkpoint: {CHECKPOINT_PATH}")
if os.path.exists(CHECKPOINT_PATH):
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
    model.load_state_dict(checkpoint['state_dict'], strict=False)
    print("Checkpoint loaded successfully!")
else:
    print(f"WARNING: Checkpoint not found at {CHECKPOINT_PATH}")
    print("Model will use random weights (for testing notebook structure)")

model = model.to(DEVICE)
model.eval()

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"\nTotal parameters: {total_params:,}")

## CIFAR-10 Class Labels

In [None]:
CIFAR10_CLASSES = [
    "airplane", "automobile", "bird", "cat", "deer",
    "dog", "frog", "horse", "ship", "truck"
]

print("CIFAR-10 Classes:")
for i, name in enumerate(CIFAR10_CLASSES):
    print(f"  {i}: {name}")

## Generation Function with Super-Resolution

In [None]:
@torch.no_grad()
def generate_images(
    model,
    class_labels,
    output_size=32,
    guidance_scale=None,
    num_steps=None,
    device=DEVICE,
):
    """
    Generate images with optional super-resolution.
    
    The extended boundary model naturally supports arbitrary output sizes
    because it predicts pixel values for continuous positions.
    
    Args:
        model: PixNerDiTExtended model
        class_labels: List of class indices [0-9]
        output_size: Output resolution (32 for native, higher for super-res)
        guidance_scale: CFG scale (None = use default)
        num_steps: Sampling steps (None = use default)
    
    Returns:
        images: numpy array [N, H, W, 3] in [0, 1]
    """
    model.eval()
    
    # Prepare labels
    if isinstance(class_labels, int):
        class_labels = [class_labels]
    
    batch_size = len(class_labels)
    labels = class_labels  # Keep as list, conditioner handles conversion
    
    # Compute super-resolution scaling
    base_size = 32  # CIFAR-10 native resolution
    scale_factor = output_size / base_size
    
    # Set decoder scaling for super-resolution
    model.ema_denoiser.decoder_patch_scaling_h = scale_factor
    model.ema_denoiser.decoder_patch_scaling_w = scale_factor
    
    # Update sampler if needed
    if guidance_scale is not None:
        model.diffusion_sampler.guidance = guidance_scale
    if num_steps is not None:
        model.diffusion_sampler.num_steps = num_steps
    
    # Prepare conditioning - returns BOTH condition and uncondition
    condition, uncondition = model.conditioner(labels)
    
    # Create noise tensor (sampler expects tensor, not shape)
    noise = torch.randn(batch_size, 3, output_size, output_size, device=device)
    
    print(f"Generating {batch_size} images at {output_size}×{output_size} (scale: {scale_factor}x)...")
    
    # Sample
    samples = model.diffusion_sampler(
        model.ema_denoiser,
        noise,
        condition,
        uncondition,
    )
    
    # Decode (identity for pixel space)
    images = model.vae.decode(samples)
    
    # Convert to numpy [0, 1]
    images = images.clamp(0, 1).cpu().numpy()
    images = images.transpose(0, 2, 3, 1)  # [N, H, W, C]
    
    # Reset scaling
    model.ema_denoiser.decoder_patch_scaling_h = 1.0
    model.ema_denoiser.decoder_patch_scaling_w = 1.0
    
    return images

## Generate Native Resolution (32×32)

In [None]:
# Generate one image per class at native resolution
class_labels = list(range(10))

images_32 = generate_images(model, class_labels, output_size=32)

# Display
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i, (ax, img) in enumerate(zip(axes.flat, images_32)):
    ax.imshow(img)
    ax.set_title(f"{i}: {CIFAR10_CLASSES[i]}")
    ax.axis('off')

plt.suptitle("Native Resolution (32×32)", fontsize=14)
plt.tight_layout()
plt.show()

## Super-Resolution Comparison

Compare the extended boundary model's super-resolution quality with bilinear upscaling.

The overlap supervision should produce smoother boundaries between patches.

In [None]:
def compare_superres(model, class_idx, scales=[1, 2, 4]):
    """
    Compare super-resolution at different scales.
    
    Shows:
    - Model output at each scale
    - Bilinear upscaling of native resolution for comparison
    """
    fig, axes = plt.subplots(len(scales), 3, figsize=(12, 4 * len(scales)))
    
    # Generate native resolution first
    img_native = generate_images(model, [class_idx], output_size=32)[0]
    
    for row, scale in enumerate(scales):
        output_size = 32 * scale
        
        # Model super-resolution
        if scale == 1:
            img_model = img_native
        else:
            img_model = generate_images(model, [class_idx], output_size=output_size)[0]
        
        # Bilinear upscaling
        img_bilinear = np.array(Image.fromarray(
            (img_native * 255).astype(np.uint8)
        ).resize((output_size, output_size), Image.BILINEAR)) / 255.0
        
        # Display
        axes[row, 0].imshow(img_model)
        axes[row, 0].set_title(f"Extended Model ({output_size}×{output_size})")
        axes[row, 0].axis('off')
        
        axes[row, 1].imshow(img_bilinear)
        axes[row, 1].set_title(f"Bilinear ({output_size}×{output_size})")
        axes[row, 1].axis('off')
        
        # Difference (amplified for visibility)
        if scale > 1:
            diff = np.abs(img_model - img_bilinear)
            diff_amplified = np.clip(diff * 5, 0, 1)  # Amplify differences
            axes[row, 2].imshow(diff_amplified)
            axes[row, 2].set_title(f"Difference (5x amplified)")
        else:
            axes[row, 2].text(0.5, 0.5, "N/A", ha='center', va='center', transform=axes[row, 2].transAxes)
            axes[row, 2].set_title("Difference")
        axes[row, 2].axis('off')
    
    plt.suptitle(f"Super-Resolution: {CIFAR10_CLASSES[class_idx]}", fontsize=14)
    plt.tight_layout()
    plt.show()

# Compare for a few classes
for class_idx in [3, 5, 8]:  # cat, dog, ship
    compare_superres(model, class_idx, scales=[1, 2, 4])

## Boundary Analysis

Analyze whether the overlap supervision produces smoother patch boundaries.

We can visualize potential boundary artifacts by looking at the gradient magnitude.

In [None]:
def analyze_boundaries(model, class_idx, output_size=128):
    """
    Analyze boundary smoothness using gradient analysis.
    
    If overlap supervision works well, we should NOT see grid patterns
    in the gradient magnitude at patch boundaries.
    """
    # Generate high-res image
    img = generate_images(model, [class_idx], output_size=output_size)[0]
    
    # Convert to grayscale for gradient analysis
    gray = np.mean(img, axis=2)
    
    # Compute gradients
    grad_x = np.abs(np.diff(gray, axis=1))
    grad_y = np.abs(np.diff(gray, axis=0))
    
    # Gradient magnitude (pad to same size)
    grad_x_pad = np.pad(grad_x, ((0, 0), (0, 1)), mode='edge')
    grad_y_pad = np.pad(grad_y, ((0, 1), (0, 0)), mode='edge')
    grad_mag = np.sqrt(grad_x_pad**2 + grad_y_pad**2)
    
    # Expected patch boundaries (based on patch_size=2, scaled up)
    scale = output_size // 32
    patch_size_scaled = PATCH_SIZE * scale
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Original image
    axes[0].imshow(img)
    axes[0].set_title(f"Generated Image ({output_size}×{output_size})")
    axes[0].axis('off')
    
    # Gradient magnitude
    im = axes[1].imshow(grad_mag, cmap='hot')
    axes[1].set_title("Gradient Magnitude")
    axes[1].axis('off')
    plt.colorbar(im, ax=axes[1], fraction=0.046)
    
    # Gradient with patch grid overlay
    axes[2].imshow(grad_mag, cmap='hot')
    # Draw patch boundaries
    for i in range(0, output_size, patch_size_scaled):
        axes[2].axhline(y=i, color='cyan', linewidth=0.5, alpha=0.5)
        axes[2].axvline(x=i, color='cyan', linewidth=0.5, alpha=0.5)
    axes[2].set_title(f"Gradient with Patch Grid (patch={patch_size_scaled}px)")
    axes[2].axis('off')
    
    plt.suptitle(f"Boundary Analysis: {CIFAR10_CLASSES[class_idx]}\n(If overlap works, grid lines should NOT align with high gradients)", fontsize=12)
    plt.tight_layout()
    plt.show()
    
    # Quantitative analysis: gradient at boundaries vs non-boundaries
    boundary_mask = np.zeros_like(grad_mag, dtype=bool)
    for i in range(0, output_size, patch_size_scaled):
        if i > 0 and i < output_size:
            boundary_mask[max(0,i-1):min(output_size,i+2), :] = True
            boundary_mask[:, max(0,i-1):min(output_size,i+2)] = True
    
    grad_at_boundaries = grad_mag[boundary_mask].mean()
    grad_elsewhere = grad_mag[~boundary_mask].mean()
    
    print(f"Average gradient at patch boundaries: {grad_at_boundaries:.4f}")
    print(f"Average gradient elsewhere: {grad_elsewhere:.4f}")
    print(f"Ratio (lower is better): {grad_at_boundaries / grad_elsewhere:.2f}")

# Analyze boundaries for a few classes
for class_idx in [0, 5, 9]:  # airplane, dog, truck
    analyze_boundaries(model, class_idx, output_size=128)

## Extreme Super-Resolution Test

Test higher resolution outputs to see how well the model generalizes.

In [None]:
# Test extreme super-resolution
class_idx = 7  # horse

fig, axes = plt.subplots(1, 4, figsize=(16, 4))
scales = [1, 4, 8, 16]

for ax, scale in zip(axes, scales):
    output_size = 32 * scale
    img = generate_images(model, [class_idx], output_size=output_size)[0]
    ax.imshow(img)
    ax.set_title(f"{output_size}×{output_size} ({scale}x)")
    ax.axis('off')

plt.suptitle(f"Extreme Super-Resolution: {CIFAR10_CLASSES[class_idx]}", fontsize=14)
plt.tight_layout()
plt.show()

## Batch Generation at High Resolution

In [None]:
# Generate all classes at 4x resolution (128×128)
images_128 = generate_images(model, list(range(10)), output_size=128)

fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i, (ax, img) in enumerate(zip(axes.flat, images_128)):
    ax.imshow(img)
    ax.set_title(f"{i}: {CIFAR10_CLASSES[i]}")
    ax.axis('off')

plt.suptitle("4× Super-Resolution (128×128)", fontsize=14)
plt.tight_layout()
plt.show()

## Interactive Generation

In [None]:
def generate_grid(class_idx, num_samples=4, output_size=64, guidance=2.0):
    """Generate a grid of samples for a single class."""
    images = generate_images(
        model, 
        [class_idx] * num_samples, 
        output_size=output_size,
        guidance_scale=guidance
    )
    
    cols = min(4, num_samples)
    rows = (num_samples + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 4 * rows))
    if num_samples == 1:
        axes = [[axes]]
    elif rows == 1:
        axes = [axes]
    
    for i, img in enumerate(images):
        ax = axes[i // cols][i % cols]
        ax.imshow(img)
        ax.axis('off')
    
    plt.suptitle(f"{CIFAR10_CLASSES[class_idx]} @ {output_size}×{output_size} (CFG={guidance})", fontsize=14)
    plt.tight_layout()
    plt.show()

# Generate samples
generate_grid(class_idx=3, num_samples=4, output_size=128, guidance=2.0)  # cat
generate_grid(class_idx=1, num_samples=4, output_size=128, guidance=3.0)  # automobile

## Summary

This notebook demonstrates the **Extended Boundaries with Overlap Supervision** model:

### Key Features:
1. **Overlap Supervision**: Each patch predicts extended regions, overlapping predictions are supervised against ground truth
2. **Consistency**: Adjacent patches must predict same values in shared regions
3. **Blending**: During inference, overlapping predictions are averaged for smooth output
4. **No Jittering**: Position jittering disabled to enforce exact overlap matching

### Configuration:
- `patch_size=2`, `margin=0.5`
- `margin_pixels=1`, `extended_size=4`
- Each 2×2 core patch predicts a 4×4 region

### Expected Benefits:
- Smoother boundaries between patches during super-resolution
- Better consistency in overlapping regions
- Reduced checkerboard/grid artifacts