# Sparse Context: Evaluation & Visualizations

Load **checkpoint_sparse_best.pt** (or _last), evaluate PSNR under **sparse** vs **full** context, and visualize reconstructions and context-fraction sensitivity.

Comparison is **sparse-context recon** (partial image → full recon) vs **full-context recon** (whole image → full recon). If you trained with USE_NCE=False, both are recon-only; the extra plots (View A/B, Retrieval@ε, TDSM, etc.) still contrast the two setups.

In [None]:
import os
import math
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
from einops import rearrange, repeat

from nf_feature_models import CascadedPerceiverIO, GaussianFourierFeatures, create_coordinate_grid, prepare_model_input

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CHECKPOINT_DIR = 'checkpoints'
CKPT_PREFIX = 'checkpoint_sparse'
CKPT_PATH = os.path.join(CHECKPOINT_DIR, CKPT_PREFIX + '_best.pt')
if not os.path.isfile(CKPT_PATH):
    CKPT_PATH = os.path.join(CHECKPOINT_DIR, CKPT_PREFIX + '_last.pt')
assert os.path.isfile(CKPT_PATH), f'No {CKPT_PREFIX}_*.pt in {CHECKPOINT_DIR}. Run SparseContext_CIFAR10.ipynb first.'
print('Device:', DEVICE, 'Checkpoint:', CKPT_PATH)


In [None]:
IMAGE_SIZE = 32
CHANNELS = 3
FOURIER_MAPPING_SIZE = 96
POS_EMBED_DIM = FOURIER_MAPPING_SIZE * 2
INPUT_DIM = CHANNELS + POS_EMBED_DIM
QUERIES_DIM = POS_EMBED_DIM
LOGITS_DIM = CHANNELS
CONTEXT_FRAC_TRAIN = 0.2
coords_32 = create_coordinate_grid(IMAGE_SIZE, IMAGE_SIZE, DEVICE)
N_FULL = coords_32.size(0)
N_SPARSE = max(64, int(N_FULL * CONTEXT_FRAC_TRAIN))

def sample_gt_at_coords(images, coords):
    B, C, H, W = images.shape
    N = coords.shape[1]
    grid = coords[..., [1, 0]].view(B, 1, N, 2)
    return F.grid_sample(images, grid, mode='bilinear', padding_mode='border', align_corners=True).squeeze(2).permute(0, 2, 1)

def prepare_sparse_context(images, coords_full, fourier_encoder, num_sparse, device):
    B = images.size(0)
    idx = torch.randperm(coords_full.size(0), device=device)[:num_sparse]
    coords_sparse = coords_full[idx]
    pixels_sparse = sample_gt_at_coords(images, coords_sparse.unsqueeze(0).expand(B, -1, -1))
    pos_sparse = fourier_encoder(coords_sparse.unsqueeze(0).expand(B, -1, -1))
    return torch.cat([pixels_sparse, pos_sparse], dim=-1)

def prepare_full_context(images, coords_full, fourier_encoder):
    input_full, _, _ = prepare_model_input(images, coords_full, fourier_encoder)
    return input_full


In [None]:
import torch.nn as nn
fourier_encoder = GaussianFourierFeatures(2, FOURIER_MAPPING_SIZE, 15.0).to(DEVICE)
model = CascadedPerceiverIO(
    input_dim=INPUT_DIM, queries_dim=QUERIES_DIM, logits_dim=LOGITS_DIM,
    latent_dims=(256, 384, 512), num_latents=(256, 256, 256), decoder_ff=True,
).to(DEVICE)
PROJ_DIM = 128
projection_head = nn.Linear(QUERIES_DIM, PROJ_DIM).to(DEVICE)
ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
model.load_state_dict(ckpt['model_state_dict'], strict=False)
fourier_encoder.load_state_dict(ckpt['fourier_encoder_state_dict'], strict=False)
if 'projection_head_state_dict' in ckpt:
    projection_head.load_state_dict(ckpt['projection_head_state_dict'], strict=False)
model.eval()
fourier_encoder.eval()
projection_head.eval()
print('Loaded', CKPT_PATH)


In [None]:
def eval_psnr(model, fourier_encoder, loader, device, use_sparse=True, num_sparse=None):
    model.eval()
    fourier_encoder.eval()
    n_sparse = num_sparse if num_sparse is not None else N_SPARSE
    mse_sum, n = 0.0, 0
    with torch.no_grad():
        for imgs, _ in loader:
            imgs = imgs.to(device)
            B = imgs.size(0)
            if use_sparse:
                input_data = prepare_sparse_context(imgs, coords_32, fourier_encoder, n_sparse, device)
            else:
                input_data = prepare_full_context(imgs, coords_32, fourier_encoder)
            target_pixels = rearrange(imgs, 'b c h w -> b (h w) c')
            queries_full = fourier_encoder(coords_32.unsqueeze(0).expand(B, -1, -1))
            reconstructed = model(input_data, queries=queries_full)
            mse_sum += F.mse_loss(reconstructed, target_pixels, reduction='sum').item()
            n += B * N_FULL
    mse = mse_sum / max(n, 1)
    return 10 * math.log10(1.0 / (mse + 1e-10))

transform = transforms.Compose([transforms.ToTensor()])
test_ds = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_ds, batch_size=64, shuffle=False)

psnr_sparse = eval_psnr(model, fourier_encoder, test_loader, DEVICE, use_sparse=True)
psnr_full = eval_psnr(model, fourier_encoder, test_loader, DEVICE, use_sparse=False)
print(f'PSNR (sparse context {N_SPARSE}): {psnr_sparse:.2f} dB')
print(f'PSNR (full context): {psnr_full:.2f} dB')


## Reconstruction gallery: sparse vs full context

In [None]:
imgs, _ = next(iter(test_loader))
imgs = imgs[:8].to(DEVICE)
B = imgs.size(0)
with torch.no_grad():
    input_sparse = prepare_sparse_context(imgs, coords_32, fourier_encoder, N_SPARSE, DEVICE)
    input_full = prepare_full_context(imgs, coords_32, fourier_encoder)
    queries_full = fourier_encoder(coords_32.unsqueeze(0).expand(B, -1, -1))
    recon_sparse = model(input_sparse, queries=queries_full)
    recon_full = model(input_full, queries=queries_full)
recon_sparse = rearrange(recon_sparse, 'b (h w) c -> b c h w', h=IMAGE_SIZE, w=IMAGE_SIZE)
recon_full = rearrange(recon_full, 'b (h w) c -> b c h w', h=IMAGE_SIZE, w=IMAGE_SIZE)

def to_vis(x):
    if x.dim() == 3:
        x = x.unsqueeze(0)
    out = x.cpu().clamp(0, 1).permute(0, 2, 3, 1).numpy()
    return out[0] if out.shape[0] == 1 else out
fig, axs = plt.subplots(3, 8, figsize=(16, 6))
for i in range(8):
    axs[0, i].imshow(to_vis(imgs[i])); axs[0, i].set_title('GT' if i==0 else ''); axs[0, i].axis('off')
    axs[1, i].imshow(to_vis(recon_sparse[i])); axs[1, i].set_title('Sparse ctx' if i==0 else ''); axs[1, i].axis('off')
    axs[2, i].imshow(to_vis(recon_full[i])); axs[2, i].set_title('Full ctx' if i==0 else ''); axs[2, i].axis('off')
plt.suptitle('Sparse-context model: recon with sparse vs full context')
plt.tight_layout()
plt.savefig('sparse_context_gallery.png', dpi=100)
plt.show()


## PSNR vs context fraction

In [None]:
fracs = [0.1, 0.15, 0.2, 0.3, 0.5, 1.0]
psnrs = []
for f in fracs:
    n = max(64, int(N_FULL * f))
    if f >= 1.0:
        p = eval_psnr(model, fourier_encoder, test_loader, DEVICE, use_sparse=False)
    else:
        p = eval_psnr(model, fourier_encoder, test_loader, DEVICE, use_sparse=True, num_sparse=n)
    psnrs.append(p)
plt.figure(figsize=(6, 4))
plt.plot(fracs, psnrs, 'o-')
plt.xlabel('Context fraction')
plt.ylabel('PSNR (dB)')
plt.title('Sparse-context model: PSNR vs context fraction at eval')
plt.tight_layout()
plt.savefig('sparse_context_psnr_vs_frac.png', dpi=100)
plt.show()
print('Context frac -> PSNR:', list(zip(fracs, [round(p,2) for p in psnrs])))


## Where the model "looks": sparse context positions (one sample)

In [None]:
torch.manual_seed(123)
img_one = imgs[:1]
n_vis = 128
# Use fixed indices so mask matches the context we feed to the model
idx_vis = torch.randperm(coords_32.size(0), device=DEVICE)[:n_vis]
coords_sparse = coords_32[idx_vis]
pixels_sparse = sample_gt_at_coords(img_one, coords_sparse.unsqueeze(0).expand(1, -1, -1))
pos_sparse = fourier_encoder(coords_sparse.unsqueeze(0).expand(1, -1, -1))
input_vis = torch.cat([pixels_sparse, pos_sparse], dim=-1)
# Mask: linear index -> (row, col) for 32x32 row-major grid
mask = torch.zeros(1, 1, IMAGE_SIZE, IMAGE_SIZE, device=DEVICE)
for i in range(n_vis):
    row, col = idx_vis[i].item() // IMAGE_SIZE, idx_vis[i].item() % IMAGE_SIZE
    mask[0, 0, row, col] = 1.0
fig, axs = plt.subplots(1, 3, figsize=(10, 4))
axs[0].imshow(to_vis(img_one[0])); axs[0].set_title('Image'); axs[0].axis('off')
axs[1].imshow(mask[0, 0].cpu().numpy(), cmap='hot'); axs[1].set_title(f'Sparse context positions (n={n_vis})'); axs[1].axis('off')
with torch.no_grad():
    q = fourier_encoder(coords_32.unsqueeze(0))
    recon_one = model(input_vis, queries=q)
recon_one = rearrange(recon_one, 'b (h w) c -> b c h w', h=IMAGE_SIZE, w=IMAGE_SIZE)
axs[2].imshow(to_vis(recon_one[0])); axs[2].set_title('Reconstruction'); axs[2].axis('off')
plt.suptitle('Example: which positions were given as context')
plt.tight_layout()
plt.savefig('sparse_context_positions.png', dpi=100)
plt.show()


## Full-context recon (baseline) vs Sparse-context recon

Load the **full-image recon** baseline (checkpoint_best.pt), then run the same correspondence and TDSM visualizations to contrast with the **sparse-context** model (partial image as input). No NCE required: this is partial vs full recon.

In [None]:
# Load full-context recon baseline (whole image as input)
CKPT_BASELINE = os.path.join(CHECKPOINT_DIR, 'checkpoint_best.pt')
if not os.path.isfile(CKPT_BASELINE):
    CKPT_BASELINE = os.path.join(CHECKPOINT_DIR, 'checkpoint_last.pt')
baseline_model = baseline_fourier = None
if os.path.isfile(CKPT_BASELINE):
    baseline_fourier = GaussianFourierFeatures(2, FOURIER_MAPPING_SIZE, 15.0).to(DEVICE)
    baseline_model = CascadedPerceiverIO(
        input_dim=INPUT_DIM, queries_dim=QUERIES_DIM, logits_dim=LOGITS_DIM,
        latent_dims=(256, 384, 512), num_latents=(256, 256, 256), decoder_ff=True,
    ).to(DEVICE)
    ckpt_bl = torch.load(CKPT_BASELINE, map_location=DEVICE)
    baseline_model.load_state_dict(ckpt_bl['model_state_dict'], strict=False)
    baseline_fourier.load_state_dict(ckpt_bl['fourier_encoder_state_dict'], strict=False)
    baseline_model.eval()
    baseline_fourier.eval()
    print('Loaded baseline (full recon):', CKPT_BASELINE)
else:
    print('No baseline checkpoint found; comparison plots will be sparse-context only.')

# Optional: Full-context + NCE-trained (from SoftInfoNCE_OmniField_CIFAR10.ipynb; save that model to checkpoint_nce_best.pt)
CKPT_NCE = os.path.join(CHECKPOINT_DIR, 'checkpoint_nce_best.pt')
if not os.path.isfile(CKPT_NCE):
    CKPT_NCE = os.path.join(CHECKPOINT_DIR, 'softnce_best.pt')
nce_model = nce_fourier = None
if os.path.isfile(CKPT_NCE):
    nce_fourier = GaussianFourierFeatures(2, FOURIER_MAPPING_SIZE, 15.0).to(DEVICE)
    nce_model = CascadedPerceiverIO(
        input_dim=INPUT_DIM, queries_dim=QUERIES_DIM, logits_dim=LOGITS_DIM,
        latent_dims=(256, 384, 512), num_latents=(256, 256, 256), decoder_ff=True,
    ).to(DEVICE)
    ckpt_nce = torch.load(CKPT_NCE, map_location=DEVICE)
    nce_model.load_state_dict(ckpt_nce['model_state_dict'], strict=False)
    nce_fourier.load_state_dict(ckpt_nce['fourier_encoder_state_dict'], strict=False)
    nce_model.eval()
    nce_fourier.eval()
    print('Loaded Full context + NCE:', CKPT_NCE)
else:
    print('No NCE checkpoint (checkpoint_nce_best.pt / softnce_best.pt); run SoftInfoNCE notebook and save to include.')

def get_residual(model, data):
    residual = None
    for block in model.encoder_blocks:
        residual = block(x=residual, context=data, mask=None, residual=residual)
    for sa_block in model.self_attn_blocks:
        residual = sa_block[0](residual) + residual
        residual = sa_block[1](residual) + residual
    return residual

def get_rgb_and_phi_raw(model, queries, residual):
    x = model.decoder_cross_attn(queries, context=residual)
    x = x + queries
    if model.decoder_ff is not None:
        x = x + model.decoder_ff(x)
    return model.to_logits(x), x

def decoder_forward(model, queries, context):
    x = model.decoder_cross_attn(queries, context=context)
    x = x + queries
    if getattr(model, 'decoder_ff', None) is not None:
        x = x + model.decoder_ff(x)
    return model.to_logits(x)

def sample_affine_params(batch_size, device, scale_range=(0.85, 1.0), max_translate=0.1, max_angle_deg=12):
    angle = (torch.rand(batch_size, device=device) * 2 - 1) * (max_angle_deg * math.pi / 180)
    scale = scale_range[0] + torch.rand(batch_size, device=device) * (scale_range[1] - scale_range[0])
    tx = (torch.rand(batch_size, device=device) * 2 - 1) * max_translate
    ty = (torch.rand(batch_size, device=device) * 2 - 1) * max_translate
    c, s = torch.cos(angle), torch.sin(angle)
    R = torch.stack([c*scale, -s*scale, s*scale, c*scale], dim=-1).view(batch_size, 2, 2)
    t = torch.stack([tx, ty], dim=1)
    return R, t

def apply_affine_to_coords(coords, R, t):
    return torch.einsum('bed,bnd->bne', R, coords) + t.unsqueeze(1)

def apply_affine_to_image(images, R, t):
    R_inv = torch.inverse(R)
    theta = torch.cat([R_inv, -(R_inv @ t.unsqueeze(2))], dim=2)
    grid = F.affine_grid(theta, images.size(), align_corners=True)
    return F.grid_sample(images, grid, mode='bilinear', padding_mode='border', align_corners=True)

def norm_to_pixel(coords_norm, h, w):
    y, x = coords_norm[..., 0], coords_norm[..., 1]
    row = (y + 1) / 2 * (h - 1)
    col = (x + 1) / 2 * (w - 1)
    return col.cpu().numpy(), row.cpu().numpy()

In [None]:
# Two-view batch and correspondence quantities (sparse-context vs full-context baseline)
N_ANCHORS, N_CANDIDATES = 128, 512
TAU, SIGMA = 0.1, 0.08
b_show = 0
n_anchors_show = 4

imgs_viz, _ = next(iter(test_loader))
imgs_viz = imgs_viz[:4].to(DEVICE)
B = imgs_viz.size(0)
R, t = sample_affine_params(B, DEVICE)
imgs_b = apply_affine_to_image(imgs_viz, R, t)

# Sparse model: sparse context for A and B
input_sparse_a = prepare_sparse_context(imgs_viz, coords_32, fourier_encoder, N_SPARSE, DEVICE)
input_sparse_b = prepare_sparse_context(imgs_b, coords_32, fourier_encoder, N_SPARSE, DEVICE)
residual_s_a = get_residual(model, input_sparse_a)
residual_s_b = get_residual(model, input_sparse_b)
anchors_a = (torch.rand(B, N_ANCHORS, 2, device=DEVICE) * 2 - 1)
candidates_b = (torch.rand(B, N_CANDIDATES, 2, device=DEVICE) * 2 - 1)
q_a = fourier_encoder(anchors_a)
q_b = fourier_encoder(candidates_b)
_, phi_raw_s_a = get_rgb_and_phi_raw(model, q_a, residual_s_a)
_, phi_raw_s_b = get_rgb_and_phi_raw(model, q_b, residual_s_b)
phi_s_a = F.normalize(projection_head(phi_raw_s_a), dim=-1)
phi_s_b = F.normalize(projection_head(phi_raw_s_b), dim=-1)
logits_s = torch.bmm(phi_s_a, phi_s_b.transpose(1, 2)) / TAU
xi_mapped = apply_affine_to_coords(anchors_a, R, t)
sqd = ((candidates_b.unsqueeze(1) - xi_mapped.unsqueeze(2)) ** 2).sum(-1)
w_s = torch.exp(-sqd / (2 * SIGMA ** 2))
w_s = w_s / (w_s.sum(dim=2, keepdim=True) + 1e-8)
pred_coords_s = candidates_b[b_show][logits_s[b_show].argmax(dim=1)]
coords_b_batch = candidates_b

# Baseline: full context for A and B; random projection for phi (same arch, no contrastive training)
if baseline_model is not None:
    input_full_a = prepare_full_context(imgs_viz, coords_32, fourier_encoder)
    input_full_b = prepare_full_context(imgs_b, coords_32, baseline_fourier)
    residual_b_a = get_residual(baseline_model, input_full_a)
    residual_b_b = get_residual(baseline_model, input_full_b)
    proj_bl = nn.Linear(QUERIES_DIM, PROJ_DIM).to(DEVICE)
    _, phi_raw_b_a = get_rgb_and_phi_raw(baseline_model, q_a, residual_b_a)
    _, phi_raw_b_b = get_rgb_and_phi_raw(baseline_model, q_b, residual_b_b)
    phi_b_a = F.normalize(proj_bl(phi_raw_b_a), dim=-1)
    phi_b_b = F.normalize(proj_bl(phi_raw_b_b), dim=-1)
    logits_b = torch.bmm(phi_b_a, phi_b_b.transpose(1, 2)) / TAU
    w_b = torch.exp(-sqd / (2 * SIGMA ** 2))
    w_b = w_b / (w_b.sum(dim=2, keepdim=True) + 1e-8)
    pred_coords_b = candidates_b[b_show][logits_b[b_show].argmax(dim=1)]

In [None]:
# View A vs View B: anchors on A, GT and predicted match on B (Sparse context vs Full-context baseline)
n_show = min(8, anchors_a.size(1))
fig, axs = plt.subplots(2 if baseline_model is not None else 1, 2, figsize=(10, 5 if baseline_model is not None else 2.5))
if baseline_model is None:
    axs = axs.reshape(1, -1)
rows_data = [('Sparse context', pred_coords_s)]
if baseline_model is not None:
    rows_data.append(('Full context (baseline)', pred_coords_b))
for row, (name, pred_coords) in enumerate(rows_data):
    axs[row, 0].imshow(to_vis(imgs_viz[b_show]))
    axs[row, 0].set_title(f'{name}: View A (anchors)')
    axs[row, 0].axis('off')
    cx_a, cy_a = norm_to_pixel(anchors_a[b_show, :n_show], IMAGE_SIZE, IMAGE_SIZE)
    axs[row, 0].scatter(cx_a, cy_a, c='lime', s=40, marker='o', edgecolors='black', linewidths=0.5)
    axs[row, 1].imshow(to_vis(imgs_b[b_show]))
    axs[row, 1].set_title(f'{name}: View B (GT vs pred)')
    axs[row, 1].axis('off')
    cx_gt, cy_gt = norm_to_pixel(xi_mapped[b_show, :n_show], IMAGE_SIZE, IMAGE_SIZE)
    cx_pr, cy_pr = norm_to_pixel(pred_coords[:n_show], IMAGE_SIZE, IMAGE_SIZE)
    axs[row, 1].scatter(cx_gt, cy_gt, c='lime', s=60, marker='+', linewidths=2, label='GT T(x)')
    axs[row, 1].scatter(cx_pr, cy_pr, c='cyan', s=40, marker='x', linewidths=1.5, label='pred')
    axs[row, 1].legend(loc='upper right', fontsize=8)
plt.suptitle('Correspondence: Sparse context vs Full-context baseline')
plt.tight_layout()
plt.savefig('sparse_context_viewA_viewB_vs_baseline.png', dpi=100)
plt.show()

In [None]:
# Retrieval @ ε: % of anchors with error < ε (Sparse context vs Full-context baseline)
eps_values = [0.05, 0.1, 0.15, 0.2, 0.25, 0.3]
err_s = (pred_coords_s - xi_mapped[b_show]).norm(dim=-1).cpu().numpy()
acc_s = [(err_s < eps).mean() * 100 for eps in eps_values]
fig, ax = plt.subplots(figsize=(6, 3))
x = np.arange(len(eps_values))
w_bar = 0.35
ax.bar(x - w_bar/2, acc_s, width=w_bar, label='Sparse context', color='steelblue')
if baseline_model is not None:
    err_b = (pred_coords_b - xi_mapped[b_show]).norm(dim=-1).cpu().numpy()
    acc_b = [(err_b < eps).mean() * 100 for eps in eps_values]
    ax.bar(x + w_bar/2, acc_b, width=w_bar, label='Baseline', color='coral', alpha=0.8)
ax.set_xticks(x)
ax.set_xticklabels([str(e) for e in eps_values])
ax.set_xlabel('ε (normalized coord distance)')
ax.set_ylabel('% anchors with error < ε')
ax.set_title('Retrieval accuracy: Sparse context vs Full-context baseline')
ax.legend()
plt.tight_layout()
plt.savefig('sparse_context_retrieval_at_eps_vs_baseline.png', dpi=100)
plt.show()

In [None]:
# Soft weight concentration (Sparse context vs Full-context baseline)
fig, axs = plt.subplots(1, 2 if baseline_model is not None else 1, figsize=(6 if baseline_model else 3, 3))
if baseline_model is None:
    axs = [axs]
w_max_s = w_s.max(dim=2).values.flatten().cpu().numpy()
axs[0].hist(w_max_s, bins=30, color='steelblue', edgecolor='black', alpha=0.8)
axs[0].set_xlabel('max_j w_ij')
axs[0].set_ylabel('Count (anchors)')
axs[0].set_title('Sparse context: weight concentration')
axs[0].axvline(w_max_s.mean(), color='red', linestyle='--', label=f'mean={w_max_s.mean():.3f}')
axs[0].legend()
if baseline_model is not None:
    w_max_b = w_b.max(dim=2).values.flatten().cpu().numpy()
    axs[1].hist(w_max_b, bins=30, color='coral', edgecolor='black', alpha=0.8)
    axs[1].set_xlabel('max_j w_ij')
    axs[1].set_ylabel('Count (anchors)')
    axs[1].set_title('Baseline: weight concentration')
    axs[1].axvline(w_max_b.mean(), color='red', linestyle='--', label=f'mean={w_max_b.mean():.3f}')
    axs[1].legend()
plt.suptitle('Peaked weights = confident correspondence')
plt.tight_layout()
plt.savefig('sparse_context_weight_concentration_vs_baseline.png', dpi=100)
plt.show()

In [None]:
# Heatmap overlay: soft weights with GT (+) and predicted (x) — Sparse context
fig, axs = plt.subplots(2, n_anchors_show, figsize=(12, 5))
for i in range(n_anchors_show):
    heat = w_s[b_show, i].cpu().numpy().reshape(IMAGE_SIZE, IMAGE_SIZE)
    axs[0, i].imshow(heat, cmap='hot')
    axs[0, i].set_title(f'Anchor {i}')
    cx_gt, cy_gt = norm_to_pixel(xi_mapped[b_show, i:i+1], IMAGE_SIZE, IMAGE_SIZE)
    cx_pr, cy_pr = norm_to_pixel(pred_coords_s[i:i+1], IMAGE_SIZE, IMAGE_SIZE)
    axs[0, i].scatter(cx_gt, cy_gt, c='lime', s=80, marker='+', linewidths=2)
    axs[0, i].scatter(cx_pr, cy_pr, c='cyan', s=50, marker='x', linewidths=1.5)
    axs[0, i].axis('off')
    axs[1, i].imshow(to_vis(imgs_b[b_show]))
    axs[1, i].scatter(cx_gt, cy_gt, c='lime', s=80, marker='+', linewidths=2)
    axs[1, i].scatter(cx_pr, cy_pr, c='cyan', s=50, marker='x', linewidths=1.5)
    axs[1, i].axis('off')
plt.suptitle('Sparse context: soft weights — lime=GT, cyan=argmax pred')
plt.tight_layout()
plt.savefig('sparse_context_heatmap_overlay.png', dpi=100)
plt.show()

In [None]:
# Feature-level: pos vs neg cosine similarity (Sparse context vs Full-context baseline)
S_s = (logits_s[b_show] * TAU).detach().cpu().numpy()
N_a, N_b = S_s.shape
sqd_b = ((coords_b_batch[b_show].unsqueeze(0) - xi_mapped[b_show].unsqueeze(1)) ** 2).sum(-1).cpu().numpy()
j_gt = np.argmin(sqd_b, axis=1)
pos_sims_s = S_s[np.arange(N_a), j_gt]
neg_mask = np.ones((N_a, N_b), dtype=bool)
neg_mask[np.arange(N_a), j_gt] = False
neg_sims_s = S_s[neg_mask]
fig, ax = plt.subplots(figsize=(7, 4))
ax.hist(neg_sims_s, bins=40, alpha=0.5, color='coral', label='Sparse context neg', density=True)
ax.hist(pos_sims_s, bins=30, alpha=0.5, color='green', label='Sparse context pos', density=True)
if baseline_model is not None:
    S_b = (logits_b[b_show] * TAU).detach().cpu().numpy()
    pos_sims_b = S_b[np.arange(N_a), j_gt]
    neg_sims_b = S_b[neg_mask]
    ax.hist(neg_sims_b, bins=40, alpha=0.4, color='gray', histtype='step', linewidth=2, label='Baseline neg', density=True)
    ax.hist(pos_sims_b, bins=30, alpha=0.5, color='blue', histtype='step', linewidth=2, label='Baseline pos', density=True)
ax.set_xlabel('Cosine similarity φ(anchor)·φ(candidate)')
ax.set_ylabel('Density')
ax.set_title('Feature-level: Sparse context (filled) vs Full-context baseline (outline)')
ax.legend(loc='upper left', fontsize=8)
plt.tight_layout()
plt.savefig('sparse_context_feature_pos_neg_vs_baseline.png', dpi=100)
plt.show()
print('Sparse context: pos mean={:.4f} neg mean={:.4f} gap={:.4f}'.format(pos_sims_s.mean(), neg_sims_s.mean(), pos_sims_s.mean()-neg_sims_s.mean()))
if baseline_model is not None:
    print('Baseline:   pos mean={:.4f} neg mean={:.4f} gap={:.4f}'.format(pos_sims_b.mean(), neg_sims_b.mean(), pos_sims_b.mean()-neg_sims_b.mean()))

## TDSM: per-token decoded spatial map (Full context vs Sparse context)

Same images; baseline gets **full context**, sparse model gets **sparse context**. Compare per-token reconstructions.

In [None]:
def get_tdsm(model, fourier_enc, data, coords, device, num_tokens=256, token_step=4):
    with torch.no_grad():
        residual = get_residual(model, data)
        B = data.size(0)
        queries_32 = fourier_enc(coords.unsqueeze(0).expand(B, -1, -1)).to(device)
        component_images = []
        for k in range(0, num_tokens, token_step):
            ctx_k = residual[:, k:k+1, :]
            logits_k = decoder_forward(model, queries_32, ctx_k)
            img_k = logits_k.reshape(B, IMAGE_SIZE, IMAGE_SIZE, 3)
            component_images.append(img_k)
        component_images = torch.stack(component_images, dim=1)
        tdsm = component_images.mean(dim=-1)
    return tdsm

TDSM_TOKEN_STEP = 4
imgs_tdsm = imgs_viz[:4]
input_tdsm_full = prepare_full_context(imgs_tdsm, coords_32, fourier_encoder)
input_tdsm_sparse = prepare_sparse_context(imgs_tdsm, coords_32, fourier_encoder, N_SPARSE, DEVICE)
tdsm_sparse = get_tdsm(model, fourier_encoder, input_tdsm_sparse, coords_32, DEVICE, token_step=TDSM_TOKEN_STEP)
if baseline_model is not None:
    tdsm_baseline = get_tdsm(baseline_model, baseline_fourier, input_tdsm_full, coords_32, DEVICE, token_step=TDSM_TOKEN_STEP)
else:
    tdsm_baseline = None

token_indices = [0, 16, 32, 48]
tdsm_slice_idx = [k // TDSM_TOKEN_STEP for k in token_indices]
n_show_t = len(token_indices)
sample_idx = 0
n_rows = 2 if baseline_model is not None else 1
fig, axs = plt.subplots(n_rows, n_show_t, figsize=(12, 5 if n_rows == 2 else 2.5))
if n_rows == 1:
    axs = axs.reshape(1, -1)
for i, (k, sk) in enumerate(zip(token_indices, tdsm_slice_idx)):
    if baseline_model is not None:
        axs[0, i].imshow(tdsm_baseline[sample_idx, sk].cpu().numpy(), cmap='viridis')
        axs[0, i].set_title('Token %d Baseline' % k)
        axs[0, i].axis('off')
    axs[n_rows-1, i].imshow(tdsm_sparse[sample_idx, sk].cpu().numpy(), cmap='viridis')
    axs[n_rows-1, i].set_title('Token %d Sparse context' % k)
    axs[n_rows-1, i].axis('off')
plt.suptitle('TDSM: Full context vs Sparse context')
plt.tight_layout()
plt.savefig('sparse_context_tdsm_vs_baseline.png', dpi=100)
plt.show()

### TDSM t-SNE: class-colored token-level reconstruction (no pooling)

Each point = one token's **full spatial map** (H×W flattened). All token maps from all images; color by image class. PCA + t-SNE. Compare **Sparse context**, **Full context** (baseline), and **Full context + NCE** (if checkpoint_nce_best.pt exists; save from SoftInfoNCE notebook).

In [None]:
# No pooling: each point = one token's full spatial map (H*W), color by image class
N_VAL_TDSM = min(400, len(test_loader.dataset))
all_feat_sparse, all_feat_baseline, all_feat_nce, all_labels = [], [], [], []
n_done = 0
with torch.no_grad():
    for imgs_batch, labels_batch in test_loader:
        if n_done >= N_VAL_TDSM:
            break
        imgs_batch = imgs_batch.to(DEVICE)
        input_full = prepare_full_context(imgs_batch, coords_32, fourier_encoder)
        input_sparse = prepare_sparse_context(imgs_batch, coords_32, fourier_encoder, N_SPARSE, DEVICE)
        tdsm_s = get_tdsm(model, fourier_encoder, input_sparse, coords_32, DEVICE, token_step=TDSM_TOKEN_STEP)
        # (B, num_tokens, H, W) -> (B*num_tokens, H*W)
        B, nt, h, w = tdsm_s.shape
        all_feat_sparse.append(tdsm_s.reshape(B * nt, h * w).cpu().numpy())
        all_labels.append(np.repeat(labels_batch.numpy(), nt))
        if baseline_model is not None:
            tdsm_b = get_tdsm(baseline_model, baseline_fourier, input_full, coords_32, DEVICE, token_step=TDSM_TOKEN_STEP)
            all_feat_baseline.append(tdsm_b.reshape(B * nt, h * w).cpu().numpy())
        if nce_model is not None:
            tdsm_n = get_tdsm(nce_model, nce_fourier, input_full, coords_32, DEVICE, token_step=TDSM_TOKEN_STEP)
            all_feat_nce.append(tdsm_n.reshape(B * nt, h * w).cpu().numpy())
        n_done += imgs_batch.size(0)

X_sparse = np.concatenate(all_feat_sparse, axis=0)
y_all = np.concatenate(all_labels, axis=0)
if baseline_model is not None:
    X_baseline = np.concatenate(all_feat_baseline, axis=0)
if nce_model is not None:
    X_nce = np.concatenate(all_feat_nce, axis=0)
print('TDSM points (no pooling):', X_sparse.shape[0], 'features dim', X_sparse.shape[1])

try:
    from sklearn.decomposition import PCA
    from sklearn.manifold import TSNE
    n_comp = min(50, X_sparse.shape[1], X_sparse.shape[0] - 1)
    X_sparse_pca = PCA(n_components=n_comp).fit_transform(X_sparse)
    X_sparse_tsne = TSNE(n_components=2, random_state=42, perplexity=30).fit_transform(X_sparse_pca)
    n_plots = 1 + (1 if baseline_model is not None else 0) + (1 if nce_model is not None else 0)
    fig, axs = plt.subplots(1, n_plots, figsize=(6 * n_plots, 5))
    if n_plots == 1:
        axs = [axs]
    idx = 0
    sc = axs[idx].scatter(X_sparse_tsne[:, 0], X_sparse_tsne[:, 1], c=y_all, cmap='tab10', s=8, alpha=0.6)
    axs[idx].set_title('Sparse context (each point = one token map)')
    axs[idx].set_xlabel('t-SNE 1'); axs[idx].set_ylabel('t-SNE 2')
    idx += 1
    if baseline_model is not None:
        X_base_pca = PCA(n_components=n_comp).fit_transform(X_baseline)
        X_base_tsne = TSNE(n_components=2, random_state=42, perplexity=30).fit_transform(X_base_pca)
        axs[idx].scatter(X_base_tsne[:, 0], X_base_tsne[:, 1], c=y_all, cmap='tab10', s=8, alpha=0.6)
        axs[idx].set_title('Full context (each point = one token map)')
        axs[idx].set_xlabel('t-SNE 1'); axs[idx].set_ylabel('t-SNE 2')
        idx += 1
    if nce_model is not None:
        X_nce_pca = PCA(n_components=n_comp).fit_transform(X_nce)
        X_nce_tsne = TSNE(n_components=2, random_state=42, perplexity=30).fit_transform(X_nce_pca)
        axs[idx].scatter(X_nce_tsne[:, 0], X_nce_tsne[:, 1], c=y_all, cmap='tab10', s=8, alpha=0.6)
        axs[idx].set_title('Full context + NCE (each point = one token map)')
        axs[idx].set_xlabel('t-SNE 1'); axs[idx].set_ylabel('t-SNE 2')
    plt.colorbar(sc, ax=axs, label='Class', shrink=0.6)
    plt.suptitle('TDSM: all token maps (no pooling), colored by CIFAR-10 class')
    plt.tight_layout()
    plt.savefig('sparse_context_tdsm_tsne_class.png', dpi=100)
    plt.show()
except ImportError as e:
    print('Install sklearn for t-SNE/PCA:', e)

In [None]:
# Where Sparse context differs from Full-context baseline: spatial diff and object sensitivity
if baseline_model is not None and tdsm_baseline is not None:
    diff_spatial = (tdsm_sparse - tdsm_baseline).abs().mean(dim=1).cpu().numpy()
    fig, axs = plt.subplots(2, 4, figsize=(14, 6))
    for i in range(4):
        axs[0, i].imshow(to_vis(imgs_tdsm[i]))
        axs[0, i].set_title('Input' if i == 0 else ''); axs[0, i].axis('off')
        im = axs[1, i].imshow(diff_spatial[i], cmap='hot')
        axs[1, i].set_title('|Sparse − Full ctx|' if i == 0 else ''); axs[1, i].axis('off')
    plt.colorbar(im, ax=axs[1, :], shrink=0.6, label='Mean |diff|')
    plt.suptitle('Where Sparse context differs from Full-context baseline')
    plt.tight_layout()
    plt.savefig('sparse_context_tdsm_spatial_diff.png', dpi=100)
    plt.show()

    H, W = IMAGE_SIZE, IMAGE_SIZE
    margin = 8
    obj_mask = np.zeros((H, W), dtype=np.float32)
    obj_mask[margin:H-margin, margin:W-margin] = 1.0
    bg_mask = 1.0 - obj_mask
    obj_mask = torch.from_numpy(obj_mask).to(DEVICE).view(1, 1, H, W)
    bg_mask = torch.from_numpy(bg_mask).to(DEVICE).view(1, 1, H, W)
    n_tokens = tdsm_baseline.size(1)
    obj_bl = (tdsm_baseline * obj_mask).sum(dim=(2, 3)) / (obj_mask.sum() + 1e-8)
    bg_bl = (tdsm_baseline * bg_mask).sum(dim=(2, 3)) / (bg_mask.sum() + 1e-8)
    obj_nce = (tdsm_sparse * obj_mask).sum(dim=(2, 3)) / (obj_mask.sum() + 1e-8)
    bg_nce = (tdsm_sparse * bg_mask).sum(dim=(2, 3)) / (bg_mask.sum() + 1e-8)
    sens_bl = (obj_bl - bg_bl).mean(dim=0).cpu().numpy()
    sens_nce = (obj_nce - bg_nce).mean(dim=0).cpu().numpy()
    fig, axs = plt.subplots(1, 2, figsize=(10, 4))
    axs[0].bar(np.arange(n_tokens), sens_bl, color='steelblue', alpha=0.8); axs[0].set_title('Baseline'); axs[0].set_ylabel('Object sensitivity')
    axs[1].bar(np.arange(n_tokens), sens_nce, color='green', alpha=0.8); axs[1].set_title('Sparse context'); axs[1].set_ylabel('Object sensitivity')
    plt.suptitle('Per-token object vs background (center − border)')
    plt.tight_layout()
    plt.savefig('sparse_context_tdsm_object_sensitivity.png', dpi=100)
    plt.show()
    plt.figure(figsize=(5, 5))
    plt.scatter(sens_bl, sens_nce, alpha=0.7)
    plt.plot([sens_bl.min(), sens_bl.max()], [sens_bl.min(), sens_bl.max()], 'r--', label='y=x')
    plt.xlabel('Full-context object sensitivity'); plt.ylabel('Sparse context object sensitivity')
    plt.title('Above line = Sparse context more object-focused'); plt.legend(); plt.tight_layout()
    plt.savefig('sparse_context_tdsm_sensitivity_scatter.png', dpi=100)
    plt.show()
else:
    print('Skipping spatial diff / object sensitivity (no baseline loaded).')