# Sparse Context: Evaluation & Visualizations

Load **checkpoint_sparse_best.pt** (or _last), evaluate PSNR under **sparse** vs **full** context, and visualize reconstructions and context-fraction sensitivity.

In [None]:
import os
import math
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
from einops import rearrange, repeat

from nf_feature_models import CascadedPerceiverIO, GaussianFourierFeatures, create_coordinate_grid

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CHECKPOINT_DIR = 'checkpoints'
CKPT_PREFIX = 'checkpoint_sparse'
CKPT_PATH = os.path.join(CHECKPOINT_DIR, CKPT_PREFIX + '_best.pt')
if not os.path.isfile(CKPT_PATH):
    CKPT_PATH = os.path.join(CHECKPOINT_DIR, CKPT_PREFIX + '_last.pt')
assert os.path.isfile(CKPT_PATH), f'No {CKPT_PREFIX}_*.pt in {CHECKPOINT_DIR}. Run SparseContext_CIFAR10.ipynb first.'
print('Device:', DEVICE, 'Checkpoint:', CKPT_PATH)


In [None]:
IMAGE_SIZE = 32
CHANNELS = 3
FOURIER_MAPPING_SIZE = 96
POS_EMBED_DIM = FOURIER_MAPPING_SIZE * 2
INPUT_DIM = CHANNELS + POS_EMBED_DIM
QUERIES_DIM = POS_EMBED_DIM
LOGITS_DIM = CHANNELS
CONTEXT_FRAC_TRAIN = 0.2
coords_32 = create_coordinate_grid(IMAGE_SIZE, IMAGE_SIZE, DEVICE)
N_FULL = coords_32.size(0)
N_SPARSE = max(64, int(N_FULL * CONTEXT_FRAC_TRAIN))

def sample_gt_at_coords(images, coords):
    B, C, H, W = images.shape
    N = coords.shape[1]
    grid = coords[..., [1, 0]].view(B, 1, N, 2)
    return F.grid_sample(images, grid, mode='bilinear', padding_mode='border', align_corners=True).squeeze(2).permute(0, 2, 1)

def prepare_sparse_context(images, coords_full, fourier_encoder, num_sparse, device):
    B = images.size(0)
    idx = torch.randperm(coords_full.size(0), device=device)[:num_sparse]
    coords_sparse = coords_full[idx]
    pixels_sparse = sample_gt_at_coords(images, coords_sparse.unsqueeze(0).expand(B, -1, -1))
    pos_sparse = fourier_encoder(coords_sparse.unsqueeze(0).expand(B, -1, -1))
    return torch.cat([pixels_sparse, pos_sparse], dim=-1)

def prepare_full_context(images, coords_full, fourier_encoder):
    input_full, _, _ = prepare_model_input(images, coords_full, fourier_encoder)
    return input_full


In [None]:
from nf_feature_models import prepare_model_input

fourier_encoder = GaussianFourierFeatures(2, FOURIER_MAPPING_SIZE, 15.0).to(DEVICE)
model = CascadedPerceiverIO(
    input_dim=INPUT_DIM, queries_dim=QUERIES_DIM, logits_dim=LOGITS_DIM,
    latent_dims=(256, 384, 512), num_latents=(256, 256, 256), decoder_ff=True,
).to(DEVICE)
ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
model.load_state_dict(ckpt['model_state_dict'], strict=False)
fourier_encoder.load_state_dict(ckpt['fourier_encoder_state_dict'], strict=False)
model.eval()
fourier_encoder.eval()
print('Loaded', CKPT_PATH)


In [None]:
def eval_psnr(model, fourier_encoder, loader, device, use_sparse=True, num_sparse=None):
    model.eval()
    fourier_encoder.eval()
    n_sparse = num_sparse if num_sparse is not None else N_SPARSE
    mse_sum, n = 0.0, 0
    with torch.no_grad():
        for imgs, _ in loader:
            imgs = imgs.to(device)
            B = imgs.size(0)
            if use_sparse:
                input_data = prepare_sparse_context(imgs, coords_32, fourier_encoder, n_sparse, device)
            else:
                input_data = prepare_full_context(imgs, coords_32, fourier_encoder)
            target_pixels = rearrange(imgs, 'b c h w -> b (h w) c')
            queries_full = fourier_encoder(coords_32.unsqueeze(0).expand(B, -1, -1))
            reconstructed = model(input_data, queries=queries_full)
            mse_sum += F.mse_loss(reconstructed, target_pixels, reduction='sum').item()
            n += B * N_FULL
    mse = mse_sum / max(n, 1)
    return 10 * math.log10(1.0 / (mse + 1e-10))

transform = transforms.Compose([transforms.ToTensor()])
test_ds = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_ds, batch_size=64, shuffle=False)

psnr_sparse = eval_psnr(model, fourier_encoder, test_loader, DEVICE, use_sparse=True)
psnr_full = eval_psnr(model, fourier_encoder, test_loader, DEVICE, use_sparse=False)
print(f'PSNR (sparse context {N_SPARSE}): {psnr_sparse:.2f} dB')
print(f'PSNR (full context): {psnr_full:.2f} dB')


## Reconstruction gallery: sparse vs full context

In [None]:
imgs, _ = next(iter(test_loader))
imgs = imgs[:8].to(DEVICE)
B = imgs.size(0)
with torch.no_grad():
    input_sparse = prepare_sparse_context(imgs, coords_32, fourier_encoder, N_SPARSE, DEVICE)
    input_full = prepare_full_context(imgs, coords_32, fourier_encoder)
    queries_full = fourier_encoder(coords_32.unsqueeze(0).expand(B, -1, -1))
    recon_sparse = model(input_sparse, queries=queries_full)
    recon_full = model(input_full, queries=queries_full)
recon_sparse = rearrange(recon_sparse, 'b (h w) c -> b c h w', h=IMAGE_SIZE, w=IMAGE_SIZE)
recon_full = rearrange(recon_full, 'b (h w) c -> b c h w', h=IMAGE_SIZE, w=IMAGE_SIZE)

def to_vis(x):
    return (x.cpu().clamp(0, 1).permute(0, 2, 3, 1).numpy())
fig, axs = plt.subplots(3, 8, figsize=(16, 6))
for i in range(8):
    axs[0, i].imshow(to_vis(imgs[i])); axs[0, i].set_title('GT' if i==0 else ''); axs[0, i].axis('off')
    axs[1, i].imshow(to_vis(recon_sparse[i])); axs[1, i].set_title('Sparse ctx' if i==0 else ''); axs[1, i].axis('off')
    axs[2, i].imshow(to_vis(recon_full[i])); axs[2, i].set_title('Full ctx' if i==0 else ''); axs[2, i].axis('off')
plt.suptitle('Sparse-context model: recon with sparse vs full context')
plt.tight_layout()
plt.savefig('sparse_context_gallery.png', dpi=100)
plt.show()


## PSNR vs context fraction

In [None]:
fracs = [0.1, 0.15, 0.2, 0.3, 0.5, 1.0]
psnrs = []
for f in fracs:
    n = max(64, int(N_FULL * f))
    if f >= 1.0:
        p = eval_psnr(model, fourier_encoder, test_loader, DEVICE, use_sparse=False)
    else:
        p = eval_psnr(model, fourier_encoder, test_loader, DEVICE, use_sparse=True, num_sparse=n)
    psnrs.append(p)
plt.figure(figsize=(6, 4))
plt.plot(fracs, psnrs, 'o-')
plt.xlabel('Context fraction')
plt.ylabel('PSNR (dB)')
plt.title('Sparse-context model: PSNR vs context fraction at eval')
plt.tight_layout()
plt.savefig('sparse_context_psnr_vs_frac.png', dpi=100)
plt.show()
print('Context frac -> PSNR:', list(zip(fracs, [round(p,2) for p in psnrs])))


## Where the model "looks": sparse context positions (one sample)

In [None]:
torch.manual_seed(123)
img_one = imgs[:1]
n_vis = 128
# Use fixed indices so mask matches the context we feed to the model
idx_vis = torch.randperm(coords_32.size(0), device=DEVICE)[:n_vis]
coords_sparse = coords_32[idx_vis]
pixels_sparse = sample_gt_at_coords(img_one, coords_sparse.unsqueeze(0).expand(1, -1, -1))
pos_sparse = fourier_encoder(coords_sparse.unsqueeze(0).expand(1, -1, -1))
input_vis = torch.cat([pixels_sparse, pos_sparse], dim=-1)
# Mask: linear index -> (row, col) for 32x32 row-major grid
mask = torch.zeros(1, 1, IMAGE_SIZE, IMAGE_SIZE, device=DEVICE)
for i in range(n_vis):
    row, col = idx_vis[i].item() // IMAGE_SIZE, idx_vis[i].item() % IMAGE_SIZE
    mask[0, 0, row, col] = 1.0
fig, axs = plt.subplots(1, 3, figsize=(10, 4))
axs[0].imshow(to_vis(img_one[0])); axs[0].set_title('Image'); axs[0].axis('off')
axs[1].imshow(mask[0, 0].cpu().numpy(), cmap='hot'); axs[1].set_title(f'Sparse context positions (n={n_vis})'); axs[1].axis('off')
with torch.no_grad():
    q = fourier_encoder(coords_32.unsqueeze(0))
    recon_one = model(input_vis, queries=q)
recon_one = rearrange(recon_one, 'b (h w) c -> b c h w', h=IMAGE_SIZE, w=IMAGE_SIZE)
axs[2].imshow(to_vis(recon_one[0])); axs[2].set_title('Reconstruction'); axs[2].axis('off')
plt.suptitle('Example: which positions were given as context')
plt.tight_layout()
plt.savefig('sparse_context_positions.png', dpi=100)
plt.show()
