# Soft InfoNCE on OmniField (CIFAR-10)

**Minimal change** on top of the original OmniField reconstruction: add a **geometry-aware contrastive objective (Soft InfoNCE)** on probe-derived tokens from the neural field.

- Sample anchor coords in view A and candidate coords in view B (B = affine augmentation of A with **known T**).
- Build L2-normalized feature tokens φ(x) from the field (decoder hidden state → projection head).
- Soft positives: \(w_{ij} \propto \exp(-\|x_j^B - T(x_i^A)\|^2 / (2\sigma^2))\) (Gaussian kernel).
- Loss: \(\mathcal{L} = \mathcal{L}_{recon} + \lambda_{ctr} \mathcal{L}_{softNCE}\).

Defaults: N_a=256, N_b=1024, d=128, τ=0.1, σ=0.08, λ_ctr=0.1 (ramp 0→0.1 over 500 steps).

In [None]:
import math
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
import torchvision
from torchvision import transforms
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from einops import rearrange, repeat

from nf_feature_models import (
    CascadedPerceiverIO,
    GaussianFourierFeatures,
    create_coordinate_grid,
    prepare_model_input,
)

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CHECKPOINT_DIR = 'checkpoints'
CKPT_PATH = os.path.join(CHECKPOINT_DIR, 'checkpoint_best.pt')
if not os.path.isfile(CKPT_PATH):
    CKPT_PATH = os.path.join(CHECKPOINT_DIR, 'checkpoint_last.pt')
assert os.path.isfile(CKPT_PATH), f'No checkpoint in {CHECKPOINT_DIR}. Train AblationCIFAR10 first.'
print(f'Device: {DEVICE}')


In [None]:
# OmniField config (match checkpoint)
IMAGE_SIZE = 32
CHANNELS = 3
FOURIER_MAPPING_SIZE = 96
POS_EMBED_DIM = FOURIER_MAPPING_SIZE * 2
INPUT_DIM = CHANNELS + POS_EMBED_DIM
QUERIES_DIM = POS_EMBED_DIM
LOGITS_DIM = CHANNELS

fourier_encoder = GaussianFourierFeatures(in_features=2, mapping_size=FOURIER_MAPPING_SIZE, scale=15.0).to(DEVICE)
model = CascadedPerceiverIO(
    input_dim=INPUT_DIM,
    queries_dim=QUERIES_DIM,
    logits_dim=LOGITS_DIM,
    latent_dims=(256, 384, 512),
    num_latents=(256, 256, 256),
    decoder_ff=True,
).to(DEVICE)

ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
model.load_state_dict(ckpt['model_state_dict'], strict=False)
fourier_encoder.load_state_dict(ckpt['fourier_encoder_state_dict'], strict=False)
coords_32 = create_coordinate_grid(IMAGE_SIZE, IMAGE_SIZE, DEVICE)
print(f'Loaded {CKPT_PATH}')


In [None]:
def get_residual(model, data):
    residual = None
    for block in model.encoder_blocks:
        residual = block(x=residual, context=data, mask=None, residual=residual)
    for sa_block in model.self_attn_blocks:
        residual = sa_block[0](residual) + residual
        residual = sa_block[1](residual) + residual
    return residual

def get_rgb_and_phi_raw(model, queries, residual):
    x = model.decoder_cross_attn(queries, context=residual)
    x = x + queries
    if model.decoder_ff is not None:
        x = x + model.decoder_ff(x)
    phi_raw = x
    rgb = model.to_logits(x)
    return rgb, phi_raw

def sample_gt_at_coords(images, coords):
    B, C, H, W = images.shape
    N = coords.shape[1]
    grid = coords[..., [1, 0]].view(B, 1, N, 2)
    sampled = F.grid_sample(images, grid, mode='bilinear', padding_mode='border', align_corners=True)
    return sampled.squeeze(2).permute(0, 2, 1)

def make_grid_2d(h, w, device):
    y = torch.linspace(-1, 1, h, device=device)
    x = torch.linspace(-1, 1, w, device=device)
    g = torch.stack(torch.meshgrid(y, x, indexing='ij'), dim=-1)
    return g.reshape(-1, 2)


## Affine augmentation with known T (A → B)

We apply an affine transform to get view B and store the forward map T so that for any coord in A we have T(x) in B space. Coords are in [-1, 1]².

In [None]:
def sample_affine_params(batch_size, device, scale_range=(0.8, 1.0), max_translate=0.1, max_angle_deg=15):
    '''Returns T (A→B): R (B,2,2), t (B,2). T(p) = p @ R.T + t (row vectors).'''
    angle = (torch.rand(batch_size, device=device) * 2 - 1) * (max_angle_deg * math.pi / 180)
    scale = scale_range[0] + torch.rand(batch_size, device=device) * (scale_range[1] - scale_range[0])
    tx = (torch.rand(batch_size, device=device) * 2 - 1) * max_translate
    ty = (torch.rand(batch_size, device=device) * 2 - 1) * max_translate
    c, s = torch.cos(angle), torch.sin(angle)
    R = torch.stack([c * scale, -s * scale, s * scale, c * scale], dim=-1).view(batch_size, 2, 2)
    t = torch.stack([tx, ty], dim=1)
    return R, t

def apply_affine_to_coords(coords, R, t):
    '''coords (B, N, 2), R (B, 2, 2), t (B, 2). T(p)=R@p+t. Returns (B, N, 2) in B space.'''
    return torch.einsum('bed,bnd->bne', R, coords) + t.unsqueeze(1)

def apply_affine_to_image(images, R, t, align_corners=True):
    '''View B = T(A). We need grid such that B is sampled from A at T_inv(B_coords). R,t = T.'''
    B, C, H, W = images.shape
    R_inv = torch.inverse(R)
    t_exp = t.unsqueeze(1)
    theta = torch.cat([R_inv, -(R_inv @ t.unsqueeze(2))], dim=2)
    grid = F.affine_grid(theta, images.size(), align_corners=align_corners)
    out = F.grid_sample(images, grid, mode='bilinear', padding_mode='border', align_corners=align_corners)
    return out


In [None]:
class ProjectionHead(nn.Module):
    def __init__(self, in_dim, out_dim=128):
        super().__init__()
        self.proj = nn.Linear(in_dim, out_dim)
        self.out_dim = out_dim

    def forward(self, z):
        return F.normalize(self.proj(z), dim=-1)

PROJ_DIM = 128
projection_head = ProjectionHead(QUERIES_DIM, PROJ_DIM).to(DEVICE)
print(projection_head)


In [None]:
def soft_infonce_loss(phi_a, phi_b, coords_a, coords_b, R, t, tau=0.1, sigma=0.08):
    '''
    phi_a (B, N_a, d), phi_b (B, N_b, d) L2-normalized.
    coords_a (B, N_a, 2), coords_b (B, N_b, 2). T: A->B given by R (B,2,2), t (B,2).
    Soft weights w_ij = exp(-||x_j^B - T(x_i^A)||^2 / (2*sigma^2)), normalized over j.
    '''
    B, N_a, _ = phi_a.shape
    N_b = phi_b.size(1)
    logits = torch.bmm(phi_a, phi_b.transpose(1, 2)) / tau
    xi_mapped = apply_affine_to_coords(coords_a, R, t)
    sqd = ((coords_b.unsqueeze(1) - xi_mapped.unsqueeze(2)) ** 2).sum(-1)
    w = torch.exp(-sqd / (2 * sigma ** 2))
    w = w / (w.sum(dim=2, keepdim=True) + 1e-8)
    log_probs = F.log_softmax(logits, dim=-1)
    loss = -(w * log_probs).sum(-1).mean()
    return loss


In [None]:
cfg = {
    'subset_size': 10000,
    'batch_size': 32,
    'N_a': 256,
    'N_b': 1024,
    'tau': 0.1,
    'sigma': 0.08,
    'lambda_ctr': 0.1,
    'ramp_steps': 500,
    'epochs': 5,
    'lr': 1e-3,
    'freeze_backbone': False,
}

transform = transforms.Compose([transforms.ToTensor()])
train_ds = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_sub = Subset(train_ds, list(range(min(cfg['subset_size'], len(train_ds)))))
train_loader = DataLoader(train_sub, batch_size=cfg['batch_size'], shuffle=True, num_workers=0)
val_ds = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False)
print(f'Train batches: {len(train_loader)}, Val batches: {len(val_loader)}')


## Training loop: recon + Soft InfoNCE (λ ramp)

In [None]:
params = list(projection_head.parameters())
if not cfg.get('freeze_backbone'):
    params += list(model.parameters()) + list(fourier_encoder.parameters())
optimizer = torch.optim.Adam(params, lr=cfg['lr'])
global_step = [0]

def get_lambda_ctr(step):
    ramp = cfg['ramp_steps']
    if step >= ramp:
        return cfg['lambda_ctr']
    return cfg['lambda_ctr'] * (step / ramp)


In [None]:
def eval_psnr(model, fourier_encoder, loader, device, grid_size=32):
    model.eval()
    fourier_encoder.eval()
    grid = make_grid_2d(grid_size, grid_size, device)
    mse_sum, n = 0.0, 0
    with torch.no_grad():
        for imgs, _ in loader:
            imgs = imgs.to(device)
            B = imgs.size(0)
            input_data, _, _ = prepare_model_input(imgs, grid, fourier_encoder)
            residual = get_residual(model, input_data)
            coords_batch = grid.unsqueeze(0).expand(B, -1, -1)
            queries = fourier_encoder(coords_batch)
            rgb, _ = get_rgb_and_phi_raw(model, queries, residual)
            gt = sample_gt_at_coords(imgs, coords_batch)
            mse_sum += F.mse_loss(rgb, gt, reduction='sum').item()
            n += B * grid.size(0)
    mse = mse_sum / max(n, 1)
    return 10 * math.log10(1.0 / (mse + 1e-10))

print(f'Baseline val PSNR: {eval_psnr(model, fourier_encoder, val_loader, DEVICE):.2f} dB')


In [None]:
psnr_log = []
for ep in range(cfg['epochs']):
    model.train()
    fourier_encoder.train()
    projection_head.train()
    total_recon = 0.0
    total_ctr = 0.0
    for imgs, _ in train_loader:
        imgs = imgs.to(DEVICE)
        B, C, H, W = imgs.shape
        R, t = sample_affine_params(B, DEVICE)
        imgs_b = apply_affine_to_image(imgs, R, t)
        N_a, N_b = cfg['N_a'], cfg['N_b']
        anchors_a = (torch.rand(B, N_a, 2, device=DEVICE) * 2 - 1)
        candidates_b = (torch.rand(B, N_b, 2, device=DEVICE) * 2 - 1)
        input_a, _, _ = prepare_model_input(imgs, coords_32, fourier_encoder)
        input_b, _, _ = prepare_model_input(imgs_b, coords_32, fourier_encoder)
        residual_a = get_residual(model, input_a)
        residual_b = get_residual(model, input_b)
        queries_a = fourier_encoder(anchors_a)
        queries_b = fourier_encoder(candidates_b)
        rgb_a, phi_raw_a = get_rgb_and_phi_raw(model, queries_a, residual_a)
        _, phi_raw_b = get_rgb_and_phi_raw(model, queries_b, residual_b)
        phi_a = projection_head(phi_raw_a)
        phi_b = projection_head(phi_raw_b)
        loss_ctr = soft_infonce_loss(phi_a, phi_b, anchors_a, candidates_b, R, t, tau=cfg['tau'], sigma=cfg['sigma'])
        coords_full = coords_32.unsqueeze(0).expand(B, -1, -1)
        queries_full = fourier_encoder(coords_full)
        rgb_full, _ = get_rgb_and_phi_raw(model, queries_full, residual_a)
        gt_full = sample_gt_at_coords(imgs, coords_full)
        loss_recon = F.mse_loss(rgb_full, gt_full)
        lam = get_lambda_ctr(global_step[0])
        loss = loss_recon + lam * loss_ctr
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        global_step[0] += 1
        total_recon += loss_recon.item()
        total_ctr += loss_ctr.item()
    psnr = eval_psnr(model, fourier_encoder, val_loader, DEVICE)
    psnr_log.append(psnr)
    print(f'Epoch {ep+1} recon: {total_recon/len(train_loader):.4f} ctr: {total_ctr/len(train_loader):.4f} val PSNR: {psnr:.2f} dB')

plt.figure(figsize=(6, 3))
plt.plot(psnr_log, 'o-')
plt.xlabel('Epoch')
plt.ylabel('Val PSNR (dB)')
plt.title('Reconstruction PSNR (baseline vs +Soft InfoNCE)')
plt.tight_layout()
plt.savefig('softnce_psnr.png', dpi=100)
plt.show()

os.makedirs(CHECKPOINT_DIR, exist_ok=True)
nce_ckpt_path = os.path.join(CHECKPOINT_DIR, 'checkpoint_nce_best.pt')
torch.save({'model_state_dict': model.state_dict(), 'fourier_encoder_state_dict': fourier_encoder.state_dict(), 'projection_head_state_dict': projection_head.state_dict()}, nce_ckpt_path)
print('Saved NCE-trained model to', nce_ckpt_path)


## Visualizations

1. **Attention heatmap (soft weights)** for a few anchors: w_j over B's grid.
2. **Coordinate error**: for each anchor, argmax_j similarity → coord; distance to T(x_i^A).
3. **2D embedding** (PCA/t-SNE) of tokens colored by class.

In [None]:
model.eval()
fourier_encoder.eval()
projection_head.eval()
imgs, labels = next(iter(val_loader))
imgs = imgs[:4].to(DEVICE)
labels = labels[:4]
B = 4
R, t = sample_affine_params(B, DEVICE)
imgs_b = apply_affine_to_image(imgs, R, t)
grid_b = make_grid_2d(IMAGE_SIZE, IMAGE_SIZE, DEVICE)
N_a_vis = 64
anchors_a = (torch.rand(B, N_a_vis, 2, device=DEVICE) * 2 - 1)
with torch.no_grad():
    input_a, _, _ = prepare_model_input(imgs, coords_32, fourier_encoder)
    input_b, _, _ = prepare_model_input(imgs_b, coords_32, fourier_encoder)
    residual_a = get_residual(model, input_a)
    residual_b = get_residual(model, input_b)
    coords_b_batch = grid_b.unsqueeze(0).expand(B, -1, -1)
    queries_a = fourier_encoder(anchors_a)
    queries_b = fourier_encoder(coords_b_batch)
    _, phi_raw_a = get_rgb_and_phi_raw(model, queries_a, residual_a)
    _, phi_raw_b = get_rgb_and_phi_raw(model, queries_b, residual_b)
    phi_a = projection_head(phi_raw_a)
    phi_b = projection_head(phi_raw_b)
    logits = torch.bmm(phi_a, phi_b.transpose(1, 2)) / cfg['tau']
    xi_mapped = apply_affine_to_coords(anchors_a, R, t)
    sqd = ((coords_b_batch.unsqueeze(1) - xi_mapped.unsqueeze(2)) ** 2).sum(-1)
    w = torch.exp(-sqd / (2 * cfg['sigma'] ** 2))
    w = w / (w.sum(dim=2, keepdim=True) + 1e-8)
b_show = 0
n_anchors_show = 4
fig, axs = plt.subplots(2, n_anchors_show, figsize=(12, 5))
for i in range(n_anchors_show):
    heat = w[b_show, i].cpu().numpy().reshape(IMAGE_SIZE, IMAGE_SIZE)
    axs[0, i].imshow(heat, cmap='hot')
    axs[0, i].set_title(f'Anchor {i} soft weights')
    axs[0, i].axis('off')
    axs[1, i].imshow(imgs_b[b_show].cpu().permute(1, 2, 0).clamp(0, 1).numpy())
    axs[1, i].axis('off')
plt.suptitle('Soft positive weights w_j over view B (one image)')
plt.tight_layout()
plt.savefig('softnce_heatmap.png', dpi=100)
plt.show()


In [None]:
with torch.no_grad():
    best_j = logits[b_show].argmax(dim=1)
    pred_coords = coords_b_batch[b_show][best_j]
    gt_coords = xi_mapped[b_show]
    err = (pred_coords - gt_coords).norm(dim=-1).cpu().numpy()
plt.figure(figsize=(5, 3))
plt.violinplot([err], positions=[0], showmeans=True)
plt.ylabel('Coord error (argmax candidate vs T(x_i^A))')
plt.title('Coordinate error (normalized space)')
plt.xticks([0], ['Soft InfoNCE'])
plt.tight_layout()
plt.savefig('softnce_coord_error.png', dpi=100)
plt.show()
print(f'Mean coord error: {err.mean():.4f}')


In [None]:
imgs_e, labels_e = next(iter(val_loader))
imgs_e = imgs_e.to(DEVICE)
n_embed = imgs_e.size(0)
labels_e = labels_e.cpu().numpy()
grid_flat = make_grid_2d(8, 8, DEVICE)
coords_embed = grid_flat.unsqueeze(0).expand(n_embed, -1, -1)
with torch.no_grad():
    input_embed, _, _ = prepare_model_input(imgs_e, coords_32, fourier_encoder)
    residual_embed = get_residual(model, input_embed)
    queries_embed = fourier_encoder(coords_embed)
    _, phi_raw_embed = get_rgb_and_phi_raw(model, queries_embed, residual_embed)
    phi_embed = projection_head(phi_raw_embed)
feats = phi_embed.cpu().numpy().reshape(-1, PROJ_DIM)
label_tile = np.repeat(labels_e[:, None], 64, axis=1).reshape(-1)
try:
    from sklearn.decomposition import PCA
    pca = PCA(n_components=2)
    feats_2d = pca.fit_transform(feats)
except ImportError:
    feats_2d = np.random.randn(feats.shape[0], 2)
    print('Install sklearn for PCA; using random 2D for demo.')
plt.figure(figsize=(6, 5))
sc = plt.scatter(feats_2d[:, 0], feats_2d[:, 1], c=label_tile, cmap='tab10', s=1, alpha=0.6)
plt.colorbar(sc, label='Class')
plt.title('Token embedding PCA (colored by CIFAR-10 class)')
plt.tight_layout()
plt.savefig('softnce_embedding.png', dpi=100)
plt.show()


## Summary

- Single intervention: **Soft InfoNCE** on probe-derived tokens (decoder hidden → projection head).
- Geometry-aware soft positives with known affine T; λ_ctr ramped to avoid destabilizing recon.
- **Visualizations**: PSNR curve; attention heatmap; coord error violin; **View A/B** (anchors → GT vs pred); **Retrieval @ ε**; **weight concentration** histogram; **heatmap overlay** (GT + pred on w_j); 2D token embedding PCA.

### What NCE adds: correspondence quality and weight concentration

The next cells show (1) **View A → View B**: anchors on A and their GT vs predicted match on B; (2) **Retrieval @ ε**: fraction of anchors matched within tolerance; (3) **Soft weight concentration**: how peaked the soft positives are; (4) **Heatmap overlay**: GT and predicted match on the weight heatmap.

In [None]:
# View A vs View B: anchors on A, GT and predicted match on B (reuses logits, anchors_a, xi_mapped, pred_coords from above)
def norm_to_pixel(coords_norm, h, w):
    y, x = coords_norm[..., 0], coords_norm[..., 1]
    row = (y + 1) / 2 * (h - 1)
    col = (x + 1) / 2 * (w - 1)
    return col.cpu().numpy(), row.cpu().numpy()

n_show = min(8, anchors_a.size(1))
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(imgs[b_show].cpu().permute(1, 2, 0).clamp(0, 1).numpy())
axs[0].set_title('View A (anchors)')
axs[0].axis('off')
cx_a, cy_a = norm_to_pixel(anchors_a[b_show, :n_show], IMAGE_SIZE, IMAGE_SIZE)
axs[0].scatter(cx_a, cy_a, c='lime', s=40, marker='o', edgecolors='black', linewidths=0.5, label='anchors')

axs[1].imshow(imgs_b[b_show].cpu().permute(1, 2, 0).clamp(0, 1).numpy())
axs[1].set_title('View B (GT vs predicted match)')
axs[1].axis('off')
cx_gt, cy_gt = norm_to_pixel(xi_mapped[b_show, :n_show], IMAGE_SIZE, IMAGE_SIZE)
cx_pr, cy_pr = norm_to_pixel(pred_coords[:n_show], IMAGE_SIZE, IMAGE_SIZE)
axs[1].scatter(cx_gt, cy_gt, c='lime', s=60, marker='+', linewidths=2, label='GT T(x)')
axs[1].scatter(cx_pr, cy_pr, c='cyan', s=40, marker='x', linewidths=1.5, label='pred (argmax)')
axs[1].legend(loc='upper right', fontsize=8)
plt.suptitle('Correspondence: green = anchor/GT, cyan = model prediction')
plt.tight_layout()
plt.savefig('softnce_viewA_viewB.png', dpi=100)
plt.show()

In [None]:
# Retrieval @ ε: fraction of anchors whose predicted match is within ε of GT (normalized coords)
eps_values = [0.05, 0.1, 0.15, 0.2, 0.25, 0.3]
err_all = (pred_coords - xi_mapped[b_show]).norm(dim=-1).cpu().numpy()
acc_at_eps = [(err_all < eps).mean() * 100 for eps in eps_values]
fig, ax = plt.subplots(figsize=(6, 3))
ax.bar(range(len(eps_values)), acc_at_eps, color='steelblue', edgecolor='black')
ax.set_xticks(range(len(eps_values)))
ax.set_xticklabels([str(e) for e in eps_values])
ax.set_xlabel('ε (normalized coord distance)')
ax.set_ylabel('% anchors with error < ε')
ax.set_title('Retrieval accuracy: how often does argmax fall near GT?')
plt.tight_layout()
plt.savefig('softnce_retrieval_at_eps.png', dpi=100)
plt.show()
print('Retrieval @ 0.1: {:.1f}%  @ 0.2: {:.1f}%'.format(acc_at_eps[1], acc_at_eps[3]))

In [None]:
# Soft weight concentration: max_j w_ij per anchor (higher = model puts mass on few candidates)
w_max = w.max(dim=2).values.flatten().cpu().numpy()
plt.figure(figsize=(5, 3))
plt.hist(w_max, bins=30, color='steelblue', edgecolor='black', alpha=0.8)
plt.xlabel('max_j w_ij (soft weight on best candidate)')
plt.ylabel('Count (anchors)')
plt.title('Concentration: peaked weights = confident correspondence')
plt.axvline(w_max.mean(), color='red', linestyle='--', label=f'mean={w_max.mean():.3f}')
plt.legend()
plt.tight_layout()
plt.savefig('softnce_weight_concentration.png', dpi=100)
plt.show()

In [None]:
# Heatmap overlay: soft weights with GT (+) and predicted (x) match for each anchor
fig, axs = plt.subplots(2, n_anchors_show, figsize=(12, 5))
for i in range(n_anchors_show):
    heat = w[b_show, i].cpu().numpy().reshape(IMAGE_SIZE, IMAGE_SIZE)
    axs[0, i].imshow(heat, cmap='hot')
    axs[0, i].set_title(f'Anchor {i}')
    cx_gt, cy_gt = norm_to_pixel(xi_mapped[b_show, i:i+1], IMAGE_SIZE, IMAGE_SIZE)
    cx_pr, cy_pr = norm_to_pixel(pred_coords[i:i+1], IMAGE_SIZE, IMAGE_SIZE)
    axs[0, i].scatter(cx_gt, cy_gt, c='lime', s=80, marker='+', linewidths=2, label='GT')
    axs[0, i].scatter(cx_pr, cy_pr, c='cyan', s=50, marker='x', linewidths=1.5, label='pred')
    axs[0, i].axis('off')
    axs[1, i].imshow(imgs_b[b_show].cpu().permute(1, 2, 0).clamp(0, 1).numpy())
    axs[1, i].scatter(cx_gt, cy_gt, c='lime', s=80, marker='+', linewidths=2)
    axs[1, i].scatter(cx_pr, cy_pr, c='cyan', s=50, marker='x', linewidths=1.5)
    axs[1, i].axis('off')
plt.suptitle('Soft weights: lime = GT T(x), cyan = argmax prediction')
plt.tight_layout()
plt.savefig('softnce_heatmap_overlay.png', dpi=100)
plt.show()

### Feature-level difference: what NCE learned

Below we compare **positive pairs** (anchor ↔ its true correspondence in B) vs **negative pairs** (anchor ↔ other candidates) in **cosine similarity**. NCE pulls positives up and pushes negatives down, so we expect a clear gap. We also compare to a **baseline** (same backbone, random projection head) to see the feature-level effect of contrastive training.

In [None]:
# Cosine similarity matrix (phi are L2-normalized, so logits * tau = cos sim)
S_nce = (logits[b_show] * cfg['tau']).detach().cpu().numpy()
N_a, N_b = S_nce.shape
# GT match index per anchor: candidate j closest to T(anchor_i)
sqd_b = ((coords_b_batch[b_show].unsqueeze(0) - xi_mapped[b_show].unsqueeze(1)) ** 2).sum(-1).cpu().numpy()
j_gt = np.argmin(sqd_b, axis=1)
pos_sims = S_nce[np.arange(N_a), j_gt]
neg_mask = np.ones((N_a, N_b), dtype=bool)
neg_mask[np.arange(N_a), j_gt] = False
neg_sims = S_nce[neg_mask]

fig, ax = plt.subplots(figsize=(6, 3.5))
ax.hist(neg_sims, bins=40, alpha=0.6, color='coral', label='negative pairs', density=True)
ax.hist(pos_sims, bins=30, alpha=0.6, color='green', label='positive pairs (GT corr.)', density=True)
ax.axvline(pos_sims.mean(), color='green', linestyle='--', linewidth=1.5, label=f'pos mean={pos_sims.mean():.3f}')
ax.axvline(neg_sims.mean(), color='coral', linestyle='--', linewidth=1.5, label=f'neg mean={neg_sims.mean():.3f}')
ax.set_xlabel('Cosine similarity φ(anchor) · φ(candidate)')
ax.set_ylabel('Density')
ax.set_title('Feature-level: NCE-trained model (pos vs neg)')
ax.legend(loc='upper left', fontsize=8)
plt.tight_layout()
plt.savefig('softnce_feature_pos_neg_hist.png', dpi=100)
plt.show()
print(f'Positive mean: {pos_sims.mean():.4f}  Negative mean: {neg_sims.mean():.4f}  Gap: {pos_sims.mean()-neg_sims.mean():.4f}')

In [None]:
# Per-anchor margin: sim to GT match minus mean sim to all candidates (higher = more discriminative)
margin_nce = pos_sims - S_nce.mean(axis=1)
plt.figure(figsize=(5, 3))
plt.hist(margin_nce, bins=30, color='steelblue', edgecolor='black', alpha=0.8)
plt.axvline(margin_nce.mean(), color='red', linestyle='--', label=f'mean margin={margin_nce.mean():.3f}')
plt.xlabel('Margin (sim to GT − mean sim to candidates)')
plt.ylabel('Count')
plt.title('Feature margin: how much higher is sim to true correspondence?')
plt.legend()
plt.tight_layout()
plt.savefig('softnce_feature_margin.png', dpi=100)
plt.show()

In [None]:
# Baseline: same backbone, RANDOM projection head (no contrastive training) → feature-level comparison
proj_baseline = ProjectionHead(QUERIES_DIM, PROJ_DIM).to(DEVICE)
with torch.no_grad():
    phi_bl_a = proj_baseline(phi_raw_a)
    phi_bl_b = proj_baseline(phi_raw_b)
S_bl = (torch.bmm(phi_bl_a, phi_bl_b.transpose(1, 2))[b_show]).cpu().numpy()
pos_sims_bl = S_bl[np.arange(N_a), j_gt]
neg_sims_bl = S_bl[neg_mask]

fig, ax = plt.subplots(figsize=(7, 4))
ax.hist(neg_sims, bins=40, alpha=0.4, color='coral', label='NCE neg', density=True)
ax.hist(pos_sims, bins=30, alpha=0.5, color='green', label='NCE pos', density=True)
ax.hist(neg_sims_bl, bins=40, alpha=0.4, color='gray', histtype='step', linewidth=2, label='baseline neg', density=True)
ax.hist(pos_sims_bl, bins=30, alpha=0.5, color='blue', histtype='step', linewidth=2, label='baseline pos', density=True)
ax.set_xlabel('Cosine similarity')
ax.set_ylabel('Density')
ax.set_title('Feature-level: NCE (filled) vs baseline / random proj (outline)')
ax.legend(loc='upper left', fontsize=8)
plt.tight_layout()
plt.savefig('softnce_feature_baseline_vs_nce.png', dpi=100)
plt.show()
print('NCE:    pos mean={:.4f}  neg mean={:.4f}  gap={:.4f}'.format(pos_sims.mean(), neg_sims.mean(), pos_sims.mean()-neg_sims.mean()))
print('Baseline: pos mean={:.4f}  neg mean={:.4f}  gap={:.4f}'.format(pos_sims_bl.mean(), neg_sims_bl.mean(), pos_sims_bl.mean()-neg_sims_bl.mean()))

### TDSM: token-decoded spatial map (baseline vs NCE)

Same as **TDSM_Classification.ipynb**: each **latent token** is used as the only context → decode at full 32×32 grid → one "component" image per token (texture-like in baseline). We compare **baseline** (no NCE) vs **NCE-trained**: do per-token reconstructions look less texture / more structure with NCE?

In [None]:
# Load baseline model (same checkpoint, never NCE-trained) for comparison
baseline_fourier = GaussianFourierFeatures(in_features=2, mapping_size=FOURIER_MAPPING_SIZE, scale=15.0).to(DEVICE)
baseline_model = CascadedPerceiverIO(
    input_dim=INPUT_DIM, queries_dim=QUERIES_DIM, logits_dim=LOGITS_DIM,
    latent_dims=(256, 384, 512), num_latents=(256, 256, 256), decoder_ff=True,
).to(DEVICE)
ckpt_baseline = torch.load(CKPT_PATH, map_location=DEVICE)
baseline_model.load_state_dict(ckpt_baseline['model_state_dict'], strict=False)
baseline_fourier.load_state_dict(ckpt_baseline['fourier_encoder_state_dict'], strict=False)
baseline_model.eval()
baseline_fourier.eval()
for p in baseline_model.parameters():
    p.requires_grad = False
for p in baseline_fourier.parameters():
    p.requires_grad = False
print('Baseline model (no NCE) loaded from checkpoint.')

In [None]:
# TDSM: one latent token -> full 32x32 decoded image (same as TDSM_Classification.ipynb)
def decoder_forward(model, queries, context):
    '''queries (B,N,qd), context (B,1 or B,L,ld) -> (B,N,3).'''
    x = model.decoder_cross_attn(queries, context=context)
    x = x + queries
    if model.decoder_ff is not None:
        x = x + model.decoder_ff(x)
    return model.to_logits(x)

def get_tdsm(model, fourier_encoder, data, coords_32, device, num_tokens=256, token_step=4):
    '''One decoded 32x32x3 per latent token (context = that token only); return (B, n_tokens, 32, 32) mean over RGB.'''
    with torch.no_grad():
        residual = get_residual(model, data)
        B = data.size(0)
        queries_32 = fourier_encoder(repeat(coords_32, 'n d -> b n d', b=B)).to(device)
        component_images = []
        for k in range(0, num_tokens, token_step):
            ctx_k = residual[:, k:k+1, :]
            logits_k = decoder_forward(model, queries_32, ctx_k)
            img_k = logits_k.reshape(B, IMAGE_SIZE, IMAGE_SIZE, 3)
            component_images.append(img_k)
        component_images = torch.stack(component_images, dim=1)
        tdsm = component_images.mean(dim=-1)
    return tdsm

TDSM_TOKEN_STEP = 4
imgs_tdsm, _ = next(iter(val_loader))
imgs_tdsm = imgs_tdsm[:4].to(DEVICE)
input_tdsm, _, _ = prepare_model_input(imgs_tdsm, coords_32, fourier_encoder)
with torch.no_grad():
    tdsm_baseline = get_tdsm(baseline_model, baseline_fourier, input_tdsm, coords_32, DEVICE, token_step=TDSM_TOKEN_STEP)
    tdsm_nce      = get_tdsm(model, fourier_encoder, input_tdsm, coords_32, DEVICE, token_step=TDSM_TOKEN_STEP)
# (B, 64, 32, 32) each

In [None]:
# TDSM maps: same token indices, baseline vs NCE (one sample). Baseline = texture-like; does NCE change it?
sample_idx = 0
token_indices = [0, 16, 32, 48]
n_show = len(token_indices)
# Map token_indices (0..255) to TDSM slice index (0..63 when token_step=4): k // TDSM_TOKEN_STEP
tdsm_slice_idx = [k // TDSM_TOKEN_STEP for k in token_indices]
fig, axs = plt.subplots(2, n_show, figsize=(12, 5))
for i, (k, sk) in enumerate(zip(token_indices, tdsm_slice_idx)):
    axs[0, i].imshow(tdsm_baseline[sample_idx, sk].cpu().numpy(), cmap='viridis')
    axs[0, i].set_title(f'Token {k} (baseline)')
    axs[0, i].axis('off')
    axs[1, i].imshow(tdsm_nce[sample_idx, sk].cpu().numpy(), cmap='viridis')
    axs[1, i].set_title(f'Token {k} (NCE)')
    axs[1, i].axis('off')
plt.suptitle('TDSM: per-token recon (baseline = texture-like; does NCE change structure?)')
plt.tight_layout()
plt.savefig('softnce_tdsm_baseline_vs_nce.png', dpi=100)
plt.show()

In [None]:
# RGB component images (full 32x32 per token) for same tokens: baseline vs NCE
with torch.no_grad():
    residual_bl = get_residual(baseline_model, input_tdsm)
    residual_nce = get_residual(model, input_tdsm)
    queries_32 = fourier_encoder(repeat(coords_32, 'n d -> b n d', b=imgs_tdsm.size(0))).to(DEVICE)
    comps_bl, comps_nce = [], []
    for k in token_indices:
        ctx_bl = residual_bl[:, k:k+1, :]
        ctx_nce = residual_nce[:, k:k+1, :]
        comps_bl.append(decoder_forward(baseline_model, queries_32, ctx_bl).reshape(imgs_tdsm.size(0), IMAGE_SIZE, IMAGE_SIZE, 3))
        comps_nce.append(decoder_forward(model, queries_32, ctx_nce).reshape(imgs_tdsm.size(0), IMAGE_SIZE, IMAGE_SIZE, 3))
    comps_bl = torch.stack(comps_bl, dim=0)
    comps_nce = torch.stack(comps_nce, dim=0)
def to_display(t):
    return (t.cpu() / 2 + 0.5).clamp(0, 1) if t.abs().max() > 1.5 else t.cpu().clamp(0, 1)
fig, axs = plt.subplots(2, n_show + 1, figsize=(14, 5))
axs[0, 0].imshow(to_display(imgs_tdsm[sample_idx]).permute(1, 2, 0).numpy())
axs[0, 0].set_title('Input')
axs[0, 0].axis('off')
axs[1, 0].axis('off')
for i in range(n_show):
    axs[0, i+1].imshow(to_display(comps_bl[i, sample_idx]).numpy())
    axs[0, i+1].set_title(f'Token {token_indices[i]} baseline')
    axs[0, i+1].axis('off')
    axs[1, i+1].imshow(to_display(comps_nce[i, sample_idx]).numpy())
    axs[1, i+1].set_title(f'Token {token_indices[i]} NCE')
    axs[1, i+1].axis('off')
plt.suptitle('Per-token RGB component (baseline vs NCE): texture vs structure?')
plt.tight_layout()
plt.savefig('softnce_tdsm_components_baseline_vs_nce.png', dpi=100)
plt.show()

### What changed: visualizing object vs background differentiation

NCE adds differential patterns (object vs background) but token-level *class* separation stays weak. Below we visualize **where** the change is (spatial difference) and **which tokens** become more object- vs background-selective.

In [None]:
# 1) Spatial difference: where did NCE change the per-token reconstructions? (mean over tokens)
diff_spatial = (tdsm_nce - tdsm_baseline).abs().mean(dim=1).cpu().numpy()
fig, axs = plt.subplots(2, 4, figsize=(14, 6))
for i in range(4):
    axs[0, i].imshow(imgs_tdsm[i].cpu().permute(1, 2, 0).clamp(0, 1).numpy())
    axs[0, i].set_title('Input' if i == 0 else '')
    axs[0, i].axis('off')
    im = axs[1, i].imshow(diff_spatial[i], cmap='hot')
    axs[1, i].set_title('|NCE − Baseline| (mean over tokens)' if i == 0 else '')
    axs[1, i].axis('off')
plt.colorbar(im, ax=axs[1, :], shrink=0.6, label='Mean |diff|')
plt.suptitle('Where NCE changed TDSM: object vs background')
plt.tight_layout()
plt.savefig('softnce_tdsm_spatial_diff.png', dpi=100)
plt.show()

In [None]:
# 2) Foreground vs background: per-token mean activation in center (object) vs border (background)
H, W = IMAGE_SIZE, IMAGE_SIZE
margin = 8
obj_mask = np.zeros((H, W), dtype=np.float32)
obj_mask[margin:H-margin, margin:W-margin] = 1.0
bg_mask = 1.0 - obj_mask
obj_mask = torch.from_numpy(obj_mask).to(DEVICE).view(1, 1, H, W)
bg_mask = torch.from_numpy(bg_mask).to(DEVICE).view(1, 1, H, W)
# tdsm: (B, 64, 32, 32)
n_tokens = tdsm_baseline.size(1)
obj_bl = (tdsm_baseline * obj_mask).sum(dim=(2, 3)) / (obj_mask.sum() + 1e-8)
bg_bl = (tdsm_baseline * bg_mask).sum(dim=(2, 3)) / (bg_mask.sum() + 1e-8)
obj_nce = (tdsm_nce * obj_mask).sum(dim=(2, 3)) / (obj_mask.sum() + 1e-8)
bg_nce = (tdsm_nce * bg_mask).sum(dim=(2, 3)) / (bg_mask.sum() + 1e-8)
sens_bl = (obj_bl - bg_bl).mean(dim=0).cpu().numpy()
sens_nce = (obj_nce - bg_nce).mean(dim=0).cpu().numpy()
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
axs[0].bar(np.arange(n_tokens), sens_bl, color='steelblue', alpha=0.8, label='Baseline')
axs[0].set_xlabel('Token index')
axs[0].set_ylabel('Object sensitivity (mean center − mean border)')
axs[0].set_title('Baseline: token object vs background')
axs[0].legend()
axs[1].bar(np.arange(n_tokens), sens_nce, color='green', alpha=0.8, label='NCE')
axs[1].set_xlabel('Token index')
axs[1].set_ylabel('Object sensitivity')
axs[1].set_title('NCE: token object vs background')
axs[1].legend()
plt.suptitle('Per-token object sensitivity (higher = more object-focused)')
plt.tight_layout()
plt.savefig('softnce_tdsm_object_sensitivity.png', dpi=100)
plt.show()
# Scatter: baseline sensitivity vs NCE sensitivity per token (above diagonal = NCE more object-focused)
plt.figure(figsize=(5, 5))
plt.scatter(sens_bl, sens_nce, alpha=0.7)
plt.plot([sens_bl.min(), sens_bl.max()], [sens_bl.min(), sens_bl.max()], 'r--', label='y=x')
plt.xlabel('Baseline object sensitivity')
plt.ylabel('NCE object sensitivity')
plt.title('Per-token: NCE vs baseline (above line = NCE more object-focused)')
plt.legend()
plt.tight_layout()
plt.savefig('softnce_tdsm_sensitivity_scatter.png', dpi=100)
plt.show()

### How to improve semantics on top of NCE

NCE gives **geometry and object vs background**; class separation needs an explicit **class signal**. Options:

1. **Auxiliary classifier**: Add a small head on pooled φ (or on TDSM pooled) → class logits; train with **λ_cls × cross-entropy** alongside recon + NCE. Keeps NCE dominant but pulls the representation toward class-discriminative.
2. **Supervised contrastive**: In the contrastive loss, add **same-class** pairs as extra positives (e.g. in-batch: anchor from image A, positive from image B if class(A)=class(B)), with a smaller weight than the geometric positives. Pushes same-class features closer.
3. **Foreground-weighted NCE**: Mask or downweight contrastive pairs where the anchor/candidate fall in "background" (e.g. by a simple saliency or center prior), so the loss focuses on object regions; semantics can emerge more from object-level consistency.

Below: minimal **auxiliary classifier** snippet you can plug into the training loop.

In [None]:
# Optional: auxiliary classifier on pooled φ to add class signal (run in training loop with lambda_cls * CE)
NUM_CLASSES = 10
pooled_dim = PROJ_DIM  # or QUERIES_DIM if you pool phi_raw
aux_classifier = nn.Sequential(
    nn.Linear(pooled_dim, 128),
    nn.ReLU(),
    nn.Linear(128, NUM_CLASSES),
).to(DEVICE)
# In training: get phi_a (B, N_a, PROJ_DIM), pool over coords: phi_pool = phi_a.mean(dim=1)  # (B, PROJ_DIM)
# logits_cls = aux_classifier(phi_pool)
# loss_cls = F.cross_entropy(logits_cls, labels)
# loss = loss_recon + lam_ctr * loss_softNCE + lambda_cls * loss_cls   # e.g. lambda_cls=0.05
print('Aux classifier (pooled φ → 10 classes):', aux_classifier)
print('To use: pool phi over anchors, add loss_cls with lambda_cls ~ 0.05 to total loss.')

### TDSM class/semantics: t-SNE (and PCA) colored by CIFAR-10 class

Pool TDSM per image (e.g. spatial mean per token → 64-dim), then t-SNE (or PCA) and color by class. Compare **baseline** vs **NCE**: does NCE yield better class separation in the reconstruction space?

In [None]:
# Collect TDSM features for many val images (pooled: mean over space per token -> 64-dim per image)
N_VAL_TDSM = min(500, len(val_ds))
batch_size = 32
all_feat_baseline = []
all_feat_nce = []
all_labels = []
n_done = 0
baseline_model.eval()
model.eval()
fourier_encoder.eval()
baseline_fourier.eval()
with torch.no_grad():
    for imgs, labels in val_loader:
        if n_done >= N_VAL_TDSM:
            break
        imgs = imgs.to(DEVICE)
        input_data, _, _ = prepare_model_input(imgs, coords_32, fourier_encoder)
        tdsm_bl = get_tdsm(baseline_model, baseline_fourier, input_data, coords_32, DEVICE, token_step=TDSM_TOKEN_STEP)
        tdsm_n = get_tdsm(model, fourier_encoder, input_data, coords_32, DEVICE, token_step=TDSM_TOKEN_STEP)
        feat_bl = tdsm_bl.mean(dim=(2, 3)).cpu().numpy()
        feat_n = tdsm_n.mean(dim=(2, 3)).cpu().numpy()
        all_feat_baseline.append(feat_bl)
        all_feat_nce.append(feat_n)
        all_labels.append(labels.numpy())
        n_done += imgs.size(0)
X_baseline = np.concatenate(all_feat_baseline, axis=0)[:N_VAL_TDSM]
X_nce = np.concatenate(all_feat_nce, axis=0)[:N_VAL_TDSM]
y_all = np.concatenate(all_labels, axis=0)[:N_VAL_TDSM]
print('TDSM features: baseline', X_baseline.shape, 'NCE', X_nce.shape, 'labels', y_all.shape)

In [None]:
# t-SNE (and PCA) of TDSM features, colored by class; baseline vs NCE
try:
    from sklearn.decomposition import PCA
    from sklearn.manifold import TSNE
    n_components_pca = min(50, X_baseline.shape[1], X_baseline.shape[0] - 1)
    pca_bl = PCA(n_components=n_components_pca).fit(X_baseline)
    pca_nce = PCA(n_components=n_components_pca).fit(X_nce)
    X_bl_pca = pca_bl.transform(X_baseline)
    X_nce_pca = pca_nce.transform(X_nce)
    X_bl_tsne = TSNE(n_components=2, random_state=42, perplexity=30).fit_transform(X_bl_pca)
    X_nce_tsne = TSNE(n_components=2, random_state=42, perplexity=30).fit_transform(X_nce_pca)
    fig, axs = plt.subplots(1, 2, figsize=(12, 5))
    for ax, X_2d, title in [(axs[0], X_bl_tsne, 'Baseline (no NCE)'), (axs[1], X_nce_tsne, 'NCE-trained')]:
        sc = ax.scatter(X_2d[:, 0], X_2d[:, 1], c=y_all, cmap='tab10', s=12, alpha=0.7)
        ax.set_title(title)
        ax.set_xlabel('t-SNE 1')
        ax.set_ylabel('t-SNE 2')
    plt.colorbar(sc, ax=axs, label='Class', shrink=0.6)
    plt.suptitle('TDSM features (pooled): t-SNE colored by CIFAR-10 class')
    plt.tight_layout()
    plt.savefig('softnce_tdsm_tsne_class.png', dpi=100)
    plt.show()
    # PCA 2D (faster, linear) for comparison
    fig, axs = plt.subplots(1, 2, figsize=(12, 5))
    X_bl_pca2 = X_bl_pca[:, :2]
    X_nce_pca2 = X_nce_pca[:, :2]
    for ax, X_2d, title in [(axs[0], X_bl_pca2, 'Baseline PCA'), (axs[1], X_nce_pca2, 'NCE PCA')]:
        ax.scatter(X_2d[:, 0], X_2d[:, 1], c=y_all, cmap='tab10', s=12, alpha=0.7)
        ax.set_title(title)
        ax.set_xlabel('PC1')
        ax.set_ylabel('PC2')
    plt.suptitle('TDSM features: PCA first 2 components (class-colored)')
    plt.tight_layout()
    plt.savefig('softnce_tdsm_pca_class.png', dpi=100)
    plt.show()
except ImportError as e:
    print('Install sklearn for PCA/t-SNE:', e)