# Semantic learning capacity: NCE vs JEPA

Compare **per-token** semantic structure of:
- **NCE** (SemanticTokenGenerator_CIFAR10: Soft InfoNCE + reconstruction)
- **JEPA** (JEPA_NeuralField_CIFAR10: masked prediction of φ + EMA target + VICReg + RGB aux)

**Per-token t-SNE**: each point = one φ token at one spatial location from one image; color = CIFAR-10 class of that image. Better semantic learning → clearer class clustering.

Optional: **k-NN accuracy** on mean-pooled φ (one vector per image) as a scalar metric.

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
from einops import rearrange, repeat

from nf_feature_models import (
    CascadedPerceiverIO,
    GaussianFourierFeatures,
    create_coordinate_grid,
    prepare_model_input,
)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT_DIR = "checkpoints"
IMAGE_SIZE = 32
CHANNELS = 3
FOURIER_MAPPING_SIZE = 96
POS_EMBED_DIM = FOURIER_MAPPING_SIZE * 2
INPUT_DIM = CHANNELS + POS_EMBED_DIM
QUERIES_DIM = POS_EMBED_DIM
LOGITS_DIM = CHANNELS
PHI_DIM = 128

coords_32 = create_coordinate_grid(IMAGE_SIZE, IMAGE_SIZE, DEVICE)
print("Device:", DEVICE)

In [None]:
def get_residual(model, context):
    residual = None
    for block in model.encoder_blocks:
        residual = block(x=residual, context=context, mask=None, residual=residual)
    for sa_block in model.self_attn_blocks:
        residual = sa_block[0](residual) + residual
        residual = sa_block[1](residual) + residual
    return residual

def decoder_forward(model, queries, residual):
    x = model.decoder_cross_attn(queries, context=residual)
    x = x + queries
    if model.decoder_ff is not None:
        x = x + model.decoder_ff(x)
    return x

def get_phi_raw(model, queries, residual):
    return decoder_forward(model, queries, residual)

def get_semantic_tokens(phi_raw, semantic_head):
    return F.normalize(semantic_head(phi_raw), dim=-1)

def phi_at_resolution(model, semantic_head, fourier_encoder, full_input, res_h, res_w):
    """(B, res_h*res_w, PHI_DIM) L2-normalized."""
    coords = create_coordinate_grid(res_h, res_w, full_input.device)
    B = full_input.size(0)
    residual = get_residual(model, full_input)
    queries = fourier_encoder(repeat(coords, "n d -> b n d", b=B))
    phi_raw = get_phi_raw(model, queries, residual)
    return get_semantic_tokens(phi_raw, semantic_head)

## Load NCE model (semantic token checkpoint)

In [None]:
CKPT_NCE = os.path.join(CHECKPOINT_DIR, "checkpoint_semantic_token_best.pt")
if not os.path.isfile(CKPT_NCE):
    CKPT_NCE = os.path.join(CHECKPOINT_DIR, "checkpoint_semantic_token_last.pt")
nce_model = nce_fourier = nce_semantic_head = None
if os.path.isfile(CKPT_NCE):
    nce_fourier = GaussianFourierFeatures(2, FOURIER_MAPPING_SIZE, 15.0).to(DEVICE)
    nce_model = CascadedPerceiverIO(
        input_dim=INPUT_DIM, queries_dim=QUERIES_DIM, logits_dim=LOGITS_DIM,
        latent_dims=(256, 384, 512), num_latents=(256, 256, 256), decoder_ff=True,
    ).to(DEVICE)
    nce_semantic_head = nn.Linear(QUERIES_DIM, PHI_DIM).to(DEVICE)
    ckpt = torch.load(CKPT_NCE, map_location=DEVICE)
    nce_model.load_state_dict(ckpt["model_state_dict"], strict=False)
    nce_fourier.load_state_dict(ckpt["fourier_encoder_state_dict"], strict=False)
    nce_semantic_head.load_state_dict(ckpt["semantic_head_state_dict"], strict=False)
    nce_model.eval()
    nce_fourier.eval()
    nce_semantic_head.eval()
    print("Loaded NCE:", CKPT_NCE)
else:
    print("NCE checkpoint not found; run SemanticTokenGenerator_CIFAR10.ipynb first.")

## Load JEPA model

In [None]:
CKPT_JEPA = os.path.join(CHECKPOINT_DIR, "checkpoint_jepa_best.pt")
if not os.path.isfile(CKPT_JEPA):
    CKPT_JEPA = os.path.join(CHECKPOINT_DIR, "checkpoint_jepa_last.pt")
jepa_model = jepa_fourier = jepa_semantic_head = None
if os.path.isfile(CKPT_JEPA):
    jepa_fourier = GaussianFourierFeatures(2, FOURIER_MAPPING_SIZE, 15.0).to(DEVICE)
    jepa_model = CascadedPerceiverIO(
        input_dim=INPUT_DIM, queries_dim=QUERIES_DIM, logits_dim=LOGITS_DIM,
        latent_dims=(256, 384, 512), num_latents=(256, 256, 256), decoder_ff=True,
    ).to(DEVICE)
    jepa_semantic_head = nn.Linear(QUERIES_DIM, PHI_DIM).to(DEVICE)
    ckpt = torch.load(CKPT_JEPA, map_location=DEVICE)
    jepa_model.load_state_dict(ckpt["model_state_dict"], strict=False)
    jepa_fourier.load_state_dict(ckpt["fourier_encoder_state_dict"], strict=False)
    jepa_semantic_head.load_state_dict(ckpt["semantic_head_state_dict"], strict=False)
    jepa_model.eval()
    jepa_fourier.eval()
    jepa_semantic_head.eval()
    print("Loaded JEPA:", CKPT_JEPA)
else:
    print("JEPA checkpoint not found; run JEPA_NeuralField_CIFAR10.ipynb first.")

In [None]:
transform = transforms.Compose([transforms.ToTensor()])
test_ds = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
test_loader = DataLoader(test_ds, batch_size=64, shuffle=False)
N_EVAL = min(400, len(test_ds))
TOKEN_RES = 16
N_TOKENS_PER_IMAGE = TOKEN_RES * TOKEN_RES
print("N_EVAL images:", N_EVAL, "| Tokens per image:", N_TOKENS_PER_IMAGE)

## Collect per-token φ (and labels) for NCE and JEPA

In [None]:
all_phi_nce, all_phi_jepa, all_labels = [], [], []
n_done = 0
with torch.no_grad():
    for imgs, labels in test_loader:
        if n_done >= N_EVAL:
            break
        imgs = imgs.to(DEVICE)
        B = imgs.size(0)
        labels_rep = labels.numpy().repeat(N_TOKENS_PER_IMAGE)
        all_labels.append(labels_rep)
        if nce_model is not None:
            full_input, _, _ = prepare_model_input(imgs, coords_32, nce_fourier)
            phi_nce = phi_at_resolution(nce_model, nce_semantic_head, nce_fourier, full_input, TOKEN_RES, TOKEN_RES)
            all_phi_nce.append(phi_nce.reshape(B * N_TOKENS_PER_IMAGE, PHI_DIM).cpu().numpy())
        if jepa_model is not None:
            full_input_j, _, _ = prepare_model_input(imgs, coords_32, jepa_fourier)
            phi_jepa = phi_at_resolution(jepa_model, jepa_semantic_head, jepa_fourier, full_input_j, TOKEN_RES, TOKEN_RES)
            all_phi_jepa.append(phi_jepa.reshape(B * N_TOKENS_PER_IMAGE, PHI_DIM).cpu().numpy())
        n_done += B

y_all = np.concatenate(all_labels, axis=0)
X_nce = np.concatenate(all_phi_nce, axis=0) if all_phi_nce else None
X_jepa = np.concatenate(all_phi_jepa, axis=0) if all_phi_jepa else None
print("Per-token features: NCE", X_nce.shape if X_nce is not None else None, "| JEPA", X_jepa.shape if X_jepa is not None else None)
print("Labels (one per token):", y_all.shape)

## Per-token t-SNE: NCE vs JEPA (colored by image class)

In [None]:
try:
    from sklearn.decomposition import PCA
    from sklearn.manifold import TSNE
    n_comp = min(50, (X_nce.shape[1] if X_nce is not None else X_jepa.shape[1]), (X_nce.shape[0] if X_nce is not None else X_jepa.shape[0]) - 1)
    n_plots = (1 if X_nce is not None else 0) + (1 if X_jepa is not None else 0)
    if n_plots == 0:
        raise RuntimeError("No model loaded")
    fig, axs = plt.subplots(1, n_plots, figsize=(7 * n_plots, 6))
    if n_plots == 1:
        axs = [axs]
    idx = 0
    perplexity = min(30, X_nce.shape[0] // 4) if X_nce is not None else min(30, X_jepa.shape[0] // 4)
    if X_nce is not None:
        X_nce_pca = PCA(n_components=n_comp).fit_transform(X_nce)
        X_nce_tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity).fit_transform(X_nce_pca)
        sc = axs[idx].scatter(X_nce_tsne[:, 0], X_nce_tsne[:, 1], c=y_all, cmap="tab10", s=2, alpha=0.5)
        axs[idx].set_title("NCE (per-token φ, color = image class)")
        axs[idx].set_xlabel("t-SNE 1")
        axs[idx].set_ylabel("t-SNE 2")
        idx += 1
    if X_jepa is not None:
        X_jepa_pca = PCA(n_components=n_comp).fit_transform(X_jepa)
        X_jepa_tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity).fit_transform(X_jepa_pca)
        sc = axs[idx].scatter(X_jepa_tsne[:, 0], X_jepa_tsne[:, 1], c=y_all, cmap="tab10", s=2, alpha=0.5)
        axs[idx].set_title("JEPA (per-token φ, color = image class)")
        axs[idx].set_xlabel("t-SNE 1")
        axs[idx].set_ylabel("t-SNE 2")
    plt.colorbar(sc, ax=axs, label="CIFAR-10 class", shrink=0.6)
    plt.suptitle("Semantics: NCE vs JEPA — per-token t-SNE (each point = one φ token)")
    plt.tight_layout()
    plt.savefig("eval_nce_vs_jepa_tsne.png", dpi=100)
    plt.show()
except Exception as e:
    print("t-SNE failed:", e)

## Optional: k-NN accuracy on mean-pooled φ (one vector per image)

In [None]:
from sklearn.neighbors import KNeighborsClassifier

n_imgs = len(y_all) // N_TOKENS_PER_IMAGE
y_pooled = y_all[::N_TOKENS_PER_IMAGE]

if X_nce is not None:
    X_nce_pooled = X_nce.reshape(n_imgs, N_TOKENS_PER_IMAGE, -1).mean(axis=1)
    knn_nce = KNeighborsClassifier(n_neighbors=20, weights="distance")
    knn_nce.fit(X_nce_pooled, y_pooled)
    acc_nce = knn_nce.score(X_nce_pooled, y_pooled)
    print("NCE k-NN accuracy (mean-pooled φ, train on same set):", f"{acc_nce:.4f}")
if X_jepa is not None:
    X_jepa_pooled = X_jepa.reshape(n_imgs, N_TOKENS_PER_IMAGE, -1).mean(axis=1)
    knn_jepa = KNeighborsClassifier(n_neighbors=20, weights="distance")
    knn_jepa.fit(X_jepa_pooled, y_pooled)
    acc_jepa = knn_jepa.score(X_jepa_pooled, y_pooled)
    print("JEPA k-NN accuracy (mean-pooled φ, train on same set):", f"{acc_jepa:.4f}")