# Semantic Token Generator (Foundation-Field) — CIFAR-10

Treat the neural field as a **continuous, resolution-agnostic token generator**:

> *f(x) → [RGB, φ(x)]*

with **φ(x)** L2-normalized semantic features. Techniques are implemented as **distinguishable components**:

1. **Dual output** — RGB (reconstruction) + φ (semantic tokens)
2. **Reconstruction loss** — MSE on RGB
3. **Geometry-aware contrastive** — Soft InfoNCE φ_A(x) ↔ φ_B(T(x)) with known affine T
4. **Multi-density coordinate sampling** — full grid, sparse 64/256/512/1024 (sampling-rate invariance)
5. **Resolution-agnostic query** — after training, query φ at 16×16, 32×32, 64×64

Checkpoints: **checkpoint_semantic_token_best.pt** / **_last.pt**

In [None]:
import os
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
from einops import rearrange, repeat

from nf_feature_models import (
    CascadedPerceiverIO,
    GaussianFourierFeatures,
    create_coordinate_grid,
    prepare_model_input,
)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT_DIR = "checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
print("Device:", DEVICE)

## Config — all techniques toggleable

In [None]:
IMAGE_SIZE = 32
CHANNELS = 3
FOURIER_MAPPING_SIZE = 96
POS_EMBED_DIM = FOURIER_MAPPING_SIZE * 2
INPUT_DIM = CHANNELS + POS_EMBED_DIM
QUERIES_DIM = POS_EMBED_DIM
LOGITS_DIM = CHANNELS
BATCH_SIZE = 64
EPOCHS = 30

# 1) Reconstruction
USE_RECON_LOSS = True

# 2) Geometry-aware contrastive (Soft InfoNCE)
USE_CONTRASTIVE = True
NCE_RAMP_STEPS = 500
LAMBDA_NCE = 0.1
N_ANCHORS, N_CANDIDATES = 256, 1024
TAU, SIGMA = 0.1, 0.08

# 3) Semantic head (φ output)
PHI_DIM = 128

# 4) Multi-density sampling: one of ['full', 'sparse_64', 'sparse_256', 'sparse_512', 'sparse_1024'] per batch
SAMPLING_MODES = ["full", "sparse_64", "sparse_256", "sparse_512", "sparse_1024"]
N_FULL = IMAGE_SIZE * IMAGE_SIZE
SAMPLING_N = {"full": N_FULL, "sparse_64": 64, "sparse_256": 256, "sparse_512": 512, "sparse_1024": min(1024, N_FULL)}

coords_32 = create_coordinate_grid(IMAGE_SIZE, IMAGE_SIZE, DEVICE)
print("N_FULL:", N_FULL, "| SAMPLING_MODES:", SAMPLING_MODES)

## 1) Field output: RGB + φ (semantic tokens)

Decoder produces one vector per query; we split into **RGB** (reconstruction) and **φ** (L2-normalized semantic token).

In [None]:
def sample_gt_at_coords(images, coords):
    B, C, H, W = images.shape
    N = coords.shape[1]
    grid = coords[..., [1, 0]].view(B, 1, N, 2)
    sampled = F.grid_sample(images, grid, mode="bilinear", padding_mode="border", align_corners=True)
    return sampled.squeeze(2).permute(0, 2, 1)

def get_residual(model, context):
    """Encoder + processor → (B, num_latents, latent_dim)."""
    residual = None
    for block in model.encoder_blocks:
        residual = block(x=residual, context=context, mask=None, residual=residual)
    for sa_block in model.self_attn_blocks:
        residual = sa_block[0](residual) + residual
        residual = sa_block[1](residual) + residual
    return residual

def decoder_forward(model, queries, residual):
    x = model.decoder_cross_attn(queries, context=residual)
    x = x + queries
    if model.decoder_ff is not None:
        x = x + model.decoder_ff(x)
    return x

def get_rgb(model, queries, residual):
    """(B, N, 3) — reconstruction target."""
    h = decoder_forward(model, queries, residual)
    return model.to_logits(h)

def get_phi_raw(model, queries, residual):
    """(B, N, QUERIES_DIM) — pre-normalized semantic features."""
    return decoder_forward(model, queries, residual)

def get_semantic_tokens(phi_raw, semantic_head):
    """φ(x) L2-normalized, shared across images. (B, N, PHI_DIM)."""
    return F.normalize(semantic_head(phi_raw), dim=-1)

## 2) Context builders — full vs sparse (multi-density)

In [None]:
def prepare_full_context(images, coords_full, fourier_encoder):
    """Full grid: (B, N_FULL, INPUT_DIM)."""
    input_full, _, _ = prepare_model_input(images, coords_full, fourier_encoder)
    return input_full

def prepare_sparse_context(images, coords_full, fourier_encoder, num_sparse, device):
    """Random subset of num_sparse coords. (B, num_sparse, INPUT_DIM)."""
    B = images.size(0)
    idx = torch.randperm(coords_full.size(0), device=device)[:num_sparse]
    coords_sparse = coords_full[idx]
    pixels = sample_gt_at_coords(images, coords_sparse.unsqueeze(0).expand(B, -1, -1))
    pos = fourier_encoder(coords_sparse.unsqueeze(0).expand(B, -1, -1))
    return torch.cat([pixels, pos], dim=-1)

def sample_context_for_batch(images, coords_full, fourier_encoder, device):
    """Multi-density: pick a mode at random, return context and number of points."""
    mode = random.choice(SAMPLING_MODES)
    n = SAMPLING_N[mode]
    if mode == "full":
        return prepare_full_context(images, coords_full, fourier_encoder), n
    return prepare_sparse_context(images, coords_full, fourier_encoder, n, device), n

## 3) Geometry-aware contrastive — Soft InfoNCE

Two views A, B with known affine T. Match φ_A(x) ↔ φ_B(T(x)) with soft weights by spatial distance.

In [None]:
def sample_affine_params(batch_size, device, scale_range=(0.85, 1.0), max_translate=0.1, max_angle_deg=12):
    angle = (torch.rand(batch_size, device=device) * 2 - 1) * (max_angle_deg * math.pi / 180)
    scale = scale_range[0] + torch.rand(batch_size, device=device) * (scale_range[1] - scale_range[0])
    tx = (torch.rand(batch_size, device=device) * 2 - 1) * max_translate
    ty = (torch.rand(batch_size, device=device) * 2 - 1) * max_translate
    c, s = torch.cos(angle), torch.sin(angle)
    R = torch.stack([c*scale, -s*scale, s*scale, c*scale], dim=-1).view(batch_size, 2, 2)
    t = torch.stack([tx, ty], dim=1)
    return R, t

def apply_affine_to_coords(coords, R, t):
    return torch.einsum("bed,bnd->bne", R, coords) + t.unsqueeze(1)

def apply_affine_to_image(images, R, t):
    R_inv = torch.inverse(R)
    theta = torch.cat([R_inv, -(R_inv @ t.unsqueeze(2))], dim=2)
    grid = F.affine_grid(theta, images.size(), align_corners=True)
    return F.grid_sample(images, grid, mode="bilinear", padding_mode="border", align_corners=True)

def soft_infonce_loss(phi_a, phi_b, coords_a, coords_b, R, t, tau=0.1, sigma=0.08):
    """Geometry-aware: φ_A(x) ↔ φ_B(T(x)); soft assignment by spatial distance."""
    logits = torch.bmm(phi_a, phi_b.transpose(1, 2)) / tau
    coords_a_mapped = apply_affine_to_coords(coords_a, R, t)
    sqd = ((coords_b.unsqueeze(1) - coords_a_mapped.unsqueeze(2)) ** 2).sum(-1)
    w = torch.exp(-sqd / (2 * sigma ** 2))
    w = w / (w.sum(dim=2, keepdim=True) + 1e-8)
    return -(w * F.log_softmax(logits, dim=-1)).sum(-1).mean()

## Model + semantic head

In [None]:
fourier_encoder = GaussianFourierFeatures(2, FOURIER_MAPPING_SIZE, 15.0).to(DEVICE)
model = CascadedPerceiverIO(
    input_dim=INPUT_DIM,
    queries_dim=QUERIES_DIM,
    logits_dim=LOGITS_DIM,
    latent_dims=(256, 384, 512),
    num_latents=(256, 256, 256),
    decoder_ff=True,
).to(DEVICE)
semantic_head = nn.Linear(QUERIES_DIM, PHI_DIM).to(DEVICE)
print("Model + semantic_head (φ dim =", PHI_DIM, ")")

In [None]:
transform = transforms.Compose([transforms.ToTensor()])
train_ds = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
test_ds = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
print("Train batches:", len(train_loader), "Test batches:", len(test_loader))

## Training loop — reconstruction + contrastive, multi-density sampling

In [None]:
params = list(model.parameters()) + list(fourier_encoder.parameters()) + list(semantic_head.parameters())
optimizer = torch.optim.Adam(params, lr=1e-3)
step = [0]
best_val_loss = float("inf")
CKPT_BEST = os.path.join(CHECKPOINT_DIR, "checkpoint_semantic_token_best.pt")
CKPT_LAST = os.path.join(CHECKPOINT_DIR, "checkpoint_semantic_token_last.pt")

In [None]:
for epoch in range(EPOCHS):
    model.train()
    fourier_encoder.train()
    semantic_head.train()
    total_loss = 0.0
    total_recon = 0.0
    total_nce = 0.0
    for imgs, _ in train_loader:
        imgs = imgs.to(DEVICE)
        B = imgs.size(0)

        # 4) Multi-density sampling: random mode per batch
        context_a, n_ctx = sample_context_for_batch(imgs, coords_32, fourier_encoder, DEVICE)
        residual_a = get_residual(model, context_a)

        queries_full = fourier_encoder(coords_32.unsqueeze(0).expand(B, -1, -1))
        target_pixels = rearrange(imgs, "b c h w -> b (h w) c")

        # 1) RGB output + 2) Reconstruction loss
        rgb = get_rgb(model, queries_full, residual_a)
        loss_recon = F.mse_loss(rgb, target_pixels) if USE_RECON_LOSS else torch.tensor(0.0, device=DEVICE)
        loss = loss_recon

        # 3) Geometry-aware contrastive (Soft InfoNCE)
        if USE_CONTRASTIVE:
            R, t = sample_affine_params(B, DEVICE)
            imgs_b = apply_affine_to_image(imgs, R, t)
            context_b, _ = sample_context_for_batch(imgs_b, coords_32, fourier_encoder, DEVICE)
            residual_b = get_residual(model, context_b)
            anchors_a = torch.rand(B, N_ANCHORS, 2, device=DEVICE) * 2 - 1
            candidates_b = torch.rand(B, N_CANDIDATES, 2, device=DEVICE) * 2 - 1
            q_a = fourier_encoder(anchors_a)
            q_b = fourier_encoder(candidates_b)
            phi_raw_a = get_phi_raw(model, q_a, residual_a)
            phi_raw_b = get_phi_raw(model, q_b, residual_b)
            phi_a = get_semantic_tokens(phi_raw_a, semantic_head)
            phi_b = get_semantic_tokens(phi_raw_b, semantic_head)
            lam = LAMBDA_NCE if step[0] >= NCE_RAMP_STEPS else LAMBDA_NCE * (step[0] / NCE_RAMP_STEPS)
            loss_nce = soft_infonce_loss(phi_a, phi_b, anchors_a, candidates_b, R, t, TAU, SIGMA)
            loss = loss + lam * loss_nce
            total_nce += loss_nce.item()
            step[0] += 1

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        total_recon += loss_recon.item()

    avg_loss = total_loss / len(train_loader)
    avg_recon = total_recon / len(train_loader)
    model.eval()
    fourier_encoder.eval()
    semantic_head.eval()
    val_loss = 0.0
    with torch.no_grad():
        for imgs, _ in test_loader:
            imgs = imgs.to(DEVICE)
            B = imgs.size(0)
            context = prepare_full_context(imgs, coords_32, fourier_encoder)
            residual = get_residual(model, context)
            queries_full = fourier_encoder(coords_32.unsqueeze(0).expand(B, -1, -1))
            rgb = get_rgb(model, queries_full, residual)
            target_pixels = rearrange(imgs, "b c h w -> b (h w) c")
            val_loss += F.mse_loss(rgb, target_pixels).item()
    val_loss /= len(test_loader)
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            "epoch": epoch + 1,
            "model_state_dict": model.state_dict(),
            "fourier_encoder_state_dict": fourier_encoder.state_dict(),
            "semantic_head_state_dict": semantic_head.state_dict(),
            "best_val_loss": best_val_loss,
        }, CKPT_BEST)
    torch.save({
        "epoch": epoch + 1,
        "model_state_dict": model.state_dict(),
        "fourier_encoder_state_dict": fourier_encoder.state_dict(),
        "semantic_head_state_dict": semantic_head.state_dict(),
        "best_val_loss": best_val_loss,
    }, CKPT_LAST)
    nce_str = f" NCE: {total_nce/len(train_loader):.4f}" if USE_CONTRASTIVE else ""
    print(f"Epoch {epoch+1}/{EPOCHS} recon: {avg_recon:.4f}{nce_str} val_loss: {val_loss:.4f}")

## 5) Resolution-agnostic query — extract φ at arbitrary coordinates

Same model; query semantic tokens at 16×16, 32×32, 64×64 (no retraining).

In [None]:
def query_semantic_tokens_at_resolution(model, semantic_head, fourier_encoder, residual, res_h, res_w, device):
    """Query φ at (res_h, res_w) grid. residual: (B, num_latents, latent_dim) from one encoder pass."""
    coords = create_coordinate_grid(res_h, res_w, device)
    B = residual.size(0)
    queries = fourier_encoder(coords.unsqueeze(0).expand(B, -1, -1))
    phi_raw = get_phi_raw(model, queries, residual)
    return get_semantic_tokens(phi_raw, semantic_head)

model.eval()
fourier_encoder.eval()
semantic_head.eval()
imgs, _ = next(iter(test_loader))
imgs = imgs[:4].to(DEVICE)
context = prepare_full_context(imgs, coords_32, fourier_encoder)
with torch.no_grad():
    residual = get_residual(model, context)
    phi_16 = query_semantic_tokens_at_resolution(model, semantic_head, fourier_encoder, residual, 16, 16, DEVICE)
    phi_32 = query_semantic_tokens_at_resolution(model, semantic_head, fourier_encoder, residual, 32, 32, DEVICE)
    phi_64 = query_semantic_tokens_at_resolution(model, semantic_head, fourier_encoder, residual, 64, 64, DEVICE)

print("φ at 16×16:", phi_16.shape, "| 32×32:", phi_32.shape, "| 64×64:", phi_64.shape)
print("Same model, one encoder pass; resolution-agnostic token extraction.")