# Classification comparison: five representations with a small transformer

Compare CIFAR-10 classification using **one small transformer** on five inputs:
1. **Original** – patched image pixels (same transformer, patch embedding).
2. **Sparse (full ctx)** – last latent of sparse-context model with **full** context.
3. **Full baseline** – last latent of full-context baseline.
4. **Full + NCE** – last latent of full-context NCE model.
5. **Sparse + NCE finetuned (full ctx)** – last latent of sparse-context NCE finetune with **full** context.

All neural-field variants use the **last latent** \((B, N, D)\) (encoder + processor output). Same data and same small transformer; Epochs kept small.

**Sparse-context-only mode** (`USE_SPARSE_CONTEXT = True` in code): every representation sees only **sparse observations** (no full image). NF models get `prepare_sparse_context` (a random subset of coords); "original" gets the same number of points as (x, y, r, g, b). So classification is from partial observations only — sparse-trained models may do better here.

---

**Why does t-SNE favor NCE/sparse while the transformer favors baseline?**

- **t-SNE** reflects *neighborhood / metric structure*; NCE and sparse-context often improve that, so t-SNE looks better for them.

- **Separability** (good linear/transformer accuracy) is often linked to **hierarchical concept learning / abstraction** — and **sparse context** and **masked-style** objectives are *meant* to encourage that (infer whole from parts, more semantic structure). So in principle sparse/MIM could do *better* on separability.

Here the **baseline** can still win on the class probe because: (1) Full-context reconstruction pushes the latent to hold *all* pixel-level information, so it becomes very *information-dense* and easy for a simple head to separate, even if not more "abstract." (2) CIFAR-10 class labels are one particular notion of "concept"; sparse/NCE may be learning abstraction that is more *view- or instance-level* (good for metric/t-SNE) and not aligned with those class boundaries. So we're not claiming "baseline is more abstract" — rather "baseline is more class-informative in a dense, linearly usable way" while sparse/NCE can excel on *metric* structure and possibly different abstraction. Below: **k-NN** checks the metric view; **linear/transformer** check separability.

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from einops import rearrange, repeat

from nf_feature_models import (
    CascadedPerceiverIO,
    GaussianFourierFeatures,
    create_coordinate_grid,
    prepare_model_input,
)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT_DIR = "checkpoints"
IMAGE_SIZE = 32
CHANNELS = 3
NUM_CLASSES = 10
FOURIER_MAPPING_SIZE = 96
POS_EMBED_DIM = FOURIER_MAPPING_SIZE * 2
INPUT_DIM = CHANNELS + POS_EMBED_DIM
QUERIES_DIM = POS_EMBED_DIM
LOGITS_DIM = CHANNELS
LATENT_DIM = 512
NUM_LATENTS = 256

# Sparse-context-only mode: all representations see only sparse observations (no full image).
USE_SPARSE_CONTEXT = True
CONTEXT_FRAC = 0.2
coords_32 = create_coordinate_grid(IMAGE_SIZE, IMAGE_SIZE, DEVICE)
N_FULL = coords_32.size(0)
N_SPARSE = max(64, int(N_FULL * CONTEXT_FRAC))

def sample_gt_at_coords(images, coords):
    B, C, H, W = images.shape
    N = coords.shape[1]
    grid = coords[..., [1, 0]].view(B, 1, N, 2)
    return F.grid_sample(images, grid, mode="bilinear", padding_mode="border", align_corners=True).squeeze(2).permute(0, 2, 1)

def prepare_sparse_context(images, coords_full, fourier_encoder, num_sparse, device):
    B = images.size(0)
    idx = torch.randperm(coords_full.size(0), device=device)[:num_sparse]
    coords_sparse = coords_full[idx]
    pixels_sparse = sample_gt_at_coords(images, coords_sparse.unsqueeze(0).expand(B, -1, -1))
    pos_sparse = fourier_encoder(coords_sparse.unsqueeze(0).expand(B, -1, -1))
    return torch.cat([pixels_sparse, pos_sparse], dim=-1)

def prepare_full_context(images, coords_full, fourier_encoder):
    input_full, _, _ = prepare_model_input(images, coords_full, fourier_encoder)
    return input_full

def get_residual(model, data):
    """Encoder + processor -> (B, num_latents, latent_dim)."""
    residual = None
    for block in model.encoder_blocks:
        residual = block(x=residual, context=data, mask=None, residual=residual)
    for sa_block in model.self_attn_blocks:
        residual = sa_block[0](residual) + residual
        residual = sa_block[1](residual) + residual
    return residual

print("Device:", DEVICE, "| USE_SPARSE_CONTEXT:", USE_SPARSE_CONTEXT, "| N_SPARSE:", N_SPARSE if USE_SPARSE_CONTEXT else "N/A")

In [None]:
# Load sparse-context model (required)
CKPT_SPARSE = os.path.join(CHECKPOINT_DIR, "checkpoint_sparse_best.pt")
if not os.path.isfile(CKPT_SPARSE):
    CKPT_SPARSE = os.path.join(CHECKPOINT_DIR, "checkpoint_sparse_last.pt")
assert os.path.isfile(CKPT_SPARSE), "Run SparseContext_CIFAR10.ipynb first."

fourier_encoder = GaussianFourierFeatures(2, FOURIER_MAPPING_SIZE, 15.0).to(DEVICE)
sparse_model = CascadedPerceiverIO(
    input_dim=INPUT_DIM, queries_dim=QUERIES_DIM, logits_dim=LOGITS_DIM,
    latent_dims=(256, 384, 512), num_latents=(256, 256, 256), decoder_ff=True,
).to(DEVICE)
ckpt = torch.load(CKPT_SPARSE, map_location=DEVICE)
sparse_model.load_state_dict(ckpt["model_state_dict"], strict=False)
fourier_encoder.load_state_dict(ckpt["fourier_encoder_state_dict"], strict=False)
sparse_model.eval()
fourier_encoder.eval()
print("Loaded sparse:", CKPT_SPARSE)

# Optional: full-context baseline
CKPT_BASELINE = os.path.join(CHECKPOINT_DIR, "checkpoint_best.pt")
if not os.path.isfile(CKPT_BASELINE):
    CKPT_BASELINE = os.path.join(CHECKPOINT_DIR, "checkpoint_last.pt")
baseline_model = baseline_fourier = None
if os.path.isfile(CKPT_BASELINE):
    baseline_fourier = GaussianFourierFeatures(2, FOURIER_MAPPING_SIZE, 15.0).to(DEVICE)
    baseline_model = CascadedPerceiverIO(
        input_dim=INPUT_DIM, queries_dim=QUERIES_DIM, logits_dim=LOGITS_DIM,
        latent_dims=(256, 384, 512), num_latents=(256, 256, 256), decoder_ff=True,
    ).to(DEVICE)
    ckpt_b = torch.load(CKPT_BASELINE, map_location=DEVICE)
    baseline_model.load_state_dict(ckpt_b["model_state_dict"], strict=False)
    baseline_fourier.load_state_dict(ckpt_b["fourier_encoder_state_dict"], strict=False)
    baseline_model.eval()
    baseline_fourier.eval()
    print("Loaded baseline:", CKPT_BASELINE)

# Optional: full-context NCE
CKPT_NCE = os.path.join(CHECKPOINT_DIR, "checkpoint_nce_best.pt")
if not os.path.isfile(CKPT_NCE):
    CKPT_NCE = os.path.join(CHECKPOINT_DIR, "softnce_best.pt")
nce_model = nce_fourier = None
if os.path.isfile(CKPT_NCE):
    nce_fourier = GaussianFourierFeatures(2, FOURIER_MAPPING_SIZE, 15.0).to(DEVICE)
    nce_model = CascadedPerceiverIO(
        input_dim=INPUT_DIM, queries_dim=QUERIES_DIM, logits_dim=LOGITS_DIM,
        latent_dims=(256, 384, 512), num_latents=(256, 256, 256), decoder_ff=True,
    ).to(DEVICE)
    ckpt_n = torch.load(CKPT_NCE, map_location=DEVICE)
    nce_model.load_state_dict(ckpt_n["model_state_dict"], strict=False)
    nce_fourier.load_state_dict(ckpt_n["fourier_encoder_state_dict"], strict=False)
    nce_model.eval()
    nce_fourier.eval()
    print("Loaded NCE:", CKPT_NCE)

# Optional: sparse-context NCE finetuned
CKPT_SFN = os.path.join(CHECKPOINT_DIR, "checkpoint_sparse_finetune_nce_best.pt")
if not os.path.isfile(CKPT_SFN):
    CKPT_SFN = os.path.join(CHECKPOINT_DIR, "checkpoint_sparse_finetune_nce_last.pt")
sfn_model = sfn_fourier = None
if os.path.isfile(CKPT_SFN):
    sfn_fourier = GaussianFourierFeatures(2, FOURIER_MAPPING_SIZE, 15.0).to(DEVICE)
    sfn_model = CascadedPerceiverIO(
        input_dim=INPUT_DIM, queries_dim=QUERIES_DIM, logits_dim=LOGITS_DIM,
        latent_dims=(256, 384, 512), num_latents=(256, 256, 256), decoder_ff=True,
    ).to(DEVICE)
    ckpt_sfn = torch.load(CKPT_SFN, map_location=DEVICE)
    sfn_model.load_state_dict(ckpt_sfn["model_state_dict"], strict=False)
    sfn_fourier.load_state_dict(ckpt_sfn["fourier_encoder_state_dict"], strict=False)
    sfn_model.eval()
    sfn_fourier.eval()
    print("Loaded sparse finetune NCE:", CKPT_SFN)

In [None]:
# CIFAR-10; train with augmentation so patched-pixel baseline can compete
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
train_dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transform)
test_dataset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=test_transform)
BATCH_SIZE = 64
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

In [None]:
# Patchify images for "original" (full context); sparse original = (x,y,r,g,b) at N_SPARSE coords.
PATCH_SIZE = 4
PATCH_DIM = PATCH_SIZE * PATCH_SIZE * CHANNELS
N_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2
SPARSE_ORIGINAL_DIM = 5  # (x, y, r, g, b) at each sparse point

def patchify(images):
    """(B, C, H, W) -> (B, n_patches, patch_dim)."""
    B, C, H, W = images.shape
    p = PATCH_SIZE
    x = rearrange(images, "b c (h ph) (w pw) -> b (h w) (ph pw c)", ph=p, pw=p)
    return x

def get_representation(images, rep_name):
    """Return (B, seq_len, d). If USE_SPARSE_CONTEXT: all reps see only sparse observations."""
    with torch.no_grad():
        if rep_name == "original":
            if USE_SPARSE_CONTEXT:
                B = images.size(0)
                idx = torch.randperm(coords_32.size(0), device=DEVICE)[:N_SPARSE]
                coords_sparse = coords_32[idx]
                pixels = sample_gt_at_coords(images, coords_sparse.unsqueeze(0).expand(B, -1, -1))
                xy = coords_sparse.unsqueeze(0).expand(B, -1, -1)
                return torch.cat([xy, pixels], dim=-1).float()
            patches = patchify(images).to(DEVICE)
            return patches.float()
        if USE_SPARSE_CONTEXT:
            input_ctx = prepare_sparse_context(images, coords_32, fourier_encoder, N_SPARSE, DEVICE)
        else:
            input_ctx = prepare_full_context(images, coords_32, fourier_encoder)
        if rep_name == "sparse":
            return get_residual(sparse_model, input_ctx)
        if rep_name == "baseline" and baseline_model is not None:
            if USE_SPARSE_CONTEXT:
                input_ctx = prepare_sparse_context(images, coords_32, baseline_fourier, N_SPARSE, DEVICE)
            else:
                input_ctx = prepare_full_context(images, coords_32, baseline_fourier)
            return get_residual(baseline_model, input_ctx)
        if rep_name == "nce" and nce_model is not None:
            if USE_SPARSE_CONTEXT:
                input_ctx = prepare_sparse_context(images, coords_32, nce_fourier, N_SPARSE, DEVICE)
            else:
                input_ctx = prepare_full_context(images, coords_32, nce_fourier)
            return get_residual(nce_model, input_ctx)
        if rep_name == "sparse_finetune_nce" and sfn_model is not None:
            if USE_SPARSE_CONTEXT:
                input_ctx = prepare_sparse_context(images, coords_32, sfn_fourier, N_SPARSE, DEVICE)
            else:
                input_ctx = prepare_full_context(images, coords_32, sfn_fourier)
            return get_residual(sfn_model, input_ctx)
    return None

In [None]:
class SmallTransformerClassifier(nn.Module):
    """Small transformer on sequence (B, L, d_in); projects to d_model, then mean pool + linear to num_classes."""

    def __init__(self, d_in, d_model=256, num_heads=4, num_layers=2, mlp_ratio=2, num_classes=10, max_len=256):
        super().__init__()
        self.d_in = d_in
        self.d_model = d_model
        self.proj = nn.Linear(d_in, d_model)
        self.pos_embed = nn.Parameter(torch.zeros(1, max_len, d_model) * 0.02)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=num_heads,
            dim_feedforward=d_model * mlp_ratio,
            batch_first=True,
            activation="gelu",
            norm_first=True,
            dropout=0.1,
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.head = nn.Linear(d_model, num_classes)

    def forward(self, x):
        B, L, _ = x.shape
        x = self.proj(x)
        x = x + self.pos_embed[:, :L, :]
        x = self.transformer(x)
        x = x.mean(dim=1)
        return self.head(x)


def train_and_eval(rep_name, d_in, seq_len, epochs=5, lr=1e-3):
    model = SmallTransformerClassifier(d_in=d_in, d_model=256, num_heads=4, num_layers=2, num_classes=NUM_CLASSES, max_len=max(seq_len, 256)).to(DEVICE)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    ce = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        model.train()
        for images, labels in train_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            x = get_representation(images, rep_name)
            if x is None:
                continue
            logits = model(x)
            loss = ce(logits, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                x = get_representation(images, rep_name)
                if x is None:
                    continue
                logits = model(x)
                correct += (logits.argmax(1) == labels).sum().item()
                total += labels.size(0)
        acc = correct / total if total else 0
        print(f"  {rep_name} epoch {epoch+1}/{epochs} test acc: {acc:.4f}")
    return acc


EPOCHS = 8
print("Small transformer: d_model=256, 2 layers, 4 heads, dropout=0.1. Epochs:", EPOCHS)

In [None]:
results = {}

print("1. Original (patched pixels)")
results["original"] = train_and_eval("original", d_in=PATCH_DIM, seq_len=N_PATCHES, epochs=EPOCHS)

print("2. Sparse model, full context (last latent)")
results["sparse"] = train_and_eval("sparse", d_in=LATENT_DIM, seq_len=NUM_LATENTS, epochs=EPOCHS)

if baseline_model is not None:
    print("3. Full baseline (last latent)")
    results["baseline"] = train_and_eval("baseline", d_in=LATENT_DIM, seq_len=NUM_LATENTS, epochs=EPOCHS)
else:
    results["baseline"] = None

if nce_model is not None:
    print("4. Full + NCE (last latent)")
    results["nce"] = train_and_eval("nce", d_in=LATENT_DIM, seq_len=NUM_LATENTS, epochs=EPOCHS)
else:
    results["nce"] = None

if sfn_model is not None:
    print("5. Sparse + NCE finetuned, full context (last latent)")
    results["sparse_finetune_nce"] = train_and_eval("sparse_finetune_nce", d_in=LATENT_DIM, seq_len=NUM_LATENTS, epochs=EPOCHS)
else:
    results["sparse_finetune_nce"] = None

print("\n--- Summary (test accuracy) ---")
for name, acc in results.items():
    print(f"  {name}: {acc:.4f}" if acc is not None else f"  {name}: (not loaded)")

import matplotlib.pyplot as plt
names = [k for k, v in results.items() if v is not None]
accs = [results[k] for k in names]
if names:
    plt.bar(names, accs, color='steelblue')
    plt.ylabel('Test accuracy')
    plt.title('Classification comparison (small transformer, full data, few epochs)')
    plt.ylim(0, 1)
    plt.xticks(rotation=15)
    plt.tight_layout()
    plt.savefig('latent_classification_comparison.png', dpi=100)
    plt.show()

### k-NN and linear probe (metric vs separability)

**k-NN** uses the raw distance in the representation space (no trained head). It matches what t-SNE is showing: good *neighborhood* structure → good k-NN. **Linear probe** trains a single linear layer; it often correlates with transformer accuracy. We mean-pool each representation to one vector per image, then evaluate.

In [None]:
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression

def get_pooled_representation(images, rep_name):
    """(B, L, d) -> (B, d) by mean pool. Returns None if rep not available."""
    x = get_representation(images, rep_name)
    if x is None:
        return None
    return x.mean(dim=1)

def collect_features(loader, rep_name):
    feats, labels = [], []
    with torch.no_grad():
        for imgs, lbls in loader:
            imgs = imgs.to(DEVICE)
            z = get_pooled_representation(imgs, rep_name)
            if z is None:
                return None, None
            feats.append(z.cpu().numpy())
            labels.append(lbls.numpy())
    return np.concatenate(feats, axis=0), np.concatenate(labels, axis=0)

KNN_K = 20
rep_names = ["original", "sparse", "baseline", "nce", "sparse_finetune_nce"]
knn_results = {}
linear_results = {}
for name in rep_names:
    X_tr, y_tr = collect_features(train_loader, name)
    if X_tr is None:
        knn_results[name] = None
        linear_results[name] = None
        continue
    X_te, y_te = collect_features(test_loader, name)
    knn = KNeighborsClassifier(n_neighbors=KNN_K, weights="distance")
    knn.fit(X_tr, y_tr)
    knn_results[name] = knn.score(X_te, y_te)
    lr = LogisticRegression(max_iter=500, C=0.1, multi_class="multinomial")
    lr.fit(X_tr, y_tr)
    linear_results[name] = lr.score(X_te, y_te)
    print(f"{name}: k-NN (k={KNN_K}) = {knn_results[name]:.4f}, linear = {linear_results[name]:.4f}")

print("\n--- k-NN (metric / neighborhood, aligns with t-SNE) ---")
for name in rep_names:
    v = knn_results.get(name)
    print(f"  {name}: {v:.4f}" if v is not None else f"  {name}: (skip)")
print("\n--- Linear probe (separability) ---")
for name in rep_names:
    v = linear_results.get(name)
    print(f"  {name}: {v:.4f}" if v is not None else f"  {name}: (skip)")