# Classification comparison: five representations with a small transformer

Compare CIFAR-10 classification using **one small transformer** on five inputs:
1. **Original** – patched image pixels (same transformer, patch embedding).
2. **Sparse (full ctx)** – last latent of sparse-context model with **full** context.
3. **Full baseline** – last latent of full-context baseline.
4. **Full + NCE** – last latent of full-context NCE model.
5. **Sparse + NCE finetuned (full ctx)** – last latent of sparse-context NCE finetune with **full** context.

All neural-field variants use the **last latent** \((B, N, D)\) (encoder + processor output). No sparsification; same data and same small transformer for comparison only. Epochs kept small.

In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from einops import rearrange, repeat

from nf_feature_models import (
    CascadedPerceiverIO,
    GaussianFourierFeatures,
    create_coordinate_grid,
    prepare_model_input,
)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT_DIR = "checkpoints"
IMAGE_SIZE = 32
CHANNELS = 3
NUM_CLASSES = 10
FOURIER_MAPPING_SIZE = 96
POS_EMBED_DIM = FOURIER_MAPPING_SIZE * 2
INPUT_DIM = CHANNELS + POS_EMBED_DIM
QUERIES_DIM = POS_EMBED_DIM
LOGITS_DIM = CHANNELS
LATENT_DIM = 512
NUM_LATENTS = 256

def prepare_full_context(images, coords_full, fourier_encoder):
    input_full, _, _ = prepare_model_input(images, coords_full, fourier_encoder)
    return input_full

def get_residual(model, data):
    """Encoder + processor -> (B, num_latents, latent_dim)."""
    residual = None
    for block in model.encoder_blocks:
        residual = block(x=residual, context=data, mask=None, residual=residual)
    for sa_block in model.self_attn_blocks:
        residual = sa_block[0](residual) + residual
        residual = sa_block[1](residual) + residual
    return residual

coords_32 = create_coordinate_grid(IMAGE_SIZE, IMAGE_SIZE, DEVICE)
print("Device:", DEVICE)

In [None]:
# Load sparse-context model (required)
CKPT_SPARSE = os.path.join(CHECKPOINT_DIR, "checkpoint_sparse_best.pt")
if not os.path.isfile(CKPT_SPARSE):
    CKPT_SPARSE = os.path.join(CHECKPOINT_DIR, "checkpoint_sparse_last.pt")
assert os.path.isfile(CKPT_SPARSE), "Run SparseContext_CIFAR10.ipynb first."

fourier_encoder = GaussianFourierFeatures(2, FOURIER_MAPPING_SIZE, 15.0).to(DEVICE)
sparse_model = CascadedPerceiverIO(
    input_dim=INPUT_DIM, queries_dim=QUERIES_DIM, logits_dim=LOGITS_DIM,
    latent_dims=(256, 384, 512), num_latents=(256, 256, 256), decoder_ff=True,
).to(DEVICE)
ckpt = torch.load(CKPT_SPARSE, map_location=DEVICE)
sparse_model.load_state_dict(ckpt["model_state_dict"], strict=False)
fourier_encoder.load_state_dict(ckpt["fourier_encoder_state_dict"], strict=False)
sparse_model.eval()
fourier_encoder.eval()
print("Loaded sparse:", CKPT_SPARSE)

# Optional: full-context baseline
CKPT_BASELINE = os.path.join(CHECKPOINT_DIR, "checkpoint_best.pt")
if not os.path.isfile(CKPT_BASELINE):
    CKPT_BASELINE = os.path.join(CHECKPOINT_DIR, "checkpoint_last.pt")
baseline_model = baseline_fourier = None
if os.path.isfile(CKPT_BASELINE):
    baseline_fourier = GaussianFourierFeatures(2, FOURIER_MAPPING_SIZE, 15.0).to(DEVICE)
    baseline_model = CascadedPerceiverIO(
        input_dim=INPUT_DIM, queries_dim=QUERIES_DIM, logits_dim=LOGITS_DIM,
        latent_dims=(256, 384, 512), num_latents=(256, 256, 256), decoder_ff=True,
    ).to(DEVICE)
    ckpt_b = torch.load(CKPT_BASELINE, map_location=DEVICE)
    baseline_model.load_state_dict(ckpt_b["model_state_dict"], strict=False)
    baseline_fourier.load_state_dict(ckpt_b["fourier_encoder_state_dict"], strict=False)
    baseline_model.eval()
    baseline_fourier.eval()
    print("Loaded baseline:", CKPT_BASELINE)

# Optional: full-context NCE
CKPT_NCE = os.path.join(CHECKPOINT_DIR, "checkpoint_nce_best.pt")
if not os.path.isfile(CKPT_NCE):
    CKPT_NCE = os.path.join(CHECKPOINT_DIR, "softnce_best.pt")
nce_model = nce_fourier = None
if os.path.isfile(CKPT_NCE):
    nce_fourier = GaussianFourierFeatures(2, FOURIER_MAPPING_SIZE, 15.0).to(DEVICE)
    nce_model = CascadedPerceiverIO(
        input_dim=INPUT_DIM, queries_dim=QUERIES_DIM, logits_dim=LOGITS_DIM,
        latent_dims=(256, 384, 512), num_latents=(256, 256, 256), decoder_ff=True,
    ).to(DEVICE)
    ckpt_n = torch.load(CKPT_NCE, map_location=DEVICE)
    nce_model.load_state_dict(ckpt_n["model_state_dict"], strict=False)
    nce_fourier.load_state_dict(ckpt_n["fourier_encoder_state_dict"], strict=False)
    nce_model.eval()
    nce_fourier.eval()
    print("Loaded NCE:", CKPT_NCE)

# Optional: sparse-context NCE finetuned
CKPT_SFN = os.path.join(CHECKPOINT_DIR, "checkpoint_sparse_finetune_nce_best.pt")
if not os.path.isfile(CKPT_SFN):
    CKPT_SFN = os.path.join(CHECKPOINT_DIR, "checkpoint_sparse_finetune_nce_last.pt")
sfn_model = sfn_fourier = None
if os.path.isfile(CKPT_SFN):
    sfn_fourier = GaussianFourierFeatures(2, FOURIER_MAPPING_SIZE, 15.0).to(DEVICE)
    sfn_model = CascadedPerceiverIO(
        input_dim=INPUT_DIM, queries_dim=QUERIES_DIM, logits_dim=LOGITS_DIM,
        latent_dims=(256, 384, 512), num_latents=(256, 256, 256), decoder_ff=True,
    ).to(DEVICE)
    ckpt_sfn = torch.load(CKPT_SFN, map_location=DEVICE)
    sfn_model.load_state_dict(ckpt_sfn["model_state_dict"], strict=False)
    sfn_fourier.load_state_dict(ckpt_sfn["fourier_encoder_state_dict"], strict=False)
    sfn_model.eval()
    sfn_fourier.eval()
    print("Loaded sparse finetune NCE:", CKPT_SFN)

In [None]:
# CIFAR-10, same transform as eval (full data, no sparsification)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
train_dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
BATCH_SIZE = 64
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

In [None]:
# Patchify images for "original" representation: (B,C,H,W) -> (B, n_patches, embed_dim)
PATCH_SIZE = 4
PATCH_DIM = PATCH_SIZE * PATCH_SIZE * CHANNELS
N_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2

def patchify(images):
    """(B, C, H, W) -> (B, n_patches, patch_dim)."""
    B, C, H, W = images.shape
    p = PATCH_SIZE
    x = rearrange(images, "b c (h ph) (w pw) -> b (h w) (ph pw c)", ph=p, pw=p)
    return x

def get_representation(images, rep_name):
    """Return (B, seq_len, LATENT_DIM) for the small transformer. All use full context."""
    with torch.no_grad():
        if rep_name == "original":
            patches = patchify(images).to(DEVICE)
            return patches.float()
        input_full = prepare_full_context(images, coords_32, fourier_encoder)
        if rep_name == "sparse":
            return get_residual(sparse_model, input_full)
        if rep_name == "baseline" and baseline_model is not None:
            return get_residual(baseline_model, input_full)
        if rep_name == "nce" and nce_model is not None:
            return get_residual(nce_model, input_full)
        if rep_name == "sparse_finetune_nce" and sfn_model is not None:
            return get_residual(sfn_model, input_full)
    return None

In [None]:
class SmallTransformerClassifier(nn.Module):
    """Small transformer on sequence (B, L, d_in); projects to d_model, then mean pool + linear to num_classes."""

    def __init__(self, d_in, d_model=256, num_heads=4, num_layers=2, mlp_ratio=2, num_classes=10, max_len=256):
        super().__init__()
        self.d_in = d_in
        self.d_model = d_model
        self.proj = nn.Linear(d_in, d_model)
        self.pos_embed = nn.Parameter(torch.zeros(1, max_len, d_model) * 0.02)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=num_heads,
            dim_feedforward=d_model * mlp_ratio,
            batch_first=True,
            activation="gelu",
            norm_first=True,
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.head = nn.Linear(d_model, num_classes)

    def forward(self, x):
        B, L, _ = x.shape
        x = self.proj(x)
        x = x + self.pos_embed[:, :L, :]
        x = self.transformer(x)
        x = x.mean(dim=1)
        return self.head(x)


def train_and_eval(rep_name, d_in, seq_len, epochs=5, lr=1e-3):
    model = SmallTransformerClassifier(d_in=d_in, d_model=256, num_heads=4, num_layers=2, num_classes=NUM_CLASSES, max_len=max(seq_len, 256)).to(DEVICE)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    ce = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        model.train()
        for images, labels in train_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            x = get_representation(images, rep_name)
            if x is None:
                continue
            logits = model(x)
            loss = ce(logits, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                x = get_representation(images, rep_name)
                if x is None:
                    continue
                logits = model(x)
                correct += (logits.argmax(1) == labels).sum().item()
                total += labels.size(0)
        acc = correct / total if total else 0
        print(f"  {rep_name} epoch {epoch+1}/{epochs} test acc: {acc:.4f}")
    return acc


EPOCHS = 5
print("Small transformer: d_model=256, 2 layers, 4 heads. Epochs:", EPOCHS)

In [None]:
results = {}

print("1. Original (patched pixels)")
results["original"] = train_and_eval("original", d_in=PATCH_DIM, seq_len=N_PATCHES, epochs=EPOCHS)

print("2. Sparse model, full context (last latent)")
results["sparse"] = train_and_eval("sparse", d_in=LATENT_DIM, seq_len=NUM_LATENTS, epochs=EPOCHS)

if baseline_model is not None:
    print("3. Full baseline (last latent)")
    results["baseline"] = train_and_eval("baseline", d_in=LATENT_DIM, seq_len=NUM_LATENTS, epochs=EPOCHS)
else:
    results["baseline"] = None

if nce_model is not None:
    print("4. Full + NCE (last latent)")
    results["nce"] = train_and_eval("nce", d_in=LATENT_DIM, seq_len=NUM_LATENTS, epochs=EPOCHS)
else:
    results["nce"] = None

if sfn_model is not None:
    print("5. Sparse + NCE finetuned, full context (last latent)")
    results["sparse_finetune_nce"] = train_and_eval("sparse_finetune_nce", d_in=LATENT_DIM, seq_len=NUM_LATENTS, epochs=EPOCHS)
else:
    results["sparse_finetune_nce"] = None

print("\n--- Summary (test accuracy) ---")
for name, acc in results.items():
    print(f"  {name}: {acc:.4f}" if acc is not None else f"  {name}: (not loaded)")

import matplotlib.pyplot as plt
names = [k for k, v in results.items() if v is not None]
accs = [results[k] for k in names]
if names:
    plt.bar(names, accs, color='steelblue')
    plt.ylabel('Test accuracy')
    plt.title('Classification comparison (small transformer, full data, few epochs)')
    plt.ylim(0, 1)
    plt.xticks(rotation=15)
    plt.tight_layout()
    plt.savefig('latent_classification_comparison.png', dpi=100)
    plt.show()