# JEPA from Pretrained OmniField — CIFAR-10

Load a **reconstruction** checkpoint (OmniField or AblationCIFAR10), then run **JEPA training** from that initialization.

- Discovers a suitable pretrained checkpoint automatically (prefers `checkpoints_omnifield_cifar10/omnifield_cifar10_best.pt`, else `checkpoints/checkpoint_best.pt`).
- Loads pretrained weights into the encoder and fourier encoder; creates semantic_head + predictor from scratch.
- Runs the same JEPA loop as `JEPA_NeuralField_CIFAR10.ipynb` (VICReg, EMA, block masking, RGB aux, etc.).
- Saves checkpoints as `checkpoint_jepa_from_pretrained_best.pt` / `_last.pt`.

In [None]:
import os
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
from einops import rearrange, repeat

from nf_feature_models import (
    CascadedPerceiverIO,
    GaussianFourierFeatures,
    create_coordinate_grid,
    prepare_model_input,
)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT_DIR = "checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
print("Device:", DEVICE)

## Config

In [None]:
IMAGE_SIZE = 32
CHANNELS = 3
FOURIER_MAPPING_SIZE = 96
POS_EMBED_DIM = FOURIER_MAPPING_SIZE * 2
INPUT_DIM = CHANNELS + POS_EMBED_DIM
QUERIES_DIM = POS_EMBED_DIM
LOGITS_DIM = CHANNELS
BATCH_SIZE = 64
EPOCHS = 30
N_FULL = IMAGE_SIZE * IMAGE_SIZE

# JEPA: context vs target coord split
MASK_STYLE = "ijepa_block"
N_CONTEXT = 512
N_TARGET = 256
USE_BLOCK_MASK = MASK_STYLE == "ijepa_block"

# I-JEPA-style block masking
PATCH_SIZE = 4
ENC_MASK_SCALE = (0.85, 1.0)
PRED_MASK_SCALE = (0.15, 0.25)
ASPECT_RATIO = (0.75, 1.5)
N_ENC_MASKS = 1
N_PRED_MASKS = 2
ALLOW_OVERLAP = False
MIN_KEEP = 4

# Verbose training
VERBOSE = True
LOG_EVERY = 20

# I-JEPA faithful: target = EMA(full image) at target coords
TARGET_FROM_FULL_CONTEXT = True

# EMA target
EMA_MOMENTUM = 0.996
EMA_MOMENTUM_RAMP = True
EMA_MOMENTUM_START = 0.996
EMA_MOMENTUM_END = 0.999

# Losses
USE_JEPA_LOSS = True
USE_LAYERNORM_JEPA = True
JEPA_LOSS_TYPE = "smooth_l1"
USE_RGB_AUX = True
LAMBDA_RGB = 0.15
RGB_QUERY_RES = 16
USE_VICREG = True
LAMBDA_VICREG = 0.02
VICREG_SIM_WEIGHT = 25.0
VICREG_VAR_WEIGHT = 25.0
VICREG_COV_WEIGHT = 1.0

PHI_DIM = 128
PREDICTOR_DIM = 256
PREDICTOR_HEADS = 4

# Override: set to a specific path to skip auto-detection
PRETRAINED_CKPT_PATH = None  # e.g. "checkpoints_omnifield_cifar10/omnifield_cifar10_best.pt"

coords_32 = create_coordinate_grid(IMAGE_SIZE, IMAGE_SIZE, DEVICE)
print("N_FULL:", N_FULL, "| MASK_STYLE:", MASK_STYLE, "| TARGET_FROM_FULL_CONTEXT:", TARGET_FROM_FULL_CONTEXT)
print("JEPA_LOSS_TYPE:", JEPA_LOSS_TYPE, "| USE_LAYERNORM_JEPA:", USE_LAYERNORM_JEPA, "| USE_VICREG:", USE_VICREG, "| EMA_ramp:", EMA_MOMENTUM_RAMP, end="")
if EMA_MOMENTUM_RAMP:
    print(" [%s, %s]" % (EMA_MOMENTUM_START, EMA_MOMENTUM_END))
else:
    print()
if MASK_STYLE == "ijepa_block":
    print("  I-JEPA blocks: patch_size=%d enc_scale=%s pred_scale=%s aspect=%s n_enc=%d n_pred=%d" % (PATCH_SIZE, ENC_MASK_SCALE, PRED_MASK_SCALE, ASPECT_RATIO, N_ENC_MASKS, N_PRED_MASKS))

## Checkpoint Discovery

Auto-detect the best pretrained reconstruction checkpoint. Preference order:
1. `checkpoints_omnifield_cifar10/omnifield_cifar10_best.pt` (OmniField)
2. `checkpoints/checkpoint_best.pt` (AblationCIFAR10)

If both exist, prefer the one with lower `best_val_loss` / `val_loss`.

In [None]:
def discover_pretrained_checkpoint(override_path=None):
    """Find the best pretrained reconstruction checkpoint.
    Returns (path, source_name) or raises if none found."""
    if override_path is not None:
        if os.path.isfile(override_path):
            return override_path, "user-specified"
        raise FileNotFoundError(f"User-specified checkpoint not found: {override_path}")

    candidates = [
        ("checkpoints_omnifield_cifar10/omnifield_cifar10_best.pt", "OmniFieldCifar10"),
        ("checkpoints_omnifield_cifar10/omnifield_cifar10_last.pt", "OmniFieldCifar10 (last)"),
        ("checkpoints/checkpoint_best.pt", "AblationCIFAR10"),
        ("checkpoints/checkpoint_last.pt", "AblationCIFAR10 (last)"),
    ]
    found = [(p, name) for p, name in candidates if os.path.isfile(p)]
    if not found:
        raise FileNotFoundError(
            "No pretrained reconstruction checkpoint found!\n"
            "Please run OmniFieldCifar10.ipynb or AblationCIFAR10.ipynb first to produce a checkpoint.\n"
            f"Searched: {[p for p, _ in candidates]}"
        )

    # If we have both OmniField-best and Ablation-best, compare val_loss
    omni_best = next(((p, n) for p, n in found if p == "checkpoints_omnifield_cifar10/omnifield_cifar10_best.pt"), None)
    abl_best = next(((p, n) for p, n in found if p == "checkpoints/checkpoint_best.pt"), None)
    if omni_best and abl_best:
        try:
            ckpt_o = torch.load(omni_best[0], map_location="cpu")
            ckpt_a = torch.load(abl_best[0], map_location="cpu")
            loss_o = ckpt_o.get("val_loss", ckpt_o.get("best_val_loss", float("inf")))
            loss_a = ckpt_a.get("best_val_loss", ckpt_a.get("val_loss", float("inf")))
            print(f"  OmniField val_loss={loss_o:.6f}, Ablation best_val_loss={loss_a:.6f}")
            if loss_o <= loss_a:
                return omni_best
            else:
                return abl_best
        except Exception:
            pass  # fallback to preference order

    return found[0]  # first available in preference order


PRETRAINED_PATH, PRETRAINED_SOURCE = discover_pretrained_checkpoint(PRETRAINED_CKPT_PATH)
print(f"Selected pretrained checkpoint: {PRETRAINED_PATH} (source: {PRETRAINED_SOURCE})")

## Helper Functions

Field output (RGB + φ), context/target split, I-JEPA block masking, predictor, losses, EMA — same as `JEPA_NeuralField_CIFAR10.ipynb`.

In [None]:
# ── Field output helpers ──

def sample_gt_at_coords(images, coords):
    B, C, H, W = images.shape
    N = coords.shape[1]
    grid = coords[..., [1, 0]].view(B, 1, N, 2)
    sampled = F.grid_sample(images, grid, mode="bilinear", padding_mode="border", align_corners=True)
    return sampled.squeeze(2).permute(0, 2, 1)

def get_residual(model, context):
    residual = None
    for block in model.encoder_blocks:
        residual = block(x=residual, context=context, mask=None, residual=residual)
    for sa_block in model.self_attn_blocks:
        residual = sa_block[0](residual) + residual
        residual = sa_block[1](residual) + residual
    return residual

def decoder_forward(model, queries, residual):
    x = model.decoder_cross_attn(queries, context=residual)
    x = x + queries
    if model.decoder_ff is not None:
        x = x + model.decoder_ff(x)
    return x

def get_rgb(model, queries, residual):
    return model.to_logits(decoder_forward(model, queries, residual))

def get_phi_raw(model, queries, residual):
    return decoder_forward(model, queries, residual)

def get_semantic_tokens(phi_raw, semantic_head):
    return F.normalize(semantic_head(phi_raw), dim=-1)

In [None]:
# ── I-JEPA block masking ──

def _patch_to_pixel_indices(patch_indices_flat, h_patch, w_patch, patch_size, image_size, device):
    pixels_per_patch = patch_size * patch_size
    pixel_indices = []
    for p_flat in patch_indices_flat:
        pi, pj = p_flat // w_patch, p_flat % w_patch
        for di in range(patch_size):
            for dj in range(patch_size):
                i, j = pi * patch_size + di, pj * patch_size + dj
                pixel_indices.append(i * image_size + j)
    return torch.tensor(pixel_indices, device=device, dtype=torch.long)

def sample_ijepa_block_indices(image_size, patch_size, enc_mask_scale, pred_mask_scale, aspect_ratio,
                               n_enc_masks, n_pred_masks, allow_overlap, min_keep, device, generator=None):
    h_patch = image_size // patch_size
    w_patch = image_size // patch_size
    n_patches = h_patch * w_patch
    if generator is None:
        generator = torch.Generator(device=device)

    def sample_block_size(scale_lo, scale_hi, aspect_lo, aspect_hi):
        s = scale_lo + torch.rand(1, device=device, generator=generator).item() * (scale_hi - scale_lo)
        ar = aspect_lo + torch.rand(1, device=device, generator=generator).item() * (aspect_hi - aspect_lo)
        max_keep = int(n_patches * s)
        h = max(1, min(h_patch - 1, int(round((max_keep * ar) ** 0.5))))
        w = max(1, min(w_patch - 1, int(round((max_keep / ar) ** 0.5))))
        return h, w

    def sample_one_block(bh, bw, acceptable_region=None):
        for _ in range(30):
            top = torch.randint(0, h_patch - bh + 1, (1,), device=device, generator=generator).item()
            left = torch.randint(0, w_patch - bw + 1, (1,), device=device, generator=generator).item()
            block = torch.zeros(h_patch, w_patch, dtype=torch.bool, device=device)
            block[top : top + bh, left : left + bw] = True
            if acceptable_region is not None:
                if not (block & acceptable_region == block).all():
                    continue
            idx = torch.nonzero(block.flatten(), as_tuple=False).squeeze(-1)
            if len(idx) >= min_keep:
                complement = torch.ones(h_patch, w_patch, dtype=torch.bool, device=device)
                complement[top : top + bh, left : left + bw] = False
                return idx, complement
        return None, None

    # 1) Target (pred) blocks
    ph, pw = sample_block_size(pred_mask_scale[0], pred_mask_scale[1], aspect_ratio[0], aspect_ratio[1])
    all_pred_patches = set()
    pred_complements = []
    for _ in range(n_pred_masks):
        idx, comp = sample_one_block(ph, pw, None)
        if idx is None:
            continue
        pred_complements.append(comp)
        for i in idx.cpu().tolist():
            all_pred_patches.add(i)
    pred_patches_flat = list(all_pred_patches)
    if len(pred_patches_flat) < min_keep:
        pred_patches_flat = list(range(n_patches))[: min_keep + 1]
    idx_t = _patch_to_pixel_indices(pred_patches_flat, h_patch, w_patch, patch_size, image_size, device)

    # 2) Context (enc) blocks: only in complement of target
    acceptable = None if allow_overlap else torch.ones(h_patch, w_patch, dtype=torch.bool, device=device)
    if not allow_overlap and pred_complements:
        for c in pred_complements:
            acceptable = acceptable & c
    eh, ew = sample_block_size(enc_mask_scale[0], enc_mask_scale[1], 1.0, 1.0)
    all_enc_patches = set()
    for _ in range(n_enc_masks):
        idx, _ = sample_one_block(eh, ew, acceptable)
        if idx is not None:
            for i in idx.cpu().tolist():
                all_enc_patches.add(i)
    enc_patches_flat = list(all_enc_patches)
    if len(enc_patches_flat) < min_keep:
        enc_patches_flat = [k for k in range(n_patches) if k not in all_pred_patches][: min_keep + 1]
    if not enc_patches_flat:
        enc_patches_flat = [k for k in range(n_patches) if k not in all_pred_patches]
    idx_c = _patch_to_pixel_indices(enc_patches_flat, h_patch, w_patch, patch_size, image_size, device)

    return idx_c, idx_t

def sample_context_target_indices(coords_full, n_context, n_target, device, block_mask=False,
                                  mask_style="random", image_size=32, patch_size=4,
                                  enc_mask_scale=(0.85, 1.0), pred_mask_scale=(0.15, 0.25),
                                  aspect_ratio=(0.75, 1.5), n_enc_masks=1, n_pred_masks=2,
                                  allow_overlap=False, min_keep=4):
    n_total = coords_full.size(0)
    if mask_style == "ijepa_block":
        return sample_ijepa_block_indices(
            image_size, patch_size, enc_mask_scale, pred_mask_scale, aspect_ratio,
            n_enc_masks, n_pred_masks, allow_overlap, min_keep, device)
    if block_mask:
        h, w = int(math.sqrt(n_total)), int(math.sqrt(n_total))
        nh, nw = max(1, h // 4), max(1, w // 4)
        top = random.randint(0, h - nh)
        left = random.randint(0, w - nw)
        target_flat = []
        for i in range(top, top + nh):
            for j in range(left, left + nw):
                target_flat.append(i * w + j)
        idx_t = torch.tensor(target_flat, device=device, dtype=torch.long)
        idx_c = torch.tensor([k for k in range(n_total) if k not in set(target_flat)], device=device)
        if idx_c.size(0) > n_context:
            idx_c = idx_c[torch.randperm(idx_c.size(0), device=device)[:n_context]]
    else:
        perm = torch.randperm(n_total, device=device)
        idx_c = perm[:n_context]
        idx_t = perm[n_context : n_context + n_target]
    return idx_c, idx_t

In [None]:
# ── Encoder input preparation ──

def prepare_context_input(images, coords_full, fourier_encoder, idx_c, device):
    B = images.size(0)
    coords_c = coords_full[idx_c]
    pixels = sample_gt_at_coords(images, coords_c.unsqueeze(0).expand(B, -1, -1))
    pos = fourier_encoder(coords_c.unsqueeze(0).expand(B, -1, -1))
    return torch.cat([pixels, pos], dim=-1)

def prepare_jepa_encoder_input(images, coords_full, fourier_encoder, idx_c, idx_t, device, mask_rgb):
    B = images.size(0)
    context_input = prepare_context_input(images, coords_full, fourier_encoder, idx_c, device)
    coords_t = coords_full[idx_t]
    pos_t = fourier_encoder(coords_t.unsqueeze(0).expand(B, -1, -1))
    rgb_t = mask_rgb.expand(B, coords_t.size(0), -1)
    target_placeholder = torch.cat([rgb_t, pos_t], dim=-1)
    return torch.cat([context_input, target_placeholder], dim=1)

In [None]:
# ── Cross-attention predictor ──

class JEPAPredictor(nn.Module):
    """Predict target embeddings from context: z_hat_t = Attn(q=embed(X_t), k=z_c, v=z_c)."""

    def __init__(self, coord_embed_dim, phi_dim, pred_dim, num_heads=4):
        super().__init__()
        self.to_q = nn.Linear(coord_embed_dim, pred_dim)
        self.to_kv = nn.Linear(phi_dim, pred_dim * 2)
        self.to_out = nn.Linear(pred_dim, phi_dim)
        self.num_heads = num_heads
        self.d_head = pred_dim // num_heads

    def forward(self, coords_t_embed, z_c):
        B, N_t, _ = coords_t_embed.shape
        _, N_c, _ = z_c.shape
        q = self.to_q(coords_t_embed).view(B, N_t, self.num_heads, self.d_head).transpose(1, 2)
        kv = self.to_kv(z_c).view(B, N_c, 2, self.num_heads, self.d_head)
        k, v = kv[:, :, 0].transpose(1, 2), kv[:, :, 1].transpose(1, 2)
        scale = self.d_head ** -0.5
        attn = torch.matmul(q, k.transpose(-2, -1)) * scale
        attn = F.softmax(attn, dim=-1)
        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).reshape(B, N_t, -1)
        return self.to_out(out)

In [None]:
# ── JEPA losses + VICReg ──

def jepa_cosine_loss(z_pred, z_target):
    z_pred = F.normalize(z_pred, dim=-1)
    return (1 - (z_pred * z_target).sum(dim=-1)).mean()

def jepa_l2_loss(z_pred, z_target):
    z_pred = F.normalize(z_pred, dim=-1)
    return (2 - 2 * (z_pred * z_target).sum(dim=-1)).mean()

def jepa_smooth_l1_loss(z_pred, z_target_raw):
    return F.smooth_l1_loss(z_pred, z_target_raw).mean()

def vicreg_loss(z, sim_weight=25.0, var_weight=25.0, cov_weight=1.0):
    B, N, D = z.shape
    z = z.reshape(B * N, D)
    std = z.std(dim=0) + 1e-4
    var_loss = torch.mean(F.relu(1 - std))
    z_centered = z - z.mean(dim=0)
    cov = (z_centered.T @ z_centered) / (z.size(0) - 1)
    cov_loss = (cov.pow(2).sum() - cov.diag().pow(2).sum()) / D
    return var_weight * var_loss + cov_weight * cov_loss

def copy_ema(source, target, momentum=0.996):
    for p_s, p_t in zip(source.parameters(), target.parameters()):
        p_t.data.mul_(momentum).add_(p_s.data, alpha=1 - momentum)

## Build Models + Load Pretrained Checkpoint

Build the same architecture as `JEPA_NeuralField_CIFAR10.ipynb`, then load pretrained weights into `model` and `fourier_encoder` from the discovered checkpoint. `semantic_head`, `predictor`, and `mask_rgb` are created fresh.

In [None]:
# ── Build models (same constructors as JEPA notebook) ──

fourier_encoder = GaussianFourierFeatures(2, FOURIER_MAPPING_SIZE, 15.0).to(DEVICE)
model = CascadedPerceiverIO(
    input_dim=INPUT_DIM, queries_dim=QUERIES_DIM, logits_dim=LOGITS_DIM,
    latent_dims=(256, 384, 512), num_latents=(256, 256, 256), decoder_ff=True,
).to(DEVICE)
semantic_head = nn.Linear(QUERIES_DIM, PHI_DIM).to(DEVICE)  # new, not from pretrained

# EMA copies
model_ema = CascadedPerceiverIO(
    input_dim=INPUT_DIM, queries_dim=QUERIES_DIM, logits_dim=LOGITS_DIM,
    latent_dims=(256, 384, 512), num_latents=(256, 256, 256), decoder_ff=True,
).to(DEVICE)
fourier_ema = GaussianFourierFeatures(2, FOURIER_MAPPING_SIZE, 15.0).to(DEVICE)
semantic_head_ema = nn.Linear(QUERIES_DIM, PHI_DIM).to(DEVICE)

# Learnable [MASK] for target positions
mask_rgb = nn.Parameter(torch.zeros(1, 1, CHANNELS).to(DEVICE))

# Predictor (new, not from pretrained)
predictor = JEPAPredictor(POS_EMBED_DIM, PHI_DIM, PREDICTOR_DIM, num_heads=PREDICTOR_HEADS).to(DEVICE)

print("Models built (same architecture as JEPA notebook).")

In [None]:
# ── Load pretrained checkpoint ──

ckpt = torch.load(PRETRAINED_PATH, map_location=DEVICE)
print(f"Checkpoint keys: {list(ckpt.keys())}")

# Handle two key conventions:
# OmniFieldCifar10: model_state / fourier_state
# AblationCIFAR10:  model_state_dict / fourier_encoder_state_dict
if "model_state_dict" in ckpt:
    model_sd = ckpt["model_state_dict"]
    fourier_sd = ckpt["fourier_encoder_state_dict"]
    key_style = "model_state_dict / fourier_encoder_state_dict"
elif "model_state" in ckpt:
    model_sd = ckpt["model_state"]
    fourier_sd = ckpt["fourier_state"]
    key_style = "model_state / fourier_state"
else:
    raise KeyError(f"Checkpoint has neither 'model_state_dict' nor 'model_state'. Keys: {list(ckpt.keys())}")

# Load into model and fourier_encoder (strict=False: pretrained has no semantic_head, may differ slightly)
missing_m, unexpected_m = model.load_state_dict(model_sd, strict=False)
missing_f, unexpected_f = fourier_encoder.load_state_dict(fourier_sd, strict=False)

print(f"\nLoaded pretrained from: {PRETRAINED_PATH} (source: {PRETRAINED_SOURCE})")
print(f"  Key style: {key_style}")
if "best_val_loss" in ckpt:
    print(f"  Pretrained best_val_loss: {ckpt['best_val_loss']:.6f}")
elif "val_loss" in ckpt:
    print(f"  Pretrained val_loss: {ckpt['val_loss']:.6f}")
if "epoch" in ckpt:
    print(f"  Pretrained epoch: {ckpt['epoch']}")
if missing_m:
    print(f"  Model missing keys (expected for new arch): {missing_m}")
if unexpected_m:
    print(f"  Model unexpected keys (ignored): {unexpected_m}")
if missing_f:
    print(f"  Fourier missing keys: {missing_f}")
if unexpected_f:
    print(f"  Fourier unexpected keys: {unexpected_f}")

# Init EMA from loaded online model + new semantic_head
model_ema.load_state_dict(model.state_dict())
fourier_ema.load_state_dict(fourier_encoder.state_dict())
semantic_head_ema.load_state_dict(semantic_head.state_dict())

print("\nEMA initialized from loaded model + new semantic_head. Predictor + mask_rgb init from scratch.")

## Diagnostic: Verify Pretrained Weights Are Active

Compute initial val_loss (RGB reconstruction) **before any training**. If pretrained loaded correctly, this should be much lower than a random model (random ≈ 0.08–0.12; pretrained ≈ 0.001–0.01). This is the definitive check.

In [None]:
# ── Diagnostic: initial reconstruction quality (before any JEPA training) ──
# If pretrained weights loaded correctly, val_loss should be MUCH lower than random init.
# Random init typically gives val_loss ≈ 0.08–0.12; a good pretrained model ≈ 0.001–0.01.

model.eval()
fourier_encoder.eval()

# Quick check on a few test batches
_diag_loss = 0.0
_diag_n = 0
with torch.no_grad():
    for _i, (imgs, _) in enumerate(test_loader):
        if _i >= 10:  # 10 batches is enough to see the difference
            break
        imgs = imgs.to(DEVICE)
        B = imgs.size(0)
        full_input, _, _ = prepare_model_input(imgs, coords_32, fourier_encoder)
        residual = get_residual(model, full_input)
        queries_full = fourier_encoder(coords_32.unsqueeze(0).expand(B, -1, -1))
        rgb = get_rgb(model, queries_full, residual)
        target_pixels = rearrange(imgs, "b c h w -> b (h w) c")
        _diag_loss += F.mse_loss(rgb, target_pixels).item()
        _diag_n += 1

_diag_loss /= _diag_n
print(f"=== DIAGNOSTIC: Initial RGB reconstruction MSE = {_diag_loss:.6f} ===")
if _diag_loss > 0.05:
    print("  WARNING: This is in the 'random init' range (>0.05).")
    print("  Pretrained weights may NOT have loaded correctly!")
    print("  Check the missing/unexpected keys printed above.")
else:
    print(f"  GOOD: pretrained encoder is active (MSE {_diag_loss:.6f} << random ~0.08–0.12).")

## Data Loaders

In [None]:
transform = transforms.Compose([transforms.ToTensor()])
train_ds = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
test_ds = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

## Representation Quality Monitoring

Collapse metrics (per-dimension std of mean-pooled φ) and k-NN accuracy on a fixed subset every epoch.

In [None]:
N_MONITOR_TRAIN = 1000
N_MONITOR_TEST = 500
MONITOR_GRID = 16
K_NN = 5

monitor_train_indices = list(range(min(N_MONITOR_TRAIN, len(train_ds))))
monitor_test_indices = list(range(min(N_MONITOR_TEST, len(test_ds))))
monitor_train_subset = Subset(train_ds, monitor_train_indices)
monitor_test_subset = Subset(test_ds, monitor_test_indices)
monitor_train_loader = DataLoader(monitor_train_subset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
monitor_test_loader = DataLoader(monitor_test_subset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

def compute_representation_metrics(model, semantic_head, fourier_encoder, coords_32, device,
                                  monitor_train_loader, monitor_test_loader, grid_size=16, k=5):
    model.eval()
    fourier_encoder.eval()
    semantic_head.eval()
    coords_monitor = create_coordinate_grid(grid_size, grid_size, device)
    feats_train, labels_train = [], []
    with torch.no_grad():
        for imgs, labels in monitor_train_loader:
            imgs = imgs.to(device)
            B = imgs.size(0)
            full_input, _, _ = prepare_model_input(imgs, coords_32, fourier_encoder)
            residual = get_residual(model, full_input)
            queries = fourier_encoder(coords_monitor.unsqueeze(0).expand(B, -1, -1))
            phi = get_semantic_tokens(get_phi_raw(model, queries, residual), semantic_head)
            phi_pooled = phi.mean(dim=1)
            feats_train.append(phi_pooled.cpu())
            labels_train.append(labels)
    feats_train = torch.cat(feats_train, dim=0)
    labels_train = torch.cat(labels_train, dim=0)
    feats_test, labels_test = [], []
    with torch.no_grad():
        for imgs, labels in monitor_test_loader:
            imgs = imgs.to(device)
            B = imgs.size(0)
            full_input, _, _ = prepare_model_input(imgs, coords_32, fourier_encoder)
            residual = get_residual(model, full_input)
            queries = fourier_encoder(coords_monitor.unsqueeze(0).expand(B, -1, -1))
            phi = get_semantic_tokens(get_phi_raw(model, queries, residual), semantic_head)
            phi_pooled = phi.mean(dim=1)
            feats_test.append(phi_pooled.cpu())
            labels_test.append(labels)
    feats_test = torch.cat(feats_test, dim=0)
    labels_test = torch.cat(labels_test, dim=0)

    all_feats = torch.cat([feats_train, feats_test], dim=0)
    std_per_dim = all_feats.std(dim=0)
    phi_std_mean = std_per_dim.mean().item()
    phi_std_min = std_per_dim.min().item()

    dist = torch.cdist(feats_test, feats_train)
    _, idx = dist.topk(k, dim=1, largest=False)
    neighbor_labels = labels_train[idx]
    votes = torch.mode(neighbor_labels, dim=1).values
    knn_acc = (votes == labels_test).float().mean().item()
    return {"phi_std_mean": phi_std_mean, "phi_std_min": phi_std_min, "knn_acc": knn_acc}

## Training Loop — JEPA + RGB Aux + VICReg + EMA Update

Same loop as `JEPA_NeuralField_CIFAR10.ipynb`. Saves to `checkpoint_jepa_from_pretrained_best.pt` / `_last.pt`.

In [None]:
params = (
    list(model.parameters())
    + list(fourier_encoder.parameters())
    + list(semantic_head.parameters())
    + list(predictor.parameters())
    + [mask_rgb]
)
optimizer = torch.optim.Adam(params, lr=1e-3)
best_val_loss = float("inf")
CKPT_BEST = os.path.join(CHECKPOINT_DIR, "checkpoint_jepa_from_pretrained_best.pt")
CKPT_LAST = os.path.join(CHECKPOINT_DIR, "checkpoint_jepa_from_pretrained_last.pt")
repr_history = []

In [None]:
for epoch in range(EPOCHS):
    model.train()
    fourier_encoder.train()
    semantic_head.train()
    predictor.train()
    model_ema.eval()
    fourier_ema.eval()
    semantic_head_ema.eval()
    total_loss = 0.0
    total_jepa = 0.0
    total_rgb = 0.0
    total_vic = 0.0
    if VERBOSE:
        print(f"Epoch {epoch+1}/{EPOCHS} (batches: {len(train_loader)})")
    for batch_ix, (imgs, _) in enumerate(train_loader):
        imgs = imgs.to(DEVICE)
        B = imgs.size(0)

        idx_c, idx_t = sample_context_target_indices(
            coords_32, N_CONTEXT, N_TARGET, DEVICE, block_mask=USE_BLOCK_MASK, mask_style=MASK_STYLE,
            image_size=IMAGE_SIZE, patch_size=PATCH_SIZE, enc_mask_scale=ENC_MASK_SCALE,
            pred_mask_scale=PRED_MASK_SCALE, aspect_ratio=ASPECT_RATIO, n_enc_masks=N_ENC_MASKS,
            n_pred_masks=N_PRED_MASKS, allow_overlap=ALLOW_OVERLAP, min_keep=MIN_KEEP)
        coords_c = coords_32[idx_c]
        coords_t = coords_32[idx_t]

        if VERBOSE and batch_ix == 0:
            print(f"  [epoch {epoch+1}] batch 0: N_context={len(idx_c)}, N_target={len(idx_t)} (mask_style={MASK_STYLE})")

        encoder_input = prepare_jepa_encoder_input(imgs, coords_32, fourier_encoder, idx_c, idx_t, DEVICE, mask_rgb)
        residual = get_residual(model, encoder_input)

        queries_c = fourier_encoder(coords_c.unsqueeze(0).expand(B, -1, -1))
        queries_t = fourier_encoder(coords_t.unsqueeze(0).expand(B, -1, -1))

        z_c = get_semantic_tokens(get_phi_raw(model, queries_c, residual), semantic_head)

        # Target branch: STOP-GRAD
        with torch.no_grad():
            if TARGET_FROM_FULL_CONTEXT:
                full_input, _, _ = prepare_model_input(imgs, coords_32, fourier_ema)
                residual_ema = get_residual(model_ema, full_input)
            else:
                residual_ema = get_residual(model_ema, encoder_input)
            phi_t_raw = get_phi_raw(model_ema, queries_t, residual_ema)
            z_t_target = get_semantic_tokens(phi_t_raw, semantic_head_ema)
            if JEPA_LOSS_TYPE == "smooth_l1" or USE_LAYERNORM_JEPA:
                z_t_target_raw = semantic_head_ema(phi_t_raw)

        z_pred = predictor(queries_t, z_c)
        if not USE_JEPA_LOSS:
            loss_jepa = torch.tensor(0.0, device=DEVICE)
        elif USE_LAYERNORM_JEPA:
            z_pred_ln = F.layer_norm(z_pred, (z_pred.size(-1),))
            z_t_ln = F.layer_norm(z_t_target_raw, (z_t_target_raw.size(-1),))
            loss_jepa = F.smooth_l1_loss(z_pred_ln, z_t_ln).mean()
        elif JEPA_LOSS_TYPE == "cosine":
            loss_jepa = jepa_cosine_loss(z_pred, z_t_target)
        elif JEPA_LOSS_TYPE == "l2":
            loss_jepa = jepa_l2_loss(z_pred, z_t_target)
        else:
            loss_jepa = jepa_smooth_l1_loss(z_pred, z_t_target_raw)
        loss = loss_jepa
        loss_rgb = torch.tensor(0.0, device=DEVICE)
        loss_vic = torch.tensor(0.0, device=DEVICE)

        if USE_VICREG:
            loss_vic = vicreg_loss(z_c, VICREG_SIM_WEIGHT, VICREG_VAR_WEIGHT, VICREG_COV_WEIGHT)
            loss = loss + LAMBDA_VICREG * loss_vic
            total_vic += loss_vic.item()

        if USE_RGB_AUX:
            if RGB_QUERY_RES >= IMAGE_SIZE:
                queries_rgb = fourier_encoder(coords_32.unsqueeze(0).expand(B, -1, -1))
                rgb = get_rgb(model, queries_rgb, residual)
                target_pixels = rearrange(imgs, "b c h w -> b (h w) c")
            else:
                coords_rgb = create_coordinate_grid(RGB_QUERY_RES, RGB_QUERY_RES, DEVICE)
                queries_rgb = fourier_encoder(coords_rgb.unsqueeze(0).expand(B, -1, -1))
                rgb = get_rgb(model, queries_rgb, residual)
                target_pixels = sample_gt_at_coords(imgs, coords_rgb.unsqueeze(0).expand(B, -1, -1))
            loss_rgb = F.mse_loss(rgb, target_pixels) * LAMBDA_RGB
            loss = loss + loss_rgb
            total_rgb += loss_rgb.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        total_jepa += loss_jepa.item()

        if VERBOSE and LOG_EVERY > 0 and (batch_ix + 1) % LOG_EVERY == 0:
            rgb_val = total_rgb / (batch_ix + 1) if USE_RGB_AUX else 0.0
            vic_val = total_vic / (batch_ix + 1) if USE_VICREG else 0.0
            print(f"  [epoch {epoch+1}] batch {batch_ix+1}/{len(train_loader)}: loss={loss.item():.4f} jepa={loss_jepa.item():.4f}" +
                  (f" rgb={loss_rgb.item():.4f}" if USE_RGB_AUX else "") +
                  (f" vic={loss_vic.item():.4f}" if USE_VICREG else "") +
                  f" (avg so far: loss={total_loss/(batch_ix+1):.4f})")

        if EMA_MOMENTUM_RAMP:
            total_steps = EPOCHS * len(train_loader)
            step = epoch * len(train_loader) + batch_ix
            current_momentum = EMA_MOMENTUM_START + (EMA_MOMENTUM_END - EMA_MOMENTUM_START) * min(1.0, step / total_steps)
        else:
            current_momentum = EMA_MOMENTUM
        copy_ema(model, model_ema, current_momentum)
        copy_ema(fourier_encoder, fourier_ema, current_momentum)
        copy_ema(semantic_head, semantic_head_ema, current_momentum)

    avg = total_loss / len(train_loader)
    model.eval()
    fourier_encoder.eval()
    val_loss = 0.0
    with torch.no_grad():
        for imgs, _ in test_loader:
            imgs = imgs.to(DEVICE)
            B = imgs.size(0)
            full_input, _, _ = prepare_model_input(imgs, coords_32, fourier_encoder)
            residual = get_residual(model, full_input)
            queries_full = fourier_encoder(coords_32.unsqueeze(0).expand(B, -1, -1))
            rgb = get_rgb(model, queries_full, residual)
            target_pixels = rearrange(imgs, "b c h w -> b (h w) c")
            val_loss += F.mse_loss(rgb, target_pixels).item()
    val_loss /= len(test_loader)
    repr_metrics = compute_representation_metrics(
        model, semantic_head, fourier_encoder, coords_32, DEVICE,
        monitor_train_loader, monitor_test_loader, grid_size=MONITOR_GRID, k=K_NN)
    repr_history.append(repr_metrics)
    if VERBOSE:
        print(f"  repr: phi_std_mean={repr_metrics['phi_std_mean']:.4f} phi_std_min={repr_metrics['phi_std_min']:.4f} knn_acc={repr_metrics['knn_acc']:.4f}")

    ckpt_data = {
        "epoch": epoch + 1,
        "model_state_dict": model.state_dict(),
        "fourier_encoder_state_dict": fourier_encoder.state_dict(),
        "semantic_head_state_dict": semantic_head.state_dict(),
        "predictor_state_dict": predictor.state_dict(),
        "model_ema_state_dict": model_ema.state_dict(),
        "fourier_ema_state_dict": fourier_ema.state_dict(),
        "semantic_head_ema_state_dict": semantic_head_ema.state_dict(),
        "best_val_loss": best_val_loss,
    }
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        ckpt_data["best_val_loss"] = best_val_loss
        torch.save(ckpt_data, CKPT_BEST)
    torch.save(ckpt_data, CKPT_LAST)

    vic_str = f" vic: {total_vic/len(train_loader):.4f}" if USE_VICREG else ""
    rgb_str = f" rgb: {total_rgb/len(train_loader):.4f}" if USE_RGB_AUX else ""
    print(f"Epoch {epoch+1}/{EPOCHS} jepa: {total_jepa/len(train_loader):.4f}{vic_str}{rgb_str} val_loss: {val_loss:.4f}")

In [None]:
# Plot representation quality over epochs
if len(repr_history) > 0:
    fig, ax = plt.subplots(1, 3, figsize=(12, 3))
    epochs_arr = np.arange(1, len(repr_history) + 1)
    ax[0].plot(epochs_arr, [m["phi_std_mean"] for m in repr_history], "o-")
    ax[0].set_title("phi std (mean over dims)")
    ax[0].set_xlabel("Epoch")
    ax[1].plot(epochs_arr, [m["phi_std_min"] for m in repr_history], "o-")
    ax[1].set_title("phi std (min dim)")
    ax[1].set_xlabel("Epoch")
    ax[2].plot(epochs_arr, [m["knn_acc"] for m in repr_history], "o-")
    ax[2].set_title("k-NN accuracy")
    ax[2].set_xlabel("Epoch")
    plt.tight_layout()
    plt.show()