# I-JEPA on CIFAR-10

Self-contained implementation of **Image-based Joint-Embedding Predictive Architecture (I-JEPA)**  
([Assran et al., 2023](https://arxiv.org/abs/2301.08243)) adapted for CIFAR-10 (32×32).

**Architecture overview:**
- **Encoder** — ViT-Tiny that patches the image (patch\_size=4 → 8×8 = 64 tokens) and processes only the *context* (unmasked) patches.
- **Target encoder** — EMA copy of the encoder; sees the full image; provides prediction targets.
- **Predictor** — lightweight Transformer that takes context-encoder tokens + learnable mask tokens and predicts target-encoder embeddings at the masked positions.
- **Loss** — Smooth-L1 between predicted and target embeddings (no pixel reconstruction).

| Component | Details |
|---|---|
| Image size | 32 × 32 |
| Patch size | 4 → 8 × 8 grid = 64 patches |
| Encoder | dim=192, depth=6, heads=3 (~1.2 M params) |
| Predictor | dim=96, depth=4, heads=3 (~0.3 M params) |
| Target encoder | EMA (momentum 0.996 → 1.0) |
| Masking | 4 target blocks (15-25 %), 1 context block (85-100 %, non-overlapping) |

In [None]:
import copy
import math
import os
import time
from functools import partial
from multiprocessing import Value

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from IPython.display import clear_output
import matplotlib.pyplot as plt

print(f"PyTorch {torch.__version__}")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {DEVICE}")

## 1 — Hyperparameters

In [None]:
# Reproducibility
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Data
DATA_ROOT      = "./data"
BATCH_SIZE     = 256
NUM_WORKERS    = 4
IMG_SIZE       = 32

# Model
PATCH_SIZE     = 4          # 32/4 = 8 patches per side
EMBED_DIM      = 192
DEPTH          = 6
NUM_HEADS      = 3
PRED_EMBED_DIM = 96
PRED_DEPTH     = 4
PRED_NUM_HEADS = 3

# Masking (on the 8×8 patch grid)
ENC_MASK_SCALE  = (0.85, 1.0)   # context block covers 85-100 %
PRED_MASK_SCALE = (0.15, 0.25)  # each target block covers 15-25 %
ASPECT_RATIO    = (0.75, 1.5)
N_ENC_MASKS     = 1
N_PRED_MASKS    = 4
MIN_KEEP        = 4
ALLOW_OVERLAP   = False

# Optimization
EPOCHS         = 100
LR             = 1e-3
START_LR       = 1e-4
FINAL_LR       = 1e-6
WARMUP_EPOCHS  = 10
WEIGHT_DECAY   = 0.05
EMA_START      = 0.996
EMA_END        = 1.0

# Logging
LOG_FREQ       = 50
CKPT_DIR       = "./ckpt_cifar10_ijepa"
os.makedirs(CKPT_DIR, exist_ok=True)

# Linear probe
PROBE_EPOCHS   = 100
PROBE_LR       = 0.1

## 2 — Model components

All modules are defined here so the notebook is fully self-contained.  
The code mirrors Meta's I-JEPA repo (`ijepa/src/models/vision_transformer.py`) with minor simplifications for CIFAR-10.

### 2.1 — Positional embeddings & utilities

In [None]:
def get_2d_sincos_pos_embed(embed_dim, grid_size):
    """2-D sine-cosine positional embedding."""
    gh = np.arange(grid_size, dtype=float)
    gw = np.arange(grid_size, dtype=float)
    grid = np.meshgrid(gw, gh)  # (w, h)
    grid = np.stack(grid, axis=0).reshape(2, 1, grid_size, grid_size)

    def _1d(dim, pos):
        assert dim % 2 == 0
        omega = np.arange(dim // 2, dtype=float) / (dim / 2.0)
        omega = 1.0 / 10000**omega
        out = np.einsum("m,d->md", pos.reshape(-1), omega)
        return np.concatenate([np.sin(out), np.cos(out)], axis=1)

    emb_h = _1d(embed_dim // 2, grid[0])
    emb_w = _1d(embed_dim // 2, grid[1])
    return np.concatenate([emb_h, emb_w], axis=1)  # (grid_size**2, embed_dim)


def trunc_normal_(tensor, std=0.02):
    """Truncated normal init (PyTorch built-in is fine for modern versions)."""
    nn.init.trunc_normal_(tensor, std=std, a=-2 * std, b=2 * std)


def apply_masks(x, masks):
    """
    Gather patches indicated by `masks` from `x`.
    x     : (B, N, D)
    masks : list of (B, K) index tensors
    return: (B * len(masks), K, D)
    """
    parts = []
    for m in masks:
        idx = m.unsqueeze(-1).expand(-1, -1, x.size(-1))
        parts.append(torch.gather(x, 1, idx))
    return torch.cat(parts, dim=0)


def repeat_interleave_batch(x, B, repeat):
    """Repeat each length-B sub-batch `repeat` times along dim-0."""
    N = len(x) // B
    return torch.cat(
        [torch.cat([x[i * B:(i + 1) * B]] * repeat, dim=0) for i in range(N)],
        dim=0,
    )

### 2.2 — Transformer building blocks

In [None]:
class Attention(nn.Module):
    def __init__(self, dim, heads=8, qkv_bias=False):
        super().__init__()
        self.heads = heads
        self.scale = (dim // heads) ** -0.5
        self.qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.heads, C // self.heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        return self.proj(x)


class MLP(nn.Module):
    def __init__(self, dim, mult=4):
        super().__init__()
        self.fc1 = nn.Linear(dim, dim * mult)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(dim * mult, dim)

    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))


class Block(nn.Module):
    def __init__(self, dim, heads, qkv_bias=True):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, heads, qkv_bias)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

### 2.3 — ViT Encoder

In [None]:
class PatchEmbed(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=192):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim,
                              kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        return self.proj(x).flatten(2).transpose(1, 2)  # (B, N, D)


class VisionTransformer(nn.Module):
    """ViT encoder — processes (optionally masked) patch tokens."""

    def __init__(self, img_size=32, patch_size=4, embed_dim=192,
                 depth=6, heads=3):
        super().__init__()
        self.embed_dim = embed_dim
        self.patch_embed = PatchEmbed(img_size, patch_size, 3, embed_dim)
        n_patches = self.patch_embed.num_patches

        self.pos_embed = nn.Parameter(
            torch.zeros(1, n_patches, embed_dim), requires_grad=False)
        pe = get_2d_sincos_pos_embed(embed_dim, int(n_patches ** 0.5))
        self.pos_embed.data.copy_(torch.from_numpy(pe).float().unsqueeze(0))

        self.blocks = nn.ModuleList(
            [Block(embed_dim, heads) for _ in range(depth)])
        self.norm = nn.LayerNorm(embed_dim)
        self.apply(self._init)

    @staticmethod
    def _init(m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Conv2d):
            trunc_normal_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    def forward(self, x, masks=None):
        """
        x     : (B, 3, H, W)
        masks : None  → encode all patches
                list of (B, K) → keep only selected patch indices
        """
        x = self.patch_embed(x) + self.pos_embed  # (B, N, D)
        if masks is not None:
            x = apply_masks(x, masks)              # (B·nenc, K, D)
        for blk in self.blocks:
            x = blk(x)
        return self.norm(x)

### 2.4 — Predictor

The predictor receives context-encoder tokens, projects them to a smaller dimension, concatenates **learnable mask tokens** (one per target position), adds positional embeddings for both sets, and runs a small Transformer. Only the mask-token outputs are returned (projected back to `embed_dim`).

In [None]:
class Predictor(nn.Module):
    def __init__(self, n_patches=64, embed_dim=192,
                 pred_dim=96, depth=4, heads=3):
        super().__init__()
        self.embed_proj = nn.Linear(embed_dim, pred_dim)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, pred_dim))
        trunc_normal_(self.mask_token)

        self.pos_embed = nn.Parameter(
            torch.zeros(1, n_patches, pred_dim), requires_grad=False)
        pe = get_2d_sincos_pos_embed(pred_dim, int(n_patches ** 0.5))
        self.pos_embed.data.copy_(torch.from_numpy(pe).float().unsqueeze(0))

        self.blocks = nn.ModuleList(
            [Block(pred_dim, heads) for _ in range(depth)])
        self.norm = nn.LayerNorm(pred_dim)
        self.out_proj = nn.Linear(pred_dim, embed_dim)
        self.apply(self._init)

    @staticmethod
    def _init(m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    def forward(self, ctx_tokens, masks_ctx, masks_tgt):
        """
        ctx_tokens : (B·nenc, K_ctx, embed_dim)  — context encoder output
        masks_ctx  : list of nenc tensors (B, K_ctx) — context indices
        masks_tgt  : list of npred tensors (B, K_tgt) — target indices
        Returns    : (B·nenc·npred, K_tgt, embed_dim)
        """
        if not isinstance(masks_ctx, list): masks_ctx = [masks_ctx]
        if not isinstance(masks_tgt, list): masks_tgt = [masks_tgt]

        B = len(ctx_tokens) // len(masks_ctx)
        npred = len(masks_tgt)

        # Project to predictor dim & add context positional embeddings
        x = self.embed_proj(ctx_tokens)                          # (B·nenc, K_ctx, pred_dim)
        x = x + apply_masks(self.pos_embed.expand(B, -1, -1),
                            masks_ctx)                           # positional info
        _, N_ctx, _ = x.shape

        # Build mask tokens with target positional embeddings
        tgt_pos = apply_masks(self.pos_embed.expand(B, -1, -1),
                              masks_tgt)                         # (B·npred, K_tgt, pred_dim)
        tgt_pos = repeat_interleave_batch(tgt_pos, B,
                                          repeat=len(masks_ctx)) # (B·npred·nenc, ...)
        pred_tokens = self.mask_token.expand_as(tgt_pos) + tgt_pos

        # Concat context + mask tokens
        x = x.repeat(npred, 1, 1)                               # (B·nenc·npred, K_ctx, pred_dim)
        x = torch.cat([x, pred_tokens], dim=1)                  # (…, K_ctx+K_tgt, pred_dim)

        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)

        x = x[:, N_ctx:]            # keep only predictions
        return self.out_proj(x)     # → (…, K_tgt, embed_dim)

### 2.5 — Block-mask collator

Generates disjoint context (encoder) and target (predictor) block masks on the 8×8 patch grid. Used as `collate_fn` in the dataloader.

In [None]:
class MaskCollator:
    """
    Dataloader collate_fn that also produces block masks.
    Returns (images, labels), masks_enc, masks_pred.
    """

    def __init__(self, input_size=32, patch_size=4,
                 enc_mask_scale=(0.85, 1.0),
                 pred_mask_scale=(0.15, 0.25),
                 aspect_ratio=(0.75, 1.5),
                 nenc=1, npred=4, min_keep=4,
                 allow_overlap=False):
        self.h = self.w = input_size // patch_size
        self.enc_scale = enc_mask_scale
        self.pred_scale = pred_mask_scale
        self.ar = aspect_ratio
        self.nenc = nenc
        self.npred = npred
        self.min_keep = min_keep
        self.allow_overlap = allow_overlap
        self._ctr = Value("i", -1)

    # --- helpers ---
    def _step(self):
        with self._ctr.get_lock():
            self._ctr.value += 1
            return self._ctr.value

    def _block_size(self, gen, scale, ar_scale):
        r = torch.rand(1, generator=gen).item()
        s = scale[0] + r * (scale[1] - scale[0])
        n = int(self.h * self.w * s)
        ar = ar_scale[0] + r * (ar_scale[1] - ar_scale[0])
        bh = max(1, min(int(round(math.sqrt(n * ar))), self.h - 1))
        bw = max(1, min(int(round(math.sqrt(n / ar))), self.w - 1))
        return bh, bw

    def _sample_mask(self, bh, bw, acceptable=None):
        tries, timeout, max_timeout = 0, 20, 20
        while True:
            top = torch.randint(0, max(self.h - bh, 1), (1,))
            left = torch.randint(0, max(self.w - bw, 1), (1,))
            m = torch.zeros(self.h, self.w, dtype=torch.int32)
            m[top:top + bh, left:left + bw] = 1
            if acceptable is not None:
                for k in range(max(len(acceptable) - tries, 0)):
                    m *= acceptable[k]
            idx = torch.nonzero(m.flatten()).squeeze(-1)
            if len(idx) > self.min_keep:
                break
            timeout -= 1
            if timeout == 0:
                tries += 1
                timeout = max_timeout
        compl = torch.ones(self.h, self.w, dtype=torch.int32)
        compl[top:top + bh, left:left + bw] = 0
        return idx, compl

    # --- collate ---
    def __call__(self, batch):
        B = len(batch)
        collated = torch.utils.data.default_collate(batch)

        g = torch.Generator()
        g.manual_seed(self._step())
        p_sz = self._block_size(g, self.pred_scale, self.ar)
        e_sz = self._block_size(g, self.enc_scale, (1.0, 1.0))

        all_enc, all_pred = [], []
        mk_pred = mk_enc = self.h * self.w

        for _ in range(B):
            masks_p, compls = [], []
            for _ in range(self.npred):
                idx, compl = self._sample_mask(*p_sz)
                masks_p.append(idx)
                compls.append(compl)
                mk_pred = min(mk_pred, len(idx))
            all_pred.append(masks_p)

            acc = None if self.allow_overlap else compls
            masks_e = []
            for _ in range(self.nenc):
                idx, _ = self._sample_mask(*e_sz, acceptable=acc)
                masks_e.append(idx)
                mk_enc = min(mk_enc, len(idx))
            all_enc.append(masks_e)

        # Truncate to common length & collate into tensors
        all_pred = [[m[:mk_pred] for m in ms] for ms in all_pred]
        all_enc  = [[m[:mk_enc]  for m in ms] for ms in all_enc]
        masks_pred = torch.utils.data.default_collate(all_pred)  # list of npred (B, mk_pred)
        masks_enc  = torch.utils.data.default_collate(all_enc)   # list of nenc  (B, mk_enc)
        return collated, masks_enc, masks_pred

### 2.6 — Learning-rate / weight-decay / momentum schedules

In [None]:
class WarmupCosine:
    """Warmup → cosine-decay LR schedule."""
    def __init__(self, opt, warmup, start_lr, ref_lr, final_lr, T):
        self.opt, self.warmup = opt, warmup
        self.start_lr, self.ref_lr, self.final_lr = start_lr, ref_lr, final_lr
        self.T = T - warmup
        self.t = 0

    def step(self):
        self.t += 1
        if self.t < self.warmup:
            lr = self.start_lr + (self.t / max(1, self.warmup)) * (self.ref_lr - self.start_lr)
        else:
            p = (self.t - self.warmup) / max(1, self.T)
            lr = max(self.final_lr,
                     self.final_lr + (self.ref_lr - self.final_lr) * 0.5 * (1 + math.cos(math.pi * p)))
        for g in self.opt.param_groups:
            g["lr"] = lr
        return lr


class CosineWD:
    """Cosine weight-decay schedule."""
    def __init__(self, opt, ref_wd, final_wd, T):
        self.opt, self.ref_wd, self.final_wd, self.T = opt, ref_wd, final_wd, T
        self.t = 0

    def step(self):
        self.t += 1
        p = self.t / self.T
        wd = self.final_wd + (self.ref_wd - self.final_wd) * 0.5 * (1 + math.cos(math.pi * p))
        wd = max(self.final_wd, wd)
        for g in self.opt.param_groups:
            if not g.get("WD_exclude", False):
                g["weight_decay"] = wd
        return wd

## 3 — Data

In [None]:
train_transform = transforms.Compose([
    transforms.RandomCrop(IMG_SIZE, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
])

train_ds = torchvision.datasets.CIFAR10(
    root=DATA_ROOT, train=True, download=True, transform=train_transform)

mask_collator = MaskCollator(
    input_size=IMG_SIZE, patch_size=PATCH_SIZE,
    enc_mask_scale=ENC_MASK_SCALE, pred_mask_scale=PRED_MASK_SCALE,
    aspect_ratio=ASPECT_RATIO, nenc=N_ENC_MASKS, npred=N_PRED_MASKS,
    min_keep=MIN_KEEP, allow_overlap=ALLOW_OVERLAP,
)

train_loader = torch.utils.data.DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=True,
    drop_last=True, collate_fn=mask_collator,
)

print(f"Training samples : {len(train_ds):,}")
print(f"Batches / epoch  : {len(train_loader)}")

## 4 — Instantiate models & optimizer

In [None]:
encoder = VisionTransformer(
    img_size=IMG_SIZE, patch_size=PATCH_SIZE,
    embed_dim=EMBED_DIM, depth=DEPTH, heads=NUM_HEADS,
).to(DEVICE)

predictor = Predictor(
    n_patches=encoder.patch_embed.num_patches,
    embed_dim=EMBED_DIM, pred_dim=PRED_EMBED_DIM,
    depth=PRED_DEPTH, heads=PRED_NUM_HEADS,
).to(DEVICE)

target_encoder = copy.deepcopy(encoder).to(DEVICE)
for p in target_encoder.parameters():
    p.requires_grad = False

print(f"Encoder    : {sum(p.numel() for p in encoder.parameters()):>10,} params")
print(f"Predictor  : {sum(p.numel() for p in predictor.parameters()):>10,} params")
print(f"Patch grid : {int(encoder.patch_embed.num_patches**0.5)}×"
      f"{int(encoder.patch_embed.num_patches**0.5)} "
      f"= {encoder.patch_embed.num_patches} patches")

In [None]:
# Separate param groups: no weight-decay on biases and norms
param_groups = [
    {"params": [p for n, p in encoder.named_parameters()
                if "bias" not in n and len(p.shape) != 1]},
    {"params": [p for n, p in predictor.named_parameters()
                if "bias" not in n and len(p.shape) != 1]},
    {"params": [p for n, p in encoder.named_parameters()
                if "bias" in n or len(p.shape) == 1],
     "WD_exclude": True, "weight_decay": 0},
    {"params": [p for n, p in predictor.named_parameters()
                if "bias" in n or len(p.shape) == 1],
     "WD_exclude": True, "weight_decay": 0},
]
optimizer = torch.optim.AdamW(param_groups, lr=LR, weight_decay=WEIGHT_DECAY)

ipe = len(train_loader)
T = EPOCHS * ipe

lr_sched = WarmupCosine(optimizer, warmup=WARMUP_EPOCHS * ipe,
                        start_lr=START_LR, ref_lr=LR, final_lr=FINAL_LR, T=T)
wd_sched = CosineWD(optimizer, ref_wd=WEIGHT_DECAY,
                     final_wd=WEIGHT_DECAY, T=T)

# Momentum schedule for EMA (linear ramp)
mom_sched = iter(
    EMA_START + i * (EMA_END - EMA_START) / T for i in range(T + 1))

## 5 — Training loop

Each step:
1. **Target encoder** (no grad): full image → all 64 patch embeddings → select target positions → layer-norm.
2. **Context encoder**: full image → keep only context patches (via `masks_enc`).
3. **Predictor**: context tokens + mask tokens → predicted embeddings at target positions.
4. **Loss**: Smooth-L1 between predicted and target embeddings.
5. **EMA update** of target encoder.

In [None]:
loss_history = []

for epoch in range(EPOCHS):
    encoder.train()
    predictor.train()
    ep_loss, n = 0.0, 0
    t0 = time.time()

    for itr, (batch, masks_enc, masks_pred) in enumerate(train_loader):
        imgs = batch[0].to(DEVICE, non_blocking=True)
        m_enc  = [m.to(DEVICE) for m in masks_enc]
        m_pred = [m.to(DEVICE) for m in masks_pred]

        cur_lr = lr_sched.step()
        cur_wd = wd_sched.step()

        # ---------- forward ----------
        # target
        with torch.no_grad():
            h = target_encoder(imgs)                        # (B, 64, D)
            h = F.layer_norm(h, (h.size(-1),))
            B = h.size(0)
            h = apply_masks(h, m_pred)                      # (B·npred, K_tgt, D)
            h = repeat_interleave_batch(h, B, repeat=N_ENC_MASKS)

        # context → predictor
        z = encoder(imgs, m_enc)                            # (B·nenc, K_ctx, D)
        z = predictor(z, m_enc, m_pred)                     # (B·nenc·npred, K_tgt, D)

        loss = F.smooth_l1_loss(z, h)

        # ---------- backward ----------
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # ---------- EMA ----------
        with torch.no_grad():
            m = next(mom_sched)
            for pq, pk in zip(encoder.parameters(),
                              target_encoder.parameters()):
                pk.data.mul_(m).add_((1 - m) * pq.detach().data)

        ep_loss += loss.item()
        n += 1

        if itr % LOG_FREQ == 0:
            print(f"  [ep {epoch+1:>3d}/{EPOCHS}, itr {itr:>3d}/{ipe}] "
                  f"loss={loss.item():.4f}  lr={cur_lr:.2e}", end="\r")

    avg = ep_loss / n
    loss_history.append(avg)
    dt = time.time() - t0
    print(f"Epoch {epoch+1:>3d}/{EPOCHS}  loss={avg:.4f}  "
          f"lr={cur_lr:.2e}  ema={m:.5f}  ({dt:.1f}s)")

    # Save checkpoint every 20 epochs + last
    if (epoch + 1) % 20 == 0 or (epoch + 1) == EPOCHS:
        torch.save({
            "epoch": epoch + 1,
            "encoder": encoder.state_dict(),
            "predictor": predictor.state_dict(),
            "target_encoder": target_encoder.state_dict(),
            "optimizer": optimizer.state_dict(),
            "loss": avg,
        }, os.path.join(CKPT_DIR, f"ckpt_ep{epoch+1}.pth"))

print("\nTraining complete.")

## 6 — Loss curve

In [None]:
plt.figure(figsize=(8, 4))
plt.plot(range(1, len(loss_history) + 1), loss_history, linewidth=1.5)
plt.xlabel("Epoch")
plt.ylabel("Smooth-L1 Loss")
plt.title("I-JEPA Training Loss (CIFAR-10)")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 7 — Linear probe evaluation

Freeze the (target) encoder, extract mean-pooled patch features for the whole dataset, and train a linear classifier on top.

In [None]:
# Evaluation data loaders (no masking, standard transforms)
eval_norm = transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010))

probe_train_ds = torchvision.datasets.CIFAR10(
    root=DATA_ROOT, train=True, download=False,
    transform=transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), eval_norm,
    ]))
probe_test_ds = torchvision.datasets.CIFAR10(
    root=DATA_ROOT, train=False, download=True,
    transform=transforms.Compose([
        transforms.ToTensor(), eval_norm,
    ]))

probe_train_loader = torch.utils.data.DataLoader(
    probe_train_ds, batch_size=512, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=True)
probe_test_loader = torch.utils.data.DataLoader(
    probe_test_ds, batch_size=512, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=True)

In [None]:
@torch.no_grad()
def extract_features(enc, loader, device):
    """Mean-pool all patch tokens → one feature vector per image."""
    feats, labels = [], []
    enc.eval()
    for imgs, lbl in loader:
        h = enc(imgs.to(device))         # (B, 64, D)
        feats.append(h.mean(dim=1).cpu())  # (B, D)
        labels.append(lbl)
    return torch.cat(feats), torch.cat(labels)

print("Extracting features (target encoder)...")
train_feats, train_labels = extract_features(target_encoder, probe_train_loader, DEVICE)
test_feats,  test_labels  = extract_features(target_encoder, probe_test_loader, DEVICE)
print(f"Train: {train_feats.shape}   Test: {test_feats.shape}")

In [None]:
# Train a linear classifier on frozen features
linear = nn.Linear(EMBED_DIM, 10).to(DEVICE)
probe_opt = torch.optim.SGD(linear.parameters(), lr=PROBE_LR, momentum=0.9)
probe_sched = torch.optim.lr_scheduler.CosineAnnealingLR(probe_opt, T_max=PROBE_EPOCHS)

feat_ds = torch.utils.data.TensorDataset(train_feats, train_labels)
feat_loader = torch.utils.data.DataLoader(feat_ds, batch_size=512, shuffle=True)

probe_acc_history = []
best_acc = 0.0

for ep in range(PROBE_EPOCHS):
    linear.train()
    for ft, lb in feat_loader:
        ft, lb = ft.to(DEVICE), lb.to(DEVICE)
        loss = F.cross_entropy(linear(ft), lb)
        probe_opt.zero_grad()
        loss.backward()
        probe_opt.step()
    probe_sched.step()

    linear.eval()
    with torch.no_grad():
        logits = linear(test_feats.to(DEVICE))
        acc = (logits.argmax(1) == test_labels.to(DEVICE)).float().mean().item() * 100
    best_acc = max(best_acc, acc)
    probe_acc_history.append(acc)

    if (ep + 1) % 20 == 0 or ep == 0:
        print(f"  Probe ep {ep+1:>3d}/{PROBE_EPOCHS}  "
              f"acc={acc:.2f}%  best={best_acc:.2f}%")

print(f"\n>>> Best linear-probe accuracy: {best_acc:.2f}%")
print(f"    (random baseline = 10%)")

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(13, 4))

axes[0].plot(range(1, len(loss_history) + 1), loss_history, linewidth=1.5)
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Smooth-L1 Loss")
axes[0].set_title("I-JEPA Pre-training Loss")
axes[0].grid(True, alpha=0.3)

axes[1].plot(range(1, len(probe_acc_history) + 1), probe_acc_history,
             linewidth=1.5, color="tab:orange")
axes[1].axhline(best_acc, ls="--", color="gray", alpha=0.5,
                label=f"best = {best_acc:.1f}%")
axes[1].axhline(10, ls=":", color="red", alpha=0.4, label="chance (10%)")
axes[1].set_xlabel("Probe Epoch")
axes[1].set_ylabel("Test Accuracy (%)")
axes[1].set_title("Linear Probe on Frozen Features")
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 8 — Mask visualization (sanity check)

Show a sample image with context (visible) and target (predicted) patch regions.

In [None]:
def visualize_masks(img_tensor, masks_enc, masks_pred, patch_size=4, grid=8):
    """
    Overlay context and target masks on a single image.
    Green = context (encoder sees), Red = target (predictor predicts).
    """
    mean = np.array([0.4914, 0.4822, 0.4465])
    std  = np.array([0.2023, 0.1994, 0.2010])
    img = img_tensor.permute(1, 2, 0).cpu().numpy() * std + mean
    img = np.clip(img, 0, 1)

    overlay = img.copy()
    enc_idx  = masks_enc[0][0].cpu().numpy()   # first enc mask, first sample
    pred_idx = masks_pred[0][0].cpu().numpy()   # first pred mask, first sample

    for idx in enc_idx:
        r, c = divmod(int(idx), grid)
        y0, x0 = r * patch_size, c * patch_size
        overlay[y0:y0+patch_size, x0:x0+patch_size, 1] = \
            0.5 * overlay[y0:y0+patch_size, x0:x0+patch_size, 1] + 0.5

    for idx in pred_idx:
        r, c = divmod(int(idx), grid)
        y0, x0 = r * patch_size, c * patch_size
        overlay[y0:y0+patch_size, x0:x0+patch_size, 0] = \
            0.5 * overlay[y0:y0+patch_size, x0:x0+patch_size, 0] + 0.5

    fig, axes = plt.subplots(1, 2, figsize=(6, 3))
    axes[0].imshow(img)
    axes[0].set_title("Original")
    axes[0].axis("off")
    axes[1].imshow(np.clip(overlay, 0, 1))
    axes[1].set_title("Green=context, Red=target")
    axes[1].axis("off")
    plt.tight_layout()
    plt.show()


# Grab one batch and visualize
sample_batch = next(iter(train_loader))
sample_imgs, sample_enc, sample_pred = sample_batch
visualize_masks(sample_imgs[0][0], sample_enc, sample_pred)

---

## Summary

| Metric | Expected |
|---|---|
| Loss (epoch 1) | ~0.15 – 0.25 |
| Loss (epoch 100) | ~0.05 – 0.10 |
| Linear probe (100 ep pretrain) | > 30 – 45 % |

**Notes:**
- CIFAR-10 is 32×32 — very low resolution for ViT/JEPA (designed for 224×224 ImageNet). These numbers are for sanity-checking, not SOTA.
- Longer training (200–300 epochs) and larger models improve results.
- The target encoder (EMA) typically gives better linear-probe features than the online encoder.