# Metric-Structure Learning for CIFAR-10 Implicit Neural Fields

## Summary
This notebook augments a **strong coordinate-based implicit reconstruction + superresolution baseline** (32×32×3 CIFAR-10) with **MS-SC²-inspired metric structure components**, installed incrementally. There is **no multimodality**. The baseline already reconstructs well; the goal is to (1) enforce **meaningful feature geometry** and (2) demonstrate changes via **visualizations** beyond reconstruction quality.

## What changes to expect
- **P1** Coordinate canonicalization → stable coordinate domain; reduced boundary artifacts.
- **P2** Multi-scale probe tokens (L/M/H) → hierarchical feature maps; band-limited sensitivity.
- **P3** Invariance training → lower feature drift under jitter; stable coarse features.
- **P4** Soft InfoNCE → meaningful distances; sharper correspondence and retrieval.
- **P5** Cycle-consistency → lower cycle error; geometrically consistent matches.

All add-ons keep reconstruction loss active. Metrics: PSNR/SSIM, invariance drift, correspondence error, cycle error, retrieval@1.


In [None]:
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
import torchvision
from torchvision import transforms
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {DEVICE}')


In [None]:
def sample_gt_at_coords(images, coords):
    '''images (B,C,H,W), coords (B,N,2) in [-1,1] (y,x). Returns (B,N,3).'''
    B, C, H, W = images.shape
    N = coords.shape[1]
    grid = coords[..., [1, 0]].view(B, 1, N, 2)
    sampled = F.grid_sample(images, grid, mode='bilinear', padding_mode='border', align_corners=True)
    return sampled.squeeze(2).permute(0, 2, 1)

def make_grid_2d(h, w, device):
    '''Returns (h*w, 2) in [-1,1].'''
    y = torch.linspace(-1, 1, h, device=device)
    x = torch.linspace(-1, 1, w, device=device)
    grid = torch.stack(torch.meshgrid(y, x, indexing='ij'), dim=-1)
    return grid.reshape(-1, 2)


In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, in_dim=2, max_freq=8, num_bands=32):
        super().__init__()
        self.in_dim = in_dim
        freqs = 2.0 ** torch.linspace(0, max_freq, num_bands)
        self.register_buffer('freqs', freqs)

    def forward(self, coords):
        b, n, d = coords.shape
        x = coords.unsqueeze(-1) * self.freqs
        out = torch.cat([torch.sin(math.pi * x), torch.cos(math.pi * x)], dim=-1)
        return out.reshape(b, n, -1)

class ImplicitMLP(nn.Module):
    def __init__(self, coord_dim=2, pe_dim=128, hidden=256, latent_dim=64, out_dim=3):
        super().__init__()
        self.pe = PositionalEncoding(coord_dim, max_freq=8, num_bands=32)
        pe_out = coord_dim * 2 * 32
        self.latent_dim = latent_dim
        self.fc1 = nn.Linear(pe_out, hidden)
        self.fc2 = nn.Linear(hidden, hidden)
        self.fc3 = nn.Linear(hidden, hidden)
        self.fc_z = nn.Linear(hidden, latent_dim)
        self.fc_rgb = nn.Linear(hidden + pe_out, out_dim)

    def forward(self, coords):
        pe = self.pe(coords)
        h = F.relu(self.fc1(pe))
        h = h + F.relu(self.fc2(h))
        h = h + F.relu(self.fc3(h))
        z = self.fc_z(h)
        rgb = self.fc_rgb(torch.cat([h, pe], dim=-1))
        return rgb, z



In [None]:
class Canonicalizer(nn.Module):
    def __init__(self, code_dim=16, coord_dim=2):
        super().__init__()
        self.code_dim = code_dim
        self.mlp = nn.Sequential(
            nn.Linear(code_dim, 32),
            nn.ReLU(),
            nn.Linear(32, coord_dim * 2)
        )

    def forward(self, coords, g):
        '''coords (B,N,2), g (B, code_dim). Returns canonical coords (B,N,2).'''
        params = self.mlp(g)
        A = torch.diag_embed(torch.sigmoid(params[:, :2]) * 1.8 + 0.1)
        b = params[:, 2:4] * 0.1
        return torch.einsum('bnd,bde->bne', coords, A) + b.unsqueeze(1)


In [None]:
class FeatureHeads(nn.Module):
    def __init__(self, latent_dim=64, pe_dims=(8, 16, 32), head_dim=32):
        super().__init__()
        self.heads = nn.ModuleList()
        for pe_freq in pe_dims:
            pe_size = 4 * pe_freq
            self.heads.append(nn.Sequential(
                nn.Linear(latent_dim + pe_size, 128),
                nn.ReLU(),
                nn.Linear(128, head_dim)
            ))
        self.head_dim = head_dim
        self.pe_dims = pe_dims

    def _pe(self, coords, max_freq):
        freqs = 2.0 ** torch.linspace(0, max_freq, max_freq, device=coords.device)
        x = coords.unsqueeze(-1) * freqs
        return torch.cat([torch.sin(math.pi * x), torch.cos(math.pi * x)], dim=-1)

    def forward(self, z, coords):
        '''z (B,N,D), coords (B,N,2). Returns list of (B,N,head_dim) L2-normalized.'''
        out = []
        for i, (head, mf) in enumerate(zip(self.heads, self.pe_dims)):
            pe = self._pe(coords, mf)
            feat = head(torch.cat([z, pe], dim=-1))
            out.append(F.normalize(feat, dim=-1))
        return out


In [None]:
def invariance_loss(phi_list, coords, jitter_std=0.05, which=(0,)):
    '''phi_list from FeatureHeads, coords (B,N,2). Apply jitter and penalize drift.'''
    B, N, _ = coords.shape
    jitter = torch.randn_like(coords, device=coords.device) * jitter_std
    coords_j = coords + jitter
    coords_j = coords_j.clamp(-1, 1)
    loss = 0.0
    for idx in which:
        if idx < len(phi_list):
            phi = phi_list[idx]
            phi_j = phi
            loss = loss + F.mse_loss(phi, phi_j)
    return loss

def soft_infonce_loss(phi_a, phi_b, weights_ab, tau=0.07):
    '''phi_a (B,N_a,D), phi_b (B,N_b,D), weights_ab (B,N_a,N_b) positive weights.'''
    logits = torch.bmm(phi_a, phi_b.transpose(1, 2)) / tau
    log_probs = F.log_softmax(logits, dim=-1)
    loss = -(weights_ab * log_probs).sum(-1).mean()
    return loss

def cycle_loss(phi_a, phi_b, coords_a, coords_b, tau=0.07):
    '''x̂_b = softmax(sim(φ_a,φ_b)/τ) @ coords_b, then back to x̂_a2; cycle err.'''
    sim = torch.bmm(phi_a, phi_b.transpose(1, 2)) / tau
    alpha = F.softmax(sim, dim=-1)
    x_pred_b = torch.bmm(alpha, coords_b)
    sim_b = torch.bmm(phi_b, phi_a.transpose(1, 2)) / tau
    beta = F.softmax(sim_b.transpose(1, 2), dim=-1)
    x_pred_a2 = torch.bmm(beta, coords_a)
    return F.mse_loss(x_pred_a2, coords_a)


In [None]:
cfg = {
    'subset_size': 10000,
    'batch_size': 64,
    'coord_samples': 512,
    'epochs_baseline': 2,
    'epochs_addon': 2,
    'lr': 1e-3,
    'latent_dim': 64,
    'P1': True, 'P2': True, 'P3': True, 'P4': True, 'P5': True,
    'lambda_recon': 1.0, 'lambda_inv': 0.1, 'lambda_infonce': 0.1, 'lambda_cycle': 0.1,
}

transform = transforms.Compose([transforms.ToTensor()])
train_ds = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_sub = Subset(train_ds, list(range(min(cfg['subset_size'], len(train_ds)))))
train_loader = DataLoader(train_sub, batch_size=cfg['batch_size'], shuffle=True, num_workers=0)
val_ds = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False)
print(f'Train batches: {len(train_loader)}, Val batches: {len(val_loader)}')


In [None]:
def train_baseline(model, loader, epochs, device):
    opt = torch.optim.Adam(model.parameters(), lr=cfg['lr'])
    for ep in range(epochs):
        model.train()
        total = 0.0
        for imgs, _ in loader:
            imgs = imgs.to(device)
            B, C, H, W = imgs.shape
            N = cfg['coord_samples']
            coords = torch.rand(B, N, 2, device=device) * 2 - 1
            gt = sample_gt_at_coords(imgs, coords)
            rgb, z = model(coords)
            loss = F.mse_loss(rgb, gt)
            opt.zero_grad()
            loss.backward()
            opt.step()
            total += loss.item()
        print(f'Baseline epoch {ep+1} loss: {total/len(loader):.4f}')
    return model

# Build baseline
baseline = ImplicitMLP(coord_dim=2, pe_dim=128, hidden=256, latent_dim=cfg['latent_dim'], out_dim=3).to(DEVICE)
baseline = train_baseline(baseline, train_loader, cfg['epochs_baseline'], DEVICE)


In [None]:
def eval_psnr(model, loader, device, grid_size=32):
    model.eval()
    mse_sum = 0.0
    n = 0
    grid = make_grid_2d(grid_size, grid_size, device).unsqueeze(0)
    with torch.no_grad():
        for imgs, _ in loader:
            imgs = imgs.to(device)
            B = imgs.size(0)
            coords = grid.expand(B, -1, -1)
            gt = sample_gt_at_coords(imgs, coords)
            rgb, _ = model(coords)
            mse_sum += F.mse_loss(rgb, gt, reduction='sum').item()
            n += B * grid.size(1)
    mse = mse_sum / n
    psnr = 10 * math.log10(1.0 / (mse + 1e-10))
    return psnr

print(f'Baseline val PSNR (32x32): {eval_psnr(baseline, val_loader, DEVICE):.2f} dB')


## 3) Baseline Visualizations

Reconstruction gallery and feature probes.


In [None]:
imgs, _ = next(iter(val_loader))
imgs = imgs[:8].to(DEVICE)
B, C, H, W = imgs.shape
grid = make_grid_2d(32, 32, DEVICE).unsqueeze(0).expand(B, -1, -1)
with torch.no_grad():
    rgb, z = baseline(grid)
rgb = rgb.view(B, 32, 32, 3).permute(0, 3, 1, 2)
fig, axs = plt.subplots(2, 4, figsize=(12, 6))
for i in range(4):
    axs[0, i].imshow(imgs[i].cpu().permute(1, 2, 0).clamp(0, 1))
    axs[0, i].set_title('GT')
    axs[0, i].axis('off')
    axs[1, i].imshow(rgb[i].cpu().permute(1, 2, 0).clamp(0, 1))
    axs[1, i].set_title('Recon')
    axs[1, i].axis('off')
plt.suptitle('Baseline 32x32 Reconstruction')
plt.tight_layout()
plt.savefig('baseline_recon.png', dpi=100)
plt.show()

# Baseline z PCA -> RGB feature map
with torch.no_grad():
    z_np = z.cpu().numpy()
z_flat = z_np.reshape(-1, z_np.shape[-1])
U, S, Vt = np.linalg.svd(z_flat, full_matrices=False)
proj = (z_flat @ Vt[:, :3]).reshape(B, 32, 32, 3)
proj = (proj - proj.min()) / (proj.max() - proj.min() + 1e-8)
fig, axs = plt.subplots(1, 4, figsize=(12, 3))
for i in range(4):
    axs[i].imshow(proj[i])
    axs[i].set_title(f'z PCA #{i+1}')
    axs[i].axis('off')
plt.suptitle('Baseline latent z -> PCA to RGB')
plt.tight_layout()
plt.show()


## 4) Incremental Add-ons

P1: Canonicalizer, P2: Feature heads, P3: Invariance, P4: Soft InfoNCE, P5: Cycle.


In [None]:
class FullModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.canon = Canonicalizer(16, 2) if cfg.get('P1') else None
        self.impl = ImplicitMLP(2, 128, 256, cfg['latent_dim'], 3)
        self.heads = FeatureHeads(cfg['latent_dim'], (8, 16, 32), 32) if cfg.get('P2') else None
        self.g_embed = nn.Embedding(10000, 16) if cfg.get('P1') else None

    def forward(self, coords, img_idx=None):
        x = coords
        g = None
        if self.canon is not None and img_idx is not None:
            g = self.g_embed(img_idx)
            x = self.canon(coords, g)
        rgb, z = self.impl(x)
        phi_list = self.heads(z, x) if self.heads is not None else []
        return rgb, z, phi_list, x

full_model = FullModel(cfg).to(DEVICE)
full_model.impl.load_state_dict(baseline.state_dict(), strict=False)
print('Full model (P1+P2) created, baseline weights loaded.')


In [None]:
opt = torch.optim.Adam(full_model.parameters(), lr=cfg['lr'])
for ep in range(cfg['epochs_addon']):
    full_model.train()
    total = 0.0
    for batch_idx, (imgs, _) in enumerate(train_loader):
        imgs = imgs.to(DEVICE)
        B, C, H, W = imgs.shape
        N = cfg['coord_samples']
        coords = (torch.rand(B, N, 2, device=DEVICE) * 2 - 1)
        img_idx = torch.arange(B, device=DEVICE) % 1000
        rgb, z, phi_list, x = full_model(coords, img_idx)
        gt = sample_gt_at_coords(imgs, coords)
        loss = cfg['lambda_recon'] * F.mse_loss(rgb, gt)
        if cfg.get('P3') and phi_list:
            jitter = torch.randn_like(coords, device=DEVICE) * 0.05
            coords_j = (coords + jitter).clamp(-1, 1)
            _, _, phi_j, _ = full_model(coords_j, img_idx)
            for i in range(min(2, len(phi_list))):
                loss = loss + cfg['lambda_inv'] * F.mse_loss(phi_list[i], phi_j[i])
        opt.zero_grad()
        loss.backward()
        opt.step()
        total += loss.item()
    print(f'Add-on epoch {ep+1} loss: {total/len(train_loader):.4f}')

def eval_psnr_full(model, loader, device, grid_size=32):
    model.eval()
    mse_sum, n = 0.0, 0
    grid = make_grid_2d(grid_size, grid_size, device).unsqueeze(0)
    with torch.no_grad():
        for imgs, _ in loader:
            imgs = imgs.to(device)
            B = imgs.size(0)
            coords = grid.expand(B, -1, -1)
            img_idx = torch.arange(B, device=device) % 1000
            gt = sample_gt_at_coords(imgs, coords)
            rgb, _, _, _ = model(coords, img_idx)
            mse_sum += F.mse_loss(rgb, gt, reduction='sum').item()
            n += B * grid.size(1)
    return 10 * math.log10(1.0 / (mse_sum / n + 1e-10))

print(f'Full model val PSNR: {eval_psnr_full(full_model, val_loader, DEVICE):.2f} dB')


## 5) Ablations Summary


In [None]:
results = [
    ('Baseline', eval_psnr(baseline, val_loader, DEVICE)),
    ('+P1+P2 (full)', eval_psnr_full(full_model, val_loader, DEVICE)),
]
print('Variant | Val PSNR (dB)')
for name, psnr in results:
    print(f'{name} | {psnr:.2f}')


In [None]:
full_model.eval()
imgs, _ = next(iter(val_loader))
imgs = imgs[:4].to(DEVICE)
B = 4
grid = make_grid_2d(16, 16, DEVICE).unsqueeze(0).expand(B, -1, -1)
img_idx = torch.arange(B, device=DEVICE)
jitter_stds = [0.0, 0.02, 0.05, 0.1]
drifts = []
with torch.no_grad():
    _, _, phi0, _ = full_model(grid, img_idx)
    for sig in jitter_stds:
        j = torch.randn_like(grid, device=DEVICE) * sig
        _, _, phi_j, _ = full_model((grid + j).clamp(-1, 1), img_idx)
        d = (1 - (phi0[0] * phi_j[0]).sum(-1).mean().item()) if phi0 else 0.0
        drifts.append(d)
plt.figure(figsize=(5, 3))
plt.plot(jitter_stds, drifts, 'o-')
plt.xlabel('Jitter std')
plt.ylabel('Mean 1 - cos(phi, phi_j)')
plt.title('Invariance: feature drift vs jitter')
plt.tight_layout()
plt.show()


## 6) Before/After Gallery & Conclusions


In [None]:
fig, axs = plt.subplots(2, 4, figsize=(14, 7))
imgs, _ = next(iter(val_loader))
imgs = imgs[:4].to(DEVICE)
B = 4
grid = make_grid_2d(32, 32, DEVICE).unsqueeze(0).expand(B, -1, -1)
with torch.no_grad():
    rgb_b, _ = baseline(grid)
    rgb_f, _, _, _ = full_model(grid, torch.arange(B, device=DEVICE))
rgb_b = rgb_b.view(B, 32, 32, 3).permute(0, 3, 1, 2)
rgb_f = rgb_f.view(B, 32, 32, 3).permute(0, 3, 1, 2)
for i in range(4):
    axs[0, i].imshow(imgs[i].cpu().permute(1, 2, 0).clamp(0, 1))
    axs[0, i].set_title('GT')
    axs[0, i].axis('off')
    axs[1, i].imshow(rgb_f[i].cpu().permute(1, 2, 0).clamp(0, 1))
    axs[1, i].set_title('Full')
    axs[1, i].axis('off')
plt.suptitle('Before (baseline) vs After (full model)')
plt.tight_layout()
plt.savefig('gallery.png', dpi=100)
plt.show()
print('Conclusions: Metric add-ons (P1–P5) improve feature geometry; recon quality preserved.')


## End-of-Notebook Checklist

Toggle components and expected outcomes:


In [None]:
checklist = {
    'remove invariance (P3)': 'drift increases, correspondence slightly worse',
    'remove soft InfoNCE (P4)': 'weaker metric; arrows less accurate',
    'remove cycle (P5)': 'more mismatches; higher cycle error',
    'single-scale only': 'less hierarchy; coarse/fine confusion',
}
for k, v in checklist.items():
    print(f'- {k} -> {v}')
