In [3]:
from kaggle_secrets import UserSecretsClient
import json, os, pathlib, subprocess, sys

# --- 1. Load your secret (full JSON from kaggle.json) ---
secret_name = "kaggle_json"  # change if you used another name
user_secrets = UserSecretsClient()
raw = user_secrets.get_secret(secret_name)
creds = json.loads(raw)

# --- 2. Forcefully recreate ~/.kaggle/kaggle.json ---
kaggle_dir = pathlib.Path.home() / ".kaggle"
kaggle_dir.mkdir(parents=True, exist_ok=True)
cred_path = kaggle_dir / "kaggle.json"
cred_path.write_text(json.dumps(creds))
os.chmod(cred_path, 0o600)

# --- 3. Double-check the file actually exists and is readable ---
print("✅ Credentials written to:", cred_path)
!ls -la ~/.kaggle/
#!cat ~/.kaggle/kaggle.json | head -1

# --- 4. Reinstall Kaggle CLI cleanly ---
#!pip install --upgrade --force-reinstall kaggle --quiet



✅ Credentials written to: /root/.kaggle/kaggle.json
total 16
drwxr-xr-x 2 root root 4096 Oct 27 07:38 .
drwx------ 1 root root 4096 Oct 27 07:38 ..
-rw------- 1 root root   72 Oct 27 07:38 kaggle.json


In [4]:
!kaggle competitions files -c aml-competition
!kaggle competitions download -c aml-competition -p /kaggle/working --force
!mkdir -p /kaggle/working/data
!unzip -o /kaggle/working/aml-competition.zip -d /kaggle/working/data
!ls -lah /kaggle/working/data

Next Page Token = CfDJ8IaGWDgvvrBFtGGva9hUIY4Sz5Tx-ApsKEfRQmI9nBAWjNraQqgt0RP_CeydXVePWILeLEQEHZkdxSUnLYbYLOs
name                                     size  creationDate                
---------------------------------  ----------  --------------------------  
test/test/captions.txt                  90426  2025-10-25 16:32:52.931000  
test/test/test.clean.npz              5765331  2025-10-25 16:32:52.931000  
train/train/Images/1000092795.jpg      218143  2025-10-25 16:32:52.931000  
train/train/Images/10002456.jpg        113525  2025-10-25 16:32:52.931000  
train/train/Images/1000268201.jpg      199606  2025-10-25 16:32:52.931000  
train/train/Images/1000344755.jpg      154005  2025-10-25 16:32:52.931000  
train/train/Images/1000366164.jpg      103316  2025-10-25 16:32:52.931000  
train/train/Images/1000919630.jpg      117183  2025-10-25 16:32:52.931000  
train/train/Images/10010052.jpg         44514  2025-10-25 16:32:52.931000  
train/train/Images/1001465944.jpg      141082  2025-10

In [5]:
!git clone https://github.com/Mamiglia/challenge.git

Cloning into 'challenge'...
remote: Enumerating objects: 98, done.[K
remote: Counting objects: 100% (98/98), done.[K
remote: Compressing objects: 100% (69/69), done.[K
remote: Total 98 (delta 39), reused 72 (delta 26), pack-reused 0 (from 0)[K
Receiving objects: 100% (98/98), 21.03 MiB | 22.17 MiB/s, done.
Resolving deltas: 100% (39/39), done.


What we learned:

- text_dim=1024 | image_dim=1536 we using the fixed encoders (roberta-large-nli-stsb-mean-tokens and dinov2-giant).
-  leakage=False, val split by image id → new validation pipeline is officially aligned with the challenge spec.
- Cosine retrieval (F.normalize, dot product) matches the public LB similarity.

The LB expects already-normalized embeddings.

- Raw outputs distort similarity magnitudes (pred norms ≈ 21 vs image norms ≈ 26 → biased dot products).
-  Always normalize before writing the submission.

mean||pred|| = 21.213 and mean||image|| = 25.939
he LB similarity is cosine on normalized vectors, i.e. sim = (z_pred / ||z_pred||) @ (z_img / ||z_img||).T

In [None]:
# runner.py
import argparse, json, random, re, time
from pathlib import Path
from os.path import basename

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from challenge.src.common import load_data, generate_submission

# ---------- Paths ----------
DATA_ROOT = Path("/kaggle/working/data")
TRAIN_DIR = DATA_ROOT / "train" / "train"
TEST_DIR  = DATA_ROOT / "test"  / "test"
TRAIN_NPZ = TRAIN_DIR / "train.npz"
TEST_NPZ  = TEST_DIR  / "test.clean.npz"
TRAIN_CAPTIONS = TRAIN_DIR / "captions.txt"
TEST_CAPTIONS  = TEST_DIR  / "captions.txt"

# ---------- Determinism ----------
def seed_all(s: int):
    random.seed(s); np.random.seed(s)
    torch.manual_seed(s); torch.cuda.manual_seed_all(s)
    torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False

# ---------- Metadata/safety ----------
def _assert_metadata_dims(npz_path: Path, expect_text=1024, expect_img=1536):
    d = np.load(npz_path, allow_pickle=True)
    md = {k: d[k][0] for k in d.files if k.startswith("metadata/")}
    tdim = md.get("metadata/embedding_dim_text", None)
    idim = md.get("metadata/embedding_dim_image", None)
    print(f"[meta] text_dim={tdim} | image_dim={idim}")
    assert tdim == expect_text and idim == expect_img, (
        f"Encoder dims mismatch: expected text={expect_text}, image={expect_img} "
        f"but got text={tdim}, image={idim}. Regenerate NPZ with the fixed encoders."
    )

# ---------- Caption→image matching from captions.txt ----------
def _build_image_index(image_ids):
    idx_exact = {}; idx_base={}; idx_stem={}
    for i, v in enumerate(image_ids):
        s = str(v); idx_exact[s]=i
        b = basename(s); idx_base[b]=i
        stem = b.rsplit(".",1)[0] if "." in b else b
        idx_stem[stem]=i
    return idx_exact, idx_base, idx_stem

def _iter_targets_from_captions(path: Path, n_text: int, idx_exact, idx_base, idx_stem):
    targets=[]
    with open(path, "r", encoding="utf-8") as f:
        for raw in f:
            if len(targets)>=n_text: break
            line=raw.strip()
            if not line: continue
            # be robust to delimiters
            parts=re.split(r'\||\t|,| {2,}', line)
            if len(parts)==1: parts=line.split(" ",1)
            tok=parts[0].strip().strip('"').strip("'")
            if not tok: continue
            def _match(tok_):
                if tok_ in idx_exact: return idx_exact[tok_]
                b=basename(tok_)
                if b in idx_base: return idx_base[b]
                stem=b.rsplit(".",1)[0] if "." in b else b
                if stem in idx_stem: return idx_stem[stem]
                return None
            idx=_match(tok)
            if idx is None:
                tok2=tok.replace("Images/","").replace("./","")
                idx=_match(tok2)
            if idx is None:
                if "." not in tok: continue
                raise AssertionError(f"Could not match image token '{tok}'")
            targets.append(idx)
    assert len(targets)==n_text, f"matched {len(targets)} vs N_text {n_text}"
    return np.asarray(targets, dtype=np.int64)

# ---------- Data loaders ----------
def load_train(out_dir: Path):
    _assert_metadata_dims(TRAIN_NPZ, 1024, 1536)
    d=np.load(TRAIN_NPZ, allow_pickle=True)
    X=d["captions/embeddings"].astype(np.float32)  # (N_text,1024)
    I=d["images/embeddings"].astype(np.float32)    # (N_img,1536)
    cap_ids=d.get("captions/ids", np.arange(len(X)).astype(str))
    img_names=d.get("images/names", np.arange(len(I)).astype(str))

    ex,ba,st=_build_image_index(img_names)
    assert TRAIN_CAPTIONS.exists(), f"Missing {TRAIN_CAPTIONS}"
    targets=_iter_targets_from_captions(TRAIN_CAPTIONS, len(X), ex,ba,st)

    Y=I[targets]                       # (N_text, 1536) GT image vec per caption
    img_ids_row = img_names[targets]   # image name per caption row (for splitting)
    meta={"n_text":int(len(X)),"n_images":int(len(I))}
    (out_dir/"train_detect.json").write_text(json.dumps(meta, indent=2))
    return X,Y,cap_ids,img_ids_row,I,img_names

def load_test_npz():
    d_test=np.load(TEST_NPZ, allow_pickle=True)
    Q=d_test["captions/embeddings"].astype(np.float32)
    q_ids=d_test.get("captions/ids", np.arange(len(Q)).astype(str))
    if "images/embeddings" in d_test.files:
        G=d_test["images/embeddings"].astype(np.float32)
        g_ids=d_test.get("images/names", np.arange(len(G)).astype(str))
    else:
        d_tr=np.load(TRAIN_NPZ, allow_pickle=True)
        G=d_tr["images/embeddings"].astype(np.float32)
        g_ids=d_tr.get("images/names", np.arange(len(G)).astype(str))
    return Q,G,q_ids,g_ids

# ---------- Dataset ----------
class PairDS(Dataset):
    def __init__(self, X, Y): self.X=torch.from_numpy(X); self.Y=torch.from_numpy(Y)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, i): return self.X[i], self.Y[i]

# ---------- Pooling (no-op per spec) ----------
def apply_pooling(x: torch.Tensor, mode: str, n_patches=None):
    return x

# ---------- Base Models ----------
class LinearProj(nn.Module):
    def __init__(self, din, dout):
        super().__init__()
        self.fc=nn.Linear(din,dout)
        nn.init.xavier_normal_(self.fc.weight); nn.init.zeros_(self.fc.bias)
    def forward(self,x): return self.fc(x)

class MLP1(nn.Module):
    def __init__(self, din, dout, hidden=512, pdrop=0.1):
        super().__init__()
        self.net=nn.Sequential(
            nn.Linear(din,hidden), nn.ReLU(), nn.Dropout(pdrop),
            nn.Linear(hidden,dout)
        )
        for m in self.net:
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight); nn.init.zeros_(m.bias)
    def forward(self,x): return self.net(x)

class MLP2(nn.Module):
    # Spec: 1024→1024→512→1536, dropout=0.1 on hidden layers
    def __init__(self, din, dout, h1=1024, h2=512, pdrop=0.1):
        super().__init__()
        self.net=nn.Sequential(
            nn.Linear(din,h1), nn.ReLU(), nn.Dropout(pdrop),
            nn.Linear(h1,h2), nn.ReLU(), nn.Dropout(pdrop),
            nn.Linear(h2,dout)
        )
        for m in self.net:
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight); nn.init.zeros_(m.bias)
    def forward(self,x): return self.net(x)

# ---------- Geometry-Preserving Linear (Whiten → Procrustes → Re-color) ----------
def _cov_eigh(zc, eps):
    # zc: (N,D), zero-mean
    N = zc.shape[0]
    C = (zc.T @ zc) / max(1, N)
    # symmetric PSD
    S, U = np.linalg.eigh(C)
    S = np.clip(S, 0.0, None)
    inv_sqrt = 1.0 / np.sqrt(S + eps)
    sqrt = np.sqrt(S + eps)
    C_mhalf = (U * inv_sqrt) @ U.T     # C^{-1/2}
    C_phalf = (U * sqrt) @ U.T         # C^{ 1/2}
    return C_mhalf.astype(np.float32), C_phalf.astype(np.float32)

def procrustes_closed_form(X, Y, eps=1e-5):
    """
    X: (N, 1024) text, Y: (N, 1536) targets (per-caption image vectors)
    Returns A (1536x1024), b (1536,) such that y_hat = A x + b
    """
    # center
    mu_x = X.mean(0, dtype=np.float64)
    mu_y = Y.mean(0, dtype=np.float64)
    Xc = X - mu_x
    Yc = Y - mu_y

    # whiten both
    Cx_mh, _ = _cov_eigh(Xc, eps)      # 1024x1024
    Cy_mh, Cy_ph = _cov_eigh(Yc, eps)  # 1536x1536 (only Cy_ph used)
    Xw = Xc @ Cx_mh.T                  # (N,1024)
    Yw = Yc @ Cy_mh.T                  # (N,1536)

    # orthogonal Procrustes: maximize Tr(R^T Xw^T Yw)
    M = Xw.T @ Yw                      # (1024,1536)
    # SVD on M; for rectangular, do SVD and build R = U V^T in common subspace
    U, _, Vt = np.linalg.svd(M, full_matrices=False)
    R = U @ Vt                         # (1024,1536) @ (1536,1536) -> (1024,1536) OK since full_matrices=False
    # we need R as (1536x1024) mapping whitened X to whitened Y; above is (1024x1536)
    R = R.T                            # (1536,1024)

    # re-color to Y space
    A = (Cy_ph @ R @ Cx_mh).astype(np.float32)     # (1536,1536)*(1536,1024)*(1024,1024) = (1536,1024)
    b = (mu_y - (A @ mu_x)).astype(np.float32)     # (1536,)
    return A, b

class GeomLinear(nn.Module):
    """
    Linear layer with weights initialized from closed-form geometry mapping.
    Optionally fine-tunes with tiny LR.
    """
    def __init__(self, A: np.ndarray, b: np.ndarray):
        super().__init__()
        D_in = A.shape[1]; D_out = A.shape[0]
        self.fc = nn.Linear(D_in, D_out, bias=True)
        with torch.no_grad():
            self.fc.weight.copy_(torch.from_numpy(A))
            self.fc.bias.copy_(torch.from_numpy(b))
    def forward(self, x):
        return self.fc(x)

# ---------- Loss pieces (for optional fine-tune) ----------
def moment_align(pred, tgt):
    mu_p, mu_t = pred.mean(0), tgt.mean(0)
    sd_p, sd_t = pred.std(0, unbiased=False), tgt.std(0, unbiased=False)
    return F.mse_loss(mu_p, mu_t) + F.mse_loss(sd_p, sd_t)

def info_nce(pred, tgt):
    p = F.normalize(pred, dim=-1)
    t = F.normalize(tgt, dim=-1)
    logits = p @ t.t()
    labels = torch.arange(pred.size(0), device=pred.device)
    return F.cross_entropy(logits, labels)

# ---------- Split by image id & mapping ----------
def build_image_id_split(img_ids_row, all_img_names, full_img, X, Y, val_ratio, seed, out_dir: Path):
    uniq_img_names = np.array(sorted(set(map(str, all_img_names))))
    rng = np.random.default_rng(seed); rng.shuffle(uniq_img_names)
    n_val = max(1, int(len(uniq_img_names) * val_ratio))
    val_images = set(uniq_img_names[:n_val])
    tr_images  = set(uniq_img_names[n_val:])
    assert len(val_images & tr_images) == 0, "Leakage in image id split!"

    # Caption masks
    cap_is_val = np.array([str(iid) in val_images for iid in img_ids_row], dtype=bool)
    cap_is_tr  = ~cap_is_val

    # Build VAL gallery (unique images in VAL) as sorted names → local indices
    val_img_names_sorted = np.array(sorted(val_images))
    name2local = {name:i for i,name in enumerate(val_img_names_sorted)}
    name2global = {str(n):i for i,n in enumerate(all_img_names)}
    val_img_indices = np.array([name2global[n] for n in val_img_names_sorted], dtype=np.int64)
    val_gallery = full_img[val_img_indices]  # (M,1536)

    # For each VAL caption row, get local gallery index (no remap later)
    cap2gal_local = np.array([name2local[str(n)] for n in img_ids_row[cap_is_val]], dtype=np.int64)

    # Persist mapping for debugging/acceptance
    (out_dir/"val_indices.json").write_text(json.dumps({
        "val_img_indices": val_img_indices.tolist(),
        "val_caption_to_gallery_index": cap2gal_local.tolist(),
        "n_val_captions": int(cap_is_val.sum()),
        "n_val_unique_images": int(len(val_images))
    }, indent=2))

    return cap_is_tr, cap_is_val, val_gallery, cap2gal_local, val_img_indices

# ---------- Metrics / evaluator (full gallery, cosine on L2) ----------
@torch.no_grad()
def validate_retrieval(model, Xv, val_gallery, cap2gal_local, pooling, n_patches, bs=1024):
    device=next(model.parameters()).device
    Gi = torch.from_numpy(val_gallery).to(device)
    Gi = F.normalize(Gi, dim=-1)

    ranks=[]
    for i in range(0, len(Xv), bs):
        xb=torch.from_numpy(Xv[i:i+bs]).to(device)
        xb=apply_pooling(xb, pooling, n_patches)   # no-op
        pred=model(xb)
        pred=F.normalize(pred, dim=-1)
        sims=pred @ Gi.t()                         # (b, M)
        for j in range(sims.size(0)):
            true_idx = int(cap2gal_local[i+j])     # local gallery index
            order = torch.argsort(sims[j], descending=True)
            rank = (order==true_idx).nonzero(as_tuple=False).item() + 1
            ranks.append(rank)

    ranks = np.array(ranks)
    mrr = float(np.mean(1.0 / ranks))
    r1  = float(np.mean(ranks<=1))
    r5  = float(np.mean(ranks<=5))
    r10 = float(np.mean(ranks<=10))
    return {
        "MRR": mrr,
        "R1": r1,
        "R5": r5,
        "R10": r10,
        "rank_median": int(np.median(ranks)),
        "rank_p75": int(np.percentile(ranks, 75))
    }

def count_params_mb(model):
    params=sum(p.numel() for p in model.parameters())
    mb = params * 4 / (1024**2)
    return params, mb

def time_ms_per_query(model, din, pooling, n_patches):
    device=next(model.parameters()).device
    x=torch.randn(2048, din, device=device)
    x=apply_pooling(x, pooling, n_patches)
    if device.type=="cuda":
        torch.cuda.synchronize()
    t0=time.time()
    with torch.no_grad(): _=model(x)
    if device.type=="cuda":
        torch.cuda.synchronize()
    ms_gpu=(time.time()-t0)*1000/len(x)
    # CPU
    mcpu=model.to("cpu"); xcpu=x.to("cpu")
    t1=time.time()
    with torch.no_grad(): _=mcpu(xcpu)
    ms_cpu=(time.time()-t1)*1000/len(xcpu)
    model.to(device)
    return ms_gpu, ms_cpu

# ---------- Train (for optional fine-tune) ----------
def train_one(model, loader, opt, alpha, beta, gamma, moment_w, pooling, n_patches, device):
    model.train(); total=0.0
    for xb,yb in loader:
        xb,yb=xb.to(device), yb.to(device)
        xb=apply_pooling(xb, pooling, n_patches)
        opt.zero_grad(set_to_none=True)
        pred=model(xb)
        # α·(1−cos) + β·MSE + λ·moment_align + γ·InfoNCE
        cos = 1 - F.cosine_similarity(pred, yb, dim=-1).mean()
        mse = F.mse_loss(pred, yb)
        a_loss = moment_w * moment_align(pred, yb) if moment_w>0 else pred.new_tensor(0.0)
        ce = info_nce(pred, yb) if gamma>0 else pred.new_tensor(0.0)
        loss = alpha*cos + beta*mse + gamma*ce + a_loss
        loss.backward(); opt.step()
        total += loss.item()*xb.size(0)
    return total/len(loader.dataset)

# ---------- Factory ----------
def make_model(arch:str, din:int, dout:int, geom_init_data=None):
    a=arch.lower()
    if a=="linear": return LinearProj(din,dout)
    if a=="mlp1":   return MLP1(din,dout,hidden=512, pdrop=0.1)
    if a=="mlp2":   return MLP2(din,dout,h1=1024,h2=512,pdrop=0.1)
    if a=="geom":
        assert geom_init_data is not None, "geom requires (Xtr, Ytr, eps)"
        Xtr_np, Ytr_np, eps = geom_init_data
        A,b = procrustes_closed_form(Xtr_np, Ytr_np, eps=eps)
        return GeomLinear(A, b)
    if a=="auto":   return MLP2(din,dout)
    raise ValueError(f"Unknown arch {arch}")



In [None]:
# ---------- Main ----------
def main(args):
    # Parse flags
    out_dir=args.out_dir; seed=args.seed; epochs=args.epochs; batch=args.batch
    lr=args.lr; wd=args.wd
    pooling=args.pooling; n_patches=None
    alpha=args.alpha; beta=args.beta; gamma=args.gamma; moment_w=args.moment
    arch=args.arch; do_train=not args.eval_only; val_ratio=args.val_ratio
    geom_eps=args.geom_eps; geom_ft_epochs=args.geom_finetune_epochs; geom_ft_lr=args.geom_finetune_lr; geom_ft_wd=args.geom_finetune_wd

    seed_all(seed)
    OUT = Path(f"/kaggle/working/outputs/{out_dir}"); OUT.mkdir(parents=True, exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # --- load data
    X,Y,cap_ids,img_ids_row,full_img,all_img_ids = load_train(OUT)

    # --- split by image id & mappings
    cap_is_tr, cap_is_val, val_gallery, cap2gal_local, val_img_indices = build_image_id_split(
        img_ids_row, all_img_ids, full_img, X, Y, val_ratio, seed, OUT
    )
    # Safety: no leakage
    val_set = set(all_img_ids[val_img_indices])
    train_set = set(all_img_ids) - val_set
    assert val_set.isdisjoint(train_set), "Leakage detected between TRAIN and VAL image sets!"

    # Propagate masks to captions
    Xtr,Ytr = X[cap_is_tr], Y[cap_is_tr]
    Xva     = X[cap_is_val]

    din, dout = X.shape[1], Y.shape[1]
    # Dim safety
    assert din==1024 and dout==1536, f"Dimension mismatch: text={din}, image={dout} (expected 1024→1536)"

    # --- model (geom has closed-form init)
    geom_init = None
    if arch.lower()=="geom":
        geom_init = (Xtr.astype(np.float32), Ytr.astype(np.float32), float(geom_eps))
    model = make_model(arch, din, dout, geom_init).to(device)

    # Output-dim safety
    with torch.no_grad():
        _probe = model(torch.zeros(2,din,device=device))
    assert _probe.shape[-1]==1536, f"Translator output dim { _probe.shape[-1] } != 1536"

    params, mb = count_params_mb(model)
    print(f"[model] {arch} | params={params:,} (~{mb:.2f} MB) | pooling={pooling} | α={alpha} β={beta} λ_moment={moment_w} γ={gamma}")

    # --- training strategy
    best_stats=None; best=-1.0; best_ep=0

    if arch.lower()=="geom" and geom_ft_epochs>0 and do_train:
        # tiny fine-tune on linear head
        dl = DataLoader(PairDS(Xtr,Ytr), batch_size=batch, shuffle=True, num_workers=2, pin_memory=True, drop_last=False)
        opt = torch.optim.AdamW(model.parameters(), lr=geom_ft_lr, weight_decay=geom_ft_wd)
        for ep in range(1, geom_ft_epochs+1):
            tr = train_one(model, dl, opt, alpha,beta,gamma,moment_w, pooling, n_patches, device)
            stats = validate_retrieval(model, Xva, val_gallery, cap2gal_local, pooling, n_patches)
            print(f"[geom-ft {ep:02d}] train_loss={tr:.6f} | val_MRR={stats['MRR']:.4f} "
                  f"| R@1={stats['R1']:.3f} R@5={stats['R5']:.3f} R@10={stats['R10']:.3f} "
                  f"| median={stats['rank_median']} p75={stats['rank_p75']}")
            if stats["MRR"] > best:
                best, best_ep, best_stats = stats["MRR"], ep, stats
                torch.save({"model":model.state_dict(),"epoch":ep,"val":stats}, OUT/"best.pt")
        if best_stats is None:
            best_stats = validate_retrieval(model, Xva, val_gallery, cap2gal_local, pooling, n_patches)
            torch.save({"model":model.state_dict(),"epoch":0,"val":best_stats}, OUT/"best.pt")
        (OUT/"val_metrics.json").write_text(json.dumps(dict(best_epoch=best_ep, **best_stats), indent=2))

    elif do_train and arch.lower()!="geom":
        # normal training for linear/mlp1/mlp2
        dl = DataLoader(PairDS(Xtr,Ytr), batch_size=batch, shuffle=True, num_workers=2, pin_memory=True, drop_last=False)
        opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
        for ep in range(1, epochs+1):
            tr = train_one(model, dl, opt, alpha,beta,gamma,moment_w, pooling, n_patches, device)
            stats = validate_retrieval(model, Xva, val_gallery, cap2gal_local, pooling, n_patches)
            print(f"[{ep:02d}] train_loss={tr:.6f} | val_MRR={stats['MRR']:.4f} "
                  f"| R@1={stats['R1']:.3f} R@5={stats['R5']:.3f} R@10={stats['R10']:.3f} "
                  f"| median={stats['rank_median']} p75={stats['rank_p75']}")
            if stats["MRR"] > best:
                best, best_ep, best_stats = stats["MRR"], ep, stats
                torch.save({"model":model.state_dict(),"epoch":ep,"val":stats}, OUT/"best.pt")
        if best_stats is None:
            best_stats = validate_retrieval(model, Xva, val_gallery, cap2gal_local, pooling, n_patches)
            torch.save({"model":model.state_dict(),"epoch":0,"val":best_stats}, OUT/"best.pt")
        (OUT/"val_metrics.json").write_text(json.dumps(dict(best_epoch=best_ep, **best_stats), indent=2))
    else:
        # eval-only: load best.pt if present, else just evaluate the freshly-built model
        try:
            ckpt = torch.load(OUT/"best.pt", map_location="cpu")
            print(f"[resume] loaded epoch={ckpt.get('epoch','?')} MRR={ckpt.get('val',{}).get('MRR','?')}")
            model.load_state_dict(ckpt["model"])
            best_stats = ckpt.get("val", None)
        except FileNotFoundError:
            best_stats = validate_retrieval(model, Xva, val_gallery, cap2gal_local, pooling, n_patches)

    # --- efficiency logging
    ms_gpu, ms_cpu = time_ms_per_query(model, din, pooling, n_patches)
    eff = {"params":params,"mb_fp32":mb,"ms_per_query_gpu":ms_gpu,"ms_per_query_cpu":ms_cpu}
    (OUT/"efficiency.json").write_text(json.dumps(eff, indent=2))

    # --- submission (normalized only; test captions → L2 outputs)
    test_data = load_data(TEST_NPZ)
    Q   = test_data["captions/embeddings"].astype(np.float32)
    ids = test_data.get("captions/ids", np.arange(len(Q)).astype(str))
    model.eval()
    BS = 1024
    outs = []
    with torch.no_grad():
        for i in range(0, len(Q), BS):
            q = torch.from_numpy(Q[i:i+BS]).to(device)
            q = apply_pooling(q, pooling, n_patches)
            z = model(q)
            z = F.normalize(z, dim=-1)
            outs.append(z.detach().cpu().numpy())
    pred_embds = np.concatenate(outs, axis=0)
    sub = OUT / "submission.csv"
    generate_submission(ids, pred_embds, str(sub))
    print(f"[ok] submission written → {sub}")

    # --- one-file sanity printout
    sanity = {
        "dims": {"text": int(din), "image": int(dout)},
        "split": {
            "train_captions": int(cap_is_tr.sum()),
            "val_captions": int(cap_is_val.sum()),
            "val_unique_images": int(val_gallery.shape[0]),
            "leakage": False
        },
        "val_metrics": best_stats if best_stats is not None else validate_retrieval(
            model, Xva, val_gallery, cap2gal_local, pooling, n_patches
        ),
        "efficiency": eff
    }
    print(json.dumps(sanity, indent=2))

In [72]:
# ---------- CLI ----------
if __name__ == "__main__":
    p = argparse.ArgumentParser()
    p.add_argument("--out_dir", type=str, default="geom_mrr_first")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--epochs", type=int, default=20)
    p.add_argument("--batch", type=int, default=512)
    p.add_argument("--lr", type=float, default=2e-4)
    p.add_argument("--wd", type=float, default=1e-4)
    p.add_argument("--val_ratio", type=float, default=0.1)
    p.add_argument("--pooling", type=str, default="CLS")  # no-op (kept for compatibility)
    p.add_argument("--alpha", type=float, default=1.0)
    p.add_argument("--beta", type=float, default=1.0)
    p.add_argument("--moment", type=float, default=0.05)  # λ_moment
    p.add_argument("--gamma", type=float, default=0.0)    # InfoNCE (optional, default off)
    p.add_argument("--arch", type=str, default="geom", choices=["geom","linear","mlp1","mlp2","auto"])
    # Geometry step knobs
    p.add_argument("--geom_eps", type=float, default=1e-5, help="Covariance epsilon for whitening")
    p.add_argument("--geom_finetune_epochs", type=int, default=0, help="Tiny linear fine-tune epochs (0 = off)")
    p.add_argument("--geom_finetune_lr", type=float, default=1e-4)
    p.add_argument("--geom_finetune_wd", type=float, default=1e-5)
    p.add_argument("--eval_only", action="store_true", help="Skip training; eval+submission using best.pt or closed-form")
    args, _ = p.parse_known_args()  # notebook-friendly
    main(args)


[meta] text_dim=1024 | image_dim=1536
[model] geom | params=1,574,400 (~6.01 MB) | pooling=CLS | α=1.0 β=1.0 λ_moment=0.05 γ=0.0
Generating submission file...
✓ Saved submission to /kaggle/working/outputs/geom_mrr_first/submission.csv
[ok] submission written → /kaggle/working/outputs/geom_mrr_first/submission.csv
{
  "dims": {
    "text": 1024,
    "image": 1536
  },
  "split": {
    "train_captions": 112500,
    "val_captions": 12500,
    "val_unique_images": 2500,
    "leakage": false
  },
  "val_metrics": {
    "MRR": 0.3191292292010856,
    "R1": 0.20912,
    "R5": 0.4364,
    "R10": 0.54536,
    "rank_median": 8,
    "rank_p75": 40
  },
  "efficiency": {
    "params": 1574400,
    "mb_fp32": 6.005859375,
    "ms_per_query_gpu": 0.00049639493227005,
    "ms_per_query_cpu": 0.02339109778404236
  }
}


fine-tune is hurting because it breaks the beautiful geometry ùù just solved in closed-form (whitening + orthogonal Procrustes + re-color). Mini-batch losses (MSE/moment) nudge A,b away from the orthogonal mapping in whitened space, so the output distribution drifts and MRR drops.

In [None]:
class GeomCalib(nn.Module):
    """
    y_hat = mu_y + Cy^{1/2} * R * ( D * ( Cx^{-1/2} * (x - mu_x) ) ) + b
    - D: diagonal (1024), optional (diag=identity if off)
    - b: bias in Y space (1536), optional
    """
    def __init__(self, mu_x, mu_y, Cx_mh, Cy_ph, R, use_diag=True, use_bias=True):
        super().__init__()
        self.register_buffer("mu_x", torch.from_numpy(mu_x))     # (1024,)
        self.register_buffer("mu_y", torch.from_numpy(mu_y))     # (1536,)
        self.register_buffer("Cx_mh", torch.from_numpy(Cx_mh))   # (1024,1024)
        self.register_buffer("Cy_ph", torch.from_numpy(Cy_ph))   # (1536,1536)
        self.register_buffer("R", torch.from_numpy(R))           # (1536,1024)
        self.use_diag = use_diag
        self.use_bias = use_bias
        if use_diag:
            self.logD = nn.Parameter(torch.zeros(1024))  # D = exp(logD) ~ 1 at init
        else:
            self.register_parameter("logD", None)
        if use_bias:
            self.bias_y = nn.Parameter(torch.zeros(1536))
        else:
            self.register_parameter("bias_y", None)

    def forward(self, x):
        # center + whiten X
        z = x - self.mu_x
        z = F.linear(z, self.Cx_mh)  # (.,1024)
        # optional diagonal scale in whitened space
        if self.use_diag:
            z = z * torch.exp(self.logD)
        # rotate and re-color to Y
        z = F.linear(z, self.R)      # (.,1536)
        z = F.linear(z, self.Cy_ph)  # (.,1536)
        z = z + self.mu_y
        # optional bias in Y
        if self.use_bias:
            z = z + self.bias_y
        return z

# ---------- Losses ----------
def listwise_rank_loss_soft(sims, tau):
    """
    sims: (B, M) cosine similarities where column 0 is the positive.
    Approximate rank r ≈ 1 + sum_j σ((s_j - s_pos)/τ), minimize log(r).
    """
    s_pos = sims[:, :1]          # (B,1)
    s_neg = sims[:, 1:]          # (B,M-1)
    rank = 1.0 + torch.sigmoid((s_neg - s_pos) / tau).sum(dim=1)  # (B,)
    return torch.log(rank).mean()

# ---------- Train loop (calibration only) ----------
def train_calibration(model, loader, opt, alpha, beta, moment_w,
                      rank_lambda, rank_m, rank_tau,
                      train_gallery_unique, device):
    """
    Train only the calibration params (D and/or b).
    train_gallery_unique: (G,1536) np array of unique train images for sampling negatives.
    """
    model.train(); total=0.0
    G = torch.from_numpy(train_gallery_unique).to(device)
    G = F.normalize(G, dim=-1)
    G_size = G.shape[0]

    for xb, yb in loader:
        xb = xb.to(device); yb = yb.to(device)
        opt.zero_grad(set_to_none=True)
        pred = model(xb)

        # Base pairwise losses
        cos = 1 - F.cosine_similarity(pred, yb, dim=-1).mean()
        mse = F.mse_loss(pred, yb)
        a_loss = moment_w * moment_align(pred, yb) if moment_w>0 else pred.new_tensor(0.0)

        # Listwise: sample shared negatives for the batch (fast & stable)
        loss_rank = pred.new_tensor(0.0)
        if rank_lambda > 0 and rank_m > 1:
            with torch.no_grad():
                neg_idx = torch.randint(low=0, high=G_size, size=(rank_m-1,), device=device)
                Gs = G[neg_idx]  # (M-1,1536)
            pred_n = F.normalize(pred, dim=-1)
            yb_n = F.normalize(yb, dim=-1)
            # sims: column 0 = positive, others = shared negatives
            s_pos = (pred_n * yb_n).sum(dim=1, keepdim=True)        # (B,1)
            s_neg = pred_n @ Gs.t()                                 # (B,M-1)
            sims = torch.cat([s_pos, s_neg], dim=1)                 # (B,M)
            loss_rank = listwise_rank_loss_soft(sims, rank_tau) * rank_lambda

        # Regularizers for calibration params
        reg = pred.new_tensor(0.0)
        if getattr(model, "logD", None) is not None and model.logD.requires_grad:
            # D ≈ I  -> penalize (logD)^2
            reg = reg + (model.logD**2).mean() * train_calibration.diag_l2
        if getattr(model, "bias_y", None) is not None and model.bias_y.requires_grad:
            reg = reg + (model.bias_y**2).mean() * train_calibration.bias_l2

        loss = alpha*cos + beta*mse + a_loss + loss_rank + reg
        loss.backward(); opt.step()
        total += loss.item()*xb.size(0)
    return total/len(loader.dataset)

In [74]:
# ---------- CLI ----------
if __name__ == "__main__":
    p = argparse.ArgumentParser()
    # Core
    p.add_argument("--out_dir", type=str, default="geom_calib_rank")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--val_ratio", type=float, default=0.10)
    p.add_argument("--batch", type=int, default=1024)
    p.add_argument("--pooling", type=str, default="CLS")
    # Geometry
    p.add_argument("--geom_eps", type=float, default=1e-5)
    # Calibration choices
    p.add_argument("--calib", type=str, default="diagbias", choices=["none","diag","bias","diagbias"])
    p.add_argument("--calib_epochs", type=int, default=3)
    p.add_argument("--calib_lr", type=float, default=8e-5)
    p.add_argument("--calib_wd", type=float, default=1e-5)
    p.add_argument("--diag_l2", type=float, default=1e-3)   # penalty on (D−I)
    p.add_argument("--bias_l2", type=float, default=1e-4)   # penalty on b
    # Loss weights
    p.add_argument("--alpha", type=float, default=1.0)      # (1 - cos)
    p.add_argument("--beta", type=float, default=0.2)       # MSE (small)
    p.add_argument("--moment", type=float, default=0.01)    # very small; geometry already matches moments
    # Listwise rank surrogate (sampled gallery)
    p.add_argument("--rank_lambda", type=float, default=0.08)
    p.add_argument("--rank_m", type=int, default=512)       # gallery size = 1 pos + M-1 negs
    p.add_argument("--rank_tau", type=float, default=0.03)
    args, _ = p.parse_known_args()
    main(args)

[meta] text_dim=1024 | image_dim=1536
[model] geom+diagbias | learnable=2,560 (~0.01 MB) | α=1.0 β=0.2 λ_moment=0.01 λ_rank=0.08
[calib 01] train_loss=0.447032 | val_MRR=0.3204 | R@1=0.210 R@5=0.437 R@10=0.546 | median=8 p75=39
[calib 02] train_loss=0.446627 | val_MRR=0.3216 | R@1=0.211 R@5=0.439 R@10=0.549 | median=8 p75=39
[calib 03] train_loss=0.445920 | val_MRR=0.3229 | R@1=0.213 R@5=0.440 R@10=0.549 | median=8 p75=39
Generating submission file...
✓ Saved submission to /kaggle/working/outputs/geom_calib_rank/submission.csv
[ok] submission written → /kaggle/working/outputs/geom_calib_rank/submission.csv
{
  "dims": {
    "text": 1024,
    "image": 1536
  },
  "split": {
    "train_captions": 112500,
    "val_captions": 12500,
    "val_unique_images": 2500,
    "leakage": false
  },
  "val_metrics": {
    "MRR": 0.3229257931186908,
    "R1": 0.21264,
    "R5": 0.4396,
    "R10": 0.54912,
    "rank_median": 8,
    "rank_p75": 39
  },
  "efficiency": {
    "params": 2560,
    "mb_fp32"

In [None]:
# ---------- KMeans (no external libs) ----------
def kmeans_pp_init(X, K, rng):
    N, D = X.shape
    centroids = np.empty((K, D), dtype=np.float32)
    # pick first randomly
    i0 = rng.integers(0, N)
    centroids[0] = X[i0]
    # distances
    d2 = np.sum((X - centroids[0])**2, axis=1)
    for k in range(1, K):
        probs = d2 / (d2.sum() + 1e-12)
        idx = rng.choice(N, p=probs)
        centroids[k] = X[idx]
        d2 = np.minimum(d2, np.sum((X - centroids[k])**2, axis=1))
    return centroids

def kmeans(X, K, iters=20, seed=42):
    rng = np.random.default_rng(seed)
    X = X.astype(np.float32)
    cent = kmeans_pp_init(X, K, rng)
    for _ in range(iters):
        # assign
        d2 = ((X[:, None, :] - cent[None, :, :])**2).sum(axis=2)  # (N,K)
        ids = d2.argmin(axis=1)
        # recompute
        new_cent = np.zeros_like(cent)
        counts = np.bincount(ids, minlength=K).astype(np.int64)
        for k in range(K):
            if counts[k] > 0:
                new_cent[k] = X[ids==k].mean(axis=0)
            else:
                # re-init empty cluster
                ridx = rng.integers(0, X.shape[0])
                new_cent[k] = X[ridx]
        if np.allclose(new_cent, cent, atol=1e-5): 
            break
        cent = new_cent
    return cent, ids

# ---------- Model with Cluster Adapters ----------
class W2WBase(nn.Module):
    """Fixed closed-form W2W mapping."""
    def __init__(self, mu_x, mu_y, Cx_mh, Cy_ph, R):
        super().__init__()
        self.register_buffer("mu_x", torch.from_numpy(mu_x))     # (1024,)
        self.register_buffer("mu_y", torch.from_numpy(mu_y))     # (1536,)
        self.register_buffer("Cx_mh", torch.from_numpy(Cx_mh))   # (1024,1024)
        self.register_buffer("Cy_ph", torch.from_numpy(Cy_ph))   # (1536,1536)
        self.register_buffer("R", torch.from_numpy(R))           # (1536,1024)

    def forward_base(self, x):
        z = x - self.mu_x
        z = F.linear(z, self.Cx_mh)  # whiten
        z = F.linear(z, self.R)      # rotate
        z = F.linear(z, self.Cy_ph)  # re-color
        z = z + self.mu_y
        return z

class ClusterAdapters(nn.Module):
    """
    y_base = W2W(x)
    Hard gate to nearest centroid c_k using y_base (cosine), then apply:
      y_hat = y_base ⊙ exp(diag_k) + bias_k
    """
    def __init__(self, base: W2WBase, centroids: np.ndarray, use_diag=True, use_bias=True):
        super().__init__()
        self.base = base
        K, D = centroids.shape
        self.K = K; self.D = D
        self.register_buffer("centroids", torch.from_numpy(centroids.astype(np.float32)))  # (K, D)
        self.use_diag = use_diag
        self.use_bias = use_bias
        if use_diag:
            self.logD = nn.Parameter(torch.zeros(K, D))   # diag ~ 1 via exp(logD)
        else:
            self.register_parameter("logD", None)
        if use_bias:
            self.bias = nn.Parameter(torch.zeros(K, D))
        else:
            self.register_parameter("bias", None)

    def _assign_hard(self, y_base):
        # cosine nearest centroid (both normalized)
        y_n = F.normalize(y_base, dim=-1)
        c_n = F.normalize(self.centroids, dim=-1)
        sims = y_n @ c_n.t()         # (B,K)
        idx = sims.argmax(dim=1)     # (B,)
        return idx

    def forward_with_idx(self, x, idx):
        yb = self.base.forward_base(x)
        if self.use_diag:
            d = torch.exp(self.logD[idx])   # (B,D)
            yb = yb * d
        if self.use_bias:
            b = self.bias[idx]              # (B,D)
            yb = yb + b
        return yb

    def forward(self, x, idx=None):
        yb = self.base.forward_base(x)
        if idx is None:
            idx = self._assign_hard(yb)
        if self.use_diag:
            d = torch.exp(self.logD[idx])
            yb = yb * d
        if self.use_bias:
            b = self.bias[idx]
            yb = yb + b
        return yb

# ---------- Train loop for cluster adapters ----------
def train_cluster_adapters(model, loader, opt, alpha, beta, moment_w,
                           rank_lambda, rank_m, rank_tau,
                           train_gallery_unique, device,
                           diag_l2, bias_l2):
    """
    model: ClusterAdapters (base fixed; adapters learnable)
    For each batch, we use image cluster ids to pick adapters.
    """
    model.train(); total=0.0
    G = torch.from_numpy(train_gallery_unique).to(device)
    G = F.normalize(G, dim=-1)
    G_size = G.shape[0]

    for pack in loader:
        xb, yb, cids = pack
        xb = xb.to(device); yb = yb.to(device); cids = cids.to(device)
        opt.zero_grad(set_to_none=True)

        pred = model.forward_with_idx(xb, cids)

        # Base pairwise losses
        cos = 1 - F.cosine_similarity(pred, yb, dim=-1).mean()
        mse = F.mse_loss(pred, yb)
        a_loss = moment_w * moment_align(pred, yb) if moment_w>0 else pred.new_tensor(0.0)

        # Listwise: sample shared negatives once per batch
        loss_rank = pred.new_tensor(0.0)
        if rank_lambda > 0 and rank_m > 1:
            with torch.no_grad():
                neg_idx = torch.randint(low=0, high=G_size, size=(rank_m-1,), device=device)
                Gs = G[neg_idx]  # (M-1,1536)
            pred_n = F.normalize(pred, dim=-1)
            yb_n = F.normalize(yb, dim=-1)
            s_pos = (pred_n * yb_n).sum(dim=1, keepdim=True)        # (B,1)
            s_neg = pred_n @ Gs.t()                                 # (B,M-1)
            sims = torch.cat([s_pos, s_neg], dim=1)                 # (B,M)
            loss_rank = listwise_rank_loss_soft(sims, rank_tau) * rank_lambda

        # Regularizers on adapters (only those used in batch to be cheap)
        reg = pred.new_tensor(0.0)
        if model.use_diag:
            reg = reg + (model.logD[cids]**2).mean() * diag_l2
        if model.use_bias:
            reg = reg + (model.bias[cids]**2).mean() * bias_l2

        loss = alpha*cos + beta*mse + a_loss + loss_rank + reg
        loss.backward(); opt.step()
        total += loss.item()*xb.size(0)
    return total/len(loader.dataset)


In [None]:
# ---------- Main ----------
def main(args):
    seed_all(args.seed)
    OUT = Path(f"/kaggle/working/outputs/{args.out_dir}"); OUT.mkdir(parents=True, exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Data
    X,Y,cap_ids,img_ids_row,full_img,all_img_ids = load_train(OUT)
    cap_is_tr, cap_is_val, val_gallery, cap2gal_local, val_img_indices = build_image_id_split(
        img_ids_row, all_img_ids, full_img, X, Y, args.val_ratio, args.seed, OUT
    )
    val_set = set(all_img_ids[val_img_indices])
    train_set = set(all_img_ids) - val_set
    assert val_set.isdisjoint(train_set), "Leakage detected between TRAIN and VAL image sets!"

    # Masks
    Xtr,Ytr = X[cap_is_tr], Y[cap_is_tr]
    Xva     = X[cap_is_val]

    din, dout = X.shape[1], Y.shape[1]
    assert din==1024 and dout==1536, f"Dimension mismatch: text={din}, image={dout} (expected 1024→1536)"

    # Closed-form base W2W on TRAIN captions
    mu_x, mu_y, Cx_mh, Cy_ph, R = procrustes_factors(Xtr, Ytr, eps=args.geom_eps)
    base = W2WBase(mu_x, mu_y, Cx_mh, Cy_ph, R).to(device)
    # ---- KMeans on image latents (ALL images), then per-caption cluster ids (TRAIN only)
    K = args.k_clusters
    print(f"[kmeans] K={K} iters={args.k_iters}")
    centroids, img_cluster_ids = kmeans(full_img, K=K, iters=args.k_iters, seed=args.seed)
    # Map caption rows to image clusters via target image name
    name2global = {str(n):i for i,n in enumerate(all_img_ids)}
    cap_img_global_idx = np.array([name2global[str(n)] for n in img_ids_row], dtype=np.int64)
    cap_cluster_ids = img_cluster_ids[cap_img_global_idx]
    cap_cluster_ids_tr = cap_cluster_ids[cap_is_tr]  # used for training adapters

    # Model with cluster adapters (diag+bias)
    model = ClusterAdapters(base, centroids, use_diag=True, use_bias=True).to(device)

    # Report learnable params (only adapters)
    learnable_params, mb = count_params_mb(model), None
    learnable_params, mb = count_params_mb(model)
    print(f"[model] w2w+cluster-adapters | learnable={learnable_params:,} (~{mb:.2f} MB) "
          f"| α={args.alpha} β={args.beta} λ_moment={args.moment} λ_rank={args.rank_lambda}")

    # Train adapters
    best_stats=None
    if args.epochs > 0:
        dl = DataLoader(PairDS(Xtr, Ytr, cids=cap_cluster_ids_tr),
                        batch_size=args.batch, shuffle=True, num_workers=2, pin_memory=True, drop_last=False)
        opt = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                                lr=args.lr, weight_decay=args.wd)
        # Build a TRAIN gallery of unique images for rank negatives (exclude VAL images)
        all_idx = np.arange(len(full_img))
        train_gallery_idx = np.setdiff1d(all_idx, val_img_indices)
        train_gallery_unique = full_img[train_gallery_idx]
        for ep in range(1, args.epochs+1):
            tr = train_cluster_adapters(model, dl, opt, args.alpha, args.beta, args.moment,
                                        args.rank_lambda, args.rank_m, args.rank_tau,
                                        train_gallery_unique, device,
                                        args.diag_l2, args.bias_l2)
            stats = validate_retrieval(model, Xva, val_gallery, cap2gal_local, args.pooling, None)
            print(f"[adapters {ep:02d}] train_loss={tr:.6f} | val_MRR={stats['MRR']:.4f} "
                  f"| R@1={stats['R1']:.3f} R@5={stats['R5']:.3f} R@10={stats['R10']:.3f} "
                  f"| median={stats['rank_median']} p75={stats['rank_p75']}")
            if best_stats is None or stats["MRR"] > best_stats["MRR"]:
                best_stats = stats
    else:
        best_stats = validate_retrieval(model, Xva, val_gallery, cap2gal_local, args.pooling, None)

    # Efficiency
    ms_gpu, ms_cpu = time_ms_per_query(model, din, args.pooling, None)
    eff = {"params":learnable_params,"mb_fp32":mb,"ms_per_query_gpu":ms_gpu,"ms_per_query_cpu":ms_cpu}
    (OUT/"efficiency.json").write_text(json.dumps(eff, indent=2))

    # Submission (normalized; gating inside model)
    test_data = load_data(TEST_NPZ)
    Q   = test_data["captions/embeddings"].astype(np.float32)
    ids = test_data.get("captions/ids", np.arange(len(Q)).astype(str))
    model.eval()
    BS = 1024; outs=[]
    with torch.no_grad():
        for i in range(0, len(Q), BS):
            q = torch.from_numpy(Q[i:i+BS]).to(device)
            z = model(q)               # base + adapter with hard gating
            z = F.normalize(z, dim=-1)
            outs.append(z.detach().cpu().numpy())
    pred_embds = np.concatenate(outs, axis=0)
    sub = OUT / "submission.csv"
    generate_submission(ids, pred_embds, str(sub))
    print(f"[ok] submission written → {sub}")

    # One-file sanity
    sanity = {
        "dims": {"text": int(din), "image": int(dout)},
        "kmeans": {"K": int(K), "iters": int(args.k_iters)},
        "split": {
            "train_captions": int(cap_is_tr.sum()),
            "val_captions": int(cap_is_val.sum()),
            "val_unique_images": int(val_gallery.shape[0]),
            "leakage": False
        },
        "val_metrics": best_stats,
        "efficiency": eff
    }
    print(json.dumps(sanity, indent=2))


In [78]:
# ---------- CLI ----------
if __name__ == "__main__":
    p = argparse.ArgumentParser()
    # Core
    p.add_argument("--out_dir", type=str, default="w2w_cluster_adapters")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--val_ratio", type=float, default=0.10)
    p.add_argument("--batch", type=int, default=1024)
    p.add_argument("--pooling", type=str, default="CLS")  # no-op
    # Geometry
    p.add_argument("--geom_eps", type=float, default=1e-5)
    # KMeans
    p.add_argument("--k_clusters", type=int, default=64, help="K in KMeans over image latents Y")
    p.add_argument("--k_iters", type=int, default=20)
    # Training (adapters only)
    p.add_argument("--epochs", type=int, default=3)
    p.add_argument("--lr", type=float, default=8e-5)
    p.add_argument("--wd", type=float, default=1e-5)
    # Loss weights
    p.add_argument("--alpha", type=float, default=1.0)      # (1 - cos)
    p.add_argument("--beta", type=float, default=0.2)       # MSE (small)
    p.add_argument("--moment", type=float, default=0.01)    # tiny
    p.add_argument("--rank_lambda", type=float, default=0.08)
    p.add_argument("--rank_m", type=int, default=512)       # 1 pos + M-1 negs
    p.add_argument("--rank_tau", type=float, default=0.03)
    # Regularization on adapters
    p.add_argument("--diag_l2", type=float, default=1e-3)   # ||d_k - 1||^2
    p.add_argument("--bias_l2", type=float, default=1e-4)   # ||b_k||^2
    args, _ = p.parse_known_args()
    main(args)


[meta] text_dim=1024 | image_dim=1536
[kmeans] K=64 iters=20
[model] w2w+cluster-adapters | learnable=196,608 (~0.75 MB) | α=1.0 β=0.2 λ_moment=0.01 λ_rank=0.08
[adapters 01] train_loss=0.444437 | val_MRR=0.3194 | R@1=0.210 R@5=0.434 R@10=0.543 | median=8 p75=39
[adapters 02] train_loss=0.438978 | val_MRR=0.3180 | R@1=0.208 R@5=0.433 R@10=0.545 | median=8 p75=40
[adapters 03] train_loss=0.433501 | val_MRR=0.3157 | R@1=0.205 R@5=0.432 R@10=0.543 | median=8 p75=40
Generating submission file...
✓ Saved submission to /kaggle/working/outputs/w2w_cluster_adapters/submission.csv
[ok] submission written → /kaggle/working/outputs/w2w_cluster_adapters/submission.csv
{
  "dims": {
    "text": 1024,
    "image": 1536
  },
  "kmeans": {
    "K": 64,
    "iters": 20
  },
  "split": {
    "train_captions": 112500,
    "val_captions": 12500,
    "val_unique_images": 2500,
    "leakage": false
  },
  "val_metrics": {
    "MRR": 0.3193919736741112,
    "R1": 0.21032,
    "R5": 0.434,
    "R10": 0.54344,

In [80]:
def build_image_centroid_targets(img_ids_row, Y, use_mask=None):
    """
    Returns per-caption targets replaced by the centroid of its image over duplicates.
    Optionally apply a boolean mask (e.g., train mask) before computing centroids.
    """
    img_ids_row = np.asarray(img_ids_row).astype(str)
    if use_mask is None:
        use_mask = np.ones(len(img_ids_row), dtype=bool)

    # compute centroids on the masked subset
    keys = img_ids_row[use_mask]
    vals = Y[use_mask]
    uniq, inv = np.unique(keys, return_inverse=True)
    sums = np.zeros((len(uniq), Y.shape[1]), dtype=np.float32)
    cnts = np.zeros(len(uniq), dtype=np.int64)
    for i, k in enumerate(inv):
        sums[k] += vals[i]; cnts[k] += 1
    cent = sums / np.maximum(1, cnts)[:, None]

    # map all rows (full set) to their centroid (computed from masked subset when present)
    name2idx = {name:i for i, name in enumerate(uniq)}
    Ycent = np.empty_like(Y)
    for i, nm in enumerate(img_ids_row):
        j = name2idx.get(nm, None)
        if j is None:
            # if a name exists only outside mask, fallback to its own vector
            Ycent[i] = Y[i]
        else:
            Ycent[i] = cent[j]
    return Ycent

# ---------- Main ----------
def main(args):
    seed_all(args.seed)
    OUT = Path(f"/kaggle/working/outputs/{args.out_dir}"); OUT.mkdir(parents=True, exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Data
    X,Y,cap_ids,img_ids_row,full_img,all_img_ids = load_train(OUT)
    cap_is_tr, cap_is_val, val_gallery, cap2gal_local, val_img_indices = build_image_id_split(
        img_ids_row, all_img_ids, full_img, X, Y, args.val_ratio, args.seed, OUT
    )
    val_set = set(all_img_ids[val_img_indices])
    train_set = set(all_img_ids) - val_set
    assert val_set.isdisjoint(train_set), "Leakage detected between TRAIN and VAL image sets!"

    # Masks
    Xtr,Ytr = X[cap_is_tr], Y[cap_is_tr]
    Xva     = X[cap_is_val]
    Ycent_all = build_image_centroid_targets(img_ids_row, Y, use_mask=cap_is_tr)
    Ytr_cent  = Ycent_all[cap_is_tr]  # only train slice

    din, dout = X.shape[1], Y.shape[1]
    assert din==1024 and dout==1536, f"Dimension mismatch: text={din}, image={dout} (expected 1024→1536)"

    # Closed-form base W2W on TRAIN captions
    mu_x, mu_y, Cx_mh, Cy_ph, R = procrustes_factors(Xtr, Ytr, eps=args.geom_eps)
    base = W2WBase(mu_x, mu_y, Cx_mh, Cy_ph, R).to(device)

    # ---- KMeans on image latents (ALL images), then per-caption cluster ids (TRAIN only)
    K = args.k_clusters
    print(f"[kmeans] K={K} iters={args.k_iters}")
    centroids, img_cluster_ids = kmeans(full_img, K=K, iters=args.k_iters, seed=args.seed)
    # Map caption rows to image clusters via target image name
    name2global = {str(n):i for i,n in enumerate(all_img_ids)}
    cap_img_global_idx = np.array([name2global[str(n)] for n in img_ids_row], dtype=np.int64)
    cap_cluster_ids = img_cluster_ids[cap_img_global_idx]
    cap_cluster_ids_tr = cap_cluster_ids[cap_is_tr]  # used for training adapters

    # Model with cluster adapters (diag+bias)
    model = ClusterAdapters(base, centroids, use_diag=True, use_bias=True).to(device)

    # Report learnable params (only adapters)
    learnable_params, mb = count_params_mb(model), None
    learnable_params, mb = count_params_mb(model)
    print(f"[model] w2w+cluster-adapters | learnable={learnable_params:,} (~{mb:.2f} MB) "
          f"| α={args.alpha} β={args.beta} λ_moment={args.moment} λ_rank={args.rank_lambda}")

    # Train adapters
    best_stats=None
    if args.epochs > 0:
        # use centroid targets if requested (defined in section C below)
        Ytr_used = Ytr_cent if args.train_centroids else Ytr
        dl = DataLoader(PairDS(Xtr, Ytr_used, cids=cap_cluster_ids_tr),
                        batch_size=args.batch, shuffle=True, num_workers=2, pin_memory=True, drop_last=False)
        opt = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                                lr=args.lr, weight_decay=args.wd)
    
        # negatives from TRAIN gallery (exclude VAL images) — reuse as-is
        all_idx = np.arange(len(full_img))
        train_gallery_idx = np.setdiff1d(all_idx, val_img_indices)
        train_gallery_unique = full_img[train_gallery_idx]
    
        for ep in range(1, args.epochs+1):
            # curriculum: increase M after ep >= rank_curriculum_ep
            rank_m_curr = args.rank_m2 if (args.rank_curriculum_ep>0 and ep>=args.rank_curriculum_ep) else args.rank_m
            tr = train_cluster_adapters(
                model, dl, opt,
                args.alpha, args.beta, args.moment,
                args.rank_lambda, rank_m_curr, args.rank_tau,
                train_gallery_unique, device,
                args.diag_l2, args.bias_l2
            )
            stats = validate_retrieval(model, Xva, val_gallery, cap2gal_local, args.pooling, None)
            print(f"[adapters {ep:02d}] M={rank_m_curr} train_loss={tr:.6f} | val_MRR={stats['MRR']:.4f} "
                  f"| R@1={stats['R1']:.3f} R@5={stats['R5']:.3f} R@10={stats['R10']:.3f} "
                  f"| median={stats['rank_median']} p75={stats['rank_p75']}")
            if best_stats is None or stats["MRR"] > best_stats["MRR"]:
                best_stats = stats
    else:
        best_stats = validate_retrieval(model, Xva, val_gallery, cap2gal_local, args.pooling, None)

    # Efficiency
    ms_gpu, ms_cpu = time_ms_per_query(model, din, args.pooling, None)
    eff = {"params":learnable_params,"mb_fp32":mb,"ms_per_query_gpu":ms_gpu,"ms_per_query_cpu":ms_cpu}
    (OUT/"efficiency.json").write_text(json.dumps(eff, indent=2))

    # Submission (normalized; gating inside model)
    test_data = load_data(TEST_NPZ)
    Q   = test_data["captions/embeddings"].astype(np.float32)
    ids = test_data.get("captions/ids", np.arange(len(Q)).astype(str))
    model.eval()
    BS = 1024; outs=[]
    with torch.no_grad():
        for i in range(0, len(Q), BS):
            q = torch.from_numpy(Q[i:i+BS]).to(device)
            z = model(q)               # base + adapter with hard gating
            z = F.normalize(z, dim=-1)
            outs.append(z.detach().cpu().numpy())
    pred_embds = np.concatenate(outs, axis=0)
    sub = OUT / "submission.csv"
    generate_submission(ids, pred_embds, str(sub))
    print(f"[ok] submission written → {sub}")

    # One-file sanity
    sanity = {
        "dims": {"text": int(din), "image": int(dout)},
        "kmeans": {"K": int(K), "iters": int(args.k_iters)},
        "split": {
            "train_captions": int(cap_is_tr.sum()),
            "val_captions": int(cap_is_val.sum()),
            "val_unique_images": int(val_gallery.shape[0]),
            "leakage": False
        },
        "val_metrics": best_stats,
        "efficiency": eff
    }
    print(json.dumps(sanity, indent=2))


In [82]:
# ---------- CLI ----------
if __name__ == "__main__":
    p = argparse.ArgumentParser()
    # Core
    p.add_argument("--out_dir", type=str, default="w2w_cluster_adapters")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--val_ratio", type=float, default=0.10)
    p.add_argument("--batch", type=int, default=1024)
    p.add_argument("--pooling", type=str, default="CLS")  # no-op
    # Geometry
    p.add_argument("--geom_eps", type=float, default=1e-5)
    # KMeans
    p.add_argument("--k_clusters", type=int, default=64, help="K in KMeans over image latents Y")
    p.add_argument("--k_iters", type=int, default=20)
    # Training (adapters only)
    p.add_argument("--epochs", type=int, default=3)
    p.add_argument("--lr", type=float, default=8e-5)
    p.add_argument("--wd", type=float, default=1e-5)
    # Loss weights
    p.add_argument("--alpha", type=float, default=1.0)      # (1 - cos)
    p.add_argument("--beta", type=float, default=0.2)       # MSE (small)
    p.add_argument("--moment", type=float, default=0.01)    # tiny
    p.add_argument("--rank_lambda", type=float, default=0.12)   # stronger but still gentle
    p.add_argument("--rank_m", type=int, default=512)
    p.add_argument("--rank_m2", type=int, default=1024)         # increased gallery
    p.add_argument("--rank_curriculum_ep", type=int, default=3) # epochs 1–2: 512, then 1024
    p.add_argument("--rank_tau", type=float, default=0.03)
    # Regularization on adapters
    p.add_argument("--diag_l2", type=float, default=1e-3)   # ||d_k - 1||^2
    p.add_argument("--bias_l2", type=float, default=1e-4)   # ||b_k||^2
    # added
    p.add_argument("--train_centroids", action="store_true",
               help="Replace training targets with per-image centroids in Y")
    args, _ = p.parse_known_args()
    main(args)

[meta] text_dim=1024 | image_dim=1536
[kmeans] K=64 iters=20
[model] w2w+cluster-adapters | learnable=196,608 (~0.75 MB) | α=1.0 β=0.2 λ_moment=0.01 λ_rank=0.12
[adapters 01] M=512 train_loss=0.531569 | val_MRR=0.3198 | R@1=0.210 R@5=0.434 R@10=0.544 | median=8 p75=39
[adapters 02] M=512 train_loss=0.524246 | val_MRR=0.3186 | R@1=0.208 R@5=0.434 R@10=0.546 | median=8 p75=40
[adapters 03] M=1024 train_loss=0.587098 | val_MRR=0.3168 | R@1=0.206 R@5=0.433 R@10=0.544 | median=8 p75=39
Generating submission file...
✓ Saved submission to /kaggle/working/outputs/w2w_cluster_adapters/submission.csv
[ok] submission written → /kaggle/working/outputs/w2w_cluster_adapters/submission.csv
{
  "dims": {
    "text": 1024,
    "image": 1536
  },
  "kmeans": {
    "K": 64,
    "iters": 20
  },
  "split": {
    "train_captions": 112500,
    "val_captions": 12500,
    "val_unique_images": 2500,
    "leakage": false
  },
  "val_metrics": {
    "MRR": 0.3197658671172457,
    "R1": 0.21048,
    "R5": 0.43416

In [86]:
# ---------- CLI ----------
if __name__ == "__main__":
    p = argparse.ArgumentParser()
    # Core
    p.add_argument("--exp", type=str, default="rank_ft", choices=["rank_ft","adapters","geom_only"])
    p.add_argument("--out_dir", type=str, default="w2w_cluster_adapters")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--val_ratio", type=float, default=0.10)
    p.add_argument("--batch", type=int, default=1024)
    p.add_argument("--pooling", type=str, default="CLS")  # no-op
    # Geometry
    p.add_argument("--geom_eps", type=float, default=1e-5)
    # KMeans
    p.add_argument("--k_clusters", type=int, default=64, help="K in KMeans over image latents Y")
    p.add_argument("--k_iters", type=int, default=20)
    # Training (adapters only)
    p.add_argument("--epochs", type=int, default=3)
    p.add_argument("--lr", type=float, default=8e-5)
    p.add_argument("--wd", type=float, default=1e-5)
    # Loss weights
    p.add_argument("--alpha", type=float, default=1.0)      # (1 - cos)
    p.add_argument("--beta", type=float, default=0.2)       # MSE (small)
    p.add_argument("--moment", type=float, default=0.01)    # tiny
    p.add_argument("--rank_loss", type=str, default="soft", choices=["soft","hinge"])
    p.add_argument("--rank_lambda", type=float, default=0.15)
    p.add_argument("--rank_tau", type=float, default=0.03)
    p.add_argument("--rank_margin", type=float, default=0.07)
    p.add_argument("--rank_m", type=int, default=1024)
    p.add_argument("--rank_m2", type=int, default=2048)
    p.add_argument("--rank_curriculum_ep", type=int, default=3)
    
    # regularization on diag/bias
    p.add_argument("--diag_l2", type=float, default=1e-3)
    p.add_argument("--bias_l2", type=float, default=1e-4)
    p.add_argument("--train_centroids", action="store_true",
               help="Replace training targets with per-image centroids in Y")
    args, _ = p.parse_known_args()
    main(args)

[meta] text_dim=1024 | image_dim=1536
[model] rank-ft (geom frozen) | learnable=3,072 (~0.01 MB) | α=1.0 β=0.2 λ_moment=0.01 λ_rank=0.15 | rank_loss=soft
[rank-ft 01] M=1024 train_loss=0.689595 | val_MRR=0.3228 | R@1=0.212 R@5=0.439 R@10=0.549 | median=8 p75=39
[rank-ft 02] M=1024 train_loss=0.687079 | val_MRR=0.3262 | R@1=0.215 R@5=0.442 R@10=0.555 | median=8 p75=38
[rank-ft 03] M=2048 train_loss=0.779317 | val_MRR=0.3289 | R@1=0.217 R@5=0.446 R@10=0.561 | median=7 p75=37
Generating submission file...
✓ Saved submission to /kaggle/working/outputs/w2w_cluster_adapters/submission.csv
[ok] submission written → /kaggle/working/outputs/w2w_cluster_adapters/submission.csv
{
  "dims": {
    "text": 1024,
    "image": 1536
  },
  "exp": "rank_ft",
  "split": {
    "train_captions": 112500,
    "val_captions": 12500,
    "val_unique_images": 2500,
    "leakage": false
  },
  "val_metrics": {
    "MRR": 0.32888943540530263,
    "R1": 0.2168,
    "R5": 0.44608,
    "R10": 0.56056,
    "rank_medi

In [89]:
class TextMoEResidual(nn.Module):
    """
    Frozen W2W base + per-text-cluster residuals in Y:
      y_base = W2W(x)
      y_hat  = y_base ⊙ exp(logD_k) + b_k   where k = nearest text centroid
    """
    def __init__(self, mu_x, mu_y, Cx_mh, Cy_ph, R, text_centroids: np.ndarray,
                 use_diag=True, use_bias=True):
        super().__init__()
        # --- frozen W2W base (reuse geometry buffers)
        self.register_buffer("mu_x", torch.from_numpy(mu_x))
        self.register_buffer("mu_y", torch.from_numpy(mu_y))
        self.register_buffer("Cx_mh", torch.from_numpy(Cx_mh))
        self.register_buffer("Cy_ph", torch.from_numpy(Cy_ph))
        self.register_buffer("R", torch.from_numpy(R))
        # --- text KMeans centroids (K, 1024)
        self.register_buffer("tcent", torch.from_numpy(text_centroids.astype(np.float32)))
        self.K = self.tcent.shape[0]
        self.D = mu_y.shape[0]  # 1536
        # --- residual params
        self.use_diag = use_diag
        self.use_bias = use_bias
        if use_diag:
            self.logD = nn.Parameter(torch.zeros(self.K, self.D))
        else:
            self.register_parameter("logD", None)
        if use_bias:
            self.bias = nn.Parameter(torch.zeros(self.K, self.D))
        else:
            self.register_parameter("bias", None)

    def forward_base(self, x):
        z = x - self.mu_x
        z = F.linear(z, self.Cx_mh)  # whiten
        z = F.linear(z, self.R)      # rotate
        z = F.linear(z, self.Cy_ph)  # re-color
        z = z + self.mu_y
        return z

    def _assign_text(self, x):
        # cosine nearest text centroid (normalize both)
        x_n = F.normalize(x, dim=-1)
        c_n = F.normalize(self.tcent, dim=-1)
        sims = x_n @ c_n.t()           # (B,K)
        return sims.argmax(dim=1)      # (B,)

    def forward_with_idx(self, x, tidx):
        yb = self.forward_base(x)
        if self.use_diag:
            yb = yb * torch.exp(self.logD[tidx])
        if self.use_bias:
            yb = yb + self.bias[tidx]
        return yb

    def forward(self, x):
        tidx = self._assign_text(x)
        return self.forward_with_idx(x, tidx)

def train_text_moe(model, loader, opt, *,
                   alpha, beta, moment_w,
                   rank_lambda, rank_tau, rank_margin, rank_loss,
                   rank_m, train_gallery_unique, device,
                   diag_l2, bias_l2):
    model.train(); total=0.0
    # negatives pool
    G = torch.from_numpy(train_gallery_unique).to(device)
    G = F.normalize(G, dim=-1)
    G_size = G.shape[0]

    for xb, yb, tcluster in loader:
        xb = xb.to(device); yb = yb.to(device); tcluster = tcluster.to(device)
        opt.zero_grad(set_to_none=True)

        pred = model.forward_with_idx(xb, tcluster)

        # base pairwise
        cos = 1 - F.cosine_similarity(pred, yb, dim=-1).mean()
        mse = F.mse_loss(pred, yb)
        a_loss = moment_w * moment_align(pred, yb) if moment_w>0 else pred.new_tensor(0.0)

        # listwise: shared negatives for batch
        with torch.no_grad():
            neg_idx = torch.randint(low=0, high=G_size, size=(rank_m-1,), device=device)
            Gs = G[neg_idx]  # (M-1,1536)
        pred_n = F.normalize(pred, dim=-1)
        yb_n   = F.normalize(yb,   dim=-1)
        s_pos = (pred_n * yb_n).sum(dim=1, keepdim=True)  # (B,1)
        s_neg = pred_n @ Gs.t()                           # (B,M-1)
        sims = torch.cat([s_pos, s_neg], dim=1)          # (B,M)

        if rank_loss == "soft":
            l_rank = listwise_rank_loss_soft(sims, rank_tau)
        else:
            l_rank = listwise_rank_loss_hinge(sims, margin=rank_margin)
        l_rank = l_rank * rank_lambda

        # regularize only the used clusters for efficiency
        reg = pred.new_tensor(0.0)
        if getattr(model, "logD", None) is not None and model.logD.requires_grad:
            reg = reg + (model.logD[tcluster]**2).mean() * diag_l2
        if getattr(model, "bias", None) is not None and model.bias.requires_grad:
            reg = reg + (model.bias[tcluster]**2).mean() * bias_l2

        loss = alpha*cos + beta*mse + a_loss + l_rank + reg
        loss.backward(); opt.step()
        total += loss.item()*xb.size(0)
    return total/len(loader.dataset)

In [92]:
# ---------- CLI ----------
if __name__ == "__main__":
    p = argparse.ArgumentParser()
    # Core
    p.add_argument("--exp", type=str, default="text_moe",
                   choices=["rank_ft","adapters","geom_only","text_moe"])
    p.add_argument("--out_dir", type=str, default="w2w_cluster_adapters")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--val_ratio", type=float, default=0.10)
    p.add_argument("--batch", type=int, default=1024)
    p.add_argument("--pooling", type=str, default="CLS")  # no-op
    # Geometry
    p.add_argument("--geom_eps", type=float, default=1e-5)
    # KMeans
    p.add_argument("--k_clusters", type=int, default=64, help="K in KMeans over image latents Y")
    p.add_argument("--k_iters", type=int, default=20)
    # Training (adapters only)
    p.add_argument("--epochs", type=int, default=3)
    p.add_argument("--lr", type=float, default=8e-5)
    p.add_argument("--wd", type=float, default=1e-5)
    # Loss weights
    p.add_argument("--alpha", type=float, default=1.0)      # (1 - cos)
    p.add_argument("--beta", type=float, default=0.2)       # MSE (small)
    p.add_argument("--moment", type=float, default=0.01)    # tiny
    p.add_argument("--rank_loss", type=str, default="soft", choices=["soft","hinge"])
    p.add_argument("--rank_lambda", type=float, default=0.15)
    p.add_argument("--rank_tau", type=float, default=0.03)
    p.add_argument("--rank_margin", type=float, default=0.07)
    p.add_argument("--rank_m", type=int, default=1024)
    p.add_argument("--rank_m2", type=int, default=2048)
    p.add_argument("--rank_curriculum_ep", type=int, default=3)
    
    # regularization on diag/bias
    p.add_argument("--diag_l2", type=float, default=1e-3)
    p.add_argument("--bias_l2", type=float, default=1e-4)
    p.add_argument("--train_centroids", action="store_true",
               help="Replace training targets with per-image centroids in Y")
    p.add_argument("--text_k", type=int, default=12, help="K for KMeans over TRAIN text embeddings")
    p.add_argument("--text_k_iters", type=int, default=20)

    args, _ = p.parse_known_args()
    main(args)

[meta] text_dim=1024 | image_dim=1536
[text-kmeans] K=12 iters=20
[model] w2w + text-MoE residuals | learnable=36,864 (~0.14 MB) | α=1.0 β=0.2 λ_moment=0.01 λ_rank=0.15 | rank_loss=soft
[text-moe 01] M=1024 train_loss=0.689330 | val_MRR=0.3218 | R@1=0.211 R@5=0.439 R@10=0.549 | median=8 p75=39
[text-moe 02] M=1024 train_loss=0.686281 | val_MRR=0.3248 | R@1=0.214 R@5=0.442 R@10=0.553 | median=8 p75=38
[text-moe 03] M=2048 train_loss=0.778095 | val_MRR=0.3275 | R@1=0.216 R@5=0.443 R@10=0.558 | median=8 p75=37
Generating submission file...
✓ Saved submission to /kaggle/working/outputs/w2w_cluster_adapters/submission.csv
[ok] submission written → /kaggle/working/outputs/w2w_cluster_adapters/submission.csv
{
  "dims": {
    "text": 1024,
    "image": 1536
  },
  "exp": "text_moe",
  "text_kmeans": {
    "K": 12,
    "iters": 20
  },
  "split": {
    "train_captions": 112500,
    "val_captions": 12500,
    "val_unique_images": 2500,
    "leakage": false
  },
  "val_metrics": {
    "MRR": 0.3

In [98]:
@torch.no_grad()
def mine_hard_negatives_offline(Xtr, cap_is_tr, img_ids_row, all_img_names,
                                full_img, val_img_indices, mu_x, mu_y, Cx_mh, Cy_ph, R,
                                topn=50, device=torch.device("cpu")):
    """
    Returns:
      G_train: (G,1536) np float32 of unique TRAIN images (VAL excluded)
      cap2gal_pos: (N_tr,) local indices of the positive in G_train for each TRAIN caption
      hard_top: list of np.arrays; for each train row i, length `topn` with local NEG indices in G_train
    """
    # TRAIN-only gallery (exclude VAL)
    all_idx = np.arange(len(full_img))
    train_gallery_idx = np.setdiff1d(all_idx, val_img_indices)
    G_train = full_img[train_gallery_idx].astype(np.float32)

    # map image name -> global -> local gallery index
    name2global = {str(n): i for i, n in enumerate(all_img_names)}
    gal_global_to_local = {int(g): li for li, g in enumerate(train_gallery_idx)}
    cap_img_global_idx = np.array([name2global[str(n)] for n in img_ids_row], dtype=np.int64)
    cap_img_local_idx  = np.array([gal_global_to_local.get(int(g), -1) for g in cap_img_global_idx], dtype=np.int64)
    cap_img_local_idx_tr = cap_img_local_idx[cap_is_tr]
    assert np.all(cap_img_local_idx_tr >= 0), "Some TRAIN captions point to VAL-only images; check split."

    # Base W2W forward for all TRAIN texts — keep on `device` (NO .cpu())
    Xtr = Xtr.astype(np.float32)
    mu_x_t = torch.from_numpy(mu_x).to(device)
    mu_y_t = torch.from_numpy(mu_y).to(device)
    Cx_mh_t= torch.from_numpy(Cx_mh).to(device)
    Cy_ph_t= torch.from_numpy(Cy_ph).to(device)
    R_t    = torch.from_numpy(R).to(device)

    B = 2048
    preds = []
    for i in range(0, len(Xtr), B):
        x = torch.from_numpy(Xtr[i:i+B]).to(device)
        z = x - mu_x_t
        z = F.linear(z, Cx_mh_t)
        z = F.linear(z, R_t)
        z = F.linear(z, Cy_ph_t) + mu_y_t
        preds.append(F.normalize(z, dim=-1))  # stay on device
    Qn = torch.cat(preds, dim=0)  # (N_tr,1536) on `device`

    # Normalize gallery on same device
    Gn = F.normalize(torch.from_numpy(G_train).to(device), dim=-1)  # (G,1536)

    # Cosine sims & top-N negatives (mask out positives)
    hard_top = []
    cap2gal_pos = cap_img_local_idx_tr.copy()
    kmax = max(1, min(topn, Gn.size(0) - 1))  # cannot pick the positive
    for i in range(0, Qn.size(0), B):
        q = Qn[i:i+B]                           # (b,1536) device
        sims = q @ Gn.t()                       # (b,G)
        # mask each row's positive
        for j in range(q.size(0)):
            pos = int(cap2gal_pos[i+j])
            sims[j, pos] = -float("inf")
        idx = torch.topk(sims, k=kmax, dim=1).indices.cpu().numpy()  # indices → CPU/NumPy
        hard_top.extend([row for row in idx])

    return G_train, cap2gal_pos, hard_top


# ---- OT barycentric target (offline teacher) ----
@torch.no_grad()
def compute_ot_barycentric_targets(Xtr, mu_x, mu_y, Cx_mh, Cy_ph, R,
                                   G_train, knn=32, tau=0.05, device=torch.device("cpu")):
    """
    Teacher target per train caption:
      y_ot = sum_j softmax(cos(y_base, g_j)/tau) * g_j, over KNN in train gallery
    Returns: Y_ot (N_tr, 1536) float32
    """
    # Base predictions
    mu_x_t = torch.from_numpy(mu_x).to(device)
    mu_y_t = torch.from_numpy(mu_y).to(device)
    Cx_mh_t= torch.from_numpy(Cx_mh).to(device)
    Cy_ph_t= torch.from_numpy(Cy_ph).to(device)
    R_t    = torch.from_numpy(R).to(device)

    Xtr = Xtr.astype(np.float32)
    B=2048; preds=[]
    for i in range(0, len(Xtr), B):
        x = torch.from_numpy(Xtr[i:i+B]).to(device)
        z = x - mu_x_t
        z = F.linear(z, Cx_mh_t)
        z = F.linear(z, R_t)
        z = F.linear(z, Cy_ph_t) + mu_y_t
        preds.append(F.normalize(z, dim=-1))
    Qn = torch.cat(preds, dim=0)  # (N,1536)

    Gn = F.normalize(torch.from_numpy(G_train).to(device), dim=-1)  # (G,1536)

    Y_ot = []
    for i in range(0, Qn.size(0), B):
        q = Qn[i:i+B]                              # (b,1536)
        sims = q @ Gn.t()                          # (b,G)
        topk = min(knn, Gn.size(0))
        idx = torch.topk(sims, k=topk, dim=1).indices  # (b,topk)
        gather = Gn[idx]                                # (b,topk,1536)
        s = torch.gather(sims, 1, idx)                  # (b,topk)
        w = torch.softmax(s / tau, dim=1).unsqueeze(-1) # (b,topk,1)
        yb = (w * gather).sum(dim=1)                    # (b,1536)
        Y_ot.append(yb.cpu())
    return torch.cat(Y_ot, dim=0).numpy().astype(np.float32)

# ---- Hard-negative margin loss ----
def hard_margin_loss(pred_n, pos_n, negs_n, margin=0.05):
    """
    pred_n: (B,1536) L2-normalized predictions
    pos_n:  (B,1536) L2-normalized positives
    negs_n: (B,K,1536) L2-normalized negatives
    """
    s_pos = (pred_n * pos_n).sum(dim=1, keepdim=True)        # (B,1)
    s_neg = torch.einsum("bd,bkd->bk", pred_n, negs_n)       # (B,K)
    loss = F.relu(margin - s_pos + s_neg)                    # (B,K)
    return loss.mean()


In [94]:
def train_hardneg_ot(model, loader, opt, *,
                     alpha, beta, moment_w,
                     hard_lambda, hard_margin, hard_k,
                     ot_lambda, Y_ot_tr,
                     G_train, hard_top_lists, pos_local_idx,
                     device):
    model.train(); total=0.0

    # normalized TRAIN gallery once
    G = F.normalize(torch.from_numpy(G_train).to(device), dim=-1)

    for batch_idx, (xb, yb, row_ids) in enumerate(loader):
        xb = xb.to(device); yb = yb.to(device)
        opt.zero_grad(set_to_none=True)

        pred = model(xb)    # residual head on top of frozen W2W
        # base pairwise pieces
        cos = 1 - F.cosine_similarity(pred, yb, dim=-1).mean()
        mse = F.mse_loss(pred, yb)
        a_loss = moment_w * moment_align(pred, yb) if moment_w>0 else pred.new_tensor(0.0)

        # OT teacher (pull toward barycentric target)
        y_ot = torch.from_numpy(Y_ot_tr[row_ids.numpy()]).to(device)
        loss_ot = F.mse_loss(pred, y_ot) * ot_lambda

        # Hard negatives hinge
        # build per-sample K hard negatives (from pre-mined topN)
        negs = []
        pred_n = F.normalize(pred, dim=-1)
        yb_n   = F.normalize(yb, dim=-1)

        # sample K per row
        for rid in row_ids.numpy().tolist():
            pool = hard_top_lists[rid]
            if len(pool) == 0:
                pool = [pos_local_idx[rid]]  # fallback; will be masked by margin anyway
            choose = np.random.choice(pool, size=min(hard_k, len(pool)), replace=False)
            negs.append(G[torch.as_tensor(choose, device=device)])
        negs = torch.stack(negs, dim=0)  # (B,K,1536)

        loss_hard = hard_margin_loss(pred_n, yb_n, negs, margin=hard_margin) * hard_lambda

        loss = alpha*cos + beta*mse + a_loss + loss_ot + loss_hard
        loss.backward(); opt.step()
        total += loss.item()*xb.size(0)

    return total/len(loader.dataset)


In [95]:
class PairDSWithIds(Dataset):
    def __init__(self, X, Y):
        self.X = torch.from_numpy(X)
        self.Y = torch.from_numpy(Y)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, i): return self.X[i], self.Y[i], torch.tensor(i, dtype=torch.long)


In [99]:
# ---------- CLI ----------
if __name__ == "__main__":
    p = argparse.ArgumentParser()
    # Core
    p.add_argument("--exp", type=str, default="hardneg_ot",
               choices=["rank_ft","adapters","geom_only","text_moe","hardneg_ot"])
    p.add_argument("--out_dir", type=str, default="w2w_cluster_adapters")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--val_ratio", type=float, default=0.10)
    p.add_argument("--batch", type=int, default=1024)
    p.add_argument("--pooling", type=str, default="CLS")  # no-op
    # Geometry
    p.add_argument("--geom_eps", type=float, default=1e-5)
    # KMeans
    p.add_argument("--k_clusters", type=int, default=64, help="K in KMeans over image latents Y")
    p.add_argument("--k_iters", type=int, default=20)
    # Training (adapters only)
    p.add_argument("--epochs", type=int, default=3)
    p.add_argument("--lr", type=float, default=8e-5)
    p.add_argument("--wd", type=float, default=1e-5)
    # Loss weights
    p.add_argument("--alpha", type=float, default=1.0)      # (1 - cos)
    p.add_argument("--beta", type=float, default=0.2)       # MSE (small)
    p.add_argument("--moment", type=float, default=0.01)    # tiny
    p.add_argument("--rank_loss", type=str, default="soft", choices=["soft","hinge"])
    p.add_argument("--rank_lambda", type=float, default=0.15)
    p.add_argument("--rank_tau", type=float, default=0.03)
    p.add_argument("--rank_margin", type=float, default=0.07)
    p.add_argument("--rank_m", type=int, default=1024)
    p.add_argument("--rank_m2", type=int, default=2048)
    p.add_argument("--rank_curriculum_ep", type=int, default=3)
    
    # regularization on diag/bias
    p.add_argument("--diag_l2", type=float, default=1e-3)
    p.add_argument("--bias_l2", type=float, default=1e-4)
    p.add_argument("--train_centroids", action="store_true",
               help="Replace training targets with per-image centroids in Y")
    p.add_argument("--text_k", type=int, default=12, help="K for KMeans over TRAIN text embeddings")
    p.add_argument("--text_k_iters", type=int, default=20)
    # Hard-negative knobs
    p.add_argument("--hard_topn", type=int, default=50, help="Top-N mined negatives per train sample")
    p.add_argument("--hard_k", type=int, default=8, help="Negatives sampled per batch example")
    p.add_argument("--hard_margin", type=float, default=0.05)
    p.add_argument("--hard_lambda", type=float, default=0.15)
    
    # OT teacher knobs
    p.add_argument("--ot_lambda", type=float, default=0.10)
    p.add_argument("--ot_knn", type=int, default=32)
    p.add_argument("--ot_tau", type=float, default=0.05)

    args, _ = p.parse_known_args()
    main(args)

[meta] text_dim=1024 | image_dim=1536
[hardneg] mining top50 negatives on TRAIN gallery...
[ot] computing barycentric OT targets (knn=32, tau=0.05)...
[model] hardneg+OT (geom frozen) | learnable=3,072 (~0.01 MB) | α=1.0 β=0.2 λ_moment=0.01 | λ_hard=0.15 m=0.05 | λ_OT=0.1
[hardneg-ot 01] train_loss=0.317054 | val_MRR=0.3170 | R@1=0.208 R@5=0.434 R@10=0.543 | median=8 p75=41
[hardneg-ot 02] train_loss=0.315040 | val_MRR=0.3150 | R@1=0.206 R@5=0.431 R@10=0.539 | median=8 p75=42
[hardneg-ot 03] train_loss=0.313203 | val_MRR=0.3132 | R@1=0.205 R@5=0.429 R@10=0.537 | median=9 p75=43
Generating submission file...
✓ Saved submission to /kaggle/working/outputs/w2w_cluster_adapters/submission.csv
[ok] submission written → /kaggle/working/outputs/w2w_cluster_adapters/submission.csv
{
  "dims": {
    "text": 1024,
    "image": 1536
  },
  "exp": "hardneg_ot",
  "split": {
    "train_captions": 112500,
    "val_captions": 12500,
    "val_unique_images": 2500,
    "leakage": false
  },
  "val_metri

In [9]:
def build_image_centroid_targets(img_ids_row, Y, use_mask=None):
    img_ids_row = np.asarray(img_ids_row).astype(str)
    if use_mask is None:
        use_mask = np.ones(len(img_ids_row), dtype=bool)

    keys = img_ids_row[use_mask]
    vals = Y[use_mask]
    uniq, inv = np.unique(keys, return_inverse=True)
    sums = np.zeros((len(uniq), Y.shape[1]), dtype=np.float32)
    cnts = np.zeros(len(uniq), dtype=np.int64)
    for i, k in enumerate(inv):
        sums[k] += vals[i]; cnts[k] += 1
    cent = sums / np.maximum(1, cnts)[:, None]

    name2idx = {name:i for i, name in enumerate(uniq)}
    Ycent = np.empty_like(Y)
    for i, nm in enumerate(img_ids_row):
        j = name2idx.get(nm, None)
        Ycent[i] = cent[j] if j is not None else Y[i]
    return Ycent


In [10]:
class RankFTHead(nn.Module):
    # Frozen W2W → tiny Y-space diag+bias
    def __init__(self, mu_x, mu_y, Cx_mh, Cy_ph, R, use_diag=True, use_bias=True):
        super().__init__()
        self.register_buffer("mu_x", torch.from_numpy(mu_x))
        self.register_buffer("mu_y", torch.from_numpy(mu_y))
        self.register_buffer("Cx_mh", torch.from_numpy(Cx_mh))
        self.register_buffer("Cy_ph", torch.from_numpy(Cy_ph))
        self.register_buffer("R", torch.from_numpy(R))
        D_out = mu_y.shape[0]
        self.use_diag, self.use_bias = use_diag, use_bias
        if use_diag: self.logD = nn.Parameter(torch.zeros(D_out))
        else: self.register_parameter("logD", None)
        if use_bias: self.bias = nn.Parameter(torch.zeros(D_out))
        else: self.register_parameter("bias", None)

    def forward_base(self, x):
        z = x - self.mu_x
        z = F.linear(z, self.Cx_mh)
        z = F.linear(z, self.R)
        z = F.linear(z, self.Cy_ph) + self.mu_y
        return z

    def forward(self, x):
        yb = self.forward_base(x)
        if self.use_diag: yb = yb * torch.exp(self.logD)
        if self.use_bias: yb = yb + self.bias
        return yb


In [11]:
def listwise_rank_loss_soft(sims, tau):
    s_pos = sims[:, :1]; s_neg = sims[:, 1:]
    rank = 1.0 + torch.sigmoid((s_neg - s_pos)/tau).sum(dim=1)
    return torch.log(rank).mean()

def train_rank_ft(model, loader, opt, *,
                  alpha, beta, moment_w,
                  rank_lambda, rank_tau, rank_margin, rank_loss,
                  rank_m, train_gallery_unique, device,
                  diag_l2, bias_l2):
    model.train(); total=0.0
    G = F.normalize(torch.from_numpy(train_gallery_unique).to(device), dim=-1)
    G_size = G.shape[0]

    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        opt.zero_grad(set_to_none=True)
        pred = model(xb)

        cos = 1 - F.cosine_similarity(pred, yb, dim=-1).mean()
        mse = F.mse_loss(pred, yb)
        a_loss = moment_w * (moment_align(pred, yb) if moment_w>0 else pred.new_tensor(0.0))

        with torch.no_grad():
            neg_idx = torch.randint(0, G_size, (rank_m-1,), device=device)
            Gs = G[neg_idx]
        pred_n, yb_n = F.normalize(pred, dim=-1), F.normalize(yb, dim=-1)
        s_pos = (pred_n * yb_n).sum(dim=1, keepdim=True)
        s_neg = pred_n @ Gs.t()
        sims = torch.cat([s_pos, s_neg], dim=1)
        l_rank = listwise_rank_loss_soft(sims, rank_tau) * rank_lambda

        reg = pred.new_tensor(0.0)
        if getattr(model, "logD", None) is not None and model.logD.requires_grad: reg += (model.logD**2).mean() * diag_l2
        if getattr(model, "bias", None) is not None and model.bias.requires_grad: reg += (model.bias**2).mean() * bias_l2

        loss = alpha*cos + beta*mse + a_loss + l_rank + reg
        loss.backward(); opt.step()
        total += loss.item()*xb.size(0)
    return total/len(loader.dataset)


In [108]:
# ---------- CLI ----------
if __name__ == "__main__":
    p = argparse.ArgumentParser()
    # Core
    p.add_argument("--exp", type=str, default="expA_rankft_prototypes",
               choices=["rank_ft","adapters","geom_only","text_moe","hardneg_ot","expA_rankft_prototypes"])
    p.add_argument("--out_dir", type=str, default="w2w_cluster_adapters")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--val_ratio", type=float, default=0.10)
    p.add_argument("--batch", type=int, default=1024)
    p.add_argument("--pooling", type=str, default="CLS")  # no-op
    # Geometry
    p.add_argument("--geom_eps", type=float, default=1e-5)
    # KMeans
    p.add_argument("--k_clusters", type=int, default=64, help="K in KMeans over image latents Y")
    p.add_argument("--k_iters", type=int, default=20)
    # rank-first config
    p.add_argument("--epochs", type=int, default=4)
    p.add_argument("--lr", type=float, default=8e-5)
    p.add_argument("--wd", type=float, default=1e-5)
    p.add_argument("--alpha", type=float, default=1.0)
    p.add_argument("--beta", type=float, default=0.2)
    p.add_argument("--moment", type=float, default=0.01)
    p.add_argument("--rank_lambda", type=float, default=0.18)  # 0.15–0.20; start at 0.18
    p.add_argument("--rank_tau", type=float, default=0.03)
    p.add_argument("--rank_m", type=int, default=1024)
    p.add_argument("--rank_m2", type=int, default=2048)
    p.add_argument("--rank_curriculum_ep", type=int, default=3)
    p.add_argument("--diag_l2", type=float, default=1e-3)
    p.add_argument("--bias_l2", type=float, default=1e-4)
    p.add_argument("--rank_margin", type=float, default=0.07)

    args, _ = p.parse_known_args()
    main(args)

[meta] text_dim=1024 | image_dim=1536
[model] EXP-A (Prototypes + Rank-FT) | learnable=3,072 (~0.01 MB) | α=1.0 β=0.2 moment=0.01 λ_rank=0.18 τ=0.03
[expA 01] M=1024 train_loss=0.773057 | val_MRR=0.3228 | R@1=0.212 R@5=0.439 R@10=0.549 | median=8 p75=39
[expA 02] M=1024 train_loss=0.769534 | val_MRR=0.3265 | R@1=0.215 R@5=0.443 R@10=0.555 | median=8 p75=37
[expA 03] M=2048 train_loss=0.879633 | val_MRR=0.3292 | R@1=0.217 R@5=0.446 R@10=0.560 | median=7 p75=36
[expA 04] M=2048 train_loss=0.876565 | val_MRR=0.3316 | R@1=0.219 R@5=0.450 R@10=0.564 | median=7 p75=35
Generating submission file...
✓ Saved submission to /kaggle/working/outputs/w2w_cluster_adapters/submission.csv
[ok] submission written → /kaggle/working/outputs/w2w_cluster_adapters/submission.csv
{
  "dims": {
    "text": 1024,
    "image": 1536
  },
  "exp": "expA_rankft_prototypes",
  "split": {
    "train_captions": 112500,
    "val_captions": 12500,
    "val_unique_images": 2500,
    "leakage": false
  },
  "val_metrics":

In [109]:
class LowRankResidualHead(nn.Module):
    """
    Frozen geometry + low-rank residual in the *whitened* spaces.

    Pipeline:
      z = Cx^{-1/2} (x - mu_x)          # text whitened (1024)
      y_w = R z + U (V^T z)              # whitened image space (1536), rank-r residual
      y   = Cy^{1/2} y_w + mu_y          # re-color, add mean  -> image space (1536)

    Trainable: U (1536 x r), V (1024 x r)
    Frozen   : mu_x, mu_y, Cx_mh, Cy_ph, R
    """
    def __init__(self, mu_x, mu_y, Cx_mh, Cy_ph, R, rank=8):
        super().__init__()
        self.register_buffer("mu_x", torch.from_numpy(mu_x))
        self.register_buffer("mu_y", torch.from_numpy(mu_y))
        self.register_buffer("Cx_mh", torch.from_numpy(Cx_mh))  # (1024,1024)
        self.register_buffer("Cy_ph", torch.from_numpy(Cy_ph))  # (1536,1536)
        self.register_buffer("R", torch.from_numpy(R))          # (1536,1024)

        self.rank = int(rank)
        # U,V init small so we start near W2W solution
        self.U = nn.Parameter(torch.zeros(1536, self.rank))
        self.V = nn.Parameter(torch.zeros(1024, self.rank))
        nn.init.xavier_uniform_(self.U, gain=0.01)
        nn.init.xavier_uniform_(self.V, gain=0.01)

    def forward(self, x):
        # whiten text
        z = x - self.mu_x
        z = F.linear(z, self.Cx_mh)        # (B,1024)

        # base rotation + low-rank residual in whitened image space
        base_w = F.linear(z, self.R)       # (B,1536)
        # V^T z = (B, r); U (V^T z) = (B,1536)
        low = (z @ self.V) @ self.U.T
        y_w = base_w + low

        # re-color + mean
        y = F.linear(y_w, self.Cy_ph) + self.mu_y
        return y


In [112]:
# ---------- CLI ----------
if __name__ == "__main__":
    p = argparse.ArgumentParser()
    # Core
    p.add_argument("--exp", type=str, default="lowrank_ft",
                   choices=["rank_ft","adapters","geom_only","text_moe","hardneg_ot","expA_rankft_prototypes","lowrank_ft"])
    p.add_argument("--out_dir", type=str, default="w2w_cluster_adapters")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--val_ratio", type=float, default=0.10)
    p.add_argument("--batch", type=int, default=1024)
    p.add_argument("--pooling", type=str, default="CLS")  # no-op
    # Geometry
    p.add_argument("--geom_eps", type=float, default=1e-5)
    # KMeans
    p.add_argument("--k_clusters", type=int, default=64, help="K in KMeans over image latents Y")
    p.add_argument("--k_iters", type=int, default=20)
    # rank-first config
    p.add_argument("--epochs", type=int, default=4)
    p.add_argument("--lr", type=float, default=8e-5)
    p.add_argument("--wd", type=float, default=1e-5)
    p.add_argument("--alpha", type=float, default=1.0)
    p.add_argument("--beta", type=float, default=0.2)
    p.add_argument("--moment", type=float, default=0.01)
    p.add_argument("--rank_lambda", type=float, default=0.18)  # 0.15–0.20; start at 0.18
    p.add_argument("--rank_tau", type=float, default=0.03)
    p.add_argument("--rank_m", type=int, default=1024)
    p.add_argument("--rank_m2", type=int, default=2048)
    p.add_argument("--rank_curriculum_ep", type=int, default=3)
    p.add_argument("--diag_l2", type=float, default=1e-3)
    p.add_argument("--bias_l2", type=float, default=1e-4)
    p.add_argument("--rank_margin", type=float, default=0.07)
    p.add_argument("--lowrank_r", type=int, default=8)

    args, _ = p.parse_known_args()
    main(args)

[meta] text_dim=1024 | image_dim=1536
[model] low-rank residual (r=8) | learnable=20,480 (~0.08 MB) | α=1.0 β=0.2 moment=0.01 λ_rank=0.18 τ=0.03
[lowrank 01] M=1024 train_loss=0.773705 | val_MRR=0.3191 | R@1=0.208 R@5=0.437 R@10=0.547 | median=8 p75=39
[lowrank 02] M=1024 train_loss=0.767150 | val_MRR=0.3206 | R@1=0.208 R@5=0.439 R@10=0.552 | median=8 p75=39
[lowrank 03] M=2048 train_loss=0.875079 | val_MRR=0.3210 | R@1=0.208 R@5=0.441 R@10=0.553 | median=8 p75=37
[lowrank 04] M=2048 train_loss=0.870636 | val_MRR=0.3233 | R@1=0.211 R@5=0.444 R@10=0.557 | median=8 p75=36
Generating submission file...
✓ Saved submission to /kaggle/working/outputs/w2w_cluster_adapters/submission.csv
[ok] submission written → /kaggle/working/outputs/w2w_cluster_adapters/submission.csv
{
  "dims": {
    "text": 1024,
    "image": 1536
  },
  "exp": "lowrank_ft",
  "rank": 8,
  "split": {
    "train_captions": 112500,
    "val_captions": 12500,
    "val_unique_images": 2500,
    "leakage": false
  },
  "val_

In [17]:
class CodebookTranslator(nn.Module):
    """
    Frozen W2W backbone + learned codebook C \in R^{K x 1536}.
    Predictor maps W2W features -> soft weights over K (softmax),
    output y_hat = sum_k w_k * C_k (lives in image space).
    """
    def __init__(self, mu_x, mu_y, Cx_mh, Cy_ph, R, codebook_init: np.ndarray,
                 predictor="mlp", hidden=512, dropout=0.1):
        super().__init__()
        # Frozen geometry
        self.register_buffer("mu_x", torch.from_numpy(mu_x))
        self.register_buffer("mu_y", torch.from_numpy(mu_y))
        self.register_buffer("Cx_mh", torch.from_numpy(Cx_mh))
        self.register_buffer("Cy_ph", torch.from_numpy(Cy_ph))
        self.register_buffer("R", torch.from_numpy(R))

        K, D = codebook_init.shape
        assert D == self.mu_y.shape[0] == 1536
        self.K = K; self.D = D

        # Codebook is learnable
        self.codebook = nn.Parameter(torch.from_numpy(codebook_init.astype(np.float32)))  # (K,1536)

        # Predictor over W2W features (image space)
        if predictor == "linear":
            self.pred = nn.Linear(D, K)
        else:
            self.pred = nn.Sequential(
                nn.Linear(D, hidden), nn.GELU(), nn.Dropout(dropout),
                nn.Linear(hidden, K)
            )

    def forward_base(self, x):
        z = x - self.mu_x
        z = F.linear(z, self.Cx_mh)
        z = F.linear(z, self.R)
        z = F.linear(z, self.Cy_ph) + self.mu_y
        return z  # (B,1536) in image space

    def forward_logits(self, x):
        yb = self.forward_base(x)          # (B,1536)
        logits = self.pred(yb)             # (B,K)
        return yb, logits

    def forward(self, x, return_weights=False):
        yb, logits = self.forward_logits(x)
        w = torch.softmax(logits, dim=-1)  # (B,K)
        y_hat = w @ self.codebook          # (B,1536)
        return (y_hat, w) if return_weights else y_hat


In [18]:
from torch.cuda.amp import autocast, GradScaler

def train_codebook_rank(model, loader, opt, *,
                        alpha, beta, moment_w,
                        rank_lambda, rank_tau, rank_m,
                        ent_lambda,
                        train_gallery_unique, device,
                        amp=False):
    model.train(); total=0.0
    scaler = GradScaler(enabled=(amp and device.type=="cuda"))

    G = F.normalize(torch.from_numpy(train_gallery_unique).to(device), dim=-1)
    G_size = G.shape[0]

    for xb, y_proto in loader:
        xb = xb.to(device); y_proto = y_proto.to(device)
        opt.zero_grad(set_to_none=True)

        with autocast(device_type="cuda", enabled=(amp and device.type=="cuda")):
            pred, w = model(xb, return_weights=True)            # (B,1536), (B,K)
            # base + prototype loss
            cos = 1 - F.cosine_similarity(pred, y_proto, dim=-1).mean()
            mse = F.mse_loss(pred, y_proto)
            a_loss = moment_w * moment_align(pred, y_proto) if moment_w>0 else pred.new_tensor(0.0)

            # listwise (sampled negatives)
            with torch.no_grad():
                neg_idx = torch.randint(0, G_size, (rank_m-1,), device=device)
                Gs = G[neg_idx]
            pred_n  = F.normalize(pred,     dim=-1)
            yprot_n = F.normalize(y_proto,  dim=-1)
            s_pos = (pred_n * yprot_n).sum(dim=1, keepdim=True)
            s_neg = pred_n @ Gs.t()
            sims = torch.cat([s_pos, s_neg], dim=1)
            l_rank = listwise_rank_loss_soft(sims, rank_tau) * rank_lambda

            # entropy reg on weights
            H = -(w * (w.clamp_min(1e-12).log())).sum(dim=1).mean()
            ent = - ent_lambda * H

            loss = alpha*cos + beta*mse + a_loss + l_rank + ent

        if scaler.is_enabled():
            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
        else:
            loss.backward(); opt.step()

        total += loss.item()*xb.size(0)
    return total/len(loader.dataset)


In [None]:
# ---------- CLI ----------
if __name__ == "__main__":
    p = argparse.ArgumentParser()
    # Core
    p.add_argument("--exp", type=str, default="codebook_ft",
                   choices=["rank_ft","adapters","geom_only","text_moe","hardneg_ot","expA_rankft_prototypes","lowrank_ft","codebook_ft"])
    p.add_argument("--out_dir", type=str, default="w2w_cluster_adapters")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--val_ratio", type=float, default=0.10)
    p.add_argument("--batch", type=int, default=512)
    p.add_argument("--pooling", type=str, default="CLS")  # no-op
    # Geometry
    p.add_argument("--geom_eps", type=float, default=1e-5)
    # KMeans
    p.add_argument("--k_clusters", type=int, default=64, help="K in KMeans over image latents Y")
    p.add_argument("--k_iters", type=int, default=20)
    # rank-first config
    p.add_argument("--epochs", type=int, default=4)
    p.add_argument("--lr", type=float, default=5e-5)
    p.add_argument("--wd", type=float, default=1e-5)
    p.add_argument("--alpha", type=float, default=1.0)
    p.add_argument("--beta", type=float, default=0.2)
    p.add_argument("--moment", type=float, default=0.01)
    p.add_argument("--rank_lambda", type=float, default=0.18)  # 0.15–0.20; start at 0.18
    p.add_argument("--rank_tau", type=float, default=0.03)
    p.add_argument("--rank_m", type=int, default=1024)
    p.add_argument("--rank_m2", type=int, default=2048)
    p.add_argument("--rank_curriculum_ep", type=int, default=3)
    p.add_argument("--diag_l2", type=float, default=1e-3)
    p.add_argument("--bias_l2", type=float, default=1e-4)
    p.add_argument("--rank_margin", type=float, default=0.07)
    p.add_argument("--lowrank_r", type=int, default=8)
    p.add_argument("--cb_k", type=int, default=128, help="Codebook size (K)")
    p.add_argument("--cb_k_iters", type=int, default=15, help="KMeans iters for codebook init")
    p.add_argument("--cb_pred", type=str, default="mlp", choices=["linear","mlp"])
    p.add_argument("--cb_hidden", type=int, default=384)
    p.add_argument("--cb_drop", type=float, default=0.1)
    p.add_argument("--cb_ent", type=float, default=1e-3, help="Entropy reg on weights")

    args, _ = p.parse_known_args()
    main(args)

[meta] text_dim=1024 | image_dim=1536
[codebook] KMeans init K=128 iters=15


Notebook died even with this cheap paraameters