# Soft InfoNCE on OmniField (CIFAR-10)

**Minimal change** on top of the original OmniField reconstruction: add a **geometry-aware contrastive objective (Soft InfoNCE)** on probe-derived tokens from the neural field.

- Sample anchor coords in view A and candidate coords in view B (B = affine augmentation of A with **known T**).
- Build L2-normalized feature tokens φ(x) from the field (decoder hidden state → projection head).
- Soft positives: \(w_{ij} \propto \exp(-\|x_j^B - T(x_i^A)\|^2 / (2\sigma^2))\) (Gaussian kernel).
- Loss: \(\mathcal{L} = \mathcal{L}_{recon} + \lambda_{ctr} \mathcal{L}_{softNCE}\).

Defaults: N_a=256, N_b=1024, d=128, τ=0.1, σ=0.08, λ_ctr=0.1 (ramp 0→0.1 over 500 steps).

In [None]:
import math
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
import torchvision
from torchvision import transforms
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from einops import rearrange, repeat

from nf_feature_models import (
    CascadedPerceiverIO,
    GaussianFourierFeatures,
    create_coordinate_grid,
    prepare_model_input,
)

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CHECKPOINT_DIR = 'checkpoints'
CKPT_PATH = os.path.join(CHECKPOINT_DIR, 'checkpoint_best.pt')
if not os.path.isfile(CKPT_PATH):
    CKPT_PATH = os.path.join(CHECKPOINT_DIR, 'checkpoint_last.pt')
assert os.path.isfile(CKPT_PATH), f'No checkpoint in {CHECKPOINT_DIR}. Train AblationCIFAR10 first.'
print(f'Device: {DEVICE}')


In [None]:
# OmniField config (match checkpoint)
IMAGE_SIZE = 32
CHANNELS = 3
FOURIER_MAPPING_SIZE = 96
POS_EMBED_DIM = FOURIER_MAPPING_SIZE * 2
INPUT_DIM = CHANNELS + POS_EMBED_DIM
QUERIES_DIM = POS_EMBED_DIM
LOGITS_DIM = CHANNELS

fourier_encoder = GaussianFourierFeatures(in_features=2, mapping_size=FOURIER_MAPPING_SIZE, scale=15.0).to(DEVICE)
model = CascadedPerceiverIO(
    input_dim=INPUT_DIM,
    queries_dim=QUERIES_DIM,
    logits_dim=LOGITS_DIM,
    latent_dims=(256, 384, 512),
    num_latents=(256, 256, 256),
    decoder_ff=True,
).to(DEVICE)

ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
model.load_state_dict(ckpt['model_state_dict'], strict=False)
fourier_encoder.load_state_dict(ckpt['fourier_encoder_state_dict'], strict=False)
coords_32 = create_coordinate_grid(IMAGE_SIZE, IMAGE_SIZE, DEVICE)
print(f'Loaded {CKPT_PATH}')


In [None]:
def get_residual(model, data):
    residual = None
    for block in model.encoder_blocks:
        residual = block(x=residual, context=data, mask=None, residual=residual)
    for sa_block in model.self_attn_blocks:
        residual = sa_block[0](residual) + residual
        residual = sa_block[1](residual) + residual
    return residual

def get_rgb_and_phi_raw(model, queries, residual):
    x = model.decoder_cross_attn(queries, context=residual)
    x = x + queries
    if model.decoder_ff is not None:
        x = x + model.decoder_ff(x)
    phi_raw = x
    rgb = model.to_logits(x)
    return rgb, phi_raw

def sample_gt_at_coords(images, coords):
    B, C, H, W = images.shape
    N = coords.shape[1]
    grid = coords[..., [1, 0]].view(B, 1, N, 2)
    sampled = F.grid_sample(images, grid, mode='bilinear', padding_mode='border', align_corners=True)
    return sampled.squeeze(2).permute(0, 2, 1)

def make_grid_2d(h, w, device):
    y = torch.linspace(-1, 1, h, device=device)
    x = torch.linspace(-1, 1, w, device=device)
    g = torch.stack(torch.meshgrid(y, x, indexing='ij'), dim=-1)
    return g.reshape(-1, 2)


## Affine augmentation with known T (A → B)

We apply an affine transform to get view B and store the forward map T so that for any coord in A we have T(x) in B space. Coords are in [-1, 1]².

In [None]:
def sample_affine_params(batch_size, device, scale_range=(0.8, 1.0), max_translate=0.1, max_angle_deg=15):
    '''Returns T (A→B): R (B,2,2), t (B,2). T(p) = p @ R.T + t (row vectors).'''
    angle = (torch.rand(batch_size, device=device) * 2 - 1) * (max_angle_deg * math.pi / 180)
    scale = scale_range[0] + torch.rand(batch_size, device=device) * (scale_range[1] - scale_range[0])
    tx = (torch.rand(batch_size, device=device) * 2 - 1) * max_translate
    ty = (torch.rand(batch_size, device=device) * 2 - 1) * max_translate
    c, s = torch.cos(angle), torch.sin(angle)
    R = torch.stack([c * scale, -s * scale, s * scale, c * scale], dim=-1).view(batch_size, 2, 2)
    t = torch.stack([tx, ty], dim=1)
    return R, t

def apply_affine_to_coords(coords, R, t):
    '''coords (B, N, 2), R (B, 2, 2), t (B, 2). T(p)=R@p+t. Returns (B, N, 2) in B space.'''
    return torch.einsum('bed,bnd->bne', R, coords) + t.unsqueeze(1)

def apply_affine_to_image(images, R, t, align_corners=True):
    '''View B = T(A). We need grid such that B is sampled from A at T_inv(B_coords). R,t = T.'''
    B, C, H, W = images.shape
    R_inv = torch.inverse(R)
    t_exp = t.unsqueeze(1)
    theta = torch.cat([R_inv, -(R_inv @ t.unsqueeze(2))], dim=2)
    grid = F.affine_grid(theta, images.size(), align_corners=align_corners)
    out = F.grid_sample(images, grid, mode='bilinear', padding_mode='border', align_corners=align_corners)
    return out


In [None]:
class ProjectionHead(nn.Module):
    def __init__(self, in_dim, out_dim=128):
        super().__init__()
        self.proj = nn.Linear(in_dim, out_dim)
        self.out_dim = out_dim

    def forward(self, z):
        return F.normalize(self.proj(z), dim=-1)

PROJ_DIM = 128
projection_head = ProjectionHead(QUERIES_DIM, PROJ_DIM).to(DEVICE)
print(projection_head)


In [None]:
def soft_infonce_loss(phi_a, phi_b, coords_a, coords_b, R, t, tau=0.1, sigma=0.08):
    '''
    phi_a (B, N_a, d), phi_b (B, N_b, d) L2-normalized.
    coords_a (B, N_a, 2), coords_b (B, N_b, 2). T: A->B given by R (B,2,2), t (B,2).
    Soft weights w_ij = exp(-||x_j^B - T(x_i^A)||^2 / (2*sigma^2)), normalized over j.
    '''
    B, N_a, _ = phi_a.shape
    N_b = phi_b.size(1)
    logits = torch.bmm(phi_a, phi_b.transpose(1, 2)) / tau
    xi_mapped = apply_affine_to_coords(coords_a, R, t)
    sqd = ((coords_b.unsqueeze(1) - xi_mapped.unsqueeze(2)) ** 2).sum(-1)
    w = torch.exp(-sqd / (2 * sigma ** 2))
    w = w / (w.sum(dim=2, keepdim=True) + 1e-8)
    log_probs = F.log_softmax(logits, dim=-1)
    loss = -(w * log_probs).sum(-1).mean()
    return loss


In [None]:
cfg = {
    'subset_size': 10000,
    'batch_size': 32,
    'N_a': 256,
    'N_b': 1024,
    'tau': 0.1,
    'sigma': 0.08,
    'lambda_ctr': 0.1,
    'ramp_steps': 500,
    'epochs': 5,
    'lr': 1e-3,
    'freeze_backbone': False,
}

transform = transforms.Compose([transforms.ToTensor()])
train_ds = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_sub = Subset(train_ds, list(range(min(cfg['subset_size'], len(train_ds)))))
train_loader = DataLoader(train_sub, batch_size=cfg['batch_size'], shuffle=True, num_workers=0)
val_ds = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False)
print(f'Train batches: {len(train_loader)}, Val batches: {len(val_loader)}')


## Training loop: recon + Soft InfoNCE (λ ramp)

In [None]:
params = list(projection_head.parameters())
if not cfg.get('freeze_backbone'):
    params += list(model.parameters()) + list(fourier_encoder.parameters())
optimizer = torch.optim.Adam(params, lr=cfg['lr'])
global_step = [0]

def get_lambda_ctr(step):
    ramp = cfg['ramp_steps']
    if step >= ramp:
        return cfg['lambda_ctr']
    return cfg['lambda_ctr'] * (step / ramp)


In [None]:
def eval_psnr(model, fourier_encoder, loader, device, grid_size=32):
    model.eval()
    fourier_encoder.eval()
    grid = make_grid_2d(grid_size, grid_size, device)
    mse_sum, n = 0.0, 0
    with torch.no_grad():
        for imgs, _ in loader:
            imgs = imgs.to(device)
            B = imgs.size(0)
            input_data, _, _ = prepare_model_input(imgs, grid, fourier_encoder)
            residual = get_residual(model, input_data)
            coords_batch = grid.unsqueeze(0).expand(B, -1, -1)
            queries = fourier_encoder(coords_batch)
            rgb, _ = get_rgb_and_phi_raw(model, queries, residual)
            gt = sample_gt_at_coords(imgs, coords_batch)
            mse_sum += F.mse_loss(rgb, gt, reduction='sum').item()
            n += B * grid.size(0)
    mse = mse_sum / max(n, 1)
    return 10 * math.log10(1.0 / (mse + 1e-10))

print(f'Baseline val PSNR: {eval_psnr(model, fourier_encoder, val_loader, DEVICE):.2f} dB')


In [None]:
psnr_log = []
for ep in range(cfg['epochs']):
    model.train()
    fourier_encoder.train()
    projection_head.train()
    total_recon = 0.0
    total_ctr = 0.0
    for imgs, _ in train_loader:
        imgs = imgs.to(DEVICE)
        B, C, H, W = imgs.shape
        R, t = sample_affine_params(B, DEVICE)
        imgs_b = apply_affine_to_image(imgs, R, t)
        N_a, N_b = cfg['N_a'], cfg['N_b']
        anchors_a = (torch.rand(B, N_a, 2, device=DEVICE) * 2 - 1)
        candidates_b = (torch.rand(B, N_b, 2, device=DEVICE) * 2 - 1)
        input_a, _, _ = prepare_model_input(imgs, coords_32, fourier_encoder)
        input_b, _, _ = prepare_model_input(imgs_b, coords_32, fourier_encoder)
        residual_a = get_residual(model, input_a)
        residual_b = get_residual(model, input_b)
        queries_a = fourier_encoder(anchors_a)
        queries_b = fourier_encoder(candidates_b)
        rgb_a, phi_raw_a = get_rgb_and_phi_raw(model, queries_a, residual_a)
        _, phi_raw_b = get_rgb_and_phi_raw(model, queries_b, residual_b)
        phi_a = projection_head(phi_raw_a)
        phi_b = projection_head(phi_raw_b)
        loss_ctr = soft_infonce_loss(phi_a, phi_b, anchors_a, candidates_b, R, t, tau=cfg['tau'], sigma=cfg['sigma'])
        coords_full = coords_32.unsqueeze(0).expand(B, -1, -1)
        queries_full = fourier_encoder(coords_full)
        rgb_full, _ = get_rgb_and_phi_raw(model, queries_full, residual_a)
        gt_full = sample_gt_at_coords(imgs, coords_full)
        loss_recon = F.mse_loss(rgb_full, gt_full)
        lam = get_lambda_ctr(global_step[0])
        loss = loss_recon + lam * loss_ctr
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        global_step[0] += 1
        total_recon += loss_recon.item()
        total_ctr += loss_ctr.item()
    psnr = eval_psnr(model, fourier_encoder, val_loader, DEVICE)
    psnr_log.append(psnr)
    print(f'Epoch {ep+1} recon: {total_recon/len(train_loader):.4f} ctr: {total_ctr/len(train_loader):.4f} val PSNR: {psnr:.2f} dB')

plt.figure(figsize=(6, 3))
plt.plot(psnr_log, 'o-')
plt.xlabel('Epoch')
plt.ylabel('Val PSNR (dB)')
plt.title('Reconstruction PSNR (baseline vs +Soft InfoNCE)')
plt.tight_layout()
plt.savefig('softnce_psnr.png', dpi=100)
plt.show()


## Visualizations

1. **Attention heatmap (soft weights)** for a few anchors: w_j over B's grid.
2. **Coordinate error**: for each anchor, argmax_j similarity → coord; distance to T(x_i^A).
3. **2D embedding** (PCA/t-SNE) of tokens colored by class.

In [None]:
model.eval()
fourier_encoder.eval()
projection_head.eval()
imgs, labels = next(iter(val_loader))
imgs = imgs[:4].to(DEVICE)
labels = labels[:4]
B = 4
R, t = sample_affine_params(B, DEVICE)
imgs_b = apply_affine_to_image(imgs, R, t)
grid_b = make_grid_2d(IMAGE_SIZE, IMAGE_SIZE, DEVICE)
N_a_vis = 64
anchors_a = (torch.rand(B, N_a_vis, 2, device=DEVICE) * 2 - 1)
with torch.no_grad():
    input_a, _, _ = prepare_model_input(imgs, coords_32, fourier_encoder)
    input_b, _, _ = prepare_model_input(imgs_b, coords_32, fourier_encoder)
    residual_a = get_residual(model, input_a)
    residual_b = get_residual(model, input_b)
    coords_b_batch = grid_b.unsqueeze(0).expand(B, -1, -1)
    queries_a = fourier_encoder(anchors_a)
    queries_b = fourier_encoder(coords_b_batch)
    _, phi_raw_a = get_rgb_and_phi_raw(model, queries_a, residual_a)
    _, phi_raw_b = get_rgb_and_phi_raw(model, queries_b, residual_b)
    phi_a = projection_head(phi_raw_a)
    phi_b = projection_head(phi_raw_b)
    logits = torch.bmm(phi_a, phi_b.transpose(1, 2)) / cfg['tau']
    xi_mapped = apply_affine_to_coords(anchors_a, R, t)
    sqd = ((coords_b_batch.unsqueeze(1) - xi_mapped.unsqueeze(2)) ** 2).sum(-1)
    w = torch.exp(-sqd / (2 * cfg['sigma'] ** 2))
    w = w / (w.sum(dim=2, keepdim=True) + 1e-8)
b_show = 0
n_anchors_show = 4
fig, axs = plt.subplots(2, n_anchors_show, figsize=(12, 5))
for i in range(n_anchors_show):
    heat = w[b_show, i].cpu().numpy().reshape(IMAGE_SIZE, IMAGE_SIZE)
    axs[0, i].imshow(heat, cmap='hot')
    axs[0, i].set_title(f'Anchor {i} soft weights')
    axs[0, i].axis('off')
    axs[1, i].imshow(imgs_b[b_show].cpu().permute(1, 2, 0).clamp(0, 1).numpy())
    axs[1, i].axis('off')
plt.suptitle('Soft positive weights w_j over view B (one image)')
plt.tight_layout()
plt.savefig('softnce_heatmap.png', dpi=100)
plt.show()


In [None]:
with torch.no_grad():
    best_j = logits[b_show].argmax(dim=1)
    pred_coords = coords_b_batch[b_show][best_j]
    gt_coords = xi_mapped[b_show]
    err = (pred_coords - gt_coords).norm(dim=-1).cpu().numpy()
plt.figure(figsize=(5, 3))
plt.violinplot([err], positions=[0], showmeans=True)
plt.ylabel('Coord error (argmax candidate vs T(x_i^A))')
plt.title('Coordinate error (normalized space)')
plt.xticks([0], ['Soft InfoNCE'])
plt.tight_layout()
plt.savefig('softnce_coord_error.png', dpi=100)
plt.show()
print(f'Mean coord error: {err.mean():.4f}')


In [None]:
n_embed = 500
imgs_e, labels_e = next(iter(val_loader))
imgs_e = imgs_e[:n_embed].to(DEVICE)
labels_e = labels_e[:n_embed].cpu().numpy()
grid_flat = make_grid_2d(8, 8, DEVICE)
coords_embed = grid_flat.unsqueeze(0).expand(n_embed, -1, -1)
with torch.no_grad():
    input_embed, _, _ = prepare_model_input(imgs_e, coords_32, fourier_encoder)
    residual_embed = get_residual(model, input_embed)
    queries_embed = fourier_encoder(coords_embed)
    _, phi_raw_embed = get_rgb_and_phi_raw(model, queries_embed, residual_embed)
    phi_embed = projection_head(phi_raw_embed)
feats = phi_embed.cpu().numpy().reshape(-1, PROJ_DIM)
label_tile = np.repeat(labels_e[:, None], 64, axis=1).reshape(-1)
try:
    from sklearn.decomposition import PCA
    pca = PCA(n_components=2)
    feats_2d = pca.fit_transform(feats)
except ImportError:
    feats_2d = np.random.randn(feats.shape[0], 2)
    print('Install sklearn for PCA; using random 2D for demo.')
plt.figure(figsize=(6, 5))
sc = plt.scatter(feats_2d[:, 0], feats_2d[:, 1], c=label_tile, cmap='tab10', s=1, alpha=0.6)
plt.colorbar(sc, label='Class')
plt.title('Token embedding PCA (colored by CIFAR-10 class)')
plt.tight_layout()
plt.savefig('softnce_embedding.png', dpi=100)
plt.show()


## Summary

- Single intervention: **Soft InfoNCE** on probe-derived tokens (decoder hidden → projection head).
- Geometry-aware soft positives with known affine T; λ_ctr ramped to avoid destabilizing recon.
- Visualizations: PSNR curve, attention heatmap, coordinate error, 2D embedding.