# Sinusoidal Neural Field Interpolation Evaluation

This notebook evaluates PixNerd's neural field interpolation capabilities for super-resolution.

## Benchmark Design

- **Training**: Model sees regularly sampled pixels (simulating low-res input)
- **Testing**: Model must predict ALL pixels (super-resolution)
- **Ground Truth**: Known sinusoidal patterns (smooth, continuous)

### Super-Resolution Simulation
Regular grid sampling simulates a downsampled low-resolution input:
- `downsample_factor=4`: Every 4th pixel visible = 4x super-res (6.25% visible in 2D)
- `downsample_factor=2`: Every 2nd pixel visible = 2x super-res (25% visible in 2D)

### Key Metrics
1. **Visible MSE**: Error on sampled positions (training data)
2. **Invisible MSE**: Error on unseen positions (interpolation quality)
3. **Interpolation Ratio**: invisible_mse / visible_mse (lower = better)

In [None]:
import os
import sys
from pathlib import Path

# Setup paths
NOTEBOOK_DIR = Path.cwd()
PIXNERD_DIR = NOTEBOOK_DIR / "PixNerd"

if PIXNERD_DIR.exists():
    os.chdir(PIXNERD_DIR)
    sys.path.insert(0, str(PIXNERD_DIR))
    print(f"Working directory: {os.getcwd()}")
else:
    print(f"ERROR: PixNerd directory not found at {PIXNERD_DIR}")

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

torch._dynamo.config.disable = True

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Configuration

In [None]:
# =============================================================================
# CONFIGURATION - Must match training!
# =============================================================================

CHECKPOINT_PATH = str(NOTEBOOK_DIR / "workdirs/exp_sinusoidal_nf_test/checkpoints/last.ckpt")

# Dataset config
RESOLUTION = 64  # High-res target
CHANNELS = 1
NUM_COMPONENTS = 5

# Super-resolution config (must match training)
DOWNSAMPLE_FACTOR = 4  # 4x super-res: every 4th pixel visible
MASK_MODE = "grid"  # 2D grid sampling

# Model config
PATCH_SIZE = 4
HIDDEN_SIZE = 256
DECODER_HIDDEN_SIZE = 64
NUM_ENCODER_BLOCKS = 6
NUM_DECODER_BLOCKS = 2
NUM_GROUPS = 4

# Sampling
NUM_SAMPLE_STEPS = 50

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Compute visible ratio
LOW_RES = RESOLUTION // DOWNSAMPLE_FACTOR
if MASK_MODE == "grid":
    VISIBLE_RATIO = 1.0 / (DOWNSAMPLE_FACTOR ** 2)
else:
    VISIBLE_RATIO = 1.0 / DOWNSAMPLE_FACTOR

print(f"Checkpoint: {CHECKPOINT_PATH}")
print(f"High-res target: {RESOLUTION}x{RESOLUTION}")
print(f"Low-res input: {LOW_RES}x{LOW_RES}")
print(f"Super-resolution: {DOWNSAMPLE_FACTOR}x")
print(f"Visible ratio: {VISIBLE_RATIO:.1%}")

## Load Model

In [None]:
from src.models.autoencoder.pixel import PixelAE
from src.models.transformer.pixnerd_c2i_heavydecoder import PixNerDiT
from src.diffusion.flow_matching.scheduling import LinearScheduler
from src.diffusion.flow_matching.sampling import EulerSampler, ode_step_fn
from src.diffusion.base.guidance import simple_guidance_fn

# Import dataset utilities
sys.path.insert(0, str(NOTEBOOK_DIR))
from train_sinusoidal import SinusoidalPLModule, SinusoidalLightningModel, SimpleEMA
from train_sinusoidal import UnconditionalConditioner, MaskedFlowMatchingTrainer
from PixNerd.src.data.dataset.sinusoidal import (
    generate_sinusoidal_image,
    create_visibility_mask,
    compute_interpolation_metrics,
)


def build_model():
    scheduler = LinearScheduler()
    vae = PixelAE(scale=1.0)
    conditioner = UnconditionalConditioner()
    
    denoiser = PixNerDiT(
        in_channels=CHANNELS,
        patch_size=PATCH_SIZE,
        num_groups=NUM_GROUPS,
        hidden_size=HIDDEN_SIZE,
        decoder_hidden_size=DECODER_HIDDEN_SIZE,
        num_encoder_blocks=NUM_ENCODER_BLOCKS,
        num_decoder_blocks=NUM_DECODER_BLOCKS,
        num_classes=1,
    )
    
    trainer = MaskedFlowMatchingTrainer(
        scheduler=scheduler,
        lognorm_t=True,
        timeshift=1.0,
    )
    
    sampler = EulerSampler(
        num_steps=NUM_SAMPLE_STEPS,
        guidance=1.0,
        guidance_interval_min=0.0,
        guidance_interval_max=1.0,
        scheduler=scheduler,
        w_scheduler=LinearScheduler(),
        guidance_fn=simple_guidance_fn,
        step_fn=ode_step_fn,
    )
    
    ema_tracker = SimpleEMA(decay=0.9999)
    
    model = SinusoidalLightningModel(
        vae=vae,
        conditioner=conditioner,
        denoiser=denoiser,
        diffusion_trainer=trainer,
        diffusion_sampler=sampler,
        ema_tracker=ema_tracker,
        optimizer=None,
    )
    
    return model


print("Building model...")
model = build_model()
model.setup_ema()

print(f"\nLoading checkpoint: {CHECKPOINT_PATH}")
if os.path.exists(CHECKPOINT_PATH):
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
    # Extract model state from Lightning checkpoint
    state_dict = checkpoint.get('state_dict', checkpoint)
    # Remove 'model.' prefix if present
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('model.'):
            new_state_dict[k[6:]] = v
        else:
            new_state_dict[k] = v
    model.load_state_dict(new_state_dict, strict=False)
    print("Checkpoint loaded!")
else:
    print(f"WARNING: Checkpoint not found!")

model = model.to(DEVICE)
model.eval()

print(f"\nTotal parameters: {sum(p.numel() for p in model.denoiser.parameters()):,}")

## Create Visibility Mask

In [None]:
# Create visibility mask (regular grid sampling)
visibility_mask = create_visibility_mask(
    RESOLUTION, DOWNSAMPLE_FACTOR, MASK_MODE
)

visible_ratio = visibility_mask.sum() / visibility_mask.size

print(f"Visibility mask shape: {visibility_mask.shape}")
print(f"Visible pixels: {visibility_mask.sum()} / {visibility_mask.size} ({visible_ratio:.1%})")
print(f"This simulates {LOW_RES}x{LOW_RES} → {RESOLUTION}x{RESOLUTION} super-resolution")

# Visualize mask
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Full mask
axes[0].imshow(visibility_mask, cmap='RdYlGn', interpolation='nearest')
axes[0].set_title(f"Visibility Mask ({visible_ratio:.1%} visible = green)")

# Zoomed view to show grid pattern
zoom_size = 16
axes[1].imshow(visibility_mask[:zoom_size, :zoom_size], cmap='RdYlGn', interpolation='nearest')
axes[1].set_title(f"Zoomed View (top-left {zoom_size}x{zoom_size})")
for ax in axes:
    ax.set_xlabel("x")
    ax.set_ylabel("y")

plt.tight_layout()
plt.show()

print(f"\nGrid pattern: Every {DOWNSAMPLE_FACTOR}th pixel in both x and y directions")

## Generate Ground Truth Samples

In [None]:
def generate_test_samples(num_samples=10, seed_start=1000):
    """Generate ground truth sinusoidal images for testing."""
    samples = []
    for i in range(num_samples):
        img = generate_sinusoidal_image(
            RESOLUTION, NUM_COMPONENTS,
            freq_range=(1.0, 8.0),
            seed=seed_start + i
        )
        samples.append(img)
    return np.stack(samples)


# Generate test samples
test_samples = generate_test_samples(10)
print(f"Generated {len(test_samples)} test samples")

# Display a few
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i, (ax, img) in enumerate(zip(axes.flat, test_samples)):
    ax.imshow(img, cmap='viridis')
    ax.set_title(f"Sample {i}")
    ax.axis('off')
plt.suptitle("Ground Truth Sinusoidal Patterns")
plt.tight_layout()
plt.show()

## Reconstruction Test

Test 1: Given a noisy version of a sinusoidal image, can the model reconstruct it?

The model was trained only on visible regions, but must predict ALL pixels.

In [None]:
@torch.no_grad()
def reconstruct_from_noise(model, num_samples=4):
    """Generate samples from pure noise."""
    model.eval()
    
    # Random noise
    noise = torch.randn(num_samples, CHANNELS, RESOLUTION, RESOLUTION, device=DEVICE)
    
    # Get condition (dummy)
    condition, uncondition = model.conditioner(noise)
    
    # Sample using EMA model
    denoiser = model.ema_denoiser if model.ema_denoiser is not None else model.denoiser
    
    samples = model.diffusion_sampler(
        denoiser,
        noise,
        condition,
        uncondition,
    )
    
    # Decode
    images = model.vae.decode(samples)
    
    return images.clamp(0, 1).cpu().numpy()


# Generate samples
print("Generating samples from noise...")
generated = reconstruct_from_noise(model, num_samples=8)

# Display
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
for i, (ax, img) in enumerate(zip(axes.flat, generated)):
    ax.imshow(img[0], cmap='viridis')  # [C, H, W] -> [H, W]
    ax.set_title(f"Generated {i}")
    ax.axis('off')
plt.suptitle("Generated Samples from Noise")
plt.tight_layout()
plt.show()

## Core Evaluation: Reconstruction-Based Interpolation Test

The key test for neural field interpolation:
1. Take ground truth sinusoidal image
2. Add noise (simulate diffusion forward process)
3. Denoise with the model
4. Compare reconstruction quality on **visible** vs **invisible** positions

If the neural field interpolates well:
- MSE on visible positions ≈ MSE on invisible positions
- Interpolation ratio ≈ 1.0

In [None]:
@torch.no_grad()
def denoise_from_timestep(model, x0, t_start=0.8, num_steps=40):
    """
    Denoise a clean image from a specific timestep.
    
    This simulates:
    1. Adding noise to x0 up to timestep t_start
    2. Denoising back to t=0
    
    Args:
        model: The diffusion model
        x0: Clean image [B, C, H, W]
        t_start: Starting timestep (0=clean, 1=pure noise)
        num_steps: Number of denoising steps
    
    Returns:
        Denoised image [B, C, H, W]
    """
    model.eval()
    device = next(model.parameters()).device
    
    # Move to device
    if not isinstance(x0, torch.Tensor):
        x0 = torch.from_numpy(x0).float()
    if x0.dim() == 3:
        x0 = x0.unsqueeze(0)  # Add batch dim
    if x0.dim() == 2:
        x0 = x0.unsqueeze(0).unsqueeze(0)  # Add batch and channel dims
    x0 = x0.to(device)
    
    batch_size = x0.shape[0]
    
    # Add noise: x_t = (1-t)*x0 + t*noise
    noise = torch.randn_like(x0)
    x_t = (1 - t_start) * x0 + t_start * noise
    
    # Get condition
    condition, uncondition = model.conditioner(x_t)
    
    # Use EMA model
    denoiser = model.ema_denoiser if model.ema_denoiser is not None else model.denoiser
    
    # Create timesteps from t_start to 0
    timesteps = torch.linspace(t_start, 0, num_steps + 1, device=device)
    
    # Denoise step by step (Euler method)
    x = x_t
    for i in range(num_steps):
        t_cur = timesteps[i]
        t_next = timesteps[i + 1]
        dt = t_next - t_cur  # negative
        
        # Predict velocity
        t_batch = t_cur.repeat(batch_size)
        v_pred = denoiser(x, t_batch, condition)
        
        # Euler step: x_next = x + v * dt
        x = x + v_pred * dt
    
    return x.clamp(0, 1)


def evaluate_interpolation(model, ground_truth_images, mask, t_start=0.7, num_steps=40):
    """
    Evaluate interpolation quality by reconstruction.
    
    Args:
        model: Diffusion model
        ground_truth_images: numpy array [N, H, W] or [N, C, H, W]
        mask: Visibility mask [H, W]
        t_start: Noise level for reconstruction test
        num_steps: Denoising steps
    
    Returns:
        Dict with metrics for each sample
    """
    results = []
    
    for i, gt in enumerate(ground_truth_images):
        # Ensure correct shape [1, C, H, W]
        if gt.ndim == 2:
            gt = gt[np.newaxis, np.newaxis, ...]  # [1, 1, H, W]
        elif gt.ndim == 3:
            gt = gt[np.newaxis, ...]  # [1, C, H, W]
        
        gt_tensor = torch.from_numpy(gt).float()
        
        # Denoise
        reconstructed = denoise_from_timestep(model, gt_tensor, t_start, num_steps)
        reconstructed = reconstructed.cpu().numpy()
        
        # Compute metrics
        gt_np = gt.squeeze()  # [H, W]
        recon_np = reconstructed.squeeze()  # [H, W]
        
        error = (gt_np - recon_np) ** 2
        
        visible_mse = error[mask].mean()
        invisible_mse = error[~mask].mean()
        full_mse = error.mean()
        ratio = invisible_mse / (visible_mse + 1e-8)
        
        results.append({
            'sample_idx': i,
            'visible_mse': visible_mse,
            'invisible_mse': invisible_mse,
            'full_mse': full_mse,
            'interpolation_ratio': ratio,
            'ground_truth': gt_np,
            'reconstructed': recon_np,
        })
        
        print(f"Sample {i}: visible_mse={visible_mse:.6f}, invisible_mse={invisible_mse:.6f}, ratio={ratio:.2f}")
    
    return results


print("Evaluation functions defined.")

In [None]:
# Run the interpolation evaluation
print("=" * 60)
print("NEURAL FIELD INTERPOLATION EVALUATION")
print("=" * 60)
print(f"\nTest setup:")
print(f"  - Ground truth: {len(test_samples)} sinusoidal patterns")
print(f"  - Visible positions: {VISIBLE_RATIO:.1%} (training data)")
print(f"  - Invisible positions: {1-VISIBLE_RATIO:.1%} (interpolation test)")
print(f"  - Noise level (t_start): 0.7")
print()

# Evaluate on test samples
eval_results = evaluate_interpolation(
    model, 
    test_samples, 
    visibility_mask, 
    t_start=0.7,  # Add 70% noise, then denoise
    num_steps=40
)

# Summary statistics
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
avg_visible = np.mean([r['visible_mse'] for r in eval_results])
avg_invisible = np.mean([r['invisible_mse'] for r in eval_results])
avg_ratio = np.mean([r['interpolation_ratio'] for r in eval_results])

print(f"\nAverage MSE on VISIBLE positions:   {avg_visible:.6f}")
print(f"Average MSE on INVISIBLE positions: {avg_invisible:.6f}")
print(f"Average Interpolation Ratio:        {avg_ratio:.2f}")
print()
if avg_ratio < 1.5:
    print("✓ GOOD: Interpolation ratio < 1.5 - Neural field generalizes well!")
elif avg_ratio < 2.0:
    print("~ OK: Interpolation ratio 1.5-2.0 - Decent interpolation")
else:
    print("✗ POOR: Interpolation ratio > 2.0 - Model struggles with unseen positions")

In [None]:
# Visualize reconstruction results
def visualize_reconstruction(result, mask, sample_idx=0):
    """Visualize ground truth vs reconstruction with error analysis."""
    gt = result['ground_truth']
    recon = result['reconstructed']
    error = np.abs(gt - recon)
    
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    
    # Row 1: Ground truth, Reconstruction, Absolute Error, Squared Error
    axes[0, 0].imshow(gt, cmap='viridis')
    axes[0, 0].set_title("Ground Truth")
    axes[0, 0].axis('off')
    
    axes[0, 1].imshow(recon, cmap='viridis')
    axes[0, 1].set_title("Reconstructed")
    axes[0, 1].axis('off')
    
    im = axes[0, 2].imshow(error, cmap='hot', vmin=0, vmax=0.3)
    axes[0, 2].set_title(f"Absolute Error (MSE={result['full_mse']:.4f})")
    axes[0, 2].axis('off')
    plt.colorbar(im, ax=axes[0, 2])
    
    # Error histogram
    axes[0, 3].hist(error[mask].flatten(), bins=50, alpha=0.7, label=f"Visible ({result['visible_mse']:.4f})", color='green')
    axes[0, 3].hist(error[~mask].flatten(), bins=50, alpha=0.7, label=f"Invisible ({result['invisible_mse']:.4f})", color='red')
    axes[0, 3].set_xlabel("Absolute Error")
    axes[0, 3].set_ylabel("Count")
    axes[0, 3].legend()
    axes[0, 3].set_title("Error Distribution")
    
    # Row 2: Error on visible, Error on invisible, Mask overlay, Difference map
    visible_error = error.copy()
    visible_error[~mask] = 0
    im = axes[1, 0].imshow(visible_error, cmap='hot', vmin=0, vmax=0.3)
    axes[1, 0].set_title(f"Error on VISIBLE (MSE={result['visible_mse']:.4f})")
    axes[1, 0].axis('off')
    
    invisible_error = error.copy()
    invisible_error[mask] = 0
    im = axes[1, 1].imshow(invisible_error, cmap='hot', vmin=0, vmax=0.3)
    axes[1, 1].set_title(f"Error on INVISIBLE (MSE={result['invisible_mse']:.4f})")
    axes[1, 1].axis('off')
    
    # Show sampling grid overlay on ground truth
    axes[1, 2].imshow(gt, cmap='viridis')
    y_vis, x_vis = np.where(mask)
    axes[1, 2].scatter(x_vis, y_vis, c='red', s=1, alpha=0.3)
    axes[1, 2].set_title(f"Sampling Grid ({VISIBLE_RATIO:.1%} visible)")
    axes[1, 2].axis('off')
    
    # Signed difference
    diff = recon - gt
    im = axes[1, 3].imshow(diff, cmap='RdBu', vmin=-0.3, vmax=0.3)
    axes[1, 3].set_title("Signed Difference (blue=under, red=over)")
    axes[1, 3].axis('off')
    plt.colorbar(im, ax=axes[1, 3])
    
    plt.suptitle(f"Sample {sample_idx} - Interpolation Ratio: {result['interpolation_ratio']:.2f}x", fontsize=14)
    plt.tight_layout()
    plt.show()


# Visualize first few results
for i in range(min(4, len(eval_results))):
    visualize_reconstruction(eval_results[i], visibility_mask, i)

## Interpolation Quality vs Noise Level

Test how interpolation quality varies with different starting noise levels.
Lower t_start = easier task (less noise to remove).

In [None]:
# Test interpolation at different noise levels
noise_levels = [0.3, 0.5, 0.7, 0.9]
noise_level_results = {}

# Use first 5 test samples for this analysis
test_subset = test_samples[:5]

for t_start in noise_levels:
    print(f"\nTesting t_start = {t_start}...")
    results = evaluate_interpolation(model, test_subset, visibility_mask, t_start=t_start, num_steps=40)
    
    avg_visible = np.mean([r['visible_mse'] for r in results])
    avg_invisible = np.mean([r['invisible_mse'] for r in results])
    avg_ratio = np.mean([r['interpolation_ratio'] for r in results])
    
    noise_level_results[t_start] = {
        'visible_mse': avg_visible,
        'invisible_mse': avg_invisible,
        'ratio': avg_ratio,
    }

# Plot results
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

t_values = list(noise_level_results.keys())
visible_mses = [noise_level_results[t]['visible_mse'] for t in t_values]
invisible_mses = [noise_level_results[t]['invisible_mse'] for t in t_values]
ratios = [noise_level_results[t]['ratio'] for t in t_values]

axes[0].plot(t_values, visible_mses, 'g-o', label='Visible (trained)')
axes[0].plot(t_values, invisible_mses, 'r-o', label='Invisible (interpolated)')
axes[0].set_xlabel('Noise Level (t_start)')
axes[0].set_ylabel('MSE')
axes[0].set_title('MSE vs Noise Level')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(t_values, ratios, 'b-o')
axes[1].axhline(y=1.0, color='gray', linestyle='--', label='Ideal (ratio=1)')
axes[1].set_xlabel('Noise Level (t_start)')
axes[1].set_ylabel('Interpolation Ratio')
axes[1].set_title('Interpolation Ratio vs Noise Level')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Bar chart comparison
x = np.arange(len(t_values))
width = 0.35
axes[2].bar(x - width/2, visible_mses, width, label='Visible', color='green', alpha=0.7)
axes[2].bar(x + width/2, invisible_mses, width, label='Invisible', color='red', alpha=0.7)
axes[2].set_xlabel('Noise Level')
axes[2].set_ylabel('MSE')
axes[2].set_title('Visible vs Invisible MSE')
axes[2].set_xticks(x)
axes[2].set_xticklabels([f't={t}' for t in t_values])
axes[2].legend()

plt.tight_layout()
plt.show()

# Summary table
print("\n" + "=" * 60)
print("NOISE LEVEL ANALYSIS SUMMARY")
print("=" * 60)
print(f"{'Noise Level':<15} {'Visible MSE':<15} {'Invisible MSE':<15} {'Ratio':<10}")
print("-" * 55)
for t in t_values:
    r = noise_level_results[t]
    print(f"{t:<15.1f} {r['visible_mse']:<15.6f} {r['invisible_mse']:<15.6f} {r['ratio']:<10.2f}")

## Interpolation Quality Analysis

Key question: How well does the model predict UNSEEN regions?

We compare:
1. MSE on visible regions (training data)
2. MSE on invisible regions (interpolation quality)

In [None]:
def visualize_interpolation(ground_truth, generated, mask):
    """
    Visualize interpolation quality.
    
    Shows:
    - Ground truth
    - Generated image
    - Error map
    - Visible/invisible region errors
    """
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    # Ground truth
    axes[0, 0].imshow(ground_truth, cmap='viridis')
    axes[0, 0].set_title("Ground Truth")
    axes[0, 0].axis('off')
    
    # Generated
    axes[0, 1].imshow(generated, cmap='viridis')
    axes[0, 1].set_title("Generated")
    axes[0, 1].axis('off')
    
    # Error map
    error = np.abs(ground_truth - generated)
    im = axes[0, 2].imshow(error, cmap='hot', vmin=0, vmax=0.5)
    axes[0, 2].set_title("Absolute Error")
    axes[0, 2].axis('off')
    plt.colorbar(im, ax=axes[0, 2])
    
    # Visibility mask overlay
    overlay = np.stack([ground_truth, ground_truth, ground_truth], axis=-1)
    overlay[~mask] = [1, 0, 0]  # Red for invisible regions
    axes[1, 0].imshow(overlay * 0.5 + 0.5 * ground_truth[..., np.newaxis])
    axes[1, 0].set_title("Visibility Mask (red = unseen)")
    axes[1, 0].axis('off')
    
    # Error on visible regions only
    visible_error = error.copy()
    visible_error[~mask] = 0
    im = axes[1, 1].imshow(visible_error, cmap='hot', vmin=0, vmax=0.5)
    visible_mse = (error[mask] ** 2).mean()
    axes[1, 1].set_title(f"Error on VISIBLE ({visible_mse:.4f} MSE)")
    axes[1, 1].axis('off')
    
    # Error on invisible regions only
    invisible_error = error.copy()
    invisible_error[mask] = 0
    im = axes[1, 2].imshow(invisible_error, cmap='hot', vmin=0, vmax=0.5)
    invisible_mse = (error[~mask] ** 2).mean()
    axes[1, 2].set_title(f"Error on INVISIBLE ({invisible_mse:.4f} MSE)")
    axes[1, 2].axis('off')
    
    ratio = invisible_mse / (visible_mse + 1e-8)
    plt.suptitle(f"Interpolation Quality - Ratio: {ratio:.2f}x (lower = better)", fontsize=14)
    plt.tight_layout()
    plt.show()
    
    return visible_mse, invisible_mse, ratio

In [None]:
# Analyze interpolation quality on generated samples
# Visualize the generated samples with the grid mask overlay

for i in range(min(4, len(generated))):
    gen_img = generated[i, 0]  # [H, W]
    
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    
    # Generated image
    axes[0].imshow(gen_img, cmap='viridis')
    axes[0].set_title("Generated")
    axes[0].axis('off')
    
    # With mask overlay (show sampled grid points)
    overlay = gen_img.copy()
    axes[1].imshow(overlay, cmap='viridis')
    # Mark visible (sampled) positions
    y_vis, x_vis = np.where(visibility_mask)
    axes[1].scatter(x_vis, y_vis, c='red', s=2, alpha=0.5)
    axes[1].set_title(f"Sampled positions (red dots, {DOWNSAMPLE_FACTOR}x grid)")
    axes[1].axis('off')
    
    # Gradient magnitude (smoothness check)
    grad_x = np.abs(np.diff(gen_img, axis=1))
    grad_y = np.abs(np.diff(gen_img, axis=0))
    axes[2].imshow(grad_x[:-1, :], cmap='hot')
    axes[2].set_title("Horizontal Gradient (edges)")
    axes[2].axis('off')
    
    plt.suptitle(f"Sample {i} Analysis")
    plt.tight_layout()
    plt.show()

## Boundary Analysis

Check if there are visible artifacts at the boundaries between visible and invisible regions.

In [None]:
def analyze_boundaries(generated, mask):
    """
    Analyze gradient magnitude at visibility boundaries.
    
    If interpolation is poor, we expect high gradients at boundaries.
    """
    results = []
    
    for img in generated:
        img = img[0]  # [H, W]
        
        # Compute gradient
        grad_x = np.abs(np.diff(img, axis=1))
        grad_x = np.pad(grad_x, ((0, 0), (0, 1)), mode='edge')
        
        # Find boundary pixels (where mask changes)
        mask_diff = np.abs(np.diff(mask.astype(float), axis=1))
        mask_diff = np.pad(mask_diff, ((0, 0), (0, 1)), mode='constant')
        boundary_mask = mask_diff > 0
        
        # Gradient at boundaries vs elsewhere
        grad_at_boundary = grad_x[boundary_mask].mean() if boundary_mask.any() else 0
        grad_elsewhere = grad_x[~boundary_mask].mean()
        
        results.append({
            'grad_at_boundary': grad_at_boundary,
            'grad_elsewhere': grad_elsewhere,
            'ratio': grad_at_boundary / (grad_elsewhere + 1e-8)
        })
    
    return results


# Analyze boundaries
boundary_results = analyze_boundaries(generated, visibility_mask)

print("Boundary Analysis:")
print("(Ratio close to 1.0 = good interpolation, no boundary artifacts)")
print()
for i, r in enumerate(boundary_results):
    print(f"Sample {i}: boundary_grad={r['grad_at_boundary']:.4f}, "
          f"other_grad={r['grad_elsewhere']:.4f}, ratio={r['ratio']:.2f}")

avg_ratio = np.mean([r['ratio'] for r in boundary_results])
print(f"\nAverage boundary ratio: {avg_ratio:.2f}")

## Summary

This benchmark tests neural field interpolation for super-resolution:

1. **Training** on regularly sampled grid (simulating low-res input)
2. **Evaluating** reconstruction quality on unseen positions
3. **Checking** for artifacts in interpolated regions

### Super-Resolution Setup
- `downsample_factor=4`: 4x super-res (16x16 → 64x64)
- Regular grid sampling (every 4th pixel in x and y)
- ~6.25% visible pixels (256/4096)

### Key Metrics:
- **Interpolation ratio** < 2.0: Good interpolation
- **Boundary ratio** ≈ 1.0: No grid artifacts
- **Visual smoothness**: Generated patterns should be smooth sinusoids

### Next Steps:
- Try different downsample factors (2, 4, 8)
- Test column-only or row-only sampling
- Compare different model architectures