# Conditional Sparse Reconstruction - Evaluation

Comprehensive evaluation with:
1. **Conditional Sampling**: Generate full reconstructions from sparse context
2. **MSE Evaluation**: Mean Squared Error on reconstructed images
3. **CRPS**: Continuous Ranked Probability Score for ensemble evaluation
4. **Visualization**: Compare context → reconstruction → ground truth

## CRPS Explanation

**Continuous Ranked Probability Score** generalizes MAE to probabilistic forecasts:

$$\text{CRPS}(F, y) = \int_{-\infty}^{\infty} [F(x) - H(x - y)]^2 dx$$

Where:
- $F$ = CDF of your ensemble forecast
- $y$ = Ground truth observation
- $H$ = Heaviside step function

**For ensemble of M samples**:
$$\text{CRPS} = \frac{1}{M} \sum_{i=1}^M |x_i - y| - \frac{1}{2M^2} \sum_{i=1}^M \sum_{j=1}^M |x_i - x_j|$$

**Key properties**:
- Rewards **sharpness** (narrow distribution if correct)
- Rewards **calibration** (GT should be plausible from ensemble)
- Collapses to MAE for deterministic forecasts (M=1)
- Lower is better

In [None]:
# Setup
from IPython.display import display, HTML
display(HTML("<style>.container { width:98% !important; }</style>"))

In [None]:
import os
import sys
import time
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# Change to parent directory
os.chdir('..')
sys.path.insert(0, '.')

%matplotlib inline

print(f"Working directory: {os.getcwd()}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
# Import utilities
from utils.sparse_datasets_fixed import (
    FixedSparseMaskDataset,
    create_context_image_batched,
    create_sparse_mask_image
)
from utils.visualize import get_grid_image
from utils.utils import Writer, count_parameters_in_M, load_checkpoint

# Import DDO components
from lib.diffusion import BlurringDiffusion, DenoisingDiffusion, diffuse
from lib.models.fourier_unet import FNOUNet2d
from lib.conditional_model import ConditionalDDOModel

from tqdm.auto import tqdm

print("All imports successful!")

## 1. Configuration & Model Loading

In [None]:
import argparse

args = argparse.Namespace()

# Paths
args.exp_path = './experiments/conditional_sparse_recon'  # Where trained model is
args.data = './data'
args.seed = 1

# Dataset
args.dataset = 'cifar10'
args.train_img_height = 32
args.input_dim = 3
args.coord_dim = 2

# Sparse settings (must match training)
args.context_ratio = 0.1
args.query_ratio = 0.1
args.mask_seed = 42

# Model architecture (must match training)
args.model = 'fnounet2d'
args.ch = 64
args.ch_mult = [1, 2, 2]
args.num_res_blocks = 2
args.modes = 16
args.dropout = 0.1
args.norm = 'group_norm'
args.use_pos = True
args.use_pointwise_op = True
args.context_feature_dim = 32

# Diffusion settings
args.ns_method = 'vp_cosine'
args.timestep_sampler = 'low_discrepancy'
args.disp_method = 'sine'
args.sigma_blur_min = 0.05
args.sigma_blur_max = 0.25
args.gp_type = 'exponential'
args.gp_exponent = 2.0
args.gp_length_scale = 0.05
args.gp_sigma = 1.0

# Sampling
args.num_steps = 250
args.sampler = 'denoise'
args.s_min = 0.0001
args.use_clip = False
args.weight_method = None

# Evaluation settings
args.num_eval_samples = 1000  # Number of images to evaluate
args.num_ensemble = 10        # Number of samples per image for CRPS
args.eval_batch_size = 16     # Batch size for evaluation

# Misc
args.checkpoint_file = 'checkpoint.pt'
args.ema_decay = 0.999

print("=" * 60)
print("Evaluation Configuration")
print("=" * 60)
print(f"Model path: {args.exp_path}")
print(f"Context ratio: {args.context_ratio*100:.0f}%")
print(f"Query ratio: {args.query_ratio*100:.0f}%")
print(f"Num eval samples: {args.num_eval_samples}")
print(f"Ensemble size (for CRPS): {args.num_ensemble}")
print("=" * 60)

### Load Test Dataset

In [None]:
# Load CIFAR-10 test set
transform = transforms.Compose([transforms.ToTensor()])

test_dataset = torchvision.datasets.CIFAR10(
    root=args.data, train=False, download=True, transform=transform
)

print(f"Test dataset: {len(test_dataset)} images")

# Wrap with fixed sparse masks
sparse_test_dataset = FixedSparseMaskDataset(
    dataset=test_dataset,
    context_ratio=args.context_ratio,
    query_ratio=args.query_ratio,
    seed=args.mask_seed + 1000  # Different seed from training
)

print(sparse_test_dataset)

# Create dataloader
test_loader = torch.utils.data.DataLoader(
    sparse_test_dataset,
    batch_size=args.eval_batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

print(f"\nTest loader: {len(test_loader)} batches")

### Load Trained Model

In [None]:
def get_mgrid(dim, img_height):
    """Generate coordinate grid"""
    grid = torch.linspace(0, img_height-1, img_height) / img_height
    if dim == 2:
        grid = torch.cat([grid[None,None,...,None].repeat(1, 1, 1, img_height),
                          grid[None,None,None].repeat(1, 1, img_height, 1)], dim=1)
    return grid


def load_trained_model(args):
    """Load trained conditional DDO model"""
    
    # GP config
    gp_config = argparse.Namespace()
    gp_config.device = 'cuda'
    gp_config.exponent = args.gp_exponent
    gp_config.length_scale = args.gp_length_scale
    gp_config.sigma = args.gp_sigma
    
    # Blurring config
    disp_config = argparse.Namespace()
    disp_config.sigma_blur_min = args.sigma_blur_min
    disp_config.sigma_blur_max = args.sigma_blur_max
    
    # Create diffusion process
    inf_sde = BlurringDiffusion(
        dim=args.coord_dim,
        ch=args.input_dim,
        ns_method=args.ns_method,
        disp_method=args.disp_method,
        disp_config=disp_config,
        gp_type=args.gp_type,
        gp_config=gp_config,
    )
    
    # Create base FNO-UNet
    base_unet = FNOUNet2d(
        modes_height=args.modes,
        modes_width=args.modes,
        in_channels=args.input_dim,
        in_height=args.train_img_height,
        ch=args.ch,
        ch_mult=tuple(args.ch_mult),
        num_res_blocks=args.num_res_blocks,
        dropout=args.dropout,
        norm=args.norm,
        use_pos=args.use_pos,
        use_pointwise_op=args.use_pointwise_op,
    )
    
    # Wrap with conditional layer
    model = ConditionalDDOModel(
        base_unet,
        input_dim=args.input_dim,
        context_feature_dim=args.context_feature_dim
    )
    
    # Create denoising diffusion wrapper
    gen_sde = DenoisingDiffusion(
        inf_sde,
        model=model,
        timestep_sampler=args.timestep_sampler,
        use_clip=args.use_clip,
        weight_method=args.weight_method
    ).cuda()
    
    # Load checkpoint
    checkpoint_file = os.path.join(args.exp_path, args.checkpoint_file)
    if not os.path.exists(checkpoint_file):
        raise FileNotFoundError(f"No checkpoint found at {checkpoint_file}")
    
    print(f"Loading checkpoint from {checkpoint_file}")
    checkpoint = torch.load(checkpoint_file, map_location='cuda')
    gen_sde.load_state_dict(checkpoint['gen_sde_state_dict'])
    
    # Use EMA weights if available
    if 'gen_sde_optimizer' in checkpoint and 'ema' in checkpoint['gen_sde_optimizer']:
        print("Loading EMA weights")
        gen_sde.load_state_dict(checkpoint['gen_sde_optimizer']['ema'])
    
    iteration = checkpoint.get('global_step', 0)
    print(f"Loaded model from iteration {iteration}")
    
    gen_sde.eval()
    return gen_sde


# Load model
gen_sde = load_trained_model(args)

num_params = count_parameters_in_M(gen_sde._model)
print(f"\nModel parameters: {num_params:.2f}M")

## 2. Conditional Sampling Functions

In [None]:
def sample_conditional(gen_sde, context_image, v, num_steps=250, sampler='denoise'):
    """
    Generate a sample conditioned on sparse context.
    
    Args:
        gen_sde: Trained conditional diffusion model
        context_image: (B, C, H, W) sparse observations
        v: (B, coord_dim, H, W) coordinate grid
        num_steps: Number of diffusion steps
        sampler: Sampling method
    
    Returns:
        Generated samples (B, C, H, W)
    """
    batch_size = context_image.shape[0]
    device = context_image.device
    
    # Start from random noise
    x_T = torch.randn_like(context_image)
    
    # Run reverse diffusion conditioned on context
    samples = diffuse(
        gen_sde,
        num_steps=num_steps,
        x_0=x_T,
        v=v,
        sampler=sampler,
        context_image=context_image,  # Pass context!
        disable_tqdm=False
    )
    
    return samples[-1]  # Return final sample


def sample_ensemble_conditional(gen_sde, context_image, v, num_ensemble=10, num_steps=250):
    """
    Generate multiple samples (ensemble) for CRPS evaluation.
    
    Returns:
        Ensemble of shape (num_ensemble, B, C, H, W)
    """
    ensemble = []
    
    for i in tqdm(range(num_ensemble), desc='Generating ensemble'):
        sample = sample_conditional(gen_sde, context_image, v, num_steps)
        ensemble.append(sample)
    
    return torch.stack(ensemble, dim=0)  # (M, B, C, H, W)


print("Sampling functions defined!")

## 3. Evaluation Metrics

In [None]:
def compute_mse(predictions, ground_truth):
    """
    Compute Mean Squared Error.
    
    Args:
        predictions: (B, C, H, W)
        ground_truth: (B, C, H, W)
    
    Returns:
        MSE per image (B,) and overall mean
    """
    mse_per_image = ((predictions - ground_truth) ** 2).mean(dim=(1, 2, 3))
    return mse_per_image, mse_per_image.mean()


def compute_mae(predictions, ground_truth):
    """
    Compute Mean Absolute Error.
    
    Returns:
        MAE per image (B,) and overall mean
    """
    mae_per_image = (predictions - ground_truth).abs().mean(dim=(1, 2, 3))
    return mae_per_image, mae_per_image.mean()


def compute_crps_ensemble(ensemble, ground_truth):
    """
    Compute Continuous Ranked Probability Score for ensemble forecasts.
    
    Args:
        ensemble: (M, B, C, H, W) - M ensemble members
        ground_truth: (B, C, H, W)
    
    Returns:
        CRPS per image (B,) and overall mean
    
    Formula:
        CRPS = (1/M) * sum_i |x_i - y| - (1/2M^2) * sum_i sum_j |x_i - x_j|
    
    Where:
        x_i = ensemble member i
        y = ground truth
        M = ensemble size
    
    Properties:
        - Rewards sharpness (narrow distribution if correct)
        - Rewards calibration (GT should be plausible)
        - Collapses to MAE when M=1
        - Lower is better
    """
    M = ensemble.shape[0]  # Ensemble size
    B = ground_truth.shape[0]  # Batch size
    
    # Flatten spatial dimensions for easier computation
    # ensemble: (M, B, C*H*W)
    # ground_truth: (B, C*H*W)
    ensemble_flat = ensemble.reshape(M, B, -1)
    gt_flat = ground_truth.reshape(B, -1)
    
    # Term 1: Average distance from each ensemble member to ground truth
    # |x_i - y| averaged over ensemble
    term1 = torch.abs(ensemble_flat - gt_flat[None, :, :]).mean(dim=0)  # (B, C*H*W)
    
    # Term 2: Average pairwise distance between ensemble members (sharpness penalty)
    # |x_i - x_j| averaged over all pairs
    term2 = 0.0
    for i in range(M):
        for j in range(M):
            term2 += torch.abs(ensemble_flat[i] - ensemble_flat[j])  # (B, C*H*W)
    term2 = term2 / (2 * M * M)
    
    # CRPS per pixel, then average over pixels
    crps_per_pixel = term1 - term2  # (B, C*H*W)
    crps_per_image = crps_per_pixel.mean(dim=1)  # (B,)
    
    return crps_per_image, crps_per_image.mean()


def compute_psnr(predictions, ground_truth, max_val=1.0):
    """
    Compute Peak Signal-to-Noise Ratio.
    
    Returns:
        PSNR per image (B,) and overall mean
    """
    mse_per_image, _ = compute_mse(predictions, ground_truth)
    psnr_per_image = 20 * torch.log10(max_val / torch.sqrt(mse_per_image))
    return psnr_per_image, psnr_per_image.mean()


print("Evaluation metrics defined!")
print("\nAvailable metrics:")
print("  - MSE: Mean Squared Error")
print("  - MAE: Mean Absolute Error (deterministic baseline)")
print("  - CRPS: Continuous Ranked Probability Score (ensemble)")
print("  - PSNR: Peak Signal-to-Noise Ratio")

## 4. Run Evaluation

### Quick Test: Single Image

In [None]:
# Test on a single image first
test_sample = sparse_test_dataset[0]

test_image = test_sample['image'].unsqueeze(0).cuda()  # (1, 3, 32, 32)
test_context_indices = test_sample['context_indices'].unsqueeze(0).cuda()
test_context_values = test_sample['context_values'].unsqueeze(0).cuda()

# Create context image
test_context_image = create_context_image_batched(
    test_context_values,
    test_context_indices,
    32, 32, 3
)

# Coordinate grid
v_grid = get_mgrid(2, 32).cuda()

print("Generating single conditional sample...")
reconstruction = sample_conditional(
    gen_sde,
    test_context_image,
    v_grid,
    num_steps=args.num_steps
)

# Compute metrics
_, mse = compute_mse(reconstruction, test_image)
_, psnr = compute_psnr(reconstruction, test_image)
_, mae = compute_mae(reconstruction, test_image)

print(f"\nSingle sample metrics:")
print(f"  MSE:  {mse:.6f}")
print(f"  MAE:  {mae:.6f}")
print(f"  PSNR: {psnr:.2f} dB")

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(12, 4))

# Context
context_vis = create_sparse_mask_image(
    test_image[0].cpu(), test_sample['context_indices'], fill_value=0.5
)
axes[0].imshow(context_vis.permute(1, 2, 0).numpy())
axes[0].set_title(f'Context (10%)', fontsize=12)
axes[0].axis('off')

# Reconstruction
axes[1].imshow(reconstruction[0].cpu().permute(1, 2, 0).clamp(0, 1).numpy())
axes[1].set_title(f'Reconstruction\nMSE={mse:.4f}, PSNR={psnr:.1f}dB', fontsize=12)
axes[1].axis('off')

# Ground truth
axes[2].imshow(test_image[0].cpu().permute(1, 2, 0).numpy())
axes[2].set_title('Ground Truth', fontsize=12)
axes[2].axis('off')

plt.tight_layout()
plt.show()

### Full Evaluation: Multiple Images with Ensemble

In [None]:
# Full evaluation with ensemble for CRPS
num_eval = min(args.num_eval_samples, len(sparse_test_dataset))
print(f"Evaluating on {num_eval} images...")
print(f"Ensemble size: {args.num_ensemble}")
print(f"This will take a while...\n")

all_mse = []
all_mae = []
all_psnr = []
all_crps = []

# Coordinate grid (shared)
v_grid = get_mgrid(2, 32).cuda()

for idx in tqdm(range(0, num_eval, args.eval_batch_size), desc='Evaluation'):
    batch_end = min(idx + args.eval_batch_size, num_eval)
    batch_indices = range(idx, batch_end)
    
    # Gather batch
    batch_images = []
    batch_context_indices = []
    batch_context_values = []
    
    for i in batch_indices:
        sample = sparse_test_dataset[i]
        batch_images.append(sample['image'])
        batch_context_indices.append(sample['context_indices'])
        batch_context_values.append(sample['context_values'])
    
    batch_images = torch.stack(batch_images).cuda()
    batch_context_indices = torch.stack(batch_context_indices).cuda()
    batch_context_values = torch.stack(batch_context_values).cuda()
    
    batch_size = batch_images.shape[0]
    
    # Create context image
    context_image = create_context_image_batched(
        batch_context_values,
        batch_context_indices,
        32, 32, 3
    )
    
    v_batch = v_grid.repeat(batch_size, 1, 1, 1)
    
    # Generate ensemble
    ensemble = sample_ensemble_conditional(
        gen_sde,
        context_image,
        v_batch,
        num_ensemble=args.num_ensemble,
        num_steps=args.num_steps
    )  # (M, B, C, H, W)
    
    # Use mean of ensemble as point estimate
    mean_prediction = ensemble.mean(dim=0)  # (B, C, H, W)
    
    # Compute metrics
    mse_batch, _ = compute_mse(mean_prediction, batch_images)
    mae_batch, _ = compute_mae(mean_prediction, batch_images)
    psnr_batch, _ = compute_psnr(mean_prediction, batch_images)
    crps_batch, _ = compute_crps_ensemble(ensemble, batch_images)
    
    all_mse.extend(mse_batch.cpu().numpy())
    all_mae.extend(mae_batch.cpu().numpy())
    all_psnr.extend(psnr_batch.cpu().numpy())
    all_crps.extend(crps_batch.cpu().numpy())

# Convert to arrays
all_mse = np.array(all_mse)
all_mae = np.array(all_mae)
all_psnr = np.array(all_psnr)
all_crps = np.array(all_crps)

print("\n" + "="*60)
print("EVALUATION RESULTS")
print("="*60)
print(f"Evaluated on {len(all_mse)} images")
print(f"Ensemble size: {args.num_ensemble}\n")

print("Point Estimate Metrics (using ensemble mean):")
print(f"  MSE:  {all_mse.mean():.6f} ± {all_mse.std():.6f}")
print(f"  MAE:  {all_mae.mean():.6f} ± {all_mae.std():.6f}")
print(f"  PSNR: {all_psnr.mean():.2f} ± {all_psnr.std():.2f} dB\n")

print("Probabilistic Metric:")
print(f"  CRPS: {all_crps.mean():.6f} ± {all_crps.std():.6f}")
print(f"\nComparison:")
print(f"  MAE (deterministic baseline): {all_mae.mean():.6f}")
print(f"  CRPS (ensemble):              {all_crps.mean():.6f}")
print(f"  Ratio (CRPS/MAE):             {all_crps.mean()/all_mae.mean():.3f}")

if all_crps.mean() < all_mae.mean():
    print("  → CRPS < MAE: Ensemble is well-calibrated and sharp!")
else:
    print("  → CRPS > MAE: Ensemble may be too spread out or miscalibrated")

print("="*60)

## 5. Visualization of Results

In [None]:
# Plot metric distributions
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

axes[0, 0].hist(all_mse, bins=50, alpha=0.7, color='blue')
axes[0, 0].axvline(all_mse.mean(), color='red', linestyle='--', label=f'Mean: {all_mse.mean():.4f}')
axes[0, 0].set_xlabel('MSE')
axes[0, 0].set_ylabel('Frequency')
axes[0, 0].set_title('MSE Distribution')
axes[0, 0].legend()

axes[0, 1].hist(all_psnr, bins=50, alpha=0.7, color='green')
axes[0, 1].axvline(all_psnr.mean(), color='red', linestyle='--', label=f'Mean: {all_psnr.mean():.2f} dB')
axes[0, 1].set_xlabel('PSNR (dB)')
axes[0, 1].set_ylabel('Frequency')
axes[0, 1].set_title('PSNR Distribution')
axes[0, 1].legend()

axes[1, 0].hist(all_mae, bins=50, alpha=0.7, color='orange', label='MAE')
axes[1, 0].axvline(all_mae.mean(), color='red', linestyle='--', linewidth=2)
axes[1, 0].set_xlabel('Error')
axes[1, 0].set_ylabel('Frequency')
axes[1, 0].set_title('MAE Distribution (Deterministic Baseline)')
axes[1, 0].legend()

axes[1, 1].hist(all_crps, bins=50, alpha=0.7, color='purple', label='CRPS')
axes[1, 1].axvline(all_crps.mean(), color='red', linestyle='--', linewidth=2)
axes[1, 1].hist(all_mae, bins=50, alpha=0.3, color='orange', label='MAE (for comparison)')
axes[1, 1].set_xlabel('Error')
axes[1, 1].set_ylabel('Frequency')
axes[1, 1].set_title('CRPS vs MAE Distribution')
axes[1, 1].legend()

plt.tight_layout()
plt.savefig(os.path.join(args.exp_path, 'evaluation_metrics.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"Saved metrics plot to {args.exp_path}/evaluation_metrics.png")

### Visualize Best and Worst Reconstructions

In [None]:
# Find best and worst by MSE
best_indices = np.argsort(all_mse)[:8]  # 8 best
worst_indices = np.argsort(all_mse)[-8:]  # 8 worst

print(f"Best MSE: {all_mse[best_indices[0]]:.6f}")
print(f"Worst MSE: {all_mse[worst_indices[-1]]:.6f}")

# Regenerate for visualization
# TODO: Add visualization code here

print("\nEvaluation complete!")

## 6. Save Results

In [None]:
# Save results
results = {
    'num_samples': len(all_mse),
    'ensemble_size': args.num_ensemble,
    'mse_mean': float(all_mse.mean()),
    'mse_std': float(all_mse.std()),
    'mae_mean': float(all_mae.mean()),
    'mae_std': float(all_mae.std()),
    'psnr_mean': float(all_psnr.mean()),
    'psnr_std': float(all_psnr.std()),
    'crps_mean': float(all_crps.mean()),
    'crps_std': float(all_crps.std()),
    'crps_to_mae_ratio': float(all_crps.mean() / all_mae.mean()),
}

import json
results_file = os.path.join(args.exp_path, 'evaluation_results.json')
with open(results_file, 'w') as f:
    json.dump(results, f, indent=2)

print(f"Results saved to {results_file}")
print("\n" + json.dumps(results, indent=2))