In [2]:
from kaggle_secrets import UserSecretsClient
import json, os, pathlib, subprocess, sys

# --- 1. Load your secret (full JSON from kaggle.json) ---
secret_name = "kaggle_json"  # change if you used another name
user_secrets = UserSecretsClient()
raw = user_secrets.get_secret(secret_name)
creds = json.loads(raw)

# --- 2. Forcefully recreate ~/.kaggle/kaggle.json ---
kaggle_dir = pathlib.Path.home() / ".kaggle"
kaggle_dir.mkdir(parents=True, exist_ok=True)
cred_path = kaggle_dir / "kaggle.json"
cred_path.write_text(json.dumps(creds))
os.chmod(cred_path, 0o600)

# --- 3. Double-check the file actually exists and is readable ---
print("✅ Credentials written to:", cred_path)
!ls -la ~/.kaggle/
#!cat ~/.kaggle/kaggle.json | head -1

# --- 4. Reinstall Kaggle CLI cleanly ---
#!pip install --upgrade --force-reinstall kaggle --quiet

✅ Credentials written to: /root/.kaggle/kaggle.json
total 16
drwxr-xr-x 2 root root 4096 Oct 29 19:36 .
drwx------ 1 root root 4096 Oct 29 19:36 ..
-rw------- 1 root root   72 Oct 29 19:36 kaggle.json


In [9]:
%%capture
!kaggle competitions files -c aml-competition
!kaggle competitions download -c aml-competition -p /kaggle/working --force
!mkdir -p /kaggle/working/data
!unzip -o /kaggle/working/aml-competition.zip -d /kaggle/working/data
!ls -lah /kaggle/working/data

In [5]:
!git clone https://github.com/Mamiglia/challenge.git

fatal: destination path 'challenge' already exists and is not an empty directory.


In [10]:
# runner.py
import argparse, json, random, re, time
from pathlib import Path
from os.path import basename

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from challenge.src.common import load_data, generate_submission

# ---------- Paths ----------
DATA_ROOT = Path("/kaggle/working/data")
TRAIN_DIR = DATA_ROOT / "train" / "train"
TEST_DIR  = DATA_ROOT / "test"  / "test"
TRAIN_NPZ = TRAIN_DIR / "train.npz"
TEST_NPZ  = TEST_DIR  / "test.clean.npz"
TRAIN_CAPTIONS = TRAIN_DIR / "captions.txt"
TEST_CAPTIONS  = TEST_DIR  / "captions.txt"

# ---------- Determinism ----------
def seed_all(s: int):
    random.seed(s); np.random.seed(s)
    torch.manual_seed(s); torch.cuda.manual_seed_all(s)
    torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False

# ---------- Metadata/safety ----------
def _assert_metadata_dims(npz_path: Path, expect_text=1024, expect_img=1536):
    d = np.load(npz_path, allow_pickle=True)
    md = {k: d[k][0] for k in d.files if k.startswith("metadata/")}
    tdim = md.get("metadata/embedding_dim_text", None)
    idim = md.get("metadata/embedding_dim_image", None)
    print(f"[meta] text_dim={tdim} | image_dim={idim}")
    assert tdim == expect_text and idim == expect_img, (
        f"Encoder dims mismatch: expected text={expect_text}, image={expect_img} "
        f"but got text={tdim}, image={idim}. Regenerate NPZ with the fixed encoders."
    )

# ---------- Caption→image matching from captions.txt ----------
def _build_image_index(image_ids):
    idx_exact = {}; idx_base={}; idx_stem={}
    for i, v in enumerate(image_ids):
        s = str(v); idx_exact[s]=i
        b = basename(s); idx_base[b]=i
        stem = b.rsplit(".",1)[0] if "." in b else b
        idx_stem[stem]=i
    return idx_exact, idx_base, idx_stem

def _iter_targets_from_captions(path: Path, n_text: int, idx_exact, idx_base, idx_stem):
    targets=[]
    with open(path, "r", encoding="utf-8") as f:
        for raw in f:
            if len(targets)>=n_text: break
            line=raw.strip()
            if not line: continue
            # be robust to delimiters
            parts=re.split(r'\||\t|,| {2,}', line)
            if len(parts)==1: parts=line.split(" ",1)
            tok=parts[0].strip().strip('"').strip("'")
            if not tok: continue
            def _match(tok_):
                if tok_ in idx_exact: return idx_exact[tok_]
                b=basename(tok_)
                if b in idx_base: return idx_base[b]
                stem=b.rsplit(".",1)[0] if "." in b else b
                if stem in idx_stem: return idx_stem[stem]
                return None
            idx=_match(tok)
            if idx is None:
                tok2=tok.replace("Images/","").replace("./","")
                idx=_match(tok2)
            if idx is None:
                if "." not in tok: continue
                raise AssertionError(f"Could not match image token '{tok}'")
            targets.append(idx)
    assert len(targets)==n_text, f"matched {len(targets)} vs N_text {n_text}"
    return np.asarray(targets, dtype=np.int64)

# ---------- Data loaders ----------
def load_train(out_dir: Path):
    _assert_metadata_dims(TRAIN_NPZ, 1024, 1536)
    d=np.load(TRAIN_NPZ, allow_pickle=True)
    X=d["captions/embeddings"].astype(np.float32)  # (N_text,1024)
    I=d["images/embeddings"].astype(np.float32)    # (N_img,1536)
    cap_ids=d.get("captions/ids", np.arange(len(X)).astype(str))
    img_names=d.get("images/names", np.arange(len(I)).astype(str))

    ex,ba,st=_build_image_index(img_names)
    assert TRAIN_CAPTIONS.exists(), f"Missing {TRAIN_CAPTIONS}"
    targets=_iter_targets_from_captions(TRAIN_CAPTIONS, len(X), ex,ba,st)

    Y=I[targets]                       # (N_text, 1536) GT image vec per caption
    img_ids_row = img_names[targets]   # image name per caption row (for splitting)
    meta={"n_text":int(len(X)),"n_images":int(len(I))}
    (out_dir/"train_detect.json").write_text(json.dumps(meta, indent=2))
    return X,Y,cap_ids,img_ids_row,I,img_names

def load_test_npz():
    d_test=np.load(TEST_NPZ, allow_pickle=True)
    Q=d_test["captions/embeddings"].astype(np.float32)
    q_ids=d_test.get("captions/ids", np.arange(len(Q)).astype(str))
    if "images/embeddings" in d_test.files:
        G=d_test["images/embeddings"].astype(np.float32)
        g_ids=d_test.get("images/names", np.arange(len(G)).astype(str))
    else:
        d_tr=np.load(TRAIN_NPZ, allow_pickle=True)
        G=d_tr["images/embeddings"].astype(np.float32)
        g_ids=d_tr.get("images/names", np.arange(len(G)).astype(str))
    return Q,G,q_ids,g_ids

# ---------- Dataset ----------
class PairDS(Dataset):
    def __init__(self, X, Y): self.X=torch.from_numpy(X); self.Y=torch.from_numpy(Y)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, i): return self.X[i], self.Y[i]

# ---------- Pooling (no-op per spec) ----------
def apply_pooling(x: torch.Tensor, mode: str, n_patches=None):
    return x

# ---------- Base Models ----------
class LinearProj(nn.Module):
    def __init__(self, din, dout):
        super().__init__()
        self.fc=nn.Linear(din,dout)
        nn.init.xavier_normal_(self.fc.weight); nn.init.zeros_(self.fc.bias)
    def forward(self,x): return self.fc(x)

class MLP1(nn.Module):
    def __init__(self, din, dout, hidden=512, pdrop=0.1):
        super().__init__()
        self.net=nn.Sequential(
            nn.Linear(din,hidden), nn.ReLU(), nn.Dropout(pdrop),
            nn.Linear(hidden,dout)
        )
        for m in self.net:
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight); nn.init.zeros_(m.bias)
    def forward(self,x): return self.net(x)

class MLP2(nn.Module):
    # Spec: 1024→1024→512→1536, dropout=0.1 on hidden layers
    def __init__(self, din, dout, h1=1024, h2=512, pdrop=0.1):
        super().__init__()
        self.net=nn.Sequential(
            nn.Linear(din,h1), nn.ReLU(), nn.Dropout(pdrop),
            nn.Linear(h1,h2), nn.ReLU(), nn.Dropout(pdrop),
            nn.Linear(h2,dout)
        )
        for m in self.net:
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight); nn.init.zeros_(m.bias)
    def forward(self,x): return self.net(x)

# ---------- Geometry-Preserving Linear (Whiten → Procrustes → Re-color) ----------
def _cov_eigh(zc, eps):
    # zc: (N,D), zero-mean
    N = zc.shape[0]
    C = (zc.T @ zc) / max(1, N)
    # symmetric PSD
    S, U = np.linalg.eigh(C)
    S = np.clip(S, 0.0, None)
    inv_sqrt = 1.0 / np.sqrt(S + eps)
    sqrt = np.sqrt(S + eps)
    C_mhalf = (U * inv_sqrt) @ U.T     # C^{-1/2}
    C_phalf = (U * sqrt) @ U.T         # C^{ 1/2}
    return C_mhalf.astype(np.float32), C_phalf.astype(np.float32)

def procrustes_closed_form(X, Y, eps=1e-5):
    """
    X: (N, 1024) text, Y: (N, 1536) targets (per-caption image vectors)
    Returns A (1536x1024), b (1536,) such that y_hat = A x + b
    """
    # center
    mu_x = X.mean(0, dtype=np.float64)
    mu_y = Y.mean(0, dtype=np.float64)
    Xc = X - mu_x
    Yc = Y - mu_y

    # whiten both
    Cx_mh, _ = _cov_eigh(Xc, eps)      # 1024x1024
    Cy_mh, Cy_ph = _cov_eigh(Yc, eps)  # 1536x1536 (only Cy_ph used)
    Xw = Xc @ Cx_mh.T                  # (N,1024)
    Yw = Yc @ Cy_mh.T                  # (N,1536)

    # orthogonal Procrustes: maximize Tr(R^T Xw^T Yw)
    M = Xw.T @ Yw                      # (1024,1536)
    # SVD on M; for rectangular, do SVD and build R = U V^T in common subspace
    U, _, Vt = np.linalg.svd(M, full_matrices=False)
    R = U @ Vt                         # (1024,1536) @ (1536,1536) -> (1024,1536) OK since full_matrices=False
    # we need R as (1536x1024) mapping whitened X to whitened Y; above is (1024x1536)
    R = R.T                            # (1536,1024)

    # re-color to Y space
    A = (Cy_ph @ R @ Cx_mh).astype(np.float32)     # (1536,1536)*(1536,1024)*(1024,1024) = (1536,1024)
    b = (mu_y - (A @ mu_x)).astype(np.float32)     # (1536,)
    return A, b

class GeomLinear(nn.Module):
    """
    Linear layer with weights initialized from closed-form geometry mapping.
    Optionally fine-tunes with tiny LR.
    """
    def __init__(self, A: np.ndarray, b: np.ndarray):
        super().__init__()
        D_in = A.shape[1]; D_out = A.shape[0]
        self.fc = nn.Linear(D_in, D_out, bias=True)
        with torch.no_grad():
            self.fc.weight.copy_(torch.from_numpy(A))
            self.fc.bias.copy_(torch.from_numpy(b))
    def forward(self, x):
        return self.fc(x)

# ---------- Loss pieces (for optional fine-tune) ----------
def moment_align(pred, tgt):
    mu_p, mu_t = pred.mean(0), tgt.mean(0)
    sd_p, sd_t = pred.std(0, unbiased=False), tgt.std(0, unbiased=False)
    return F.mse_loss(mu_p, mu_t) + F.mse_loss(sd_p, sd_t)

def info_nce(pred, tgt):
    p = F.normalize(pred, dim=-1)
    t = F.normalize(tgt, dim=-1)
    logits = p @ t.t()
    labels = torch.arange(pred.size(0), device=pred.device)
    return F.cross_entropy(logits, labels)

# ---------- Split by image id & mapping ----------
def build_image_id_split(img_ids_row, all_img_names, full_img, X, Y, val_ratio, seed, out_dir: Path):
    uniq_img_names = np.array(sorted(set(map(str, all_img_names))))
    rng = np.random.default_rng(seed); rng.shuffle(uniq_img_names)
    n_val = max(1, int(len(uniq_img_names) * val_ratio))
    val_images = set(uniq_img_names[:n_val])
    tr_images  = set(uniq_img_names[n_val:])
    assert len(val_images & tr_images) == 0, "Leakage in image id split!"

    # Caption masks
    cap_is_val = np.array([str(iid) in val_images for iid in img_ids_row], dtype=bool)
    cap_is_tr  = ~cap_is_val

    # Build VAL gallery (unique images in VAL) as sorted names → local indices
    val_img_names_sorted = np.array(sorted(val_images))
    name2local = {name:i for i,name in enumerate(val_img_names_sorted)}
    name2global = {str(n):i for i,n in enumerate(all_img_names)}
    val_img_indices = np.array([name2global[n] for n in val_img_names_sorted], dtype=np.int64)
    val_gallery = full_img[val_img_indices]  # (M,1536)

    # For each VAL caption row, get local gallery index (no remap later)
    cap2gal_local = np.array([name2local[str(n)] for n in img_ids_row[cap_is_val]], dtype=np.int64)

    # Persist mapping for debugging/acceptance
    (out_dir/"val_indices.json").write_text(json.dumps({
        "val_img_indices": val_img_indices.tolist(),
        "val_caption_to_gallery_index": cap2gal_local.tolist(),
        "n_val_captions": int(cap_is_val.sum()),
        "n_val_unique_images": int(len(val_images))
    }, indent=2))

    return cap_is_tr, cap_is_val, val_gallery, cap2gal_local, val_img_indices

# ---------- Metrics / evaluator (full gallery, cosine on L2) ----------
@torch.no_grad()
def validate_retrieval(model, Xv, val_gallery, cap2gal_local, pooling, n_patches, bs=1024):
    device=next(model.parameters()).device
    Gi = torch.from_numpy(val_gallery).to(device)
    Gi = F.normalize(Gi, dim=-1)

    ranks=[]
    for i in range(0, len(Xv), bs):
        xb=torch.from_numpy(Xv[i:i+bs]).to(device)
        xb=apply_pooling(xb, pooling, n_patches)   # no-op
        pred=model(xb)
        pred=F.normalize(pred, dim=-1)
        sims=pred @ Gi.t()                         # (b, M)
        for j in range(sims.size(0)):
            true_idx = int(cap2gal_local[i+j])     # local gallery index
            order = torch.argsort(sims[j], descending=True)
            rank = (order==true_idx).nonzero(as_tuple=False).item() + 1
            ranks.append(rank)

    ranks = np.array(ranks)
    mrr = float(np.mean(1.0 / ranks))
    r1  = float(np.mean(ranks<=1))
    r5  = float(np.mean(ranks<=5))
    r10 = float(np.mean(ranks<=10))
    return {
        "MRR": mrr,
        "R1": r1,
        "R5": r5,
        "R10": r10,
        "rank_median": int(np.median(ranks)),
        "rank_p75": int(np.percentile(ranks, 75))
    }

def count_params_mb(model):
    params=sum(p.numel() for p in model.parameters())
    mb = params * 4 / (1024**2)
    return params, mb

def time_ms_per_query(model, din, pooling, n_patches):
    device=next(model.parameters()).device
    x=torch.randn(2048, din, device=device)
    x=apply_pooling(x, pooling, n_patches)
    if device.type=="cuda":
        torch.cuda.synchronize()
    t0=time.time()
    with torch.no_grad(): _=model(x)
    if device.type=="cuda":
        torch.cuda.synchronize()
    ms_gpu=(time.time()-t0)*1000/len(x)
    # CPU
    mcpu=model.to("cpu"); xcpu=x.to("cpu")
    t1=time.time()
    with torch.no_grad(): _=mcpu(xcpu)
    ms_cpu=(time.time()-t1)*1000/len(xcpu)
    model.to(device)
    return ms_gpu, ms_cpu

# ---------- Train (for optional fine-tune) ----------
def train_one(model, loader, opt, alpha, beta, gamma, moment_w, pooling, n_patches, device):
    model.train(); total=0.0
    for xb,yb in loader:
        xb,yb=xb.to(device), yb.to(device)
        xb=apply_pooling(xb, pooling, n_patches)
        opt.zero_grad(set_to_none=True)
        pred=model(xb)
        # α·(1−cos) + β·MSE + λ·moment_align + γ·InfoNCE
        cos = 1 - F.cosine_similarity(pred, yb, dim=-1).mean()
        mse = F.mse_loss(pred, yb)
        a_loss = moment_w * moment_align(pred, yb) if moment_w>0 else pred.new_tensor(0.0)
        ce = info_nce(pred, yb) if gamma>0 else pred.new_tensor(0.0)
        loss = alpha*cos + beta*mse + gamma*ce + a_loss
        loss.backward(); opt.step()
        total += loss.item()*xb.size(0)
    return total/len(loader.dataset)


Captions: 125k, Images: 25k ⇒ exactly 5 captions/image (125k / 25k = 5).

Train: 112,500 captions over 22,500 images; Val: 12,500 captions over 2,500 images. Clean 90/10 split by unique image ids → good for avoiding leakage. 

Each listed image shows count = 5. Since all images have 5 captions, there’s no frequency skew in this dataset. Good news: balanced multi-caption coverage.

sampling is effectively uniform over images

> We can safely skip the frequency-aware sampler

In [25]:
# ---------- Residual Adapter + Geom (with Orthogonalization) ----------
class ResidualAdapter(nn.Module):
    def __init__(self, din=1024, hidden=1024, dout=1536, pdrop=0.1, init_scale=0.1):
        super().__init__()
        self.fc1 = nn.Linear(din, hidden)
        self.fc2 = nn.Linear(hidden, dout)
        self.drop = nn.Dropout(pdrop)
        nn.init.kaiming_normal_(self.fc1.weight, nonlinearity='relu')
        nn.init.zeros_(self.fc1.bias)
        nn.init.kaiming_normal_(self.fc2.weight, nonlinearity='linear')
        nn.init.zeros_(self.fc2.bias)
        with torch.no_grad():
            self.fc1.weight.mul_(init_scale)
            self.fc2.weight.mul_(init_scale)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = self.drop(h)
        return self.fc2(h)


class GeomWithAdapter(nn.Module):
    """
    y_hat = GeomLinear(x) (+ optionally trainable) + Residual(x),
    where Residual can be *projected* to the orthogonal complement of span(A) or *penalized* to be orthogonal.

    residual_ortho: "project" | "penalty" | "none"
    ortho_eps: ridge for projector
    ortho_lambda: weight for penalty (if residual_ortho == "penalty")
    unfreeze_bias: whether to unfreeze geom bias together with weight
    """
    def __init__(
        self, A: np.ndarray, b: np.ndarray,
        din=1024, dout=1536, hidden=1024, pdrop=0.1,
        residual_ortho: str = "project",
        ortho_eps: float = 1e-4,
        ortho_lambda: float = 0.05,
        unfreeze_bias: bool = False,
    ):
        super().__init__()
        self.geom = GeomLinear(A, b)
        self.adapter = ResidualAdapter(din=din, hidden=hidden, dout=dout, pdrop=pdrop)

        # default: freeze geom at start; training loop may unfreeze weight later
        for p in self.geom.parameters():
            p.requires_grad = False

        # orthogonalization config
        assert residual_ortho in ("project", "penalty", "none")
        self.residual_ortho = residual_ortho
        self.ortho_lambda = float(ortho_lambda)
        self.ortho = ResidualOrthogonalizer(eps=float(ortho_eps))
        self._unfreeze_bias = bool(unfreeze_bias)

    def unfreeze_geom(self, weight_lr_scale=0.05):
        # In training we set per-parameter LRs in the optimizer; here just set requires_grad flags.
        self.geom.fc.weight.requires_grad = True
        if self._unfreeze_bias:
            self.geom.fc.bias.requires_grad = True

    def forward(self, x, return_ortho_penalty: bool = False):
        """
        If residual_ortho == "project": returns (yhat, 0.0) with hard projection (autocast OFF inside).
        If residual_ortho == "penalty": returns (yhat, penalty_scalar).
        If residual_ortho == "none": returns (yhat, 0.0).
        """
        g = self.geom(x)               # (B,1536)
        r = self.adapter(x)            # (B,1536)

        penalty = None
        if self.residual_ortho == "project":
            # Hard projection of residual into orth complement of span(A) in float32, autocast OFF
            A = self.geom.fc.weight     # (1536,1024)
            device_type = "cuda" if g.device.type == "cuda" else "cpu"
            with torch.autocast(device_type=device_type, enabled=False):
                r = self.ortho.project_residual(r, A)
            y = g + r
            penalty = r.new_tensor(0.0)

        elif self.residual_ortho == "penalty":
            # Soft orthogonality: mean squared cosine between normalized parts
            g_n = F.normalize(g, dim=-1)
            r_n = F.normalize(r, dim=-1)
            dot = torch.sum(g_n * r_n, dim=-1)  # (B,)
            penalty = torch.mean(dot * dot)
            y = g + r

        else:  # "none"
            y = g + r
            penalty = r.new_tensor(0.0)

        if return_ortho_penalty:
            return y, penalty
        return y

In [28]:
# ---------- Residual Orthogonalization Utilities (AMP- & device-safe) ----------
class ResidualOrthogonalizer:
    """
    Projects residuals onto the orthogonal complement of span(A) safely under AMP and across devices.

    Math (done in float32, autocast OFF):
      G = A^T A + eps I   (1024x1024)
      invG = G^{-1}       (or cholesky/pinv fallback)
      Proj_col(A)(R) = ((R @ A) @ invG) @ A^T
      (I - P_A)R = R - Proj_col(A)(R)

    We cache invG per A-hash; at use, we always move invG to A.device to avoid CPU/GPU mismatches.
    """
    def __init__(self, eps: float = 1e-4, device: torch.device | None = None):
        self.eps = float(eps)
        self.device = device
        self._cached_invG = None
        self._cached_A_hash = None

    @torch.no_grad()
    def _hash_weight(self, A: torch.Tensor) -> tuple[int, int]:
        h = (A.shape[0], A.shape[1])
        sample = A.reshape(-1)[:1024] if A.numel() >= 1024 else A.reshape(-1)
        checksum = int(torch.sum((sample.float() * 1e3).round()).item())
        return (h[0] * 10_000 + h[1], checksum)

    @torch.no_grad()
    def refresh(self, A: torch.Tensor):
        """
        Ensure invG (float32) is up to date for current A.
        Always computes with autocast DISABLED and in float32.
        Caches on A.device.
        """
        dev = A.device
        device_type = "cuda" if dev.type == "cuda" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            A32 = A.detach().to(dev, dtype=torch.float32)
            key = self._hash_weight(A32)
            # If cache hit AND cached tensor already on the right device, reuse
            if key == self._cached_A_hash and self._cached_invG is not None and self._cached_invG.device == dev:
                return

            G = A32.transpose(0, 1) @ A32
            G = G + self.eps * torch.eye(G.shape[0], device=dev, dtype=torch.float32)
            try:
                L = torch.linalg.cholesky(G)
                invG = torch.cholesky_inverse(L)
            except RuntimeError:
                try:
                    invG = torch.linalg.inv(G)
                except RuntimeError:
                    invG = torch.linalg.pinv(G)

            self._cached_invG = invG   # stored on A.device
            self._cached_A_hash = key

    def project_residual(self, R: torch.Tensor, A: torch.Tensor) -> torch.Tensor:
        """
        Return (I - P_A) R. Autocast OFF inside, math in float32, cast back to R.dtype at end.
        Ensures invG and all operands are on the SAME device as A/R.
        """
        dev = R.device
        device_type = "cuda" if dev.type == "cuda" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            # Make sure the cache matches current A AND is on correct device
            self.refresh(A)
            A32  = A.to(dtype=torch.float32, device=dev)
            R32  = R.to(dtype=torch.float32, device=dev)
            invG = self._cached_invG
            if invG.device != dev:
                invG = invG.to(dev)

            # Proj_col(A)(R) = ((R @ A) @ invG) @ A^T
            Z = R32 @ A32
            Z = Z @ invG
            R_proj = Z @ A32.transpose(0, 1)

            R_orth = R32 - R_proj
            return R_orth.to(dtype=R.dtype, device=dev)


In [None]:
def agreement_loss(names, preds, eps=1e-8):
    """
    names: list[str] length B (image ids per caption)
    preds: (B, D) unnormalized predictions
    Returns mean variance across groups (after L2-norm), scalar.
    """
    with torch.no_grad():
        groups = {}
        for i, n in enumerate(map(str, names)):
            groups.setdefault(n, []).append(i)
        valid = [idxs for idxs in groups.values() if len(idxs) >= 2]
    if not valid:
        return preds.new_tensor(0.0)
    z = F.normalize(preds, dim=-1)
    vars_ = []
    for idxs in valid:
        g = z[idxs]
        v = g.var(dim=0, unbiased=False).mean()
        vars_.append(v)
    return torch.stack(vars_).mean()


In [13]:
class ImgQueue:
    def __init__(self, dim: int, capacity: int, device: torch.device):
        self.capacity = int(capacity)
        self.device = device
        self.ptr = 0                  # write pointer into a circular buffer
        self.full = False
        self.bank = torch.zeros(self.capacity, dim, device=device)

    @torch.no_grad()
    def enqueue(self, feats: torch.Tensor):
        # store as negatives only; avoid holding computation graphs
        feats = feats.detach()
        n = feats.size(0)
        if n >= self.capacity:
            # keep only the most recent 'capacity' entries
            self.bank.copy_(feats[-self.capacity:])
            self.ptr = 0
            self.full = True
            return
        end = self.ptr + n
        if end <= self.capacity:
            self.bank[self.ptr:end] = feats
        else:
            cut = self.capacity - self.ptr
            self.bank[self.ptr:] = feats[:cut]
            self.bank[:end - self.capacity] = feats[cut:]
        self.ptr = (self.ptr + n) % self.capacity
        if self.ptr == 0:
            self.full = True

    def size(self) -> int:
        # correct empty vs full handling
        if self.full:
            return self.capacity
        return self.ptr  # 0 when empty

    def get(self) -> torch.Tensor:
        # return FIFO-ordered contents (oldest -> newest)
        if self.size() == 0:
            return self.bank[:0]  # empty tensor with correct shape
        if self.full:
            # [ptr:cap] then [0:ptr] so that tail is the most recent
            return torch.cat([self.bank[self.ptr:], self.bank[:self.ptr]], dim=0)
        # not full: contents are [0:ptr)
        return self.bank[:self.ptr]

    def recent(self, max_items: int) -> torch.Tensor:
        """
        Return only the freshest 'max_items' features (chronological).
        """
        q = self.get()
        if q is None or q.numel() == 0:
            return q
        if max_items is None or q.size(0) <= int(max_items):
            return q
        return q[-int(max_items):]


 visualization

In [123]:
# ----------------- viz_utils.py (drop-in) -----------------
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
from typing import List, Sequence, Optional, Tuple

@torch.inference_mode()
def visualize_retrieval(
    pred_emb: torch.Tensor,              # (D,) predicted caption->image vector (already in image space)
    gt_index: int,                       # ground truth image index (into image_files / image_embeddings)
    image_files: Sequence[str],
    caption_text: str,
    image_embeddings: torch.Tensor,      # (N, D) gallery in same space
    k: int = 5,
    dataset_path: str = "/kaggle/working/data/train/train",
) -> Tuple[bool, Optional[int]]:
    """
    Show ground-truth image + top-k retrieved images. Highlights correct match if present.
    Returns: (gt_in_topk, gt_rank or None)
    """
    # Ensure CPU + float
    pred = F.normalize(pred_emb.view(1, -1), dim=-1).cpu()
    gallery = F.normalize(image_embeddings, dim=-1).cpu()

    # cosine similarities
    sims = (pred @ gallery.T).squeeze(0).numpy()  # (N,)
    retrieved = np.argsort(-sims)[:k]
    gt_in_topk = int(gt_index) in retrieved
    gt_rank = int(np.where(retrieved == gt_index)[0][0] + 1) if gt_in_topk else None

    # figure
    fig, axes = plt.subplots(1, k + 1, figsize=(22, 4))

    def _load_show(ax, img_name, title, color=None, weight=None):
        ax.axis("off")
        p = Path(dataset_path) / "Images" / img_name
        try:
            img = Image.open(p).convert("RGB")
            ax.imshow(img)
        except Exception:
            ax.text(0.5, 0.5, f"Image not found:\n{img_name}", ha="center", va="center")
        if color or weight:
            ax.set_title(title, fontsize=10, color=color or "black", fontweight=weight or "normal")
        else:
            ax.set_title(title, fontsize=10)

    # GT on the left
    gt_name = image_files[int(gt_index)]
    _load_show(axes[0], gt_name, f"Ground Truth\n{gt_name[:28]}...", color="green", weight="bold")

    # Top-k
    for i, idx in enumerate(retrieved):
        name = image_files[int(idx)]
        is_gt = (int(idx) == int(gt_index))
        title = f"Rank {i+1}\nCos: {sims[idx]:.3f}" + ("\n✓ CORRECT" if is_gt else "")
        _load_show(axes[i+1], name, title, color=("green" if is_gt else None), weight=("bold" if is_gt else None))

    plt.suptitle(f"Query caption: “{caption_text}”\nStatus: {'✓ Found at rank '+str(gt_rank) if gt_in_topk else f'✗ Not in top-{k}'}",
                 fontsize=13, fontweight="bold")
    plt.tight_layout()
    plt.show()

    return bool(gt_in_topk), (int(gt_rank) if gt_rank is not None else None)


@torch.inference_mode()
def visualize_samples(
    model: torch.nn.Module,
    text_embeddings: torch.Tensor,       # (M, din) caption embeddings (VAL)
    image_embeddings: torch.Tensor,      # (N, D) gallery (VAL)
    image_index_of_caption: Sequence[int],# length M: for each caption row, the GT image index
    image_files: Sequence[str],          # length N image filenames
    caption_texts: Optional[Sequence[str]] = None,  # optional strings to show (len M). If None, uses "caption i"
    n_examples: int = 6,
    k: int = 5,
    device: Optional[torch.device] = None,
):
    """
    Picks n_examples random caption rows, runs them through `model`, and calls visualize_retrieval().
    Also prints the current gate value g (if your model has .gate_raw).
    """
    model_device = next(model.parameters()).device
    device = device or model_device

    # Report gating scalar (if present)
    g = None
    if hasattr(model, "gate_raw"):
        g = torch.sigmoid(model.gate_raw.detach().to("cpu")).item()
        print(f"[gate] adapter gate g = {g:.4f} (0=off, 1=full)")

    M = text_embeddings.shape[0]
    idxs = torch.randperm(M)[:n_examples].tolist()

    for i, row in enumerate(idxs, 1):
        x = text_embeddings[row:row+1].to(device)   # (1, din)
        pred = model(x).squeeze(0)                  # (D,)
        gt_idx = int(image_index_of_caption[row])
        cap_txt = caption_texts[row] if (caption_texts is not None) else f"caption {row}"
        print(f"\n[{i}/{n_examples}] cap_row={row} → GT image idx={gt_idx}")
        visualize_retrieval(
            pred_emb=pred,
            gt_index=gt_idx,
            image_files=image_files,
            caption_text=cap_txt,
            image_embeddings=image_embeddings,
            k=k
        )


@torch.inference_mode()
def compute_gt_ranks(
    model: torch.nn.Module,
    text_embeddings: torch.Tensor,         # (M, din)
    image_embeddings: torch.Tensor,        # (N, D)
    image_index_of_caption: Sequence[int], # (M,)
    batch_size: int = 256,
    device: Optional[torch.device] = None,
) -> np.ndarray:
    """
    Returns an array of size (M,) with 1-based ranks of the ground-truth image for each caption.
    """
    model_device = next(model.parameters()).device
    device = device or model_device

    # Pre-norm gallery on device for speed
    gallery = F.normalize(image_embeddings.to(device), dim=-1)      # (N, D)

    ranks = []
    M = text_embeddings.shape[0]
    for i in range(0, M, batch_size):
        xb = text_embeddings[i:i+batch_size].to(device)
        pred = model(xb)                                            # (B, D)
        pred = F.normalize(pred, dim=-1)
        sims = pred @ gallery.T                                     # (B, N)
        order = torch.argsort(sims, dim=1, descending=True)         # (B, N)
        gt = torch.as_tensor(image_index_of_caption[i:i+xb.size(0)], device=device).view(-1, 1)
        # find rank (1-based)
        # We locate GT in each row
        eq = (order == gt).nonzero(as_tuple=False)
        # eq has rows: [row_idx, col_idx], where col_idx is the rank-1
        pos = torch.zeros(xb.size(0), dtype=torch.long, device=device)
        pos[eq[:,0]] = eq[:,1] + 1
        ranks.append(pos.detach().cpu().numpy())

    return np.concatenate(ranks, axis=0)


def plot_rank_hist(ranks: np.ndarray, k_marks: Tuple[int, ...] = (1,5,10,50)):
    """
    Simple histogram of ranks with vertical lines for common top-k cutoffs.
    """
    plt.figure(figsize=(8,4))
    # do NOT set colors explicitly (keeps your style guide clean)
    plt.hist(ranks, bins=1000)
    for k in k_marks:
        plt.axvline(k, linestyle="--")
    top1 = (ranks <= 1).mean()*100
    top5 = (ranks <= 5).mean()*100
    top10 = (ranks <= 10).mean()*100
    plt.title(f"GT rank distribution | R@1={top1:.1f}%  R@5={top5:.1f}%  R@10={top10:.1f}%")
    plt.xlabel("Ground-truth rank (1=best)")
    plt.ylabel("# captions")
    plt.tight_layout()
    plt.show()
# -----------------------------------------------------------


# settings for new model

In [14]:
# =============== P×K Sampler (F1.1) ===============
class PKBatchSampler(torch.utils.data.Sampler[list[int]]):
    """
    Samples batches with P unique images, each with K captions.
    Use in DataLoader(batch_sampler=PKBatchSampler(...)).
    """
    def __init__(self, img_ids: list[str], P: int, K: int, drop_last: bool = True, seed: int = 42):
        # group dataset indices by image id
        self.by_img = {}
        for idx, iid in enumerate(map(str, img_ids)):
            self.by_img.setdefault(iid, []).append(idx)
        # shuffle within each image's index list
        g = torch.Generator()
        g.manual_seed(seed)
        for lst in self.by_img.values():
            perm = torch.randperm(len(lst), generator=g).tolist()
            lst[:] = [lst[i] for i in perm]

        self.P, self.K, self.drop_last = int(P), int(K), bool(drop_last)
        self.iids = list(self.by_img.keys())
        self.cur = {iid: 0 for iid in self.iids}
        self.pool = [iid for iid in self.iids if len(self.by_img[iid]) > 0]

    def __iter__(self):
        import random
        pool = self.pool[:]
        random.shuffle(pool)
        batch = []
        while len(pool) >= self.P:
            pick = pool[:self.P]
            pool = pool[self.P:]
            for iid in pick:
                start = self.cur[iid]
                end = start + self.K
                lst = self.by_img[iid]
                if start >= len(lst):
                    continue
                take = lst[start:min(end, len(lst))]
                self.cur[iid] = min(end, len(lst))
                batch.extend(take)
            if len(batch) == self.P * self.K:
                yield batch
            else:
                if not self.drop_last and len(batch) > 0:
                    yield batch
                break
            batch = []

    def __len__(self):
        total = sum(len(v) // self.K for v in self.by_img.values())
        return max(0, total // self.P)


In [15]:
# =============== Cross-batch Positives Buffer (F1.2) ===============
from collections import deque

class XBPBuffer:
    """
    Keeps recent predicted caption embeddings as extra positives per image.
    Read-only in loss; write after optimizer step.
    """
    def __init__(self, per_image_cap: int = 4, global_cap: int = 32000, device: torch.device = torch.device("cpu")):
        self.per_img = {}            # iid -> deque of normalized tensors
        self.order = deque()         # (iid, stamp) to enforce global cap
        self.per_image_cap = int(per_image_cap)
        self.global_cap = int(global_cap)
        self.device = device
        self._count = 0
        self._stamp = 0

    @torch.no_grad()
    def add(self, img_ids: list[str], pred_feats: torch.Tensor):
        feats = F.normalize(pred_feats.detach(), dim=-1)
        for iid, f in zip(map(str, img_ids), feats):
            dq = self.per_img.setdefault(iid, deque(maxlen=self.per_image_cap))
            dq.append(f.to(self.device))
            self.order.append((iid, self._stamp))
            self._stamp += 1
            self._count += 1
            # trim global
            while self._count > self.global_cap:
                old_iid, _ = self.order.popleft()
                if self.per_img.get(old_iid):
                    try:
                        self.per_img[old_iid].popleft()
                        self._count -= 1
                        if len(self.per_img[old_iid]) == 0:
                            self.per_img.pop(old_iid, None)
                    except IndexError:
                        pass

    def build_bank_and_mask(self, batch_img_names: list[str], device: torch.device):
        """
        Returns:
          xbp_bank: (M, D) stacked positives from XBP across batch (M can be 0)
          xbp_pos_cols_per_row: list[list[int]] with column indices into xbp_bank for each row
        """
        rows = [str(x) for x in batch_img_names]
        per_row_feats = []
        row_counts = []
        for iid in rows:
            dq = self.per_img.get(iid, None)
            if dq is None or len(dq) == 0:
                row_counts.append(0)
            else:
                row_counts.append(len(dq))
                per_row_feats.extend(list(dq))
        if len(per_row_feats) == 0:
            return torch.empty(0, 0, device=device), [[] for _ in rows]
        xbp_bank = torch.stack(per_row_feats, dim=0).to(device)
        xbp_pos_cols = []
        offset = 0
        for cnt in row_counts:
            if cnt == 0:
                xbp_pos_cols.append([])
            else:
                xbp_pos_cols.append(list(range(offset, offset + cnt)))
                offset += cnt
        return xbp_bank, xbp_pos_cols


In [16]:
# =============== Hard-subset Mining (F2.1) ===============
@torch.no_grad()
def mine_hard(q_recent: torch.Tensor, pred: torch.Tensor, H: int, exclude_feats: torch.Tensor = None):
    """
    q_recent: (Q,D) normalized features to mine from (negatives pool)
    pred:     (B,D) normalized anchor predictions
    Returns:
      idx: (B,H) indices into q_recent for each anchor
      flat_idx: (B*H,) flattened indices
      mined: (B*H,D) the mined subset
    """
    if q_recent is None or q_recent.numel() == 0 or H <= 0:
        return None, None, None
    B = pred.size(0)
    Q = q_recent.size(0)
    sims = pred @ q_recent.t()  # (B,Q)
    H = min(int(H), Q)
    _, idx = torch.topk(sims, k=H, dim=1, largest=True, sorted=False)  # (B,H)
    flat_idx = idx.reshape(-1)
    mined = q_recent.index_select(0, flat_idx)      # (B*H,D)
    return idx, flat_idx, mined


In [17]:
# =============== Debiased NT-Xent denominator (F2.2) ===============
def _debiased_logZ(logits: torch.Tensor, pos_mask: torch.Tensor, class_prior: float = 0.01):
    """
    Chuang et al., 2020 (Debiased Contrastive Learning) — practical variant:
      Z* = sum_k exp(l_ik) - π * sum_{j in P_i} exp(l_ij)
    """
    exp_all = torch.exp(logits)          # (B,K)
    exp_pos = exp_all * pos_mask
    Z = exp_all.sum(dim=1)               # (B,)
    corr = class_prior * exp_pos.sum(dim=1)
    Z_star = torch.clamp(Z - corr, min=1e-8)
    return torch.log(Z_star)


In [18]:
# =============== Multi-positive InfoNCE with XBP, Hard Mining, DCL (F1.1+F1.2+F2.1+F2.2) ===============
def info_nce_multi(
    pred,                                    # (B,D) unnormed
    batch_img_targets,                       # (B,D)
    batch_img_names,                         # list length B
    queue_feats=None,                        # (Q,D) negatives pool
    tau: float = 0.07,
    xbp_bank: torch.Tensor = None,           # (M,D) extra positives
    xbp_pos_cols_per_row: list[list[int]] = None,
    hard_subset_H: int = 0,
    use_dcl: bool = False,
    dcl_prior: float = 0.01,
):
    p = F.normalize(pred, dim=-1)
    t = F.normalize(batch_img_targets, dim=-1)

    # base bank block 0: in-batch targets (positives exist here)
    bank_blocks = [t]
    B = pred.size(0)

    # optional hard mining over queue
    if queue_feats is not None and queue_feats.numel() > 0:
        qn = F.normalize(queue_feats, dim=-1)
        if hard_subset_H and hard_subset_H > 0:
            _, _, mined_block = mine_hard(qn, p, int(hard_subset_H))
            bank_blocks.append(mined_block)
        else:
            bank_blocks.append(qn)

    # optional XBP block
    xbp_offsets = None
    if xbp_bank is not None and xbp_bank.numel() > 0:
        xbp_offsets = sum(b.size(0) for b in bank_blocks)  # start col for XBP block
        bank_blocks.append(xbp_bank)

    bank = torch.cat(bank_blocks, dim=0) if len(bank_blocks) > 1 else bank_blocks[0]  # (K,D)
    logits = (p @ bank.t()) / float(tau)                                              # (B,K)
    K = bank.size(0)

    # build pos mask: in-batch + XBP
    pos_mask = torch.zeros(B, K, dtype=torch.bool, device=pred.device)
    names = list(map(str, batch_img_names))
    for i in range(B):
        for j in range(B):
            if names[i] == names[j]:
                pos_mask[i, j] = True
    if xbp_offsets is not None and xbp_pos_cols_per_row is not None:
        for i, cols in enumerate(xbp_pos_cols_per_row):
            for c in cols:
                pos_mask[i, xbp_offsets + c] = True

    assert pos_mask.any(dim=1).all(), "Every row must have at least one positive."

    # denominator: standard vs debiased
    if not use_dcl:
        logZ = torch.logsumexp(logits, dim=1)
    else:
        logZ = _debiased_logZ(logits, pos_mask, class_prior=float(dcl_prior))

    logits_pos = logits.masked_fill(~pos_mask, float('-inf'))
    logPos = torch.logsumexp(logits_pos, dim=1)
    return (-(logPos - logZ)).mean()


In [19]:
# =============== Dual Queue: Prediction FIFO (F2.3) ===============
class PredQueue:
    def __init__(self, dim: int, capacity: int, device: torch.device):
        self.capacity = int(capacity)
        self.device = device
        self.ptr = 0
        self.full = False
        self.feats = torch.zeros(self.capacity, dim, device=device)  # normalized feats
        self.ids = torch.empty(self.capacity, dtype=torch.long, device=device)  # hashed img ids

    @torch.no_grad()
    def enqueue(self, feats: torch.Tensor, img_ids: list[str]):
        feats = F.normalize(feats.detach(), dim=-1)
        ids = torch.tensor([hash(str(i)) for i in img_ids], dtype=torch.long, device=self.device)
        n = feats.size(0)
        if n >= self.capacity:
            self.feats.copy_(feats[-self.capacity:])
            self.ids.copy_(ids[-self.capacity:])
            self.ptr = 0
            self.full = True
            return
        end = self.ptr + n
        if end <= self.capacity:
            self.feats[self.ptr:end] = feats
            self.ids[self.ptr:end] = ids
        else:
            cut = self.capacity - self.ptr
            self.feats[self.ptr:] = feats[:cut]
            self.ids[self.ptr:] = ids[:cut]
            self.feats[:end - self.capacity] = feats[cut:]
            self.ids[:end - self.capacity] = ids[cut:]
        self.ptr = (self.ptr + n) % self.capacity
        if self.ptr == 0:
            self.full = True

    def size(self) -> int:
        if self.full:
            return self.capacity
        return self.ptr

    def recent(self, max_items: int):
        n = self.size()
        if n == 0:
            return self.feats[:0], self.ids[:0]
        k = min(int(max_items), n) if max_items is not None else n
        if self.full:
            idx = torch.arange(self.ptr - n, self.ptr, device=self.device) % self.capacity
        else:
            idx = torch.arange(0, n, device=self.device)
        idx = idx[-k:]
        return self.feats.index_select(0, idx), self.ids.index_select(0, idx)


In [20]:
def make_model(
    arch: str, din: int, dout: int, geom_init_data=None,
    residual_ortho: str = "project",
    ortho_eps: float = 1e-4,
    ortho_lambda: float = 0.05,
    unfreeze_geom_bias: bool = False,
):
    a = arch.lower()
    if a == "linear":
        return LinearProj(din, dout)
    if a == "mlp1":
        return MLP1(din, dout, hidden=512, pdrop=0.1)
    if a == "mlp2":
        return MLP2(din, dout, h1=1024, h2=512, pdrop=0.1)
    if a == "geom":
        assert geom_init_data is not None, "geom_init_data required for arch='geom'."
        Xtr_np, Ytr_np, eps = geom_init_data
        A, b = procrustes_closed_form(Xtr_np, Ytr_np, eps=float(eps))
        return GeomLinear(A, b)
    if a in ("geom_adapter", "geom+adapter", "bigjump"):
        assert geom_init_data is not None, "geom_init_data required for geom+adapter/bigjump."
        Xtr_np, Ytr_np, eps = geom_init_data
        A, b = procrustes_closed_form(Xtr_np, Ytr_np, eps=float(eps))
        return GeomWithAdapter(
            A, b, din=din, dout=dout, hidden=1024, pdrop=0.1,
            residual_ortho=residual_ortho, ortho_eps=ortho_eps,
            ortho_lambda=ortho_lambda, unfreeze_bias=unfreeze_geom_bias
        )
    if a == "auto":
        return MLP2(din, dout)
    raise ValueError(f"Unknown arch {arch}")

In [38]:
# ---------- Centroid helpers ----------
def _centroids_from_img_ids(X_rows: np.ndarray, Y_rows: np.ndarray, img_names_rows: np.ndarray):
    """
    Build per-image centroids from per-caption rows.
    Returns:
      Xc: (N_img, 1024)   mean text per image
      Yc: (N_img, 1536)   mean image per image (usually equals the single image vector)
      order_names: list[str] image names in same order
    """
    from collections import defaultdict
    buckets_x = defaultdict(list)
    buckets_y = defaultdict(list)
    for x, y, n in zip(X_rows, Y_rows, map(str, img_names_rows)):
        buckets_x[n].append(x)
        buckets_y[n].append(y)
    order_names = sorted(buckets_x.keys())
    Xc = np.stack([np.mean(buckets_x[n], axis=0) for n in order_names], axis=0).astype(np.float32)
    Yc = np.stack([np.mean(buckets_y[n], axis=0) for n in order_names], axis=0).astype(np.float32)
    return Xc, Yc, order_names

# ---------- Geometry: closed-form (whiten→orthogonal align→recolor) on CENTROIDS ----------
def procrustes_closed_form_centroids(X_rows: np.ndarray, Y_rows: np.ndarray, img_names_rows: np.ndarray, eps: float = 1e-5):
    """
    Fit linear A,b using per-image centroids instead of per-caption rows.
    """
    Xc, Yc, _ = _centroids_from_img_ids(X_rows, Y_rows, img_names_rows)

    mu_x = Xc.mean(0, dtype=np.float64)
    mu_y = Yc.mean(0, dtype=np.float64)
    Xzc = Xc - mu_x
    Yzc = Yc - mu_y

    def _cov_eigh(zc, eps):
        S, U = np.linalg.eigh((zc.T @ zc) / max(1, zc.shape[0]))
        S = np.clip(S, 0.0, None)
        inv_sqrt = 1.0 / np.sqrt(S + eps)
        sqrt     = np.sqrt(S + eps)
        C_mhalf = (U * inv_sqrt) @ U.T
        C_phalf = (U * sqrt)     @ U.T
        return C_mhalf.astype(np.float32), C_phalf.astype(np.float32)

    Cx_mh, _     = _cov_eigh(Xzc, eps)      # 1024x1024
    Cy_mh, Cy_ph = _cov_eigh(Yzc, eps)      # 1536x1536
    Xw = Xzc @ Cx_mh.T                      # (Nimg,1024)
    Yw = Yzc @ Cy_mh.T                      # (Nimg,1536)

    # rectangular orthogonal alignment in whitened space
    M = Xw.T @ Yw                           # (1024,1536)
    U, _, Vt = np.linalg.svd(M, full_matrices=False)
    R = (U @ Vt).T                          # (1536,1024)

    # recolor to Y space
    A = (Cy_ph @ R @ Cx_mh).astype(np.float32)     # (1536,1024)
    b = (mu_y - (A @ mu_x)).astype(np.float32)     # (1536,)
    return A, b


In [39]:
# ---------- Losses: fixed-τ InfoNCE + Triplet (cosine) ----------
def info_nce_fixed(pred: torch.Tensor, tgt: torch.Tensor, tau: float):
    # both should be L2-normalized already
    logits = (pred @ tgt.t()) / float(tau)                 # (B,B)
    labels = torch.arange(pred.size(0), device=pred.device)
    return F.cross_entropy(logits, labels)

class TripletCosineLoss(nn.Module):
    def __init__(self, margin: float = 0.2):
        super().__init__()
        self.loss = nn.TripletMarginWithDistanceLoss(
            distance_function=lambda x, y: 1 - F.cosine_similarity(x, y),
            margin=margin
        )
    def forward(self, anchor, positive, negative):
        return self.loss(anchor, positive, negative)


In [43]:
def main(args):
    """
    Main (advanced, two-stage): 
      - Stage 1: Centroid-Procrustes + Frozen Base + Residual, trained with fixed-τ InfoNCE + Triplet (stable warm-up).
      - Stage 2: Continue from Stage 1 and switch to BigJump trainer with a *small, late* queue + light mining + tiny agreement. 
                 Geometry stays frozen. No hard residual projection.
      - Then: evaluate, log efficiency, generate submission.

    Assumes these already exist in your file:
      load_train, build_image_id_split, procrustes_closed_form_centroids
      ResidualAdapter, train_simple, train_bigjump
      validate_retrieval, count_params_mb, time_ms_per_query
      apply_pooling, load_data, generate_submission, TEST_NPZ
    """
    import json
    from pathlib import Path
    import numpy as np
    import torch
    import torch.nn as nn
    import torch.nn.functional as F

    # -----------------------
    # Setup
    # -----------------------
    out_dir, seed = args.out_dir, args.seed
    pooling, n_patches = args.pooling, None
    OUT = Path(f"/kaggle/working/outputs/{out_dir}")
    OUT.mkdir(parents=True, exist_ok=True)
    STAGE1 = OUT / "stage1"
    STAGE2 = OUT / "stage2"
    STAGE1.mkdir(parents=True, exist_ok=True)
    STAGE2.mkdir(parents=True, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # -----------------------
    # Data
    # -----------------------
    X, Y, cap_ids, img_ids_row, full_img, all_img_ids = load_train(OUT)

    # Split by image id (NO LEAKAGE)
    cap_is_tr, cap_is_val, val_gallery, cap2gal_local, val_img_indices = build_image_id_split(
        img_ids_row, all_img_ids, full_img, X, Y, args.val_ratio, seed, OUT
    )
    val_set = set(all_img_ids[val_img_indices])
    train_set = set(all_img_ids) - val_set
    assert val_set.isdisjoint(train_set), "Leakage detected between TRAIN and VAL image sets!"

    # Masks to arrays
    Xtr, Ytr = X[cap_is_tr], Y[cap_is_tr]
    Xva      = X[cap_is_val]

    din, dout = X.shape[1], Y.shape[1]
    assert din == 1024 and dout == 1536, f"Dimension mismatch: text={din}, image={dout} (expected 1024→1536)"

    # -----------------------
    # Geometry init (CENTROIDS on TRAIN ONLY)
    # -----------------------
    A, b = procrustes_closed_form_centroids(
        Xtr.astype(np.float32),
        Ytr.astype(np.float32),
        img_ids_row[cap_is_tr],
        eps=float(args.geom_eps)
    )

    # -----------------------
    # Model: Frozen base (A,b) + residual adapter
    # -----------------------
    class GeomFromWeights(nn.Module):
        def __init__(self, A_np, b_np):
            super().__init__()
            self.fc = nn.Linear(A_np.shape[1], A_np.shape[0], bias=True)
            with torch.no_grad():
                self.fc.weight.copy_(torch.from_numpy(A_np))
                self.fc.bias.copy_(torch.from_numpy(b_np))
        def forward(self, x): return self.fc(x)

    geom = GeomFromWeights(A, b).to(device)
    for p in geom.parameters(): p.requires_grad = False  # keep base frozen in both stages

    # residual capacity like BigJump adapter
    adapter = ResidualAdapter(din=din, hidden=1024, dout=dout, pdrop=0.1)

    class GeomPlusAdapter(nn.Module):
        def __init__(self, g, a): super().__init__(); self.geom=g; self.adapter=a
        def forward(self, x): return self.geom(x) + self.adapter(x)

    model = GeomPlusAdapter(geom, adapter).to(device)

    # Probe
    with torch.no_grad():
        _probe = model(torch.zeros(2, din, device=device))
    assert _probe.shape[-1] == 1536, f"Translator output dim { _probe.shape[-1] } != 1536"

    # -----------------------
    # Stage 1: simple warm-up (fixed τ InfoNCE + Triplet, in-batch negatives)
    # -----------------------
    if not args.eval_only:
        print(f"[stage1] τ_fixed={args.tau:.3f} | triplet_w={args.triplet_w} | margin={args.triplet_margin} | epochs={args.stage1_epochs}")
        best_stats_s1 = train_simple(
            model,
            Xtr, Ytr, img_ids_row[cap_is_tr],
            Xva, val_gallery, cap2gal_local,
            batch=args.batch, epochs=int(args.stage1_epochs), lr=args.lr, wd=args.wd,
            tau_fixed=float(args.tau), triplet_margin=float(args.triplet_margin), triplet_w=float(args.triplet_w),
            alpha_cos=float(args.alpha), moment_w=float(args.moment),
            use_pk=bool(args.use_pk), P=int(args.P), K_pk=int(args.K_pk),
            device=device, pooling=pooling, n_patches=n_patches,
            out_dir=str(STAGE1), seed=seed,
        )
        # Load best from stage1 (ensure we continue from the best)
        try:
            ckpt1 = torch.load(STAGE1 / "best.pt", map_location=device)
            model.load_state_dict(ckpt1["model"])
            print(f"[stage1] resume best epoch={ckpt1.get('epoch','?')} MRR={ckpt1.get('val',{}).get('MRR','?')}")
        except FileNotFoundError:
            pass
    else:
        best_stats_s1 = None

    # -----------------------
    # Stage 2: BigJump+ (small queue, late; light mining; tiny agreement; geometry still frozen)
    # -----------------------
    if not args.eval_only and args.stage2_epochs > 0:
        print(f"[stage2] queue_warmup={args.queue_warmup_epochs} | recent_k={args.queue_recent_k} | mine_H={args.mine_H} | "
              f"τ_sched={args.tau_start:.3f}→{args.tau_end:.3f} | agree={args.lambda_agree}")
        # Keep geometry frozen in BigJump
        best_stats_s2 = train_bigjump(
            model,
            Xtr, Ytr, img_ids_row[cap_is_tr],
            Xva, val_gallery, cap2gal_local,
            batch=args.batch, epochs=int(args.stage2_epochs), base_lr=args.lr, wd=args.wd,
            # legacy knobs (we’ll let InfoNCE dominate; keep cosine light, moment tiny/off)
            tau=None,                                 # use τ curriculum in stage2
            alpha_cos=float(args.alpha_stage2),       # e.g., 0.3
            lambda_moment=float(args.moment_stage2),  # e.g., 0.01
            queue_size=int(args.queue),
            # schedules
            tau_start=float(args.tau_start), tau_end=float(args.tau_end),
            queue_warmup_epochs=int(args.queue_warmup_epochs),
            queue_recent_schedule=tuple(args.queue_recent_k),
            lambda_agree=float(args.lambda_agree),
            # keep geometry frozen in stage2 as well
            geom_unfreeze_epoch=0,
            geom_lr_scale=float(args.geom_lr_scale),
            # misc
            device=device, pooling=pooling, n_patches=n_patches,
            out_dir=str(STAGE2), seed=seed,
            # NEW knobs
            use_pk=bool(args.use_pk), P=int(args.P), K_pk=int(args.K_pk),
            xbp_per_img=int(args.xbp_per_img), xbp_global=int(args.xbp_global),
            use_dcl=bool(args.use_dcl), dcl_prior=float(args.dcl_prior),
            mine_H=int(args.mine_H),
            use_pred_queue=False, queue_pred_capacity=int(args.queue_pred_capacity),
        )
        # Load best from stage2 for final export
        try:
            ckpt2 = torch.load(STAGE2 / "best.pt", map_location=device)
            model.load_state_dict(ckpt2["model"])
            best_stats_final = ckpt2.get("val", None)
            print(f"[stage2] resume best epoch={ckpt2.get('epoch','?')} MRR={ckpt2.get('val',{}).get('MRR','?')}")
        except FileNotFoundError:
            best_stats_final = best_stats_s1
    else:
        # eval-only: try to load stage2 first, else stage1
        best_stats_final = None
        if args.eval_only:
            for cand in [STAGE2 / "best.pt", STAGE1 / "best.pt", OUT / "best.pt"]:
                try:
                    ckpt = torch.load(cand, map_location=device)
                    model.load_state_dict(ckpt["model"])
                    best_stats_final = ckpt.get("val", None)
                    print(f"[resume] loaded {cand} | epoch={ckpt.get('epoch','?')} MRR={ckpt.get('val',{}).get('MRR','?')}")
                    break
                except FileNotFoundError:
                    continue

    # -----------------------
    # Efficiency logging
    # -----------------------
    params, mb = count_params_mb(model)
    ms_gpu, ms_cpu = time_ms_per_query(model, din, pooling, n_patches)
    eff = {"params": params, "mb_fp32": mb, "ms_per_query_gpu": ms_gpu, "ms_per_query_cpu": ms_cpu}
    (OUT / "efficiency.json").write_text(json.dumps(eff, indent=2))

    # -----------------------
    # Submission
    # -----------------------
    test_data = load_data(TEST_NPZ)
    Q   = test_data["captions/embeddings"].astype(np.float32)
    ids = test_data.get("captions/ids", np.arange(len(Q)).astype(str))
    model.eval()
    BS, outs = 1024, []
    with torch.no_grad():
        for i in range(0, len(Q), BS):
            q = torch.from_numpy(Q[i:i+BS]).to(device)
            q = apply_pooling(q, pooling, n_patches)
            z = model(q)
            z = F.normalize(z, dim=-1)
            outs.append(z.detach().cpu().numpy())
    pred_embds = np.concatenate(outs, axis=0)
    sub = OUT / "submission.csv"
    generate_submission(ids, pred_embds, str(sub))
    print(f"[ok] submission written → {sub}")

    # -----------------------
    # Final sanity print
    # -----------------------
    if best_stats_final is None:
        best_stats_final = validate_retrieval(model, Xva, val_gallery, cap2gal_local, pooling, n_patches)

    sanity = {
        "dims": {"text": int(din), "image": int(dout)},
        "split": {
            "train_captions": int(cap_is_tr.sum()),
            "val_captions": int(cap_is_val.sum()),
            "val_unique_images": int(val_gallery.shape[0]),
            "leakage": False
        },
        "val_metrics": best_stats_final,
        "efficiency": eff,
        "recipe": {
            "stage1": {
                "type": "simple_fixed_tau_infoNCE+triplet",
                "tau_fixed": float(args.tau),
                "triplet_margin": float(args.triplet_margin),
                "triplet_w": float(args.triplet_w),
                "alpha_cos": float(args.alpha),
                "moment_w": float(args.moment),
                "epochs": int(args.stage1_epochs),
                "use_pk": bool(args.use_pk), "P": int(args.P), "K": int(args.K_pk)
            },
            "stage2": {
                "type": "bigjump_plus_small_queue_light_mining",
                "epochs": int(args.stage2_epochs),
                "tau_sched": [float(args.tau_start), float(args.tau_end)],
                "queue_warmup_epochs": int(args.queue_warmup_epochs),
                "queue_recent_schedule": list(map(int, args.queue_recent_k)),
                "mine_H": int(args.mine_H),
                "lambda_agree": float(args.lambda_agree),
                "geom_unfreeze_epoch": 0
            },
            "centroid_procrustes": True,
            "geom_frozen": True,
            "residual_projection": "none"
        }
    }
    print(json.dumps(sanity, indent=2))


In [45]:
# =========================
#        ARGPARSE
# =========================
if __name__ == "__main__":
    import argparse
    p = argparse.ArgumentParser(description="RoBERTa→DINOv2 | 2-Stage (Centroid-Procrustes Warmup → BigJump+)")

    # IO & run mode
    p.add_argument("--out_dir", type=str, default="two_stage_push")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--eval_only", action="store_true")

    # data & split
    p.add_argument("--val_ratio", type=float, default=0.10)

    # architecture (we still build geom+adapter internally)
    p.add_argument("--arch", type=str, default="bigjump",
                   choices=["linear","mlp1","mlp2","geom","geom_adapter","geom+adapter","bigjump"])

    # optimization (shared)
    p.add_argument("--epochs", type=int, default=0, help="(unused; stages have their own epochs)")
    p.add_argument("--batch", type=int, default=512)
    p.add_argument("--lr", type=float, default=2e-4)
    p.add_argument("--wd", type=float, default=1e-4)

    # geometry init
    p.add_argument("--geom_eps", type=float, default=1e-5)

    # pooling (no-op)
    p.add_argument("--pooling", type=str, default="CLS", choices=["CLS"])

    # ---------- STAGE 1 (Simple: fixed-τ InfoNCE + Triplet, in-batch negatives) ----------
    p.add_argument("--stage1_epochs", type=int, default=12)
    p.add_argument("--tau", type=float, default=0.07, help="Fixed InfoNCE temperature for Stage 1.")
    p.add_argument("--triplet_margin", type=float, default=0.2)
    p.add_argument("--triplet_w", type=float, default=0.3, help="Triplet weight; InfoNCE weight = 1 - triplet_w.")
    p.add_argument("--alpha", type=float, default=0.0, help="Cosine auxiliary weight in Stage 1.")
    p.add_argument("--moment", type=float, default=0.0, help="Moment alignment weight in Stage 1.")

    # Optional PK sampler (both stages)
    p.add_argument("--use_pk", action="store_true")
    p.add_argument("--P", type=int, default=192)
    p.add_argument("--K_pk", type=int, default=3)

    # ---------- STAGE 2 (BigJump+: small late queue, light mining, tiny agreement) ----------
    p.add_argument("--stage2_epochs", type=int, default=13)
    p.add_argument("--tau_start", type=float, default=0.10)
    p.add_argument("--tau_end",   type=float, default=0.06)

    p.add_argument("--queue", type=int, default=65536)
    p.add_argument("--queue_warmup_epochs", type=int, default=3)
    p.add_argument("--queue_recent_k", type=int, nargs=3, default=[8000, 16000, 65536])

    p.add_argument("--mine_H", type=int, default=32, help="Top-H hard negatives per anchor (recent queue slice).")
    p.add_argument("--lambda_agree", type=float, default=0.02, help="Caption-agreement weight (tiny).")
    p.add_argument("--alpha_stage2", type=float, default=0.3, help="Cosine auxiliary in Stage 2.")
    p.add_argument("--moment_stage2", type=float, default=0.01, help="Moment alignment in Stage 2.")
    p.add_argument("--geom_lr_scale", type=float, default=0.05)

    # XBP / DCL / dual-queue toggles (kept for completeness; default off/safe)
    p.add_argument("--xbp_per_img", type=int, default=4)
    p.add_argument("--xbp_global", type=int, default=32000)
    p.add_argument("--use_dcl", action="store_true")
    p.add_argument("--dcl_prior", type=float, default=0.01)
    p.add_argument("--queue_pred_capacity", type=int, default=65536)

    args, _ = p.parse_known_args()
    main(args)


[meta] text_dim=1024 | image_dim=1536
[stage1] τ_fixed=0.070 | triplet_w=0.3 | margin=0.2 | epochs=12
[simple 01] loss=2.166024 | val_MRR=0.3573 | R@1=0.232 R@5=0.496 R@10=0.620 | median=6 p75=23
[simple 02] loss=2.019182 | val_MRR=0.3671 | R@1=0.241 R@5=0.508 R@10=0.635 | median=5 p75=21
[simple 03] loss=1.967485 | val_MRR=0.3740 | R@1=0.246 R@5=0.514 R@10=0.643 | median=5 p75=19
[simple 04] loss=1.931230 | val_MRR=0.3808 | R@1=0.251 R@5=0.525 R@10=0.651 | median=5 p75=19
[simple 05] loss=1.901381 | val_MRR=0.3852 | R@1=0.256 R@5=0.531 R@10=0.658 | median=5 p75=18
[simple 06] loss=1.875517 | val_MRR=0.3887 | R@1=0.259 R@5=0.534 R@10=0.663 | median=5 p75=17
[simple 07] loss=1.854193 | val_MRR=0.3924 | R@1=0.261 R@5=0.541 R@10=0.669 | median=5 p75=17
[simple 08] loss=1.834112 | val_MRR=0.3955 | R@1=0.266 R@5=0.542 R@10=0.669 | median=4 p75=17
[simple 09] loss=1.816755 | val_MRR=0.3977 | R@1=0.267 R@5=0.546 R@10=0.674 | median=5 p75=16
[simple 10] loss=1.799566 | val_MRR=0.3992 | R@1=0.2