# Synthetic OmniField Experiments

Runnable proofs for the four properties in **SYNTHETIC_OMNIFIELD_EXPERIMENTS.md**:
1. **Geometry in representation** — correspondence under known warp T
2. **Multi-scale semantics** — probe at coarse vs. fine scale
3. **Decoupling from discretization** — train at one resolution, eval at others + random coords
4. **Continuous correspondence refinement** — gradient ascent on similarity in y

This notebook implements **synthetic data** and **Experiment 3** first (easiest); stubs for 1, 2, 4.

In [None]:
import math
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from einops import rearrange, repeat
import matplotlib.pyplot as plt

from nf_feature_models import (
    CascadedPerceiverIO,
    GaussianFourierFeatures,
    create_coordinate_grid,
    prepare_model_input,
)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', DEVICE)


## Synthetic datasets

In [None]:
def continuous_field_example(h, w, device):
    """Continuous scalar field f(x,y) on [-1,1]^2 for Exp 3. Returns (H,W) tensor."""
    y_coord = torch.linspace(-1, 1, h, device=device)
    x_coord = torch.linspace(-1, 1, w, device=device)
    yy, xx = torch.meshgrid(y_coord, x_coord, indexing='ij')
    # Example: distance-like field + smooth bump
    f = torch.exp(-(xx**2 + yy**2) / 0.5) + 0.5 * torch.sin(3 * xx) * torch.cos(2 * yy)
    return f

def field_to_rgb(f):
    """Map scalar field to 3-channel image for OmniField (context = RGB)."""
    f = (f - f.min()) / (f.max() - f.min() + 1e-8)
    return f.unsqueeze(0).expand(3, -1, -1)

class ContinuousFieldDataset(Dataset):
    """Synthetic dataset for Exp 3: continuous field sampled at (train_res, train_res)."""
    def __init__(self, num_samples, res=32, device='cpu'):
        self.res = res
        self.device = device
        self.num_samples = num_samples
    def __len__(self):
        return self.num_samples
    def __getitem__(self, i):
        torch.manual_seed(i)
        f = continuous_field_example(self.res, self.res, self.device)
        img = field_to_rgb(f).clamp(0, 1)
        return img, f  # image (3,H,W), scalar field (H,W) for GT at this res


In [None]:
def sample_affine_warp(device, scale_range=(0.85, 1.0), max_translate=0.1, max_angle_deg=10):
    angle = (torch.rand(1, device=device).item() * 2 - 1) * (max_angle_deg * math.pi / 180)
    scale = scale_range[0] + (scale_range[1] - scale_range[0]) * torch.rand(1, device=device).item()
    tx = (torch.rand(1, device=device).item() * 2 - 1) * max_translate
    ty = (torch.rand(1, device=device).item() * 2 - 1) * max_translate
    c, s = math.cos(angle), math.sin(angle)
    R = torch.tensor([[c*scale, -s*scale], [s*scale, c*scale]], device=device)
    t = torch.tensor([tx, ty], device=device)
    return R, t

def apply_warp_to_coords(coords, R, t):
    return (coords @ R.T) + t

def apply_warp_to_image(images, R, t):
    R_inv = torch.inverse(R.unsqueeze(0)).squeeze(0)
    theta = torch.cat([R_inv, -(R_inv @ t.unsqueeze(1))], dim=1).unsqueeze(0)
    grid = F.affine_grid(theta.expand(images.size(0), -1, -1), images.size(), align_corners=True)
    return F.grid_sample(images, grid, mode='bilinear', padding_mode='border', align_corners=True)

class WarpedPairsDataset(Dataset):
    """For Exp 1 & 4: (image_A, image_B, R, t) with B = warp(A)."""
    def __init__(self, num_samples, res=32, device='cpu'):
        self.res = res
        self.device = device
        self.num_samples = num_samples
    def __len__(self):
        return self.num_samples
    def __getitem__(self, i):
        torch.manual_seed(i)
        f = continuous_field_example(self.res, self.res, self.device)
        img_a = field_to_rgb(f).clamp(0, 1)
        R, t = sample_affine_warp(self.device)
        img_b = apply_warp_to_image(img_a.unsqueeze(0), R.unsqueeze(0), t.unsqueeze(0)).squeeze(0)
        return img_a, img_b, R, t


## Experiment 3: Decoupling representation from discretization

Train on 32×32; evaluate at 32×32, 64×64, and at **random sub-pixel coordinates**. OmniField should maintain low error at all; a discrete baseline would only be defined on the training grid.

In [None]:
IMAGE_SIZE = 32
CHANNELS = 3
FOURIER_MAPPING_SIZE = 96
POS_EMBED_DIM = FOURIER_MAPPING_SIZE * 2
INPUT_DIM = CHANNELS + POS_EMBED_DIM
QUERIES_DIM = POS_EMBED_DIM
LOGITS_DIM = CHANNELS

coords_32 = create_coordinate_grid(IMAGE_SIZE, IMAGE_SIZE, DEVICE)
fourier_encoder = GaussianFourierFeatures(2, FOURIER_MAPPING_SIZE, 15.0).to(DEVICE)
model = CascadedPerceiverIO(
    input_dim=INPUT_DIM, queries_dim=QUERIES_DIM, logits_dim=LOGITS_DIM,
    latent_dims=(256, 384, 512), num_latents=(256, 256, 256), decoder_ff=True,
).to(DEVICE)

train_ds = ContinuousFieldDataset(500, res=IMAGE_SIZE, device=DEVICE)
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
print('Train batches:', len(train_loader))


In [None]:
optimizer = torch.optim.Adam(list(model.parameters()) + list(fourier_encoder.parameters()), lr=1e-3)
for epoch in range(5):
    model.train()
    fourier_encoder.train()
    total = 0.0
    for imgs, fields in train_loader:
        imgs = imgs.to(DEVICE)
        B = imgs.size(0)
        input_data, pixels, _ = prepare_model_input(imgs, coords_32, fourier_encoder)
        queries = fourier_encoder(coords_32.unsqueeze(0).expand(B, -1, -1))
        recon = model(input_data, queries=queries)
        target = rearrange(imgs, 'b c h w -> b (h w) c')
        loss = F.mse_loss(recon, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total += loss.item()
    print(f'Epoch {epoch+1} recon loss: {total/len(train_loader):.4f}')


In [None]:
def eval_at_resolution(model, fourier_encoder, coords_train, res_eval, train_ds, num_samples=20, device=DEVICE):
    """Evaluate at resolution res_eval; context is always at training grid (32x32). GT uses same seed as sample."""
    model.eval()
    fourier_encoder.eval()
    coords_eval = create_coordinate_grid(res_eval, res_eval, device)
    mse_sum, n = 0.0, 0
    with torch.no_grad():
        for i in range(num_samples):
            img, _ = train_ds[i]
            img = img.unsqueeze(0).to(device)
            B = 1
            input_data, _, _ = prepare_model_input(img, coords_train, fourier_encoder)
            queries = fourier_encoder(coords_eval.unsqueeze(0).expand(B, -1, -1))
            recon = model(input_data, queries=queries)
            torch.manual_seed(i)
            gt_at_eval = continuous_field_example(res_eval, res_eval, device)
            gt_rgb = field_to_rgb(gt_at_eval).unsqueeze(0).to(device)
            target = rearrange(gt_rgb, 'b c h w -> b (h w) c')
            mse_sum += F.mse_loss(recon, target).item()
            n += 1
    return mse_sum / max(n, 1)

def eval_at_random_coords(model, fourier_encoder, coords_train, train_ds, num_points=500, num_samples=20, device=DEVICE):
    """Query at random (sub-pixel) coordinates; GT from continuous field (same seed as sample)."""
    model.eval()
    fourier_encoder.eval()
    mse_sum, n = 0.0, 0
    with torch.no_grad():
        for i in range(num_samples):
            img, _ = train_ds[i]
            img = img.unsqueeze(0).to(device)
            B = 1
            input_data, _, _ = prepare_model_input(img, coords_train, fourier_encoder)
            coords_rand = (torch.rand(num_points, 2, device=device) * 2 - 1)
            queries = fourier_encoder(coords_rand.unsqueeze(0).expand(B, -1, -1))
            recon = model(input_data, queries=queries)
            torch.manual_seed(i)
            f_hr = continuous_field_example(128, 128, device)
            sample_grid = coords_rand.view(1, num_points, 1, 2)
            f_at_rand = F.grid_sample(f_hr.unsqueeze(0).unsqueeze(0), sample_grid, mode='bilinear', align_corners=True).squeeze()
            f_norm = (f_at_rand - f_at_rand.min()) / (f_at_rand.max() - f_at_rand.min() + 1e-8)
            target = f_norm.unsqueeze(0).unsqueeze(-1).expand(1, num_points, 3)
            mse_sum += F.mse_loss(recon, target).item()
            n += 1
    return mse_sum / max(n, 1)


In [None]:
# Eval at training res (32), double res (64), and random coords
mse_32 = eval_at_resolution(model, fourier_encoder, coords_32, 32, train_ds)
mse_64 = eval_at_resolution(model, fourier_encoder, coords_32, 64, train_ds)
mse_rand = eval_at_random_coords(model, fourier_encoder, coords_32, train_ds)
print('MSE @ 32x32:', mse_32)
print('MSE @ 64x64:', mse_64)
print('MSE @ random coords:', mse_rand)
print('Decoupling: model was trained only at 32x32 but can query at 64x64 and arbitrary points.')


## Experiment 1 (stub): Geometry in representation

Use `WarpedPairsDataset`; encode A and B; for anchors in A compute φ_A(x), in B compute φ_B(y) on a grid; rank y by similarity; measure Recall@k where true match is T(x). See SYNTHETIC_OMNIFIELD_EXPERIMENTS.md.

In [None]:
# TODO: load warped pairs; get_residual + projection head for phi; compute rank/Recall@k
warped_ds = WarpedPairsDataset(100, res=IMAGE_SIZE, device=DEVICE)
img_a, img_b, R, t = warped_ds[0]
print('Warped pair shapes:', img_a.shape, img_b.shape, 'R', R.shape, 't', t.shape)


## Experiment 4 (stub): Continuous correspondence refinement

For anchor x in A, maximize S(y) = similarity(φ_A(x), φ_B(y)) over continuous y via gradient ascent. Requires y.requires_grad_(True) and backprop through decoder to y. See doc.

In [None]:
# TODO: differentiable y; gradient ascent; report error before/after refinement
print('Stub: implement refinement loop with decoder(queries=GFF(y), context=residual_b), y.requires_grad_(True)')
