# TDSM Classification: Token-Decoded Spatial Map for CIFAR-10

Uses the **pretrained** OmniField-style model (checkpoint from `AblationCIFAR10.ipynb`). We extract **TDSM** (one decoded "component" image per latent token), then train a **small CNN** on TDSM for classification. Includes visualizations of the learned features.

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
from einops import rearrange, repeat

from nf_feature_models import (
    CascadedPerceiverIO,
    GaussianFourierFeatures,
    create_coordinate_grid,
    prepare_model_input,
)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CHECKPOINT_DIR = "checkpoints"
CKPT_PATH = os.path.join(CHECKPOINT_DIR, "checkpoint_best.pt")
if not os.path.isfile(CKPT_PATH):
    CKPT_PATH = os.path.join(CHECKPOINT_DIR, "checkpoint_last.pt")
assert os.path.isfile(CKPT_PATH), f"No checkpoint found in {CHECKPOINT_DIR}. Train AblationCIFAR10.ipynb first."

# Config (must match training notebook)
IMAGE_SIZE = 32
CHANNELS = 3
FOURIER_MAPPING_SIZE = 96
POS_EMBED_DIM = FOURIER_MAPPING_SIZE * 2
INPUT_DIM = CHANNELS + POS_EMBED_DIM
QUERIES_DIM = POS_EMBED_DIM
LOGITS_DIM = CHANNELS
NUM_LATENTS = 256

fourier_encoder = GaussianFourierFeatures(in_features=2, mapping_size=FOURIER_MAPPING_SIZE, scale=15.0).to(DEVICE)
model = CascadedPerceiverIO(
    input_dim=INPUT_DIM,
    queries_dim=QUERIES_DIM,
    logits_dim=LOGITS_DIM,
    latent_dims=(256, 384, 512),
    num_latents=(256, 256, 256),
    decoder_ff=True,
).to(DEVICE)

ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
model.load_state_dict(ckpt["model_state_dict"], strict=False)
fourier_encoder.load_state_dict(ckpt["fourier_encoder_state_dict"], strict=False)
model.eval()
fourier_encoder.eval()

coords_32 = create_coordinate_grid(IMAGE_SIZE, IMAGE_SIZE, DEVICE)
print(f"Loaded {CKPT_PATH}")
print(f"Model and fourier_encoder on {DEVICE}")

In [None]:
def get_residual(model, data):
    """Run encoder + processor only; return latent field (B, num_latents, latent_dim)."""
    residual = None
    for block in model.encoder_blocks:
        residual = block(x=residual, context=data, mask=None, residual=residual)
    for sa_block in model.self_attn_blocks:
        residual = sa_block[0](residual) + residual
        residual = sa_block[1](residual) + residual
    return residual


def decoder_forward(model, queries, context):
    """Run decoder: queries (B,N,qd), context (B,1 or B,L,ld) -> (B,N,3)."""
    x = model.decoder_cross_attn(queries, context=context)
    x = x + queries
    if model.decoder_ff is not None:
        x = x + model.decoder_ff(x)
    return model.to_logits(x)


def get_tdsm(model, fourier_encoder, data, coords_32, device, num_tokens=256):
    """
    Token-decoded spatial map: for each latent token, decode to 32x32x3; then take mean over RGB.
    Returns TDSM of shape (B, num_tokens, 32, 32) for CNN (channels first).
    """
    with torch.no_grad():
        residual = get_residual(model, data)
    B = data.size(0)
    queries_32 = fourier_encoder(repeat(coords_32, "n d -> b n d", b=B)).to(device)
    component_images = []
    for k in range(num_tokens):
        ctx_k = residual[:, k : k + 1, :]
        logits_k = decoder_forward(model, queries_32, ctx_k)
        img_k = logits_k.reshape(B, IMAGE_SIZE, IMAGE_SIZE, 3)
        component_images.append(img_k)
    component_images = torch.stack(component_images, dim=1)
    tdsm = component_images.mean(dim=-1)
    return tdsm

In [None]:
class TDSMClassifier(nn.Module):
    """Small CNN on TDSM (B, num_tokens, 32, 32)."""

    def __init__(self, num_tokens=256, num_classes=10):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(num_tokens, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(32, num_classes),
        )

    def forward(self, tdsm):
        return self.net(tdsm)


NUM_CLASSES = 10
classifier = TDSMClassifier(num_tokens=NUM_LATENTS, num_classes=NUM_CLASSES).to(DEVICE)
print(classifier)

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
train_dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
BATCH_SIZE = 32
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
EPOCHS = 15

In [None]:
for epoch in range(EPOCHS):
    model.eval()
    fourier_encoder.eval()
    classifier.train()
    total_loss = 0.0
    correct = 0
    total = 0
    for images, labels in train_loader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        input_data, _, _ = prepare_model_input(images, coords_32, fourier_encoder)
        tdsm = get_tdsm(model, fourier_encoder, input_data, coords_32, DEVICE)
        logits = classifier(tdsm)
        loss = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * images.size(0)
        pred = logits.argmax(dim=1)
        correct += (pred == labels).sum().item()
        total += images.size(0)
    train_acc = correct / total
    print(f"Epoch {epoch+1}/{EPOCHS}  Train Loss: {total_loss/total:.4f}  Train Acc: {train_acc:.4f}")

classifier.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        input_data, _, _ = prepare_model_input(images, coords_32, fourier_encoder)
        tdsm = get_tdsm(model, fourier_encoder, input_data, coords_32, DEVICE)
        logits = classifier(tdsm)
        pred = logits.argmax(dim=1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)
test_acc = correct / total
print(f"Test Accuracy: {test_acc:.4f}")

## Feature visualizations

1. **TDSM slices**: a few latent tokens' decoded maps (32Ã—32) for one sample.
2. **Component images**: RGB "component" images for a few tokens (what each token contributes to the reconstruction).
3. **TDSM channel stats**: mean and std across tokens to see which tokens are most active.

In [None]:
model.eval()
fourier_encoder.eval()
classifier.eval()
images, labels = next(iter(test_loader))
images = images[:8].to(DEVICE)
labels = labels[:8]
input_data, _, _ = prepare_model_input(images, coords_32, fourier_encoder)
with torch.no_grad():
    tdsm = get_tdsm(model, fourier_encoder, input_data, coords_32, DEVICE)

sample_idx = 0
tdsm_one = tdsm[sample_idx].cpu().numpy()
num_show = 16
step = max(1, NUM_LATENTS // num_show)
indices = list(range(0, NUM_LATENTS, step))[:num_show]

fig, axs = plt.subplots(4, 4, figsize=(10, 10))
for i, ax in enumerate(axs.flat):
    k = indices[i]
    im = ax.imshow(tdsm_one[k], cmap="viridis")
    ax.set_title(f"Token {k}")
    ax.axis("off")
plt.suptitle("TDSM: Token-decoded spatial maps (one sample, 16 tokens)", fontsize=12)
plt.tight_layout()
plt.show()

In [None]:
with torch.no_grad():
    residual = get_residual(model, input_data)
    queries_32 = fourier_encoder(repeat(coords_32, "n d -> b n d", b=images.size(0))).to(DEVICE)
    token_indices = [0, 32, 64, 128, 192, 255]
    comps = []
    for k in token_indices:
        ctx_k = residual[:, k : k + 1, :]
        logits_k = decoder_forward(model, queries_32, ctx_k)
        comps.append(logits_k.reshape(images.size(0), IMAGE_SIZE, IMAGE_SIZE, 3))
    comps = torch.stack(comps, dim=0)

def to_display(t):
    return (t.cpu() / 2 + 0.5).clamp(0, 1)

fig, axs = plt.subplots(2, 4, figsize=(12, 6))
axs[0, 0].imshow(to_display(images[sample_idx]).permute(1, 2, 0).numpy())
axs[0, 0].set_title("Input")
axs[0, 0].axis("off")
for i in range(3):
    axs[0, i+1].imshow(to_display(comps[i, sample_idx]).numpy())
    axs[0, i+1].set_title(f"Token {token_indices[i]}")
    axs[0, i+1].axis("off")
for i in range(3):
    axs[1, i].imshow(to_display(comps[i+3, sample_idx]).numpy())
    axs[1, i].set_title(f"Token {token_indices[i+3]}")
    axs[1, i].axis("off")
axs[1, 3].axis("off")
plt.suptitle("Input and token component images (same sample)", fontsize=11)
plt.tight_layout()
plt.show()

In [None]:
token_means = tdsm_one.mean(axis=(1, 2))
token_stds = tdsm_one.std(axis=(1, 2))
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 4))
ax1.bar(range(NUM_LATENTS), token_means, color="steelblue", alpha=0.8)
ax1.set_xlabel("Token index")
ax1.set_ylabel("Mean activation")
ax1.set_title("TDSM: Mean value per token (spatial mean)")
ax2.bar(range(NUM_LATENTS), token_stds, color="coral", alpha=0.8)
ax2.set_xlabel("Token index")
ax2.set_ylabel("Std")
ax2.set_title("TDSM: Spatial std per token")
plt.tight_layout()
plt.show()

In [None]:
try:
    from sklearn.manifold import TSNE
    from sklearn.decomposition import PCA
    N_VAL = 500
    all_tdsm = []
    all_labels = []
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(DEVICE)
            input_data, _, _ = prepare_model_input(images, coords_32, fourier_encoder)
            tdsm = get_tdsm(model, fourier_encoder, input_data, coords_32, DEVICE)
            feat = tdsm.mean(dim=(2, 3)).cpu().numpy()
            all_tdsm.append(feat)
            all_labels.append(labels.numpy())
            if sum(len(x) for x in all_tdsm) >= N_VAL:
                break
    X = np.concatenate(all_tdsm, axis=0)[:N_VAL]
    y = np.concatenate(all_labels, axis=0)[:N_VAL]
    pca = PCA(n_components=50)
    Xp = pca.fit_transform(X)
    X_tsne = TSNE(n_components=2, random_state=42, perplexity=30).fit_transform(Xp)
    plt.figure(figsize=(8, 6))
    scatter = plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=y, cmap="tab10", s=10, alpha=0.7)
    plt.colorbar(scatter, label="Class")
    plt.title("t-SNE of TDSM pooled features (global mean over space)")
    plt.xlabel("t-SNE 1")
    plt.ylabel("t-SNE 2")
    plt.tight_layout()
    plt.show()
except ImportError:
    print("Install sklearn for t-SNE: pip install scikit-learn")