# NCE on top of Sparse-Context Recon (CIFAR-10)

Load a **pretrained sparse-context** model (e.g. `checkpoint_sparse_nce_off_best.pt`), add a **projection head**, and train **Soft InfoNCE** for a few epochs while keeping **sparse-input reconstruction** as the main task.

- **Recon:** Sparse context → full 32×32 (unchanged).
- **NCE:** Two-view (affine), anchors/candidates, φ from decoder → projection head; soft InfoNCE with λ ramp.
- **Epochs:** 5 (short finetune). Saves to **checkpoint_sparse_finetune_nce_best.pt** / **_last.pt**.

In [None]:
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from einops import rearrange

from nf_feature_models import (
    CascadedPerceiverIO,
    GaussianFourierFeatures,
    create_coordinate_grid,
)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CHECKPOINT_DIR = 'checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

CKPT_LOAD = os.path.join(CHECKPOINT_DIR, 'checkpoint_sparse_nce_off_best.pt')
if not os.path.isfile(CKPT_LOAD):
    CKPT_LOAD = os.path.join(CHECKPOINT_DIR, 'checkpoint_sparse_best.pt')
if not os.path.isfile(CKPT_LOAD):
    CKPT_LOAD = os.path.join(CHECKPOINT_DIR, 'checkpoint_sparse_nce_off_last.pt')
assert os.path.isfile(CKPT_LOAD), f'No sparse checkpoint in {CHECKPOINT_DIR}. Run SparseContext_CIFAR10.ipynb with USE_NCE=False first.'
print('Device:', DEVICE, 'Load from:', CKPT_LOAD)


In [None]:
IMAGE_SIZE = 32
CHANNELS = 3
FOURIER_MAPPING_SIZE = 96
POS_EMBED_DIM = FOURIER_MAPPING_SIZE * 2
INPUT_DIM = CHANNELS + POS_EMBED_DIM
QUERIES_DIM = POS_EMBED_DIM
LOGITS_DIM = CHANNELS
CONTEXT_FRAC = 0.2
BATCH_SIZE = 64
EPOCHS = 5
NCE_RAMP_STEPS = 500
LAMBDA_NCE = 0.1
N_A, N_B = 256, 1024
TAU, SIGMA = 0.1, 0.08
PROJ_DIM = 128

coords_32 = create_coordinate_grid(IMAGE_SIZE, IMAGE_SIZE, DEVICE)
N_FULL = coords_32.size(0)
N_SPARSE = max(64, int(N_FULL * CONTEXT_FRAC))

fourier_encoder = GaussianFourierFeatures(2, FOURIER_MAPPING_SIZE, 15.0).to(DEVICE)
model = CascadedPerceiverIO(
    input_dim=INPUT_DIM, queries_dim=QUERIES_DIM, logits_dim=LOGITS_DIM,
    latent_dims=(256, 384, 512), num_latents=(256, 256, 256), decoder_ff=True,
).to(DEVICE)
ckpt = torch.load(CKPT_LOAD, map_location=DEVICE)
model.load_state_dict(ckpt['model_state_dict'], strict=False)
fourier_encoder.load_state_dict(ckpt['fourier_encoder_state_dict'], strict=False)

projection_head = nn.Linear(QUERIES_DIM, PROJ_DIM).to(DEVICE)
if 'projection_head_state_dict' in ckpt:
    projection_head.load_state_dict(ckpt['projection_head_state_dict'], strict=False)

def proj_norm(z):
    return F.normalize(projection_head(z), dim=-1)

print('Loaded', CKPT_LOAD, '| Sparse ctx:', N_SPARSE, '| Epochs:', EPOCHS)


In [None]:
def sample_gt_at_coords(images, coords):
    B, C, H, W = images.shape
    N = coords.shape[1]
    grid = coords[..., [1, 0]].view(B, 1, N, 2)
    return F.grid_sample(images, grid, mode='bilinear', padding_mode='border', align_corners=True).squeeze(2).permute(0, 2, 1)

def prepare_sparse_context(images, coords_full, fourier_encoder, num_sparse, device):
    B = images.size(0)
    idx = torch.randperm(coords_full.size(0), device=device)[:num_sparse]
    coords_sparse = coords_full[idx]
    pixels_sparse = sample_gt_at_coords(images, coords_sparse.unsqueeze(0).expand(B, -1, -1))
    pos_sparse = fourier_encoder(coords_sparse.unsqueeze(0).expand(B, -1, -1))
    return torch.cat([pixels_sparse, pos_sparse], dim=-1)

def get_residual(model, data):
    residual = None
    for block in model.encoder_blocks:
        residual = block(x=residual, context=data, mask=None, residual=residual)
    for sa_block in model.self_attn_blocks:
        residual = sa_block[0](residual) + residual
        residual = sa_block[1](residual) + residual
    return residual

def get_rgb_and_phi_raw(model, queries, residual):
    x = model.decoder_cross_attn(queries, context=residual)
    x = x + queries
    if model.decoder_ff is not None:
        x = x + model.decoder_ff(x)
    return model.to_logits(x), x

def sample_affine_params(batch_size, device, scale_range=(0.85, 1.0), max_translate=0.1, max_angle_deg=12):
    angle = (torch.rand(batch_size, device=device) * 2 - 1) * (max_angle_deg * math.pi / 180)
    scale = scale_range[0] + torch.rand(batch_size, device=device) * (scale_range[1] - scale_range[0])
    tx = (torch.rand(batch_size, device=device) * 2 - 1) * max_translate
    ty = (torch.rand(batch_size, device=device) * 2 - 1) * max_translate
    c, s = torch.cos(angle), torch.sin(angle)
    R = torch.stack([c*scale, -s*scale, s*scale, c*scale], dim=-1).view(batch_size, 2, 2)
    t = torch.stack([tx, ty], dim=1)
    return R, t

def apply_affine_to_coords(coords, R, t):
    return torch.einsum('bed,bnd->bne', R, coords) + t.unsqueeze(1)

def apply_affine_to_image(images, R, t):
    R_inv = torch.inverse(R)
    theta = torch.cat([R_inv, -(R_inv @ t.unsqueeze(2))], dim=2)
    grid = F.affine_grid(theta, images.size(), align_corners=True)
    return F.grid_sample(images, grid, mode='bilinear', padding_mode='border', align_corners=True)

def soft_infonce_loss(phi_a, phi_b, coords_a, coords_b, R, t, tau=0.1, sigma=0.08):
    B, N_a, _ = phi_a.shape
    logits = torch.bmm(phi_a, phi_b.transpose(1, 2)) / tau
    xi_mapped = apply_affine_to_coords(coords_a, R, t)
    sqd = ((coords_b.unsqueeze(1) - xi_mapped.unsqueeze(2)) ** 2).sum(-1)
    w = torch.exp(-sqd / (2 * sigma ** 2))
    w = w / (w.sum(dim=2, keepdim=True) + 1e-8)
    return -(w * F.log_softmax(logits, dim=-1)).sum(-1).mean()


In [None]:
transform = transforms.Compose([transforms.ToTensor()])
train_ds = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_ds = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
print('Train batches:', len(train_loader), 'Test batches:', len(test_loader))


In [None]:
params = list(model.parameters()) + list(fourier_encoder.parameters()) + list(projection_head.parameters())
optimizer = torch.optim.Adam(params, lr=1e-3)
step = [0]
best_val_loss = float('inf')
CKPT_PREFIX = 'checkpoint_sparse_finetune_nce'
CKPT_BEST = os.path.join(CHECKPOINT_DIR, CKPT_PREFIX + '_best.pt')
CKPT_LAST = os.path.join(CHECKPOINT_DIR, CKPT_PREFIX + '_last.pt')
print('Saving to:', CKPT_BEST, CKPT_LAST)


In [None]:
for epoch in range(EPOCHS):
    model.train()
    fourier_encoder.train()
    projection_head.train()
    total_loss = 0.0
    total_recon = 0.0
    total_nce = 0.0
    for imgs, _ in train_loader:
        imgs = imgs.to(DEVICE)
        B = imgs.size(0)
        input_sparse = prepare_sparse_context(imgs, coords_32, fourier_encoder, N_SPARSE, DEVICE)
        target_pixels = rearrange(imgs, 'b c h w -> b (h w) c')
        queries_full = fourier_encoder(coords_32.unsqueeze(0).expand(B, -1, -1))
        residual_a = get_residual(model, input_sparse)
        reconstructed, _ = get_rgb_and_phi_raw(model, queries_full, residual_a)
        loss_recon = F.mse_loss(reconstructed, target_pixels)
        loss = loss_recon

        R, t = sample_affine_params(B, DEVICE)
        imgs_b = apply_affine_to_image(imgs, R, t)
        input_sparse_b = prepare_sparse_context(imgs_b, coords_32, fourier_encoder, N_SPARSE, DEVICE)
        residual_b = get_residual(model, input_sparse_b)
        anchors_a = (torch.rand(B, N_A, 2, device=DEVICE) * 2 - 1)
        candidates_b = (torch.rand(B, N_B, 2, device=DEVICE) * 2 - 1)
        q_a = fourier_encoder(anchors_a)
        q_b = fourier_encoder(candidates_b)
        _, phi_raw_a = get_rgb_and_phi_raw(model, q_a, residual_a)
        _, phi_raw_b = get_rgb_and_phi_raw(model, q_b, residual_b)
        phi_a = proj_norm(phi_raw_a)
        phi_b = proj_norm(phi_raw_b)
        lam = LAMBDA_NCE if step[0] >= NCE_RAMP_STEPS else LAMBDA_NCE * (step[0] / NCE_RAMP_STEPS)
        loss_nce = soft_infonce_loss(phi_a, phi_b, anchors_a, candidates_b, R, t, TAU, SIGMA)
        loss = loss + lam * loss_nce
        total_nce += loss_nce.item()
        step[0] += 1

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        total_recon += loss_recon.item()

    avg_loss = total_loss / len(train_loader)
    avg_recon = total_recon / len(train_loader)
    avg_nce = total_nce / len(train_loader)
    model.eval()
    fourier_encoder.eval()
    val_loss = 0.0
    with torch.no_grad():
        for imgs, _ in test_loader:
            imgs = imgs.to(DEVICE)
            B = imgs.size(0)
            input_sparse = prepare_sparse_context(imgs, coords_32, fourier_encoder, N_SPARSE, DEVICE)
            target_pixels = rearrange(imgs, 'b c h w -> b (h w) c')
            queries_full = fourier_encoder(coords_32.unsqueeze(0).expand(B, -1, -1))
            reconstructed = model(input_sparse, queries=queries_full)
            val_loss += F.mse_loss(reconstructed, target_pixels).item()
    val_loss /= len(test_loader)
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({'epoch': epoch+1, 'model_state_dict': model.state_dict(), 'fourier_encoder_state_dict': fourier_encoder.state_dict(), 'projection_head_state_dict': projection_head.state_dict(), 'best_val_loss': best_val_loss}, CKPT_BEST)
    torch.save({'epoch': epoch+1, 'model_state_dict': model.state_dict(), 'fourier_encoder_state_dict': fourier_encoder.state_dict(), 'projection_head_state_dict': projection_head.state_dict(), 'best_val_loss': best_val_loss}, CKPT_LAST)
    print(f'Epoch {epoch+1}/{EPOCHS} recon: {avg_recon:.4f} NCE: {avg_nce:.4f} val_loss: {val_loss:.4f}')


Done. Checkpoints: **checkpoint_sparse_finetune_nce_best.pt** / **_last.pt**. Use in SparseContext_Eval by loading this checkpoint (or set eval to try `checkpoint_sparse_finetune_nce_best.pt`).