In [20]:
from kaggle_secrets import UserSecretsClient
import json, os, pathlib, subprocess, sys

# --- 1. Load your secret (full JSON from kaggle.json) ---
secret_name = "kaggle_json"  # change if you used another name
user_secrets = UserSecretsClient()
raw = user_secrets.get_secret(secret_name)
creds = json.loads(raw)

# --- 2. Forcefully recreate ~/.kaggle/kaggle.json ---
kaggle_dir = pathlib.Path.home() / ".kaggle"
kaggle_dir.mkdir(parents=True, exist_ok=True)
cred_path = kaggle_dir / "kaggle.json"
cred_path.write_text(json.dumps(creds))
os.chmod(cred_path, 0o600)

# --- 3. Double-check the file actually exists and is readable ---
print("✅ Credentials written to:", cred_path)
!ls -la ~/.kaggle/
#!cat ~/.kaggle/kaggle.json | head -1

# --- 4. Reinstall Kaggle CLI cleanly ---
#!pip install --upgrade --force-reinstall kaggle --quiet



✅ Credentials written to: /root/.kaggle/kaggle.json
total 16
drwxr-xr-x 2 root root 4096 Oct 26 19:54 .
drwx------ 1 root root 4096 Oct 26 20:01 ..
-rw------- 1 root root   72 Oct 26 20:42 kaggle.json


In [None]:
!kaggle competitions files -c aml-competition
!kaggle competitions download -c aml-competition -p /kaggle/working --force
!mkdir -p /kaggle/working/data
!unzip -o /kaggle/working/aml-competition.zip -d /kaggle/working/data
!ls -lah /kaggle/working/data

In [3]:
!git clone https://github.com/Mamiglia/challenge.git

Cloning into 'challenge'...
remote: Enumerating objects: 98, done.[K
remote: Counting objects: 100% (98/98), done.[K
remote: Compressing objects: 100% (69/69), done.[K
remote: Total 98 (delta 39), reused 72 (delta 26), pack-reused 0 (from 0)[K
Receiving objects: 100% (98/98), 21.03 MiB | 21.15 MiB/s, done.
Resolving deltas: 100% (39/39), done.


In [21]:
# runner.py
import json, random, re, time
from pathlib import Path
from os.path import basename
from collections import defaultdict

import numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from challenge.src.common import load_data, generate_submission

# ---------- Paths ----------
DATA_ROOT = Path("/kaggle/working/data")
TRAIN_DIR = DATA_ROOT / "train" / "train"
TEST_DIR  = DATA_ROOT / "test"  / "test"
TRAIN_NPZ = TRAIN_DIR / "train.npz"
TEST_NPZ  = TEST_DIR / "test.clean.npz"
TRAIN_CAPTIONS = TRAIN_DIR / "captions.txt"
TEST_CAPTIONS  = TEST_DIR  / "captions.txt"

# ---------- Utils ----------
def seed_all(s: int):
    random.seed(s); np.random.seed(s)
    torch.manual_seed(s); torch.cuda.manual_seed_all(s)
    torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False

def _build_image_index(image_ids):
    idx_exact = {}; idx_base={}; idx_stem={}
    for i, v in enumerate(image_ids):
        s = str(v); idx_exact[s]=i
        b = basename(s); idx_base[b]=i
        stem = b.rsplit(".",1)[0] if "." in b else b
        idx_stem[stem]=i
    return idx_exact, idx_base, idx_stem

def _iter_targets_from_captions(path: Path, n_text: int, idx_exact, idx_base, idx_stem):
    targets=[]
    with open(path, "r", encoding="utf-8") as f:
        for raw in f:
            if len(targets)>=n_text: break
            line=raw.strip()
            if not line: continue
            lower=line.lower()
            if ("image" in lower or "filename" in lower) and ("," in line or "|" in line or "\t" in line):
                first=re.split(r'\||\t|,| {2,}', line)[0].strip().strip('"').strip("'")
                if "." not in first: continue
            parts=re.split(r'\||\t|,| {2,}', line)
            if len(parts)==1: parts=line.split(" ",1)
            tok=parts[0].strip().strip('"').strip("'")
            if not tok: continue
            def _match(tok_):
                if tok_ in idx_exact: return idx_exact[tok_]
                b=basename(tok_)
                if b in idx_base: return idx_base[b]
                stem=b.rsplit(".",1)[0] if "." in b else b
                if stem in idx_stem: return idx_stem[stem]
                return None
            idx=_match(tok)
            if idx is None:
                tok2=tok.replace("Images/","").replace("./","")
                idx=_match(tok2)
            if idx is None:
                if "." not in tok: continue
                raise AssertionError(f"Could not match image token '{tok}'")
            targets.append(idx)
    assert len(targets)==n_text, f"matched {len(targets)} vs N_text {n_text}"
    return np.asarray(targets, dtype=np.int64)

def load_train(out_dir: Path):
    d=np.load(TRAIN_NPZ, allow_pickle=True)
    X=d["captions/embeddings"].astype(np.float32)  # (N_text, 1024 or P*D)
    I=d["images/embeddings"].astype(np.float32)    # (N_img, 1536)
    cap_ids=d.get("captions/ids", np.arange(len(X)).astype(str))
    img_names=d.get("images/names", np.arange(len(I)).astype(str))

    ex,ba,st=_build_image_index(img_names)
    assert TRAIN_CAPTIONS.exists(), f"Missing {TRAIN_CAPTIONS}"
    targets=_iter_targets_from_captions(TRAIN_CAPTIONS, len(X), ex,ba,st)

    Y=I[targets]                       # (N_text, 1536) ground-truth
    img_ids_row = img_names[targets]   # image id per row
    meta={"n_text":int(len(X)),"n_images":int(len(I))}
    (out_dir/"train_detect.json").write_text(json.dumps(meta, indent=2))
    return X,Y,cap_ids,img_ids_row,I,img_names

def load_test():
    d_test=np.load(TEST_NPZ, allow_pickle=True)
    Q=d_test["captions/embeddings"].astype(np.float32)
    q_ids=d_test.get("captions/ids", np.arange(len(Q)).astype(str))
    # gallery: prefer test's, else train's
    if "images/embeddings" in d_test.files:
        G=d_test["images/embeddings"].astype(np.float32)
        g_ids=d_test.get("images/names", np.arange(len(G)).astype(str))
    else:
        d_tr=np.load(TRAIN_NPZ, allow_pickle=True)
        G=d_tr["images/embeddings"].astype(np.float32)
        g_ids=d_tr.get("images/names", np.arange(len(G)).astype(str))
    return Q,G,q_ids,g_ids

def infer_arch_from_state_dict(sd):
    ks = set(sd.keys())
    if {"fc.weight","fc.bias"} <= ks: return "linear"
    if {"fc1.weight","fc1.bias","fc2.weight","fc2.bias"} <= ks: return "mlp1"
    if any(k.startswith("net.") for k in ks): return "mlp2"
    raise ValueError("Unrecognized checkpoint layout.")

# ---------- Dataset ----------
class PairDS(Dataset):
    def __init__(self, X, Y): self.X=torch.from_numpy(X); self.Y=torch.from_numpy(Y)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, i): return self.X[i], self.Y[i]

# ---------- Pooling ----------
def apply_pooling(x: torch.Tensor, mode: str, n_patches: int|None):
    if mode.lower()=="cls":            # identity
        return x
    if mode.lower().startswith("mean"):
        if x.ndim==3: return x.mean(1)
        if n_patches and x.shape[1]%n_patches==0:
            P=n_patches; D=x.shape[1]//P
            return x.view(x.shape[0], P, D).mean(1)
        return x                       # fallback
    return x

# ---------- Models ----------
class LinearProj(nn.Module):
    def __init__(self, din, dout):
        super().__init__()
        self.fc=nn.Linear(din,dout)
        nn.init.xavier_normal_(self.fc.weight); nn.init.zeros_(self.fc.bias)
    def forward(self,x): return self.fc(x)

class MLP1(nn.Module):
    def __init__(self, din, dout, hidden=512, pdrop=0.0):
        super().__init__()
        self.net=nn.Sequential(
            nn.Linear(din,hidden), nn.ReLU(), nn.Dropout(pdrop),
            nn.Linear(hidden,dout)
        )
        for m in self.net:
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight); nn.init.zeros_(m.bias)
    def forward(self,x): return self.net(x)

class MLP2(nn.Module):
    def __init__(self, din, dout, h1=1024, h2=512, pdrop=0.0):
        super().__init__()
        self.net=nn.Sequential(
            nn.Linear(din,h1), nn.ReLU(), nn.Dropout(pdrop),
            nn.Linear(h1,h2), nn.ReLU(), nn.Dropout(pdrop),
            nn.Linear(h2,dout)
        )
        for m in self.net:
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight); nn.init.zeros_(m.bias)
    def forward(self,x): return self.net(x)

def make_model(arch:str, din:int, dout:int):
    arch=arch.lower()
    if arch=="linear": return LinearProj(din,dout)
    if arch=="mlp1":   return MLP1(din,dout,hidden=512)
    if arch=="mlp2":   return MLP2(din,dout,h1=1024,h2=512)
    raise ValueError(f"Unknown arch {arch}")

# ---------- Loss pieces ----------
def loss_align(pred, tgt, kind:str):
    if kind=="none": return pred.new_tensor(0.0)
    if kind=="moment":
        # match batch mean and std (channel-wise)
        mu_p, mu_t = pred.mean(0), tgt.mean(0)
        sd_p, sd_t = pred.std(0, unbiased=False), tgt.std(0, unbiased=False)
        return F.mse_loss(mu_p, mu_t) + F.mse_loss(sd_p, sd_t)
    if kind=="normcal":
        # calibrate L2 norm distribution
        np_ = pred.norm(dim=-1).mean()
        nt_ = tgt.norm(dim=-1).mean()
        return F.mse_loss(np_, nt_)
    raise ValueError(kind)

def info_nce(pred, tgt):
    # in-batch InfoNCE with cosine sim
    p = F.normalize(pred, dim=-1)
    t = F.normalize(tgt, dim=-1)
    logits = p @ t.t()                      # (B,B)
    labels = torch.arange(pred.size(0), device=pred.device)
    return F.cross_entropy(logits, labels)

# ---------- Metrics ----------
@torch.no_grad()
def validate_retrieval(model, Xv, Yv, pooling, n_patches, bs=1024):
    device=next(model.parameters()).device
    # gallery = unique images in val
    Yv_np = Yv.copy()
    # build unique gallery
    _, uniq_idx = np.unique(Yv_np, axis=0, return_index=True)  # cheap, works since Yv rows are copies from I
    Gi = torch.from_numpy(Yv_np[sorted(uniq_idx)]).to(device)
    Gi = F.normalize(Gi, dim=-1)

    ranks=[]
    for i in range(0, len(Xv), bs):
        xb=torch.from_numpy(Xv[i:i+bs]).to(device)
        xb=apply_pooling(xb, pooling, n_patches)
        pred=model(xb)
        pred=F.normalize(pred, dim=-1)
        sims=pred @ Gi.t()                 # (b, M)
        # compute rank of the matching Y row within Gi
        for j in range(sims.size(0)):
            # find index in Gi that matches Yv row
            y = torch.from_numpy(Yv_np[i+j]).to(device)
            y = F.normalize(y, dim=-1)
            # true index = argmax sim with Gi
            true_idx = torch.argmax(Gi @ y, dim=0).item()
            order = torch.argsort(sims[j], descending=True)
            rank = (order==true_idx).nonzero(as_tuple=False).item() + 1
            ranks.append(rank)

    ranks = np.array(ranks)
    mrr = np.mean(1.0 / ranks)
    r1  = np.mean(ranks<=1)
    r5  = np.mean(ranks<=5)
    r10 = np.mean(ranks<=10)
    return dict(MRR=float(mrr), R1=float(r1), R5=float(r5), R10=float(r10),
                median=int(np.median(ranks)), p75=int(np.percentile(ranks,75)))

def count_params_mb(model):
    params=sum(p.numel() for p in model.parameters())
    mb = params * 4 / (1024**2)
    return params, mb

def time_ms_per_query(model, din, pooling, n_patches):
    device=next(model.parameters()).device
    x=torch.randn(2048, din, device=device)
    x=apply_pooling(x, pooling, n_patches)
    torch.cuda.synchronize(device) if device.type=="cuda" else None
    t0=time.time()
    with torch.no_grad(): _=model(x)
    torch.cuda.synchronize(device) if device.type=="cuda" else None
    ms = (time.time()-t0)*1000/len(x)
    # cpu timing
    mcpu=model.to("cpu")
    xcpu=x.to("cpu")
    t1=time.time()
    with torch.no_grad(): _=mcpu(xcpu)
    ms_cpu=(time.time()-t1)*1000/len(xcpu)
    model.to(device)
    return ms, ms_cpu

# ---------- Training ----------
def train_one(model, loader, opt, alpha, beta, gamma, align_kind, pooling, n_patches, device):
    model.train(); total=0.0
    for xb,yb in loader:
        xb,yb=xb.to(device), yb.to(device)
        xb=apply_pooling(xb, pooling, n_patches)
        opt.zero_grad(set_to_none=True)
        pred=model(xb)
        cos = 1 - F.cosine_similarity(pred, yb, dim=-1).mean()
        mse = F.mse_loss(pred, yb)
        a_loss = loss_align(pred, yb, align_kind)
        ce = info_nce(pred, yb) if gamma>0 else pred.new_tensor(0.0)
        loss = alpha*cos + beta*mse + gamma*ce + a_loss
        loss.backward(); opt.step()
        total += loss.item()*xb.size(0)
    return total/len(loader.dataset)

# ---------- Main ----------
def main(
    out_dir="baseline_ref", seed=42, epochs=20, batch=512, lr=1e-4, wd=1e-4,
    pooling="CLS", n_patches=None, alpha=1.0, beta=1.0, gamma=0.0,
    align_loss="none", arch="auto", train=True, val_ratio=0.1
):
    seed_all(seed)
    OUT = Path(f"/kaggle/working/outputs/{out_dir}"); OUT.mkdir(parents=True, exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # --- data
    X,Y,cap_ids,img_ids_row,full_img,all_img_ids = load_train(OUT)
    uniq = np.unique(img_ids_row); rng = np.random.default_rng(seed); rng.shuffle(uniq)
    n_val = max(1, int(len(uniq)*val_ratio)); valset = set(uniq[:n_val])
    m = np.array([iid in valset for iid in img_ids_row])
    Xtr,Ytr,Xva,Yva = X[~m],Y[~m],X[m],Y[m]

    din, dout = X.shape[1], Y.shape[1]

    # --- build/load model
    if train:
        eff_arch = arch if arch != "auto" else "linear"  # default for fresh training
        model = make_model(eff_arch, din, dout).to(device)
    else:
        ckpt = torch.load(OUT/"best.pt", map_location="cpu")
        eff_arch = arch if arch != "auto" else infer_arch_from_state_dict(ckpt["model"])
        print(f"[ckpt] detected arch = {eff_arch}")
        model = make_model(eff_arch, din, dout).to(device)
        model.load_state_dict(ckpt["model"])

    params, mb = count_params_mb(model)
    print(f"[model] {eff_arch} | params={params:,} (~{mb:.2f} MB) | pooling={pooling} | align={align_loss} | α={alpha} β={beta} γ={gamma}")

    # --- train (if requested)
    if train:
        dl = DataLoader(PairDS(Xtr,Ytr), batch_size=batch, shuffle=True, num_workers=2, pin_memory=True)
        opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
        best = -1.0; best_ep = 0; best_stats = None
        for ep in range(1, epochs+1):
            tr = train_one(model, dl, opt, alpha,beta,gamma,align_loss, pooling, n_patches, device)
            stats = validate_retrieval(model, Xva, Yva, pooling, n_patches)
            print(f"[{ep:02d}] train_loss={tr:.6f} | val_MRR={stats['MRR']:.4f} | R@1={stats['R1']:.3f} R@5={stats['R5']:.3f}")
            if stats["MRR"] > best:
                best, best_ep, best_stats = stats["MRR"], ep, stats
                torch.save({"model":model.state_dict(),"epoch":ep,"val":stats}, OUT/"best.pt")
        print(f"[best] MRR={best:.4f} @ epoch {best_ep}")
        (OUT/"val_metrics.json").write_text(json.dumps(dict(best_epoch=best_ep, **best_stats), indent=2))
    else:
        # already loaded weights above
        ckpt = torch.load(OUT/"best.pt", map_location="cpu")
        print(f"[resume] loaded epoch={ckpt.get('epoch','?')} MRR={ckpt.get('val',{}).get('MRR','?')}")

    # --- efficiency
    ms_gpu, ms_cpu = time_ms_per_query(model, din, pooling, n_patches)
    (OUT/"efficiency.json").write_text(json.dumps(
        {"params":params,"mb_fp32":mb,"ms_per_query_gpu":ms_gpu,"ms_per_query_cpu":ms_cpu}, indent=2
    ))

    # --- submission (official helper)
    TEST_NPZ = (TEST_DIR / "test.clean.npz")
    test_data = load_data(TEST_NPZ)
    Q   = test_data["captions/embeddings"].astype(np.float32)
    ids = test_data.get("captions/ids", np.arange(len(Q)).astype(str))

    model.eval()
    BS = 1024
    outs = []
    with torch.no_grad():
        for i in range(0, len(Q), BS):
            q = torch.from_numpy(Q[i:i+BS]).to(device)
            q = apply_pooling(q, pooling, n_patches)
            z = model(q)
            z = F.normalize(z, dim=-1)
            outs.append(z.detach().cpu().numpy())
    pred_embds = np.concatenate(outs, axis=0)

    sub = OUT / "submission.csv"
    generate_submission(ids, pred_embds, str(sub))
    print(f"[ok] submission written → {sub}")

In [22]:
if __name__=="__main__":
    import argparse
    p=argparse.ArgumentParser()
    p.add_argument("--out_dir", type=str, default="baseline_ref")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--epochs", type=int, default=30)
    p.add_argument("--batch", type=int, default=512)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--wd", type=float, default=1e-4)
    p.add_argument("--pooling", type=str, default="CLS", choices=["CLS","mean-patch"])
    p.add_argument("--n_patches", type=int, default=None)
    p.add_argument("--alpha", type=float, default=1.0)
    p.add_argument("--beta", type=float, default=1.0)
    p.add_argument("--gamma", type=float, default=0.0)
    p.add_argument("--align_loss", type=str, default="none", choices=["none","moment","normcal"])
    p.add_argument("--arch", type=str, default="auto", choices=["auto","linear","mlp1","mlp2"])
    p.add_argument("--train", type=int, default=1)
    p.add_argument("--val_ratio", type=float, default=0.1)
    args, _ = p.parse_known_args()
    main(**vars(args))


[model] linear | params=1,574,400 (~6.01 MB) | pooling=CLS | align=none | α=1.0 β=1.0 γ=0.0
[01] train_loss=0.954751 | val_MRR=0.1206 | R@1=0.057 R@5=0.170
[02] train_loss=0.507948 | val_MRR=0.1398 | R@1=0.070 R@5=0.194
[03] train_loss=0.432700 | val_MRR=0.1506 | R@1=0.077 R@5=0.212
[04] train_loss=0.401670 | val_MRR=0.1607 | R@1=0.083 R@5=0.226
[05] train_loss=0.383644 | val_MRR=0.1665 | R@1=0.087 R@5=0.234
[06] train_loss=0.371423 | val_MRR=0.1701 | R@1=0.088 R@5=0.242
[07] train_loss=0.362467 | val_MRR=0.1754 | R@1=0.094 R@5=0.248
[08] train_loss=0.355643 | val_MRR=0.1786 | R@1=0.095 R@5=0.252
[09] train_loss=0.350284 | val_MRR=0.1810 | R@1=0.096 R@5=0.258
[10] train_loss=0.346007 | val_MRR=0.1829 | R@1=0.097 R@5=0.261
[11] train_loss=0.342501 | val_MRR=0.1846 | R@1=0.099 R@5=0.262
[12] train_loss=0.339609 | val_MRR=0.1882 | R@1=0.102 R@5=0.267
[13] train_loss=0.337181 | val_MRR=0.1888 | R@1=0.101 R@5=0.269
[14] train_loss=0.335115 | val_MRR=0.1915 | R@1=0.104 R@5=0.271
[15] train_l

In [17]:
# !python runner.py --out_dir mlp1_cls_cos_mse --train 1 --arch mlp1 --pooling CLS --alpha 1 --beta 1 --gamma 0 --align_loss none

if __name__=="__main__":
    import argparse
    p=argparse.ArgumentParser()
    p.add_argument("--out_dir", type=str, default="mlp1_cls_cos_mse")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--epochs", type=int, default=30)
    p.add_argument("--batch", type=int, default=512)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--wd", type=float, default=1e-4)
    p.add_argument("--pooling", type=str, default="CLS", choices=["CLS","mean-patch"])
    p.add_argument("--n_patches", type=int, default=None)
    p.add_argument("--alpha", type=float, default=1.0)
    p.add_argument("--beta", type=float, default=1.0)
    p.add_argument("--gamma", type=float, default=0.0)
    p.add_argument("--align_loss", type=str, default="none", choices=["none","moment","normcal"])
    p.add_argument("--arch", type=str, default="mlp1", choices=["auto","linear","mlp1","mlp2"])
    p.add_argument("--train", type=int, default=1)
    p.add_argument("--val_ratio", type=float, default=0.1)
    args, _ = p.parse_known_args()
    main(**vars(args))

[model] mlp1 | params=1,312,768 (~5.01 MB) | pooling=CLS | align=none | α=1.0 β=1.0 γ=0.0
[01] train_loss=0.660932 | val_MRR=0.0333 | R@1=0.009 R@5=0.043
[02] train_loss=0.412790 | val_MRR=0.0618 | R@1=0.024 R@5=0.082
[03] train_loss=0.380159 | val_MRR=0.0828 | R@1=0.035 R@5=0.111
[04] train_loss=0.363856 | val_MRR=0.0970 | R@1=0.042 R@5=0.132
[05] train_loss=0.353517 | val_MRR=0.1121 | R@1=0.052 R@5=0.156
[06] train_loss=0.346098 | val_MRR=0.1212 | R@1=0.055 R@5=0.173
[07] train_loss=0.340426 | val_MRR=0.1313 | R@1=0.061 R@5=0.186
[08] train_loss=0.335868 | val_MRR=0.1379 | R@1=0.064 R@5=0.199
[09] train_loss=0.332036 | val_MRR=0.1467 | R@1=0.069 R@5=0.213
[10] train_loss=0.328716 | val_MRR=0.1541 | R@1=0.074 R@5=0.223
[11] train_loss=0.325812 | val_MRR=0.1613 | R@1=0.080 R@5=0.232
[12] train_loss=0.323231 | val_MRR=0.1674 | R@1=0.085 R@5=0.240
[13] train_loss=0.320927 | val_MRR=0.1739 | R@1=0.089 R@5=0.251
[14] train_loss=0.318824 | val_MRR=0.1789 | R@1=0.093 R@5=0.256
[15] train_los

In [18]:
# python runner.py --out_dir mlp1_patch_cos_mse --pooling mean-patch --n_patches None --alpha 1 --beta 1 --gamma 0 --align_loss none --arch mlp1

if __name__=="__main__":
    import argparse
    p=argparse.ArgumentParser()
    p.add_argument("--out_dir", type=str, default="mlp1_patch_cos_mse")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--epochs", type=int, default=30)
    p.add_argument("--batch", type=int, default=512)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--wd", type=float, default=1e-4)
    p.add_argument("--pooling", type=str, default="mean-patch", choices=["CLS","mean-patch"])
    p.add_argument("--n_patches", type=int, default=None)
    p.add_argument("--alpha", type=float, default=1.0)
    p.add_argument("--beta", type=float, default=1.0)
    p.add_argument("--gamma", type=float, default=0.0)
    p.add_argument("--align_loss", type=str, default="none", choices=["none","moment","normcal"])
    p.add_argument("--arch", type=str, default="mlp1", choices=["auto","linear","mlp1","mlp2"])
    p.add_argument("--train", type=int, default=1)
    p.add_argument("--val_ratio", type=float, default=0.1)
    args, _ = p.parse_known_args()
    main(**vars(args))

[model] mlp1 | params=1,312,768 (~5.01 MB) | pooling=mean-patch | align=none | α=1.0 β=1.0 γ=0.0
[01] train_loss=0.660932 | val_MRR=0.0333 | R@1=0.009 R@5=0.043
[02] train_loss=0.412790 | val_MRR=0.0618 | R@1=0.024 R@5=0.082
[03] train_loss=0.380159 | val_MRR=0.0828 | R@1=0.035 R@5=0.111
[04] train_loss=0.363856 | val_MRR=0.0970 | R@1=0.042 R@5=0.132
[05] train_loss=0.353517 | val_MRR=0.1121 | R@1=0.052 R@5=0.156
[06] train_loss=0.346098 | val_MRR=0.1212 | R@1=0.055 R@5=0.173
[07] train_loss=0.340426 | val_MRR=0.1313 | R@1=0.061 R@5=0.186
[08] train_loss=0.335868 | val_MRR=0.1379 | R@1=0.064 R@5=0.199
[09] train_loss=0.332036 | val_MRR=0.1467 | R@1=0.069 R@5=0.213
[10] train_loss=0.328716 | val_MRR=0.1541 | R@1=0.074 R@5=0.223
[11] train_loss=0.325812 | val_MRR=0.1613 | R@1=0.080 R@5=0.232
[12] train_loss=0.323231 | val_MRR=0.1674 | R@1=0.085 R@5=0.240
[13] train_loss=0.320927 | val_MRR=0.1739 | R@1=0.089 R@5=0.251
[14] train_loss=0.318824 | val_MRR=0.1789 | R@1=0.093 R@5=0.256
[15] tr

In [58]:
#!/usr/bin/env python3
import argparse, json, re, time
from pathlib import Path
from os.path import basename

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

# --------------------------
# I/O helpers
# --------------------------
def load_npz(path: Path):
    d = np.load(path, allow_pickle=True)
    return d

def assert_dims(d, expect_text=1024, expect_img=1536):
    md = {k: d[k][0] for k in d.files if k.startswith("metadata/")}
    tdim = md.get("metadata/embedding_dim_text", None)
    idim = md.get("metadata/embedding_dim_image", None)
    print(f"[meta] text_dim={tdim} | image_dim={idim}")
    ok = (tdim == expect_text) and (idim == expect_img)
    if not ok:
        print("[WARN] Metadata dims differ from expected fixed encoders "
              f"(expected text={expect_text}, image={expect_img}). "
              "Absolute metrics may shift; re-generate .npz with the official encoders.")
    return ok

def _build_image_index(image_ids):
    idx_exact = {}; idx_base={}; idx_stem={}
    for i, v in enumerate(image_ids):
        s = str(v); idx_exact[s]=i
        b = basename(s); idx_base[b]=i
        stem = b.rsplit(".",1)[0] if "." in b else b
        idx_stem[stem]=i
    return idx_exact, idx_base, idx_stem

def _iter_targets_from_captions(path: Path, n_text: int, idx_exact, idx_base, idx_stem):
    targets=[]
    with open(path, "r", encoding="utf-8") as f:
        for raw in f:
            if len(targets)>=n_text: break
            line=raw.strip()
            if not line: continue
            parts=re.split(r'\||\t|,| {2,}', line)
            if len(parts)==1: parts=line.split(" ",1)
            tok=parts[0].strip().strip('"').strip("'")
            if not tok: continue
            def _match(tok_):
                if tok_ in idx_exact: return idx_exact[tok_]
                b=basename(tok_)
                if b in idx_base: return idx_base[b]
                stem=b.rsplit(".",1)[0] if "." in b else b
                if stem in idx_stem: return idx_stem[stem]
                return None
            idx=_match(tok)
            if idx is None:
                tok2=tok.replace("Images/","").replace("./","")
                idx=_match(tok2)
            if idx is None:
                if "." not in tok: continue
                raise AssertionError(f"Could not match image token '{tok}'")
            targets.append(idx)
    assert len(targets)==n_text, f"matched {len(targets)} vs N_text {n_text}"
    return np.asarray(targets, dtype=np.int64)

# --------------------------
# Val split by image id
# --------------------------
def make_image_id_split(train_npz: Path, captions_txt: Path, seed=42, val_ratio=0.1):
    d = load_npz(train_npz)
    X = d["captions/embeddings"].astype(np.float32)          # (N_text, 1024)
    I = d["images/embeddings"].astype(np.float32)            # (N_img, 1536)
    img_names = d["images/names"]
    ex,ba,st = _build_image_index(img_names)
    targets = _iter_targets_from_captions(captions_txt, len(X), ex,ba,st)  # image index per caption
    img_id_per_row = img_names[targets]

    uniq = np.unique(img_id_per_row)
    rng = np.random.default_rng(seed); rng.shuffle(uniq)
    n_val = max(1, int(len(uniq)*val_ratio))
    val_set = set(uniq[:n_val])

    mask_val = np.array([iid in val_set for iid in img_id_per_row])
    mask_trn = ~mask_val

    # (sanity) no leakage:
    trn_imgs = set(np.unique(img_id_per_row[mask_trn]).tolist())
    val_imgs = set(np.unique(img_id_per_row[mask_val]).tolist())
    leakage = len(trn_imgs.intersection(val_imgs))!=0

    print(f"[split] captions: train={mask_trn.sum()} | val={mask_val.sum()} | "
          f"unique images: train={len(trn_imgs)} | val={len(val_imgs)} | leakage={leakage}")

    # build gallery for val = unique val images (rows from I)
    val_img_indices = np.array(sorted({ex.get(str(x), None) or ba.get(basename(str(x)), None)
                                       for x in val_imgs if x is not None}))
    G_val = I[val_img_indices]
    # map each val caption to its true gallery index
    # (true image index within G_val)
    map_global_to_local = {gi:i for i,gi in enumerate(val_img_indices)}
    y_val_local = np.array([map_global_to_local[t] for t in targets[mask_val]], dtype=np.int64)

    return (X[mask_trn], targets[mask_trn]), (X[mask_val], y_val_local, G_val), (mask_trn, mask_val)

# --------------------------
# Retrieval metrics
# --------------------------
def compute_mrr(topk_idx: np.ndarray, gt: np.ndarray) -> float:
    rr=[]
    for i in range(len(gt)):
        pos = np.where(topk_idx[i]==gt[i])[0]
        rr.append(1.0/(pos[0]+1) if len(pos)>0 else 0.0)
    return float(np.mean(rr))

def recall_at_k(topk_idx: np.ndarray, gt: np.ndarray, k: int) -> float:
    return float(np.mean([gt[i] in topk_idx[i,:k] for i in range(len(gt))]))

@torch.no_grad()
def eval_full_gallery(pred: torch.Tensor, G: torch.Tensor, y_true_local: np.ndarray, k: int = 100):
    # Ensure both tensors live on the same device
    if G.device != pred.device:
        G = G.to(pred.device)

    # cosine on L2-normalized vectors (validation protocol)
    pred_n = F.normalize(pred, dim=-1)
    G_n    = F.normalize(G, dim=-1)
    sims   = pred_n @ G_n.t()                         # (Nq, Ng)
    topk   = min(k, G_n.size(0))
    idx    = sims.topk(k=topk, dim=1, largest=True, sorted=True).indices.cpu().numpy()

    # metrics
    def _mrr(topk_idx, gt):
        rr=[]
        for i in range(len(gt)):
            pos = np.where(topk_idx[i]==gt[i])[0]
            rr.append(1.0/(pos[0]+1) if len(pos)>0 else 0.0)
        return float(np.mean(rr))

    def _recall_k(topk_idx, gt, K):
        K = min(K, topk_idx.shape[1])
        return float(np.mean([gt[i] in topk_idx[i,:K] for i in range(len(gt))]))

    return {
        "MRR": _mrr(idx, y_true_local),
        "R1":  _recall_k(idx, y_true_local, 1),
        "R5":  _recall_k(idx, y_true_local, 5),
        "R10": _recall_k(idx, y_true_local, 10),
    }

# --------------------------
# Pooling no-op check (on text)
# --------------------------
def pooling_noop_check(X_val: np.ndarray):
    # Your text embeddings are already pooled (single vector).
    # Any "CLS vs mean-patch" toggle on text should do nothing.
    x = torch.from_numpy(X_val[:2048]).float()
    x_cls  = x.clone()
    # "mean" pretend: average over fake patches (not applicable)
    x_mean = x.mean(dim=1, keepdim=True) if x.ndim==3 else x  # will equal x if ndims==2
    diff = (x_cls - x_mean).abs().max().item()
    if x.ndim == 2 and diff == 0.0:
        print("[pooling] Text CLS vs mean is a NO-OP (identical). This explains equal MRR for CLS/mean in your logs.")
    else:
        print(f"[pooling] Text tensor ndim={x.ndim}; max|CLS-mean|={diff:.3e} (should be ~0 for your case).")

# --------------------------
# Write submissions A/B (normalized vs raw)
# --------------------------
def write_submission(ids, Z: np.ndarray, out_csv: Path):
    import pandas as pd
    df = pd.DataFrame({"id": ids, "embedding": Z.tolist()})
    df.to_csv(out_csv, index=False, float_format="%.17g")
    print(f"[submission] wrote {out_csv}")

# --------------------------
# Dummy model interface (plug your model here)
# --------------------------
class IdentityHead(torch.nn.Module):
    """Use this if you already have translated predictions saved externally.
       Otherwise replace this with your loaded MLP and correct dims.
    """
    def __init__(self, d): super().__init__(); self.d=d
    def forward(self, x): return x


def fit_linear_map_ridge(X_tr: np.ndarray, Y_tr: np.ndarray, l2: float = 1e-2, device=None):
    """
    Learn A s.t. X @ A ≈ Y (ridge-regularized least squares).
    X_tr: (N, dx) text; Y_tr: (N, dy) image
    Returns torch.Tensor A: (dx, dy) on `device`.
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    X = torch.from_numpy(X_tr).to(device)          # (N, dx)
    Y = torch.from_numpy(Y_tr).to(device)          # (N, dy)
    dx = X.shape[1]
    # Normal equations: A = (XᵀX + λI)⁻¹ Xᵀ Y
    Xt = X.transpose(0,1)                           # (dx, N)
    G  = Xt @ X                                     # (dx, dx)
    G = G + l2 * torch.eye(dx, device=device)
    A  = torch.linalg.solve(G, Xt @ Y)              # (dx, dy)
    return A

def run(args):
    train_npz = Path(args.train_npz)
    test_npz  = Path(args.test_npz)
    train_caps= Path(args.train_captions)
    out_dir   = Path(args.out_dir); out_dir.mkdir(parents=True, exist_ok=True)

    dtr = load_npz(train_npz)
    ok_dims = assert_dims(dtr, expect_text=1024, expect_img=1536)

    # Split by image id (no leakage)
    (Xtr,_), (Xv, yv_local, Gv_np), _ = make_image_id_split(train_npz, train_caps, seed=args.seed, val_ratio=args.val_ratio)

    # --- after: (Xtr,_), (Xv, yv_local, Gv_np), ...
    # Build Ytr from training rows (true image embeddings per caption)
    D = np.load(args.train_npz, allow_pickle=True)
    img_names = D["images/names"]
    I_all     = D["images/embeddings"].astype(np.float32)
    # Recompute train targets like we did for the split:
    ex,ba,st  = _build_image_index(img_names)
    targets_all = _iter_targets_from_captions(Path(args.train_captions), len(D["captions/embeddings"]), ex,ba,st)
    mask_trn = _[0]  # (mask_trn, mask_val) returned earlier; we kept it as "_"
    Ytr = I_all[targets_all[mask_trn]]               # (N_train_caps, 1536)
    
    # If your translator is identity or dims don't match, use the closed-form A
    use_ridge = True  # set False if you will load your MLP
    if use_ridge:
        print("[ridge] fitting closed-form linear map X→Y on train split...")
        A = fit_linear_map_ridge(Xtr, Ytr, l2=1e-2, device=device)   # (1024,1536)
    
        # VAL preds
        BS=2048
        preds_val=[]
        with torch.no_grad():
            for i in range(0, len(Xv), BS):
                q = torch.from_numpy(Xv[i:i+BS]).float().to(device)   # (b,1024)
                z = q @ A                                             # (b,1536)
                preds_val.append(z)
        preds_val = torch.cat(preds_val, dim=0)                       # torch
    
        Gv = torch.from_numpy(Gv_np).float().to(device)
        stats = eval_full_gallery(preds_val, Gv, yv_local, k=100)
        print("[val] full-gallery retrieval:", json.dumps(stats, indent=2))
    
        # TEST preds + A/B submissions
        dtst = load_npz(Path(args.test_npz))
        Q = dtst["captions/embeddings"].astype(np.float32)
        ids = dtst.get("captions/ids", np.arange(len(Q)))
        outs=[]
        with torch.no_grad():
            for i in range(0, len(Q), BS):
                q = torch.from_numpy(Q[i:i+BS]).float().to(device)
                z = q @ A
                outs.append(z)
        preds_test = torch.cat(outs, dim=0)
    
        Z_norm = F.normalize(preds_test, dim=-1).cpu().numpy()
        write_submission(ids, Z_norm, Path(args.out_dir)/"submission_norm.csv")
        Z_raw  = preds_test.cpu().numpy()
        write_submission(ids, Z_raw,  Path(args.out_dir)/"submission_raw.csv")
        return

    # quick magnitude diagnostics
    norms_pred = preds_test.norm(dim=-1).mean().item()
    norms_G    = torch.from_numpy(load_npz(train_npz)["images/embeddings"]).float().norm(dim=-1).mean().item()
    print(f"[diag] mean||pred||={norms_pred:.3f} vs mean||image||={norms_G:.3f} (useful if raw dot vs cosine matters)")

    report = {
        "dims_ok": ok_dims,
        "val_MRR": stats["MRR"],
        "val_R@1": stats["R1"],
        "val_R@5": stats["R5"],
        "val_R@10": stats["R10"],
        "notes": [
            "Text pooling toggle is a no-op (CLS vs mean identical) — expected.",
            "Use submission_norm.csv vs submission_raw.csv A/B to see leaderboard scorer’s sensitivity to normalization.",
            "Val split built by image id; metrics reflect official retrieval protocol (captions→val gallery)."
        ]
    }
    (out_dir/"sanity_report.json").write_text(json.dumps(report, indent=2))
    print(f"[ok] wrote {out_dir/'sanity_report.json'}")
    
if __name__ == "__main__":
    import sys
    # ignore unwanted Jupyter args like "-f ..."
    argv = [a for a in sys.argv if not a.startswith("-f")]
    p = argparse.ArgumentParser()
    p.add_argument("--train_npz", type=str, default="/kaggle/working/data/train/train/train.npz")
    p.add_argument("--test_npz",  type=str, default="/kaggle/working/data/test/test/test.clean.npz")
    p.add_argument("--train_captions", type=str, default="/kaggle/working/data/train/train/captions.txt")
    p.add_argument("--out_dir", type=str, default="checks_out")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--val_ratio", type=float, default=0.1)
    args, _ = p.parse_known_args()
    run(args)


[meta] text_dim=1024 | image_dim=1536
[split] captions: train=112500 | val=12500 | unique images: train=22500 | val=2500 | leakage=False
[ridge] fitting closed-form linear map X→Y on train split...


NameError: name 'device' is not defined

In [59]:
#!/usr/bin/env python3
import argparse, json, re, time
from pathlib import Path
from os.path import basename

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

# --------------------------
# Simple MLP Translator
# --------------------------
class MLPTranslator(torch.nn.Module):
    def __init__(self, text_dim=1024, img_dim=1536, hidden_dim=2048):
        super().__init__()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(text_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.1),
            torch.nn.Linear(hidden_dim, img_dim)
        )
    
    def forward(self, x):
        return self.mlp(x)

# --------------------------
# I/O helpers
# --------------------------
def load_npz(path: Path):
    d = np.load(path, allow_pickle=True)
    return d

def assert_dims(d, expect_text=1024, expect_img=1536):
    md = {k: d[k][0] for k in d.files if k.startswith("metadata/")}
    tdim = md.get("metadata/embedding_dim_text", None)
    idim = md.get("metadata/embedding_dim_image", None)
    print(f"[meta] text_dim={tdim} | image_dim={idim}")
    ok = (tdim == expect_text) and (idim == expect_img)
    if not ok:
        print("[WARN] Metadata dims differ from expected fixed encoders "
              f"(expected text={expect_text}, image={expect_img}). "
              "Absolute metrics may shift; re-generate .npz with the official encoders.")
    return ok

def _build_image_index(image_ids):
    idx_exact = {}; idx_base={}; idx_stem={}
    for i, v in enumerate(image_ids):
        s = str(v); idx_exact[s]=i
        b = basename(s); idx_base[b]=i
        stem = b.rsplit(".",1)[0] if "." in b else b
        idx_stem[stem]=i
    return idx_exact, idx_base, idx_stem

def _iter_targets_from_captions(path: Path, n_text: int, idx_exact, idx_base, idx_stem):
    targets=[]
    with open(path, "r", encoding="utf-8") as f:
        for raw in f:
            if len(targets)>=n_text: break
            line=raw.strip()
            if not line: continue
            parts=re.split(r'\||\t|,| {2,}', line)
            if len(parts)==1: parts=line.split(" ",1)
            tok=parts[0].strip().strip('"').strip("'")
            if not tok: continue
            def _match(tok_):
                if tok_ in idx_exact: return idx_exact[tok_]
                b=basename(tok_)
                if b in idx_base: return idx_base[b]
                stem=b.rsplit(".",1)[0] if "." in b else b
                if stem in idx_stem: return idx_stem[stem]
                return None
            idx=_match(tok)
            if idx is None:
                tok2=tok.replace("Images/","").replace("./","")
                idx=_match(tok2)
            if idx is None:
                if "." not in tok: continue
                raise AssertionError(f"Could not match image token '{tok}'")
            targets.append(idx)
    assert len(targets)==n_text, f"matched {len(targets)} vs N_text {n_text}"
    return np.asarray(targets, dtype=np.int64)

def eval_full_gallery(pred: torch.Tensor, G: torch.Tensor, y_true_local: np.ndarray, k: int = 100):
    # Ensure both tensors live on the same device
    device = pred.device
    if G.device != device:
        G = G.to(device)

    # cosine on L2-normalized vectors (validation protocol)
    pred_n = F.normalize(pred, dim=-1)    # (Nq, 1536)
    G_n = F.normalize(G, dim=-1)          # (Ng, 1536)
    sims = torch.mm(pred_n, G_n.t())      # (Nq, Ng)
    topk = min(k, G_n.size(0))
    idx = sims.topk(k=topk, dim=1, largest=True, sorted=True).indices.cpu().numpy()

    # metrics
    def _mrr(topk_idx, gt):
        rr = []
        for i in range(len(gt)):
            pos = np.where(topk_idx[i]==gt[i])[0]
            rr.append(1.0/(pos[0]+1) if len(pos)>0 else 0.0)
        return float(np.mean(rr))

    def _recall_k(topk_idx, gt, K):
        K = min(K, topk_idx.shape[1])
        return float(np.mean([gt[i] in topk_idx[i,:K] for i in range(len(gt))]))

    return {
        "MRR": _mrr(idx, y_true_local),
        "R1":  _recall_k(idx, y_true_local, 1),
        "R5":  _recall_k(idx, y_true_local, 5),
        "R10": _recall_k(idx, y_true_local, 10),
    }

def run(args):
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    train_npz = Path(args.train_npz)
    test_npz = Path(args.test_npz)
    train_caps = Path(args.train_captions)
    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    dtr = load_npz(train_npz)
    ok_dims = assert_dims(dtr, expect_text=1024, expect_img=1536)

    # Split by image id (no leakage)
    (Xtr, ytr), (Xv, yv_local, Gv_np), _ = make_image_id_split(
        train_npz, train_caps, seed=args.seed, val_ratio=args.val_ratio
    )

    # Pooling no-op sanity on text
    pooling_noop_check(Xv)

    # Create and move model to GPU
    model = MLPTranslator(text_dim=1024, img_dim=1536).to(device)
    model.train()

    # Quick training loop (you might want to expand this)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    batch_size = 512
    n_epochs = 5

    for epoch in range(n_epochs):
        total_loss = 0
        n_batches = 0
        
        # Training loop
        for i in range(0, len(Xtr), batch_size):
            batch_x = torch.from_numpy(Xtr[i:i+batch_size]).float().to(device)
            batch_y = torch.from_numpy(dtr['images/embeddings'][ytr[i:i+batch_size]]).float().to(device)
            
            optimizer.zero_grad()
            pred = model(batch_x)
            loss = F.mse_loss(pred, batch_y)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            n_batches += 1
        
        print(f"Epoch {epoch+1}/{n_epochs}, Avg Loss: {total_loss/n_batches:.6f}")

    # Evaluation mode
    model.eval()

    # Eval on VAL
    BS = 1024
    preds_val = []
    with torch.no_grad():
        for i in range(0, len(Xv), BS):
            q = torch.from_numpy(Xv[i:i+BS]).float().to(device)
            z = model(q)  # Now translates to image space (1536d)
            preds_val.append(z)
    preds_val = torch.cat(preds_val, dim=0)

    Gv = torch.from_numpy(Gv_np).float().to(device)
    stats = eval_full_gallery(preds_val, Gv, yv_local, k=100)
    print("[val] full-gallery retrieval:", json.dumps(stats, indent=2))

    # Test predictions
    dtst = load_npz(test_npz)
    Q = dtst["captions/embeddings"].astype(np.float32)
    ids = dtst.get("captions/ids", np.arange(len(Q)))

    preds_test = []
    with torch.no_grad():
        for i in range(0, len(Q), BS):
            q = torch.from_numpy(Q[i:i+BS]).float().to(device)
            z = model(q)
            preds_test.append(z)
    preds_test = torch.cat(preds_test, dim=0)

    # Normalized submission
    Z_norm = F.normalize(preds_test, dim=-1).cpu().numpy()
    write_submission(ids, Z_norm, out_dir/"submission_norm.csv")

    # Raw submission
    Z_raw = preds_test.cpu().numpy()
    write_submission(ids, Z_raw, out_dir/"submission_raw.csv")

    # Diagnostics
    norms_pred = preds_test.norm(dim=-1).mean().item()
    norms_G = torch.from_numpy(dtr["images/embeddings"]).float().norm(dim=-1).mean().item()
    print(f"[diag] mean||pred||={norms_pred:.3f} vs mean||image||={norms_G:.3f}")

    # Convert NumPy types to Python native types for JSON serialization
    report = {
        "dims_ok": bool(ok_dims),  # Convert np.bool_ to Python bool
        "val_MRR": float(stats["MRR"]),  # Convert np.float to Python float
        "val_R@1": float(stats["R1"]),
        "val_R@5": float(stats["R5"]),
        "val_R@10": float(stats["R10"]),
        "notes": [
            "Text pooling toggle is a no-op (CLS vs mean identical) — expected.",
            "Using MLP translator from text (1024d) to image space (1536d)",
            "Val split built by image id; metrics reflect official retrieval protocol."
        ]
    }
    
    # Write the report
    (out_dir/"sanity_report.json").write_text(json.dumps(report, indent=2))
    print(f"[ok] wrote {out_dir/'sanity_report.json'}")

if __name__ == "__main__":
    p = argparse.ArgumentParser()
    p.add_argument("--train_npz", type=str, default="/kaggle/working/data/train/train/train.npz")
    p.add_argument("--test_npz", type=str, default="/kaggle/working/data/test/test/test.clean.npz")
    p.add_argument("--train_captions", type=str, default="/kaggle/working/data/train/train/captions.txt")
    p.add_argument("--out_dir", type=str, default="checks_out")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--val_ratio", type=float, default=0.1)
    args, _ = p.parse_known_args()
    run(args)

Using device: cuda
[meta] text_dim=1024 | image_dim=1536
[split] captions: train=112500 | val=12500 | unique images: train=22500 | val=2500 | leakage=False
[pooling] Text CLS vs mean is a NO-OP (identical). This explains equal MRR for CLS/mean in your logs.
Epoch 1/5, Avg Loss: 0.198158
Epoch 2/5, Avg Loss: 0.159567
Epoch 3/5, Avg Loss: 0.152692
Epoch 4/5, Avg Loss: 0.148695
Epoch 5/5, Avg Loss: 0.145914
[val] full-gallery retrieval: {
  "MRR": 0.17877673524057414,
  "R1": 0.0916,
  "R5": 0.2632,
  "R10": 0.36432
}
[submission] wrote checks_out/submission_norm.csv
[submission] wrote checks_out/submission_raw.csv
[diag] mean||pred||=21.213 vs mean||image||=25.939


TypeError: Object of type bool_ is not JSON serializable

What we learned:

- text_dim=1024 | image_dim=1536 we using the fixed encoders (roberta-large-nli-stsb-mean-tokens and dinov2-giant).
-  leakage=False, val split by image id → new validation pipeline is officially aligned with the challenge spec.
- Cosine retrieval (F.normalize, dot product) matches the public LB similarity.

The LB expects already-normalized embeddings.

- Raw outputs distort similarity magnitudes (pred norms ≈ 21 vs image norms ≈ 26 → biased dot products).
-  Always normalize before writing the submission.

mean||pred|| = 21.213 and mean||image|| = 25.939
he LB similarity is cosine on normalized vectors, i.e. sim = (z_pred / ||z_pred||) @ (z_img / ||z_img||).T

In [19]:
# python runner.py --out_dir mlp1_cls_cos0.5_mse1.5 --pooling CLS --alpha 0.5 --beta 1.5 --gamma 0 --align_loss none --arch mlp1

if __name__=="__main__":
    import argparse
    p=argparse.ArgumentParser()
    p.add_argument("--out_dir", type=str, default="mlp1_cls_cos0.5_mse1.5")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--epochs", type=int, default=30)
    p.add_argument("--batch", type=int, default=512)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--wd", type=float, default=1e-4)
    p.add_argument("--pooling", type=str, default="CLS", choices=["CLS","mean-patch"])
    p.add_argument("--n_patches", type=int, default=None)
    p.add_argument("--alpha", type=float, default=0.5)
    p.add_argument("--beta", type=float, default=1.5)
    p.add_argument("--gamma", type=float, default=0.0)
    p.add_argument("--align_loss", type=str, default="none", choices=["none","moment","normcal"])
    p.add_argument("--arch", type=str, default="mlp1", choices=["auto","linear","mlp1","mlp2"])
    p.add_argument("--train", type=int, default=1)
    p.add_argument("--val_ratio", type=float, default=0.1)
    args, _ = p.parse_known_args()
    main(**vars(args))

[model] mlp1 | params=1,312,768 (~5.01 MB) | pooling=CLS | align=none | α=0.5 β=1.5 γ=0.0
[01] train_loss=0.609829 | val_MRR=0.0321 | R@1=0.009 R@5=0.040
[02] train_loss=0.387240 | val_MRR=0.0599 | R@1=0.023 R@5=0.079
[03] train_loss=0.357811 | val_MRR=0.0814 | R@1=0.035 R@5=0.110
[04] train_loss=0.342970 | val_MRR=0.0957 | R@1=0.041 R@5=0.131
[05] train_loss=0.333531 | val_MRR=0.1105 | R@1=0.050 R@5=0.155
[06] train_loss=0.326755 | val_MRR=0.1204 | R@1=0.054 R@5=0.172
[07] train_loss=0.321568 | val_MRR=0.1302 | R@1=0.059 R@5=0.186
[08] train_loss=0.317397 | val_MRR=0.1376 | R@1=0.064 R@5=0.196
[09] train_loss=0.313890 | val_MRR=0.1463 | R@1=0.070 R@5=0.211
[10] train_loss=0.310848 | val_MRR=0.1538 | R@1=0.075 R@5=0.221
[11] train_loss=0.308186 | val_MRR=0.1608 | R@1=0.079 R@5=0.231
[12] train_loss=0.305821 | val_MRR=0.1665 | R@1=0.083 R@5=0.240
[13] train_loss=0.303701 | val_MRR=0.1734 | R@1=0.088 R@5=0.249
[14] train_loss=0.301760 | val_MRR=0.1781 | R@1=0.092 R@5=0.258
[15] train_los

In [23]:
# python runner.py --out_dir mlp1_cls_cos1.5_mse0.5 --pooling CLS --alpha 1.5 --beta 0.5 --gamma 0 --align_loss none --arch mlp1

if __name__=="__main__":
    import argparse
    p=argparse.ArgumentParser()
    p.add_argument("--out_dir", type=str, default="mlp1_cls_cos1.5_mse0.5")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--epochs", type=int, default=30)
    p.add_argument("--batch", type=int, default=512)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--wd", type=float, default=1e-4)
    p.add_argument("--pooling", type=str, default="CLS", choices=["CLS","mean-patch"])
    p.add_argument("--n_patches", type=int, default=None)
    p.add_argument("--alpha", type=float, default=1.5)
    p.add_argument("--beta", type=float, default=0.5)
    p.add_argument("--gamma", type=float, default=0.0)
    p.add_argument("--align_loss", type=str, default="none", choices=["none","moment","normcal"])
    p.add_argument("--arch", type=str, default="mlp1", choices=["auto","linear","mlp1","mlp2"])
    p.add_argument("--train", type=int, default=1)
    p.add_argument("--val_ratio", type=float, default=0.1)
    args, _ = p.parse_known_args()
    main(**vars(args))

[model] mlp1 | params=1,312,768 (~5.01 MB) | pooling=CLS | align=none | α=1.5 β=0.5 γ=0.0
[01] train_loss=0.712952 | val_MRR=0.0363 | R@1=0.010 R@5=0.046
[02] train_loss=0.438971 | val_MRR=0.0648 | R@1=0.025 R@5=0.088
[03] train_loss=0.403020 | val_MRR=0.0858 | R@1=0.037 R@5=0.114
[04] train_loss=0.385062 | val_MRR=0.0999 | R@1=0.044 R@5=0.136
[05] train_loss=0.373677 | val_MRR=0.1140 | R@1=0.052 R@5=0.158
[06] train_loss=0.365519 | val_MRR=0.1241 | R@1=0.058 R@5=0.175
[07] train_loss=0.359262 | val_MRR=0.1334 | R@1=0.063 R@5=0.189
[08] train_loss=0.354221 | val_MRR=0.1402 | R@1=0.066 R@5=0.201
[09] train_loss=0.349983 | val_MRR=0.1489 | R@1=0.072 R@5=0.214
[10] train_loss=0.346339 | val_MRR=0.1560 | R@1=0.077 R@5=0.225
[11] train_loss=0.343170 | val_MRR=0.1631 | R@1=0.081 R@5=0.233
[12] train_loss=0.340354 | val_MRR=0.1674 | R@1=0.083 R@5=0.241
[13] train_loss=0.337847 | val_MRR=0.1748 | R@1=0.089 R@5=0.252
[14] train_loss=0.335557 | val_MRR=0.1789 | R@1=0.091 R@5=0.259
[15] train_los

In [24]:
# python runner.py --out_dir mlp1_cls_cos+mse+moments --pooling CLS --alpha 1 --beta 1 --gamma 0 --align_loss moment --arch mlp1

if __name__=="__main__":
    import argparse
    p=argparse.ArgumentParser()
    p.add_argument("--out_dir", type=str, default="mlp1_cls_cos+mse+moments")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--epochs", type=int, default=30)
    p.add_argument("--batch", type=int, default=512)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--wd", type=float, default=1e-4)
    p.add_argument("--pooling", type=str, default="CLS", choices=["CLS","mean-patch"])
    p.add_argument("--n_patches", type=int, default=None)
    p.add_argument("--alpha", type=float, default=1)
    p.add_argument("--beta", type=float, default=1)
    p.add_argument("--gamma", type=float, default=0.0)
    p.add_argument("--align_loss", type=str, default="moment", choices=["none","moment","normcal"])
    p.add_argument("--arch", type=str, default="mlp1", choices=["auto","linear","mlp1","mlp2"])
    p.add_argument("--train", type=int, default=1)
    p.add_argument("--val_ratio", type=float, default=0.1)
    args, _ = p.parse_known_args()
    main(**vars(args))

[model] mlp1 | params=1,312,768 (~5.01 MB) | pooling=CLS | align=moment | α=1 β=1 γ=0.0
[01] train_loss=0.740650 | val_MRR=0.0504 | R@1=0.017 R@5=0.067
[02] train_loss=0.465031 | val_MRR=0.0868 | R@1=0.038 R@5=0.116
[03] train_loss=0.426755 | val_MRR=0.1102 | R@1=0.050 R@5=0.150
[04] train_loss=0.407163 | val_MRR=0.1248 | R@1=0.057 R@5=0.176
[05] train_loss=0.394496 | val_MRR=0.1385 | R@1=0.064 R@5=0.197
[06] train_loss=0.385237 | val_MRR=0.1488 | R@1=0.070 R@5=0.211
[07] train_loss=0.378039 | val_MRR=0.1587 | R@1=0.075 R@5=0.228
[08] train_loss=0.372214 | val_MRR=0.1672 | R@1=0.081 R@5=0.240
[09] train_loss=0.367317 | val_MRR=0.1747 | R@1=0.086 R@5=0.252
[10] train_loss=0.363102 | val_MRR=0.1818 | R@1=0.092 R@5=0.263
[11] train_loss=0.359437 | val_MRR=0.1892 | R@1=0.097 R@5=0.272
[12] train_loss=0.356184 | val_MRR=0.1951 | R@1=0.102 R@5=0.282
[13] train_loss=0.353258 | val_MRR=0.2016 | R@1=0.107 R@5=0.290
[14] train_loss=0.350566 | val_MRR=0.2067 | R@1=0.111 R@5=0.297
[15] train_loss=

In [25]:
# python runner.py --out_dir mlp1_cls_cos+mse+normcal --pooling CLS --alpha 1 --beta 1 --gamma 0 --align_loss normcal --arch mlp1

if __name__=="__main__":
    import argparse
    p=argparse.ArgumentParser()
    p.add_argument("--out_dir", type=str, default="mlp1_cls_cos+mse+normcal")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--epochs", type=int, default=30)
    p.add_argument("--batch", type=int, default=512)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--wd", type=float, default=1e-4)
    p.add_argument("--pooling", type=str, default="CLS", choices=["CLS","mean-patch"])
    p.add_argument("--n_patches", type=int, default=None)
    p.add_argument("--alpha", type=float, default=1)
    p.add_argument("--beta", type=float, default=1)
    p.add_argument("--gamma", type=float, default=0.0)
    p.add_argument("--align_loss", type=str, default="normcal", choices=["none","moment","normcal"])
    p.add_argument("--arch", type=str, default="mlp1", choices=["auto","linear","mlp1","mlp2"])
    p.add_argument("--train", type=int, default=1)
    p.add_argument("--val_ratio", type=float, default=0.1)
    args, _ = p.parse_known_args()
    main(**vars(args))

[model] mlp1 | params=1,312,768 (~5.01 MB) | pooling=CLS | align=normcal | α=1 β=1 γ=0.0
[01] train_loss=2.795950 | val_MRR=0.0046 | R@1=0.001 R@5=0.003
[02] train_loss=1.281909 | val_MRR=0.0047 | R@1=0.000 R@5=0.004
[03] train_loss=0.996945 | val_MRR=0.0055 | R@1=0.001 R@5=0.004
[04] train_loss=0.828152 | val_MRR=0.0065 | R@1=0.001 R@5=0.005
[05] train_loss=0.748381 | val_MRR=0.0079 | R@1=0.001 R@5=0.006
[06] train_loss=0.703146 | val_MRR=0.0096 | R@1=0.002 R@5=0.009
[07] train_loss=0.661552 | val_MRR=0.0113 | R@1=0.002 R@5=0.011
[08] train_loss=0.625856 | val_MRR=0.0132 | R@1=0.003 R@5=0.014
[09] train_loss=0.631494 | val_MRR=0.0155 | R@1=0.004 R@5=0.017
[10] train_loss=0.596262 | val_MRR=0.0179 | R@1=0.004 R@5=0.020
[11] train_loss=0.585320 | val_MRR=0.0202 | R@1=0.005 R@5=0.023
[12] train_loss=0.573368 | val_MRR=0.0227 | R@1=0.006 R@5=0.026
[13] train_loss=0.557340 | val_MRR=0.0255 | R@1=0.008 R@5=0.030
[14] train_loss=0.541786 | val_MRR=0.0278 | R@1=0.008 R@5=0.033
[15] train_loss

In [26]:
# python runner.py --out_dir mlp2_cls_cos+mse --pooling CLS --alpha 1 --beta 1 --gamma 0 --align_loss none --arch mlp2

if __name__=="__main__":
    import argparse
    p=argparse.ArgumentParser()
    p.add_argument("--out_dir", type=str, default="mlp2_cls_cos+mse")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--epochs", type=int, default=30)
    p.add_argument("--batch", type=int, default=512)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--wd", type=float, default=1e-4)
    p.add_argument("--pooling", type=str, default="CLS", choices=["CLS","mean-patch"])
    p.add_argument("--n_patches", type=int, default=None)
    p.add_argument("--alpha", type=float, default=1)
    p.add_argument("--beta", type=float, default=1)
    p.add_argument("--gamma", type=float, default=0.0)
    p.add_argument("--align_loss", type=str, default="moment", choices=["none","moment","normcal"])
    p.add_argument("--arch", type=str, default="mlp2", choices=["auto","linear","mlp1","mlp2"])
    p.add_argument("--train", type=int, default=1)
    p.add_argument("--val_ratio", type=float, default=0.1)
    args, _ = p.parse_known_args()
    main(**vars(args))

[model] mlp2 | params=2,362,368 (~9.01 MB) | pooling=CLS | align=moment | α=1 β=1 γ=0.0
[01] train_loss=0.636256 | val_MRR=0.0449 | R@1=0.014 R@5=0.057
[02] train_loss=0.438947 | val_MRR=0.0828 | R@1=0.032 R@5=0.115
[03] train_loss=0.406235 | val_MRR=0.1096 | R@1=0.046 R@5=0.154
[04] train_loss=0.389227 | val_MRR=0.1285 | R@1=0.056 R@5=0.181
[05] train_loss=0.377987 | val_MRR=0.1440 | R@1=0.066 R@5=0.207
[06] train_loss=0.369644 | val_MRR=0.1560 | R@1=0.073 R@5=0.229
[07] train_loss=0.363055 | val_MRR=0.1688 | R@1=0.081 R@5=0.247
[08] train_loss=0.357602 | val_MRR=0.1773 | R@1=0.087 R@5=0.259
[09] train_loss=0.353023 | val_MRR=0.1858 | R@1=0.092 R@5=0.273
[10] train_loss=0.348982 | val_MRR=0.1918 | R@1=0.096 R@5=0.281
[11] train_loss=0.345339 | val_MRR=0.1993 | R@1=0.103 R@5=0.291
[12] train_loss=0.341988 | val_MRR=0.2051 | R@1=0.108 R@5=0.296
[13] train_loss=0.338930 | val_MRR=0.2127 | R@1=0.113 R@5=0.310
[14] train_loss=0.336059 | val_MRR=0.2169 | R@1=0.116 R@5=0.317
[15] train_loss=

In [38]:
# python runner.py --out_dir mlp1_cls_cos+mse+infoNCE --pooling CLS --alpha 1 --beta 1 --gamma 0.1 --align_loss none --arch mlp1

if __name__=="__main__":
    import argparse
    p=argparse.ArgumentParser()
    p.add_argument("--out_dir", type=str, default="mlp1_cls_cos+mse+infoNCE")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--epochs", type=int, default=30)
    p.add_argument("--batch", type=int, default=512)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--wd", type=float, default=1e-4)
    p.add_argument("--pooling", type=str, default="CLS", choices=["CLS","mean-patch"])
    p.add_argument("--n_patches", type=int, default=None)
    p.add_argument("--alpha", type=float, default=1)
    p.add_argument("--beta", type=float, default=1)
    p.add_argument("--gamma", type=float, default=0.1)
    p.add_argument("--align_loss", type=str, default="none", choices=["none","moment","normcal"])
    p.add_argument("--arch", type=str, default="mlp1", choices=["auto","linear","mlp1","mlp2"])
    p.add_argument("--train", type=int, default=1)
    p.add_argument("--val_ratio", type=float, default=0.1)
    args, _ = p.parse_known_args()
    main(**vars(args))

[model] mlp1 | params=1,312,768 (~5.01 MB) | pooling=CLS | align=none | α=1 β=1 γ=0.1
[01] train_loss=1.280466 | val_MRR=0.0384 | R@1=0.012 R@5=0.049
[02] train_loss=1.027822 | val_MRR=0.0682 | R@1=0.028 R@5=0.092
[03] train_loss=0.993354 | val_MRR=0.0901 | R@1=0.039 R@5=0.122
[04] train_loss=0.976056 | val_MRR=0.1044 | R@1=0.046 R@5=0.143
[05] train_loss=0.965049 | val_MRR=0.1202 | R@1=0.056 R@5=0.169
[06] train_loss=0.957130 | val_MRR=0.1304 | R@1=0.061 R@5=0.185
[07] train_loss=0.951046 | val_MRR=0.1404 | R@1=0.066 R@5=0.200
[08] train_loss=0.946120 | val_MRR=0.1482 | R@1=0.071 R@5=0.214
[09] train_loss=0.941967 | val_MRR=0.1561 | R@1=0.075 R@5=0.226
[10] train_loss=0.938371 | val_MRR=0.1637 | R@1=0.081 R@5=0.235
[11] train_loss=0.935229 | val_MRR=0.1720 | R@1=0.087 R@5=0.247
[12] train_loss=0.932428 | val_MRR=0.1769 | R@1=0.091 R@5=0.254
[13] train_loss=0.929919 | val_MRR=0.1842 | R@1=0.096 R@5=0.265
[14] train_loss=0.927631 | val_MRR=0.1885 | R@1=0.099 R@5=0.272
[15] train_loss=0.

In [28]:
# python runner.py --out_dir mlp1_cls_cos+mse+infoNCE --pooling CLS --lr 5e-5 --alpha 1 --beta 1 --gamma 0.1 --align_loss moment --arch mlp1

# smaller LR stabilizes combo losses

if __name__=="__main__":
    import argparse
    p=argparse.ArgumentParser()
    p.add_argument("--out_dir", type=str, default="mlp1_lr5e5_moment01")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--epochs", type=int, default=30)
    p.add_argument("--batch", type=int, default=512)
    p.add_argument("--lr", type=float, default=5e-5)
    p.add_argument("--wd", type=float, default=1e-4)
    p.add_argument("--pooling", type=str, default="CLS", choices=["CLS","mean-patch"])
    p.add_argument("--n_patches", type=int, default=None)
    p.add_argument("--alpha", type=float, default=1)
    p.add_argument("--beta", type=float, default=1)
    p.add_argument("--gamma", type=float, default=0.1)
    p.add_argument("--align_loss", type=str, default="moment", choices=["none","moment","normcal"])
    p.add_argument("--arch", type=str, default="mlp1", choices=["auto","linear","mlp1","mlp2"])
    p.add_argument("--train", type=int, default=1)
    p.add_argument("--val_ratio", type=float, default=0.1)
    args, _ = p.parse_known_args()
    main(**vars(args))

[model] mlp1 | params=1,312,768 (~5.01 MB) | pooling=CLS | align=moment | α=1 β=1 γ=0.1
[01] train_loss=1.554588 | val_MRR=0.0285 | R@1=0.008 R@5=0.034
[02] train_loss=1.149325 | val_MRR=0.0571 | R@1=0.021 R@5=0.075
[03] train_loss=1.089582 | val_MRR=0.0791 | R@1=0.033 R@5=0.105
[04] train_loss=1.059996 | val_MRR=0.0957 | R@1=0.042 R@5=0.130
[05] train_loss=1.041193 | val_MRR=0.1090 | R@1=0.050 R@5=0.148
[06] train_loss=1.027914 | val_MRR=0.1190 | R@1=0.054 R@5=0.163
[07] train_loss=1.017830 | val_MRR=0.1276 | R@1=0.058 R@5=0.178
[08] train_loss=1.009741 | val_MRR=0.1349 | R@1=0.062 R@5=0.191
[09] train_loss=1.003019 | val_MRR=0.1420 | R@1=0.066 R@5=0.200
[10] train_loss=0.997265 | val_MRR=0.1487 | R@1=0.070 R@5=0.210
[11] train_loss=0.992285 | val_MRR=0.1552 | R@1=0.074 R@5=0.222
[12] train_loss=0.987915 | val_MRR=0.1594 | R@1=0.075 R@5=0.230
[13] train_loss=0.984024 | val_MRR=0.1658 | R@1=0.080 R@5=0.239
[14] train_loss=0.980513 | val_MRR=0.1707 | R@1=0.084 R@5=0.247
[15] train_loss=

In [34]:
# python runner.py --out_dir mlp2_moment_infonce01 --arch mlp2 --gamma 0.1 --align_loss moment --lr 2e-4
# If underfitting, higher LR helps
if __name__=="__main__":
    import argparse
    p=argparse.ArgumentParser()
    p.add_argument("--out_dir", type=str, default="mlp1_lr2e4_moment01")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--epochs", type=int, default=30)
    p.add_argument("--batch", type=int, default=512)
    p.add_argument("--lr", type=float, default=2e-4)
    p.add_argument("--wd", type=float, default=1e-4)
    p.add_argument("--pooling", type=str, default="CLS", choices=["CLS","mean-patch"])
    p.add_argument("--n_patches", type=int, default=None)
    p.add_argument("--alpha", type=float, default=1)
    p.add_argument("--beta", type=float, default=1)
    p.add_argument("--gamma", type=float, default=0.1)
    p.add_argument("--align_loss", type=str, default="moment", choices=["none","moment","normcal"])
    p.add_argument("--arch", type=str, default="mlp1", choices=["auto","linear","mlp1","mlp2"])
    p.add_argument("--train", type=int, default=1)
    p.add_argument("--val_ratio", type=float, default=0.1)
    args, _ = p.parse_known_args()
    main(**vars(args))

[model] mlp1 | params=1,312,768 (~5.01 MB) | pooling=CLS | align=moment | α=1 β=1 γ=0.1
[01] train_loss=1.227705 | val_MRR=0.0838 | R@1=0.035 R@5=0.113
[02] train_loss=1.030644 | val_MRR=0.1229 | R@1=0.056 R@5=0.171
[03] train_loss=1.001965 | val_MRR=0.1477 | R@1=0.069 R@5=0.211
[04] train_loss=0.986564 | val_MRR=0.1646 | R@1=0.079 R@5=0.237
[05] train_loss=0.976255 | val_MRR=0.1809 | R@1=0.090 R@5=0.261
[06] train_loss=0.968573 | val_MRR=0.1918 | R@1=0.099 R@5=0.276
[07] train_loss=0.962537 | val_MRR=0.2036 | R@1=0.107 R@5=0.295
[08] train_loss=0.957577 | val_MRR=0.2120 | R@1=0.115 R@5=0.304
[09] train_loss=0.953321 | val_MRR=0.2189 | R@1=0.120 R@5=0.313
[10] train_loss=0.949557 | val_MRR=0.2258 | R@1=0.125 R@5=0.324
[11] train_loss=0.946282 | val_MRR=0.2343 | R@1=0.133 R@5=0.335
[12] train_loss=0.943318 | val_MRR=0.2376 | R@1=0.134 R@5=0.338
[13] train_loss=0.940671 | val_MRR=0.2444 | R@1=0.141 R@5=0.347
[14] train_loss=0.938257 | val_MRR=0.2475 | R@1=0.142 R@5=0.352
[15] train_loss=

In [31]:
# python runner.py --out_dir mlp2_wd5e4_moment01 --arch mlp2 --wd 1e-5 --gamma 0.1 --align_loss moment

# too much WD can hurt alignment; so we relax it
if __name__=="__main__": 
    import argparse
    p=argparse.ArgumentParser()
    p.add_argument("--out_dir", type=str, default="mlp1_wd5e4_moment01")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--epochs", type=int, default=30)
    p.add_argument("--batch", type=int, default=512)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--wd", type=float, default=1e-5)
    p.add_argument("--pooling", type=str, default="CLS", choices=["CLS","mean-patch"])
    p.add_argument("--n_patches", type=int, default=None)
    p.add_argument("--alpha", type=float, default=1)
    p.add_argument("--beta", type=float, default=1)
    p.add_argument("--gamma", type=float, default=0.1)
    p.add_argument("--align_loss", type=str, default="moment", choices=["none","moment","normcal"])
    p.add_argument("--arch", type=str, default="mlp1", choices=["auto","linear","mlp1","mlp2"])
    p.add_argument("--train", type=int, default=1)
    p.add_argument("--val_ratio", type=float, default=0.1)
    args, _ = p.parse_known_args()
    main(**vars(args))

[model] mlp1 | params=1,312,768 (~5.01 MB) | pooling=CLS | align=moment | α=1 β=1 γ=0.1
[01] train_loss=1.358940 | val_MRR=0.0541 | R@1=0.019 R@5=0.072
[02] train_loss=1.077363 | val_MRR=0.0915 | R@1=0.040 R@5=0.123
[03] train_loss=1.037254 | val_MRR=0.1146 | R@1=0.052 R@5=0.158
[04] train_loss=1.016725 | val_MRR=0.1294 | R@1=0.059 R@5=0.181
[05] train_loss=1.003400 | val_MRR=0.1438 | R@1=0.067 R@5=0.204
[06] train_loss=0.993644 | val_MRR=0.1539 | R@1=0.072 R@5=0.222
[07] train_loss=0.986059 | val_MRR=0.1641 | R@1=0.079 R@5=0.238
[08] train_loss=0.979914 | val_MRR=0.1723 | R@1=0.084 R@5=0.249
[09] train_loss=0.974755 | val_MRR=0.1802 | R@1=0.089 R@5=0.262
[10] train_loss=0.970305 | val_MRR=0.1876 | R@1=0.096 R@5=0.270
[11] train_loss=0.966435 | val_MRR=0.1949 | R@1=0.102 R@5=0.281
[12] train_loss=0.962991 | val_MRR=0.2001 | R@1=0.105 R@5=0.289
[13] train_loss=0.959890 | val_MRR=0.2073 | R@1=0.110 R@5=0.299
[14] train_loss=0.957038 | val_MRR=0.2125 | R@1=0.115 R@5=0.304
[15] train_loss=

In [32]:
# python runner.py --out_dir mlp2_wd5e4_moment01 --arch mlp2 --wd 5e-4 --gamma 0.1 --align_loss moment

# if overfitting, stronger WD may help.

if __name__=="__main__":
    import argparse
    p=argparse.ArgumentParser()
    p.add_argument("--out_dir", type=str, default="mlp1_wd5e4_moment01_vero")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--epochs", type=int, default=30)
    p.add_argument("--batch", type=int, default=512)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--wd", type=float, default=5e-4)
    p.add_argument("--pooling", type=str, default="CLS", choices=["CLS","mean-patch"])
    p.add_argument("--n_patches", type=int, default=None)
    p.add_argument("--alpha", type=float, default=1)
    p.add_argument("--beta", type=float, default=1)
    p.add_argument("--gamma", type=float, default=0.1)
    p.add_argument("--align_loss", type=str, default="moment", choices=["none","moment","normcal"])
    p.add_argument("--arch", type=str, default="mlp1", choices=["auto","linear","mlp1","mlp2"])
    p.add_argument("--train", type=int, default=1)
    p.add_argument("--val_ratio", type=float, default=0.1)
    args, _ = p.parse_known_args()
    main(**vars(args))

[model] mlp1 | params=1,312,768 (~5.01 MB) | pooling=CLS | align=moment | α=1 β=1 γ=0.1
[01] train_loss=1.358938 | val_MRR=0.0541 | R@1=0.019 R@5=0.072
[02] train_loss=1.077359 | val_MRR=0.0915 | R@1=0.040 R@5=0.123
[03] train_loss=1.037249 | val_MRR=0.1146 | R@1=0.052 R@5=0.158
[04] train_loss=1.016719 | val_MRR=0.1294 | R@1=0.059 R@5=0.181
[05] train_loss=1.003395 | val_MRR=0.1438 | R@1=0.067 R@5=0.204
[06] train_loss=0.993638 | val_MRR=0.1540 | R@1=0.072 R@5=0.221
[07] train_loss=0.986053 | val_MRR=0.1641 | R@1=0.079 R@5=0.238
[08] train_loss=0.979908 | val_MRR=0.1724 | R@1=0.084 R@5=0.248
[09] train_loss=0.974748 | val_MRR=0.1801 | R@1=0.089 R@5=0.262
[10] train_loss=0.970300 | val_MRR=0.1875 | R@1=0.095 R@5=0.270
[11] train_loss=0.966431 | val_MRR=0.1949 | R@1=0.102 R@5=0.281
[12] train_loss=0.962989 | val_MRR=0.2002 | R@1=0.105 R@5=0.290
[13] train_loss=0.959889 | val_MRR=0.2071 | R@1=0.110 R@5=0.299
[14] train_loss=0.957037 | val_MRR=0.2125 | R@1=0.115 R@5=0.304
[15] train_loss=

In [33]:
# python runner.py --out_dir mlp2_bs256_moment01 --arch mlp2 --batch 256 --gamma 0.1 --align_loss moment

# more gradient noise can improve generalization

if __name__=="__main__":
    import argparse
    p=argparse.ArgumentParser()
    p.add_argument("--out_dir", type=str, default="mlp1_bs256_moment01")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--epochs", type=int, default=30)
    p.add_argument("--batch", type=int, default=256)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--wd", type=float, default=1e-4)
    p.add_argument("--pooling", type=str, default="CLS", choices=["CLS","mean-patch"])
    p.add_argument("--n_patches", type=int, default=None)
    p.add_argument("--alpha", type=float, default=1)
    p.add_argument("--beta", type=float, default=1)
    p.add_argument("--gamma", type=float, default=0.1)
    p.add_argument("--align_loss", type=str, default="moment", choices=["none","moment","normcal"])
    p.add_argument("--arch", type=str, default="mlp1", choices=["auto","linear","mlp1","mlp2"])
    p.add_argument("--train", type=int, default=1)
    p.add_argument("--val_ratio", type=float, default=0.1)
    args, _ = p.parse_known_args()
    main(**vars(args))


[model] mlp1 | params=1,312,768 (~5.01 MB) | pooling=CLS | align=moment | α=1 β=1 γ=0.1
[01] train_loss=1.161027 | val_MRR=0.0819 | R@1=0.033 R@5=0.109
[02] train_loss=0.967359 | val_MRR=0.1204 | R@1=0.055 R@5=0.166
[03] train_loss=0.937974 | val_MRR=0.1434 | R@1=0.067 R@5=0.205
[04] train_loss=0.921858 | val_MRR=0.1607 | R@1=0.077 R@5=0.230
[05] train_loss=0.910998 | val_MRR=0.1756 | R@1=0.086 R@5=0.255
[06] train_loss=0.902919 | val_MRR=0.1878 | R@1=0.096 R@5=0.273
[07] train_loss=0.896514 | val_MRR=0.2001 | R@1=0.105 R@5=0.289
[08] train_loss=0.891194 | val_MRR=0.2069 | R@1=0.109 R@5=0.299
[09] train_loss=0.886670 | val_MRR=0.2169 | R@1=0.118 R@5=0.311
[10] train_loss=0.882740 | val_MRR=0.2237 | R@1=0.123 R@5=0.320
[11] train_loss=0.879313 | val_MRR=0.2327 | R@1=0.131 R@5=0.332
[12] train_loss=0.876230 | val_MRR=0.2368 | R@1=0.134 R@5=0.337
[13] train_loss=0.873493 | val_MRR=0.2432 | R@1=0.139 R@5=0.344
[14] train_loss=0.871009 | val_MRR=0.2470 | R@1=0.142 R@5=0.351
[15] train_loss=

In [35]:
# python runner.py --out_dir mlp2_bs1024_moment01 --arch mlp2 --batch 1024 --gamma 0.1 --align_loss moment

# more gradient noise can improve generalization

if __name__=="__main__":
    import argparse
    p=argparse.ArgumentParser()
    p.add_argument("--out_dir", type=str, default="mlp1_bs1024_moment01")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--epochs", type=int, default=30)
    p.add_argument("--batch", type=int, default=1024)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--wd", type=float, default=1e-4)
    p.add_argument("--pooling", type=str, default="CLS", choices=["CLS","mean-patch"])
    p.add_argument("--n_patches", type=int, default=None)
    p.add_argument("--alpha", type=float, default=1)
    p.add_argument("--beta", type=float, default=1)
    p.add_argument("--gamma", type=float, default=0.1)
    p.add_argument("--align_loss", type=str, default="moment", choices=["none","moment","normcal"])
    p.add_argument("--arch", type=str, default="mlp1", choices=["auto","linear","mlp1","mlp2"])
    p.add_argument("--train", type=int, default=1)
    p.add_argument("--val_ratio", type=float, default=0.1)
    args, _ = p.parse_known_args()
    main(**vars(args))

[model] mlp1 | params=1,312,768 (~5.01 MB) | pooling=CLS | align=moment | α=1 β=1 γ=0.1
[01] train_loss=1.625906 | val_MRR=0.0281 | R@1=0.008 R@5=0.033
[02] train_loss=1.212432 | val_MRR=0.0584 | R@1=0.022 R@5=0.077
[03] train_loss=1.153054 | val_MRR=0.0812 | R@1=0.034 R@5=0.108
[04] train_loss=1.123816 | val_MRR=0.0984 | R@1=0.044 R@5=0.135
[05] train_loss=1.105431 | val_MRR=0.1119 | R@1=0.051 R@5=0.153
[06] train_loss=1.092532 | val_MRR=0.1224 | R@1=0.056 R@5=0.168
[07] train_loss=1.082776 | val_MRR=0.1311 | R@1=0.060 R@5=0.183
[08] train_loss=1.075003 | val_MRR=0.1383 | R@1=0.064 R@5=0.196
[09] train_loss=1.068588 | val_MRR=0.1455 | R@1=0.068 R@5=0.206
[10] train_loss=1.063140 | val_MRR=0.1517 | R@1=0.071 R@5=0.216
[11] train_loss=1.058438 | val_MRR=0.1585 | R@1=0.076 R@5=0.227
[12] train_loss=1.054333 | val_MRR=0.1637 | R@1=0.079 R@5=0.235
[13] train_loss=1.050660 | val_MRR=0.1691 | R@1=0.083 R@5=0.243
[14] train_loss=1.047374 | val_MRR=0.1744 | R@1=0.087 R@5=0.251
[15] train_loss=

In [37]:
# python runner.py --out_dir mlp2__infonce02 --arch mlp2 --gamma 0.2 --align_loss none

if __name__=="__main__":
    import argparse
    p=argparse.ArgumentParser()
    p.add_argument("--out_dir", type=str, default="mlp1_moment02_loss_none")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--epochs", type=int, default=30)
    p.add_argument("--batch", type=int, default=512)
    p.add_argument("--lr", type=float, default=1e-4)
    p.add_argument("--wd", type=float, default=1e-4)
    p.add_argument("--pooling", type=str, default="CLS", choices=["CLS","mean-patch"])
    p.add_argument("--n_patches", type=int, default=None)
    p.add_argument("--alpha", type=float, default=1)
    p.add_argument("--beta", type=float, default=1)
    p.add_argument("--gamma", type=float, default=0.2)
    p.add_argument("--align_loss", type=str, default="none", choices=["none","moment","normcal"])
    p.add_argument("--arch", type=str, default="mlp1", choices=["auto","linear","mlp1","mlp2"])
    p.add_argument("--train", type=int, default=1)
    p.add_argument("--val_ratio", type=float, default=0.1)
    args, _ = p.parse_known_args()
    main(**vars(args))

[model] mlp1 | params=1,312,768 (~5.01 MB) | pooling=CLS | align=none | α=1 β=1 γ=0.2
[01] train_loss=1.899691 | val_MRR=0.0433 | R@1=0.015 R@5=0.056
[02] train_loss=1.642304 | val_MRR=0.0746 | R@1=0.031 R@5=0.101
[03] train_loss=1.605984 | val_MRR=0.0969 | R@1=0.042 R@5=0.133
[04] train_loss=1.587693 | val_MRR=0.1122 | R@1=0.050 R@5=0.156
[05] train_loss=1.575995 | val_MRR=0.1278 | R@1=0.060 R@5=0.181
[06] train_loss=1.567547 | val_MRR=0.1381 | R@1=0.065 R@5=0.198
[07] train_loss=1.561022 | val_MRR=0.1484 | R@1=0.071 R@5=0.212
[08] train_loss=1.555711 | val_MRR=0.1556 | R@1=0.075 R@5=0.225
[09] train_loss=1.551231 | val_MRR=0.1649 | R@1=0.080 R@5=0.239
[10] train_loss=1.547342 | val_MRR=0.1725 | R@1=0.086 R@5=0.250
[11] train_loss=1.543942 | val_MRR=0.1813 | R@1=0.094 R@5=0.260
[12] train_loss=1.540911 | val_MRR=0.1858 | R@1=0.096 R@5=0.270
[13] train_loss=1.538198 | val_MRR=0.1929 | R@1=0.101 R@5=0.280
[14] train_loss=1.535719 | val_MRR=0.1980 | R@1=0.105 R@5=0.285
[15] train_loss=1.

In [63]:
# runner.py
import argparse, json, random, re, time
from pathlib import Path
from os.path import basename

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Uses your existing helpers (do not change)
from challenge.src.common import load_data, generate_submission

# ---------- Paths ----------
DATA_ROOT = Path("/kaggle/working/data")
TRAIN_DIR = DATA_ROOT / "train" / "train"
TEST_DIR  = DATA_ROOT / "test"  / "test"
TRAIN_NPZ = TRAIN_DIR / "train.npz"
TEST_NPZ  = TEST_DIR  / "test.clean.npz"
TRAIN_CAPTIONS = TRAIN_DIR / "captions.txt"
TEST_CAPTIONS  = TEST_DIR  / "captions.txt"

# ---------- Determinism ----------
def seed_all(s: int):
    random.seed(s); np.random.seed(s)
    torch.manual_seed(s); torch.cuda.manual_seed_all(s)
    torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False

# ---------- Metadata/safety ----------
def _assert_metadata_dims(npz_path: Path, expect_text=1024, expect_img=1536):
    d = np.load(npz_path, allow_pickle=True)
    md = {k: d[k][0] for k in d.files if k.startswith("metadata/")}
    tdim = md.get("metadata/embedding_dim_text", None)
    idim = md.get("metadata/embedding_dim_image", None)
    print(f"[meta] text_dim={tdim} | image_dim={idim}")
    assert tdim == expect_text and idim == expect_img, (
        f"Encoder dims mismatch: expected text={expect_text}, image={expect_img} "
        f"but got text={tdim}, image={idim}. Regenerate NPZ with the fixed encoders."
    )

# ---------- Caption→image matching from captions.txt ----------
def _build_image_index(image_ids):
    idx_exact = {}; idx_base={}; idx_stem={}
    for i, v in enumerate(image_ids):
        s = str(v); idx_exact[s]=i
        b = basename(s); idx_base[b]=i
        stem = b.rsplit(".",1)[0] if "." in b else b
        idx_stem[stem]=i
    return idx_exact, idx_base, idx_stem

def _iter_targets_from_captions(path: Path, n_text: int, idx_exact, idx_base, idx_stem):
    targets=[]
    with open(path, "r", encoding="utf-8") as f:
        for raw in f:
            if len(targets)>=n_text: break
            line=raw.strip()
            if not line: continue
            # be robust to delimiters
            parts=re.split(r'\||\t|,| {2,}', line)
            if len(parts)==1: parts=line.split(" ",1)
            tok=parts[0].strip().strip('"').strip("'")
            if not tok: continue
            def _match(tok_):
                if tok_ in idx_exact: return idx_exact[tok_]
                b=basename(tok_)
                if b in idx_base: return idx_base[b]
                stem=b.rsplit(".",1)[0] if "." in b else b
                if stem in idx_stem: return idx_stem[stem]
                return None
            idx=_match(tok)
            if idx is None:
                tok2=tok.replace("Images/","").replace("./","")
                idx=_match(tok2)
            if idx is None:
                if "." not in tok: continue
                raise AssertionError(f"Could not match image token '{tok}'")
            targets.append(idx)
    assert len(targets)==n_text, f"matched {len(targets)} vs N_text {n_text}"
    return np.asarray(targets, dtype=np.int64)

# ---------- Data loaders ----------
def load_train(out_dir: Path):
    _assert_metadata_dims(TRAIN_NPZ, 1024, 1536)
    d=np.load(TRAIN_NPZ, allow_pickle=True)
    X=d["captions/embeddings"].astype(np.float32)  # (N_text,1024)
    I=d["images/embeddings"].astype(np.float32)    # (N_img,1536)
    cap_ids=d.get("captions/ids", np.arange(len(X)).astype(str))
    img_names=d.get("images/names", np.arange(len(I)).astype(str))

    ex,ba,st=_build_image_index(img_names)
    assert TRAIN_CAPTIONS.exists(), f"Missing {TRAIN_CAPTIONS}"
    targets=_iter_targets_from_captions(TRAIN_CAPTIONS, len(X), ex,ba,st)

    Y=I[targets]                       # (N_text, 1536) GT image vec per caption
    img_ids_row = img_names[targets]   # image name per caption row (for splitting)
    meta={"n_text":int(len(X)),"n_images":int(len(I))}
    (out_dir/"train_detect.json").write_text(json.dumps(meta, indent=2))
    return X,Y,cap_ids,img_ids_row,I,img_names

def load_test_npz():
    d_test=np.load(TEST_NPZ, allow_pickle=True)
    Q=d_test["captions/embeddings"].astype(np.float32)
    q_ids=d_test.get("captions/ids", np.arange(len(Q)).astype(str))
    if "images/embeddings" in d_test.files:
        G=d_test["images/embeddings"].astype(np.float32)
        g_ids=d_test.get("images/names", np.arange(len(G)).astype(str))
    else:
        d_tr=np.load(TRAIN_NPZ, allow_pickle=True)
        G=d_tr["images/embeddings"].astype(np.float32)
        g_ids=d_tr.get("images/names", np.arange(len(G)).astype(str))
    return Q,G,q_ids,g_ids

# ---------- Dataset ----------
class PairDS(Dataset):
    def __init__(self, X, Y): self.X=torch.from_numpy(X); self.Y=torch.from_numpy(Y)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, i): return self.X[i], self.Y[i]

# ---------- Pooling (no-op per spec) ----------
def apply_pooling(x: torch.Tensor, mode: str, n_patches=None):
    return x

# ---------- Models ----------
class LinearProj(nn.Module):
    def __init__(self, din, dout):
        super().__init__()
        self.fc=nn.Linear(din,dout)
        nn.init.xavier_normal_(self.fc.weight); nn.init.zeros_(self.fc.bias)
    def forward(self,x): return self.fc(x)

class MLP1(nn.Module):
    def __init__(self, din, dout, hidden=512, pdrop=0.0):
        super().__init__()
        self.net=nn.Sequential(
            nn.Linear(din,hidden), nn.ReLU(), nn.Dropout(pdrop),
            nn.Linear(hidden,dout)
        )
        for m in self.net:
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight); nn.init.zeros_(m.bias)
    def forward(self,x): return self.net(x)

class MLP2(nn.Module):
    # Spec: 1024→1024→512→1536, dropout=0.1 on hidden layers
    def __init__(self, din, dout, h1=1024, h2=512, pdrop=0.1):
        super().__init__()
        self.net=nn.Sequential(
            nn.Linear(din,h1), nn.ReLU(), nn.Dropout(pdrop),
            nn.Linear(h1,h2), nn.ReLU(), nn.Dropout(pdrop),
            nn.Linear(h2,dout)
        )
        for m in self.net:
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight); nn.init.zeros_(m.bias)
    def forward(self,x): return self.net(x)

def make_model(arch:str, din:int, dout:int):
    a=arch.lower()
    if a=="linear": return LinearProj(din,dout)
    if a=="mlp1":   return MLP1(din,dout,hidden=512, pdrop=0.1)
    if a=="mlp2":   return MLP2(din,dout,h1=1024,h2=512,pdrop=0.1)
    if a=="auto":   return MLP2(din,dout)
    raise ValueError(f"Unknown arch {arch}")

# ---------- Loss pieces ----------
def moment_align(pred, tgt):
    # Match batch mean & std channel-wise
    mu_p, mu_t = pred.mean(0), tgt.mean(0)
    sd_p, sd_t = pred.std(0, unbiased=False), tgt.std(0, unbiased=False)
    return F.mse_loss(mu_p, mu_t) + F.mse_loss(sd_p, sd_t)

def info_nce(pred, tgt):
    p = F.normalize(pred, dim=-1)
    t = F.normalize(tgt, dim=-1)
    logits = p @ t.t()                      # (B,B)
    labels = torch.arange(pred.size(0), device=pred.device)
    return F.cross_entropy(logits, labels)

# ---------- Split by image id & mapping ----------
def build_image_id_split(img_ids_row, all_img_names, full_img, X, Y, val_ratio, seed, out_dir: Path):
    uniq_img_names = np.array(sorted(set(map(str, all_img_names))))
    rng = np.random.default_rng(seed); rng.shuffle(uniq_img_names)
    n_val = max(1, int(len(uniq_img_names) * val_ratio))
    val_images = set(uniq_img_names[:n_val])
    tr_images  = set(uniq_img_names[n_val:])
    assert len(val_images & tr_images) == 0, "Leakage in image id split!"

    # Caption masks
    cap_is_val = np.array([str(iid) in val_images for iid in img_ids_row], dtype=bool)
    cap_is_tr  = ~cap_is_val

    # Build VAL gallery (unique images in VAL) as sorted names → local indices
    val_img_names_sorted = np.array(sorted(val_images))
    name2local = {name:i for i,name in enumerate(val_img_names_sorted)}
    name2global = {str(n):i for i,n in enumerate(all_img_names)}
    val_img_indices = np.array([name2global[n] for n in val_img_names_sorted], dtype=np.int64)
    val_gallery = full_img[val_img_indices]  # (M,1536)

    # For each VAL caption row, get local gallery index (no remap later)
    cap2gal_local = np.array([name2local[str(n)] for n in img_ids_row[cap_is_val]], dtype=np.int64)

    # Persist mapping for debugging/acceptance
    (out_dir/"val_indices.json").write_text(json.dumps({
        "val_img_indices": val_img_indices.tolist(),
        "val_caption_to_gallery_index": cap2gal_local.tolist(),
        "n_val_captions": int(cap_is_val.sum()),
        "n_val_unique_images": int(len(val_images))
    }, indent=2))

    return cap_is_tr, cap_is_val, val_gallery, cap2gal_local, val_img_indices

# ---------- Metrics / evaluator (full gallery, cosine on L2) ----------
@torch.no_grad()
def validate_retrieval(model, Xv, val_gallery, cap2gal_local, pooling, n_patches, bs=1024):
    device=next(model.parameters()).device
    Gi = torch.from_numpy(val_gallery).to(device)
    Gi = F.normalize(Gi, dim=-1)

    ranks=[]
    for i in range(0, len(Xv), bs):
        xb=torch.from_numpy(Xv[i:i+bs]).to(device)
        xb=apply_pooling(xb, pooling, n_patches)   # no-op
        pred=model(xb)
        pred=F.normalize(pred, dim=-1)
        sims=pred @ Gi.t()                         # (b, M)
        for j in range(sims.size(0)):
            true_idx = int(cap2gal_local[i+j])     # local gallery index
            order = torch.argsort(sims[j], descending=True)
            rank = (order==true_idx).nonzero(as_tuple=False).item() + 1
            ranks.append(rank)

    ranks = np.array(ranks)
    mrr = float(np.mean(1.0 / ranks))
    r1  = float(np.mean(ranks<=1))
    r5  = float(np.mean(ranks<=5))
    r10 = float(np.mean(ranks<=10))
    return {
        "MRR": mrr,
        "R1": r1,
        "R5": r5,
        "R10": r10,
        "rank_median": int(np.median(ranks)),
        "rank_p75": int(np.percentile(ranks, 75))
    }

def count_params_mb(model):
    params=sum(p.numel() for p in model.parameters())
    mb = params * 4 / (1024**2)
    return params, mb

def time_ms_per_query(model, din, pooling, n_patches):
    device=next(model.parameters()).device
    x=torch.randn(2048, din, device=device)
    x=apply_pooling(x, pooling, n_patches)
    if device.type=="cuda":
        torch.cuda.synchronize()
    t0=time.time()
    with torch.no_grad(): _=model(x)
    if device.type=="cuda":
        torch.cuda.synchronize()
    ms_gpu=(time.time()-t0)*1000/len(x)

    # CPU timing
    mcpu = model.to("cpu")
    xcpu = x.to("cpu")
    t1=time.time()
    with torch.no_grad(): _=mcpu(xcpu)
    ms_cpu=(time.time()-t1)*1000/len(xcpu)
    model.to(device)
    return ms_gpu, ms_cpu

# ---------- Train loop ----------
def train_one(model, loader, opt, alpha, beta, gamma, moment_w, pooling, n_patches, device):
    model.train(); total=0.0
    for xb,yb in loader:
        xb,yb=xb.to(device), yb.to(device)
        xb=apply_pooling(xb, pooling, n_patches)
        opt.zero_grad(set_to_none=True)
        pred=model(xb)

        # loss = α·(1 − cos) + β·MSE + λ·moment_align + γ·InfoNCE
        cos = 1 - F.cosine_similarity(pred, yb, dim=-1).mean()
        mse = F.mse_loss(pred, yb)
        a_loss = moment_w * moment_align(pred, yb) if moment_w>0 else pred.new_tensor(0.0)
        ce = info_nce(pred, yb) if gamma>0 else pred.new_tensor(0.0)
        loss = alpha*cos + beta*mse + gamma*ce + a_loss

        loss.backward(); opt.step()
        total += loss.item()*xb.size(0)
    return total/len(loader.dataset)

# ---------- Main ----------
def main(args):
    # Parse flags
    out_dir=args.out_dir; seed=args.seed; epochs=args.epochs; batch=args.batch
    lr=args.lr; wd=args.wd
    pooling=args.pooling; n_patches=None
    alpha=args.alpha; beta=args.beta; gamma=args.gamma; moment_w=args.moment
    arch=args.arch; do_train=not args.eval_only; val_ratio=args.val_ratio

    seed_all(seed)
    OUT = Path(f"/kaggle/working/outputs/{out_dir}"); OUT.mkdir(parents=True, exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # --- load data
    X,Y,cap_ids,img_ids_row,full_img,all_img_ids = load_train(OUT)

    # --- split by image id & mappings
    cap_is_tr, cap_is_val, val_gallery, cap2gal_local, val_img_indices = build_image_id_split(
        img_ids_row, all_img_ids, full_img, X, Y, val_ratio, seed, OUT
    )
    # Safety: no leakage
    val_set = set(all_img_ids[val_img_indices])
    train_set = set(all_img_ids) - val_set
    assert val_set.isdisjoint(train_set), "Leakage detected between TRAIN and VAL image sets!"

    # Propagate masks to captions
    Xtr,Ytr = X[cap_is_tr], Y[cap_is_tr]
    Xva     = X[cap_is_val]

    din, dout = X.shape[1], Y.shape[1]
    # Dim safety
    assert din==1024 and dout==1536, f"Dimension mismatch: text={din}, image={dout} (expected 1024→1536)"

    # --- model
    eff_arch = arch if arch!="auto" else "mlp2"  # default per spec
    model = make_model(eff_arch, din, dout).to(device)

    # Output-dim safety
    with torch.no_grad():
        _probe = model(torch.zeros(2,din,device=device))
    assert _probe.shape[-1]==1536, f"Translator output dim { _probe.shape[-1] } != 1536"

    params, mb = count_params_mb(model)
    print(f"[model] {eff_arch} | params={params:,} (~{mb:.2f} MB) | pooling={pooling} | α={alpha} β={beta} λ_moment={moment_w} γ={gamma}")

    # --- train (best.pt chosen by MRR)
    best_stats=None; best=-1.0; best_ep=0
    if do_train:
        dl = DataLoader(PairDS(Xtr,Ytr), batch_size=batch, shuffle=True, num_workers=2, pin_memory=True, drop_last=False)
        opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
        for ep in range(1, epochs+1):
            tr = train_one(model, dl, opt, alpha,beta,gamma,moment_w, pooling, n_patches, device)
            stats = validate_retrieval(model, Xva, val_gallery, cap2gal_local, pooling, n_patches)
            print(f"[{ep:02d}] train_loss={tr:.6f} | val_MRR={stats['MRR']:.4f} "
                  f"| R@1={stats['R1']:.3f} R@5={stats['R5']:.3f} R@10={stats['R10']:.3f} "
                  f"| median={stats['rank_median']} p75={stats['rank_p75']}")
            if stats["MRR"] > best:
                best, best_ep, best_stats = stats["MRR"], ep, stats
                torch.save({"model":model.state_dict(),"epoch":ep,"val":stats}, OUT/"best.pt")
        print(f"[best] MRR={best:.4f} @ epoch {best_ep}")
        (OUT/"val_metrics.json").write_text(json.dumps(dict(best_epoch=best_ep, **best_stats), indent=2))
    else:
        ckpt = torch.load(OUT/"best.pt", map_location="cpu")
        print(f"[resume] loaded epoch={ckpt.get('epoch','?')} MRR={ckpt.get('val',{}).get('MRR','?')}")
        model.load_state_dict(ckpt["model"])
        best_stats = ckpt.get("val", None)

    # --- efficiency logging
    ms_gpu, ms_cpu = time_ms_per_query(model, din, pooling, n_patches)
    eff = {"params":params,"mb_fp32":mb,"ms_per_query_gpu":ms_gpu,"ms_per_query_cpu":ms_cpu}
    (OUT/"efficiency.json").write_text(json.dumps(eff, indent=2))

    # --- submission (normalized only; test captions → L2 outputs)
    test_data = load_data(TEST_NPZ)
    Q   = test_data["captions/embeddings"].astype(np.float32)
    ids = test_data.get("captions/ids", np.arange(len(Q)).astype(str))
    model.eval()
    BS = 1024
    outs = []
    with torch.no_grad():
        for i in range(0, len(Q), BS):
            q = torch.from_numpy(Q[i:i+BS]).to(device)
            q = apply_pooling(q, pooling, n_patches)
            z = model(q)
            z = F.normalize(z, dim=-1)
            outs.append(z.detach().cpu().numpy())
    pred_embds = np.concatenate(outs, axis=0)
    sub = OUT / "submission.csv"
    generate_submission(ids, pred_embds, str(sub))
    print(f"[ok] submission written → {sub}")

    # --- one-file sanity printout
    sanity = {
        "dims": {"text": int(din), "image": int(dout)},
        "split": {
            "train_captions": int(cap_is_tr.sum()),
            "val_captions": int(cap_is_val.sum()),
            "val_unique_images": int(val_gallery.shape[0]),
            "leakage": False
        },
        "val_metrics": best_stats if best_stats is not None else validate_retrieval(
            model, Xva, val_gallery, cap2gal_local, pooling, n_patches
        ),
        "efficiency": eff
    }
    print(json.dumps(sanity, indent=2))

In [64]:
# ---------- CLI ----------
if __name__ == "__main__":
    p = argparse.ArgumentParser()
    p.add_argument("--out_dir", type=str, default="mrr_first")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--epochs", type=int, default=20)
    p.add_argument("--batch", type=int, default=512)
    p.add_argument("--lr", type=float, default=2e-4)
    p.add_argument("--wd", type=float, default=1e-4)
    p.add_argument("--val_ratio", type=float, default=0.1)
    p.add_argument("--pooling", type=str, default="CLS")  # no-op (kept for compatibility)
    p.add_argument("--alpha", type=float, default=1.0)
    p.add_argument("--beta", type=float, default=1.0)
    p.add_argument("--moment", type=float, default=0.05)  # λ_moment
    p.add_argument("--gamma", type=float, default=0.0)    # InfoNCE (optional, default off)
    p.add_argument("--arch", type=str, default="mlp2", choices=["linear","mlp1","mlp2","auto"])
    p.add_argument("--eval_only", action="store_true", help="Skip training; eval+submission using best.pt")
    args, _ = p.parse_known_args()  # notebook-friendly
    main(args)


[meta] text_dim=1024 | image_dim=1536
[model] mlp2 | params=2,362,368 (~9.01 MB) | pooling=CLS | α=1.0 β=1.0 λ_moment=0.05 γ=0.0
[01] train_loss=0.538621 | val_MRR=0.0418 | R@1=0.014 R@5=0.052 R@10=0.087 | median=197 p75=621
[02] train_loss=0.386475 | val_MRR=0.0807 | R@1=0.031 R@5=0.108 R@10=0.172 | median=79 p75=296
[03] train_loss=0.358601 | val_MRR=0.1049 | R@1=0.045 R@5=0.146 R@10=0.226 | median=55 p75=220
[04] train_loss=0.344627 | val_MRR=0.1249 | R@1=0.055 R@5=0.180 R@10=0.268 | median=41 p75=176
[05] train_loss=0.335342 | val_MRR=0.1379 | R@1=0.062 R@5=0.200 R@10=0.300 | median=35 p75=155
[06] train_loss=0.328462 | val_MRR=0.1502 | R@1=0.071 R@5=0.218 R@10=0.315 | median=31 p75=140
[07] train_loss=0.323083 | val_MRR=0.1579 | R@1=0.076 R@5=0.228 R@10=0.331 | median=28 p75=127
[08] train_loss=0.318563 | val_MRR=0.1677 | R@1=0.083 R@5=0.245 R@10=0.346 | median=26 p75=121
[09] train_loss=0.314723 | val_MRR=0.1741 | R@1=0.086 R@5=0.255 R@10=0.358 | median=24 p75=113
[10] train_loss

In [65]:
# runner.py
import argparse, json, random, re, time
from pathlib import Path
from os.path import basename

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Uses your existing helpers (do not change)
from challenge.src.common import load_data, generate_submission

# ---------- Paths ----------
DATA_ROOT = Path("/kaggle/working/data")
TRAIN_DIR = DATA_ROOT / "train" / "train"
TEST_DIR  = DATA_ROOT / "test"  / "test"
TRAIN_NPZ = TRAIN_DIR / "train.npz"
TEST_NPZ  = TEST_DIR  / "test.clean.npz"
TRAIN_CAPTIONS = TRAIN_DIR / "captions.txt"
TEST_CAPTIONS  = TEST_DIR  / "captions.txt"

# ---------- Determinism ----------
def seed_all(s: int):
    random.seed(s); np.random.seed(s)
    torch.manual_seed(s); torch.cuda.manual_seed_all(s)
    torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False

# ---------- Metadata/safety ----------
def _assert_metadata_dims(npz_path: Path, expect_text=1024, expect_img=1536):
    d = np.load(npz_path, allow_pickle=True)
    md = {k: d[k][0] for k in d.files if k.startswith("metadata/")}
    tdim = md.get("metadata/embedding_dim_text", None)
    idim = md.get("metadata/embedding_dim_image", None)
    print(f"[meta] text_dim={tdim} | image_dim={idim}")
    assert tdim == expect_text and idim == expect_img, (
        f"Encoder dims mismatch: expected text={expect_text}, image={expect_img} "
        f"but got text={tdim}, image={idim}. Regenerate NPZ with the fixed encoders."
    )

# ---------- Caption→image matching from captions.txt ----------
def _build_image_index(image_ids):
    idx_exact = {}; idx_base={}; idx_stem={}
    for i, v in enumerate(image_ids):
        s = str(v); idx_exact[s]=i
        b = basename(s); idx_base[b]=i
        stem = b.rsplit(".",1)[0] if "." in b else b
        idx_stem[stem]=i
    return idx_exact, idx_base, idx_stem

def _iter_targets_from_captions(path: Path, n_text: int, idx_exact, idx_base, idx_stem):
    targets=[]
    with open(path, "r", encoding="utf-8") as f:
        for raw in f:
            if len(targets)>=n_text: break
            line=raw.strip()
            if not line: continue
            # be robust to delimiters
            parts=re.split(r'\||\t|,| {2,}', line)
            if len(parts)==1: parts=line.split(" ",1)
            tok=parts[0].strip().strip('"').strip("'")
            if not tok: continue
            def _match(tok_):
                if tok_ in idx_exact: return idx_exact[tok_]
                b=basename(tok_)
                if b in idx_base: return idx_base[b]
                stem=b.rsplit(".",1)[0] if "." in b else b
                if stem in idx_stem: return idx_stem[stem]
                return None
            idx=_match(tok)
            if idx is None:
                tok2=tok.replace("Images/","").replace("./","")
                idx=_match(tok2)
            if idx is None:
                if "." not in tok: continue
                raise AssertionError(f"Could not match image token '{tok}'")
            targets.append(idx)
    assert len(targets)==n_text, f"matched {len(targets)} vs N_text {n_text}"
    return np.asarray(targets, dtype=np.int64)

# ---------- Data loaders ----------
def load_train(out_dir: Path):
    _assert_metadata_dims(TRAIN_NPZ, 1024, 1536)
    d=np.load(TRAIN_NPZ, allow_pickle=True)
    X=d["captions/embeddings"].astype(np.float32)  # (N_text,1024)
    I=d["images/embeddings"].astype(np.float32)    # (N_img,1536)
    cap_ids=d.get("captions/ids", np.arange(len(X)).astype(str))
    img_names=d.get("images/names", np.arange(len(I)).astype(str))

    ex,ba,st=_build_image_index(img_names)
    assert TRAIN_CAPTIONS.exists(), f"Missing {TRAIN_CAPTIONS}"
    targets=_iter_targets_from_captions(TRAIN_CAPTIONS, len(X), ex,ba,st)

    Y=I[targets]                       # (N_text, 1536) GT image vec per caption
    img_ids_row = img_names[targets]   # image name per caption row (for splitting)
    meta={"n_text":int(len(X)),"n_images":int(len(I))}
    (out_dir/"train_detect.json").write_text(json.dumps(meta, indent=2))
    return X,Y,cap_ids,img_ids_row,I,img_names

def load_test_npz():
    d_test=np.load(TEST_NPZ, allow_pickle=True)
    Q=d_test["captions/embeddings"].astype(np.float32)
    q_ids=d_test.get("captions/ids", np.arange(len(Q)).astype(str))
    if "images/embeddings" in d_test.files:
        G=d_test["images/embeddings"].astype(np.float32)
        g_ids=d_test.get("images/names", np.arange(len(G)).astype(str))
    else:
        d_tr=np.load(TRAIN_NPZ, allow_pickle=True)
        G=d_tr["images/embeddings"].astype(np.float32)
        g_ids=d_tr.get("images/names", np.arange(len(G)).astype(str))
    return Q,G,q_ids,g_ids

# ---------- Dataset ----------
class PairDS(Dataset):
    def __init__(self, X, Y): self.X=torch.from_numpy(X); self.Y=torch.from_numpy(Y)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, i): return self.X[i], self.Y[i]

# ---------- Pooling (no-op per spec) ----------
def apply_pooling(x: torch.Tensor, mode: str, n_patches=None):
    return x

# ---------- Base Models ----------
class LinearProj(nn.Module):
    def __init__(self, din, dout):
        super().__init__()
        self.fc=nn.Linear(din,dout)
        nn.init.xavier_normal_(self.fc.weight); nn.init.zeros_(self.fc.bias)
    def forward(self,x): return self.fc(x)

class MLP1(nn.Module):
    def __init__(self, din, dout, hidden=512, pdrop=0.1):
        super().__init__()
        self.net=nn.Sequential(
            nn.Linear(din,hidden), nn.ReLU(), nn.Dropout(pdrop),
            nn.Linear(hidden,dout)
        )
        for m in self.net:
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight); nn.init.zeros_(m.bias)
    def forward(self,x): return self.net(x)

class MLP2(nn.Module):
    # Spec: 1024→1024→512→1536, dropout=0.1 on hidden layers
    def __init__(self, din, dout, h1=1024, h2=512, pdrop=0.1):
        super().__init__()
        self.net=nn.Sequential(
            nn.Linear(din,h1), nn.ReLU(), nn.Dropout(pdrop),
            nn.Linear(h1,h2), nn.ReLU(), nn.Dropout(pdrop),
            nn.Linear(h2,dout)
        )
        for m in self.net:
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight); nn.init.zeros_(m.bias)
    def forward(self,x): return self.net(x)

# ---------- Geometry-Preserving Linear (Whiten → Procrustes → Re-color) ----------
def _cov_eigh(zc, eps):
    # zc: (N,D), zero-mean
    N = zc.shape[0]
    C = (zc.T @ zc) / max(1, N)
    # symmetric PSD
    S, U = np.linalg.eigh(C)
    S = np.clip(S, 0.0, None)
    inv_sqrt = 1.0 / np.sqrt(S + eps)
    sqrt = np.sqrt(S + eps)
    C_mhalf = (U * inv_sqrt) @ U.T     # C^{-1/2}
    C_phalf = (U * sqrt) @ U.T         # C^{ 1/2}
    return C_mhalf.astype(np.float32), C_phalf.astype(np.float32)

def procrustes_closed_form(X, Y, eps=1e-5):
    """
    X: (N, 1024) text, Y: (N, 1536) targets (per-caption image vectors)
    Returns A (1536x1024), b (1536,) such that y_hat = A x + b
    """
    # center
    mu_x = X.mean(0, dtype=np.float64)
    mu_y = Y.mean(0, dtype=np.float64)
    Xc = X - mu_x
    Yc = Y - mu_y

    # whiten both
    Cx_mh, _ = _cov_eigh(Xc, eps)      # 1024x1024
    Cy_mh, Cy_ph = _cov_eigh(Yc, eps)  # 1536x1536 (only Cy_ph used)
    Xw = Xc @ Cx_mh.T                  # (N,1024)
    Yw = Yc @ Cy_mh.T                  # (N,1536)

    # orthogonal Procrustes: maximize Tr(R^T Xw^T Yw)
    M = Xw.T @ Yw                      # (1024,1536)
    # SVD on M; for rectangular, do SVD and build R = U V^T in common subspace
    U, _, Vt = np.linalg.svd(M, full_matrices=False)
    R = U @ Vt                         # (1024,1536) @ (1536,1536) -> (1024,1536) OK since full_matrices=False
    # we need R as (1536x1024) mapping whitened X to whitened Y; above is (1024x1536)
    R = R.T                            # (1536,1024)

    # re-color to Y space
    A = (Cy_ph @ R @ Cx_mh).astype(np.float32)     # (1536,1536)*(1536,1024)*(1024,1024) = (1536,1024)
    b = (mu_y - (A @ mu_x)).astype(np.float32)     # (1536,)
    return A, b

class GeomLinear(nn.Module):
    """
    Linear layer with weights initialized from closed-form geometry mapping.
    Optionally fine-tunes with tiny LR.
    """
    def __init__(self, A: np.ndarray, b: np.ndarray):
        super().__init__()
        D_in = A.shape[1]; D_out = A.shape[0]
        self.fc = nn.Linear(D_in, D_out, bias=True)
        with torch.no_grad():
            self.fc.weight.copy_(torch.from_numpy(A))
            self.fc.bias.copy_(torch.from_numpy(b))
    def forward(self, x):
        return self.fc(x)

# ---------- Loss pieces (for optional fine-tune) ----------
def moment_align(pred, tgt):
    mu_p, mu_t = pred.mean(0), tgt.mean(0)
    sd_p, sd_t = pred.std(0, unbiased=False), tgt.std(0, unbiased=False)
    return F.mse_loss(mu_p, mu_t) + F.mse_loss(sd_p, sd_t)

def info_nce(pred, tgt):
    p = F.normalize(pred, dim=-1)
    t = F.normalize(tgt, dim=-1)
    logits = p @ t.t()
    labels = torch.arange(pred.size(0), device=pred.device)
    return F.cross_entropy(logits, labels)

# ---------- Split by image id & mapping ----------
def build_image_id_split(img_ids_row, all_img_names, full_img, X, Y, val_ratio, seed, out_dir: Path):
    uniq_img_names = np.array(sorted(set(map(str, all_img_names))))
    rng = np.random.default_rng(seed); rng.shuffle(uniq_img_names)
    n_val = max(1, int(len(uniq_img_names) * val_ratio))
    val_images = set(uniq_img_names[:n_val])
    tr_images  = set(uniq_img_names[n_val:])
    assert len(val_images & tr_images) == 0, "Leakage in image id split!"

    # Caption masks
    cap_is_val = np.array([str(iid) in val_images for iid in img_ids_row], dtype=bool)
    cap_is_tr  = ~cap_is_val

    # Build VAL gallery (unique images in VAL) as sorted names → local indices
    val_img_names_sorted = np.array(sorted(val_images))
    name2local = {name:i for i,name in enumerate(val_img_names_sorted)}
    name2global = {str(n):i for i,n in enumerate(all_img_names)}
    val_img_indices = np.array([name2global[n] for n in val_img_names_sorted], dtype=np.int64)
    val_gallery = full_img[val_img_indices]  # (M,1536)

    # For each VAL caption row, get local gallery index (no remap later)
    cap2gal_local = np.array([name2local[str(n)] for n in img_ids_row[cap_is_val]], dtype=np.int64)

    # Persist mapping for debugging/acceptance
    (out_dir/"val_indices.json").write_text(json.dumps({
        "val_img_indices": val_img_indices.tolist(),
        "val_caption_to_gallery_index": cap2gal_local.tolist(),
        "n_val_captions": int(cap_is_val.sum()),
        "n_val_unique_images": int(len(val_images))
    }, indent=2))

    return cap_is_tr, cap_is_val, val_gallery, cap2gal_local, val_img_indices

# ---------- Metrics / evaluator (full gallery, cosine on L2) ----------
@torch.no_grad()
def validate_retrieval(model, Xv, val_gallery, cap2gal_local, pooling, n_patches, bs=1024):
    device=next(model.parameters()).device
    Gi = torch.from_numpy(val_gallery).to(device)
    Gi = F.normalize(Gi, dim=-1)

    ranks=[]
    for i in range(0, len(Xv), bs):
        xb=torch.from_numpy(Xv[i:i+bs]).to(device)
        xb=apply_pooling(xb, pooling, n_patches)   # no-op
        pred=model(xb)
        pred=F.normalize(pred, dim=-1)
        sims=pred @ Gi.t()                         # (b, M)
        for j in range(sims.size(0)):
            true_idx = int(cap2gal_local[i+j])     # local gallery index
            order = torch.argsort(sims[j], descending=True)
            rank = (order==true_idx).nonzero(as_tuple=False).item() + 1
            ranks.append(rank)

    ranks = np.array(ranks)
    mrr = float(np.mean(1.0 / ranks))
    r1  = float(np.mean(ranks<=1))
    r5  = float(np.mean(ranks<=5))
    r10 = float(np.mean(ranks<=10))
    return {
        "MRR": mrr,
        "R1": r1,
        "R5": r5,
        "R10": r10,
        "rank_median": int(np.median(ranks)),
        "rank_p75": int(np.percentile(ranks, 75))
    }

def count_params_mb(model):
    params=sum(p.numel() for p in model.parameters())
    mb = params * 4 / (1024**2)
    return params, mb

def time_ms_per_query(model, din, pooling, n_patches):
    device=next(model.parameters()).device
    x=torch.randn(2048, din, device=device)
    x=apply_pooling(x, pooling, n_patches)
    if device.type=="cuda":
        torch.cuda.synchronize()
    t0=time.time()
    with torch.no_grad(): _=model(x)
    if device.type=="cuda":
        torch.cuda.synchronize()
    ms_gpu=(time.time()-t0)*1000/len(x)
    # CPU
    mcpu=model.to("cpu"); xcpu=x.to("cpu")
    t1=time.time()
    with torch.no_grad(): _=mcpu(xcpu)
    ms_cpu=(time.time()-t1)*1000/len(xcpu)
    model.to(device)
    return ms_gpu, ms_cpu

# ---------- Train (for optional fine-tune) ----------
def train_one(model, loader, opt, alpha, beta, gamma, moment_w, pooling, n_patches, device):
    model.train(); total=0.0
    for xb,yb in loader:
        xb,yb=xb.to(device), yb.to(device)
        xb=apply_pooling(xb, pooling, n_patches)
        opt.zero_grad(set_to_none=True)
        pred=model(xb)
        # α·(1−cos) + β·MSE + λ·moment_align + γ·InfoNCE
        cos = 1 - F.cosine_similarity(pred, yb, dim=-1).mean()
        mse = F.mse_loss(pred, yb)
        a_loss = moment_w * moment_align(pred, yb) if moment_w>0 else pred.new_tensor(0.0)
        ce = info_nce(pred, yb) if gamma>0 else pred.new_tensor(0.0)
        loss = alpha*cos + beta*mse + gamma*ce + a_loss
        loss.backward(); opt.step()
        total += loss.item()*xb.size(0)
    return total/len(loader.dataset)

# ---------- Factory ----------
def make_model(arch:str, din:int, dout:int, geom_init_data=None):
    a=arch.lower()
    if a=="linear": return LinearProj(din,dout)
    if a=="mlp1":   return MLP1(din,dout,hidden=512, pdrop=0.1)
    if a=="mlp2":   return MLP2(din,dout,h1=1024,h2=512,pdrop=0.1)
    if a=="geom":
        assert geom_init_data is not None, "geom requires (Xtr, Ytr, eps)"
        Xtr_np, Ytr_np, eps = geom_init_data
        A,b = procrustes_closed_form(Xtr_np, Ytr_np, eps=eps)
        return GeomLinear(A, b)
    if a=="auto":   return MLP2(din,dout)
    raise ValueError(f"Unknown arch {arch}")

# ---------- Main ----------
def main(args):
    # Parse flags
    out_dir=args.out_dir; seed=args.seed; epochs=args.epochs; batch=args.batch
    lr=args.lr; wd=args.wd
    pooling=args.pooling; n_patches=None
    alpha=args.alpha; beta=args.beta; gamma=args.gamma; moment_w=args.moment
    arch=args.arch; do_train=not args.eval_only; val_ratio=args.val_ratio
    geom_eps=args.geom_eps; geom_ft_epochs=args.geom_finetune_epochs; geom_ft_lr=args.geom_finetune_lr; geom_ft_wd=args.geom_finetune_wd

    seed_all(seed)
    OUT = Path(f"/kaggle/working/outputs/{out_dir}"); OUT.mkdir(parents=True, exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # --- load data
    X,Y,cap_ids,img_ids_row,full_img,all_img_ids = load_train(OUT)

    # --- split by image id & mappings
    cap_is_tr, cap_is_val, val_gallery, cap2gal_local, val_img_indices = build_image_id_split(
        img_ids_row, all_img_ids, full_img, X, Y, val_ratio, seed, OUT
    )
    # Safety: no leakage
    val_set = set(all_img_ids[val_img_indices])
    train_set = set(all_img_ids) - val_set
    assert val_set.isdisjoint(train_set), "Leakage detected between TRAIN and VAL image sets!"

    # Propagate masks to captions
    Xtr,Ytr = X[cap_is_tr], Y[cap_is_tr]
    Xva     = X[cap_is_val]

    din, dout = X.shape[1], Y.shape[1]
    # Dim safety
    assert din==1024 and dout==1536, f"Dimension mismatch: text={din}, image={dout} (expected 1024→1536)"

    # --- model (geom has closed-form init)
    geom_init = None
    if arch.lower()=="geom":
        geom_init = (Xtr.astype(np.float32), Ytr.astype(np.float32), float(geom_eps))
    model = make_model(arch, din, dout, geom_init).to(device)

    # Output-dim safety
    with torch.no_grad():
        _probe = model(torch.zeros(2,din,device=device))
    assert _probe.shape[-1]==1536, f"Translator output dim { _probe.shape[-1] } != 1536"

    params, mb = count_params_mb(model)
    print(f"[model] {arch} | params={params:,} (~{mb:.2f} MB) | pooling={pooling} | α={alpha} β={beta} λ_moment={moment_w} γ={gamma}")

    # --- training strategy
    best_stats=None; best=-1.0; best_ep=0

    if arch.lower()=="geom" and geom_ft_epochs>0 and do_train:
        # tiny fine-tune on linear head
        dl = DataLoader(PairDS(Xtr,Ytr), batch_size=batch, shuffle=True, num_workers=2, pin_memory=True, drop_last=False)
        opt = torch.optim.AdamW(model.parameters(), lr=geom_ft_lr, weight_decay=geom_ft_wd)
        for ep in range(1, geom_ft_epochs+1):
            tr = train_one(model, dl, opt, alpha,beta,gamma,moment_w, pooling, n_patches, device)
            stats = validate_retrieval(model, Xva, val_gallery, cap2gal_local, pooling, n_patches)
            print(f"[geom-ft {ep:02d}] train_loss={tr:.6f} | val_MRR={stats['MRR']:.4f} "
                  f"| R@1={stats['R1']:.3f} R@5={stats['R5']:.3f} R@10={stats['R10']:.3f} "
                  f"| median={stats['rank_median']} p75={stats['rank_p75']}")
            if stats["MRR"] > best:
                best, best_ep, best_stats = stats["MRR"], ep, stats
                torch.save({"model":model.state_dict(),"epoch":ep,"val":stats}, OUT/"best.pt")
        if best_stats is None:
            best_stats = validate_retrieval(model, Xva, val_gallery, cap2gal_local, pooling, n_patches)
            torch.save({"model":model.state_dict(),"epoch":0,"val":best_stats}, OUT/"best.pt")
        (OUT/"val_metrics.json").write_text(json.dumps(dict(best_epoch=best_ep, **best_stats), indent=2))

    elif do_train and arch.lower()!="geom":
        # normal training for linear/mlp1/mlp2
        dl = DataLoader(PairDS(Xtr,Ytr), batch_size=batch, shuffle=True, num_workers=2, pin_memory=True, drop_last=False)
        opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
        for ep in range(1, epochs+1):
            tr = train_one(model, dl, opt, alpha,beta,gamma,moment_w, pooling, n_patches, device)
            stats = validate_retrieval(model, Xva, val_gallery, cap2gal_local, pooling, n_patches)
            print(f"[{ep:02d}] train_loss={tr:.6f} | val_MRR={stats['MRR']:.4f} "
                  f"| R@1={stats['R1']:.3f} R@5={stats['R5']:.3f} R@10={stats['R10']:.3f} "
                  f"| median={stats['rank_median']} p75={stats['rank_p75']}")
            if stats["MRR"] > best:
                best, best_ep, best_stats = stats["MRR"], ep, stats
                torch.save({"model":model.state_dict(),"epoch":ep,"val":stats}, OUT/"best.pt")
        if best_stats is None:
            best_stats = validate_retrieval(model, Xva, val_gallery, cap2gal_local, pooling, n_patches)
            torch.save({"model":model.state_dict(),"epoch":0,"val":best_stats}, OUT/"best.pt")
        (OUT/"val_metrics.json").write_text(json.dumps(dict(best_epoch=best_ep, **best_stats), indent=2))
    else:
        # eval-only: load best.pt if present, else just evaluate the freshly-built model
        try:
            ckpt = torch.load(OUT/"best.pt", map_location="cpu")
            print(f"[resume] loaded epoch={ckpt.get('epoch','?')} MRR={ckpt.get('val',{}).get('MRR','?')}")
            model.load_state_dict(ckpt["model"])
            best_stats = ckpt.get("val", None)
        except FileNotFoundError:
            best_stats = validate_retrieval(model, Xva, val_gallery, cap2gal_local, pooling, n_patches)

    # --- efficiency logging
    ms_gpu, ms_cpu = time_ms_per_query(model, din, pooling, n_patches)
    eff = {"params":params,"mb_fp32":mb,"ms_per_query_gpu":ms_gpu,"ms_per_query_cpu":ms_cpu}
    (OUT/"efficiency.json").write_text(json.dumps(eff, indent=2))

    # --- submission (normalized only; test captions → L2 outputs)
    test_data = load_data(TEST_NPZ)
    Q   = test_data["captions/embeddings"].astype(np.float32)
    ids = test_data.get("captions/ids", np.arange(len(Q)).astype(str))
    model.eval()
    BS = 1024
    outs = []
    with torch.no_grad():
        for i in range(0, len(Q), BS):
            q = torch.from_numpy(Q[i:i+BS]).to(device)
            q = apply_pooling(q, pooling, n_patches)
            z = model(q)
            z = F.normalize(z, dim=-1)
            outs.append(z.detach().cpu().numpy())
    pred_embds = np.concatenate(outs, axis=0)
    sub = OUT / "submission.csv"
    generate_submission(ids, pred_embds, str(sub))
    print(f"[ok] submission written → {sub}")

    # --- one-file sanity printout
    sanity = {
        "dims": {"text": int(din), "image": int(dout)},
        "split": {
            "train_captions": int(cap_is_tr.sum()),
            "val_captions": int(cap_is_val.sum()),
            "val_unique_images": int(val_gallery.shape[0]),
            "leakage": False
        },
        "val_metrics": best_stats if best_stats is not None else validate_retrieval(
            model, Xva, val_gallery, cap2gal_local, pooling, n_patches
        ),
        "efficiency": eff
    }
    print(json.dumps(sanity, indent=2))

In [72]:
# ---------- CLI ----------
if __name__ == "__main__":
    p = argparse.ArgumentParser()
    p.add_argument("--out_dir", type=str, default="geom_mrr_first")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--epochs", type=int, default=20)
    p.add_argument("--batch", type=int, default=512)
    p.add_argument("--lr", type=float, default=2e-4)
    p.add_argument("--wd", type=float, default=1e-4)
    p.add_argument("--val_ratio", type=float, default=0.1)
    p.add_argument("--pooling", type=str, default="CLS")  # no-op (kept for compatibility)
    p.add_argument("--alpha", type=float, default=1.0)
    p.add_argument("--beta", type=float, default=1.0)
    p.add_argument("--moment", type=float, default=0.05)  # λ_moment
    p.add_argument("--gamma", type=float, default=0.0)    # InfoNCE (optional, default off)
    p.add_argument("--arch", type=str, default="geom", choices=["geom","linear","mlp1","mlp2","auto"])
    # Geometry step knobs
    p.add_argument("--geom_eps", type=float, default=1e-5, help="Covariance epsilon for whitening")
    p.add_argument("--geom_finetune_epochs", type=int, default=0, help="Tiny linear fine-tune epochs (0 = off)")
    p.add_argument("--geom_finetune_lr", type=float, default=1e-4)
    p.add_argument("--geom_finetune_wd", type=float, default=1e-5)
    p.add_argument("--eval_only", action="store_true", help="Skip training; eval+submission using best.pt or closed-form")
    args, _ = p.parse_known_args()  # notebook-friendly
    main(args)


[meta] text_dim=1024 | image_dim=1536
[model] geom | params=1,574,400 (~6.01 MB) | pooling=CLS | α=1.0 β=1.0 λ_moment=0.05 γ=0.0
Generating submission file...
✓ Saved submission to /kaggle/working/outputs/geom_mrr_first/submission.csv
[ok] submission written → /kaggle/working/outputs/geom_mrr_first/submission.csv
{
  "dims": {
    "text": 1024,
    "image": 1536
  },
  "split": {
    "train_captions": 112500,
    "val_captions": 12500,
    "val_unique_images": 2500,
    "leakage": false
  },
  "val_metrics": {
    "MRR": 0.3191292292010856,
    "R1": 0.20912,
    "R5": 0.4364,
    "R10": 0.54536,
    "rank_median": 8,
    "rank_p75": 40
  },
  "efficiency": {
    "params": 1574400,
    "mb_fp32": 6.005859375,
    "ms_per_query_gpu": 0.00049639493227005,
    "ms_per_query_cpu": 0.02339109778404236
  }
}
