# Metric-Structure Learning on OmniField (CIFAR-10)

Same **P1–P5** metric-structure add-ons as in `MetricStructure_CIFAR10.ipynb`, but built on the **well-working OmniField** model (pretrained from `AblationCIFAR10.ipynb` / `nf_feature_models.py`) instead of the Linear/ImplicitMLP baseline. The field representation **φ** is the decoder hidden state (before RGB logits) at query coordinates.

- **P1** Coordinate canonicalization (g from pooled residual)
- **P2** Multi-scale feature heads on φ
- **P3** Invariance under coord jitter
- **P4** Soft InfoNCE (optional)
- **P5** Cycle-consistency (optional)


In [None]:
import math
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
import torchvision
from torchvision import transforms
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from einops import rearrange, repeat

from nf_feature_models import (
    CascadedPerceiverIO,
    GaussianFourierFeatures,
    create_coordinate_grid,
    prepare_model_input,
)

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CHECKPOINT_DIR = 'checkpoints'
CKPT_PATH = os.path.join(CHECKPOINT_DIR, 'checkpoint_best.pt')
if not os.path.isfile(CKPT_PATH):
    CKPT_PATH = os.path.join(CHECKPOINT_DIR, 'checkpoint_last.pt')
assert os.path.isfile(CKPT_PATH), f'No checkpoint in {CHECKPOINT_DIR}. Train AblationCIFAR10 first.'
print(f'Device: {DEVICE}')


In [None]:
def sample_gt_at_coords(images, coords):
    '''images (B,C,H,W), coords (B,N,2) in [-1,1] (y,x). Returns (B,N,3).'''
    B, C, H, W = images.shape
    N = coords.shape[1]
    grid = coords[..., [1, 0]].view(B, 1, N, 2)
    sampled = F.grid_sample(images, grid, mode='bilinear', padding_mode='border', align_corners=True)
    return sampled.squeeze(2).permute(0, 2, 1)

def make_grid_2d(h, w, device):
    '''Returns (h*w, 2) in [-1,1].'''
    y = torch.linspace(-1, 1, h, device=device)
    x = torch.linspace(-1, 1, w, device=device)
    grid = torch.stack(torch.meshgrid(y, x, indexing='ij'), dim=-1)
    return grid.reshape(-1, 2)


In [None]:
# OmniField config (match AblationCIFAR10 / checkpoint)
IMAGE_SIZE = 32
CHANNELS = 3
FOURIER_MAPPING_SIZE = 96
POS_EMBED_DIM = FOURIER_MAPPING_SIZE * 2
INPUT_DIM = CHANNELS + POS_EMBED_DIM
QUERIES_DIM = POS_EMBED_DIM
LOGITS_DIM = CHANNELS

fourier_encoder = GaussianFourierFeatures(in_features=2, mapping_size=FOURIER_MAPPING_SIZE, scale=15.0).to(DEVICE)
model = CascadedPerceiverIO(
    input_dim=INPUT_DIM,
    queries_dim=QUERIES_DIM,
    logits_dim=LOGITS_DIM,
    latent_dims=(256, 384, 512),
    num_latents=(256, 256, 256),
    decoder_ff=True,
).to(DEVICE)

ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
model.load_state_dict(ckpt['model_state_dict'], strict=False)
fourier_encoder.load_state_dict(ckpt['fourier_encoder_state_dict'], strict=False)
model.eval()
fourier_encoder.eval()

coords_32 = create_coordinate_grid(IMAGE_SIZE, IMAGE_SIZE, DEVICE)
print(f'Loaded {CKPT_PATH}')


In [None]:
def get_residual(model, data):
    '''Encoder + processor only -> (B, num_latents, latent_dim).'''
    residual = None
    for block in model.encoder_blocks:
        residual = block(x=residual, context=data, mask=None, residual=residual)
    for sa_block in model.self_attn_blocks:
        residual = sa_block[0](residual) + residual
        residual = sa_block[1](residual) + residual
    return residual

def get_rgb_and_phi(model, queries, residual):
    '''Decoder up to logits; return rgb (B,N,3) and phi (B,N,QUERIES_DIM) for metric learning.'''
    x = model.decoder_cross_attn(queries, context=residual)
    x = x + queries
    if model.decoder_ff is not None:
        x = x + model.decoder_ff(x)
    phi = x
    rgb = model.to_logits(x)
    return rgb, phi

def omnifield_forward(images, coords, model, fourier_encoder, coords_full, device):
    '''
    images (B,C,H,W), coords (B,N,2). coords_full: full grid for context (e.g. 32*32, 2).
    Returns rgb (B,N,3), phi (B,N,QUERIES_DIM), residual (B,L,D).
    '''
    input_data, _, _ = prepare_model_input(images, coords_full, fourier_encoder)
    residual = get_residual(model, input_data)
    B = coords.shape[0]
    if coords.ndim == 2:
        coords = coords.unsqueeze(0).expand(residual.size(0), -1, -1)
    queries = fourier_encoder(coords)
    rgb, phi = get_rgb_and_phi(model, queries, residual)
    return rgb, phi, residual


In [None]:
class Canonicalizer(nn.Module):
    def __init__(self, code_dim=16, coord_dim=2):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(code_dim, 32),
            nn.ReLU(),
            nn.Linear(32, coord_dim * 2)
        )

    def forward(self, coords, g):
        '''coords (B,N,2), g (B, code_dim). Returns canonical coords (B,N,2).'''
        params = self.mlp(g)
        A = torch.diag_embed(torch.sigmoid(params[:, :2]) * 1.8 + 0.1)
        b = params[:, 2:4] * 0.1
        return torch.einsum('bnd,bde->bne', coords, A) + b.unsqueeze(1)

class FeatureHeads(nn.Module):
    def __init__(self, latent_dim, pe_dims=(8, 16, 32), head_dim=32):
        super().__init__()
        self.heads = nn.ModuleList()
        for pe_freq in pe_dims:
            pe_size = 4 * pe_freq
            self.heads.append(nn.Sequential(
                nn.Linear(latent_dim + pe_size, 128),
                nn.ReLU(),
                nn.Linear(128, head_dim)
            ))
        self.head_dim = head_dim
        self.pe_dims = pe_dims

    def _pe(self, coords, max_freq):
        b, n, _ = coords.shape
        freqs = 2.0 ** torch.linspace(0, max_freq, max_freq, device=coords.device)
        x = coords.unsqueeze(-1) * freqs
        out = torch.cat([torch.sin(math.pi * x), torch.cos(math.pi * x)], dim=-1)
        return out.reshape(b, n, -1)

    def forward(self, z, coords):
        '''z (B,N,D), coords (B,N,2). Returns list of (B,N,head_dim) L2-normalized.'''
        out = []
        for head, mf in zip(self.heads, self.pe_dims):
            pe = self._pe(coords, mf)
            feat = head(torch.cat([z, pe], dim=-1))
            out.append(F.normalize(feat, dim=-1))
        return out


In [None]:
cfg = {
    'subset_size': 10000,
    'batch_size': 32,
    'coord_samples': 512,
    'epochs_addon': 3,
    'lr': 1e-3,
    'P1': True, 'P2': True, 'P3': True, 'P4': False, 'P5': False,
    'lambda_recon': 1.0, 'lambda_inv': 0.1, 'lambda_infonce': 0.1, 'lambda_cycle': 0.1,
    'freeze_omnifield': True,
}

transform = transforms.Compose([transforms.ToTensor()])
train_ds = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_sub = Subset(train_ds, list(range(min(cfg['subset_size'], len(train_ds)))))
train_loader = DataLoader(train_sub, batch_size=cfg['batch_size'], shuffle=True, num_workers=0)
val_ds = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False)
print(f'Train batches: {len(train_loader)}, Val batches: {len(val_loader)}')


In [None]:
class OmniFieldMetricWrapper(nn.Module):
    '''Wraps pretrained OmniField + optional P1 canonicalizer and P2 feature heads.'''
    def __init__(self, model, fourier_encoder, coords_full, cfg, device):
        super().__init__()
        self.model = model
        self.fourier_encoder = fourier_encoder
        self.coords_full = coords_full
        self.cfg = cfg
        self.device = device
        self.residual_dim = 512
        self.queries_dim = QUERIES_DIM
        self.canon = Canonicalizer(code_dim=16, coord_dim=2) if cfg.get('P1') else None
        self.g_proj = nn.Linear(self.residual_dim, 16) if cfg.get('P1') else None
        self.heads = FeatureHeads(self.queries_dim, pe_dims=(8, 16, 32), head_dim=32) if cfg.get('P2') else None

    def forward(self, images, coords):
        B, N = coords.shape[0], coords.shape[1]
        input_data, _, _ = prepare_model_input(images, self.coords_full, self.fourier_encoder)
        residual = get_residual(self.model, input_data)
        g = self.g_proj(residual.mean(1)) if self.g_proj is not None else None
        x = coords
        if self.canon is not None and g is not None:
            x = self.canon(coords, g)
        queries = self.fourier_encoder(x)
        rgb, phi = get_rgb_and_phi(self.model, queries, residual)
        phi_list = self.heads(phi, x) if self.heads is not None else []
        return rgb, phi, phi_list, x
}

wrapper = OmniFieldMetricWrapper(model, fourier_encoder, coords_32, cfg, DEVICE).to(DEVICE)
if cfg.get('freeze_omnifield'):
    for p in model.parameters():
        p.requires_grad = False
    for p in fourier_encoder.parameters():
        p.requires_grad = False
params = [p for p in wrapper.parameters() if p.requires_grad]
print(f'Trainable params: {sum(p.numel() for p in params)}')


In [None]:
def eval_psnr_omnifield(wrapper, loader, device, grid_size=32):
    wrapper.eval()
    model.eval()
    fourier_encoder.eval()
    grid = make_grid_2d(grid_size, grid_size, device).unsqueeze(0)
    mse_sum, n = 0.0, 0
    with torch.no_grad():
        for imgs, _ in loader:
            imgs = imgs.to(device)
            B = imgs.size(0)
            coords = grid.expand(B, -1, -1)
            rgb, _, _, x = wrapper(imgs, coords)
            gt = sample_gt_at_coords(imgs, x)
            mse_sum += F.mse_loss(rgb, gt, reduction='sum').item()
            n += B * grid.size(1)
    mse = mse_sum / max(n, 1)
    return 10 * math.log10(1.0 / (mse + 1e-10))

print(f'OmniField (no add-ons) val PSNR: {eval_psnr_omnifield(wrapper, val_loader, DEVICE):.2f} dB')


## Train add-ons (P1–P3)


In [None]:
opt = torch.optim.Adam(params, lr=cfg['lr'])
for ep in range(cfg['epochs_addon']):
    wrapper.train()
    if not cfg.get('freeze_omnifield'):
        model.train()
        fourier_encoder.train()
    total = 0.0
    for imgs, _ in train_loader:
        imgs = imgs.to(DEVICE)
        B, C, H, W = imgs.shape
        N = cfg['coord_samples']
        coords = (torch.rand(B, N, 2, device=DEVICE) * 2 - 1)
        rgb, phi, phi_list, x = wrapper(imgs, coords)
        gt = sample_gt_at_coords(imgs, x)
        loss = cfg['lambda_recon'] * F.mse_loss(rgb, gt)
        if cfg.get('P3') and phi_list:
            jitter = torch.randn_like(coords, device=DEVICE) * 0.05
            coords_j = (coords + jitter).clamp(-1, 1)
            _, _, phi_j, _ = wrapper(imgs, coords_j)
            for i in range(min(2, len(phi_list))):
                loss = loss + cfg['lambda_inv'] * F.mse_loss(phi_list[i], phi_j[i])
        opt.zero_grad()
        loss.backward()
        opt.step()
        total += loss.item()
    print(f'Add-on epoch {ep+1} loss: {total/len(train_loader):.4f}')

print(f'Full (OmniField + P1+P2+P3) val PSNR: {eval_psnr_omnifield(wrapper, val_loader, DEVICE):.2f} dB')


## Visualizations


In [None]:
wrapper.eval()
imgs, _ = next(iter(val_loader))
imgs = imgs[:8].to(DEVICE)
B, C, H, W = imgs.shape
grid = make_grid_2d(32, 32, DEVICE).unsqueeze(0).expand(B, -1, -1)
with torch.no_grad():
    rgb, phi, phi_list, x = wrapper(imgs, grid)
rgb = rgb.view(B, 32, 32, 3).permute(0, 3, 1, 2)
recon_vis = rgb.cpu().float()
for b in range(B):
    rb = recon_vis[b]
    lo, hi = rb.min().item(), rb.max().item()
    if hi > lo:
        recon_vis[b] = (rb - lo) / (hi - lo)
    else:
        recon_vis[b] = rb
recon_vis = recon_vis.clamp(0, 1)
fig, axs = plt.subplots(3, 4, figsize=(12, 9))
for i in range(4):
    axs[0, i].imshow(imgs[i].cpu().permute(1, 2, 0).clamp(0, 1).numpy())
    axs[0, i].set_title('GT')
    axs[0, i].axis('off')
    axs[1, i].imshow(recon_vis[i].permute(1, 2, 0).numpy())
    axs[1, i].set_title('Recon (norm)')
    axs[1, i].axis('off')
    diff = (imgs[i].cpu() - recon_vis[i]).abs().permute(1, 2, 0).numpy()
    axs[2, i].imshow(diff)
    axs[2, i].set_title('|GT - Recon|')
    axs[2, i].axis('off')
plt.suptitle('OmniField + metric add-ons: 32x32 Reconstruction')
plt.tight_layout()
plt.savefig('omnifield_metric_recon.png', dpi=100)
plt.show()


In [None]:
if phi_list:
    z_np = phi_list[0].cpu().numpy()
    z_flat = z_np.reshape(-1, z_np.shape[-1])
    U, S, Vt = np.linalg.svd(z_flat, full_matrices=False)
    proj = (z_flat @ Vt[:, :3]).reshape(B, 32, 32, 3)
    p1, p99 = np.percentile(proj, [1, 99])
    proj = np.clip((proj - p1) / (p99 - p1 + 1e-8), 0, 1)
    fig, axs = plt.subplots(1, 4, figsize=(12, 3))
    for i in range(4):
        axs[i].imshow(proj[i])
        axs[i].set_title(f'φ_L PCA #{i+1}')
        axs[i].axis('off')
    plt.suptitle('OmniField φ (coarse head) -> PCA')
    plt.tight_layout()
    plt.show()


In [None]:
if phi_list:
    jitter_stds = [0.0, 0.02, 0.05, 0.1]
    drifts = []
    with torch.no_grad():
        _, _, phi0, _ = wrapper(imgs[:4], grid[:4])
        for sig in jitter_stds:
            j = torch.randn_like(grid[:4], device=DEVICE) * sig
            _, _, phi_j, _ = wrapper(imgs[:4], (grid[:4] + j).clamp(-1, 1))
            d = (1 - (phi0[0] * phi_j[0]).sum(-1).mean().item()) if phi0 else 0.0
            drifts.append(d)
    plt.figure(figsize=(5, 3))
    plt.plot(jitter_stds, drifts, 'o-')
    plt.xlabel('Jitter std')
    plt.ylabel('Mean 1 - cos(φ, φ_j)')
    plt.title('Invariance: feature drift vs jitter (OmniField φ)')
    plt.tight_layout()
    plt.show()


## Checklist (P1–P5 on OmniField)


In [None]:
checklist = {
    'P1 canonicalizer': 'stable coordinate domain; g from pooled residual',
    'P2 feature heads': 'φ = decoder hidden state; L/M/H band-limited heads',
    'P3 invariance': 'drift under jitter penalized',
    'P4 Soft InfoNCE': 'optional; enable with P4: True and two-view batch',
    'P5 cycle': 'optional; enable with P5: True',
}
for k, v in checklist.items():
    print(f'- {k} -> {v}')
print('\nBaseline: OmniField (pretrained). Add-ons trained on top; recon preserved.')