# JEPA (Joint-Embedding Predictive Architecture) — Neural Field on CIFAR-10

Learn representations by **predicting target embeddings from context embeddings** under masking, **no contrastive negatives**. Continuity-aware field gives domain-continuous tokens φ(x); JEPA predicts φ at masked coordinates from visible ones.

**Distinguishable components:**

1. **Dual output** — f(x) → [RGB, φ(x)] (same as semantic-token notebook)
2. **Context / target split** — mask coords into visible (context) and predicted (target)
3. **EMA target encoder** — φ_target = stop_grad(φ_θ̄(X_t)); θ̄ updated by momentum
4. **Cross-attention predictor** — ẑ_t = P(z_c, X_c, X_t); predicts target embeddings from context
5. **JEPA loss** — cosine or L2 on norm(ẑ_t) vs norm(z_t); no negatives
6. **Optional VICReg** — variance + covariance on φ_online to avoid collapse
7. **Optional RGB auxiliary** — small λ reconstruction so semantics aren't dominated by pixels

Checkpoints: **checkpoint_jepa_best.pt** / **_last.pt**

**I-JEPA faithful**: With **TARGET_FROM_FULL_CONTEXT=True** (default), target = φ_EMA(**full image**) at target coords, matching `ijepa/src/train.py` (target_encoder sees full image; we use full grid). Context branch unchanged: online encoder sees context only; predictor predicts from context tokens. Loss in representation space; EMA update of target encoder.

**Speed**: Use `num_workers=2` and **RGB_QUERY_RES=16** to avoid >5 min/epoch.

In [None]:
import os
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
from einops import rearrange, repeat

from nf_feature_models import (
    CascadedPerceiverIO,
    GaussianFourierFeatures,
    create_coordinate_grid,
    prepare_model_input,
)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT_DIR = "checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
print("Device:", DEVICE)

## Config

In [None]:
IMAGE_SIZE = 32
CHANNELS = 3
FOURIER_MAPPING_SIZE = 96
POS_EMBED_DIM = FOURIER_MAPPING_SIZE * 2
INPUT_DIM = CHANNELS + POS_EMBED_DIM
QUERIES_DIM = POS_EMBED_DIM
LOGITS_DIM = CHANNELS
BATCH_SIZE = 64
EPOCHS = 30
N_FULL = IMAGE_SIZE * IMAGE_SIZE

# JEPA: context vs target coord split
# Mask style: "ijepa_block" = I-JEPA-style block-level (patch grid, pred blocks + enc blocks in complement); "random" = random pixel split
MASK_STYLE = "ijepa_block"
N_CONTEXT = 512
N_TARGET = 256
USE_BLOCK_MASK = MASK_STYLE == "ijepa_block"  # legacy flag for viz

# I-JEPA-style block masking (see ijepa/src/masks/multiblock.py)
PATCH_SIZE = 4
ENC_MASK_SCALE = (0.85, 1.0)
PRED_MASK_SCALE = (0.15, 0.25)
ASPECT_RATIO = (0.75, 1.5)
N_ENC_MASKS = 1
N_PRED_MASKS = 2
ALLOW_OVERLAP = False
MIN_KEEP = 4

# Verbose training: log every N batches (0 = only epoch summary)
VERBOSE = True
LOG_EVERY = 20

# I-JEPA faithful: target = EMA(full image) at target coords (True). Legacy: target = EMA(context) at target (False).
TARGET_FROM_FULL_CONTEXT = True

# EMA target (momentum for target encoder; same as ijepa)
EMA_MOMENTUM = 0.996

# Losses
USE_JEPA_LOSS = True
USE_RGB_AUX = True
LAMBDA_RGB = 0.1
RGB_QUERY_RES = 16
# RGB aux at 16×16 (256 queries) instead of 32×32 (1024) cuts decoder cost ~4x; set 32 for full-res.
USE_VICREG = True
LAMBDA_VICREG = 0.1
VICREG_SIM_WEIGHT = 25.0
VICREG_VAR_WEIGHT = 25.0
VICREG_COV_WEIGHT = 1.0

PHI_DIM = 128
PREDICTOR_DIM = 256
PREDICTOR_HEADS = 4

coords_32 = create_coordinate_grid(IMAGE_SIZE, IMAGE_SIZE, DEVICE)
print("N_FULL:", N_FULL, "| MASK_STYLE:", MASK_STYLE, "| TARGET_FROM_FULL_CONTEXT:", TARGET_FROM_FULL_CONTEXT)
if MASK_STYLE == "ijepa_block":
    print("  I-JEPA blocks: patch_size=%d enc_scale=%s pred_scale=%s aspect=%s n_enc=%d n_pred=%d" % (PATCH_SIZE, ENC_MASK_SCALE, PRED_MASK_SCALE, ASPECT_RATIO, N_ENC_MASKS, N_PRED_MASKS))

## 1) Field output: RGB + φ (shared with semantic-token design)

In [None]:
def sample_gt_at_coords(images, coords):
    B, C, H, W = images.shape
    N = coords.shape[1]
    grid = coords[..., [1, 0]].view(B, 1, N, 2)
    sampled = F.grid_sample(images, grid, mode="bilinear", padding_mode="border", align_corners=True)
    return sampled.squeeze(2).permute(0, 2, 1)

def get_residual(model, context):
    residual = None
    for block in model.encoder_blocks:
        residual = block(x=residual, context=context, mask=None, residual=residual)
    for sa_block in model.self_attn_blocks:
        residual = sa_block[0](residual) + residual
        residual = sa_block[1](residual) + residual
    return residual

def decoder_forward(model, queries, residual):
    x = model.decoder_cross_attn(queries, context=residual)
    x = x + queries
    if model.decoder_ff is not None:
        x = x + model.decoder_ff(x)
    return x

def get_rgb(model, queries, residual):
    return model.to_logits(decoder_forward(model, queries, residual))

def get_phi_raw(model, queries, residual):
    return decoder_forward(model, queries, residual)

def get_semantic_tokens(phi_raw, semantic_head):
    return F.normalize(semantic_head(phi_raw), dim=-1)

## 2) Context/target split — I-JEPA-style block masking or random

In [None]:
def _patch_to_pixel_indices(patch_indices_flat, h_patch, w_patch, patch_size, image_size, device):
    """Convert flat patch indices to flat pixel indices (C order: row, then col)."""
    pixels_per_patch = patch_size * patch_size
    pixel_indices = []
    for p_flat in patch_indices_flat:
        pi, pj = p_flat // w_patch, p_flat % w_patch
        for di in range(patch_size):
            for dj in range(patch_size):
                i, j = pi * patch_size + di, pj * patch_size + dj
                pixel_indices.append(i * image_size + j)
    return torch.tensor(pixel_indices, device=device, dtype=torch.long)

def sample_ijepa_block_indices(image_size, patch_size, enc_mask_scale, pred_mask_scale, aspect_ratio,
                               n_enc_masks, n_pred_masks, allow_overlap, min_keep, device, generator=None):
    """
    I-JEPA-style block masking (ijepa/src/masks/multiblock.py).
    Returns (idx_c, idx_t) as flat pixel indices: context = encoder sees, target = predictor predicts.
    """
    h_patch = image_size // patch_size
    w_patch = image_size // patch_size
    n_patches = h_patch * w_patch
    if generator is None:
        generator = torch.Generator(device=device)

    def sample_block_size(scale_lo, scale_hi, aspect_lo, aspect_hi):
        s = scale_lo + torch.rand(1, device=device, generator=generator).item() * (scale_hi - scale_lo)
        ar = aspect_lo + torch.rand(1, device=device, generator=generator).item() * (aspect_hi - aspect_lo)
        max_keep = int(n_patches * s)
        h = max(1, min(h_patch - 1, int(round((max_keep * ar) ** 0.5))))
        w = max(1, min(w_patch - 1, int(round((max_keep / ar) ** 0.5))))
        return h, w

    def sample_one_block(bh, bw, acceptable_region=None):
        """acceptable_region: (H, W) bool, True = can place block. None = anywhere."""
        for _ in range(30):
            top = torch.randint(0, h_patch - bh + 1, (1,), device=device, generator=generator).item()
            left = torch.randint(0, w_patch - bw + 1, (1,), device=device, generator=generator).item()
            block = torch.zeros(h_patch, w_patch, dtype=torch.bool, device=device)
            block[top : top + bh, left : left + bw] = True
            if acceptable_region is not None:
                if not (block & acceptable_region == block).all():
                    continue
            idx = torch.nonzero(block.flatten(), as_tuple=False).squeeze(-1)
            if len(idx) >= min_keep:
                complement = torch.ones(h_patch, w_patch, dtype=torch.bool, device=device)
                complement[top : top + bh, left : left + bw] = False
                return idx, complement
        return None, None

    # 1) Target (pred) blocks
    ph, pw = sample_block_size(pred_mask_scale[0], pred_mask_scale[1], aspect_ratio[0], aspect_ratio[1])
    all_pred_patches = set()
    pred_complements = []
    for _ in range(n_pred_masks):
        idx, comp = sample_one_block(ph, pw, None)
        if idx is None:
            continue
        pred_complements.append(comp)
        for i in idx.cpu().tolist():
            all_pred_patches.add(i)
    pred_patches_flat = list(all_pred_patches)
    if len(pred_patches_flat) < min_keep:
        pred_patches_flat = list(range(n_patches))[: min_keep + 1]
    idx_t = _patch_to_pixel_indices(pred_patches_flat, h_patch, w_patch, patch_size, image_size, device)

    # 2) Context (enc) blocks: only in complement of target
    acceptable = None if allow_overlap else torch.ones(h_patch, w_patch, dtype=torch.bool, device=device)
    if not allow_overlap and pred_complements:
        for c in pred_complements:
            acceptable = acceptable & c
    eh, ew = sample_block_size(enc_mask_scale[0], enc_mask_scale[1], 1.0, 1.0)
    all_enc_patches = set()
    for _ in range(n_enc_masks):
        idx, _ = sample_one_block(eh, ew, acceptable)
        if idx is not None:
            for i in idx.cpu().tolist():
                all_enc_patches.add(i)
    enc_patches_flat = list(all_enc_patches)
    if len(enc_patches_flat) < min_keep:
        enc_patches_flat = [k for k in range(n_patches) if k not in all_pred_patches][: min_keep + 1]
    if not enc_patches_flat:
        enc_patches_flat = [k for k in range(n_patches) if k not in all_pred_patches]
    idx_c = _patch_to_pixel_indices(enc_patches_flat, h_patch, w_patch, patch_size, image_size, device)

    return idx_c, idx_t

def sample_context_target_indices(coords_full, n_context, n_target, device, block_mask=False,
                                  mask_style="random", image_size=32, patch_size=4,
                                  enc_mask_scale=(0.85, 1.0), pred_mask_scale=(0.15, 0.25),
                                  aspect_ratio=(0.75, 1.5), n_enc_masks=1, n_pred_masks=2,
                                  allow_overlap=False, min_keep=4):
    """Return idx_c, idx_t disjoint. mask_style='ijepa_block' uses I-JEPA block masking."""
    n_total = coords_full.size(0)
    if mask_style == "ijepa_block":
        return sample_ijepa_block_indices(
            image_size, patch_size, enc_mask_scale, pred_mask_scale, aspect_ratio,
            n_enc_masks, n_pred_masks, allow_overlap, min_keep, device)
    if block_mask:
        h, w = int(math.sqrt(n_total)), int(math.sqrt(n_total))
        nh, nw = max(1, h // 4), max(1, w // 4)
        top = random.randint(0, h - nh)
        left = random.randint(0, w - nw)
        target_flat = []
        for i in range(top, top + nh):
            for j in range(left, left + nw):
                target_flat.append(i * w + j)
        idx_t = torch.tensor(target_flat, device=device, dtype=torch.long)
        idx_c = torch.tensor([k for k in range(n_total) if k not in set(target_flat)], device=device)
        if idx_c.size(0) > n_context:
            idx_c = idx_c[torch.randperm(idx_c.size(0), device=device)[:n_context]]
    else:
        perm = torch.randperm(n_total, device=device)
        idx_c = perm[:n_context]
        idx_t = perm[n_context : n_context + n_target]
    return idx_c, idx_t

def prepare_context_input(images, coords_full, fourier_encoder, idx_c, device):
    """Build (B, len(idx_c), INPUT_DIM) from pixels at context coords only."""
    B = images.size(0)
    coords_c = coords_full[idx_c]
    pixels = sample_gt_at_coords(images, coords_c.unsqueeze(0).expand(B, -1, -1))
    pos = fourier_encoder(coords_c.unsqueeze(0).expand(B, -1, -1))
    return torch.cat([pixels, pos], dim=-1)

## 3) Cross-attention predictor — P(z_c, X_c, X_t) → ẑ_t

Query from target coords; key/value from context tokens. Field-native.

In [None]:
class JEPAPredictor(nn.Module):
    """Predict target embeddings from context: ẑ_t = Attn(q=embed(X_t), k=z_c, v=z_c)."""

    def __init__(self, coord_embed_dim, phi_dim, pred_dim, num_heads=4):
        super().__init__()
        self.to_q = nn.Linear(coord_embed_dim, pred_dim)
        self.to_kv = nn.Linear(phi_dim, pred_dim * 2)
        self.to_out = nn.Linear(pred_dim, phi_dim)
        self.num_heads = num_heads
        self.d_head = pred_dim // num_heads

    def forward(self, coords_t_embed, z_c):
        B, N_t, _ = coords_t_embed.shape
        _, N_c, _ = z_c.shape
        q = self.to_q(coords_t_embed).view(B, N_t, self.num_heads, self.d_head).transpose(1, 2)
        kv = self.to_kv(z_c).view(B, N_c, 2, self.num_heads, self.d_head)
        k, v = kv[:, :, 0].transpose(1, 2), kv[:, :, 1].transpose(1, 2)
        scale = self.d_head ** -0.5
        attn = torch.matmul(q, k.transpose(-2, -1)) * scale
        attn = F.softmax(attn, dim=-1)
        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).reshape(B, N_t, -1)
        return self.to_out(out)

## 4) JEPA loss (no negatives) + optional VICReg

In [None]:
def jepa_cosine_loss(z_pred, z_target):
    """1 - cos(norm(z_pred), z_target); z_target already normalized."""
    z_pred = F.normalize(z_pred, dim=-1)
    return (1 - (z_pred * z_target).sum(dim=-1)).mean()

def vicreg_loss(z, sim_weight=25.0, var_weight=25.0, cov_weight=1.0):
    """Variance + covariance regularization to avoid collapse (no negatives)."""
    B, N, D = z.shape
    z = z.reshape(B * N, D)
    std = z.std(dim=0) + 1e-4
    var_loss = torch.mean(F.relu(1 - std))
    z_centered = z - z.mean(dim=0)
    cov = (z_centered.T @ z_centered) / (z.size(0) - 1)
    cov_loss = (cov.pow(2).sum() - cov.diag().pow(2).sum()) / D
    return var_weight * var_loss + cov_weight * cov_loss

## 5) Online encoder + EMA target + predictor

In [None]:
def copy_ema(source, target, momentum=0.996):
    for p_s, p_t in zip(source.parameters(), target.parameters()):
        p_t.data.mul_(momentum).add_(p_s.data, alpha=1 - momentum)

fourier_encoder = GaussianFourierFeatures(2, FOURIER_MAPPING_SIZE, 15.0).to(DEVICE)
model = CascadedPerceiverIO(
    input_dim=INPUT_DIM, queries_dim=QUERIES_DIM, logits_dim=LOGITS_DIM,
    latent_dims=(256, 384, 512), num_latents=(256, 256, 256), decoder_ff=True,
).to(DEVICE)
semantic_head = nn.Linear(QUERIES_DIM, PHI_DIM).to(DEVICE)

model_ema = CascadedPerceiverIO(
    input_dim=INPUT_DIM, queries_dim=QUERIES_DIM, logits_dim=LOGITS_DIM,
    latent_dims=(256, 384, 512), num_latents=(256, 256, 256), decoder_ff=True,
).to(DEVICE)
fourier_ema = GaussianFourierFeatures(2, FOURIER_MAPPING_SIZE, 15.0).to(DEVICE)
semantic_head_ema = nn.Linear(QUERIES_DIM, PHI_DIM).to(DEVICE)
model_ema.load_state_dict(model.state_dict())
fourier_ema.load_state_dict(fourier_encoder.state_dict())
semantic_head_ema.load_state_dict(semantic_head.state_dict())

predictor = JEPAPredictor(POS_EMBED_DIM, PHI_DIM, PREDICTOR_DIM, num_heads=PREDICTOR_HEADS).to(DEVICE)
print("Online + EMA target + predictor")

In [None]:
transform = transforms.Compose([transforms.ToTensor()])
train_ds = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
test_ds = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

## Visualize inputs: full image, context encoder, target encoder, and masks

Below we take one batch and show **what each branch sees** so the data flow is transparent.

In [None]:
# One batch, one fixed mask for clarity
imgs_viz, _ = next(iter(train_loader))
imgs_viz = imgs_viz.to(DEVICE)
B_viz = min(4, imgs_viz.size(0))
imgs_viz = imgs_viz[:B_viz]
H_viz, W_viz = IMAGE_SIZE, IMAGE_SIZE
N_full = H_viz * W_viz

idx_c, idx_t = sample_context_target_indices(
    coords_32, N_CONTEXT, N_TARGET, DEVICE, block_mask=USE_BLOCK_MASK, mask_style=MASK_STYLE,
    image_size=IMAGE_SIZE, patch_size=PATCH_SIZE, enc_mask_scale=ENC_MASK_SCALE,
    pred_mask_scale=PRED_MASK_SCALE, aspect_ratio=ASPECT_RATIO, n_enc_masks=N_ENC_MASKS,
    n_pred_masks=N_PRED_MASKS, allow_overlap=ALLOW_OVERLAP, min_keep=MIN_KEEP)
# Build mask images: green = context only, red = target only
mask_c_2d = torch.zeros(N_full, device=DEVICE, dtype=torch.bool)
mask_c_2d[idx_c] = True
mask_t_2d = torch.zeros(N_full, device=DEVICE, dtype=torch.bool)
mask_t_2d[idx_t] = True
mask_c_2d = mask_c_2d.view(H_viz, W_viz).cpu().numpy()
mask_t_2d = mask_t_2d.view(H_viz, W_viz).cpu().numpy()

# Full image: what the target encoder (EMA) sees when TARGET_FROM_FULL_CONTEXT=True
full_imgs = imgs_viz  # (B, 3, H, W)

# Context-only "image": pixels at context coords, gray (0.5) at target coords (what online encoder sees)
pixels_flat = rearrange(imgs_viz, "b c h w -> b (h w) c")
context_only = torch.ones(B_viz, N_full, 3, device=DEVICE) * 0.5  # gray
context_only[:, idx_c] = pixels_flat[:, idx_c]
context_only = context_only.view(B_viz, H_viz, W_viz, 3).permute(0, 3, 1, 2)

# Target-positions "image": pixels only at target coords, gray elsewhere (where we predict φ)
target_only = torch.ones(B_viz, N_full, 3, device=DEVICE) * 0.5
target_only[:, idx_t] = pixels_flat[:, idx_t]
target_only = target_only.view(B_viz, H_viz, W_viz, 3).permute(0, 3, 1, 2)

# Mask overlay: green = context, red = target (same mask for all samples in batch)
overlay = torch.zeros(B_viz, 3, H_viz, W_viz, device=DEVICE)
m_t = torch.from_numpy(mask_t_2d).float().to(DEVICE).unsqueeze(0).expand(B_viz, -1, -1)
m_c = torch.from_numpy(mask_c_2d).float().to(DEVICE).unsqueeze(0).expand(B_viz, -1, -1)
overlay[:, 0] = m_t  # R where target
overlay[:, 1] = m_c  # G where context

def to_display(img_bchw):
    """(B,C,H,W) in [-1,1] -> (H,W,C) in [0,1] for one image."""
    x = img_bchw[0].cpu().permute(1, 2, 0).numpy()
    return np.clip(x * 0.5 + 0.5, 0, 1)

fig, axes = plt.subplots(5, B_viz, figsize=(3 * B_viz, 14))
if B_viz == 1:
    axes = axes[:, np.newaxis]
titles_row = [
    "Full image (input to target encoder / EMA when TARGET_FROM_FULL_CONTEXT=True)",
    "Context only (input to online context encoder)",
    "Target positions (pixels at coords where we predict φ)",
    "Mask: green = context, red = target",
    "Context (green) & target (red) coords on full image",
]
# Coords in [-1,1] -> display: (x+1)/2 * W for col, (1-y)/2 * H for row (y up)
coords_c_np = coords_32[idx_c].cpu().numpy()
coords_t_np = coords_32[idx_t].cpu().numpy()
for b in range(B_viz):
    axes[0, b].imshow(to_display(full_imgs[b : b + 1]))
    axes[0, b].set_title(f"Sample {b+1}" if b == 0 else "")
    axes[0, b].axis("off")
    axes[1, b].imshow(to_display(context_only[b : b + 1]))
    axes[1, b].axis("off")
    axes[2, b].imshow(to_display(target_only[b : b + 1]))
    axes[2, b].axis("off")
    axes[3, b].imshow(to_display(overlay[b : b + 1]))
    axes[3, b].axis("off")
    # Scatter: coords (x,y) in [-1,1], imshow has (0,0) top-left, x=col, y=row
    axes[4, b].imshow(to_display(full_imgs[b : b + 1]))
    col_c = (coords_c_np[:, 0] + 1) / 2 * (W_viz - 1)
    row_c = (1 - coords_c_np[:, 1]) / 2 * (H_viz - 1)
    col_t = (coords_t_np[:, 0] + 1) / 2 * (W_viz - 1)
    row_t = (1 - coords_t_np[:, 1]) / 2 * (H_viz - 1)
    axes[4, b].scatter(col_c, row_c, c="lime", s=8, alpha=0.8, label="context")
    axes[4, b].scatter(col_t, row_t, c="red", s=8, alpha=0.8, label="target")
    axes[4, b].axis("off")
for r in range(5):
    axes[r, 0].set_ylabel(titles_row[r], fontsize=10)
plt.suptitle("JEPA inputs: what each branch sees (one batch)", fontsize=12)
plt.tight_layout()
plt.show()

# Print shapes and counts
context_input_viz = prepare_context_input(imgs_viz, coords_32, fourier_encoder, idx_c, DEVICE)
full_input_viz, _, _ = prepare_model_input(imgs_viz, coords_32, fourier_encoder)
print("Shapes:")
print("  Full image (batch):", tuple(full_imgs.shape))
print("  full_input (for EMA/target encoder when TARGET_FROM_FULL_CONTEXT):", tuple(full_input_viz.shape), "→ [B, N_FULL, CHANNELS+POS]")
print("  context_input (for online encoder):", tuple(context_input_viz.shape), "→ [B, N_CONTEXT, CHANNELS+POS]")
print("  Context coords:", len(idx_c), "| Target coords:", len(idx_t))
print("  TARGET_FROM_FULL_CONTEXT:", TARGET_FROM_FULL_CONTEXT)
if TARGET_FROM_FULL_CONTEXT:
    print("    → Target encoder (EMA) sees: full image. Target = φ_EMA(full image) at target coords.")
else:
    print("    → Target encoder (EMA) sees: context only (same as context encoder input). Target = φ_EMA(context) at target coords.")

**Target of the network (loss):** The predictor outputs **z_pred** (predicted φ at target coords); the target is **z_t_target** = φ_EMA at those same coords (no grad). Loss = cosine(z_pred, z_t_target). So we do *not* supervise pixels at target positions—we supervise **representation space** at target coordinates.

## Training loop — JEPA + RGB aux + VICReg + EMA update

In [None]:
params = list(model.parameters()) + list(fourier_encoder.parameters()) + list(semantic_head.parameters()) + list(predictor.parameters())
optimizer = torch.optim.Adam(params, lr=1e-3)
best_val_loss = float("inf")
CKPT_BEST = os.path.join(CHECKPOINT_DIR, "checkpoint_jepa_best.pt")
CKPT_LAST = os.path.join(CHECKPOINT_DIR, "checkpoint_jepa_last.pt")

In [None]:
for epoch in range(EPOCHS):
    model.train()
    fourier_encoder.train()
    semantic_head.train()
    predictor.train()
    model_ema.eval()
    fourier_ema.eval()
    semantic_head_ema.eval()
    total_loss = 0.0
    total_jepa = 0.0
    total_rgb = 0.0
    total_vic = 0.0
    if VERBOSE:
        print(f"Epoch {epoch+1}/{EPOCHS} (batches: {len(train_loader)})")
    for batch_ix, (imgs, _) in enumerate(train_loader):
        imgs = imgs.to(DEVICE)
        B = imgs.size(0)

        idx_c, idx_t = sample_context_target_indices(
            coords_32, N_CONTEXT, N_TARGET, DEVICE, block_mask=USE_BLOCK_MASK, mask_style=MASK_STYLE,
            image_size=IMAGE_SIZE, patch_size=PATCH_SIZE, enc_mask_scale=ENC_MASK_SCALE,
            pred_mask_scale=PRED_MASK_SCALE, aspect_ratio=ASPECT_RATIO, n_enc_masks=N_ENC_MASKS,
            n_pred_masks=N_PRED_MASKS, allow_overlap=ALLOW_OVERLAP, min_keep=MIN_KEEP)
        coords_c = coords_32[idx_c]
        coords_t = coords_32[idx_t]

        if VERBOSE and batch_ix == 0:
            print(f"  [epoch {epoch+1}] batch 0: N_context={len(idx_c)}, N_target={len(idx_t)} (mask_style={MASK_STYLE})")

        context_input = prepare_context_input(imgs, coords_32, fourier_encoder, idx_c, DEVICE)
        residual = get_residual(model, context_input)

        queries_c = fourier_encoder(coords_c.unsqueeze(0).expand(B, -1, -1))
        queries_t = fourier_encoder(coords_t.unsqueeze(0).expand(B, -1, -1))

        z_c = get_semantic_tokens(get_phi_raw(model, queries_c, residual), semantic_head)

        with torch.no_grad():
            if TARGET_FROM_FULL_CONTEXT:
                full_input, _, _ = prepare_model_input(imgs, coords_32, fourier_ema)
                residual_ema = get_residual(model_ema, full_input)
            else:
                residual_ema = get_residual(model_ema, context_input)
            z_t_target = get_semantic_tokens(get_phi_raw(model_ema, queries_t, residual_ema), semantic_head_ema)

        z_pred = predictor(queries_t, z_c)
        loss_jepa = jepa_cosine_loss(z_pred, z_t_target) if USE_JEPA_LOSS else torch.tensor(0.0, device=DEVICE)
        loss = loss_jepa
        loss_rgb = torch.tensor(0.0, device=DEVICE)
        loss_vic = torch.tensor(0.0, device=DEVICE)

        if USE_VICREG:
            loss_vic = vicreg_loss(z_c, VICREG_SIM_WEIGHT, VICREG_VAR_WEIGHT, VICREG_COV_WEIGHT)
            loss = loss + LAMBDA_VICREG * loss_vic
            total_vic += loss_vic.item()

        if USE_RGB_AUX:
            if RGB_QUERY_RES >= IMAGE_SIZE:
                queries_rgb = fourier_encoder(coords_32.unsqueeze(0).expand(B, -1, -1))
                rgb = get_rgb(model, queries_rgb, residual)
                target_pixels = rearrange(imgs, "b c h w -> b (h w) c")
            else:
                coords_rgb = create_coordinate_grid(RGB_QUERY_RES, RGB_QUERY_RES, DEVICE)
                queries_rgb = fourier_encoder(coords_rgb.unsqueeze(0).expand(B, -1, -1))
                rgb = get_rgb(model, queries_rgb, residual)
                target_pixels = sample_gt_at_coords(imgs, coords_rgb.unsqueeze(0).expand(B, -1, -1))
            loss_rgb = F.mse_loss(rgb, target_pixels) * LAMBDA_RGB
            loss = loss + loss_rgb
            total_rgb += loss_rgb.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        total_jepa += loss_jepa.item()

        if VERBOSE and LOG_EVERY > 0 and (batch_ix + 1) % LOG_EVERY == 0:
            rgb_val = total_rgb / (batch_ix + 1) if USE_RGB_AUX else 0.0
            vic_val = total_vic / (batch_ix + 1) if USE_VICREG else 0.0
            print(f"  [epoch {epoch+1}] batch {batch_ix+1}/{len(train_loader)}: loss={loss.item():.4f} jepa={loss_jepa.item():.4f}" +
                  (f" rgb={loss_rgb.item():.4f}" if USE_RGB_AUX else "") +
                  (f" vic={loss_vic.item():.4f}" if USE_VICREG else "") +
                  f" (avg so far: loss={total_loss/(batch_ix+1):.4f})")

        copy_ema(model, model_ema, EMA_MOMENTUM)
        copy_ema(fourier_encoder, fourier_ema, EMA_MOMENTUM)
        copy_ema(semantic_head, semantic_head_ema, EMA_MOMENTUM)

    avg = total_loss / len(train_loader)
    model.eval()
    fourier_encoder.eval()
    val_loss = 0.0
    with torch.no_grad():
        for imgs, _ in test_loader:
            imgs = imgs.to(DEVICE)
            B = imgs.size(0)
            full_input, _, _ = prepare_model_input(imgs, coords_32, fourier_encoder)
            residual = get_residual(model, full_input)
            queries_full = fourier_encoder(coords_32.unsqueeze(0).expand(B, -1, -1))
            rgb = get_rgb(model, queries_full, residual)
            target_pixels = rearrange(imgs, "b c h w -> b (h w) c")
            val_loss += F.mse_loss(rgb, target_pixels).item()
    val_loss /= len(test_loader)
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            "epoch": epoch + 1,
            "model_state_dict": model.state_dict(),
            "fourier_encoder_state_dict": fourier_encoder.state_dict(),
            "semantic_head_state_dict": semantic_head.state_dict(),
            "predictor_state_dict": predictor.state_dict(),
            "model_ema_state_dict": model_ema.state_dict(),
            "fourier_ema_state_dict": fourier_ema.state_dict(),
            "semantic_head_ema_state_dict": semantic_head_ema.state_dict(),
            "best_val_loss": best_val_loss,
        }, CKPT_BEST)
    torch.save({
        "epoch": epoch + 1,
        "model_state_dict": model.state_dict(),
        "fourier_encoder_state_dict": fourier_encoder.state_dict(),
        "semantic_head_state_dict": semantic_head.state_dict(),
        "predictor_state_dict": predictor.state_dict(),
        "model_ema_state_dict": model_ema.state_dict(),
        "fourier_ema_state_dict": fourier_ema.state_dict(),
        "semantic_head_ema_state_dict": semantic_head_ema.state_dict(),
        "best_val_loss": best_val_loss,
    }, CKPT_LAST)
    vic_str = f" vic: {total_vic/len(train_loader):.4f}" if USE_VICREG else ""
    rgb_str = f" rgb: {total_rgb/len(train_loader):.4f}" if USE_RGB_AUX else ""
    print(f"Epoch {epoch+1}/{EPOCHS} jepa: {total_jepa/len(train_loader):.4f}{vic_str}{rgb_str} val_loss: {val_loss:.4f}")

## Resolution-agnostic query (same as semantic-token notebook)

In [None]:
def query_semantic_tokens_at_resolution(model, semantic_head, fourier_encoder, residual, res_h, res_w, device):
    coords = create_coordinate_grid(res_h, res_w, device)
    B = residual.size(0)
    queries = fourier_encoder(coords.unsqueeze(0).expand(B, -1, -1))
    phi_raw = get_phi_raw(model, queries, residual)
    return get_semantic_tokens(phi_raw, semantic_head)

model.eval()
fourier_encoder.eval()
semantic_head.eval()
imgs, _ = next(iter(test_loader))
imgs = imgs[:4].to(DEVICE)
full_input, _, _ = prepare_model_input(imgs, coords_32, fourier_encoder)
with torch.no_grad():
    residual = get_residual(model, full_input)
    phi_16 = query_semantic_tokens_at_resolution(model, semantic_head, fourier_encoder, residual, 16, 16, DEVICE)
    phi_32 = query_semantic_tokens_at_resolution(model, semantic_head, fourier_encoder, residual, 32, 32, DEVICE)
print("φ at 16×16:", phi_16.shape, "| 32×32:", phi_32.shape)