# JEPA (Joint-Embedding Predictive Architecture) — Neural Field on CIFAR-10

Learn representations by **predicting target embeddings from context embeddings** under masking, **no contrastive negatives**. Continuity-aware field gives domain-continuous tokens φ(x); JEPA predicts φ at masked coordinates from visible ones.

**Distinguishable components:**

1. **Dual output** — f(x) → [RGB, φ(x)] (same as semantic-token notebook)
2. **Context / target split** — mask coords into visible (context) and predicted (target)
3. **EMA target encoder** — φ_target = stop_grad(φ_θ̄(X_t)); θ̄ updated by momentum
4. **Cross-attention predictor** — ẑ_t = P(z_c, X_c, X_t); predicts target embeddings from context
5. **JEPA loss** — cosine or L2 on norm(ẑ_t) vs norm(z_t); no negatives
6. **Optional VICReg** — variance + covariance on φ_online to avoid collapse
7. **Optional RGB auxiliary** — small λ reconstruction so semantics aren't dominated by pixels

Checkpoints: **checkpoint_jepa_best.pt** / **_last.pt**

In [None]:
import os
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
from einops import rearrange, repeat

from nf_feature_models import (
    CascadedPerceiverIO,
    GaussianFourierFeatures,
    create_coordinate_grid,
    prepare_model_input,
)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT_DIR = "checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
print("Device:", DEVICE)

## Config

In [None]:
IMAGE_SIZE = 32
CHANNELS = 3
FOURIER_MAPPING_SIZE = 96
POS_EMBED_DIM = FOURIER_MAPPING_SIZE * 2
INPUT_DIM = CHANNELS + POS_EMBED_DIM
QUERIES_DIM = POS_EMBED_DIM
LOGITS_DIM = CHANNELS
BATCH_SIZE = 64
EPOCHS = 30
N_FULL = IMAGE_SIZE * IMAGE_SIZE

# JEPA: context vs target coord split
N_CONTEXT = 512
N_TARGET = 256
USE_BLOCK_MASK = False

# EMA target
EMA_MOMENTUM = 0.996

# Losses
USE_JEPA_LOSS = True
USE_RGB_AUX = True
LAMBDA_RGB = 0.1
USE_VICREG = True
LAMBDA_VICREG = 0.1
VICREG_SIM_WEIGHT = 25.0
VICREG_VAR_WEIGHT = 25.0
VICREG_COV_WEIGHT = 1.0

PHI_DIM = 128
PREDICTOR_DIM = 256
PREDICTOR_HEADS = 4

coords_32 = create_coordinate_grid(IMAGE_SIZE, IMAGE_SIZE, DEVICE)
print("N_FULL:", N_FULL, "| N_CONTEXT:", N_CONTEXT, "| N_TARGET:", N_TARGET)

## 1) Field output: RGB + φ (shared with semantic-token design)

In [None]:
def sample_gt_at_coords(images, coords):
    B, C, H, W = images.shape
    N = coords.shape[1]
    grid = coords[..., [1, 0]].view(B, 1, N, 2)
    sampled = F.grid_sample(images, grid, mode="bilinear", padding_mode="border", align_corners=True)
    return sampled.squeeze(2).permute(0, 2, 1)

def get_residual(model, context):
    residual = None
    for block in model.encoder_blocks:
        residual = block(x=residual, context=context, mask=None, residual=residual)
    for sa_block in model.self_attn_blocks:
        residual = sa_block[0](residual) + residual
        residual = sa_block[1](residual) + residual
    return residual

def decoder_forward(model, queries, residual):
    x = model.decoder_cross_attn(queries, context=residual)
    x = x + queries
    if model.decoder_ff is not None:
        x = x + model.decoder_ff(x)
    return x

def get_rgb(model, queries, residual):
    return model.to_logits(decoder_forward(model, queries, residual))

def get_phi_raw(model, queries, residual):
    return decoder_forward(model, queries, residual)

def get_semantic_tokens(phi_raw, semantic_head):
    return F.normalize(semantic_head(phi_raw), dim=-1)

## 2) Context/target split — random or block mask

In [None]:
def sample_context_target_indices(coords_full, n_context, n_target, device, block_mask=False):
    """Return idx_c, idx_t disjoint. block_mask: target = contiguous block (e.g. center)."""
    n_total = coords_full.size(0)
    if block_mask:
        h, w = int(math.sqrt(n_total)), int(math.sqrt(n_total))
        nh, nw = max(1, h // 4), max(1, w // 4)
        top = random.randint(0, h - nh)
        left = random.randint(0, w - nw)
        target_flat = []
        for i in range(top, top + nh):
            for j in range(left, left + nw):
                target_flat.append(i * w + j)
        idx_t = torch.tensor(target_flat, device=device, dtype=torch.long)
        idx_c = torch.tensor([k for k in range(n_total) if k not in set(target_flat)], device=device)
        if idx_c.size(0) > n_context:
            idx_c = idx_c[torch.randperm(idx_c.size(0), device=device)[:n_context]]
    else:
        perm = torch.randperm(n_total, device=device)
        idx_c = perm[:n_context]
        idx_t = perm[n_context : n_context + n_target]
    return idx_c, idx_t

def prepare_context_input(images, coords_full, fourier_encoder, idx_c, device):
    """Build (B, len(idx_c), INPUT_DIM) from pixels at context coords only."""
    B = images.size(0)
    coords_c = coords_full[idx_c]
    pixels = sample_gt_at_coords(images, coords_c.unsqueeze(0).expand(B, -1, -1))
    pos = fourier_encoder(coords_c.unsqueeze(0).expand(B, -1, -1))
    return torch.cat([pixels, pos], dim=-1)

## 3) Cross-attention predictor — P(z_c, X_c, X_t) → ẑ_t

Query from target coords; key/value from context tokens. Field-native.

In [None]:
class JEPAPredictor(nn.Module):
    """Predict target embeddings from context: ẑ_t = Attn(q=embed(X_t), k=z_c, v=z_c)."""

    def __init__(self, coord_embed_dim, phi_dim, pred_dim, num_heads=4):
        super().__init__()
        self.to_q = nn.Linear(coord_embed_dim, pred_dim)
        self.to_kv = nn.Linear(phi_dim, pred_dim * 2)
        self.to_out = nn.Linear(pred_dim, phi_dim)
        self.num_heads = num_heads
        self.d_head = pred_dim // num_heads

    def forward(self, coords_t_embed, z_c):
        B, N_t, _ = coords_t_embed.shape
        _, N_c, _ = z_c.shape
        q = self.to_q(coords_t_embed).view(B, N_t, self.num_heads, self.d_head).transpose(1, 2)
        kv = self.to_kv(z_c).view(B, N_c, 2, self.num_heads, self.d_head)
        k, v = kv[:, :, 0].transpose(1, 2), kv[:, :, 1].transpose(1, 2)
        scale = self.d_head ** -0.5
        attn = torch.matmul(q, k.transpose(-2, -1)) * scale
        attn = F.softmax(attn, dim=-1)
        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).reshape(B, N_t, -1)
        return self.to_out(out)

## 4) JEPA loss (no negatives) + optional VICReg

In [None]:
def jepa_cosine_loss(z_pred, z_target):
    """1 - cos(norm(z_pred), z_target); z_target already normalized."""
    z_pred = F.normalize(z_pred, dim=-1)
    return (1 - (z_pred * z_target).sum(dim=-1)).mean()

def vicreg_loss(z, sim_weight=25.0, var_weight=25.0, cov_weight=1.0):
    """Variance + covariance regularization to avoid collapse (no negatives)."""
    B, N, D = z.shape
    z = z.reshape(B * N, D)
    std = z.std(dim=0) + 1e-4
    var_loss = torch.mean(F.relu(1 - std))
    z_centered = z - z.mean(dim=0)
    cov = (z_centered.T @ z_centered) / (z.size(0) - 1)
    cov_loss = (cov.pow(2).sum() - cov.diag().pow(2).sum()) / D
    return var_weight * var_loss + cov_weight * cov_loss

## 5) Online encoder + EMA target + predictor

In [None]:
def copy_ema(source, target, momentum=0.996):
    for p_s, p_t in zip(source.parameters(), target.parameters()):
        p_t.data.mul_(momentum).add_(p_s.data, alpha=1 - momentum)

fourier_encoder = GaussianFourierFeatures(2, FOURIER_MAPPING_SIZE, 15.0).to(DEVICE)
model = CascadedPerceiverIO(
    input_dim=INPUT_DIM, queries_dim=QUERIES_DIM, logits_dim=LOGITS_DIM,
    latent_dims=(256, 384, 512), num_latents=(256, 256, 256), decoder_ff=True,
).to(DEVICE)
semantic_head = nn.Linear(QUERIES_DIM, PHI_DIM).to(DEVICE)

model_ema = CascadedPerceiverIO(
    input_dim=INPUT_DIM, queries_dim=QUERIES_DIM, logits_dim=LOGITS_DIM,
    latent_dims=(256, 384, 512), num_latents=(256, 256, 256), decoder_ff=True,
).to(DEVICE)
fourier_ema = GaussianFourierFeatures(2, FOURIER_MAPPING_SIZE, 15.0).to(DEVICE)
semantic_head_ema = nn.Linear(QUERIES_DIM, PHI_DIM).to(DEVICE)
model_ema.load_state_dict(model.state_dict())
fourier_ema.load_state_dict(fourier_encoder.state_dict())
semantic_head_ema.load_state_dict(semantic_head.state_dict())

predictor = JEPAPredictor(POS_EMBED_DIM, PHI_DIM, PREDICTOR_DIM, num_heads=PREDICTOR_HEADS).to(DEVICE)
print("Online + EMA target + predictor")

In [None]:
transform = transforms.Compose([transforms.ToTensor()])
train_ds = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
test_ds = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

## Training loop — JEPA + RGB aux + VICReg + EMA update

In [None]:
params = list(model.parameters()) + list(fourier_encoder.parameters()) + list(semantic_head.parameters()) + list(predictor.parameters())
optimizer = torch.optim.Adam(params, lr=1e-3)
best_val_loss = float("inf")
CKPT_BEST = os.path.join(CHECKPOINT_DIR, "checkpoint_jepa_best.pt")
CKPT_LAST = os.path.join(CHECKPOINT_DIR, "checkpoint_jepa_last.pt")

In [None]:
for epoch in range(EPOCHS):
    model.train()
    fourier_encoder.train()
    semantic_head.train()
    predictor.train()
    model_ema.eval()
    fourier_ema.eval()
    semantic_head_ema.eval()
    total_loss = 0.0
    total_jepa = 0.0
    total_rgb = 0.0
    total_vic = 0.0
    for imgs, _ in train_loader:
        imgs = imgs.to(DEVICE)
        B = imgs.size(0)

        idx_c, idx_t = sample_context_target_indices(coords_32, N_CONTEXT, N_TARGET, DEVICE, block_mask=USE_BLOCK_MASK)
        coords_c = coords_32[idx_c]
        coords_t = coords_32[idx_t]

        context_input = prepare_context_input(imgs, coords_32, fourier_encoder, idx_c, DEVICE)
        residual = get_residual(model, context_input)

        queries_c = fourier_encoder(coords_c.unsqueeze(0).expand(B, -1, -1))
        queries_t = fourier_encoder(coords_t.unsqueeze(0).expand(B, -1, -1))

        z_c = get_semantic_tokens(get_phi_raw(model, queries_c, residual), semantic_head)

        with torch.no_grad():
            residual_ema = get_residual(model_ema, context_input)
            z_t_target = get_semantic_tokens(get_phi_raw(model_ema, queries_t, residual_ema), semantic_head_ema)

        z_pred = predictor(queries_t, z_c)
        loss_jepa = jepa_cosine_loss(z_pred, z_t_target) if USE_JEPA_LOSS else torch.tensor(0.0, device=DEVICE)
        loss = loss_jepa

        if USE_VICREG:
            loss_vic = vicreg_loss(z_c, VICREG_SIM_WEIGHT, VICREG_VAR_WEIGHT, VICREG_COV_WEIGHT)
            loss = loss + LAMBDA_VICREG * loss_vic
            total_vic += loss_vic.item()

        if USE_RGB_AUX:
            queries_full = fourier_encoder(coords_32.unsqueeze(0).expand(B, -1, -1))
            rgb = get_rgb(model, queries_full, residual)
            target_pixels = rearrange(imgs, "b c h w -> b (h w) c")
            loss_rgb = F.mse_loss(rgb, target_pixels) * LAMBDA_RGB
            loss = loss + loss_rgb
            total_rgb += loss_rgb.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        total_jepa += loss_jepa.item()

        copy_ema(model, model_ema, EMA_MOMENTUM)
        copy_ema(fourier_encoder, fourier_ema, EMA_MOMENTUM)
        copy_ema(semantic_head, semantic_head_ema, EMA_MOMENTUM)

    avg = total_loss / len(train_loader)
    model.eval()
    fourier_encoder.eval()
    val_loss = 0.0
    with torch.no_grad():
        for imgs, _ in test_loader:
            imgs = imgs.to(DEVICE)
            B = imgs.size(0)
            full_input, _, _ = prepare_model_input(imgs, coords_32, fourier_encoder)
            residual = get_residual(model, full_input)
            queries_full = fourier_encoder(coords_32.unsqueeze(0).expand(B, -1, -1))
            rgb = get_rgb(model, queries_full, residual)
            target_pixels = rearrange(imgs, "b c h w -> b (h w) c")
            val_loss += F.mse_loss(rgb, target_pixels).item()
    val_loss /= len(test_loader)
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            "epoch": epoch + 1,
            "model_state_dict": model.state_dict(),
            "fourier_encoder_state_dict": fourier_encoder.state_dict(),
            "semantic_head_state_dict": semantic_head.state_dict(),
            "predictor_state_dict": predictor.state_dict(),
            "model_ema_state_dict": model_ema.state_dict(),
            "fourier_ema_state_dict": fourier_ema.state_dict(),
            "semantic_head_ema_state_dict": semantic_head_ema.state_dict(),
            "best_val_loss": best_val_loss,
        }, CKPT_BEST)
    torch.save({
        "epoch": epoch + 1,
        "model_state_dict": model.state_dict(),
        "fourier_encoder_state_dict": fourier_encoder.state_dict(),
        "semantic_head_state_dict": semantic_head.state_dict(),
        "predictor_state_dict": predictor.state_dict(),
        "model_ema_state_dict": model_ema.state_dict(),
        "fourier_ema_state_dict": fourier_ema.state_dict(),
        "semantic_head_ema_state_dict": semantic_head_ema.state_dict(),
        "best_val_loss": best_val_loss,
    }, CKPT_LAST)
    vic_str = f" vic: {total_vic/len(train_loader):.4f}" if USE_VICREG else ""
    rgb_str = f" rgb: {total_rgb/len(train_loader):.4f}" if USE_RGB_AUX else ""
    print(f"Epoch {epoch+1}/{EPOCHS} jepa: {total_jepa/len(train_loader):.4f}{vic_str}{rgb_str} val_loss: {val_loss:.4f}")

## Resolution-agnostic query (same as semantic-token notebook)

In [None]:
def query_semantic_tokens_at_resolution(model, semantic_head, fourier_encoder, residual, res_h, res_w, device):
    coords = create_coordinate_grid(res_h, res_w, device)
    B = residual.size(0)
    queries = fourier_encoder(coords.unsqueeze(0).expand(B, -1, -1))
    phi_raw = get_phi_raw(model, queries, residual)
    return get_semantic_tokens(phi_raw, semantic_head)

model.eval()
fourier_encoder.eval()
semantic_head.eval()
imgs, _ = next(iter(test_loader))
imgs = imgs[:4].to(DEVICE)
full_input, _, _ = prepare_model_input(imgs, coords_32, fourier_encoder)
with torch.no_grad():
    residual = get_residual(model, full_input)
    phi_16 = query_semantic_tokens_at_resolution(model, semantic_head, fourier_encoder, residual, 16, 16, DEVICE)
    phi_32 = query_semantic_tokens_at_resolution(model, semantic_head, fourier_encoder, residual, 32, 32, DEVICE)
print("φ at 16×16:", phi_16.shape, "| 32×32:", phi_32.shape)