# Sparse Conditioning Evaluation Notebook

This notebook evaluates the sparse conditioning capabilities of a trained PixNerDiT model.

**Important Note**: If the model was trained with the buggy code (where `cond_mask` was ignored),
sparse conditioning will have limited effectiveness. The model needs to be retrained with the
fixed code to properly learn sparse pixel conditioning.

In [None]:
import os
import sys
from pathlib import Path

# Configuration - UPDATE THESE PATHS
CHECKPOINT_PATH = "/home/idies/workspace/Temporary/dpark1/scratch/kevinhelp1/workdirs/exp_sfc_both/checkpoints/last-v1.ckpt"
# Alternative for local testing:
# CHECKPOINT_PATH = "./workdirs/exp_sfc_both/checkpoints/last-v1.ckpt"

# Add project to path
PROJECT_ROOT = Path(".").absolute()
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

print(f"Project root: {PROJECT_ROOT}")
print(f"Checkpoint: {CHECKPOINT_PATH}")

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision.datasets import CIFAR10
from torchvision import transforms
from functools import partial

# Disable torch.compile for compatibility
torch._dynamo.config.disable = True

# Set precision for tensor cores
torch.set_float32_matmul_precision('high')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

In [None]:
# Import model components
from src.models.autoencoder.pixel import PixelAE
from src.models.conditioner.class_label import LabelConditioner
from src.models.transformer.pixnerd_c2i_heavydecoder import PixNerDiT
from src.diffusion.flow_matching.scheduling import LinearScheduler
from src.diffusion.flow_matching.sampling import EulerSampler, ode_step_fn
from src.diffusion.base.guidance import simple_guidance_fn
from src.diffusion.flow_matching.training import FlowMatchingTrainer
from src.callbacks.simple_ema import SimpleEMA
from src.lightning_model import LightningModel

print("All imports successful!")

## 1. Model Configuration

These settings should match the training configuration.

In [None]:
# Model configuration (must match training config)
MODEL_CONFIG = {
    "in_channels": 3,
    "patch_size": 4,
    "num_groups": 4,
    "hidden_size": 256,
    "decoder_hidden_size": 64,
    "num_encoder_blocks": 4,
    "num_decoder_blocks": 2,
    "num_classes": 10,
    "encoder_type": "sfc",
    # SFC settings
    "sfc_curve": "hilbert",
    "sfc_group_size": 8,
    "sfc_cross_depth": 2,
    # Ablation flags (both enabled for exp_sfc_both)
    "sfc_unified_coords": True,
    "sfc_spatial_bias": True,
}

# Sampling configuration
SAMPLING_CONFIG = {
    "num_steps": 50,  # Fewer steps for faster evaluation
    "guidance": 2.0,
}

# Sparsity configuration
SPARSITY = 0.4  # 40% of pixels observed

print("Configuration loaded.")
print(f"Encoder type: {MODEL_CONFIG['encoder_type']}")
print(f"SFC unified coords: {MODEL_CONFIG['sfc_unified_coords']}")
print(f"SFC spatial bias: {MODEL_CONFIG['sfc_spatial_bias']}")

## 2. Build Model and Load Checkpoint

In [None]:
def build_model(config, sampling_config):
    """Build the model architecture."""
    main_scheduler = LinearScheduler()
    
    vae = PixelAE(scale=1.0)
    conditioner = LabelConditioner(num_classes=config["num_classes"])
    
    denoiser = PixNerDiT(
        in_channels=config["in_channels"],
        patch_size=config["patch_size"],
        num_groups=config["num_groups"],
        hidden_size=config["hidden_size"],
        decoder_hidden_size=config["decoder_hidden_size"],
        num_encoder_blocks=config["num_encoder_blocks"],
        num_decoder_blocks=config["num_decoder_blocks"],
        num_classes=config["num_classes"],
        encoder_type=config["encoder_type"],
        sfc_curve=config["sfc_curve"],
        sfc_group_size=config["sfc_group_size"],
        sfc_cross_depth=config["sfc_cross_depth"],
        sfc_unified_coords=config["sfc_unified_coords"],
        sfc_spatial_bias=config["sfc_spatial_bias"],
    )
    
    sampler = EulerSampler(
        num_steps=sampling_config["num_steps"],
        guidance=sampling_config["guidance"],
        guidance_interval_min=0.0,
        guidance_interval_max=1.0,
        scheduler=main_scheduler,
        w_scheduler=LinearScheduler(),
        guidance_fn=simple_guidance_fn,
        step_fn=ode_step_fn,
    )
    
    trainer = FlowMatchingTrainer(
        scheduler=main_scheduler,
        lognorm_t=True,
        timeshift=1.0,
    )
    
    ema_tracker = SimpleEMA(decay=0.9999)
    optimizer = partial(torch.optim.AdamW, lr=1e-4, weight_decay=0.0)
    
    model = LightningModel(
        vae=vae,
        conditioner=conditioner,
        denoiser=denoiser,
        diffusion_trainer=trainer,
        diffusion_sampler=sampler,
        ema_tracker=ema_tracker,
        optimizer=optimizer,
        lr_scheduler=None,
        eval_original_model=False,
    )
    return model

# Build model
model = build_model(MODEL_CONFIG, SAMPLING_CONFIG)
print(f"Model built. Total parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Load checkpoint
print(f"Loading checkpoint from: {CHECKPOINT_PATH}")

if not os.path.exists(CHECKPOINT_PATH):
    raise FileNotFoundError(f"Checkpoint not found: {CHECKPOINT_PATH}")

checkpoint = torch.load(CHECKPOINT_PATH, map_location="cpu")
print(f"Checkpoint keys: {checkpoint.keys()}")

# Load state dict
state_dict = checkpoint["state_dict"]
missing, unexpected = model.load_state_dict(state_dict, strict=False)

print(f"\nMissing keys: {len(missing)}")
if missing:
    print(f"  Examples: {missing[:5]}")
print(f"Unexpected keys: {len(unexpected)}")
if unexpected:
    print(f"  Examples: {unexpected[:5]}")

# Get training step info if available
if "global_step" in checkpoint:
    print(f"\nCheckpoint from step: {checkpoint['global_step']}")
if "epoch" in checkpoint:
    print(f"Checkpoint from epoch: {checkpoint['epoch']}")

In [None]:
# Move to GPU and set to eval mode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()

# Convert EMA denoiser to float32 for inference
model.ema_denoiser.to(torch.float32)

print(f"Model moved to: {device}")
print("Model set to eval mode.")

## 3. Load Test Data

In [None]:
# CIFAR-10 class names
CIFAR10_CLASSES = [
    "airplane", "automobile", "bird", "cat", "deer",
    "dog", "frog", "horse", "ship", "truck"
]

# Load test dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

test_dataset = CIFAR10(root="./data", train=False, transform=transform, download=True)
print(f"Test dataset size: {len(test_dataset)}")

In [None]:
def get_test_batch(dataset, indices, device):
    """Get a batch of test images."""
    images = []
    labels = []
    for idx in indices:
        img, label = dataset[idx]
        images.append(img)
        labels.append(label)
    
    images = torch.stack(images).to(device)
    labels = torch.tensor(labels, dtype=torch.long, device=device)
    return images, labels

# Get a batch of test images (one per class)
test_indices = [i * 1000 for i in range(10)]  # Spread across dataset
test_images, test_labels = get_test_batch(test_dataset, test_indices, device)

print(f"Test batch shape: {test_images.shape}")
print(f"Test labels: {[CIFAR10_CLASSES[l] for l in test_labels.tolist()]}")

## 4. Utility Functions

In [None]:
def generate_sparse_mask(x, sparsity=0.4):
    """
    Generate a sparse conditioning mask.
    
    Args:
        x: (B, C, H, W) input tensor
        sparsity: fraction of pixels to observe
    
    Returns:
        cond_mask: (B, 1, H, W) binary mask where 1 = observed pixel
    """
    B, C, H, W = x.shape
    device = x.device
    
    total_keep = int(sparsity * H * W)
    cond_mask = torch.zeros(B, 1, H, W, device=device)
    
    for b in range(B):
        indices = torch.randperm(H * W, device=device)[:total_keep]
        mask_flat = torch.zeros(H * W, device=device)
        mask_flat[indices] = 1.0
        cond_mask[b, 0] = mask_flat.view(H, W)
    
    return cond_mask

def tensor_to_image(tensor):
    """Convert tensor to displayable image."""
    img = tensor.detach().cpu()
    img = (img.clamp(-1, 1) + 1) / 2  # [-1, 1] -> [0, 1]
    img = img.permute(1, 2, 0).numpy()
    return img

def visualize_batch(images, titles=None, figsize=(15, 3)):
    """Visualize a batch of images."""
    n = len(images)
    fig, axes = plt.subplots(1, n, figsize=figsize)
    if n == 1:
        axes = [axes]
    
    for i, (img, ax) in enumerate(zip(images, axes)):
        ax.imshow(tensor_to_image(img))
        ax.axis("off")
        if titles:
            ax.set_title(titles[i], fontsize=10)
    
    plt.tight_layout()
    return fig

print("Utility functions defined.")

## 5. Class-Conditional Generation (Baseline)

First, test that the model can generate class-conditional images without sparse hints.

In [None]:
@torch.no_grad()
def generate_class_conditional(model, labels, num_samples=None):
    """
    Generate images conditioned only on class labels.
    """
    if num_samples is None:
        num_samples = len(labels)
    
    device = next(model.parameters()).device
    
    # Get conditioning
    condition, uncondition = model.conditioner(labels)
    
    # Generate noise
    noise = torch.randn(num_samples, 3, 32, 32, device=device)
    
    # Sample
    samples = model.diffusion_sampler(
        model.ema_denoiser,
        noise,
        condition,
        uncondition,
    )
    
    # Handle tuple return
    if isinstance(samples, tuple):
        samples = samples[0][-1] if isinstance(samples[0], list) else samples[0]
    
    # Decode (PixelAE is identity)
    samples = model.vae.decode(samples)
    
    return samples

print("Generating class-conditional samples...")
class_samples = generate_class_conditional(model, test_labels)
print(f"Generated {len(class_samples)} samples.")

In [None]:
# Visualize class-conditional samples
titles = [f"{CIFAR10_CLASSES[l]}" for l in test_labels.tolist()]
fig = visualize_batch(class_samples, titles, figsize=(20, 2.5))
plt.suptitle("Class-Conditional Generation (No Sparse Hints)", fontsize=14, y=1.05)
plt.show()

## 6. Sparse-Conditioned Generation

Now test generation with sparse pixel hints.

**Note**: If the model was trained with the buggy code, sparse conditioning will have limited effect.

In [None]:
@torch.no_grad()
def generate_sparse_conditioned(model, images, labels, sparsity=0.4):
    """
    Generate images conditioned on class labels AND sparse pixel hints.
    """
    device = next(model.parameters()).device
    B = len(images)
    
    # Encode images to latent space
    x_latent = model.vae.encode(images)
    
    # Generate sparse mask
    cond_mask = generate_sparse_mask(x_latent, sparsity=sparsity)
    
    # Get conditioning
    condition, uncondition = model.conditioner(labels)
    
    # Generate noise
    noise = torch.randn_like(x_latent)
    
    # Sample with sparse conditioning
    samples = model.diffusion_sampler(
        model.ema_denoiser,
        noise,
        condition,
        uncondition,
        cond_mask=cond_mask,
        x_cond=x_latent,
    )
    
    # Handle tuple return
    if isinstance(samples, tuple):
        samples = samples[0][-1] if isinstance(samples[0], list) else samples[0]
    
    # Decode
    samples = model.vae.decode(samples)
    
    # Create sparse input visualization
    sparse_input = images * cond_mask
    
    return samples, sparse_input, cond_mask

print(f"Generating sparse-conditioned samples (sparsity={SPARSITY})...")
sparse_samples, sparse_inputs, masks = generate_sparse_conditioned(
    model, test_images, test_labels, sparsity=SPARSITY
)
print(f"Generated {len(sparse_samples)} samples.")

In [None]:
# Visualize comparison: Ground Truth vs Sparse Input vs Sparse-Conditioned Output
n_show = min(5, len(test_images))

fig, axes = plt.subplots(3, n_show, figsize=(3*n_show, 9))

for i in range(n_show):
    # Ground truth
    axes[0, i].imshow(tensor_to_image(test_images[i]))
    axes[0, i].set_title(f"GT: {CIFAR10_CLASSES[test_labels[i]]}")
    axes[0, i].axis("off")
    
    # Sparse input
    axes[1, i].imshow(tensor_to_image(sparse_inputs[i]))
    axes[1, i].set_title(f"Sparse ({SPARSITY*100:.0f}%)")
    axes[1, i].axis("off")
    
    # Sparse-conditioned output
    axes[2, i].imshow(tensor_to_image(sparse_samples[i]))
    axes[2, i].set_title("Reconstructed")
    axes[2, i].axis("off")

axes[0, 0].set_ylabel("Ground Truth", fontsize=12)
axes[1, 0].set_ylabel("Sparse Input", fontsize=12)
axes[2, 0].set_ylabel("Reconstruction", fontsize=12)

plt.suptitle("Sparse-Conditioned Reconstruction", fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

## 7. Comparison: Class-Only vs Sparse-Conditioned

In [None]:
# Generate class-only samples with same noise for fair comparison
torch.manual_seed(42)
noise = torch.randn(len(test_images), 3, 32, 32, device=device)

with torch.no_grad():
    x_latent = model.vae.encode(test_images)
    cond_mask = generate_sparse_mask(x_latent, sparsity=SPARSITY)
    condition, uncondition = model.conditioner(test_labels)
    
    # Class-only (same noise)
    torch.manual_seed(42)
    noise = torch.randn_like(x_latent)
    class_only = model.diffusion_sampler(
        model.ema_denoiser, noise, condition, uncondition
    )
    if isinstance(class_only, tuple):
        class_only = class_only[0][-1] if isinstance(class_only[0], list) else class_only[0]
    class_only = model.vae.decode(class_only)
    
    # Sparse-conditioned (same noise)
    torch.manual_seed(42)
    noise = torch.randn_like(x_latent)
    sparse_cond = model.diffusion_sampler(
        model.ema_denoiser, noise, condition, uncondition,
        cond_mask=cond_mask, x_cond=x_latent
    )
    if isinstance(sparse_cond, tuple):
        sparse_cond = sparse_cond[0][-1] if isinstance(sparse_cond[0], list) else sparse_cond[0]
    sparse_cond = model.vae.decode(sparse_cond)
    
    sparse_input_vis = test_images * cond_mask

In [None]:
# Visualize 4-way comparison
n_show = min(5, len(test_images))

fig, axes = plt.subplots(4, n_show, figsize=(3*n_show, 12))

for i in range(n_show):
    # Ground truth
    axes[0, i].imshow(tensor_to_image(test_images[i]))
    axes[0, i].set_title(f"{CIFAR10_CLASSES[test_labels[i]]}")
    axes[0, i].axis("off")
    
    # Sparse input
    axes[1, i].imshow(tensor_to_image(sparse_input_vis[i]))
    axes[1, i].axis("off")
    
    # Class-only
    axes[2, i].imshow(tensor_to_image(class_only[i]))
    axes[2, i].axis("off")
    
    # Sparse-conditioned
    axes[3, i].imshow(tensor_to_image(sparse_cond[i]))
    axes[3, i].axis("off")

row_labels = ["Ground Truth", f"Sparse Input ({SPARSITY*100:.0f}%)", 
              "Class-Only", "Sparse-Conditioned"]
for i, label in enumerate(row_labels):
    axes[i, 0].set_ylabel(label, fontsize=11)

plt.suptitle("Comparison: Class-Only vs Sparse-Conditioned Generation", fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

## 8. Quantitative Metrics

In [None]:
def compute_metrics(gt, pred, mask=None):
    """
    Compute reconstruction metrics.
    
    Args:
        gt: Ground truth images (B, C, H, W)
        pred: Predicted images (B, C, H, W)
        mask: Optional mask (B, 1, H, W) - metrics computed at mask=1 locations
    
    Returns:
        dict with MSE, PSNR
    """
    # Normalize to [0, 1]
    gt = (gt.clamp(-1, 1) + 1) / 2
    pred = (pred.clamp(-1, 1) + 1) / 2
    
    if mask is not None:
        # Only compute at masked locations
        mask = mask.expand_as(gt)
        n_pixels = mask.sum()
        mse = ((gt - pred) ** 2 * mask).sum() / n_pixels
    else:
        mse = ((gt - pred) ** 2).mean()
    
    psnr = 10 * torch.log10(1.0 / mse)
    
    return {
        "mse": mse.item(),
        "psnr": psnr.item(),
    }

# Compute metrics
metrics_class_only = compute_metrics(test_images, class_only)
metrics_sparse_full = compute_metrics(test_images, sparse_cond)
metrics_sparse_masked = compute_metrics(test_images, sparse_cond, cond_mask)

print("=" * 50)
print("Reconstruction Metrics")
print("=" * 50)
print(f"\nClass-Only Generation (no hints):")
print(f"  MSE:  {metrics_class_only['mse']:.6f}")
print(f"  PSNR: {metrics_class_only['psnr']:.2f} dB")

print(f"\nSparse-Conditioned (full image):")
print(f"  MSE:  {metrics_sparse_full['mse']:.6f}")
print(f"  PSNR: {metrics_sparse_full['psnr']:.2f} dB")

print(f"\nSparse-Conditioned (at hint locations only):")
print(f"  MSE:  {metrics_sparse_masked['mse']:.6f}")
print(f"  PSNR: {metrics_sparse_masked['psnr']:.2f} dB")

print("\n" + "=" * 50)
if metrics_sparse_masked['psnr'] > 30:
    print("Hint fidelity is HIGH - sparse conditioning is working!")
elif metrics_sparse_masked['psnr'] > 20:
    print("Hint fidelity is MODERATE - sparse conditioning partially working.")
else:
    print("Hint fidelity is LOW - model may need retraining with fixed code.")

## 9. Test Different Sparsity Levels

In [None]:
sparsity_levels = [0.1, 0.2, 0.4, 0.6, 0.8]
results = []

print("Testing different sparsity levels...")
for sparsity in sparsity_levels:
    with torch.no_grad():
        x_latent = model.vae.encode(test_images)
        cond_mask = generate_sparse_mask(x_latent, sparsity=sparsity)
        condition, uncondition = model.conditioner(test_labels)
        
        torch.manual_seed(42)
        noise = torch.randn_like(x_latent)
        samples = model.diffusion_sampler(
            model.ema_denoiser, noise, condition, uncondition,
            cond_mask=cond_mask, x_cond=x_latent
        )
        if isinstance(samples, tuple):
            samples = samples[0][-1] if isinstance(samples[0], list) else samples[0]
        samples = model.vae.decode(samples)
        
        metrics = compute_metrics(test_images, samples)
        metrics_hints = compute_metrics(test_images, samples, cond_mask)
        
        results.append({
            "sparsity": sparsity,
            "psnr_full": metrics["psnr"],
            "psnr_hints": metrics_hints["psnr"],
        })
        print(f"  Sparsity {sparsity*100:.0f}%: Full PSNR={metrics['psnr']:.2f}, Hint PSNR={metrics_hints['psnr']:.2f}")

print("Done.")

In [None]:
# Plot PSNR vs Sparsity
sparsities = [r["sparsity"] * 100 for r in results]
psnr_full = [r["psnr_full"] for r in results]
psnr_hints = [r["psnr_hints"] for r in results]

fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(sparsities, psnr_full, 'b-o', label='Full Image PSNR', linewidth=2, markersize=8)
ax.plot(sparsities, psnr_hints, 'r-s', label='Hint Locations PSNR', linewidth=2, markersize=8)
ax.set_xlabel('Sparsity (%)', fontsize=12)
ax.set_ylabel('PSNR (dB)', fontsize=12)
ax.set_title('Reconstruction Quality vs Sparsity Level', fontsize=14)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 10. Conclusion

### Interpretation Guide:

**If the model was trained with the FIXED code:**
- Sparse-conditioned samples should closely match ground truth at hint locations
- Hint PSNR should be high (>30 dB)
- Full image PSNR should improve with more hints

**If the model was trained with the BUGGY code:**
- Sparse conditioning will have minimal effect
- Class-only and sparse-conditioned will look similar
- Hint PSNR will be low
- **Solution**: Retrain with the fixed code!

In [None]:
print("\n" + "="*60)
print("EVALUATION SUMMARY")
print("="*60)
print(f"\nCheckpoint: {CHECKPOINT_PATH}")
print(f"Encoder type: {MODEL_CONFIG['encoder_type']}")
print(f"SFC unified coords: {MODEL_CONFIG['sfc_unified_coords']}")
print(f"SFC spatial bias: {MODEL_CONFIG['sfc_spatial_bias']}")
print(f"\nTest sparsity: {SPARSITY*100:.0f}%")
print(f"\nMetrics:")
print(f"  Class-only PSNR: {metrics_class_only['psnr']:.2f} dB")
print(f"  Sparse-cond PSNR (full): {metrics_sparse_full['psnr']:.2f} dB")
print(f"  Sparse-cond PSNR (hints): {metrics_sparse_masked['psnr']:.2f} dB")
print("\n" + "="*60)

improvement = metrics_sparse_full['psnr'] - metrics_class_only['psnr']
if improvement > 3:
    print(f"Sparse conditioning provides +{improvement:.1f} dB improvement!")
    print("The model appears to have learned sparse conditioning.")
elif improvement > 0:
    print(f"Sparse conditioning provides modest +{improvement:.1f} dB improvement.")
    print("Consider retraining with the fixed code for better results.")
else:
    print(f"Sparse conditioning shows no improvement ({improvement:.1f} dB).")
    print("The model was likely trained with buggy code - RETRAIN REQUIRED.")