# Kernel Mean Embedding (KME) Classification

Treat the **decoded field** as a distribution over (position, value). Embed each image's field into an RKHS via a product kernel (spatial × value), then use **kernel SVM** on the Gram matrix.

- **Field:** Query decoder at grid points → (x_i, y_i, z_i) with z_i = decoded RGB.
- **Kernel:** k((x,y,z), (x',y',z')) = k_space((x,y),(x',y')) × k_val(z,z'); both RBF.
- **K(I, I')** = ⟨μ[I], μ[I']⟩ = (1/n²) ∑_{i,j} k(p_i, p'_j).
- **Classification:** sklearn SVM with precomputed kernel (Gram matrix).

In [ ]:
import os
import torch
import numpy as np
from torch.utils.data import DataLoader, Subset
import torchvision
from torchvision import transforms
from einops import repeat

from nf_feature_models import (
    CascadedPerceiverIO,
    GaussianFourierFeatures,
    create_coordinate_grid,
    prepare_model_input,
)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CHECKPOINT_DIR = "checkpoints"
CKPT_PATH = os.path.join(CHECKPOINT_DIR, "checkpoint_best.pt")
if not os.path.isfile(CKPT_PATH):
    CKPT_PATH = os.path.join(CHECKPOINT_DIR, "checkpoint_last.pt")
assert os.path.isfile(CKPT_PATH), f"No checkpoint in {CHECKPOINT_DIR}. Train AblationCIFAR10 first."

IMAGE_SIZE = 32
CHANNELS = 3
FOURIER_MAPPING_SIZE = 96
POS_EMBED_DIM = FOURIER_MAPPING_SIZE * 2
INPUT_DIM = CHANNELS + POS_EMBED_DIM
QUERIES_DIM = POS_EMBED_DIM
LOGITS_DIM = CHANNELS

fourier_encoder = GaussianFourierFeatures(2, FOURIER_MAPPING_SIZE, scale=15.0).to(DEVICE)
model = CascadedPerceiverIO(
    input_dim=INPUT_DIM,
    queries_dim=QUERIES_DIM,
    logits_dim=LOGITS_DIM,
    latent_dims=(256, 384, 512),
    num_latents=(256, 256, 256),
    decoder_ff=True,
).to(DEVICE)
ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
model.load_state_dict(ckpt["model_state_dict"], strict=False)
fourier_encoder.load_state_dict(ckpt["fourier_encoder_state_dict"], strict=False)
model.eval()
fourier_encoder.eval()

coords_32 = create_coordinate_grid(IMAGE_SIZE, IMAGE_SIZE, DEVICE)
print(f"Loaded {CKPT_PATH}")
print(f"Device: {DEVICE}")

In [ ]:
def get_residual(model, data):
    residual = None
    for block in model.encoder_blocks:
        residual = block(x=residual, context=data, mask=None, residual=residual)
    for sa_block in model.self_attn_blocks:
        residual = sa_block[0](residual) + residual
        residual = sa_block[1](residual) + residual
    return residual


def decoder_forward(model, queries, context):
    x = model.decoder_cross_attn(queries, context=context)
    x = x + queries
    if model.decoder_ff is not None:
        x = x + model.decoder_ff(x)
    return model.to_logits(x)


def get_field_points(model, fourier_encoder, data, coords, device, point_indices=None):
    """
    Query the field (one decode with full latent) to get (x,y,z) at grid points.
    data: (B, N, input_dim), coords: (N, 2).
    Returns: (B, n_pts, 5) with last dim = (x, y, z_r, z_g, z_b); if point_indices given, n_pts = len(point_indices).
    """
    with torch.no_grad():
        residual = get_residual(model, data)
        B = data.size(0)
        N = coords.size(0)
        if point_indices is not None:
            coords_sub = coords[point_indices]
            n_pts = len(point_indices)
        else:
            coords_sub = coords
            n_pts = N
        queries = fourier_encoder(repeat(coords_sub, "n d -> b n d", b=B)).to(device)
        z = decoder_forward(model, queries, residual)
        z = z.cpu().numpy()
        xy = coords_sub.cpu().numpy()
    points = np.zeros((B, n_pts, 5), dtype=np.float32)
    points[:, :, :2] = xy[None, :, :]
    points[:, :, 2:5] = z
    return points

In [ ]:
def rbf_kernel_2d(X, Y, sigma):
    """X (n,2), Y (m,2) -> (n,m) RBF k(x,y) = exp(-||x-y||^2 / (2*sigma^2))."""
    X = np.asarray(X, dtype=np.float64)
    Y = np.asarray(Y, dtype=np.float64)
    sq_dists = np.sum(X**2, axis=1, keepdims=True) + np.sum(Y**2, axis=1) - 2 * (X @ Y.T)
    return np.exp(-np.maximum(sq_dists, 0) / (2 * sigma**2))


def rbf_kernel_val(Z, W, sigma):
    """Z (n,d), W (m,d) -> (n,m) RBF in value space."""
    Z = np.asarray(Z, dtype=np.float64)
    W = np.asarray(W, dtype=np.float64)
    sq_dists = np.sum(Z**2, axis=1, keepdims=True) + np.sum(W**2, axis=1) - 2 * (Z @ W.T)
    return np.exp(-np.maximum(sq_dists, 0) / (2 * sigma**2))


def product_kernel_gram_two(P, Q, sigma_space, sigma_val):
    """
    P (n, 5): (x,y,z1,z2,z3) for image I.
    Q (m, 5): for image I'.
    Returns K(I,I') = (1/(n*m)) sum_{i,j} k_space(p_i,q_j) * k_val(p_i,q_j).
    """
    n, m = P.shape[0], Q.shape[0]
    k_s = rbf_kernel_2d(P[:, :2], Q[:, :2], sigma_space)
    k_v = rbf_kernel_val(P[:, 2:5], Q[:, 2:5], sigma_val)
    return np.sum(k_s * k_v) / (n * m)


def gram_matrix(points_list, sigma_space, sigma_val):
    """points_list: list of (n_i, 5) arrays. Returns Gram (N, N). Vectorized inner loop."""
    N = len(points_list)
    Gram = np.zeros((N, N), dtype=np.float64)
    for i in range(N):
        for j in range(i, N):
            Gram[i, j] = Gram[j, i] = product_kernel_gram_two(
                points_list[i], points_list[j], sigma_space, sigma_val
            )
    return Gram


def gram_test_train(test_points_list, train_points_list, sigma_space, sigma_val):
    """K(test_i, train_j) -> (n_test, n_train)."""
    n_test = len(test_points_list)
    n_train = len(train_points_list)
    K = np.zeros((n_test, n_train), dtype=np.float64)
    for i in range(n_test):
        for j in range(n_train):
            K[i, j] = product_kernel_gram_two(
                test_points_list[i], train_points_list[j], sigma_space, sigma_val
            )
    return K

In [ ]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
train_dataset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

# Subsample for tractability (Gram is O(N^2 * n^2)); increase if you have time
N_TRAIN_KME = 1500
N_TEST_KME = 500
N_POINTS = 128  # points per image (subsample 32*32 grid)

np.random.seed(42)
train_idx = np.random.choice(len(train_dataset), size=N_TRAIN_KME, replace=False)
test_idx = np.random.choice(len(test_dataset), size=N_TEST_KME, replace=False)
point_indices = np.random.choice(IMAGE_SIZE * IMAGE_SIZE, size=N_POINTS, replace=False)

train_sub = Subset(train_dataset, train_idx)
test_sub = Subset(test_dataset, test_idx)
train_loader = DataLoader(train_sub, batch_size=32, shuffle=False, num_workers=0)
test_loader = DataLoader(test_sub, batch_size=32, shuffle=False, num_workers=0)

print(f"N_train={N_TRAIN_KME}, N_test={N_TEST_KME}, n_points={N_POINTS}")

In [ ]:
def extract_points_from_loader(loader, model, fourier_encoder, coords_32, device, point_indices):
    all_points = []
    for images, _ in loader:
        images = images.to(device)
        input_data, _, _ = prepare_model_input(images, coords_32, fourier_encoder)
        pts = get_field_points(model, fourier_encoder, input_data, coords_32, device, point_indices)
        for b in range(pts.shape[0]):
            all_points.append(pts[b])
    return all_points


print("Extracting field points for training...")
train_points_list = extract_points_from_loader(
    train_loader, model, fourier_encoder, coords_32, DEVICE, point_indices
)
train_labels = np.array([train_dataset[i][1] for i in train_idx])
print(f"Train: {len(train_points_list)} samples, each {train_points_list[0].shape}")

print("Extracting field points for test...")
test_points_list = extract_points_from_loader(
    test_loader, model, fourier_encoder, coords_32, DEVICE, point_indices
)
test_labels = np.array([test_dataset[i][1] for i in test_idx])
print(f"Test: {len(test_points_list)} samples")

In [ ]:
SIGMA_SPACE = 0.5
SIGMA_VAL = 1.0

print("Computing Gram matrix (train x train)...")
Gram_train = gram_matrix(train_points_list, SIGMA_SPACE, SIGMA_VAL)
print(f"Gram_train shape: {Gram_train.shape}")

print("Computing kernel matrix test x train...")
K_test_train = gram_test_train(test_points_list, train_points_list, SIGMA_SPACE, SIGMA_VAL)
print(f"K_test_train shape: {K_test_train.shape}")

In [ ]:
from sklearn.svm import SVC

print("Training kernel SVM (precomputed kernel)...")
clf = SVC(kernel="precomputed", C=1.0, class_weight="balanced")
clf.fit(Gram_train, train_labels)

pred_test = clf.predict(K_test_train)
acc = np.mean(pred_test == test_labels)
print(f"Test accuracy (KME + kernel SVM): {acc:.4f}")