# Options A/B Ablation Study - Evaluation Notebook

This notebook evaluates the impact of:
- **Option A** (sfc_unified_coords): Shared coordinate embedder for tokens & queries
- **Option B** (sfc_spatial_bias): Spatial attention bias in cross-attention

We also explore:
- Different **sampling steps** (50, 100, 200, 500)
- Different **guidance scales** (1.0, 1.5, 2.0, 3.0)
- Different **sparsity rates** (10%, 20%, 40%, 60%, 80%)

In [None]:
import os
import sys
from pathlib import Path

# Setup paths
NOTEBOOK_DIR = Path(os.getcwd())
PIXNERD_DIR = NOTEBOOK_DIR / "PixNerd"
if PIXNERD_DIR.exists():
    os.chdir(PIXNERD_DIR)
    sys.path.insert(0, str(PIXNERD_DIR))

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torchvision.utils import make_grid
from tqdm import tqdm  # Use regular tqdm instead of tqdm.auto
import warnings
warnings.filterwarnings('ignore')

# Disable torch compile
torch._dynamo.config.disable = True

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

In [None]:
# Import model components
from src.models.autoencoder.pixel import PixelAE
from src.models.conditioner.class_label import LabelConditioner
from src.models.transformer.pixnerd_c2i_heavydecoder import PixNerDiT
from src.diffusion.flow_matching.scheduling import LinearScheduler
from src.diffusion.flow_matching.sampling import EulerSampler, ode_step_fn
from src.diffusion.base.guidance import simple_guidance_fn

print("Imports successful!")

## 1. Configuration

In [None]:
# Device configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32

# Model configuration (should match training)
MODEL_CONFIG = {
    "in_channels": 3,
    "hidden_size": 512,
    "decoder_hidden_size": 64,
    "num_encoder_blocks": 8,
    "num_decoder_blocks": 2,
    "num_classes": 10,
    "patch_size": 8,
    "num_groups": 8,
    "encoder_type": "sfc",
    "sfc_curve": "hilbert",
    "sfc_group_size": 8,
    "sfc_cross_depth": 2,
    "sfc_self_depth": 2,
}

# Auto-discover checkpoints
def find_checkpoints():
    """Scan workdirs for available checkpoints."""
    import glob
    
    # Common checkpoint locations to search
    search_paths = [
        "./workdirs/*/checkpoints/*.ckpt",
        "./workdirs/*/checkpoints/last.ckpt",
        "../workdirs/*/checkpoints/*.ckpt",
        "/home/*/workspace/*/workdirs/*/checkpoints/*.ckpt",  # SciServer pattern
    ]
    
    found = {}
    all_ckpts = []
    
    for pattern in search_paths:
        all_ckpts.extend(glob.glob(pattern, recursive=True))
    
    # Deduplicate
    all_ckpts = list(set(all_ckpts))
    
    print("Found checkpoints:")
    for ckpt in sorted(all_ckpts):
        print(f"  {ckpt}")
        
        # Try to match to model variant based on path
        path_lower = ckpt.lower()
        if "baseline" in path_lower or ("no_option" in path_lower and "a" in path_lower and "b" in path_lower):
            found.setdefault("baseline", ckpt)
        elif "a+b" in path_lower or "ab" in path_lower or ("option_a" in path_lower and "option_b" in path_lower):
            found.setdefault("A+B", ckpt)
        elif "a_only" in path_lower or ("option_a" in path_lower and "no_option_b" in path_lower):
            found.setdefault("A_only", ckpt)
        elif "b_only" in path_lower or ("option_b" in path_lower and "no_option_a" in path_lower):
            found.setdefault("B_only", ckpt)
        # Also check exp_name patterns
        elif "cifar10_ab" in path_lower and "only" not in path_lower:
            found.setdefault("A+B", ckpt)
        elif "cifar10_a_only" in path_lower or "cifar10_sfc_a_only" in path_lower:
            found.setdefault("A_only", ckpt)
        elif "cifar10_b_only" in path_lower or "cifar10_sfc_b_only" in path_lower:
            found.setdefault("B_only", ckpt)
        elif "cifar10_baseline" in path_lower or "cifar10_sfc_baseline" in path_lower:
            found.setdefault("baseline", ckpt)
    
    return found, all_ckpts

discovered_ckpts, all_ckpt_list = find_checkpoints()

# Default checkpoint paths - will be overridden by discovered ones
CHECKPOINT_PATHS = {
    "A+B": "./workdirs/exp_cifar10_ab/checkpoints/last.ckpt",
    "A_only": "./workdirs/exp_cifar10_a_only/checkpoints/last.ckpt",
    "B_only": "./workdirs/exp_cifar10_b_only/checkpoints/last.ckpt",
    "baseline": "./workdirs/exp_cifar10_baseline/checkpoints/last.ckpt",
}

# Override with discovered checkpoints
CHECKPOINT_PATHS.update(discovered_ckpts)

print("\n" + "="*60)
print("CHECKPOINT CONFIGURATION")
print("="*60)
for name, path in CHECKPOINT_PATHS.items():
    exists = os.path.exists(path)
    status = "✓ FOUND" if exists else "✗ NOT FOUND"
    print(f"  {name:12s}: {status}")
    print(f"               {path}")

# If checkpoints not auto-discovered, user can manually set them here:
# CHECKPOINT_PATHS["A+B"] = "/path/to/your/checkpoint.ckpt"

# Evaluation parameters
EVAL_BATCH_SIZE = 8
NUM_EVAL_BATCHES = 4  # Total samples = EVAL_BATCH_SIZE * NUM_EVAL_BATCHES

print(f"\nDevice: {DEVICE}")
print(f"Dtype: {DTYPE}")

## 2. Helper Functions

In [None]:
def load_cifar10_test(num_samples=32, seed=42):
    """Load fixed test samples from CIFAR-10."""
    torch.manual_seed(seed)
    
    tfm = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    
    dataset = CIFAR10(root="./data", train=False, transform=tfm, download=True)
    
    # Get fixed samples
    indices = torch.randperm(len(dataset))[:num_samples]
    images = torch.stack([dataset[i][0] for i in indices])
    labels = torch.tensor([dataset[i][1] for i in indices])
    
    return images, labels


def generate_sparsity_mask(batch_size, height, width, sparsity, device, dtype=torch.float32, seed=None):
    """
    Generate conditioning mask with given sparsity.
    
    Args:
        sparsity: fraction of pixels to use as conditioning (0.0 to 1.0)
    
    Returns:
        cond_mask: (B, 1, H, W) with 1s at conditioning pixels
    """
    if seed is not None:
        torch.manual_seed(seed)
    
    B, H, W = batch_size, height, width
    total = H * W
    k_cond = int(round(sparsity * total))
    k_cond = max(1, min(total, k_cond))
    
    cond_mask = torch.zeros(B, 1, H, W, device=device, dtype=dtype)
    flat_idx = torch.arange(total, device=device)
    
    for b in range(B):
        perm = flat_idx[torch.randperm(total, device=device)]
        cond_idx = perm[:k_cond]
        cond_mask[b].view(-1)[cond_idx] = 1.0
    
    return cond_mask


def tensor_to_image(tensor, nrow=8):
    """Convert tensor to displayable image grid."""
    tensor = tensor.detach().cpu().float()  # Convert to float32 for numpy compatibility
    tensor = (tensor.clamp(-1, 1) + 1) / 2  # [-1,1] -> [0,1]
    grid = make_grid(tensor, nrow=nrow, padding=2, normalize=False)
    return grid.permute(1, 2, 0).numpy()


def compute_metrics(pred, target, mask=None):
    """Compute reconstruction metrics."""
    pred = pred.float()
    target = target.float()
    
    if mask is not None:
        mask = mask.float().expand_as(pred)
        mse = ((pred - target) ** 2 * mask).sum() / mask.sum()
    else:
        mse = F.mse_loss(pred, target)
    
    psnr = 10 * torch.log10(4.0 / (mse + 1e-8))  # max range is 2 for [-1,1]
    
    return {
        "mse": mse.item(),
        "psnr": psnr.item(),
    }


print("Helper functions defined!")

## 3. Model Loading

In [None]:
def build_model(option_a: bool, option_b: bool):
    """Build model with specified options."""
    model = PixNerDiT(
        **MODEL_CONFIG,
        sfc_unified_coords=option_a,
        sfc_spatial_bias=option_b,
    )
    return model


def load_checkpoint(model, checkpoint_path):
    """Load checkpoint into model."""
    if not os.path.exists(checkpoint_path):
        print(f"Warning: Checkpoint not found: {checkpoint_path}")
        return False
    
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    
    # Handle Lightning checkpoint format
    if "state_dict" in checkpoint:
        state_dict = checkpoint["state_dict"]
    else:
        state_dict = checkpoint
    
    # Extract EMA denoiser weights (preferred) or regular denoiser
    model_state = {}
    prefix = "ema_denoiser."  # Use EMA weights
    
    for k, v in state_dict.items():
        if k.startswith(prefix):
            new_key = k[len(prefix):]
            model_state[new_key] = v
    
    if not model_state:
        # Fallback to regular denoiser
        prefix = "denoiser."
        for k, v in state_dict.items():
            if k.startswith(prefix):
                new_key = k[len(prefix):]
                model_state[new_key] = v
    
    if model_state:
        model.load_state_dict(model_state, strict=False)
        print(f"Loaded {len(model_state)} parameters from {checkpoint_path}")
        return True
    else:
        print(f"Warning: Could not find model weights in checkpoint")
        return False


def build_sampler(scheduler, num_steps=200, guidance=2.0):
    """Build sampler with specified parameters."""
    return EulerSampler(
        num_steps=num_steps,
        guidance=guidance,
        guidance_interval_min=0.0,
        guidance_interval_max=1.0,
        scheduler=scheduler,
        w_scheduler=LinearScheduler(),
        guidance_fn=simple_guidance_fn,
        step_fn=ode_step_fn,
    )


print("Model loading functions defined!")

In [None]:
# Load test images
test_images, test_labels = load_cifar10_test(num_samples=EVAL_BATCH_SIZE * NUM_EVAL_BATCHES)
print(f"Loaded {len(test_images)} test images")

# Display some test images
fig, ax = plt.subplots(1, 1, figsize=(12, 3))
ax.imshow(tensor_to_image(test_images[:16], nrow=16))
ax.set_title("Test Images (first 16)")
ax.axis("off")
plt.tight_layout()
plt.show()

## 4. Load All Model Variants

In [None]:
# Define model variants
MODEL_VARIANTS = {
    "A+B": {"option_a": True, "option_b": True},
    "A_only": {"option_a": True, "option_b": False},
    "B_only": {"option_a": False, "option_b": True},
    "baseline": {"option_a": False, "option_b": False},
}

# Build conditioner
conditioner = LabelConditioner(num_classes=MODEL_CONFIG["num_classes"]).to(DEVICE)
conditioner.eval()

# Build VAE (identity for pixel space)
vae = PixelAE(scale=1.0)

# Scheduler
scheduler = LinearScheduler()

# Load models
models = {}
for name, config in MODEL_VARIANTS.items():
    print(f"\nLoading {name}...")
    model = build_model(**config)
    
    ckpt_path = CHECKPOINT_PATHS.get(name)
    if ckpt_path and os.path.exists(ckpt_path):
        load_checkpoint(model, ckpt_path)
        model = model.to(DEVICE).to(DTYPE)
        model.eval()
        models[name] = model
        print(f"  Loaded successfully!")
    else:
        print(f"  Checkpoint not found, skipping.")

print(f"\nLoaded {len(models)} model variants: {list(models.keys())}")

## 5. Sampling Function

In [None]:
@torch.no_grad()
def sample_reconstruction(
    model,
    images,
    labels,
    sparsity=0.2,
    num_steps=200,
    guidance=2.0,
    disable_spatial_bias=False,
    seed=42,
):
    """
    Run sparse-conditioned reconstruction.
    
    Args:
        model: PixNerDiT model
        images: (B, C, H, W) ground truth images
        labels: (B,) class labels
        sparsity: fraction of pixels to condition on
        num_steps: number of sampling steps
        guidance: classifier-free guidance scale
        disable_spatial_bias: whether to disable Option B during sampling
        seed: random seed for reproducibility
    
    Returns:
        samples: (B, C, H, W) reconstructed images
        cond_mask: (B, 1, H, W) conditioning mask used
    """
    B, C, H, W = images.shape
    device = images.device
    dtype = images.dtype
    
    # Generate conditioning mask
    cond_mask = generate_sparsity_mask(B, H, W, sparsity, device, dtype, seed=seed)
    
    # Get class conditioning
    condition, uncondition = conditioner(labels)
    
    # Build sampler
    sampler = build_sampler(scheduler, num_steps=num_steps, guidance=guidance)
    
    # Start from noise
    noise = torch.randn_like(images)
    
    # Wrapper to pass disable_spatial_bias
    def model_fn(x, t, y, cond_mask=None):
        return model(x, t, y, cond_mask=cond_mask, disable_spatial_bias=disable_spatial_bias)
    
    # Sample
    samples = sampler(
        model_fn,
        noise,
        condition,
        uncondition,
        cond_mask=cond_mask,
        x_cond=images,
    )
    
    return samples, cond_mask


print("Sampling function defined!")

## 6. Ablation: Options A/B Comparison

In [None]:
def run_ablation_comparison(
    models_dict,
    images,
    labels,
    sparsity=0.2,
    num_steps=200,
    guidance=2.0,
):
    """Run comparison across all model variants."""
    results = {}
    
    images = images.to(DEVICE).to(DTYPE)
    labels = labels.to(DEVICE)
    
    for name, model in tqdm(models_dict.items(), desc="Models"):
        samples, cond_mask = sample_reconstruction(
            model, images, labels,
            sparsity=sparsity,
            num_steps=num_steps,
            guidance=guidance,
        )
        
        # Compute metrics on non-conditioned pixels
        target_mask = 1.0 - cond_mask
        metrics = compute_metrics(samples, images, mask=target_mask)
        
        results[name] = {
            "samples": samples.cpu(),
            "cond_mask": cond_mask.cpu(),
            "metrics": metrics,
        }
        
        print(f"{name}: PSNR={metrics['psnr']:.2f} dB, MSE={metrics['mse']:.4f}")
    
    return results


# Run ablation if models are loaded
if models:
    print("\n" + "="*60)
    print("Options A/B Ablation (sparsity=0.2, steps=200, guidance=2.0)")
    print("="*60)
    
    ablation_results = run_ablation_comparison(
        models,
        test_images[:EVAL_BATCH_SIZE],
        test_labels[:EVAL_BATCH_SIZE],
        sparsity=0.2,
        num_steps=200,
        guidance=2.0,
    )
else:
    print("No models loaded - please check checkpoint paths")

In [None]:
def visualize_ablation_results(results, gt_images, num_show=8):
    """Visualize ablation results."""
    num_variants = len(results)
    
    fig, axes = plt.subplots(num_variants + 2, 1, figsize=(16, 3 * (num_variants + 2)))
    
    # Ground truth
    axes[0].imshow(tensor_to_image(gt_images[:num_show], nrow=num_show))
    axes[0].set_title("Ground Truth", fontsize=14)
    axes[0].axis("off")
    
    # Conditioning mask (from first result)
    first_result = list(results.values())[0]
    cond_vis = gt_images[:num_show].clone()
    mask = first_result["cond_mask"][:num_show]
    cond_vis = cond_vis * mask + (1 - mask) * 0.5  # Gray out non-conditioned
    axes[1].imshow(tensor_to_image(cond_vis, nrow=num_show))
    axes[1].set_title(f"Conditioning Mask (sparsity={mask.mean():.1%})", fontsize=14)
    axes[1].axis("off")
    
    # Each variant
    for idx, (name, data) in enumerate(results.items()):
        samples = data["samples"][:num_show]
        metrics = data["metrics"]
        
        axes[idx + 2].imshow(tensor_to_image(samples, nrow=num_show))
        axes[idx + 2].set_title(
            f"{name}: PSNR={metrics['psnr']:.2f} dB",
            fontsize=14
        )
        axes[idx + 2].axis("off")
    
    plt.tight_layout()
    plt.savefig("ablation_options_ab.png", dpi=150, bbox_inches="tight")
    plt.show()


if models and 'ablation_results' in dir():
    visualize_ablation_results(ablation_results, test_images[:EVAL_BATCH_SIZE])

## 7. Ablation: Sampling Steps

In [None]:
def run_steps_ablation(model, images, labels, steps_list=[50, 100, 200, 500], sparsity=0.2, guidance=2.0):
    """Evaluate effect of sampling steps."""
    results = {}
    
    images = images.to(DEVICE).to(DTYPE)
    labels = labels.to(DEVICE)
    
    for num_steps in tqdm(steps_list, desc="Steps"):
        samples, cond_mask = sample_reconstruction(
            model, images, labels,
            sparsity=sparsity,
            num_steps=num_steps,
            guidance=guidance,
        )
        
        target_mask = 1.0 - cond_mask
        metrics = compute_metrics(samples, images, mask=target_mask)
        
        results[num_steps] = {
            "samples": samples.cpu(),
            "cond_mask": cond_mask.cpu(),
            "metrics": metrics,
        }
        
        print(f"Steps={num_steps}: PSNR={metrics['psnr']:.2f} dB")
    
    return results


# Run if we have the A+B model
if "A+B" in models:
    print("\n" + "="*60)
    print("Sampling Steps Ablation (A+B model)")
    print("="*60)
    
    steps_results = run_steps_ablation(
        models["A+B"],
        test_images[:EVAL_BATCH_SIZE],
        test_labels[:EVAL_BATCH_SIZE],
        steps_list=[50, 100, 200, 500],
    )

In [None]:
def visualize_steps_ablation(results, gt_images, num_show=8):
    """Visualize steps ablation."""
    num_steps_variants = len(results)
    
    fig, axes = plt.subplots(num_steps_variants + 1, 1, figsize=(16, 3 * (num_steps_variants + 1)))
    
    # Ground truth
    axes[0].imshow(tensor_to_image(gt_images[:num_show], nrow=num_show))
    axes[0].set_title("Ground Truth", fontsize=14)
    axes[0].axis("off")
    
    # Each step count
    for idx, (num_steps, data) in enumerate(sorted(results.items())):
        samples = data["samples"][:num_show]
        metrics = data["metrics"]
        
        axes[idx + 1].imshow(tensor_to_image(samples, nrow=num_show))
        axes[idx + 1].set_title(
            f"Steps={num_steps}: PSNR={metrics['psnr']:.2f} dB",
            fontsize=14
        )
        axes[idx + 1].axis("off")
    
    plt.tight_layout()
    plt.savefig("ablation_steps.png", dpi=150, bbox_inches="tight")
    plt.show()
    
    # Plot PSNR vs steps
    fig, ax = plt.subplots(figsize=(8, 5))
    steps = sorted(results.keys())
    psnrs = [results[s]["metrics"]["psnr"] for s in steps]
    ax.plot(steps, psnrs, "o-", markersize=10, linewidth=2)
    ax.set_xlabel("Sampling Steps", fontsize=12)
    ax.set_ylabel("PSNR (dB)", fontsize=12)
    ax.set_title("PSNR vs Sampling Steps", fontsize=14)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig("psnr_vs_steps.png", dpi=150, bbox_inches="tight")
    plt.show()


if 'steps_results' in dir():
    visualize_steps_ablation(steps_results, test_images[:EVAL_BATCH_SIZE])

## 8. Ablation: Guidance Scale

In [None]:
def run_guidance_ablation(model, images, labels, guidance_list=[1.0, 1.5, 2.0, 3.0, 4.0], sparsity=0.2, num_steps=200):
    """Evaluate effect of guidance scale."""
    results = {}
    
    images = images.to(DEVICE).to(DTYPE)
    labels = labels.to(DEVICE)
    
    for guidance in tqdm(guidance_list, desc="Guidance"):
        samples, cond_mask = sample_reconstruction(
            model, images, labels,
            sparsity=sparsity,
            num_steps=num_steps,
            guidance=guidance,
        )
        
        target_mask = 1.0 - cond_mask
        metrics = compute_metrics(samples, images, mask=target_mask)
        
        results[guidance] = {
            "samples": samples.cpu(),
            "cond_mask": cond_mask.cpu(),
            "metrics": metrics,
        }
        
        print(f"Guidance={guidance}: PSNR={metrics['psnr']:.2f} dB")
    
    return results


if "A+B" in models:
    print("\n" + "="*60)
    print("Guidance Scale Ablation (A+B model)")
    print("="*60)
    
    guidance_results = run_guidance_ablation(
        models["A+B"],
        test_images[:EVAL_BATCH_SIZE],
        test_labels[:EVAL_BATCH_SIZE],
        guidance_list=[1.0, 1.5, 2.0, 3.0, 4.0],
    )

In [None]:
def visualize_guidance_ablation(results, gt_images, num_show=8):
    """Visualize guidance ablation."""
    num_variants = len(results)
    
    fig, axes = plt.subplots(num_variants + 1, 1, figsize=(16, 3 * (num_variants + 1)))
    
    axes[0].imshow(tensor_to_image(gt_images[:num_show], nrow=num_show))
    axes[0].set_title("Ground Truth", fontsize=14)
    axes[0].axis("off")
    
    for idx, (guidance, data) in enumerate(sorted(results.items())):
        samples = data["samples"][:num_show]
        metrics = data["metrics"]
        
        axes[idx + 1].imshow(tensor_to_image(samples, nrow=num_show))
        axes[idx + 1].set_title(
            f"Guidance={guidance}: PSNR={metrics['psnr']:.2f} dB",
            fontsize=14
        )
        axes[idx + 1].axis("off")
    
    plt.tight_layout()
    plt.savefig("ablation_guidance.png", dpi=150, bbox_inches="tight")
    plt.show()
    
    # Plot PSNR vs guidance
    fig, ax = plt.subplots(figsize=(8, 5))
    guidance_vals = sorted(results.keys())
    psnrs = [results[g]["metrics"]["psnr"] for g in guidance_vals]
    ax.plot(guidance_vals, psnrs, "o-", markersize=10, linewidth=2, color="orange")
    ax.set_xlabel("Guidance Scale", fontsize=12)
    ax.set_ylabel("PSNR (dB)", fontsize=12)
    ax.set_title("PSNR vs Guidance Scale", fontsize=14)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig("psnr_vs_guidance.png", dpi=150, bbox_inches="tight")
    plt.show()


if 'guidance_results' in dir():
    visualize_guidance_ablation(guidance_results, test_images[:EVAL_BATCH_SIZE])

## 9. Ablation: Sparsity Rate

In [None]:
def run_sparsity_ablation(model, images, labels, sparsity_list=[0.1, 0.2, 0.4, 0.6, 0.8], num_steps=200, guidance=2.0):
    """Evaluate effect of conditioning sparsity."""
    results = {}
    
    images = images.to(DEVICE).to(DTYPE)
    labels = labels.to(DEVICE)
    
    for sparsity in tqdm(sparsity_list, desc="Sparsity"):
        samples, cond_mask = sample_reconstruction(
            model, images, labels,
            sparsity=sparsity,
            num_steps=num_steps,
            guidance=guidance,
        )
        
        target_mask = 1.0 - cond_mask
        metrics = compute_metrics(samples, images, mask=target_mask)
        
        results[sparsity] = {
            "samples": samples.cpu(),
            "cond_mask": cond_mask.cpu(),
            "metrics": metrics,
        }
        
        print(f"Sparsity={sparsity:.0%}: PSNR={metrics['psnr']:.2f} dB")
    
    return results


if "A+B" in models:
    print("\n" + "="*60)
    print("Sparsity Rate Ablation (A+B model, trained with 20% input)")
    print("="*60)
    
    sparsity_results = run_sparsity_ablation(
        models["A+B"],
        test_images[:EVAL_BATCH_SIZE],
        test_labels[:EVAL_BATCH_SIZE],
        sparsity_list=[0.05, 0.1, 0.2, 0.4, 0.6, 0.8],
    )

In [None]:
def visualize_sparsity_ablation(results, gt_images, num_show=8):
    """Visualize sparsity ablation with conditioning masks."""
    num_variants = len(results)
    
    fig, axes = plt.subplots(num_variants + 1, 2, figsize=(20, 3 * (num_variants + 1)))
    
    # Ground truth
    axes[0, 0].imshow(tensor_to_image(gt_images[:num_show], nrow=num_show))
    axes[0, 0].set_title("Ground Truth", fontsize=14)
    axes[0, 0].axis("off")
    axes[0, 1].axis("off")
    
    # Each sparsity level
    for idx, (sparsity, data) in enumerate(sorted(results.items())):
        samples = data["samples"][:num_show]
        cond_mask = data["cond_mask"][:num_show]
        metrics = data["metrics"]
        
        # Show conditioning visualization
        cond_vis = gt_images[:num_show].clone()
        cond_vis = cond_vis * cond_mask + (1 - cond_mask) * 0.5
        axes[idx + 1, 0].imshow(tensor_to_image(cond_vis, nrow=num_show))
        axes[idx + 1, 0].set_title(f"Input ({sparsity:.0%} observed)", fontsize=14)
        axes[idx + 1, 0].axis("off")
        
        # Show reconstruction
        axes[idx + 1, 1].imshow(tensor_to_image(samples, nrow=num_show))
        axes[idx + 1, 1].set_title(f"Output: PSNR={metrics['psnr']:.2f} dB", fontsize=14)
        axes[idx + 1, 1].axis("off")
    
    plt.tight_layout()
    plt.savefig("ablation_sparsity.png", dpi=150, bbox_inches="tight")
    plt.show()
    
    # Plot PSNR vs sparsity
    fig, ax = plt.subplots(figsize=(8, 5))
    sparsity_vals = sorted(results.keys())
    psnrs = [results[s]["metrics"]["psnr"] for s in sparsity_vals]
    ax.plot([s * 100 for s in sparsity_vals], psnrs, "o-", markersize=10, linewidth=2, color="green")
    ax.axvline(x=20, color="red", linestyle="--", label="Training sparsity (20%)")
    ax.set_xlabel("Conditioning Sparsity (%)", fontsize=12)
    ax.set_ylabel("PSNR (dB)", fontsize=12)
    ax.set_title("PSNR vs Conditioning Sparsity\n(Model trained with 20% input)", fontsize=14)
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig("psnr_vs_sparsity.png", dpi=150, bbox_inches="tight")
    plt.show()


if 'sparsity_results' in dir():
    visualize_sparsity_ablation(sparsity_results, test_images[:EVAL_BATCH_SIZE])

## 10. Super-Resolution Evaluation (4x)

In [None]:
@torch.no_grad()
def sample_superres(
    model,
    images,
    labels,
    scale=4,
    sparsity=0.2,
    num_steps=200,
    guidance=2.0,
    disable_spatial_bias=False,
    seed=42,
):
    """
    Run super-resolution: upsample conditioning mask and sample at higher resolution.
    """
    B, C, H, W = images.shape
    device = images.device
    dtype = images.dtype
    
    H_hr, W_hr = H * scale, W * scale
    
    # Generate LR conditioning mask
    cond_mask_lr = generate_sparsity_mask(B, H, W, sparsity, device, dtype, seed=seed)
    
    # Lift to HR (place LR pixels on HR grid)
    cond_mask_hr = torch.zeros((B, 1, H_hr, W_hr), device=device, dtype=dtype)
    cond_mask_hr[:, :, ::scale, ::scale] = cond_mask_lr
    
    x_cond_hr = torch.zeros((B, C, H_hr, W_hr), device=device, dtype=dtype)
    x_cond_hr[:, :, ::scale, ::scale] = images
    
    # Set decoder scale
    old_scale_h = model.decoder_patch_scaling_h
    old_scale_w = model.decoder_patch_scaling_w
    model.decoder_patch_scaling_h = scale
    model.decoder_patch_scaling_w = scale
    
    # Get conditioning
    condition, uncondition = conditioner(labels)
    
    # Build sampler
    sampler = build_sampler(scheduler, num_steps=num_steps, guidance=guidance)
    
    # Sample
    noise = torch.randn((B, C, H_hr, W_hr), device=device, dtype=dtype)
    
    def model_fn(x, t, y, cond_mask=None):
        return model(x, t, y, cond_mask=cond_mask, disable_spatial_bias=disable_spatial_bias)
    
    samples = sampler(
        model_fn,
        noise,
        condition,
        uncondition,
        cond_mask=cond_mask_hr,
        x_cond=x_cond_hr,
    )
    
    # Restore scale
    model.decoder_patch_scaling_h = old_scale_h
    model.decoder_patch_scaling_w = old_scale_w
    
    return samples, cond_mask_hr, x_cond_hr


print("Super-resolution function defined!")

In [None]:
def run_superres_comparison(models_dict, images, labels, scale=4, sparsity=0.2, num_steps=200, guidance=2.0):
    """Compare super-resolution across model variants."""
    results = {}
    
    images = images.to(DEVICE).to(DTYPE)
    labels = labels.to(DEVICE)
    
    for name, model in tqdm(models_dict.items(), desc="SR Models"):
        # Try with and without spatial bias disabled
        samples, cond_mask, x_cond = sample_superres(
            model, images, labels,
            scale=scale,
            sparsity=sparsity,
            num_steps=num_steps,
            guidance=guidance,
            disable_spatial_bias=False,  # Use Option B if available
        )
        
        results[name] = {
            "samples": samples.cpu(),
            "cond_mask": cond_mask.cpu(),
            "x_cond": x_cond.cpu(),
        }
        
        # Also try with spatial bias disabled (for B variants)
        if "B" in name or "A+B" in name:
            samples_no_bias, _, _ = sample_superres(
                model, images, labels,
                scale=scale,
                sparsity=sparsity,
                num_steps=num_steps,
                guidance=guidance,
                disable_spatial_bias=True,
            )
            results[f"{name}_no_bias"] = {
                "samples": samples_no_bias.cpu(),
                "cond_mask": cond_mask.cpu(),
                "x_cond": x_cond.cpu(),
            }
    
    return results


if models:
    print("\n" + "="*60)
    print(f"Super-Resolution Comparison (4x upscale)")
    print("="*60)
    
    sr_results = run_superres_comparison(
        models,
        test_images[:4],  # Fewer samples for SR (memory)
        test_labels[:4],
        scale=4,
        sparsity=0.2,
        num_steps=200,
        guidance=2.0,
    )

In [None]:
def visualize_superres(results, gt_images, scale=4, num_show=4):
    """Visualize super-resolution results."""
    num_variants = len(results)
    
    fig, axes = plt.subplots(num_variants + 2, 1, figsize=(16, 4 * (num_variants + 2)))
    
    # LR input
    axes[0].imshow(tensor_to_image(gt_images[:num_show], nrow=num_show))
    axes[0].set_title(f"LR Input (32x32)", fontsize=14)
    axes[0].axis("off")
    
    # Bicubic upscale for reference
    bicubic = F.interpolate(gt_images[:num_show], scale_factor=scale, mode="bicubic", align_corners=False)
    axes[1].imshow(tensor_to_image(bicubic, nrow=num_show))
    axes[1].set_title(f"Bicubic Upscale ({32*scale}x{32*scale})", fontsize=14)
    axes[1].axis("off")
    
    # Each variant
    for idx, (name, data) in enumerate(results.items()):
        samples = data["samples"][:num_show]
        
        axes[idx + 2].imshow(tensor_to_image(samples, nrow=num_show))
        axes[idx + 2].set_title(f"{name} ({32*scale}x{32*scale})", fontsize=14)
        axes[idx + 2].axis("off")
    
    plt.tight_layout()
    plt.savefig("superres_comparison.png", dpi=150, bbox_inches="tight")
    plt.show()


if 'sr_results' in dir():
    visualize_superres(sr_results, test_images[:4], scale=4)

## 11. Summary Statistics

In [None]:
def print_summary():
    """Print summary of all ablation results."""
    print("\n" + "="*80)
    print("ABLATION STUDY SUMMARY")
    print("="*80)
    
    if 'ablation_results' in dir() and ablation_results:
        print("\n1. OPTIONS A/B COMPARISON (sparsity=20%, steps=200, guidance=2.0)")
        print("-" * 60)
        for name, data in ablation_results.items():
            m = data["metrics"]
            print(f"  {name:15s}: PSNR = {m['psnr']:.2f} dB, MSE = {m['mse']:.4f}")
    
    if 'steps_results' in dir() and steps_results:
        print("\n2. SAMPLING STEPS (A+B model)")
        print("-" * 60)
        for steps, data in sorted(steps_results.items()):
            m = data["metrics"]
            print(f"  Steps={steps:4d}: PSNR = {m['psnr']:.2f} dB")
    
    if 'guidance_results' in dir() and guidance_results:
        print("\n3. GUIDANCE SCALE (A+B model)")
        print("-" * 60)
        for guidance, data in sorted(guidance_results.items()):
            m = data["metrics"]
            print(f"  Guidance={guidance:.1f}: PSNR = {m['psnr']:.2f} dB")
    
    if 'sparsity_results' in dir() and sparsity_results:
        print("\n4. SPARSITY RATE (A+B model, trained with 20%)")
        print("-" * 60)
        for sparsity, data in sorted(sparsity_results.items()):
            m = data["metrics"]
            marker = " <-- training" if abs(sparsity - 0.2) < 0.01 else ""
            print(f"  Sparsity={sparsity:5.0%}: PSNR = {m['psnr']:.2f} dB{marker}")
    
    print("\n" + "="*80)


print_summary()

## 12. Detailed Single-Image Analysis

In [None]:
def detailed_single_image_analysis(model, image, label, sparsities=[0.1, 0.2, 0.4], num_steps=200, guidance=2.0):
    """
    Detailed analysis of a single image across sparsity levels.
    Shows: GT, conditioning mask, reconstruction, error map.
    """
    image = image.unsqueeze(0).to(DEVICE).to(DTYPE)
    label = label.unsqueeze(0).to(DEVICE)
    
    fig, axes = plt.subplots(len(sparsities), 4, figsize=(16, 4 * len(sparsities)))
    
    for idx, sparsity in enumerate(sparsities):
        samples, cond_mask = sample_reconstruction(
            model, image, label,
            sparsity=sparsity,
            num_steps=num_steps,
            guidance=guidance,
        )
        
        gt = image[0].cpu()
        recon = samples[0].cpu()
        mask = cond_mask[0].cpu()
        
        # Conditioned input visualization
        cond_vis = gt * mask + (1 - mask) * 0.5
        
        # Error map (absolute difference)
        error = (recon - gt).abs().mean(dim=0)  # Average over channels
        
        # Plot
        axes[idx, 0].imshow(tensor_to_image(gt.unsqueeze(0), nrow=1))
        axes[idx, 0].set_title("Ground Truth")
        axes[idx, 0].axis("off")
        
        axes[idx, 1].imshow(tensor_to_image(cond_vis.unsqueeze(0), nrow=1))
        axes[idx, 1].set_title(f"Input ({sparsity:.0%} observed)")
        axes[idx, 1].axis("off")
        
        axes[idx, 2].imshow(tensor_to_image(recon.unsqueeze(0), nrow=1))
        axes[idx, 2].set_title("Reconstruction")
        axes[idx, 2].axis("off")
        
        im = axes[idx, 3].imshow(error.numpy(), cmap="hot", vmin=0, vmax=0.5)
        axes[idx, 3].set_title(f"Error Map (MAE={error.mean():.3f})")
        axes[idx, 3].axis("off")
        plt.colorbar(im, ax=axes[idx, 3], fraction=0.046)
    
    plt.tight_layout()
    plt.savefig("detailed_analysis.png", dpi=150, bbox_inches="tight")
    plt.show()


if "A+B" in models:
    print("\nDetailed Single-Image Analysis (A+B model)")
    print("="*60)
    detailed_single_image_analysis(
        models["A+B"],
        test_images[0],
        test_labels[0],
        sparsities=[0.1, 0.2, 0.4, 0.6],
    )

In [None]:
print("\nEvaluation notebook complete!")
print("\nGenerated files:")
print("  - ablation_options_ab.png")
print("  - ablation_steps.png")
print("  - psnr_vs_steps.png")
print("  - ablation_guidance.png")
print("  - psnr_vs_guidance.png")
print("  - ablation_sparsity.png")
print("  - psnr_vs_sparsity.png")
print("  - superres_comparison.png")
print("  - detailed_analysis.png")