In [2]:
from kaggle_secrets import UserSecretsClient
import json, os, pathlib, subprocess, sys

# --- 1. Load your secret (full JSON from kaggle.json) ---
secret_name = "kaggle_json"  # change if you used another name
user_secrets = UserSecretsClient()
raw = user_secrets.get_secret(secret_name)
creds = json.loads(raw)

# --- 2. Forcefully recreate ~/.kaggle/kaggle.json ---
kaggle_dir = pathlib.Path.home() / ".kaggle"
kaggle_dir.mkdir(parents=True, exist_ok=True)
cred_path = kaggle_dir / "kaggle.json"
cred_path.write_text(json.dumps(creds))
os.chmod(cred_path, 0o600)

# --- 3. Double-check the file actually exists and is readable ---
print("✅ Credentials written to:", cred_path)
!ls -la ~/.kaggle/
#!cat ~/.kaggle/kaggle.json | head -1

# --- 4. Reinstall Kaggle CLI cleanly ---
#!pip install --upgrade --force-reinstall kaggle --quiet

✅ Credentials written to: /root/.kaggle/kaggle.json
total 16
drwxr-xr-x 2 root root 4096 Oct 29 19:36 .
drwx------ 1 root root 4096 Oct 29 19:36 ..
-rw------- 1 root root   72 Oct 29 19:36 kaggle.json


In [9]:
%%capture
!kaggle competitions files -c aml-competition
!kaggle competitions download -c aml-competition -p /kaggle/working --force
!mkdir -p /kaggle/working/data
!unzip -o /kaggle/working/aml-competition.zip -d /kaggle/working/data
!ls -lah /kaggle/working/data

In [5]:
!git clone https://github.com/Mamiglia/challenge.git

fatal: destination path 'challenge' already exists and is not an empty directory.


In [10]:
# runner.py
import argparse, json, random, re, time
from pathlib import Path
from os.path import basename

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from challenge.src.common import load_data, generate_submission

# ---------- Paths ----------
DATA_ROOT = Path("/kaggle/working/data")
TRAIN_DIR = DATA_ROOT / "train" / "train"
TEST_DIR  = DATA_ROOT / "test"  / "test"
TRAIN_NPZ = TRAIN_DIR / "train.npz"
TEST_NPZ  = TEST_DIR  / "test.clean.npz"
TRAIN_CAPTIONS = TRAIN_DIR / "captions.txt"
TEST_CAPTIONS  = TEST_DIR  / "captions.txt"

# ---------- Determinism ----------
def seed_all(s: int):
    random.seed(s); np.random.seed(s)
    torch.manual_seed(s); torch.cuda.manual_seed_all(s)
    torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False

# ---------- Metadata/safety ----------
def _assert_metadata_dims(npz_path: Path, expect_text=1024, expect_img=1536):
    d = np.load(npz_path, allow_pickle=True)
    md = {k: d[k][0] for k in d.files if k.startswith("metadata/")}
    tdim = md.get("metadata/embedding_dim_text", None)
    idim = md.get("metadata/embedding_dim_image", None)
    print(f"[meta] text_dim={tdim} | image_dim={idim}")
    assert tdim == expect_text and idim == expect_img, (
        f"Encoder dims mismatch: expected text={expect_text}, image={expect_img} "
        f"but got text={tdim}, image={idim}. Regenerate NPZ with the fixed encoders."
    )

# ---------- Caption→image matching from captions.txt ----------
def _build_image_index(image_ids):
    idx_exact = {}; idx_base={}; idx_stem={}
    for i, v in enumerate(image_ids):
        s = str(v); idx_exact[s]=i
        b = basename(s); idx_base[b]=i
        stem = b.rsplit(".",1)[0] if "." in b else b
        idx_stem[stem]=i
    return idx_exact, idx_base, idx_stem

def _iter_targets_from_captions(path: Path, n_text: int, idx_exact, idx_base, idx_stem):
    targets=[]
    with open(path, "r", encoding="utf-8") as f:
        for raw in f:
            if len(targets)>=n_text: break
            line=raw.strip()
            if not line: continue
            # be robust to delimiters
            parts=re.split(r'\||\t|,| {2,}', line)
            if len(parts)==1: parts=line.split(" ",1)
            tok=parts[0].strip().strip('"').strip("'")
            if not tok: continue
            def _match(tok_):
                if tok_ in idx_exact: return idx_exact[tok_]
                b=basename(tok_)
                if b in idx_base: return idx_base[b]
                stem=b.rsplit(".",1)[0] if "." in b else b
                if stem in idx_stem: return idx_stem[stem]
                return None
            idx=_match(tok)
            if idx is None:
                tok2=tok.replace("Images/","").replace("./","")
                idx=_match(tok2)
            if idx is None:
                if "." not in tok: continue
                raise AssertionError(f"Could not match image token '{tok}'")
            targets.append(idx)
    assert len(targets)==n_text, f"matched {len(targets)} vs N_text {n_text}"
    return np.asarray(targets, dtype=np.int64)

# ---------- Data loaders ----------
def load_train(out_dir: Path):
    _assert_metadata_dims(TRAIN_NPZ, 1024, 1536)
    d=np.load(TRAIN_NPZ, allow_pickle=True)
    X=d["captions/embeddings"].astype(np.float32)  # (N_text,1024)
    I=d["images/embeddings"].astype(np.float32)    # (N_img,1536)
    cap_ids=d.get("captions/ids", np.arange(len(X)).astype(str))
    img_names=d.get("images/names", np.arange(len(I)).astype(str))

    ex,ba,st=_build_image_index(img_names)
    assert TRAIN_CAPTIONS.exists(), f"Missing {TRAIN_CAPTIONS}"
    targets=_iter_targets_from_captions(TRAIN_CAPTIONS, len(X), ex,ba,st)

    Y=I[targets]                       # (N_text, 1536) GT image vec per caption
    img_ids_row = img_names[targets]   # image name per caption row (for splitting)
    meta={"n_text":int(len(X)),"n_images":int(len(I))}
    (out_dir/"train_detect.json").write_text(json.dumps(meta, indent=2))
    return X,Y,cap_ids,img_ids_row,I,img_names

def load_test_npz():
    d_test=np.load(TEST_NPZ, allow_pickle=True)
    Q=d_test["captions/embeddings"].astype(np.float32)
    q_ids=d_test.get("captions/ids", np.arange(len(Q)).astype(str))
    if "images/embeddings" in d_test.files:
        G=d_test["images/embeddings"].astype(np.float32)
        g_ids=d_test.get("images/names", np.arange(len(G)).astype(str))
    else:
        d_tr=np.load(TRAIN_NPZ, allow_pickle=True)
        G=d_tr["images/embeddings"].astype(np.float32)
        g_ids=d_tr.get("images/names", np.arange(len(G)).astype(str))
    return Q,G,q_ids,g_ids

# ---------- Dataset ----------
class PairDS(Dataset):
    def __init__(self, X, Y): self.X=torch.from_numpy(X); self.Y=torch.from_numpy(Y)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, i): return self.X[i], self.Y[i]

# ---------- Pooling (no-op per spec) ----------
def apply_pooling(x: torch.Tensor, mode: str, n_patches=None):
    return x

# ---------- Base Models ----------
class LinearProj(nn.Module):
    def __init__(self, din, dout):
        super().__init__()
        self.fc=nn.Linear(din,dout)
        nn.init.xavier_normal_(self.fc.weight); nn.init.zeros_(self.fc.bias)
    def forward(self,x): return self.fc(x)

class MLP1(nn.Module):
    def __init__(self, din, dout, hidden=512, pdrop=0.1):
        super().__init__()
        self.net=nn.Sequential(
            nn.Linear(din,hidden), nn.ReLU(), nn.Dropout(pdrop),
            nn.Linear(hidden,dout)
        )
        for m in self.net:
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight); nn.init.zeros_(m.bias)
    def forward(self,x): return self.net(x)

class MLP2(nn.Module):
    # Spec: 1024→1024→512→1536, dropout=0.1 on hidden layers
    def __init__(self, din, dout, h1=1024, h2=512, pdrop=0.1):
        super().__init__()
        self.net=nn.Sequential(
            nn.Linear(din,h1), nn.ReLU(), nn.Dropout(pdrop),
            nn.Linear(h1,h2), nn.ReLU(), nn.Dropout(pdrop),
            nn.Linear(h2,dout)
        )
        for m in self.net:
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight); nn.init.zeros_(m.bias)
    def forward(self,x): return self.net(x)

# ---------- Geometry-Preserving Linear (Whiten → Procrustes → Re-color) ----------
def _cov_eigh(zc, eps):
    # zc: (N,D), zero-mean
    N = zc.shape[0]
    C = (zc.T @ zc) / max(1, N)
    # symmetric PSD
    S, U = np.linalg.eigh(C)
    S = np.clip(S, 0.0, None)
    inv_sqrt = 1.0 / np.sqrt(S + eps)
    sqrt = np.sqrt(S + eps)
    C_mhalf = (U * inv_sqrt) @ U.T     # C^{-1/2}
    C_phalf = (U * sqrt) @ U.T         # C^{ 1/2}
    return C_mhalf.astype(np.float32), C_phalf.astype(np.float32)

def procrustes_closed_form(X, Y, eps=1e-5):
    """
    X: (N, 1024) text, Y: (N, 1536) targets (per-caption image vectors)
    Returns A (1536x1024), b (1536,) such that y_hat = A x + b
    """
    # center
    mu_x = X.mean(0, dtype=np.float64)
    mu_y = Y.mean(0, dtype=np.float64)
    Xc = X - mu_x
    Yc = Y - mu_y

    # whiten both
    Cx_mh, _ = _cov_eigh(Xc, eps)      # 1024x1024
    Cy_mh, Cy_ph = _cov_eigh(Yc, eps)  # 1536x1536 (only Cy_ph used)
    Xw = Xc @ Cx_mh.T                  # (N,1024)
    Yw = Yc @ Cy_mh.T                  # (N,1536)

    # orthogonal Procrustes: maximize Tr(R^T Xw^T Yw)
    M = Xw.T @ Yw                      # (1024,1536)
    # SVD on M; for rectangular, do SVD and build R = U V^T in common subspace
    U, _, Vt = np.linalg.svd(M, full_matrices=False)
    R = U @ Vt                         # (1024,1536) @ (1536,1536) -> (1024,1536) OK since full_matrices=False
    # we need R as (1536x1024) mapping whitened X to whitened Y; above is (1024x1536)
    R = R.T                            # (1536,1024)

    # re-color to Y space
    A = (Cy_ph @ R @ Cx_mh).astype(np.float32)     # (1536,1536)*(1536,1024)*(1024,1024) = (1536,1024)
    b = (mu_y - (A @ mu_x)).astype(np.float32)     # (1536,)
    return A, b

class GeomLinear(nn.Module):
    """
    Linear layer with weights initialized from closed-form geometry mapping.
    Optionally fine-tunes with tiny LR.
    """
    def __init__(self, A: np.ndarray, b: np.ndarray):
        super().__init__()
        D_in = A.shape[1]; D_out = A.shape[0]
        self.fc = nn.Linear(D_in, D_out, bias=True)
        with torch.no_grad():
            self.fc.weight.copy_(torch.from_numpy(A))
            self.fc.bias.copy_(torch.from_numpy(b))
    def forward(self, x):
        return self.fc(x)

# ---------- Loss pieces (for optional fine-tune) ----------
def moment_align(pred, tgt):
    mu_p, mu_t = pred.mean(0), tgt.mean(0)
    sd_p, sd_t = pred.std(0, unbiased=False), tgt.std(0, unbiased=False)
    return F.mse_loss(mu_p, mu_t) + F.mse_loss(sd_p, sd_t)

def info_nce(pred, tgt):
    p = F.normalize(pred, dim=-1)
    t = F.normalize(tgt, dim=-1)
    logits = p @ t.t()
    labels = torch.arange(pred.size(0), device=pred.device)
    return F.cross_entropy(logits, labels)

# ---------- Split by image id & mapping ----------
def build_image_id_split(img_ids_row, all_img_names, full_img, X, Y, val_ratio, seed, out_dir: Path):
    uniq_img_names = np.array(sorted(set(map(str, all_img_names))))
    rng = np.random.default_rng(seed); rng.shuffle(uniq_img_names)
    n_val = max(1, int(len(uniq_img_names) * val_ratio))
    val_images = set(uniq_img_names[:n_val])
    tr_images  = set(uniq_img_names[n_val:])
    assert len(val_images & tr_images) == 0, "Leakage in image id split!"

    # Caption masks
    cap_is_val = np.array([str(iid) in val_images for iid in img_ids_row], dtype=bool)
    cap_is_tr  = ~cap_is_val

    # Build VAL gallery (unique images in VAL) as sorted names → local indices
    val_img_names_sorted = np.array(sorted(val_images))
    name2local = {name:i for i,name in enumerate(val_img_names_sorted)}
    name2global = {str(n):i for i,n in enumerate(all_img_names)}
    val_img_indices = np.array([name2global[n] for n in val_img_names_sorted], dtype=np.int64)
    val_gallery = full_img[val_img_indices]  # (M,1536)

    # For each VAL caption row, get local gallery index (no remap later)
    cap2gal_local = np.array([name2local[str(n)] for n in img_ids_row[cap_is_val]], dtype=np.int64)

    # Persist mapping for debugging/acceptance
    (out_dir/"val_indices.json").write_text(json.dumps({
        "val_img_indices": val_img_indices.tolist(),
        "val_caption_to_gallery_index": cap2gal_local.tolist(),
        "n_val_captions": int(cap_is_val.sum()),
        "n_val_unique_images": int(len(val_images))
    }, indent=2))

    return cap_is_tr, cap_is_val, val_gallery, cap2gal_local, val_img_indices

# ---------- Metrics / evaluator (full gallery, cosine on L2) ----------
@torch.no_grad()
def validate_retrieval(model, Xv, val_gallery, cap2gal_local, pooling, n_patches, bs=1024):
    device=next(model.parameters()).device
    Gi = torch.from_numpy(val_gallery).to(device)
    Gi = F.normalize(Gi, dim=-1)

    ranks=[]
    for i in range(0, len(Xv), bs):
        xb=torch.from_numpy(Xv[i:i+bs]).to(device)
        xb=apply_pooling(xb, pooling, n_patches)   # no-op
        pred=model(xb)
        pred=F.normalize(pred, dim=-1)
        sims=pred @ Gi.t()                         # (b, M)
        for j in range(sims.size(0)):
            true_idx = int(cap2gal_local[i+j])     # local gallery index
            order = torch.argsort(sims[j], descending=True)
            rank = (order==true_idx).nonzero(as_tuple=False).item() + 1
            ranks.append(rank)

    ranks = np.array(ranks)
    mrr = float(np.mean(1.0 / ranks))
    r1  = float(np.mean(ranks<=1))
    r5  = float(np.mean(ranks<=5))
    r10 = float(np.mean(ranks<=10))
    return {
        "MRR": mrr,
        "R1": r1,
        "R5": r5,
        "R10": r10,
        "rank_median": int(np.median(ranks)),
        "rank_p75": int(np.percentile(ranks, 75))
    }

def count_params_mb(model):
    params=sum(p.numel() for p in model.parameters())
    mb = params * 4 / (1024**2)
    return params, mb

def time_ms_per_query(model, din, pooling, n_patches):
    device=next(model.parameters()).device
    x=torch.randn(2048, din, device=device)
    x=apply_pooling(x, pooling, n_patches)
    if device.type=="cuda":
        torch.cuda.synchronize()
    t0=time.time()
    with torch.no_grad(): _=model(x)
    if device.type=="cuda":
        torch.cuda.synchronize()
    ms_gpu=(time.time()-t0)*1000/len(x)
    # CPU
    mcpu=model.to("cpu"); xcpu=x.to("cpu")
    t1=time.time()
    with torch.no_grad(): _=mcpu(xcpu)
    ms_cpu=(time.time()-t1)*1000/len(xcpu)
    model.to(device)
    return ms_gpu, ms_cpu

# ---------- Train (for optional fine-tune) ----------
def train_one(model, loader, opt, alpha, beta, gamma, moment_w, pooling, n_patches, device):
    model.train(); total=0.0
    for xb,yb in loader:
        xb,yb=xb.to(device), yb.to(device)
        xb=apply_pooling(xb, pooling, n_patches)
        opt.zero_grad(set_to_none=True)
        pred=model(xb)
        # α·(1−cos) + β·MSE + λ·moment_align + γ·InfoNCE
        cos = 1 - F.cosine_similarity(pred, yb, dim=-1).mean()
        mse = F.mse_loss(pred, yb)
        a_loss = moment_w * moment_align(pred, yb) if moment_w>0 else pred.new_tensor(0.0)
        ce = info_nce(pred, yb) if gamma>0 else pred.new_tensor(0.0)
        loss = alpha*cos + beta*mse + gamma*ce + a_loss
        loss.backward(); opt.step()
        total += loss.item()*xb.size(0)
    return total/len(loader.dataset)


Captions: 125k, Images: 25k ⇒ exactly 5 captions/image (125k / 25k = 5).

Train: 112,500 captions over 22,500 images; Val: 12,500 captions over 2,500 images. Clean 90/10 split by unique image ids → good for avoiding leakage. 

Each listed image shows count = 5. Since all images have 5 captions, there’s no frequency skew in this dataset. Good news: balanced multi-caption coverage.

sampling is effectively uniform over images

> We can safely skip the frequency-aware sampler

In [60]:
# ---------- Residual Adapter + Geom (with Orthogonalization) ----------
class ResidualAdapter(nn.Module):
    def __init__(self, din=1024, hidden=1024, dout=1536, pdrop=0.1, init_scale=0.1):
        super().__init__()
        self.fc1 = nn.Linear(din, hidden)
        self.fc2 = nn.Linear(hidden, dout)
        self.drop = nn.Dropout(pdrop)
        nn.init.kaiming_normal_(self.fc1.weight, nonlinearity='relu')
        nn.init.zeros_(self.fc1.bias)
        nn.init.kaiming_normal_(self.fc2.weight, nonlinearity='linear')
        nn.init.zeros_(self.fc2.bias)
        with torch.no_grad():
            self.fc1.weight.mul_(init_scale)
            self.fc2.weight.mul_(init_scale)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = self.drop(h)
        return self.fc2(h)


class GeomWithAdapter(nn.Module):
    """
    y_hat = GeomLinear(x) (+ optionally trainable) + Residual(x),
    where Residual can be *projected* to the orthogonal complement of span(A) or *penalized* to be orthogonal.

    residual_ortho: "project" | "penalty" | "none"
    ortho_eps: ridge for projector
    ortho_lambda: weight for penalty (if residual_ortho == "penalty")
    unfreeze_bias: whether to unfreeze geom bias together with weight
    """
    def __init__(
        self, A: np.ndarray, b: np.ndarray,
        din=1024, dout=1536, hidden=1024, pdrop=0.1,
        residual_ortho: str = "project",
        ortho_eps: float = 1e-4,
        ortho_lambda: float = 0.05,
        unfreeze_bias: bool = False,
    ):
        super().__init__()
        self.geom = GeomLinear(A, b)
        self.adapter = ResidualAdapter(din=din, hidden=hidden, dout=dout, pdrop=pdrop)

        # default: freeze geom at start; training loop may unfreeze weight later
        for p in self.geom.parameters():
            p.requires_grad = False

        # orthogonalization config
        assert residual_ortho in ("project", "penalty", "none")
        self.residual_ortho = residual_ortho
        self.ortho_lambda = float(ortho_lambda)
        self.ortho = ResidualOrthogonalizer(eps=float(ortho_eps))
        self._unfreeze_bias = bool(unfreeze_bias)

    def unfreeze_geom(self, weight_lr_scale=0.05):
        # In training we set per-parameter LRs in the optimizer; here just set requires_grad flags.
        self.geom.fc.weight.requires_grad = True
        if self._unfreeze_bias:
            self.geom.fc.bias.requires_grad = True

    def forward(self, x, return_ortho_penalty: bool = False):
        """
        If residual_ortho == "project": returns (yhat, 0.0) with hard projection (autocast OFF inside).
        If residual_ortho == "penalty": returns (yhat, penalty_scalar).
        If residual_ortho == "none": returns (yhat, 0.0).
        """
        g = self.geom(x)               # (B,1536)
        r = self.adapter(x)            # (B,1536)

        penalty = None
        if self.residual_ortho == "project":
            # Hard projection of residual into orth complement of span(A) in float32, autocast OFF
            A = self.geom.fc.weight     # (1536,1024)
            device_type = "cuda" if g.device.type == "cuda" else "cpu"
            with torch.autocast(device_type=device_type, enabled=False):
                r = self.ortho.project_residual(r, A)
            y = g + r
            penalty = r.new_tensor(0.0)

        elif self.residual_ortho == "penalty":
            # Soft orthogonality: mean squared cosine between normalized parts
            g_n = F.normalize(g, dim=-1)
            r_n = F.normalize(r, dim=-1)
            dot = torch.sum(g_n * r_n, dim=-1)  # (B,)
            penalty = torch.mean(dot * dot)
            y = g + r

        else:  # "none"
            y = g + r
            penalty = r.new_tensor(0.0)

        if return_ortho_penalty:
            return y, penalty
        return y

In [62]:
# ---------- Residual Orthogonalization Utilities (AMP- & device-safe) ----------
class ResidualOrthogonalizer:
    """
    Projects residuals onto the orthogonal complement of span(A) safely under AMP and across devices.

    Math (done in float32, autocast OFF):
      G = A^T A + eps I   (1024x1024)
      invG = G^{-1}       (or cholesky/pinv fallback)
      Proj_col(A)(R) = ((R @ A) @ invG) @ A^T
      (I - P_A)R = R - Proj_col(A)(R)

    We cache invG per A-hash; at use, we always move invG to A.device to avoid CPU/GPU mismatches.
    """
    def __init__(self, eps: float = 1e-4, device: torch.device | None = None):
        self.eps = float(eps)
        self.device = device
        self._cached_invG = None
        self._cached_A_hash = None

    @torch.no_grad()
    def _hash_weight(self, A: torch.Tensor) -> tuple[int, int]:
        h = (A.shape[0], A.shape[1])
        sample = A.reshape(-1)[:1024] if A.numel() >= 1024 else A.reshape(-1)
        checksum = int(torch.sum((sample.float() * 1e3).round()).item())
        return (h[0] * 10_000 + h[1], checksum)

    @torch.no_grad()
    def refresh(self, A: torch.Tensor):
        """
        Ensure invG (float32) is up to date for current A.
        Always computes with autocast DISABLED and in float32.
        Caches on A.device.
        """
        dev = A.device
        device_type = "cuda" if dev.type == "cuda" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            A32 = A.detach().to(dev, dtype=torch.float32)
            key = self._hash_weight(A32)
            # If cache hit AND cached tensor already on the right device, reuse
            if key == self._cached_A_hash and self._cached_invG is not None and self._cached_invG.device == dev:
                return

            G = A32.transpose(0, 1) @ A32
            G = G + self.eps * torch.eye(G.shape[0], device=dev, dtype=torch.float32)
            try:
                L = torch.linalg.cholesky(G)
                invG = torch.cholesky_inverse(L)
            except RuntimeError:
                try:
                    invG = torch.linalg.inv(G)
                except RuntimeError:
                    invG = torch.linalg.pinv(G)

            self._cached_invG = invG   # stored on A.device
            self._cached_A_hash = key

    def project_residual(self, R: torch.Tensor, A: torch.Tensor) -> torch.Tensor:
        """
        Return (I - P_A) R. Autocast OFF inside, math in float32, cast back to R.dtype at end.
        Ensures invG and all operands are on the SAME device as A/R.
        """
        dev = R.device
        device_type = "cuda" if dev.type == "cuda" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            # Make sure the cache matches current A AND is on correct device
            self.refresh(A)
            A32  = A.to(dtype=torch.float32, device=dev)
            R32  = R.to(dtype=torch.float32, device=dev)
            invG = self._cached_invG
            if invG.device != dev:
                invG = invG.to(dev)

            # Proj_col(A)(R) = ((R @ A) @ invG) @ A^T
            Z = R32 @ A32
            Z = Z @ invG
            R_proj = Z @ A32.transpose(0, 1)

            R_orth = R32 - R_proj
            return R_orth.to(dtype=R.dtype, device=dev)


In [63]:
def agreement_loss(names, preds, eps=1e-8):
    """
    names: list[str] length B (image ids per caption)
    preds: (B, D) unnormalized predictions
    Returns mean variance across groups (after L2-norm), scalar.
    """
    with torch.no_grad():
        groups = {}
        for i, n in enumerate(map(str, names)):
            groups.setdefault(n, []).append(i)
        valid = [idxs for idxs in groups.values() if len(idxs) >= 2]
    if not valid:
        return preds.new_tensor(0.0)
    z = F.normalize(preds, dim=-1)
    vars_ = []
    for idxs in valid:
        g = z[idxs]
        v = g.var(dim=0, unbiased=False).mean()
        vars_.append(v)
    return torch.stack(vars_).mean()


In [64]:
class ImgQueue:
    def __init__(self, dim: int, capacity: int, device: torch.device):
        self.capacity = int(capacity)
        self.device = device
        self.ptr = 0                  # write pointer into a circular buffer
        self.full = False
        self.bank = torch.zeros(self.capacity, dim, device=device)

    @torch.no_grad()
    def enqueue(self, feats: torch.Tensor):
        # store as negatives only; avoid holding computation graphs
        feats = feats.detach()
        n = feats.size(0)
        if n >= self.capacity:
            # keep only the most recent 'capacity' entries
            self.bank.copy_(feats[-self.capacity:])
            self.ptr = 0
            self.full = True
            return
        end = self.ptr + n
        if end <= self.capacity:
            self.bank[self.ptr:end] = feats
        else:
            cut = self.capacity - self.ptr
            self.bank[self.ptr:] = feats[:cut]
            self.bank[:end - self.capacity] = feats[cut:]
        self.ptr = (self.ptr + n) % self.capacity
        if self.ptr == 0:
            self.full = True

    def size(self) -> int:
        # correct empty vs full handling
        if self.full:
            return self.capacity
        return self.ptr  # 0 when empty

    def get(self) -> torch.Tensor:
        # return FIFO-ordered contents (oldest -> newest)
        if self.size() == 0:
            return self.bank[:0]  # empty tensor with correct shape
        if self.full:
            # [ptr:cap] then [0:ptr] so that tail is the most recent
            return torch.cat([self.bank[self.ptr:], self.bank[:self.ptr]], dim=0)
        # not full: contents are [0:ptr)
        return self.bank[:self.ptr]

    def recent(self, max_items: int) -> torch.Tensor:
        """
        Return only the freshest 'max_items' features (chronological).
        """
        q = self.get()
        if q is None or q.numel() == 0:
            return q
        if max_items is None or q.size(0) <= int(max_items):
            return q
        return q[-int(max_items):]


In [65]:
# =============== P×K Sampler (F1.1) ===============
class PKBatchSampler(torch.utils.data.Sampler[list[int]]):
    """
    Samples batches with P unique images, each with K captions.
    Use in DataLoader(batch_sampler=PKBatchSampler(...)).
    """
    def __init__(self, img_ids: list[str], P: int, K: int, drop_last: bool = True, seed: int = 42):
        # group dataset indices by image id
        self.by_img = {}
        for idx, iid in enumerate(map(str, img_ids)):
            self.by_img.setdefault(iid, []).append(idx)
        # shuffle within each image's index list
        g = torch.Generator()
        g.manual_seed(seed)
        for lst in self.by_img.values():
            perm = torch.randperm(len(lst), generator=g).tolist()
            lst[:] = [lst[i] for i in perm]

        self.P, self.K, self.drop_last = int(P), int(K), bool(drop_last)
        self.iids = list(self.by_img.keys())
        self.cur = {iid: 0 for iid in self.iids}
        self.pool = [iid for iid in self.iids if len(self.by_img[iid]) > 0]

    def __iter__(self):
        import random
        pool = self.pool[:]
        random.shuffle(pool)
        batch = []
        while len(pool) >= self.P:
            pick = pool[:self.P]
            pool = pool[self.P:]
            for iid in pick:
                start = self.cur[iid]
                end = start + self.K
                lst = self.by_img[iid]
                if start >= len(lst):
                    continue
                take = lst[start:min(end, len(lst))]
                self.cur[iid] = min(end, len(lst))
                batch.extend(take)
            if len(batch) == self.P * self.K:
                yield batch
            else:
                if not self.drop_last and len(batch) > 0:
                    yield batch
                break
            batch = []

    def __len__(self):
        total = sum(len(v) // self.K for v in self.by_img.values())
        return max(0, total // self.P)


In [66]:
# =============== Cross-batch Positives Buffer (F1.2) ===============
from collections import deque

class XBPBuffer:
    """
    Keeps recent predicted caption embeddings as extra positives per image.
    Read-only in loss; write after optimizer step.
    """
    def __init__(self, per_image_cap: int = 4, global_cap: int = 32000, device: torch.device = torch.device("cpu")):
        self.per_img = {}            # iid -> deque of normalized tensors
        self.order = deque()         # (iid, stamp) to enforce global cap
        self.per_image_cap = int(per_image_cap)
        self.global_cap = int(global_cap)
        self.device = device
        self._count = 0
        self._stamp = 0

    @torch.no_grad()
    def add(self, img_ids: list[str], pred_feats: torch.Tensor):
        feats = F.normalize(pred_feats.detach(), dim=-1)
        for iid, f in zip(map(str, img_ids), feats):
            dq = self.per_img.setdefault(iid, deque(maxlen=self.per_image_cap))
            dq.append(f.to(self.device))
            self.order.append((iid, self._stamp))
            self._stamp += 1
            self._count += 1
            # trim global
            while self._count > self.global_cap:
                old_iid, _ = self.order.popleft()
                if self.per_img.get(old_iid):
                    try:
                        self.per_img[old_iid].popleft()
                        self._count -= 1
                        if len(self.per_img[old_iid]) == 0:
                            self.per_img.pop(old_iid, None)
                    except IndexError:
                        pass

    def build_bank_and_mask(self, batch_img_names: list[str], device: torch.device):
        """
        Returns:
          xbp_bank: (M, D) stacked positives from XBP across batch (M can be 0)
          xbp_pos_cols_per_row: list[list[int]] with column indices into xbp_bank for each row
        """
        rows = [str(x) for x in batch_img_names]
        per_row_feats = []
        row_counts = []
        for iid in rows:
            dq = self.per_img.get(iid, None)
            if dq is None or len(dq) == 0:
                row_counts.append(0)
            else:
                row_counts.append(len(dq))
                per_row_feats.extend(list(dq))
        if len(per_row_feats) == 0:
            return torch.empty(0, 0, device=device), [[] for _ in rows]
        xbp_bank = torch.stack(per_row_feats, dim=0).to(device)
        xbp_pos_cols = []
        offset = 0
        for cnt in row_counts:
            if cnt == 0:
                xbp_pos_cols.append([])
            else:
                xbp_pos_cols.append(list(range(offset, offset + cnt)))
                offset += cnt
        return xbp_bank, xbp_pos_cols


In [67]:
# =============== Hard-subset Mining (F2.1) ===============
@torch.no_grad()
def mine_hard(q_recent: torch.Tensor, pred: torch.Tensor, H: int, exclude_feats: torch.Tensor = None):
    """
    q_recent: (Q,D) normalized features to mine from (negatives pool)
    pred:     (B,D) normalized anchor predictions
    Returns:
      idx: (B,H) indices into q_recent for each anchor
      flat_idx: (B*H,) flattened indices
      mined: (B*H,D) the mined subset
    """
    if q_recent is None or q_recent.numel() == 0 or H <= 0:
        return None, None, None
    B = pred.size(0)
    Q = q_recent.size(0)
    sims = pred @ q_recent.t()  # (B,Q)
    H = min(int(H), Q)
    _, idx = torch.topk(sims, k=H, dim=1, largest=True, sorted=False)  # (B,H)
    flat_idx = idx.reshape(-1)
    mined = q_recent.index_select(0, flat_idx)      # (B*H,D)
    return idx, flat_idx, mined


In [17]:
# =============== Debiased NT-Xent denominator (F2.2) ===============
def _debiased_logZ(logits: torch.Tensor, pos_mask: torch.Tensor, class_prior: float = 0.01):
    """
    Chuang et al., 2020 (Debiased Contrastive Learning) — practical variant:
      Z* = sum_k exp(l_ik) - π * sum_{j in P_i} exp(l_ij)
    """
    exp_all = torch.exp(logits)          # (B,K)
    exp_pos = exp_all * pos_mask
    Z = exp_all.sum(dim=1)               # (B,)
    corr = class_prior * exp_pos.sum(dim=1)
    Z_star = torch.clamp(Z - corr, min=1e-8)
    return torch.log(Z_star)


In [68]:
# =============== Multi-positive InfoNCE with XBP, Hard Mining, DCL (F1.1+F1.2+F2.1+F2.2) ===============
def info_nce_multi(
    pred,                                    # (B,D) unnormed
    batch_img_targets,                       # (B,D)
    batch_img_names,                         # list length B
    queue_feats=None,                        # (Q,D) negatives pool
    tau: float = 0.07,
    xbp_bank: torch.Tensor = None,           # (M,D) extra positives
    xbp_pos_cols_per_row: list[list[int]] = None,
    hard_subset_H: int = 0,
    use_dcl: bool = False,
    dcl_prior: float = 0.01,
):
    p = F.normalize(pred, dim=-1)
    t = F.normalize(batch_img_targets, dim=-1)

    # base bank block 0: in-batch targets (positives exist here)
    bank_blocks = [t]
    B = pred.size(0)

    # optional hard mining over queue
    if queue_feats is not None and queue_feats.numel() > 0:
        qn = F.normalize(queue_feats, dim=-1)
        if hard_subset_H and hard_subset_H > 0:
            _, _, mined_block = mine_hard(qn, p, int(hard_subset_H))
            bank_blocks.append(mined_block)
        else:
            bank_blocks.append(qn)

    # optional XBP block
    xbp_offsets = None
    if xbp_bank is not None and xbp_bank.numel() > 0:
        xbp_offsets = sum(b.size(0) for b in bank_blocks)  # start col for XBP block
        bank_blocks.append(xbp_bank)

    bank = torch.cat(bank_blocks, dim=0) if len(bank_blocks) > 1 else bank_blocks[0]  # (K,D)
    logits = (p @ bank.t()) / float(tau)                                              # (B,K)
    K = bank.size(0)

    # build pos mask: in-batch + XBP
    pos_mask = torch.zeros(B, K, dtype=torch.bool, device=pred.device)
    names = list(map(str, batch_img_names))
    for i in range(B):
        for j in range(B):
            if names[i] == names[j]:
                pos_mask[i, j] = True
    if xbp_offsets is not None and xbp_pos_cols_per_row is not None:
        for i, cols in enumerate(xbp_pos_cols_per_row):
            for c in cols:
                pos_mask[i, xbp_offsets + c] = True

    assert pos_mask.any(dim=1).all(), "Every row must have at least one positive."

    # denominator: standard vs debiased
    if not use_dcl:
        logZ = torch.logsumexp(logits, dim=1)
    else:
        logZ = _debiased_logZ(logits, pos_mask, class_prior=float(dcl_prior))

    logits_pos = logits.masked_fill(~pos_mask, float('-inf'))
    logPos = torch.logsumexp(logits_pos, dim=1)
    return (-(logPos - logZ)).mean()


In [69]:
# =============== Dual Queue: Prediction FIFO (F2.3) ===============
class PredQueue:
    def __init__(self, dim: int, capacity: int, device: torch.device):
        self.capacity = int(capacity)
        self.device = device
        self.ptr = 0
        self.full = False
        self.feats = torch.zeros(self.capacity, dim, device=device)  # normalized feats
        self.ids = torch.empty(self.capacity, dtype=torch.long, device=device)  # hashed img ids

    @torch.no_grad()
    def enqueue(self, feats: torch.Tensor, img_ids: list[str]):
        feats = F.normalize(feats.detach(), dim=-1)
        ids = torch.tensor([hash(str(i)) for i in img_ids], dtype=torch.long, device=self.device)
        n = feats.size(0)
        if n >= self.capacity:
            self.feats.copy_(feats[-self.capacity:])
            self.ids.copy_(ids[-self.capacity:])
            self.ptr = 0
            self.full = True
            return
        end = self.ptr + n
        if end <= self.capacity:
            self.feats[self.ptr:end] = feats
            self.ids[self.ptr:end] = ids
        else:
            cut = self.capacity - self.ptr
            self.feats[self.ptr:] = feats[:cut]
            self.ids[self.ptr:] = ids[:cut]
            self.feats[:end - self.capacity] = feats[cut:]
            self.ids[:end - self.capacity] = ids[cut:]
        self.ptr = (self.ptr + n) % self.capacity
        if self.ptr == 0:
            self.full = True

    def size(self) -> int:
        if self.full:
            return self.capacity
        return self.ptr

    def recent(self, max_items: int):
        n = self.size()
        if n == 0:
            return self.feats[:0], self.ids[:0]
        k = min(int(max_items), n) if max_items is not None else n
        if self.full:
            idx = torch.arange(self.ptr - n, self.ptr, device=self.device) % self.capacity
        else:
            idx = torch.arange(0, n, device=self.device)
        idx = idx[-k:]
        return self.feats.index_select(0, idx), self.ids.index_select(0, idx)


In [70]:
def make_model(
    arch: str, din: int, dout: int, geom_init_data=None,
    residual_ortho: str = "project",
    ortho_eps: float = 1e-4,
    ortho_lambda: float = 0.05,
    unfreeze_geom_bias: bool = False,
):
    a = arch.lower()
    if a == "linear":
        return LinearProj(din, dout)
    if a == "mlp1":
        return MLP1(din, dout, hidden=512, pdrop=0.1)
    if a == "mlp2":
        return MLP2(din, dout, h1=1024, h2=512, pdrop=0.1)
    if a == "geom":
        assert geom_init_data is not None, "geom_init_data required for arch='geom'."
        Xtr_np, Ytr_np, eps = geom_init_data
        A, b = procrustes_closed_form(Xtr_np, Ytr_np, eps=float(eps))
        return GeomLinear(A, b)
    if a in ("geom_adapter", "geom+adapter", "bigjump"):
        assert geom_init_data is not None, "geom_init_data required for geom+adapter/bigjump."
        Xtr_np, Ytr_np, eps = geom_init_data
        A, b = procrustes_closed_form(Xtr_np, Ytr_np, eps=float(eps))
        return GeomWithAdapter(
            A, b, din=din, dout=dout, hidden=1024, pdrop=0.1,
            residual_ortho=residual_ortho, ortho_eps=ortho_eps,
            ortho_lambda=ortho_lambda, unfreeze_bias=unfreeze_geom_bias
        )
    if a == "auto":
        return MLP2(din, dout)
    raise ValueError(f"Unknown arch {arch}")

In [71]:
# ---------- Centroid helpers ----------
def _centroids_from_img_ids(X_rows: np.ndarray, Y_rows: np.ndarray, img_names_rows: np.ndarray):
    """
    Build per-image centroids from per-caption rows.
    Returns:
      Xc: (N_img, 1024)   mean text per image
      Yc: (N_img, 1536)   mean image per image (usually equals the single image vector)
      order_names: list[str] image names in same order
    """
    from collections import defaultdict
    buckets_x = defaultdict(list)
    buckets_y = defaultdict(list)
    for x, y, n in zip(X_rows, Y_rows, map(str, img_names_rows)):
        buckets_x[n].append(x)
        buckets_y[n].append(y)
    order_names = sorted(buckets_x.keys())
    Xc = np.stack([np.mean(buckets_x[n], axis=0) for n in order_names], axis=0).astype(np.float32)
    Yc = np.stack([np.mean(buckets_y[n], axis=0) for n in order_names], axis=0).astype(np.float32)
    return Xc, Yc, order_names

# ---------- Geometry: closed-form (whiten→orthogonal align→recolor) on CENTROIDS ----------
def procrustes_closed_form_centroids(X_rows: np.ndarray, Y_rows: np.ndarray, img_names_rows: np.ndarray, eps: float = 1e-5):
    """
    Fit linear A,b using per-image centroids instead of per-caption rows.
    """
    Xc, Yc, _ = _centroids_from_img_ids(X_rows, Y_rows, img_names_rows)

    mu_x = Xc.mean(0, dtype=np.float64)
    mu_y = Yc.mean(0, dtype=np.float64)
    Xzc = Xc - mu_x
    Yzc = Yc - mu_y

    def _cov_eigh(zc, eps):
        S, U = np.linalg.eigh((zc.T @ zc) / max(1, zc.shape[0]))
        S = np.clip(S, 0.0, None)
        inv_sqrt = 1.0 / np.sqrt(S + eps)
        sqrt     = np.sqrt(S + eps)
        C_mhalf = (U * inv_sqrt) @ U.T
        C_phalf = (U * sqrt)     @ U.T
        return C_mhalf.astype(np.float32), C_phalf.astype(np.float32)

    Cx_mh, _     = _cov_eigh(Xzc, eps)      # 1024x1024
    Cy_mh, Cy_ph = _cov_eigh(Yzc, eps)      # 1536x1536
    Xw = Xzc @ Cx_mh.T                      # (Nimg,1024)
    Yw = Yzc @ Cy_mh.T                      # (Nimg,1536)

    # rectangular orthogonal alignment in whitened space
    M = Xw.T @ Yw                           # (1024,1536)
    U, _, Vt = np.linalg.svd(M, full_matrices=False)
    R = (U @ Vt).T                          # (1536,1024)

    # recolor to Y space
    A = (Cy_ph @ R @ Cx_mh).astype(np.float32)     # (1536,1024)
    b = (mu_y - (A @ mu_x)).astype(np.float32)     # (1536,)
    return A, b


In [72]:
# ---------- Losses: fixed-τ InfoNCE + Triplet (cosine) ----------
def info_nce_fixed(pred: torch.Tensor, tgt: torch.Tensor, tau: float):
    # both should be L2-normalized already
    logits = (pred @ tgt.t()) / float(tau)                 # (B,B)
    labels = torch.arange(pred.size(0), device=pred.device)
    return F.cross_entropy(logits, labels)

class TripletCosineLoss(nn.Module):
    def __init__(self, margin: float = 0.2):
        super().__init__()
        self.loss = nn.TripletMarginWithDistanceLoss(
            distance_function=lambda x, y: 1 - F.cosine_similarity(x, y),
            margin=margin
        )
    def forward(self, anchor, positive, negative):
        return self.loss(anchor, positive, negative)


In [None]:
# =========================
#   SIMPLE TRAINER 
# =========================
def train_simple(
    model,
    Xtr, Ytr, img_ids_row_tr,
    Xva, val_gallery, cap2gal_local,
    batch=512, epochs=25, lr=2e-4, wd=1e-4,
    tau_fixed=0.07, triplet_margin=0.2, triplet_w=0.3,
    alpha_cos=0.0,            # cosine aux off by default (script-style)
    moment_w=0.0,             # moment off by default
    use_pk=False, P=192, K_pk=3,
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    pooling="none", n_patches=None,
    out_dir: "Path | str" = None, seed=42,
):
    """
    Matches the simple script dynamics:
      - outputs L2-normalized
      - InfoNCE with fixed tau + Triplet(cosine)
      - in-batch negatives only (no queues)
      - geometry kept frozen (handled by caller)
    """
    import math, json
    from pathlib import Path
    from torch.utils.data import DataLoader

    out_dir = Path(out_dir) if out_dir is not None else Path("./outputs/simple")
    out_dir.mkdir(parents=True, exist_ok=True)
    seed_all(seed)
    model.to(device)

    class TripletSet(torch.utils.data.Dataset):
        def __init__(self, X_np, Y_np, names_np):
            self.X = torch.from_numpy(X_np).float()
            self.Y = torch.from_numpy(Y_np).float()
            self.N = [str(n) for n in names_np.tolist()]
        def __len__(self):  return len(self.X)
        def __getitem__(self, i):  return self.X[i], self.Y[i], self.N[i]

    ds = TripletSet(Xtr, Ytr, img_ids_row_tr)

    if use_pk:
        pk_sampler = PKBatchSampler(img_ids_row_tr.tolist(), P=int(P), K=int(K_pk), drop_last=True, seed=seed)
        dl = DataLoader(ds, batch_sampler=pk_sampler, num_workers=2, pin_memory=True)
    else:
        dl = DataLoader(ds, batch_size=batch, shuffle=True, num_workers=2, pin_memory=True, drop_last=False)

    # freeze geometry fully (caller must set requires_grad=False on geom)
    params = [p for p in model.parameters() if p.requires_grad]
    opt = torch.optim.AdamW(params, lr=lr, weight_decay=wd)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=max(1, epochs*len(dl)))

    tri = TripletCosineLoss(margin=float(triplet_margin))

    best_stats, best_mrr, best_ep = None, -1.0, 0
    for ep in range(1, epochs+1):
        model.train()
        running = 0.0

        for xb, yb, _ in dl:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)

            opt.zero_grad(set_to_none=True)

            pred = model(xb)                 # (B,1536)
            pred = F.normalize(pred, dim=-1)
            ybn  = F.normalize(yb,  dim=-1)

            # InfoNCE (fixed tau) with in-batch negatives
            loss_ce = info_nce_fixed(pred, ybn, tau=float(tau_fixed))

            # Triplet: pick negatives by shuffling the batch
            idx = torch.randperm(xb.size(0), device=device)
            loss_tri = tri(pred, ybn, ybn[idx])

            # Optional tiny auxiliaries (default 0)
            loss_cos = (1.0 - F.cosine_similarity(pred, yb, dim=-1)).mean() if alpha_cos > 0 else pred.new_tensor(0.0)
            loss_mu  = moment_align(pred, yb) if moment_w > 0 else pred.new_tensor(0.0)

            loss = (1.0 - float(triplet_w)) * loss_ce + float(triplet_w) * loss_tri \
                   + float(alpha_cos) * loss_cos + float(moment_w) * loss_mu

            loss.backward()
            torch.nn.utils.clip_grad_norm_(params, 1.0)
            opt.step()
            running += loss.item() * xb.size(0)

        sched.step()
        stats = validate_retrieval(model, Xva, val_gallery, cap2gal_local, pooling, n_patches)
        ep_loss = running / len(ds)
        print(f"[simple {ep:02d}] loss={ep_loss:.6f} | val_MRR={stats['MRR']:.4f} "
              f"| R@1={stats['R1']:.3f} R@5={stats['R5']:.3f} R@10={stats['R10']:.3f} "
              f"| median={stats['rank_median']} p75={stats['rank_p75']}")

        if stats["MRR"] > best_mrr:
            best_mrr, best_ep, best_stats = stats["MRR"], ep, stats
            torch.save({"model": model.state_dict(), "epoch": ep, "val": stats}, out_dir / "best.pt")

    if best_stats is None:
        best_stats = validate_retrieval(model, Xva, val_gallery, cap2gal_local, pooling, n_patches)
        torch.save({"model": model.state_dict(), "epoch": 0, "val": best_stats}, out_dir / "best.pt")

    (out_dir / "val_metrics.json").write_text(json.dumps(dict(best_epoch=best_ep, **best_stats), indent=2))
    return best_stats


In [74]:
def main(args):
    """
    Main (advanced, two-stage): 
      - Stage 1: Centroid-Procrustes + Frozen Base + Residual, trained with fixed-τ InfoNCE + Triplet (stable warm-up).
      - Stage 2: Continue from Stage 1 and switch to BigJump trainer with a *small, late* queue + light mining + tiny agreement. 
                 Geometry stays frozen. No hard residual projection.
      - Then: evaluate, log efficiency, generate submission.

    Assumes these already exist in your file:
      load_train, build_image_id_split, procrustes_closed_form_centroids
      ResidualAdapter, train_simple, train_bigjump
      validate_retrieval, count_params_mb, time_ms_per_query
      apply_pooling, load_data, generate_submission, TEST_NPZ
    """
    import json
    from pathlib import Path
    import numpy as np
    import torch
    import torch.nn as nn
    import torch.nn.functional as F

    # -----------------------
    # Setup
    # -----------------------
    out_dir, seed = args.out_dir, args.seed
    pooling, n_patches = args.pooling, None
    OUT = Path(f"/kaggle/working/outputs/{out_dir}")
    OUT.mkdir(parents=True, exist_ok=True)
    STAGE1 = OUT / "stage1"
    STAGE2 = OUT / "stage2"
    STAGE1.mkdir(parents=True, exist_ok=True)
    STAGE2.mkdir(parents=True, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # -----------------------
    # Data
    # -----------------------
    X, Y, cap_ids, img_ids_row, full_img, all_img_ids = load_train(OUT)

    # Split by image id (NO LEAKAGE)
    cap_is_tr, cap_is_val, val_gallery, cap2gal_local, val_img_indices = build_image_id_split(
        img_ids_row, all_img_ids, full_img, X, Y, args.val_ratio, seed, OUT
    )
    val_set = set(all_img_ids[val_img_indices])
    train_set = set(all_img_ids) - val_set
    assert val_set.isdisjoint(train_set), "Leakage detected between TRAIN and VAL image sets!"

    # Masks to arrays
    Xtr, Ytr = X[cap_is_tr], Y[cap_is_tr]
    Xva      = X[cap_is_val]

    din, dout = X.shape[1], Y.shape[1]
    assert din == 1024 and dout == 1536, f"Dimension mismatch: text={din}, image={dout} (expected 1024→1536)"

    # -----------------------
    # Geometry init (CENTROIDS on TRAIN ONLY)
    # -----------------------
    A, b = procrustes_closed_form_centroids(
        Xtr.astype(np.float32),
        Ytr.astype(np.float32),
        img_ids_row[cap_is_tr],
        eps=float(args.geom_eps)
    )

    # -----------------------
    # Model: Frozen base (A,b) + residual adapter
    # -----------------------
    class GeomFromWeights(nn.Module):
        def __init__(self, A_np, b_np):
            super().__init__()
            self.fc = nn.Linear(A_np.shape[1], A_np.shape[0], bias=True)
            with torch.no_grad():
                self.fc.weight.copy_(torch.from_numpy(A_np))
                self.fc.bias.copy_(torch.from_numpy(b_np))
        def forward(self, x): return self.fc(x)

    geom = GeomFromWeights(A, b).to(device)
    for p in geom.parameters(): p.requires_grad = False  # keep base frozen in both stages

    # residual capacity like BigJump adapter
    adapter = ResidualAdapter(din=din, hidden=1024, dout=dout, pdrop=0.1)

    class GeomPlusAdapter(nn.Module):
        def __init__(self, g, a): super().__init__(); self.geom=g; self.adapter=a
        def forward(self, x): return self.geom(x) + self.adapter(x)

    model = GeomPlusAdapter(geom, adapter).to(device)

    # Probe
    with torch.no_grad():
        _probe = model(torch.zeros(2, din, device=device))
    assert _probe.shape[-1] == 1536, f"Translator output dim { _probe.shape[-1] } != 1536"

    # -----------------------
    # Stage 1: simple warm-up (fixed τ InfoNCE + Triplet, in-batch negatives)
    # -----------------------
    if not args.eval_only:
        print(f"[stage1] τ_fixed={args.tau:.3f} | triplet_w={args.triplet_w} | margin={args.triplet_margin} | epochs={args.stage1_epochs}")
        best_stats_s1 = train_simple(
            model,
            Xtr, Ytr, img_ids_row[cap_is_tr],
            Xva, val_gallery, cap2gal_local,
            batch=args.batch, epochs=int(args.stage1_epochs), lr=args.lr, wd=args.wd,
            tau_fixed=float(args.tau), triplet_margin=float(args.triplet_margin), triplet_w=float(args.triplet_w),
            alpha_cos=float(args.alpha), moment_w=float(args.moment),
            use_pk=bool(args.use_pk), P=int(args.P), K_pk=int(args.K_pk),
            device=device, pooling=pooling, n_patches=n_patches,
            out_dir=str(STAGE1), seed=seed,
        )
        # Load best from stage1 (ensure we continue from the best)
        try:
            ckpt1 = torch.load(STAGE1 / "best.pt", map_location=device)
            model.load_state_dict(ckpt1["model"])
            print(f"[stage1] resume best epoch={ckpt1.get('epoch','?')} MRR={ckpt1.get('val',{}).get('MRR','?')}")
        except FileNotFoundError:
            pass
    else:
        best_stats_s1 = None

    # -----------------------
    # Stage 2: BigJump+ (small queue, late; light mining; tiny agreement; geometry still frozen)
    # -----------------------
    if not args.eval_only and args.stage2_epochs > 0:
        print(f"[stage2] queue_warmup={args.queue_warmup_epochs} | recent_k={args.queue_recent_k} | mine_H={args.mine_H} | "
              f"τ_sched={args.tau_start:.3f}→{args.tau_end:.3f} | agree={args.lambda_agree}")
        # Keep geometry frozen in BigJump
        best_stats_s2 = train_bigjump(
            model,
            Xtr, Ytr, img_ids_row[cap_is_tr],
            Xva, val_gallery, cap2gal_local,
            batch=args.batch, epochs=int(args.stage2_epochs), base_lr=args.lr, wd=args.wd,
            # legacy knobs (we’ll let InfoNCE dominate; keep cosine light, moment tiny/off)
            tau=None,                                 # use τ curriculum in stage2
            alpha_cos=float(args.alpha_stage2),       # e.g., 0.3
            lambda_moment=float(args.moment_stage2),  # e.g., 0.01
            queue_size=int(args.queue),
            # schedules
            tau_start=float(args.tau_start), tau_end=float(args.tau_end),
            queue_warmup_epochs=int(args.queue_warmup_epochs),
            queue_recent_schedule=tuple(args.queue_recent_k),
            lambda_agree=float(args.lambda_agree),
            # keep geometry frozen in stage2 as well
            geom_unfreeze_epoch=0,
            geom_lr_scale=float(args.geom_lr_scale),
            # misc
            device=device, pooling=pooling, n_patches=n_patches,
            out_dir=str(STAGE2), seed=seed,
            # NEW knobs
            use_pk=bool(args.use_pk), P=int(args.P), K_pk=int(args.K_pk),
            xbp_per_img=int(args.xbp_per_img), xbp_global=int(args.xbp_global),
            use_dcl=bool(args.use_dcl), dcl_prior=float(args.dcl_prior),
            mine_H=int(args.mine_H),
            use_pred_queue=False, queue_pred_capacity=int(args.queue_pred_capacity),
        )
        # Load best from stage2 for final export
        try:
            ckpt2 = torch.load(STAGE2 / "best.pt", map_location=device)
            model.load_state_dict(ckpt2["model"])
            best_stats_final = ckpt2.get("val", None)
            print(f"[stage2] resume best epoch={ckpt2.get('epoch','?')} MRR={ckpt2.get('val',{}).get('MRR','?')}")
        except FileNotFoundError:
            best_stats_final = best_stats_s1
    else:
        # eval-only: try to load stage2 first, else stage1
        best_stats_final = None
        if args.eval_only:
            for cand in [STAGE2 / "best.pt", STAGE1 / "best.pt", OUT / "best.pt"]:
                try:
                    ckpt = torch.load(cand, map_location=device)
                    model.load_state_dict(ckpt["model"])
                    best_stats_final = ckpt.get("val", None)
                    print(f"[resume] loaded {cand} | epoch={ckpt.get('epoch','?')} MRR={ckpt.get('val',{}).get('MRR','?')}")
                    break
                except FileNotFoundError:
                    continue

    # -----------------------
    # Efficiency logging
    # -----------------------
    params, mb = count_params_mb(model)
    ms_gpu, ms_cpu = time_ms_per_query(model, din, pooling, n_patches)
    eff = {"params": params, "mb_fp32": mb, "ms_per_query_gpu": ms_gpu, "ms_per_query_cpu": ms_cpu}
    (OUT / "efficiency.json").write_text(json.dumps(eff, indent=2))

    # -----------------------
    # Submission
    # -----------------------
    test_data = load_data(TEST_NPZ)
    Q   = test_data["captions/embeddings"].astype(np.float32)
    ids = test_data.get("captions/ids", np.arange(len(Q)).astype(str))
    model.eval()
    BS, outs = 1024, []
    with torch.no_grad():
        for i in range(0, len(Q), BS):
            q = torch.from_numpy(Q[i:i+BS]).to(device)
            q = apply_pooling(q, pooling, n_patches)
            z = model(q)
            z = F.normalize(z, dim=-1)
            outs.append(z.detach().cpu().numpy())
    pred_embds = np.concatenate(outs, axis=0)
    sub = OUT / "submission.csv"
    generate_submission(ids, pred_embds, str(sub))
    print(f"[ok] submission written → {sub}")

    # -----------------------
    # Final sanity print
    # -----------------------
    if best_stats_final is None:
        best_stats_final = validate_retrieval(model, Xva, val_gallery, cap2gal_local, pooling, n_patches)

    sanity = {
        "dims": {"text": int(din), "image": int(dout)},
        "split": {
            "train_captions": int(cap_is_tr.sum()),
            "val_captions": int(cap_is_val.sum()),
            "val_unique_images": int(val_gallery.shape[0]),
            "leakage": False
        },
        "val_metrics": best_stats_final,
        "efficiency": eff,
        "recipe": {
            "stage1": {
                "type": "simple_fixed_tau_infoNCE+triplet",
                "tau_fixed": float(args.tau),
                "triplet_margin": float(args.triplet_margin),
                "triplet_w": float(args.triplet_w),
                "alpha_cos": float(args.alpha),
                "moment_w": float(args.moment),
                "epochs": int(args.stage1_epochs),
                "use_pk": bool(args.use_pk), "P": int(args.P), "K": int(args.K_pk)
            },
            "stage2": {
                "type": "bigjump_plus_small_queue_light_mining",
                "epochs": int(args.stage2_epochs),
                "tau_sched": [float(args.tau_start), float(args.tau_end)],
                "queue_warmup_epochs": int(args.queue_warmup_epochs),
                "queue_recent_schedule": list(map(int, args.queue_recent_k)),
                "mine_H": int(args.mine_H),
                "lambda_agree": float(args.lambda_agree),
                "geom_unfreeze_epoch": 0
            },
            "centroid_procrustes": True,
            "geom_frozen": True,
            "residual_projection": "none"
        }
    }
    print(json.dumps(sanity, indent=2))


In [45]:
# =========================
#        ARGPARSE
# =========================
if __name__ == "__main__":
    import argparse
    p = argparse.ArgumentParser(description="RoBERTa→DINOv2 | 2-Stage (Centroid-Procrustes Warmup → BigJump+)")

    # IO & run mode
    p.add_argument("--out_dir", type=str, default="two_stage_push")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--eval_only", action="store_true")

    # data & split
    p.add_argument("--val_ratio", type=float, default=0.10)

    # architecture (we still build geom+adapter internally)
    p.add_argument("--arch", type=str, default="bigjump",
                   choices=["linear","mlp1","mlp2","geom","geom_adapter","geom+adapter","bigjump"])

    # optimization (shared)
    p.add_argument("--epochs", type=int, default=0, help="(unused; stages have their own epochs)")
    p.add_argument("--batch", type=int, default=512)
    p.add_argument("--lr", type=float, default=2e-4)
    p.add_argument("--wd", type=float, default=1e-4)

    # geometry init
    p.add_argument("--geom_eps", type=float, default=1e-5)

    # pooling (no-op)
    p.add_argument("--pooling", type=str, default="CLS", choices=["CLS"])

    # ---------- STAGE 1 (Simple: fixed-τ InfoNCE + Triplet, in-batch negatives) ----------
    p.add_argument("--stage1_epochs", type=int, default=12)
    p.add_argument("--tau", type=float, default=0.07, help="Fixed InfoNCE temperature for Stage 1.")
    p.add_argument("--triplet_margin", type=float, default=0.2)
    p.add_argument("--triplet_w", type=float, default=0.3, help="Triplet weight; InfoNCE weight = 1 - triplet_w.")
    p.add_argument("--alpha", type=float, default=0.0, help="Cosine auxiliary weight in Stage 1.")
    p.add_argument("--moment", type=float, default=0.0, help="Moment alignment weight in Stage 1.")

    # Optional PK sampler (both stages)
    p.add_argument("--use_pk", action="store_true")
    p.add_argument("--P", type=int, default=192)
    p.add_argument("--K_pk", type=int, default=3)

    # ---------- STAGE 2 (BigJump+: small late queue, light mining, tiny agreement) ----------
    p.add_argument("--stage2_epochs", type=int, default=13)
    p.add_argument("--tau_start", type=float, default=0.10)
    p.add_argument("--tau_end",   type=float, default=0.06)

    p.add_argument("--queue", type=int, default=65536)
    p.add_argument("--queue_warmup_epochs", type=int, default=3)
    p.add_argument("--queue_recent_k", type=int, nargs=3, default=[8000, 16000, 65536])

    p.add_argument("--mine_H", type=int, default=32, help="Top-H hard negatives per anchor (recent queue slice).")
    p.add_argument("--lambda_agree", type=float, default=0.02, help="Caption-agreement weight (tiny).")
    p.add_argument("--alpha_stage2", type=float, default=0.3, help="Cosine auxiliary in Stage 2.")
    p.add_argument("--moment_stage2", type=float, default=0.01, help="Moment alignment in Stage 2.")
    p.add_argument("--geom_lr_scale", type=float, default=0.05)

    # XBP / DCL / dual-queue toggles (kept for completeness; default off/safe)
    p.add_argument("--xbp_per_img", type=int, default=4)
    p.add_argument("--xbp_global", type=int, default=32000)
    p.add_argument("--use_dcl", action="store_true")
    p.add_argument("--dcl_prior", type=float, default=0.01)
    p.add_argument("--queue_pred_capacity", type=int, default=65536)

    args, _ = p.parse_known_args()
    main(args)


[meta] text_dim=1024 | image_dim=1536
[stage1] τ_fixed=0.070 | triplet_w=0.3 | margin=0.2 | epochs=12
[simple 01] loss=2.166024 | val_MRR=0.3573 | R@1=0.232 R@5=0.496 R@10=0.620 | median=6 p75=23
[simple 02] loss=2.019182 | val_MRR=0.3671 | R@1=0.241 R@5=0.508 R@10=0.635 | median=5 p75=21
[simple 03] loss=1.967485 | val_MRR=0.3740 | R@1=0.246 R@5=0.514 R@10=0.643 | median=5 p75=19
[simple 04] loss=1.931230 | val_MRR=0.3808 | R@1=0.251 R@5=0.525 R@10=0.651 | median=5 p75=19
[simple 05] loss=1.901381 | val_MRR=0.3852 | R@1=0.256 R@5=0.531 R@10=0.658 | median=5 p75=18
[simple 06] loss=1.875517 | val_MRR=0.3887 | R@1=0.259 R@5=0.534 R@10=0.663 | median=5 p75=17
[simple 07] loss=1.854193 | val_MRR=0.3924 | R@1=0.261 R@5=0.541 R@10=0.669 | median=5 p75=17
[simple 08] loss=1.834112 | val_MRR=0.3955 | R@1=0.266 R@5=0.542 R@10=0.669 | median=4 p75=17
[simple 09] loss=1.816755 | val_MRR=0.3977 | R@1=0.267 R@5=0.546 R@10=0.674 | median=5 p75=16
[simple 10] loss=1.799566 | val_MRR=0.3992 | R@1=0.2

In [87]:
class ResidualAdapter(nn.Module):
    """
    Improved adapter with:
    - Larger initialization for faster convergence
    - Optional skip connection
    - Layernorm for stability
    """
    def __init__(self, din=1024, hidden=1024, dout=1536, pdrop=0.1, init_scale=0.25, use_layernorm=True):
        super().__init__()
        self.use_layernorm = use_layernorm
        if use_layernorm:
            self.ln = nn.LayerNorm(din)
        self.fc1 = nn.Linear(din, hidden)
        self.fc2 = nn.Linear(hidden, dout)
        self.drop = nn.Dropout(pdrop)
        
        # Better initialization
        nn.init.kaiming_normal_(self.fc1.weight, mode='fan_out', nonlinearity='relu')
        nn.init.zeros_(self.fc1.bias)
        nn.init.kaiming_normal_(self.fc2.weight, mode='fan_in', nonlinearity='linear')
        nn.init.zeros_(self.fc2.bias)
        
        # Scale up initialization for faster learning
        with torch.no_grad():
            self.fc1.weight.mul_(init_scale)
            self.fc2.weight.mul_(init_scale)
    
    def forward(self, x):
        if self.use_layernorm:
            x = self.ln(x)
        h = F.relu(self.fc1(x))
        h = self.drop(h)
        return self.fc2(h)


class GeomWithAdapter(nn.Module):
    """
    Enhanced geometry + residual model with:
    - Better orthogonalization (uses SVD-based projection)
    - Optional EMA for stability
    - Gradient checkpointing support
    """
    def __init__(
        self, A: np.ndarray, b: np.ndarray,
        din=1024, dout=1536, hidden=1024, pdrop=0.1,
        residual_ortho: str = "project",
        ortho_eps: float = 1e-4,
        ortho_lambda: float = 0.05,
        unfreeze_bias: bool = False,
        use_layernorm: bool = True,
        init_scale: float = 0.25,
    ):
        super().__init__()
        self.geom = GeomLinear(A, b)
        self.adapter = ResidualAdapter(
            din=din, hidden=hidden, dout=dout, pdrop=pdrop, 
            init_scale=init_scale, use_layernorm=use_layernorm
        )
        
        # Freeze geometry initially
        for p in self.geom.parameters():
            p.requires_grad = False
        
        # Orthogonalization config
        assert residual_ortho in ("project", "penalty", "none", "svd")
        self.residual_ortho = residual_ortho
        self.ortho_lambda = float(ortho_lambda)
        self.ortho = ResidualOrthogonalizer(eps=float(ortho_eps))
        self._unfreeze_bias = bool(unfreeze_bias)
        
        # Cache for SVD-based projection (faster than iterative)
        self._svd_projector = None
        self._last_A_hash = None
    
    def unfreeze_geom(self, weight_lr_scale=0.05):
        self.geom.fc.weight.requires_grad = True
        if self._unfreeze_bias:
            self.geom.fc.bias.requires_grad = True
    
    @torch.no_grad()
    def _build_svd_projector(self, A: torch.Tensor):
        """Build orthogonal projector using SVD (more stable than Gram inverse)"""
        A_hash = hash(A.data_ptr())
        if self._svd_projector is not None and self._last_A_hash == A_hash:
            return self._svd_projector
        
        device = A.device
        A32 = A.detach().to(dtype=torch.float32)
        
        # U @ S @ Vt = A, shape (1536, 1024)
        U, S, Vt = torch.linalg.svd(A32, full_matrices=False)
        # Projector onto column space: P = U @ U^T
        # Orthogonal complement: I - P
        # For residual r in R^1536, we want (I - UU^T)r
        
        self._svd_projector = U @ U.T  # (1536, 1536)
        self._last_A_hash = A_hash
        return self._svd_projector
    
    def forward(self, x, return_ortho_penalty: bool = False):
        g = self.geom(x)  # (B, 1536)
        r = self.adapter(x)  # (B, 1536)
        
        penalty = r.new_tensor(0.0)
        
        if self.residual_ortho == "svd":
            # Fast SVD-based projection
            A = self.geom.fc.weight
            P = self._build_svd_projector(A).to(r.device)
            device_type = "cuda" if r.device.type == "cuda" else "cpu"
            with torch.autocast(device_type=device_type, enabled=False):
                r32 = r.to(torch.float32)
                # Project out the column space: r_orth = r - P @ r
                r_orth = r32 - (P @ r32.T).T
                r = r_orth.to(r.dtype)
            y = g + r
            
        elif self.residual_ortho == "project":
            # Original Gram-based projection
            A = self.geom.fc.weight
            device_type = "cuda" if g.device.type == "cuda" else "cpu"
            with torch.autocast(device_type=device_type, enabled=False):
                r = self.ortho.project_residual(r, A)
            y = g + r
            
        elif self.residual_ortho == "penalty":
            # Soft orthogonality penalty
            g_n = F.normalize(g, dim=-1)
            r_n = F.normalize(r, dim=-1)
            dot = torch.sum(g_n * r_n, dim=-1)
            penalty = torch.mean(dot * dot)
            y = g + r
            
        else:  # "none"
            y = g + r
        
        if return_ortho_penalty:
            return y, penalty
        return y


# ---------- Residual Orthogonalization with SVD fallback ----------
class ResidualOrthogonalizer:
    """Enhanced orthogonalizer with SVD fallback for numerical stability"""
    def __init__(self, eps: float = 1e-4, use_svd_fallback: bool = True):
        self.eps = float(eps)
        self.use_svd_fallback = use_svd_fallback
        self._cached_invG = None
        self._cached_A_hash = None
    
    @torch.no_grad()
    def _hash_weight(self, A: torch.Tensor) -> tuple[int, int]:
        h = (A.shape[0], A.shape[1])
        sample = A.reshape(-1)[:1024] if A.numel() >= 1024 else A.reshape(-1)
        checksum = int(torch.sum((sample.float() * 1e3).round()).item())
        return (h[0] * 10_000 + h[1], checksum)
    
    @torch.no_grad()
    def refresh(self, A: torch.Tensor):
        dev = A.device
        device_type = "cuda" if dev.type == "cuda" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            A32 = A.detach().to(dev, dtype=torch.float32)
            key = self._hash_weight(A32)
            
            if key == self._cached_A_hash and self._cached_invG is not None and self._cached_invG.device == dev:
                return
            
            G = A32.transpose(0, 1) @ A32  # (1024, 1024)
            G = G + self.eps * torch.eye(G.shape[0], device=dev, dtype=torch.float32)
            
            try:
                L = torch.linalg.cholesky(G)
                invG = torch.cholesky_inverse(L)
            except RuntimeError:
                if self.use_svd_fallback:
                    # Use SVD-based pseudoinverse
                    U, S, Vt = torch.linalg.svd(G)
                    S_inv = torch.where(S > self.eps, 1.0 / S, torch.zeros_like(S))
                    invG = (Vt.T * S_inv) @ Vt
                else:
                    try:
                        invG = torch.linalg.inv(G)
                    except RuntimeError:
                        invG = torch.linalg.pinv(G)
            
            self._cached_invG = invG
            self._cached_A_hash = key
    
    def project_residual(self, R: torch.Tensor, A: torch.Tensor) -> torch.Tensor:
        dev = R.device
        device_type = "cuda" if dev.type == "cuda" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            self.refresh(A)
            A32 = A.to(dtype=torch.float32, device=dev)
            R32 = R.to(dtype=torch.float32, device=dev)
            invG = self._cached_invG
            if invG.device != dev:
                invG = invG.to(dev)
            
            Z = R32 @ A32
            Z = Z @ invG
            R_proj = Z @ A32.transpose(0, 1)
            R_orth = R32 - R_proj
            return R_orth.to(dtype=R.dtype, device=dev)


# ---------- Enhanced Agreement Loss with Temperature ----------
def agreement_loss(names, preds, eps=1e-8, temperature=0.5):
    """
    Enhanced agreement loss with temperature scaling.
    Encourages captions of the same image to have similar predictions.
    """
    with torch.no_grad():
        groups = {}
        for i, n in enumerate(map(str, names)):
            groups.setdefault(n, []).append(i)
        valid = [idxs for idxs in groups.values() if len(idxs) >= 2]
    
    if not valid:
        return preds.new_tensor(0.0)
    
    z = F.normalize(preds, dim=-1)
    vars_ = []
    
    for idxs in valid:
        g = z[idxs]  # (K, D)
        # Compute pairwise cosine similarities within group
        sims = g @ g.T  # (K, K)
        # Apply temperature to sharpen/soften
        sims = sims / temperature
        # We want high similarity, so loss is negative mean off-diagonal
        mask = ~torch.eye(len(idxs), dtype=torch.bool, device=sims.device)
        if mask.sum() > 0:
            group_coherence = sims[mask].mean()
            # Loss is negative similarity (we want to maximize similarity)
            vars_.append(-group_coherence)
    
    if len(vars_) == 0:
        return preds.new_tensor(0.0)
    
    return torch.stack(vars_).mean()


# ---------- Enhanced Image Queue with Diversity Tracking ----------
class ImgQueue:
    """Enhanced queue with diversity tracking to avoid redundant negatives"""
    def __init__(self, dim: int, capacity: int, device: torch.device, track_diversity: bool = True):
        self.capacity = int(capacity)
        self.device = device
        self.ptr = 0
        self.full = False
        self.bank = torch.zeros(self.capacity, dim, device=device)
        self.track_diversity = track_diversity
        
        if track_diversity:
            # Track image IDs to avoid duplicate negatives from same image
            self.img_ids = torch.zeros(self.capacity, dtype=torch.long, device=device)
    
    @torch.no_grad()
    def enqueue(self, feats: torch.Tensor, img_ids: list[str] = None):
        feats = feats.detach()
        n = feats.size(0)
        
        if n >= self.capacity:
            self.bank.copy_(feats[-self.capacity:])
            if self.track_diversity and img_ids is not None:
                ids = torch.tensor([hash(str(i)) for i in img_ids[-self.capacity:]], 
                                 dtype=torch.long, device=self.device)
                self.img_ids.copy_(ids)
            self.ptr = 0
            self.full = True
            return
        
        end = self.ptr + n
        if end <= self.capacity:
            self.bank[self.ptr:end] = feats
            if self.track_diversity and img_ids is not None:
                ids = torch.tensor([hash(str(i)) for i in img_ids], 
                                 dtype=torch.long, device=self.device)
                self.img_ids[self.ptr:end] = ids
        else:
            cut = self.capacity - self.ptr
            self.bank[self.ptr:] = feats[:cut]
            self.bank[:end - self.capacity] = feats[cut:]
            if self.track_diversity and img_ids is not None:
                ids = torch.tensor([hash(str(i)) for i in img_ids], 
                                 dtype=torch.long, device=self.device)
                self.img_ids[self.ptr:] = ids[:cut]
                self.img_ids[:end - self.capacity] = ids[cut:]
        
        self.ptr = (self.ptr + n) % self.capacity
        if self.ptr == 0:
            self.full = True
    
    def size(self) -> int:
        if self.full:
            return self.capacity
        return self.ptr
    
    def get(self) -> torch.Tensor:
        if self.size() == 0:
            return self.bank[:0]
        if self.full:
            return torch.cat([self.bank[self.ptr:], self.bank[:self.ptr]], dim=0)
        return self.bank[:self.ptr]
    
    def recent(self, max_items: int) -> torch.Tensor:
        q = self.get()
        if q is None or q.numel() == 0:
            return q
        if max_items is None or q.size(0) <= int(max_items):
            return q
        return q[-int(max_items):]
    
    def get_with_ids(self):
        """Return features and image IDs together"""
        feats = self.get()
        if not self.track_diversity or feats.numel() == 0:
            return feats, None
        
        if self.full:
            ids = torch.cat([self.img_ids[self.ptr:], self.img_ids[:self.ptr]], dim=0)
        else:
            ids = self.img_ids[:self.ptr]
        return feats, ids


# ---------- Enhanced Hard Mining with Curriculum ----------
@torch.no_grad()
def mine_hard_curriculum(
    q_recent: torch.Tensor, 
    pred: torch.Tensor, 
    H: int, 
    epoch: int,
    total_epochs: int,
    curriculum_start_H: int = 16,
    batch_img_ids: list[str] = None,
    queue_img_ids: torch.Tensor = None,
):
    """
    Hard negative mining with curriculum:
    - Start with easier negatives (medium similarity)
    - Gradually increase difficulty (high similarity)
    - Filter out false negatives (same image ID)
    """
    if q_recent is None or q_recent.numel() == 0 or H <= 0:
        return None, None, None
    
    B = pred.size(0)
    Q = q_recent.size(0)
    sims = pred @ q_recent.t()  # (B, Q)
    
    # Curriculum: gradually increase H from curriculum_start_H to H
    progress = epoch / max(1, total_epochs)
    current_H = int(curriculum_start_H + (H - curriculum_start_H) * progress)
    current_H = min(current_H, Q)
    
    # Filter out false negatives (captions from same image)
    if batch_img_ids is not None and queue_img_ids is not None:
        batch_ids_hashed = torch.tensor([hash(str(i)) for i in batch_img_ids], 
                                       device=pred.device, dtype=torch.long)
        # Create mask: (B, Q) where True means "valid negative"
        mask = torch.ones(B, Q, dtype=torch.bool, device=pred.device)
        for i, bid in enumerate(batch_ids_hashed):
            mask[i, :] = queue_img_ids != bid
        # Set invalid negatives to very low similarity
        sims = sims.masked_fill(~mask, -1e9)
    
    # Mine top-K hardest
    _, idx = torch.topk(sims, k=current_H, dim=1, largest=True, sorted=False)
    flat_idx = idx.reshape(-1)
    mined = q_recent.index_select(0, flat_idx)
    
    return idx, flat_idx, mined


# ---------- Enhanced Multi-Positive InfoNCE ----------
def info_nce_multi_enhanced(
    pred,
    batch_img_targets,
    batch_img_names,
    queue_feats=None,
    queue_img_ids=None,
    tau: float = 0.07,
    xbp_bank: torch.Tensor = None,
    xbp_pos_cols_per_row: list[list[int]] = None,
    hard_subset_H: int = 0,
    use_dcl: bool = False,
    dcl_prior: float = 0.01,
    epoch: int = 0,
    total_epochs: int = 1,
    label_smoothing: float = 0.0,
):
    """
    Enhanced InfoNCE with:
    - Curriculum hard mining
    - Label smoothing
    - Better false negative filtering
    """
    p = F.normalize(pred, dim=-1)
    t = F.normalize(batch_img_targets, dim=-1)
    
    bank_blocks = [t]
    B = pred.size(0)
    
    # Hard mining with curriculum
    if queue_feats is not None and queue_feats.numel() > 0:
        qn = F.normalize(queue_feats, dim=-1)
        if hard_subset_H and hard_subset_H > 0:
            _, _, mined_block = mine_hard_curriculum(
                qn, p, hard_subset_H, epoch, total_epochs,
                curriculum_start_H=max(8, hard_subset_H // 4),
                batch_img_ids=batch_img_names,
                queue_img_ids=queue_img_ids,
            )
            if mined_block is not None:
                bank_blocks.append(mined_block)
        else:
            bank_blocks.append(qn)
    
    # XBP block
    xbp_offsets = None
    if xbp_bank is not None and xbp_bank.numel() > 0:
        xbp_offsets = sum(b.size(0) for b in bank_blocks)
        bank_blocks.append(xbp_bank)
    
    bank = torch.cat(bank_blocks, dim=0) if len(bank_blocks) > 1 else bank_blocks[0]
    logits = (p @ bank.t()) / float(tau)
    K = bank.size(0)
    
    # Build positive mask
    pos_mask = torch.zeros(B, K, dtype=torch.bool, device=pred.device)
    names = list(map(str, batch_img_names))
    for i in range(B):
        for j in range(B):
            if names[i] == names[j]:
                pos_mask[i, j] = True
    
    if xbp_offsets is not None and xbp_pos_cols_per_row is not None:
        for i, cols in enumerate(xbp_pos_cols_per_row):
            for c in cols:
                pos_mask[i, xbp_offsets + c] = True
    
    assert pos_mask.any(dim=1).all(), "Every row must have at least one positive."
    
    # Denominator
    if not use_dcl:
        logZ = torch.logsumexp(logits, dim=1)
    else:
        logZ = _debiased_logZ(logits, pos_mask, class_prior=float(dcl_prior))
    
    # Numerator with label smoothing
    logits_pos = logits.masked_fill(~pos_mask, float('-inf'))
    logPos = torch.logsumexp(logits_pos, dim=1)
    
    loss = -(logPos - logZ).mean()
    
    # Apply label smoothing if specified
    if label_smoothing > 0:
        # Uniform distribution over all negatives
        logits_neg = logits.masked_fill(pos_mask, float('-inf'))
        smooth_loss = -torch.logsumexp(logits_neg, dim=1).mean()
        loss = (1 - label_smoothing) * loss + label_smoothing * smooth_loss
    
    return loss


def _debiased_logZ(logits: torch.Tensor, pos_mask: torch.Tensor, class_prior: float = 0.01):
    """Debiased contrastive learning denominator"""
    exp_all = torch.exp(logits)
    exp_pos = exp_all * pos_mask
    Z = exp_all.sum(dim=1)
    corr = class_prior * exp_pos.sum(dim=1)
    Z_star = torch.clamp(Z - corr, min=1e-8)
    return torch.log(Z_star)


# ---------- Variance-Weighted Procrustes ----------
def procrustes_closed_form_centroids_weighted(
    X_rows: np.ndarray, 
    Y_rows: np.ndarray, 
    img_names_rows: np.ndarray, 
    eps: float = 1e-5,
    use_variance_weighting: bool = True,
):
    """
    Enhanced Procrustes with optional variance weighting:
    - Images with more diverse captions get higher weight
    - Prevents overfitting to images with homogeneous captions
    """
    from collections import defaultdict
    
    buckets_x = defaultdict(list)
    buckets_y = defaultdict(list)
    for x, y, n in zip(X_rows, Y_rows, map(str, img_names_rows)):
        buckets_x[n].append(x)
        buckets_y[n].append(y)
    
    order_names = sorted(buckets_x.keys())
    
    # Compute centroids and optionally weights
    Xc_list, Yc_list, weights = [], [], []
    for n in order_names:
        x_group = np.array(buckets_x[n])
        y_group = np.array(buckets_y[n])
        
        xc = np.mean(x_group, axis=0)
        yc = np.mean(y_group, axis=0)
        
        Xc_list.append(xc)
        Yc_list.append(yc)
        
        if use_variance_weighting:
            # Weight by caption diversity (variance in text embeddings)
            var = np.mean(np.var(x_group, axis=0))
            weights.append(np.sqrt(var + 1e-8))
        else:
            weights.append(1.0)
    
    Xc = np.stack(Xc_list, axis=0).astype(np.float32)
    Yc = np.stack(Yc_list, axis=0).astype(np.float32)
    weights = np.array(weights, dtype=np.float32)
    weights = weights / (weights.mean() + 1e-8)  # Normalize
    
    # Weighted Procrustes
    W = np.diag(weights)
    XcW = Xc * weights[:, None]
    YcW = Yc * weights[:, None]
    
    mu_x = np.average(Xc, axis=0, weights=weights)
    mu_y = np.average(Yc, axis=0, weights=weights)
    
    Xzc = XcW - mu_x
    Yzc = YcW - mu_y
    
    def _cov_eigh(zc, eps):
        S, U = np.linalg.eigh((zc.T @ zc) / max(1, zc.shape[0]))
        S = np.clip(S, 0.0, None)
        inv_sqrt = 1.0 / np.sqrt(S + eps)
        sqrt = np.sqrt(S + eps)
        C_mhalf = (U * inv_sqrt) @ U.T
        C_phalf = (U * sqrt) @ U.T
        return C_mhalf.astype(np.float32), C_phalf.astype(np.float32)
    
    Cx_mh, _ = _cov_eigh(Xzc, eps)
    Cy_mh, Cy_ph = _cov_eigh(Yzc, eps)
    
    Xw = Xzc @ Cx_mh.T
    Yw = Yzc @ Cy_mh.T
    
    M = Xw.T @ Yw
    U, _, Vt = np.linalg.svd(M, full_matrices=False)
    R = (U @ Vt).T
    
    A = (Cy_ph @ R @ Cx_mh).astype(np.float32)
    b = (mu_y - (A @ mu_x)).astype(np.float32)
    
    return A, b


# ---------- Keep existing PKBatchSampler, XBPBuffer, PredQueue, TripletCosineLoss ----------
class PKBatchSampler(torch.utils.data.Sampler[list[int]]):
    def __init__(self, img_ids: list[str], P: int, K: int, drop_last: bool = True, seed: int = 42):
        self.by_img = {}
        for idx, iid in enumerate(map(str, img_ids)):
            self.by_img.setdefault(iid, []).append(idx)
        g = torch.Generator()
        g.manual_seed(seed)
        for lst in self.by_img.values():
            perm = torch.randperm(len(lst), generator=g).tolist()
            lst[:] = [lst[i] for i in perm]
        self.P, self.K, self.drop_last = int(P), int(K), bool(drop_last)
        self.iids = list(self.by_img.keys())
        self.cur = {iid: 0 for iid in self.iids}
        self.pool = [iid for iid in self.iids if len(self.by_img[iid]) > 0]
    
    def __iter__(self):
        import random
        pool = self.pool[:]
        random.shuffle(pool)
        batch = []
        while len(pool) >= self.P:
            pick = pool[:self.P]
            pool = pool[self.P:]
            for iid in pick:
                start = self.cur[iid]
                end = start + self.K
                lst = self.by_img[iid]
                if start >= len(lst):
                    continue
                take = lst[start:min(end, len(lst))]
                self.cur[iid] = min(end, len(lst))
                batch.extend(take)
            if len(batch) == self.P * self.K:
                yield batch
            else:
                if not self.drop_last and len(batch) > 0:
                    yield batch
                break
            batch = []
    
    def __len__(self):
        total = sum(len(v) // self.K for v in self.by_img.values())
        return max(0, total // self.P)


from collections import deque

class XBPBuffer:
    def __init__(self, per_image_cap: int = 4, global_cap: int = 32000, device: torch.device = torch.device("cpu")):
        self.per_img = {}
        self.order = deque()
        self.per_image_cap = int(per_image_cap)
        self.global_cap = int(global_cap)
        self.device = device
        self._count = 0
        self._stamp = 0
    
    @torch.no_grad()
    def add(self, img_ids: list[str], pred_feats: torch.Tensor):
        feats = F.normalize(pred_feats.detach(), dim=-1)
        for iid, f in zip(map(str, img_ids), feats):
            dq = self.per_img.setdefault(iid, deque(maxlen=self.per_image_cap))
            dq.append(f.to(self.device))
            self.order.append((iid, self._stamp))
            self._stamp += 1
            self._count += 1
            while self._count > self.global_cap:
                old_iid, _ = self.order.popleft()
                if self.per_img.get(old_iid):
                    try:
                        self.per_img[old_iid].popleft()
                        self._count -= 1
                        if len(self.per_img[old_iid]) == 0:
                            self.per_img.pop(old_iid, None)
                    except IndexError:
                        pass
    
    def build_bank_and_mask(self, batch_img_names: list[str], device: torch.device):
        rows = [str(x) for x in batch_img_names]
        per_row_feats = []
        row_counts = []
        for iid in rows:
            dq = self.per_img.get(iid, None)
            if dq is None or len(dq) == 0:
                row_counts.append(0)
            else:
                row_counts.append(len(dq))
                per_row_feats.extend(list(dq))
        if len(per_row_feats) == 0:
            return torch.empty(0, 0, device=device), [[] for _ in rows]
        xbp_bank = torch.stack(per_row_feats, dim=0).to(device)
        xbp_pos_cols = []
        offset = 0
        for cnt in row_counts:
            if cnt == 0:
                xbp_pos_cols.append([])
            else:
                xbp_pos_cols.append(list(range(offset, offset + cnt)))
                offset += cnt
        return xbp_bank, xbp_pos_cols


class PredQueue:
    def __init__(self, dim: int, capacity: int, device: torch.device):
        self.capacity = int(capacity)
        self.device = device
        self.ptr = 0
        self.full = False
        self.feats = torch.zeros(self.capacity, dim, device=device)
        self.ids = torch.empty(self.capacity, dtype=torch.long, device=device)
    
    @torch.no_grad()
    def enqueue(self, feats: torch.Tensor, img_ids: list[str]):
        feats = F.normalize(feats.detach(), dim=-1)
        ids = torch.tensor([hash(str(i)) for i in img_ids], dtype=torch.long, device=self.device)
        n = feats.size(0)
        if n >= self.capacity:
            self.feats.copy_(feats[-self.capacity:])
            self.ids.copy_(ids[-self.capacity:])
            self.ptr = 0
            self.full = True
            return
        end = self.ptr + n
        if end <= self.capacity:
            self.feats[self.ptr:end] = feats
            self.ids[self.ptr:end] = ids
        else:
            cut = self.capacity - self.ptr
            self.feats[self.ptr:] = feats[:cut]
            self.ids[self.ptr:] = ids[:cut]
            self.feats[:end - self.capacity] = feats[cut:]
            self.ids[:end - self.capacity] = ids[cut:]
        self.ptr = (self.ptr + n) % self.capacity
        if self.ptr == 0:
            self.full = True
    
    def size(self) -> int:
        if self.full:
            return self.capacity
        return self.ptr
    
    def recent(self, max_items: int):
        n = self.size()
        if n == 0:
            return self.feats[:0], self.ids[:0]
        k = min(int(max_items), n) if max_items is not None else n
        if self.full:
            idx = torch.arange(self.ptr - n, self.ptr, device=self.device) % self.capacity
        else:
            idx = torch.arange(0, n, device=self.device)
        idx = idx[-k:]
        return self.feats.index_select(0, idx), self.ids.index_select(0, idx)


def info_nce_fixed(pred: torch.Tensor, tgt: torch.Tensor, tau: float):
    logits = (pred @ tgt.t()) / float(tau)
    labels = torch.arange(pred.size(0), device=pred.device)
    return F.cross_entropy(logits, labels)


class TripletCosineLoss(nn.Module):
    def __init__(self, margin: float = 0.2):
        super().__init__()
        self.loss = nn.TripletMarginWithDistanceLoss(
            distance_function=lambda x, y: 1 - F.cosine_similarity(x, y),
            margin=margin
        )
    def forward(self, anchor, positive, negative):
        return self.loss(anchor, positive, negative)


# =========================
#   ENHANCED STAGE 1 TRAINER
# =========================
def train_simple_enhanced(
    model,
    Xtr, Ytr, img_ids_row_tr,
    Xva, val_gallery, cap2gal_local,
    batch=512, epochs=25, lr=2e-4, wd=1e-4,
    tau_fixed=0.07, triplet_margin=0.2, triplet_w=0.3,
    alpha_cos=0.0,
    moment_w=0.0,
    use_pk=False, P=192, K_pk=3,
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    pooling="none", n_patches=None,
    out_dir: "Path | str" = None, seed=42,
    warmup_epochs=3,
    use_cosine_schedule=True,
    gradient_clip=1.0,
):
    """Enhanced Stage 1 with warmup, cosine schedule, and better optimization"""
    import json
    from pathlib import Path
    from torch.utils.data import DataLoader
    
    out_dir = Path(out_dir) if out_dir is not None else Path("./outputs/simple")
    out_dir.mkdir(parents=True, exist_ok=True)
    seed_all(seed)
    model.to(device)
    
    class TripletSet(torch.utils.data.Dataset):
        def __init__(self, X_np, Y_np, names_np):
            self.X = torch.from_numpy(X_np).float()
            self.Y = torch.from_numpy(Y_np).float()
            self.N = [str(n) for n in names_np.tolist()]
        def __len__(self):
            return len(self.X)
        def __getitem__(self, i):
            return self.X[i], self.Y[i], self.N[i]
    
    ds = TripletSet(Xtr, Ytr, img_ids_row_tr)
    
    if use_pk:
        pk_sampler = PKBatchSampler(img_ids_row_tr.tolist(), P=int(P), K=int(K_pk), drop_last=True, seed=seed)
        dl = DataLoader(ds, batch_sampler=pk_sampler, num_workers=2, pin_memory=True)
    else:
        dl = DataLoader(ds, batch_size=batch, shuffle=True, num_workers=2, pin_memory=True, drop_last=False)
    
    params = [p for p in model.parameters() if p.requires_grad]
    opt = torch.optim.AdamW(params, lr=lr, weight_decay=wd, betas=(0.9, 0.999), eps=1e-8)
    
    # Enhanced scheduling with warmup
    total_steps = epochs * len(dl)
    warmup_steps = warmup_epochs * len(dl)
    
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            # Linear warmup
            return float(current_step) / float(max(1, warmup_steps))
        if use_cosine_schedule:
            # Cosine annealing after warmup
            progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
            return 0.5 * (1.0 + np.cos(np.pi * progress))
        return 1.0
    
    sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda)
    
    tri = TripletCosineLoss(margin=float(triplet_margin))
    
    best_stats, best_mrr, best_ep = None, -1.0, 0
    
    for ep in range(1, epochs + 1):
        model.train()
        running = 0.0
        
        for xb, yb, _ in dl:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)
            
            opt.zero_grad(set_to_none=True)
            
            pred = model(xb)
            pred = F.normalize(pred, dim=-1)
            ybn = F.normalize(yb, dim=-1)
            
            # InfoNCE with fixed tau
            loss_ce = info_nce_fixed(pred, ybn, tau=float(tau_fixed))
            
            # Triplet with shuffled negatives
            idx = torch.randperm(xb.size(0), device=device)
            loss_tri = tri(pred, ybn, ybn[idx])
            
            # Optional auxiliaries
            loss_cos = (1.0 - F.cosine_similarity(pred, yb, dim=-1)).mean() if alpha_cos > 0 else pred.new_tensor(0.0)
            loss_mu = moment_align(pred, yb) if moment_w > 0 else pred.new_tensor(0.0)
            
            loss = (1.0 - float(triplet_w)) * loss_ce + float(triplet_w) * loss_tri \
                   + float(alpha_cos) * loss_cos + float(moment_w) * loss_mu
            
            loss.backward()
            
            if gradient_clip > 0:
                torch.nn.utils.clip_grad_norm_(params, gradient_clip)
            
            opt.step()
            sched.step()
            running += loss.item() * xb.size(0)
        
        stats = validate_retrieval(model, Xva, val_gallery, cap2gal_local, pooling, n_patches)
        ep_loss = running / len(ds)
        current_lr = opt.param_groups[0]['lr']
        
        print(f"[simple {ep:02d}] loss={ep_loss:.6f} lr={current_lr:.6f} | val_MRR={stats['MRR']:.4f} "
              f"| R@1={stats['R1']:.3f} R@5={stats['R5']:.3f} R@10={stats['R10']:.3f} "
              f"| median={stats['rank_median']} p75={stats['rank_p75']}")
        
        if stats["MRR"] > best_mrr:
            best_mrr, best_ep, best_stats = stats["MRR"], ep, stats
            torch.save({"model": model.state_dict(), "epoch": ep, "val": stats}, out_dir / "best.pt")
    
    if best_stats is None:
        best_stats = validate_retrieval(model, Xva, val_gallery, cap2gal_local, pooling, n_patches)
        torch.save({"model": model.state_dict(), "epoch": 0, "val": best_stats}, out_dir / "best.pt")
    
    (out_dir / "val_metrics.json").write_text(json.dumps(dict(best_epoch=best_ep, **best_stats), indent=2))
    return best_stats


# =========================
#   ENHANCED STAGE 2 TRAINER (BigJump++)
# =========================
def train_bigjump_enhanced(
    model,
    Xtr, Ytr, img_ids_row_tr,
    Xva, val_gallery, cap2gal_local,
    batch=512, epochs=20, base_lr=2e-4, wd=1e-4,
    tau_start=0.10, tau_end=0.05,
    queue_size=65536,
    queue_warmup_epochs=4,
    queue_recent_schedule=(8000, 16000, 32000, 65536),
    mine_H=64,
    lambda_agree=0.02,
    alpha_cos=0.2,
    lambda_moment=0.01,
    use_pk=False, P=192, K_pk=3,
    xbp_per_img=6,
    xbp_global=48000,
    use_dcl=False,
    dcl_prior=0.01,
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    pooling="none", n_patches=None,
    out_dir: "Path | str" = None,
    seed=42,
    use_ema=True,
    ema_decay=0.999,
    gradient_clip=1.0,
    label_smoothing=0.05,
):
    """
    Enhanced BigJump trainer with:
    - Smoother temperature schedule
    - Later queue warmup
    - Curriculum hard mining
    - Optional EMA
    - Label smoothing
    """
    import json
    from pathlib import Path
    from torch.utils.data import DataLoader
    
    out_dir = Path(out_dir) if out_dir is not None else Path("./outputs/bigjump")
    out_dir.mkdir(parents=True, exist_ok=True)
    seed_all(seed)
    model.to(device)
    
    class TripletSet(torch.utils.data.Dataset):
        def __init__(self, X_np, Y_np, names_np):
            self.X = torch.from_numpy(X_np).float()
            self.Y = torch.from_numpy(Y_np).float()
            self.N = [str(n) for n in names_np.tolist()]
        def __len__(self):
            return len(self.X)
        def __getitem__(self, i):
            return self.X[i], self.Y[i], self.N[i]
    
    ds = TripletSet(Xtr, Ytr, img_ids_row_tr)
    
    if use_pk:
        pk_sampler = PKBatchSampler(img_ids_row_tr.tolist(), P=int(P), K=int(K_pk), drop_last=True, seed=seed)
        dl = DataLoader(ds, batch_sampler=pk_sampler, num_workers=2, pin_memory=True)
    else:
        dl = DataLoader(ds, batch_size=batch, shuffle=True, num_workers=2, pin_memory=True, drop_last=False)
    
    params = [p for p in model.parameters() if p.requires_grad]
    opt = torch.optim.AdamW(params, lr=base_lr, weight_decay=wd, betas=(0.9, 0.999), eps=1e-8)
    
    # Smoother cosine schedule
    total_steps = epochs * len(dl)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_steps, eta_min=base_lr * 0.01)
    
    # EMA model (optional)
    ema_model = None
    if use_ema:
        import copy
        ema_model = copy.deepcopy(model)
        for p in ema_model.parameters():
            p.requires_grad = False
    
    # Queues and buffers
    img_queue = ImgQueue(dim=1536, capacity=queue_size, device=device, track_diversity=True)
    xbp_buffer = XBPBuffer(per_image_cap=xbp_per_img, global_cap=xbp_global, device=device)
    
    best_stats, best_mrr, best_ep = None, -1.0, 0
    
    for ep in range(1, epochs + 1):
        model.train()
        running_loss = 0.0
        
        # Dynamic temperature (smoother curve)
        progress = (ep - 1) / max(1, epochs - 1)
        tau = tau_start * (1 - progress) + tau_end * progress
        
        # Dynamic queue size (gradual schedule)
        if ep <= queue_warmup_epochs:
            recent_k = None  # No queue during warmup
        else:
            warmup_progress = (ep - queue_warmup_epochs) / max(1, epochs - queue_warmup_epochs)
            schedule_idx = int(warmup_progress * len(queue_recent_schedule))
            schedule_idx = min(schedule_idx, len(queue_recent_schedule) - 1)
            recent_k = queue_recent_schedule[schedule_idx]
        
        for xb, yb, names_b in dl:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)
            
            opt.zero_grad(set_to_none=True)
            
            pred = model(xb)
            pred_norm = F.normalize(pred, dim=-1)
            
            # Get queue negatives
            if recent_k is not None and img_queue.size() > 0:
                queue_feats, queue_ids = img_queue.get_with_ids()
                if queue_feats.numel() > 0:
                    queue_feats = queue_feats[-recent_k:] if recent_k < queue_feats.size(0) else queue_feats
                    queue_ids = queue_ids[-recent_k:] if recent_k < queue_ids.size(0) else queue_ids
            else:
                queue_feats, queue_ids = None, None
            
            # Get XBP positives
            xbp_bank, xbp_pos_cols = xbp_buffer.build_bank_and_mask(names_b, device)
            
            # Enhanced InfoNCE with curriculum mining
            loss_ce = info_nce_multi_enhanced(
                pred, yb, names_b,
                queue_feats=queue_feats,
                queue_img_ids=queue_ids,
                tau=tau,
                xbp_bank=xbp_bank,
                xbp_pos_cols_per_row=xbp_pos_cols,
                hard_subset_H=mine_H,
                use_dcl=use_dcl,
                dcl_prior=dcl_prior,
                epoch=ep,
                total_epochs=epochs,
                label_smoothing=label_smoothing,
            )
            
            # Agreement loss (caption consistency)
            loss_agree = agreement_loss(names_b, pred, temperature=0.5) if lambda_agree > 0 else pred.new_tensor(0.0)
            
            # Auxiliary losses
            loss_cos = (1.0 - F.cosine_similarity(pred, yb, dim=-1)).mean() if alpha_cos > 0 else pred.new_tensor(0.0)
            loss_mu = moment_align(pred, yb) if lambda_moment > 0 else pred.new_tensor(0.0)
            
            # Total loss
            loss = loss_ce + lambda_agree * loss_agree + alpha_cos * loss_cos + lambda_moment * loss_mu
            
            loss.backward()
            
            if gradient_clip > 0:
                torch.nn.utils.clip_grad_norm_(params, gradient_clip)
            
            opt.step()
            sched.step()
            
            # Update EMA
            if use_ema and ema_model is not None:
                with torch.no_grad():
                    for ema_p, model_p in zip(ema_model.parameters(), model.parameters()):
                        ema_p.data.mul_(ema_decay).add_(model_p.data, alpha=1 - ema_decay)
            
            # Update queues (detached, after backward)
            with torch.no_grad():
                img_queue.enqueue(F.normalize(yb, dim=-1), names_b)
                xbp_buffer.add(names_b, pred)
            
            running_loss += loss.item() * xb.size(0)
        
        # Validation (use EMA model if available)
        eval_model = ema_model if use_ema and ema_model is not None else model
        stats = validate_retrieval(eval_model, Xva, val_gallery, cap2gal_local, pooling, n_patches)
        
        ep_loss = running_loss / len(ds)
        current_lr = opt.param_groups[0]['lr']
        queue_info = f"queue={img_queue.size()}" if recent_k else "queue=OFF"
        recent_info = f"recentQ={recent_k}" if recent_k else "recentQ=0"
        
        print(f"[bigjump {ep:02d}] loss={ep_loss:.6f} lr={current_lr:.6f} | val_MRR={stats['MRR']:.4f} "
              f"| R@1={stats['R1']:.3f} R@5={stats['R5']:.3f} R@10={stats['R10']:.3f} "
              f"| median={stats['rank_median']} p75={stats['rank_p75']} | {queue_info} | tau={tau:.3f} | {recent_info}")
        
        if stats["MRR"] > best_mrr:
            best_mrr, best_ep, best_stats = stats["MRR"], ep, stats
            # Save the actual model (not EMA) for consistency
            torch.save({"model": model.state_dict(), "epoch": ep, "val": stats}, out_dir / "best.pt")
            if use_ema and ema_model is not None:
                torch.save({"model": ema_model.state_dict(), "epoch": ep, "val": stats}, out_dir / "best_ema.pt")
    
    if best_stats is None:
        eval_model = ema_model if use_ema and ema_model is not None else model
        best_stats = validate_retrieval(eval_model, Xva, val_gallery, cap2gal_local, pooling, n_patches)
        torch.save({"model": model.state_dict(), "epoch": 0, "val": best_stats}, out_dir / "best.pt")
    
    (out_dir / "val_metrics.json").write_text(json.dumps(dict(best_epoch=best_ep, **best_stats), indent=2))
    return best_stats


# =========================
#   ENHANCED MAIN
# =========================
def main(args):
    """
    Enhanced two-stage pipeline with:
    - Variance-weighted Procrustes
    - SVD-based orthogonalization
    - Smoother schedules
    - EMA
    - Better hard mining
    """
    import json
    from pathlib import Path
    import numpy as np
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    out_dir, seed = args.out_dir, args.seed
    pooling, n_patches = args.pooling, None
    OUT = Path(f"/kaggle/working/outputs/{out_dir}")
    OUT.mkdir(parents=True, exist_ok=True)
    STAGE1 = OUT / "stage1"
    STAGE2 = OUT / "stage2"
    STAGE1.mkdir(parents=True, exist_ok=True)
    STAGE2.mkdir(parents=True, exist_ok=True)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load data
    X, Y, cap_ids, img_ids_row, full_img, all_img_ids = load_train(OUT)
    
    # Image-level split (no leakage)
    cap_is_tr, cap_is_val, val_gallery, cap2gal_local, val_img_indices = build_image_id_split(
        img_ids_row, all_img_ids, full_img, X, Y, args.val_ratio, seed, OUT
    )
    
    val_set = set(all_img_ids[val_img_indices])
    train_set = set(all_img_ids) - val_set
    assert val_set.isdisjoint(train_set), "Leakage detected!"
    
    Xtr, Ytr = X[cap_is_tr], Y[cap_is_tr]
    Xva = X[cap_is_val]
    
    din, dout = X.shape[1], Y.shape[1]
    assert din == 1024 and dout == 1536
    
    # Enhanced Procrustes with variance weighting
    print("[geometry] Computing variance-weighted centroid Procrustes...")
    A, b = procrustes_closed_form_centroids_weighted(
        Xtr.astype(np.float32),
        Ytr.astype(np.float32),
        img_ids_row[cap_is_tr],
        eps=float(args.geom_eps),
        use_variance_weighting=True,
    )
    
    # Build model with SVD orthogonalization
    class GeomFromWeights(nn.Module):
        def __init__(self, A_np, b_np):
            super().__init__()
            self.fc = nn.Linear(A_np.shape[1], A_np.shape[0], bias=True)
            with torch.no_grad():
                self.fc.weight.copy_(torch.from_numpy(A_np))
                self.fc.bias.copy_(torch.from_numpy(b_np))
        def forward(self, x):
            return self.fc(x)
    
    geom = GeomFromWeights(A, b).to(device)
    for p in geom.parameters():
        p.requires_grad = False
    
    adapter = ResidualAdapter(
        din=din, hidden=1024, dout=dout, pdrop=0.1,
        init_scale=0.25,  # Larger init
        use_layernorm=True,
    )
    
    class GeomPlusAdapter(nn.Module):
        def __init__(self, g, a):
            super().__init__()
            self.geom = g
            self.adapter = a
        def forward(self, x):
            return self.geom(x) + self.adapter(x)
    
    model = GeomPlusAdapter(geom, adapter).to(device)
    
    with torch.no_grad():
        _probe = model(torch.zeros(2, din, device=device))
    assert _probe.shape[-1] == 1536
    
    # Stage 1: Enhanced warmup
    if not args.eval_only:
        print(f"\n{'='*60}")
        print(f"[STAGE 1] Enhanced warmup with smoother scheduling")
        print(f"{'='*60}")
        print(f"τ_fixed={args.tau:.3f} | triplet_w={args.triplet_w} | margin={args.triplet_margin}")
        print(f"epochs={args.stage1_epochs} | warmup_epochs=3 | lr={args.lr}")
        
        best_stats_s1 = train_simple_enhanced(
            model,
            Xtr, Ytr, img_ids_row[cap_is_tr],
            Xva, val_gallery, cap2gal_local,
            batch=args.batch,
            epochs=int(args.stage1_epochs),
            lr=args.lr,
            wd=args.wd,
            tau_fixed=float(args.tau),
            triplet_margin=float(args.triplet_margin),
            triplet_w=float(args.triplet_w),
            alpha_cos=float(args.alpha),
            moment_w=float(args.moment),
            use_pk=bool(args.use_pk),
            P=int(args.P),
            K_pk=int(args.K_pk),
            device=device,
            pooling=pooling,
            n_patches=n_patches,
            out_dir=str(STAGE1),
            seed=seed,
            warmup_epochs=3,
            use_cosine_schedule=True,
            gradient_clip=1.0,
        )
        
        # Load best from stage 1
        try:
            ckpt1 = torch.load(STAGE1 / "best.pt", map_location=device)
            model.load_state_dict(ckpt1["model"])
            print(f"\n[stage1] Loaded best checkpoint: epoch={ckpt1.get('epoch','?')} MRR={ckpt1.get('val',{}).get('MRR','?'):.4f}")
        except FileNotFoundError:
            print("[stage1] Warning: No checkpoint found, continuing with current weights")
    else:
        best_stats_s1 = None
    
    # Stage 2: Enhanced BigJump
    if not args.eval_only and args.stage2_epochs > 0:
        print(f"\n{'='*60}")
        print(f"[STAGE 2] Enhanced BigJump++ with curriculum & EMA")
        print(f"{'='*60}")
        print(f"queue_warmup={args.queue_warmup_epochs} | queue_schedule={args.queue_recent_k}")
        print(f"mine_H={args.mine_H} (curriculum) | τ_sched={args.tau_start:.3f}→{args.tau_end:.3f}")
        print(f"agree={args.lambda_agree} | label_smoothing={args.label_smoothing}")
        
        best_stats_s2 = train_bigjump_enhanced(
            model,
            Xtr, Ytr, img_ids_row[cap_is_tr],
            Xva, val_gallery, cap2gal_local,
            batch=args.batch,
            epochs=int(args.stage2_epochs),
            base_lr=args.lr * 0.5,  # Lower LR for stage 2
            wd=args.wd,
            tau_start=float(args.tau_start),
            tau_end=float(args.tau_end),
            queue_size=int(args.queue),
            queue_warmup_epochs=int(args.queue_warmup_epochs),
            queue_recent_schedule=tuple(args.queue_recent_k),
            mine_H=int(args.mine_H),
            lambda_agree=float(args.lambda_agree),
            alpha_cos=float(args.alpha_stage2),
            lambda_moment=float(args.moment_stage2),
            use_pk=bool(args.use_pk),
            P=int(args.P),
            K_pk=int(args.K_pk),
            xbp_per_img=int(args.xbp_per_img),
            xbp_global=int(args.xbp_global),
            use_dcl=bool(args.use_dcl),
            dcl_prior=float(args.dcl_prior),
            device=device,
            pooling=pooling,
            n_patches=n_patches,
            out_dir=str(STAGE2),
            seed=seed,
            use_ema=args.use_ema,
            ema_decay=args.ema_decay,
            gradient_clip=1.0,
            label_smoothing=args.label_smoothing,
        )
        
        # Load best from stage 2 (prefer EMA if available)
        try:
            if args.use_ema and (STAGE2 / "best_ema.pt").exists():
                ckpt2 = torch.load(STAGE2 / "best_ema.pt", map_location=device)
                print(f"[stage2] Loading EMA checkpoint")
            else:
                ckpt2 = torch.load(STAGE2 / "best.pt", map_location=device)
            model.load_state_dict(ckpt2["model"])
            best_stats_final = ckpt2.get("val", None)
            print(f"\n[stage2] Loaded best checkpoint: epoch={ckpt2.get('epoch','?')} MRR={ckpt2.get('val',{}).get('MRR','?'):.4f}")
        except FileNotFoundError:
            best_stats_final = best_stats_s1
            print("[stage2] Warning: No checkpoint found, using stage1 results")
    else:
        best_stats_final = None
        if args.eval_only:
            for cand in [STAGE2 / "best_ema.pt", STAGE2 / "best.pt", STAGE1 / "best.pt", OUT / "best.pt"]:
                try:
                    ckpt = torch.load(cand, map_location=device)
                    model.load_state_dict(ckpt["model"])
                    best_stats_final = ckpt.get("val", None)
                    print(f"[resume] Loaded {cand} | epoch={ckpt.get('epoch','?')} MRR={ckpt.get('val',{}).get('MRR','?'):.4f}")
                    break
                except FileNotFoundError:
                    continue
    
    # Efficiency logging
    params, mb = count_params_mb(model)
    ms_gpu, ms_cpu = time_ms_per_query(model, din, pooling, n_patches)
    eff = {
        "params": params,
        "mb_fp32": mb,
        "ms_per_query_gpu": ms_gpu,
        "ms_per_query_cpu": ms_cpu
    }
    (OUT / "efficiency.json").write_text(json.dumps(eff, indent=2))
    
    # Generate submission
    print(f"\n{'='*60}")
    print("[SUBMISSION] Generating predictions...")
    print(f"{'='*60}")
    
    test_data = load_data(TEST_NPZ)
    Q = test_data["captions/embeddings"].astype(np.float32)
    ids = test_data.get("captions/ids", np.arange(len(Q)).astype(str))
    
    model.eval()
    BS, outs = 1024, []
    with torch.no_grad():
        for i in range(0, len(Q), BS):
            q = torch.from_numpy(Q[i:i+BS]).to(device)
            q = apply_pooling(q, pooling, n_patches)
            z = model(q)
            z = F.normalize(z, dim=-1)
            outs.append(z.detach().cpu().numpy())
    
    pred_embds = np.concatenate(outs, axis=0)
    sub = OUT / "submission.csv"
    generate_submission(ids, pred_embds, str(sub))
    print(f"✓ Submission saved to {sub}")
    
    # Final validation
    if best_stats_final is None:
        best_stats_final = validate_retrieval(model, Xva, val_gallery, cap2gal_local, pooling, n_patches)
    
    # Summary report
    print(f"\n{'='*60}")
    print("[FINAL REPORT]")
    print(f"{'='*60}")
    
    sanity = {
        "dims": {"text": int(din), "image": int(dout)},
        "split": {
            "train_captions": int(cap_is_tr.sum()),
            "val_captions": int(cap_is_val.sum()),
            "val_unique_images": int(val_gallery.shape[0]),
            "leakage": False
        },
        "val_metrics": best_stats_final,
        "efficiency": eff,
        "recipe": {
            "geometry": {
                "type": "variance_weighted_centroid_procrustes",
                "eps": float(args.geom_eps),
                "frozen": True
            },
            "stage1": {
                "type": "enhanced_warmup_infoNCE+triplet",
                "tau_fixed": float(args.tau),
                "triplet_margin": float(args.triplet_margin),
                "triplet_w": float(args.triplet_w),
                "alpha_cos": float(args.alpha),
                "moment_w": float(args.moment),
                "epochs": int(args.stage1_epochs),
                "warmup_epochs": 3,
                "lr_schedule": "cosine_with_warmup",
                "use_pk": bool(args.use_pk),
                "P": int(args.P),
                "K": int(args.K_pk)
            },
            "stage2": {
                "type": "bigjump_plus_plus_curriculum_ema",
                "epochs": int(args.stage2_epochs),
                "tau_sched": [float(args.tau_start), float(args.tau_end)],
                "queue_warmup_epochs": int(args.queue_warmup_epochs),
                "queue_recent_schedule": list(map(int, args.queue_recent_k)),
                "mine_H": int(args.mine_H),
                "mine_curriculum": True,
                "lambda_agree": float(args.lambda_agree),
                "label_smoothing": float(args.label_smoothing),
                "use_ema": bool(args.use_ema),
                "ema_decay": float(args.ema_decay),
                "alpha_cos": float(args.alpha_stage2),
                "moment": float(args.moment_stage2)
            },
            "adapter": {
                "hidden": 1024,
                "dropout": 0.1,
                "init_scale": 0.25,
                "layernorm": True,
                "orthogonalization": "svd"
            }
        }
    }
    
    print(json.dumps(sanity, indent=2))
    (OUT / "summary.json").write_text(json.dumps(sanity, indent=2))
    
    print(f"\n{'='*60}")
    print(f"✓ Pipeline complete!")
    print(f"  Best MRR: {best_stats_final['MRR']:.4f}")
    print(f"  R@1: {best_stats_final['R1']:.3f}")
    print(f"  R@5: {best_stats_final['R5']:.3f}")
    print(f"  R@10: {best_stats_final['R10']:.3f}")
    print(f"  Median rank: {best_stats_final['rank_median']}")
    print(f"{'='*60}\n")


# =========================
#   ARGPARSE (Enhanced)
# =========================
if __name__ == "__main__":
    import argparse
    
    p = argparse.ArgumentParser(
        description="Enhanced RoBERTa→DINOv2 | 2-Stage with Curriculum, EMA & Better Optimization",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    
    # IO & run mode
    p.add_argument("--out_dir", type=str, default="enhanced_two_stage",
                   help="Output directory name")
    p.add_argument("--seed", type=int, default=42,
                   help="Random seed")
    p.add_argument("--eval_only", action="store_true",
                   help="Skip training, only evaluate and generate submission")
    
    # Data & split
    p.add_argument("--val_ratio", type=float, default=0.10,
                   help="Validation split ratio (image-level)")
    
    # Architecture
    p.add_argument("--arch", type=str, default="bigjump",
                   choices=["linear", "mlp1", "mlp2", "geom", "geom_adapter", "geom+adapter", "bigjump"],
                   help="Model architecture (internal use)")
    
    # Optimization (shared)
    p.add_argument("--batch", type=int, default=512,
                   help="Batch size")
    p.add_argument("--lr", type=float, default=2e-4,
                   help="Base learning rate")
    p.add_argument("--wd", type=float, default=1e-4,
                   help="Weight decay")
    
    # Geometry initialization
    p.add_argument("--geom_eps", type=float, default=1e-5,
                   help="Regularization epsilon for Procrustes covariance")
    
    # Pooling
    p.add_argument("--pooling", type=str, default="CLS", choices=["CLS"],
                   help="Pooling strategy (CLS only)")
    
    # ==================== STAGE 1 ====================
    stage1_group = p.add_argument_group('Stage 1 (Warmup)', 'Enhanced warmup with fixed-τ InfoNCE + Triplet')
    stage1_group.add_argument("--stage1_epochs", type=int, default=15,
                              help="Stage 1 epochs (increased from 12)")
    stage1_group.add_argument("--tau", type=float, default=0.07,
                              help="Fixed InfoNCE temperature")
    stage1_group.add_argument("--triplet_margin", type=float, default=0.2,
                              help="Triplet loss margin")
    stage1_group.add_argument("--triplet_w", type=float, default=0.3,
                              help="Triplet weight (InfoNCE weight = 1 - triplet_w)")
    stage1_group.add_argument("--alpha", type=float, default=0.0,
                              help="Cosine auxiliary weight")
    stage1_group.add_argument("--moment", type=float, default=0.0,
                              help="Moment alignment weight")
    
    # ==================== STAGE 2 ====================
    stage2_group = p.add_argument_group('Stage 2 (BigJump++)', 'Enhanced contrastive learning with curriculum')
    stage2_group.add_argument("--stage2_epochs", type=int, default=18,
                              help="Stage 2 epochs (increased from 13)")
    stage2_group.add_argument("--tau_start", type=float, default=0.10,
                              help="Starting temperature (smoother than before)")
    stage2_group.add_argument("--tau_end", type=float, default=0.05,
                              help="Ending temperature (lower for harder negatives)")
    stage2_group.add_argument("--queue", type=int, default=65536,
                              help="Queue capacity")
    stage2_group.add_argument("--queue_warmup_epochs", type=int, default=5,
                              help="Epochs before enabling queue (increased from 3)")
    stage2_group.add_argument("--queue_recent_k", type=int, nargs='+', 
                              default=[8000, 16000, 32000, 65536],
                              help="Queue size schedule (4 stages)")
    stage2_group.add_argument("--mine_H", type=int, default=64,
                              help="Top-H hard negatives (doubled from 32)")
    stage2_group.add_argument("--lambda_agree", type=float, default=0.02,
                              help="Caption agreement weight")
    stage2_group.add_argument("--alpha_stage2", type=float, default=0.2,
                              help="Cosine auxiliary in Stage 2")
    stage2_group.add_argument("--moment_stage2", type=float, default=0.01,
                              help="Moment alignment in Stage 2")
    stage2_group.add_argument("--label_smoothing", type=float, default=0.05,
                              help="Label smoothing (NEW)")
    
    # ==================== EMA ====================
    ema_group = p.add_argument_group('EMA', 'Exponential Moving Average for stability')
    ema_group.add_argument("--use_ema", action="store_true", default=True,
                           help="Use EMA model (recommended)")
    ema_group.add_argument("--ema_decay", type=float, default=0.999,
                           help="EMA decay rate")
    
    # ==================== PK SAMPLER ====================
    pk_group = p.add_argument_group('PK Sampler', 'P images × K captions per batch')
    pk_group.add_argument("--use_pk", action="store_true",
                          help="Use P×K batch sampling")
    pk_group.add_argument("--P", type=int, default=192,
                          help="Number of unique images per batch")
    pk_group.add_argument("--K_pk", type=int, default=3,
                          help="Number of captions per image")
    
    # ==================== XBP & DCL ====================
    advanced_group = p.add_argument_group('Advanced Features', 'XBP buffer and DCL')
    advanced_group.add_argument("--xbp_per_img", type=int, default=6,
                                help="XBP buffer size per image (increased)")
    advanced_group.add_argument("--xbp_global", type=int, default=48000,
                                help="XBP global capacity (increased)")
    advanced_group.add_argument("--use_dcl", action="store_true",
                                help="Use debiased contrastive learning")
    advanced_group.add_argument("--dcl_prior", type=float, default=0.01,
                                help="DCL class prior")
    
    # Legacy args (kept for compatibility)
    p.add_argument("--epochs", type=int, default=0,
                   help="(unused, kept for compatibility)")
    p.add_argument("--geom_lr_scale", type=float, default=0.05,
                   help="(unused, geometry frozen)")
    p.add_argument("--queue_pred_capacity", type=int, default=65536,
                   help="(unused)")
    
    args, _ = p.parse_known_args()
    
    # Validate arguments
    if len(args.queue_recent_k) < 2:
        print("Warning: queue_recent_k should have at least 2 values for gradual schedule")
        args.queue_recent_k = [8000, 16000, 32000, 65536]
    
    main(args)

[meta] text_dim=1024 | image_dim=1536
[geometry] Computing variance-weighted centroid Procrustes...

[STAGE 1] Enhanced warmup with smoother scheduling
τ_fixed=0.070 | triplet_w=0.3 | margin=0.2
epochs=15 | warmup_epochs=3 | lr=0.0002
[simple 01] loss=2.491584 lr=0.000067 | val_MRR=0.3375 | R@1=0.216 R@5=0.473 R@10=0.593 | median=6 p75=27
[simple 02] loss=2.117398 lr=0.000133 | val_MRR=0.3561 | R@1=0.232 R@5=0.493 R@10=0.620 | median=6 p75=23
[simple 03] loss=2.036331 lr=0.000200 | val_MRR=0.3671 | R@1=0.240 R@5=0.506 R@10=0.631 | median=5 p75=20
[simple 04] loss=1.980336 lr=0.000197 | val_MRR=0.3740 | R@1=0.246 R@5=0.515 R@10=0.644 | median=5 p75=19
[simple 05] loss=1.937702 lr=0.000187 | val_MRR=0.3814 | R@1=0.253 R@5=0.524 R@10=0.653 | median=5 p75=18
[simple 06] loss=1.904336 lr=0.000171 | val_MRR=0.3864 | R@1=0.258 R@5=0.528 R@10=0.656 | median=5 p75=18
[simple 07] loss=1.878786 lr=0.000150 | val_MRR=0.3902 | R@1=0.260 R@5=0.537 R@10=0.663 | median=5 p75=17
[simple 08] loss=1.8565

In [93]:
# =========================
#   UPDATED ARGPARSE (KAGGLE NOTEBOOK FIX)
# =========================
if __name__ == "__main__":
    import argparse
    
    p = argparse.ArgumentParser(
        description="Enhanced RoBERTa→DINOv2 with Normalizing Flows",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    
    # IO
    p.add_argument("--out_dir", type=str, default="flow_enhanced")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--eval_only", action="store_true")
    p.add_argument("--val_ratio", type=float, default=0.10)
    p.add_argument("--arch", type=str, default="bigjump")
    
    # Optimization
    p.add_argument("--batch", type=int, default=512)
    p.add_argument("--lr", type=float, default=2e-4)
    p.add_argument("--lr_stage2", type=float, default=1e-4)
    p.add_argument("--wd", type=float, default=1e-4)
    p.add_argument("--geom_eps", type=float, default=1e-5)
    p.add_argument("--pooling", type=str, default="CLS")
    
    # Flow
    p.add_argument("--use_flow", action="store_true", default=True)
    p.add_argument("--flow_layers", type=int, default=3)
    p.add_argument("--lambda_flow", type=float, default=0.01)
    
    # Stage 1
    p.add_argument("--stage1_epochs", type=int, default=15)
    p.add_argument("--tau", type=float, default=0.07)
    p.add_argument("--triplet_margin", type=float, default=0.2)
    p.add_argument("--triplet_w", type=float, default=0.3)
    p.add_argument("--alpha", type=float, default=0.0)
    p.add_argument("--moment", type=float, default=0.0)
    
    # Stage 2
    p.add_argument("--stage2_epochs", type=int, default=24)
    p.add_argument("--tau_start", type=float, default=0.08)
    p.add_argument("--tau_end", type=float, default=0.05)
    p.add_argument("--queue", type=int, default=65536)
    p.add_argument("--queue_warmup_epochs", type=int, default=8)
    p.add_argument("--queue_recent_k", type=int, nargs='+',
                    default=[500, 1000, 2000, 4000, 8000, 16000, 32000, 65536])
    p.add_argument("--mine_H", type=int, default=64)
    p.add_argument("--lambda_agree", type=float, default=0.015)
    p.add_argument("--alpha_stage2", type=float, default=0.15)
    p.add_argument("--moment_stage2", type=float, default=0.005)
    p.add_argument("--label_smoothing", type=float, default=0.03)
    
    # EMA
    p.add_argument("--use_ema", action="store_true", default=None)
    p.add_argument("--ema_decay", type=float, default=0)
    
    # PK Sampler
    p.add_argument("--use_pk", action="store_true")
    p.add_argument("--P", type=int, default=192)
    p.add_argument("--K_pk", type=int, default=3)
    
    # XBP & DCL
    p.add_argument("--xbp_per_img", type=int, default=6)
    p.add_argument("--xbp_global", type=int, default=48000)
    p.add_argument("--use_dcl", action="store_true")
    p.add_argument("--dcl_prior", type=float, default=0.01)
    
    # Legacy
    p.add_argument("--epochs", type=int, default=0)
    p.add_argument("--geom_lr_scale", type=float, default=0.05)
    p.add_argument("--queue_pred_capacity", type=int, default=65536)
    
    # FIXED: parse_known_args returns tuple (args, unknown)
    args, unknown = p.parse_known_args()
    
    # Validate and fix queue schedule if needed
    if not hasattr(args, 'queue_recent_k') or args.queue_recent_k is None:
        args.queue_recent_k = [500, 1000, 2000, 4000, 8000, 16000, 32000, 65536]
    elif len(args.queue_recent_k) < 4:
        print("Warning: queue_recent_k too short, using default 8-stage schedule")
        args.queue_recent_k = [500, 1000, 2000, 4000, 8000, 16000, 32000, 65536]
    
    main(args)

[meta] text_dim=1024 | image_dim=1536
[geometry] Computing variance-weighted centroid Procrustes...

[STAGE 1] Enhanced warmup
[simple 01] loss=2.471956 lr=0.000067 | val_MRR=0.3408 | R@1=0.219 R@5=0.475 R@10=0.598 | median=6 p75=27
[simple 02] loss=2.103471 lr=0.000133 | val_MRR=0.3584 | R@1=0.234 R@5=0.494 R@10=0.622 | median=6 p75=22
[simple 03] loss=2.011080 lr=0.000200 | val_MRR=0.3706 | R@1=0.244 R@5=0.511 R@10=0.637 | median=5 p75=20
[simple 04] loss=1.940317 lr=0.000197 | val_MRR=0.3772 | R@1=0.249 R@5=0.517 R@10=0.648 | median=5 p75=18
[simple 05] loss=1.881534 lr=0.000187 | val_MRR=0.3823 | R@1=0.251 R@5=0.530 R@10=0.659 | median=5 p75=17
[simple 06] loss=1.832789 lr=0.000171 | val_MRR=0.3914 | R@1=0.262 R@5=0.536 R@10=0.671 | median=5 p75=17
[simple 07] loss=1.793071 lr=0.000150 | val_MRR=0.3947 | R@1=0.263 R@5=0.543 R@10=0.675 | median=5 p75=16
[simple 08] loss=1.757707 lr=0.000126 | val_MRR=0.3995 | R@1=0.269 R@5=0.547 R@10=0.677 | median=4 p75=16
[simple 09] loss=1.728539

In [94]:
# =========================
#   ENSEMBLE PREDICTION + TEST-TIME AUGMENTATION
# =========================
def generate_submission_enhanced(model, test_data, pooling, n_patches, device, out_path, num_tta=5):
    """
    Enhanced submission with:
    - Test-Time Augmentation (TTA)
    - Multiple forward passes with dropout
    - Ensemble averaging
    """
    Q = test_data["captions/embeddings"].astype(np.float32)
    ids = test_data.get("captions/ids", np.arange(len(Q)).astype(str))
    
    model.eval()
    # Enable dropout during inference for TTA
    if hasattr(model, 'adapter'):
        for m in model.adapter.modules():
            if isinstance(m, nn.Dropout):
                m.train()
    
    BS = 1024
    all_predictions = []
    
    print(f"[TTA] Running {num_tta} augmented forward passes...")
    
    for tta_idx in range(num_tta):
        outs = []
        with torch.no_grad():
            for i in range(0, len(Q), BS):
                q = torch.from_numpy(Q[i:i+BS]).to(device)
                
                # Add slight noise for augmentation
                if tta_idx > 0:
                    noise = torch.randn_like(q) * 0.01
                    q = q + noise
                
                q = apply_pooling(q, pooling, n_patches)
                z = model(q)
                z = F.normalize(z, dim=-1)
                outs.append(z.cpu().numpy())
        
        pred = np.concatenate(outs, axis=0)
        all_predictions.append(pred)
    
    # Average all TTA predictions and re-normalize
    ensemble_pred = np.mean(all_predictions, axis=0)
    ensemble_pred = ensemble_pred / (np.linalg.norm(ensemble_pred, axis=1, keepdims=True) + 1e-8)
    
    generate_submission(ids, ensemble_pred, str(out_path))
    print(f"✓ Enhanced submission with {num_tta}x TTA: {out_path}")
    return ensemble_pred


# =========================
#   SELF-DISTILLATION LOSS
# =========================
def self_distillation_loss(student_logits, teacher_logits, temperature=2.0):
    """
    Distill knowledge from teacher (EMA) to student (current model)
    """
    student_log_probs = F.log_softmax(student_logits / temperature, dim=1)
    teacher_probs = F.softmax(teacher_logits / temperature, dim=1)
    return F.kl_div(student_log_probs, teacher_probs, reduction='batchmean') * (temperature ** 2)


# =========================
#   MIXUP AUGMENTATION
# =========================
def mixup_data(x, y, alpha=0.4):
    """Mixup augmentation for robustness"""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    
    batch_size = x.size(0)
    index = torch.randperm(batch_size, device=x.device)
    
    mixed_x = lam * x + (1 - lam) * x[index]
    mixed_y = lam * y + (1 - lam) * y[index]
    
    return mixed_x, mixed_y, lam


# =========================
#   ULTRA-ENHANCED STAGE 2 (ALL TRICKS)
# =========================
def train_bigjump_ultra(
    model,
    Xtr, Ytr, img_ids_row_tr,
    Xva, val_gallery, cap2gal_local,
    batch=512, epochs=28, base_lr=8e-5, wd=1e-4,
    tau_start=0.09, tau_end=0.04,
    queue_size=98304,
    queue_warmup_epochs=6,
    queue_recent_schedule=(1000, 2000, 4000, 8000, 16000, 32000, 65536, 98304),
    mine_H=96,
    lambda_agree=0.02,
    alpha_cos=0.1,
    lambda_moment=0.003,
    lambda_flow=0.015,
    lambda_distill=0.1,
    use_mixup=True,
    mixup_alpha=0.3,
    use_pk=False, P=192, K_pk=3,
    xbp_per_img=8,
    xbp_global=65536,
    use_dcl=True,
    dcl_prior=0.015,
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    pooling="none", n_patches=None,
    out_dir=None,
    seed=42,
    use_ema=True,
    ema_decay=0.9997,
    gradient_clip=1.5,
    label_smoothing=0.05,
):
    """
    ULTRA-ENHANCED trainer with ALL techniques:
    - Self-distillation from EMA
    - Mixup augmentation
    - Larger queue & harder mining
    - DCL enabled
    - Higher flow regularization
    """
    import json
    from pathlib import Path
    from torch.utils.data import DataLoader
    
    out_dir = Path(out_dir) if out_dir is not None else Path("./outputs/ultra")
    out_dir.mkdir(parents=True, exist_ok=True)
    seed_all(seed)
    model.to(device)
    
    has_flow = hasattr(model, 'adapter') and hasattr(model.adapter, 'flow') and model.adapter.flow is not None
    
    class TripletSet(torch.utils.data.Dataset):
        def __init__(self, X_np, Y_np, names_np):
            self.X = torch.from_numpy(X_np).float()
            self.Y = torch.from_numpy(Y_np).float()
            self.N = [str(n) for n in names_np.tolist()]
        def __len__(self):
            return len(self.X)
        def __getitem__(self, i):
            return self.X[i], self.Y[i], self.N[i]
    
    ds = TripletSet(Xtr, Ytr, img_ids_row_tr)
    
    if use_pk:
        pk_sampler = PKBatchSampler(img_ids_row_tr.tolist(), P=int(P), K=int(K_pk), drop_last=True, seed=seed)
        dl = DataLoader(ds, batch_sampler=pk_sampler, num_workers=2, pin_memory=True)
    else:
        dl = DataLoader(ds, batch_size=batch, shuffle=True, num_workers=2, pin_memory=True, drop_last=False)
    
    params = [p for p in model.parameters() if p.requires_grad]
    opt = torch.optim.AdamW(params, lr=base_lr, weight_decay=wd, betas=(0.9, 0.999), eps=1e-8)
    
    # Stable schedule
    total_steps = epochs * len(dl)
    def lr_lambda(step):
        warmup_steps = total_steps // 8
        if step < warmup_steps:
            return step / warmup_steps
        progress = (step - warmup_steps) / (total_steps - warmup_steps)
        return 0.15 + 0.85 * 0.5 * (1 + np.cos(np.pi * progress))
    
    sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda)
    
    # EMA
    ema_model = None
    if use_ema:
        import copy
        ema_model = copy.deepcopy(model)
        for p in ema_model.parameters():
            p.requires_grad = False
    
    # Queues
    img_queue = ImgQueue(dim=1536, capacity=queue_size, device=device, track_diversity=True)
    xbp_buffer = XBPBuffer(per_image_cap=xbp_per_img, global_cap=xbp_global, device=device)
    
    best_stats, best_mrr, best_ep = None, -1.0, 0
    
    for ep in range(1, epochs + 1):
        model.train()
        running_loss = 0.0
        
        # Temperature schedule (slower decay)
        progress = (ep - 1) / max(1, epochs - 1)
        tau = tau_start * (1 - progress**0.7) + tau_end * (progress**0.7)
        
        # Queue schedule
        if ep <= queue_warmup_epochs:
            recent_k = None
        else:
            queue_progress = (ep - queue_warmup_epochs) / max(1, epochs - queue_warmup_epochs)
            schedule_idx = int(queue_progress * len(queue_recent_schedule))
            schedule_idx = min(schedule_idx, len(queue_recent_schedule) - 1)
            recent_k = queue_recent_schedule[schedule_idx]
        
        for xb, yb, names_b in dl:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)
            
            # Mixup augmentation
            if use_mixup and np.random.rand() < 0.4:
                xb, yb, lam = mixup_data(xb, yb, alpha=mixup_alpha)
            
            opt.zero_grad(set_to_none=True)
            
            # Forward
            if has_flow:
                pred, flow_logdet = model.adapter(xb, return_flow_logdet=True)
                if hasattr(model, 'geom'):
                    pred = model.geom(xb) + pred
            else:
                pred = model(xb)
                flow_logdet = None
            
            pred_norm = F.normalize(pred, dim=-1)
            
            # Teacher prediction for distillation
            if use_ema and ema_model is not None and ep > queue_warmup_epochs:
                with torch.no_grad():
                    if has_flow:
                        teacher_pred, _ = ema_model.adapter(xb, return_flow_logdet=True)
                        if hasattr(ema_model, 'geom'):
                            teacher_pred = ema_model.geom(xb) + teacher_pred
                    else:
                        teacher_pred = ema_model(xb)
                    teacher_pred_norm = F.normalize(teacher_pred, dim=-1)
            else:
                teacher_pred_norm = None
            
            # Queue negatives
            if recent_k is not None and img_queue.size() > 0:
                queue_feats, queue_ids = img_queue.get_with_ids()
                if queue_feats.numel() > 0:
                    queue_feats = queue_feats[-recent_k:] if recent_k < queue_feats.size(0) else queue_feats
                    queue_ids = queue_ids[-recent_k:] if recent_k < queue_ids.size(0) else queue_ids
            else:
                queue_feats, queue_ids = None, None
            
            # XBP
            xbp_bank, xbp_pos_cols = xbp_buffer.build_bank_and_mask(names_b, device)
            
            # Main loss
            loss_ce = info_nce_multi_enhanced(
                pred, yb, names_b,
                queue_feats=queue_feats,
                queue_img_ids=queue_ids,
                tau=tau,
                xbp_bank=xbp_bank,
                xbp_pos_cols_per_row=xbp_pos_cols,
                hard_subset_H=mine_H,
                use_dcl=use_dcl,
                dcl_prior=dcl_prior,
                epoch=ep,
                total_epochs=epochs,
                label_smoothing=label_smoothing,
            )
            
            # Self-distillation
            loss_distill = pred.new_tensor(0.0)
            if teacher_pred_norm is not None and lambda_distill > 0:
                student_logits = pred_norm @ pred_norm.t()
                teacher_logits = teacher_pred_norm @ teacher_pred_norm.t()
                loss_distill = self_distillation_loss(student_logits, teacher_logits, temperature=3.0)
            
            # Auxiliary losses
            loss_agree = agreement_loss(names_b, pred, temperature=0.4) if lambda_agree > 0 else pred.new_tensor(0.0)
            loss_cos = (1.0 - F.cosine_similarity(pred, yb, dim=-1)).mean() if alpha_cos > 0 else pred.new_tensor(0.0)
            loss_mu = moment_align(pred, yb) if lambda_moment > 0 else pred.new_tensor(0.0)
            
            # Flow regularization
            loss_flow = pred.new_tensor(0.0)
            if has_flow and flow_logdet is not None and lambda_flow > 0:
                loss_flow = torch.mean(torch.abs(flow_logdet)) * lambda_flow
            
            # Total loss
            loss = loss_ce + lambda_distill * loss_distill + lambda_agree * loss_agree + \
                   alpha_cos * loss_cos + lambda_moment * loss_mu + loss_flow
            
            loss.backward()
            
            if gradient_clip > 0:
                torch.nn.utils.clip_grad_norm_(params, gradient_clip)
            
            opt.step()
            sched.step()
            
            # Update EMA
            if use_ema and ema_model is not None:
                with torch.no_grad():
                    for ema_p, model_p in zip(ema_model.parameters(), model.parameters()):
                        ema_p.data.mul_(ema_decay).add_(model_p.data, alpha=1 - ema_decay)
            
            # Update queues
            with torch.no_grad():
                img_queue.enqueue(F.normalize(yb, dim=-1), names_b)
                xbp_buffer.add(names_b, pred)
            
            running_loss += loss.item() * xb.size(0)
        
        # Validation
        eval_model = ema_model if use_ema and ema_model is not None else model
        stats = validate_retrieval(eval_model, Xva, val_gallery, cap2gal_local, pooling, n_patches)
        
        ep_loss = running_loss / len(ds)
        current_lr = opt.param_groups[0]['lr']
        
        print(f"[ultra {ep:02d}] loss={ep_loss:.6f} lr={current_lr:.6f} | MRR={stats['MRR']:.4f} "
              f"R@1={stats['R1']:.3f} R@5={stats['R5']:.3f} R@10={stats['R10']:.3f} "
              f"| med={stats['rank_median']} p75={stats['rank_p75']} | tau={tau:.3f} | recentK={recent_k or 0}")
        
        if stats["MRR"] > best_mrr:
            best_mrr, best_ep, best_stats = stats["MRR"], ep, stats
            torch.save({"model": model.state_dict(), "epoch": ep, "val": stats}, out_dir / "best.pt")
            if use_ema and ema_model is not None:
                torch.save({"model": ema_model.state_dict(), "epoch": ep, "val": stats}, out_dir / "best_ema.pt")
    
    if best_stats is None:
        eval_model = ema_model if use_ema and ema_model is not None else model
        best_stats = validate_retrieval(eval_model, Xva, val_gallery, cap2gal_local, pooling, n_patches)
    
    (out_dir / "val_metrics.json").write_text(json.dumps(dict(best_epoch=best_ep, **best_stats), indent=2))
    return best_stats


# =========================
#   UPDATED MAIN (ULTRA MODE)
# =========================
def main(args):
    """ULTRA mode with all techniques"""
    import json
    from pathlib import Path
    import numpy as np
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    out_dir, seed = args.out_dir, args.seed
    pooling, n_patches = args.pooling, None
    OUT = Path(f"/kaggle/working/outputs/{out_dir}")
    OUT.mkdir(parents=True, exist_ok=True)
    STAGE1 = OUT / "stage1"
    STAGE2 = OUT / "stage2"
    STAGE1.mkdir(parents=True, exist_ok=True)
    STAGE2.mkdir(parents=True, exist_ok=True)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load data
    X, Y, cap_ids, img_ids_row, full_img, all_img_ids = load_train(OUT)
    
    cap_is_tr, cap_is_val, val_gallery, cap2gal_local, val_img_indices = build_image_id_split(
        img_ids_row, all_img_ids, full_img, X, Y, args.val_ratio, seed, OUT
    )
    
    Xtr, Ytr = X[cap_is_tr], Y[cap_is_tr]
    Xva = X[cap_is_val]
    
    din, dout = 1024, 1536
    
    # Geometry
    print("[geometry] Variance-weighted Procrustes...")
    A, b = procrustes_closed_form_centroids_weighted(
        Xtr.astype(np.float32),
        Ytr.astype(np.float32),
        img_ids_row[cap_is_tr],
        eps=float(args.geom_eps),
        use_variance_weighting=True,
    )
    
    # Model
    class GeomFromWeights(nn.Module):
        def __init__(self, A_np, b_np):
            super().__init__()
            self.fc = nn.Linear(A_np.shape[1], A_np.shape[0], bias=True)
            with torch.no_grad():
                self.fc.weight.copy_(torch.from_numpy(A_np))
                self.fc.bias.copy_(torch.from_numpy(b_np))
        def forward(self, x):
            return self.fc(x)
    
    geom = GeomFromWeights(A, b).to(device)
    for p in geom.parameters():
        p.requires_grad = False
    
    adapter = FlowEnhancedAdapter(
        din=din, hidden=1024, dout=dout, pdrop=0.1,
        init_scale=0.3,  # Even larger
        use_layernorm=True,
        use_flow=True,
        flow_layers=4,  # More layers
    )
    
    class GeomPlusAdapter(nn.Module):
        def __init__(self, g, a):
            super().__init__()
            self.geom = g
            self.adapter = a
        def forward(self, x):
            return self.geom(x) + self.adapter(x)
    
    model = GeomPlusAdapter(geom, adapter).to(device)
    
    # Stage 1
    if not args.eval_only:
        print(f"\n{'='*60}")
        print(f"[STAGE 1] Warmup")
        print(f"{'='*60}")
        
        best_stats_s1 = train_simple_enhanced(
            model, Xtr, Ytr, img_ids_row[cap_is_tr],
            Xva, val_gallery, cap2gal_local,
            batch=args.batch, epochs=15, lr=2e-4, wd=args.wd,
            tau_fixed=0.07, triplet_margin=0.2, triplet_w=0.3,
            device=device, pooling=pooling, n_patches=n_patches,
            out_dir=str(STAGE1), seed=seed,
        )
        
        try:
            ckpt1 = torch.load(STAGE1 / "best.pt", map_location=device)
            model.load_state_dict(ckpt1["model"])
            print(f"[stage1] Best MRR={ckpt1.get('val',{}).get('MRR','?'):.4f}")
        except:
            pass
    
    # Stage 2 ULTRA
    if not args.eval_only:
        print(f"\n{'='*60}")
        print(f"[STAGE 2] ULTRA MODE - ALL TECHNIQUES")
        print(f"{'='*60}")
        
        best_stats_s2 = train_bigjump_ultra(
            model, Xtr, Ytr, img_ids_row[cap_is_tr],
            Xva, val_gallery, cap2gal_local,
            batch=args.batch,
            epochs=28,
            base_lr=8e-5,
            wd=args.wd,
            tau_start=0.09,
            tau_end=0.04,
            queue_size=98304,
            queue_warmup_epochs=6,
            queue_recent_schedule=(1000, 2000, 4000, 8000, 16000, 32000, 65536, 98304),
            mine_H=96,
            lambda_agree=0.02,
            alpha_cos=0.1,
            lambda_moment=0.003,
            lambda_flow=0.015,
            lambda_distill=0.1,
            use_mixup=True,
            mixup_alpha=0.3,
            xbp_per_img=8,
            xbp_global=65536,
            use_dcl=True,
            dcl_prior=0.015,
            device=device,
            pooling=pooling,
            n_patches=n_patches,
            out_dir=str(STAGE2),
            seed=seed,
            use_ema=True,
            ema_decay=0.9997,
            gradient_clip=1.5,
            label_smoothing=0.05,
        )
        
        try:
            if (STAGE2 / "best_ema.pt").exists():
                ckpt2 = torch.load(STAGE2 / "best_ema.pt", map_location=device)
            else:
                ckpt2 = torch.load(STAGE2 / "best.pt", map_location=device)
            model.load_state_dict(ckpt2["model"])
            best_stats_final = ckpt2.get("val", None)
            print(f"[stage2] Best MRR={best_stats_final.get('MRR','?'):.4f}")
        except:
            best_stats_final = {}
    else:
        for cand in [STAGE2 / "best_ema.pt", STAGE2 / "best.pt", STAGE1 / "best.pt"]:
            try:
                ckpt = torch.load(cand, map_location=device)
                model.load_state_dict(ckpt["model"])
                best_stats_final = ckpt.get("val", {})
                break
            except:
                continue
    
    # Enhanced submission with TTA
    print(f"\n{'='*60}")
    print("[SUBMISSION] TTA Ensemble")
    print(f"{'='*60}")
    
    test_data = load_data(TEST_NPZ)
    sub = OUT / "submission.csv"
    generate_submission_enhanced(model, test_data, pooling, n_patches, device, sub, num_tta=5)
    
    print(f"\n✓ DONE! Best val MRR: {best_stats_final.get('MRR', 0):.4f}")


# Update argparse with ultra defaults
if __name__ == "__main__":
    import argparse
    p = argparse.ArgumentParser()
    p.add_argument("--out_dir", type=str, default="ultra_final")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--eval_only", action="store_true")
    p.add_argument("--val_ratio", type=float, default=0.10)
    p.add_argument("--batch", type=int, default=512)
    p.add_argument("--wd", type=float, default=1e-4)
    p.add_argument("--geom_eps", type=float, default=1e-5)
    p.add_argument("--pooling", type=str, default="CLS")
    
    args, unknown = p.parse_known_args()
    main(args)

[meta] text_dim=1024 | image_dim=1536
[geometry] Variance-weighted Procrustes...

[STAGE 1] Warmup
[simple 01] loss=2.475841 lr=0.000067 | val_MRR=0.3401 | R@1=0.219 R@5=0.474 R@10=0.595 | median=6 p75=27
[simple 02] loss=2.108213 lr=0.000133 | val_MRR=0.3576 | R@1=0.233 R@5=0.493 R@10=0.620 | median=6 p75=22
[simple 03] loss=2.012987 lr=0.000200 | val_MRR=0.3701 | R@1=0.244 R@5=0.510 R@10=0.637 | median=5 p75=20
[simple 04] loss=1.939724 lr=0.000197 | val_MRR=0.3771 | R@1=0.249 R@5=0.518 R@10=0.648 | median=5 p75=18
[simple 05] loss=1.877276 lr=0.000187 | val_MRR=0.3838 | R@1=0.253 R@5=0.530 R@10=0.662 | median=5 p75=17
[simple 06] loss=1.825146 lr=0.000171 | val_MRR=0.3923 | R@1=0.263 R@5=0.538 R@10=0.672 | median=5 p75=17
[simple 07] loss=1.782118 lr=0.000150 | val_MRR=0.3947 | R@1=0.263 R@5=0.543 R@10=0.676 | median=5 p75=16
[simple 08] loss=1.743567 lr=0.000126 | val_MRR=0.3999 | R@1=0.269 R@5=0.548 R@10=0.678 | median=4 p75=16
[simple 09] loss=1.711573 lr=0.000100 | val_MRR=0.401

In [95]:
# =========================
#   FIX 1: BETTER VALIDATION SPLIT (Match Test Distribution)
# =========================
def build_stratified_image_split(img_ids_row, all_img_names, full_img, X, Y, val_ratio, seed, out_dir):
    """
    Stratified split that better matches test distribution:
    - Balance by caption count per image
    - Ensure diverse image types in validation
    """
    from collections import Counter
    
    # Count captions per image
    img_caption_counts = Counter(map(str, img_ids_row))
    
    # Group images by caption count (1, 2-3, 4-5, 6+)
    def get_bucket(count):
        if count == 1: return 0
        if count <= 3: return 1
        if count <= 5: return 2
        return 3
    
    buckets = {0: [], 1: [], 2: [], 3: []}
    for img_name in set(map(str, all_img_names)):
        count = img_caption_counts[img_name]
        bucket = get_bucket(count)
        buckets[bucket].append(img_name)
    
    # Stratified sampling from each bucket
    rng = np.random.default_rng(seed)
    val_images = []
    for bucket_imgs in buckets.values():
        bucket_imgs = np.array(bucket_imgs)
        rng.shuffle(bucket_imgs)
        n_val = max(1, int(len(bucket_imgs) * val_ratio))
        val_images.extend(bucket_imgs[:n_val])
    
    val_images = set(val_images)
    tr_images = set(map(str, all_img_names)) - val_images
    
    # Build masks
    cap_is_val = np.array([str(iid) in val_images for iid in img_ids_row], dtype=bool)
    cap_is_tr = ~cap_is_val
    
    # Val gallery
    val_img_names_sorted = np.array(sorted(val_images))
    name2local = {name: i for i, name in enumerate(val_img_names_sorted)}
    name2global = {str(n): i for i, n in enumerate(all_img_names)}
    val_img_indices = np.array([name2global[n] for n in val_img_names_sorted], dtype=np.int64)
    val_gallery = full_img[val_img_indices]
    cap2gal_local = np.array([name2local[str(n)] for n in img_ids_row[cap_is_val]], dtype=np.int64)
    
    # Save
    (out_dir/"val_indices.json").write_text(json.dumps({
        "val_img_indices": val_img_indices.tolist(),
        "val_caption_to_gallery_index": cap2gal_local.tolist(),
        "n_val_captions": int(cap_is_val.sum()),
        "n_val_unique_images": int(len(val_images)),
        "bucket_distribution": {k: len(v) for k, v in buckets.items()}
    }, indent=2))
    
    return cap_is_tr, cap_is_val, val_gallery, cap2gal_local, val_img_indices


# =========================
#   FIX 2: PSEUDO-LABELING ON TEST SET
# =========================
def pseudo_label_test_set(model, test_data, train_img_gallery, device, confidence_thresh=0.85):
    """
    Generate pseudo-labels for test set using confident predictions
    """
    Q = test_data["captions/embeddings"].astype(np.float32)
    
    model.eval()
    BS = 1024
    pseudo_targets = []
    confidence_scores = []
    
    train_gallery_tensor = torch.from_numpy(train_img_gallery).to(device)
    train_gallery_tensor = F.normalize(train_gallery_tensor, dim=-1)
    
    with torch.no_grad():
        for i in range(0, len(Q), BS):
            q = torch.from_numpy(Q[i:i+BS]).to(device)
            z = model(q)
            z = F.normalize(z, dim=-1)
            
            # Find most similar training image
            sims = z @ train_gallery_tensor.t()
            max_sims, max_idx = torch.max(sims, dim=1)
            
            pseudo_targets.append(max_idx.cpu().numpy())
            confidence_scores.append(max_sims.cpu().numpy())
    
    pseudo_targets = np.concatenate(pseudo_targets)
    confidence_scores = np.concatenate(confidence_scores)
    
    # Filter by confidence
    high_conf_mask = confidence_scores >= confidence_thresh
    
    print(f"[pseudo-label] {high_conf_mask.sum()}/{len(Q)} captions above {confidence_thresh:.2f} confidence")
    
    return Q, pseudo_targets, high_conf_mask, confidence_scores


# =========================
#   FIX 3: SEMI-SUPERVISED FINE-TUNING
# =========================
def train_semi_supervised(
    model, 
    Xtr, Ytr, img_ids_row_tr,  # Original training
    Q_pseudo, Y_pseudo, pseudo_mask,  # Pseudo-labeled test
    Xva, val_gallery, cap2gal_local,
    batch=512, epochs=10, lr=3e-5, wd=1e-4,
    tau=0.06, confidence_weight=True,
    device=torch.device("cuda"),
    pooling="none", n_patches=None,
    out_dir=None, seed=42,
):
    """
    Fine-tune on train + high-confidence pseudo-labeled test
    """
    from torch.utils.data import DataLoader, ConcatDataset
    
    out_dir = Path(out_dir) if out_dir is not None else Path("./outputs/semi")
    out_dir.mkdir(parents=True, exist_ok=True)
    seed_all(seed)
    
    # Train dataset
    class SimpleDS(torch.utils.data.Dataset):
        def __init__(self, X, Y, weight=1.0):
            self.X = torch.from_numpy(X).float()
            self.Y = torch.from_numpy(Y).float()
            self.weight = weight
        def __len__(self): return len(self.X)
        def __getitem__(self, i): return self.X[i], self.Y[i], self.weight
    
    # Pseudo dataset (only high-confidence)
    Q_conf = Q_pseudo[pseudo_mask]
    Y_conf = Y_pseudo[pseudo_mask]
    
    ds_train = SimpleDS(Xtr, Ytr, weight=1.0)
    ds_pseudo = SimpleDS(Q_conf, Y_conf, weight=0.5)  # Lower weight for pseudo
    
    # Combine
    combined_ds = ConcatDataset([ds_train, ds_pseudo])
    dl = DataLoader(combined_ds, batch_size=batch, shuffle=True, num_workers=2, pin_memory=True)
    
    params = [p for p in model.parameters() if p.requires_grad]
    opt = torch.optim.AdamW(params, lr=lr, weight_decay=wd)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs*len(dl), eta_min=lr*0.1)
    
    best_mrr = -1.0
    
    for ep in range(1, epochs+1):
        model.train()
        running_loss = 0.0
        
        for xb, yb, wb in dl:
            xb, yb, wb = xb.to(device), yb.to(device), wb.to(device)
            
            opt.zero_grad()
            pred = model(xb)
            pred_norm = F.normalize(pred, dim=-1)
            yb_norm = F.normalize(yb, dim=-1)
            
            # Simple InfoNCE
            logits = (pred_norm @ yb_norm.t()) / tau
            labels = torch.arange(len(xb), device=device)
            loss = F.cross_entropy(logits, labels, reduction='none')
            
            # Weight by confidence
            loss = (loss * wb).mean()
            
            loss.backward()
            opt.step()
            sched.step()
            
            running_loss += loss.item() * xb.size(0)
        
        stats = validate_retrieval(model, Xva, val_gallery, cap2gal_local, pooling, n_patches)
        print(f"[semi {ep:02d}] loss={running_loss/len(combined_ds):.4f} | MRR={stats['MRR']:.4f} R@1={stats['R1']:.3f}")
        
        if stats["MRR"] > best_mrr:
            best_mrr = stats["MRR"]
            torch.save({"model": model.state_dict(), "epoch": ep, "val": stats}, out_dir / "best.pt")
    
    return stats


# =========================
#   FIX 4: MULTI-CROP TEST-TIME AUGMENTATION
# =========================
def generate_submission_multicrop(model, test_data, pooling, n_patches, device, out_path, num_crops=7):
    """
    Multi-crop TTA with different noise patterns
    """
    Q = test_data["captions/embeddings"].astype(np.float32)
    ids = test_data.get("captions/ids", np.arange(len(Q)).astype(str))
    
    model.eval()
    # Keep some dropout active
    for m in model.modules():
        if isinstance(m, nn.Dropout):
            m.train()
            m.p = 0.05  # Reduce dropout
    
    BS = 1024
    all_preds = []
    
    print(f"[MultiCrop TTA] {num_crops} crops...")
    
    for crop_idx in range(num_crops):
        outs = []
        with torch.no_grad():
            for i in range(0, len(Q), BS):
                q = torch.from_numpy(Q[i:i+BS]).to(device)
                
                # Different augmentation per crop
                if crop_idx > 0:
                    # Gaussian noise
                    noise_scale = 0.005 * (1 + crop_idx * 0.2)
                    q = q + torch.randn_like(q) * noise_scale
                    
                    # Random masking (dropout some dimensions)
                    if crop_idx % 2 == 0:
                        mask = torch.rand_like(q) > 0.05
                        q = q * mask
                
                z = model(q)
                z = F.normalize(z, dim=-1)
                outs.append(z.cpu().numpy())
        
        pred = np.concatenate(outs, axis=0)
        all_preds.append(pred)
    
    # Weighted average (later crops get slightly less weight)
    weights = np.array([1.0] + [0.9] * (num_crops - 1))
    weights = weights / weights.sum()
    
    ensemble = np.average(all_preds, axis=0, weights=weights)
    ensemble = ensemble / (np.linalg.norm(ensemble, axis=1, keepdims=True) + 1e-8)
    
    generate_submission(ids, ensemble, str(out_path))
    print(f"✓ MultiCrop TTA ({num_crops}x): {out_path}")
    return ensemble


# =========================
#   ULTIMATE MAIN - ALL FIXES
# =========================
def main(args):
    """
    ULTIMATE pipeline:
    1. Better validation split (stratified)
    2. Train Stage 1 + 2 (ULTRA)
    3. Pseudo-label test set
    4. Semi-supervised fine-tune
    5. Multi-crop TTA submission
    """
    import json
    from pathlib import Path
    import numpy as np
    import torch
    
    out_dir, seed = args.out_dir, args.seed
    pooling, n_patches = args.pooling, None
    OUT = Path(f"/kaggle/working/outputs/{out_dir}")
    OUT.mkdir(parents=True, exist_ok=True)
    STAGE1 = OUT / "stage1"
    STAGE2 = OUT / "stage2"
    SEMI = OUT / "semi"
    for d in [STAGE1, STAGE2, SEMI]:
        d.mkdir(parents=True, exist_ok=True)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load data
    X, Y, cap_ids, img_ids_row, full_img, all_img_ids = load_train(OUT)
    
    # BETTER VALIDATION SPLIT
    print("[split] Using stratified split for better test alignment...")
    cap_is_tr, cap_is_val, val_gallery, cap2gal_local, val_img_indices = build_stratified_image_split(
        img_ids_row, all_img_ids, full_img, X, Y, args.val_ratio, seed, OUT
    )
    
    Xtr, Ytr = X[cap_is_tr], Y[cap_is_tr]
    Xva = X[cap_is_val]
    
    din, dout = 1024, 1536
    
    # Geometry
    A, b = procrustes_closed_form_centroids_weighted(
        Xtr.astype(np.float32), Ytr.astype(np.float32),
        img_ids_row[cap_is_tr], eps=1e-5, use_variance_weighting=True,
    )
    
    # Model
    class GeomFromWeights(nn.Module):
        def __init__(self, A_np, b_np):
            super().__init__()
            self.fc = nn.Linear(A_np.shape[1], A_np.shape[0], bias=True)
            with torch.no_grad():
                self.fc.weight.copy_(torch.from_numpy(A_np))
                self.fc.bias.copy_(torch.from_numpy(b_np))
        def forward(self, x): return self.fc(x)
    
    geom = GeomFromWeights(A, b).to(device)
    for p in geom.parameters():
        p.requires_grad = False
    
    adapter = FlowEnhancedAdapter(
        din=din, hidden=1024, dout=dout, pdrop=0.08,  # Less dropout
        init_scale=0.3, use_layernorm=True,
        use_flow=True, flow_layers=4,
    )
    
    class GeomPlusAdapter(nn.Module):
        def __init__(self, g, a):
            super().__init__()
            self.geom, self.adapter = g, a
        def forward(self, x): return self.geom(x) + self.adapter(x)
    
    model = GeomPlusAdapter(geom, adapter).to(device)
    
    # Stage 1
    if not args.eval_only:
        print(f"\n{'='*60}\n[STAGE 1]\n{'='*60}")
        train_simple_enhanced(
            model, Xtr, Ytr, img_ids_row[cap_is_tr],
            Xva, val_gallery, cap2gal_local,
            batch=args.batch, epochs=12, lr=2.5e-4, wd=1e-4,
            tau_fixed=0.07, triplet_margin=0.2, triplet_w=0.3,
            device=device, pooling=pooling, n_patches=n_patches,
            out_dir=str(STAGE1), seed=seed,
        )
        ckpt1 = torch.load(STAGE1 / "best.pt", map_location=device)
        model.load_state_dict(ckpt1["model"])
    
    # Stage 2 ULTRA
    if not args.eval_only:
        print(f"\n{'='*60}\n[STAGE 2 ULTRA]\n{'='*60}")
        train_bigjump_ultra(
            model, Xtr, Ytr, img_ids_row[cap_is_tr],
            Xva, val_gallery, cap2gal_local,
            batch=args.batch, epochs=24, base_lr=1e-4, wd=1e-4,
            tau_start=0.085, tau_end=0.045,
            queue_size=98304, queue_warmup_epochs=7,
            queue_recent_schedule=(800, 1600, 3200, 6400, 12800, 25600, 51200, 98304),
            mine_H=80, lambda_agree=0.018, alpha_cos=0.12,
            lambda_moment=0.004, lambda_flow=0.012, lambda_distill=0.08,
            use_mixup=True, mixup_alpha=0.25,
            xbp_per_img=7, xbp_global=65536,
            use_dcl=True, dcl_prior=0.012,
            device=device, pooling=pooling, n_patches=n_patches,
            out_dir=str(STAGE2), seed=seed,
            use_ema=True, ema_decay=0.9998,
        )
        
        ckpt2_path = STAGE2 / "best_ema.pt" if (STAGE2 / "best_ema.pt").exists() else STAGE2 / "best.pt"
        ckpt2 = torch.load(ckpt2_path, map_location=device)
        model.load_state_dict(ckpt2["model"])
        print(f"Stage2 best: {ckpt2.get('val', {}).get('MRR', 0):.4f}")
    
    # PSEUDO-LABELING
    if not args.eval_only:
        print(f"\n{'='*60}\n[PSEUDO-LABELING TEST SET]\n{'='*60}")
        test_data = load_data(TEST_NPZ)
        
        Q_pseudo, targets_pseudo, conf_mask, conf_scores = pseudo_label_test_set(
            model, test_data, full_img, device, confidence_thresh=0.80
        )
        
        # Get pseudo target vectors
        Y_pseudo = full_img[targets_pseudo]
        
        # Semi-supervised fine-tune
        print(f"\n{'='*60}\n[SEMI-SUPERVISED FINE-TUNE]\n{'='*60}")
        train_semi_supervised(
            model,
            Xtr, Ytr, img_ids_row[cap_is_tr],
            Q_pseudo, Y_pseudo, conf_mask,
            Xva, val_gallery, cap2gal_local,
            batch=512, epochs=8, lr=2e-5, wd=5e-5,
            tau=0.05,
            device=device, pooling=pooling, n_patches=n_patches,
            out_dir=str(SEMI), seed=seed,
        )
        
        ckpt_semi = torch.load(SEMI / "best.pt", map_location=device)
        model.load_state_dict(ckpt_semi["model"])
        print(f"Semi-supervised best: {ckpt_semi.get('val', {}).get('MRR', 0):.4f}")
    
    # FINAL SUBMISSION with MultiCrop TTA
    print(f"\n{'='*60}\n[FINAL SUBMISSION - MultiCrop TTA]\n{'='*60}")
    test_data = load_data(TEST_NPZ)
    sub = OUT / "submission.csv"
    generate_submission_multicrop(model, test_data, pooling, n_patches, device, sub, num_crops=9)
    
    print(f"\n✓✓✓ ULTIMATE PIPELINE COMPLETE ✓✓✓")


# Quick argparse
if __name__ == "__main__":
    import argparse
    p = argparse.ArgumentParser()
    p.add_argument("--out_dir", type=str, default="ultimate_final")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--eval_only", action="store_true")
    p.add_argument("--val_ratio", type=float, default=0.10)
    p.add_argument("--batch", type=int, default=512)
    p.add_argument("--pooling", type=str, default="CLS")
    args, _ = p.parse_known_args()
    main(args)

[meta] text_dim=1024 | image_dim=1536
[split] Using stratified split for better test alignment...

[STAGE 1]
[simple 01] loss=2.439839 lr=0.000083 | val_MRR=0.3263 | R@1=0.204 R@5=0.455 R@10=0.581 | median=7 p75=29
[simple 02] loss=2.083307 lr=0.000167 | val_MRR=0.3424 | R@1=0.215 R@5=0.478 R@10=0.611 | median=6 p75=24
[simple 03] loss=1.986402 lr=0.000250 | val_MRR=0.3575 | R@1=0.228 R@5=0.499 R@10=0.634 | median=6 p75=20
[simple 04] loss=1.908524 lr=0.000242 | val_MRR=0.3655 | R@1=0.235 R@5=0.512 R@10=0.642 | median=5 p75=19
[simple 05] loss=1.840707 lr=0.000221 | val_MRR=0.3742 | R@1=0.242 R@5=0.524 R@10=0.654 | median=5 p75=18
[simple 06] loss=1.784344 lr=0.000188 | val_MRR=0.3802 | R@1=0.246 R@5=0.534 R@10=0.667 | median=5 p75=17
[simple 07] loss=1.736089 lr=0.000147 | val_MRR=0.3862 | R@1=0.251 R@5=0.544 R@10=0.674 | median=5 p75=16
[simple 08] loss=1.694322 lr=0.000103 | val_MRR=0.3901 | R@1=0.254 R@5=0.547 R@10=0.678 | median=4 p75=16
[simple 09] loss=1.660397 lr=0.000063 | val