### Training

In [43]:
# ---------- train_vqa_fast.py (cached image embeddings + robust multitask heads) ----------
import os, csv, json, hashlib, math
from pathlib import Path
from datetime import datetime
from typing import List
from itertools import chain

import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader, WeightedRandomSampler

import timm
from transformers import AutoModel, AutoTokenizer

# Pull safe utils you already have
from fm_utils import CFG, NPZDataset, collate_simple, nullcontext

# =====================================================
# I/O (parquets)
# =====================================================
IDX_PARQUET = "metaspace_images_dump/msi_fm_samples3.parquet"
MAN_PARQUET = "metaspace_images_dump/manifest_expanded.parquet"

df_idx = pd.read_parquet(IDX_PARQUET)
df_man = pd.read_parquet(MAN_PARQUET)

need_cols = [
    "dataset_id", "organism", "polarity", "Organism_Part", "Condition",
    "analyzerType", "ionisationSource"
]
man_sub = df_man[[c for c in need_cols if c in df_man.columns]].drop_duplicates("dataset_id")

df = df_idx.merge(man_sub, on="dataset_id", how="left", suffixes=("", "_man"))
df = df.loc[:, ~df.columns.duplicated()]

if "sample_path" not in df.columns:
    raise ValueError("Index parquet must include 'sample_path' pointing to .npz files.")

# =====================================================
# Label normalization
# =====================================================
def _canon_base(x):
    if x is None:
        return "unknown"
    s = str(x).strip().lower()
    return "unknown" if s in ("", "none", "nan", "na", "null") else s

def normalize_organism(x):
    s = _canon_base(x)
    if s == "unknown": return s
    if "mus musculus" in s or s in {"mouse", "mouse brain", "mus", "m. musculus"}:
        return "mus musculus"
    if "homo sapiens" in s or s in {"human", "h. sapiens"}:
        return "homo sapiens"
    return "unknown"

def normalize_polarity(x):
    s = _canon_base(x)
    if s in {"pos", "+", "positive", "positive ion mode", "positive mode"}:  return "positive"
    if s in {"neg", "-", "negative", "negative ion mode", "negative mode"}:  return "negative"
    return "unknown"

def normalize_text(x):
    return _canon_base(x)

for col, fn in [
    ("organism", normalize_organism),
    ("polarity", normalize_polarity),
    ("Organism_Part", normalize_text),
    ("Condition", normalize_text),
    ("analyzerType", normalize_text),
    ("ionisationSource", normalize_text),
]:
    if col in df.columns:
        df[col] = df[col].map(fn)
    else:
        df[col] = "unknown"

# =====================================================
# RARE-CLASS FILTERING / NA REMOVAL
# =====================================================
# Any label mapped to "unknown" is **ignored** in loss. We also downmap rare classes to "unknown".
MIN_SAMPLES_PER_CLASS = 100      # <- change as needed
DROP_ROWS_ALL_UNKNOWN = True    # drop rows where ALL tasks are unknown (saves compute)

TASK_COLS = {
    "organism": "organism",
    "polarity": "polarity",
    "organ": "Organism_Part",
    "condition": "Condition",
    "analyzerType": "analyzerType",
    "ionisationSource": "ionisationSource",
}

def _apply_min_count_filter(df: pd.DataFrame, col: str, min_count: int) -> pd.Series:
    """Return a filtered label series: classes with freq < min_count -> 'unknown'."""
    s = df[col].astype(str)
    vc = s[s != "unknown"].value_counts()
    keep = set(vc[vc >= min_count].index.tolist())
    return s.where(s.isin(keep), other="unknown")

for tname, col in TASK_COLS.items():
    df[col] = _apply_min_count_filter(df, col, MIN_SAMPLES_PER_CLASS)

if DROP_ROWS_ALL_UNKNOWN:
    mask_known_any = False
    for col in TASK_COLS.values():
        mask_known_any |= (df[col] != "unknown")
    before = len(df)
    df = df[mask_known_any].reset_index(drop=True)
    print(f"[FILTER] Dropped {before - len(df)} rows with all tasks == 'unknown'.")

def _print_task_summary(df, task_cols):
    print("\n[SUMMARY] Class counts after filtering (unknown shown but ignored during loss):")
    for tname, col in task_cols.items():
        vc = df[col].value_counts().sort_values(ascending=False)
        print(f" - {tname}: total={int(vc.sum())}, classes={int((vc.index!='unknown').sum())} (+unknown)")
        print(vc.head(20).to_string())  # top-20 preview

_print_task_summary(df, TASK_COLS)

# Build class vocabularies **excluding 'unknown'** (we ignore unknown in loss)
def _build_vocab(series):
    vals = sorted([v for v in set(series.dropna().astype(str).tolist()) if v != "unknown"])
    # Can be empty if everything was filtered; handle later when building heads
    return vals

cls_spaces = {
    k: _build_vocab(df[col]) for k, col in TASK_COLS.items()
}

# =====================================================
# Config / knobs
# =====================================================
SEED = 6740
torch.manual_seed(SEED); np.random.seed(SEED)

# ---- Switches ----
USE_TEXT   = True     # set False for image-only (often a strong baseline)
REBALANCE  = False    # simple sampler example (organism only)
LABEL_SMOOTH = 0.05   # cross-entropy label smoothing

# ---- Frozen image backbone (timm ViT) ----
TIMM_ID         = "vit_small_patch14_reg4_dinov2.lvd142m"   # e.g., "deit_small_distilled_patch16_224"
PATCH_MULTIPLE  = 14                                        # 16 for DeiT/MAE, 14 for DINOv2
TARGET_SIZE     = 518                                       # 224 for DeiT/MAE; 518 for DINOv2 reg4
PRETRAINED      = True

# ---- Trainable text tower (HF) ----
HF_TEXT_MODEL      = "sentence-transformers/all-MiniLM-L6-v2"
TEXT_MAX_LEN       = 64
TEXT_OUT_DIM       = 384

# Partial tuning strategy for speed
# options: "none" (freeze all, only heads), "proj" (only projection), "last_k" (unfreeze last K blocks + proj)
TEXT_TRAIN_STRATEGY = "last_k"
TEXT_TRAIN_LAST_K   = 2
TEXT_FREEZE_EMBED   = True

# Optimizer specifics (separate LR for base vs projection)
LR_TXT_BASE   = 5e-5
LR_TXT_PROJ   = 1e-4
USE_FUSED_ADAMW = torch.cuda.is_available()

# ---- Data / train ----
AMP = True
CHANNELS_PER_VIEW  = 64
CHANNELS_PER_STEP  = 16
BATCH_SIZE         = 128
NUM_WORKERS        = 0

TARGET_H = 256
TARGET_W = 256

TOTAL_STEPS = 10000
LR_FUSION   = 5e-4
LR_HEADS    = 8e-4
WD          = 0.05
SAVE_EVERY  = 200
KEEP_EVERY  = 1000
EMA_BETA    = 0.98

torch.set_float32_matmul_precision('high')

# =====================================================
# Output / stats cache
# =====================================================
RUN_ROOT = Path("vqa"); RUN_ROOT.mkdir(parents=True, exist_ok=True)
RUN_DIR  = RUN_ROOT / datetime.now().strftime("%Y%m%d_%H%M%S"); RUN_DIR.mkdir(parents=True, exist_ok=True)

# --- PERSISTENT cache across runs (fix) ---
STATS_DIR = RUN_ROOT / "_stats_cache"
STATS_DIR.mkdir(parents=True, exist_ok=True)
print(f"[STATS] Using cache dir: {STATS_DIR.resolve()}")

# --- Image embedding cache across runs ---
IMG_CACHE_DIR = RUN_ROOT / "_img_cache"
IMG_CACHE_DIR.mkdir(parents=True, exist_ok=True)
print(f"[CACHE] Image embeddings dir: {IMG_CACHE_DIR.resolve()}")

cfg = CFG(
    channels_per_view=CHANNELS_PER_VIEW,
    input_size=TARGET_H,
    crop_size=TARGET_H,
    patch_size=PATCH_MULTIPLE,
    batch_size=max(8, min(64, BATCH_SIZE)),  # for the raw image pass during caching
    seed=SEED
)

# =====================================================
# DataLoader (tiles) for CACHING only
# =====================================================
def make_loader(paths, cfg, shuffle=False):
    ds = NPZDataset(
        paths,
        target_h=int(cfg.input_size),
        target_w=int(cfg.input_size),
        k_target=int(cfg.channels_per_view),
        scale_u16=True,             # -> float [0,1]
        pad_mode="repeat",
        sort_by_mz=False
    )
    ld = DataLoader(
        ds,
        batch_size=int(cfg.batch_size),
        shuffle=shuffle,
        num_workers=NUM_WORKERS,
        pin_memory=0,
        persistent_workers=(NUM_WORKERS > 0),
        collate_fn=collate_simple,
        drop_last=False
    )
    return ds, ld

all_paths = df["sample_path"].tolist()

# =====================================================
# Frozen image backbone (returns [B, D_img] CLS-like)
# =====================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class FrozenBackbone(nn.Module):
    def __init__(self, timm_name: str, pretrained: bool = True):
        super().__init__()
        self.m = timm.create_model(timm_name, pretrained=pretrained, num_classes=0)
        for p in self.m.parameters():
            p.requires_grad_(False)
        self.m.eval()

    @torch.no_grad()
    def forward(self, x3):  # [N,3,H,W]
        feats = self.m.forward_features(x3)
        if isinstance(feats, dict):
            if 'x_norm_clstoken' in feats:  return feats['x_norm_clstoken']
            if 'cls_token' in feats:        return feats['cls_token']
            if 'avgpool' in feats:          return feats['avgpool']
            for k in ('last_hidden_state', 'tokens', 'x'):
                if k in feats and torch.is_tensor(feats[k]):
                    t = feats[k]
                    return t[:,0] if t.dim() == 3 else t
        if torch.is_tensor(feats):
            return feats[:,0] if feats.dim() == 3 else feats
        return feats.mean(dim=-2)  # fallback

def crop_resize_to_target(x3, target=224, patch_multiple=16):
    _, _, H, W = x3.shape
    Hc = (H // patch_multiple) * patch_multiple
    Wc = (W // patch_multiple) * patch_multiple
    dh = (H - Hc) // 2; dw = (W - Wc) // 2
    if Hc > 0 and Wc > 0:
        x3 = x3[:, :, dh:dh+Hc, dw:dw+Wc]
    if (Hc, Wc) != (target, target):
        x3 = F.interpolate(x3, size=(target, target), mode="bilinear", align_corners=False)
    return x3

bb = FrozenBackbone(TIMM_ID, pretrained=PRETRAINED).to(device)
bb.eval()
autocast_ctx = (torch.autocast(device_type="cuda", dtype=torch.float16)
                if (AMP and torch.cuda.is_available()) else nullcontext())

# probe image dim
with torch.no_grad():
    dummy = torch.zeros(2, 3, TARGET_H, TARGET_W, device=device)
    d_img = bb(crop_resize_to_target(dummy, target=TARGET_SIZE, patch_multiple=PATCH_MULTIPLE)).shape[-1]

# =====================================================
# Per-channel mean/std (cached)
# =====================================================
def _paths_signature(paths):
    n = len(paths)
    head = paths[:100]; tail = paths[-100:]
    sig_src = json.dumps([n, head, tail], separators=(',', ':')).encode('utf-8')
    return hashlib.sha1(sig_src).hexdigest()

def _stats_cache_path(paths, cfg):
    sig = _paths_signature(paths)
    fname = f"mu_std_c{int(cfg.channels_per_view)}_in{int(cfg.input_size)}_{sig}.npz"
    return STATS_DIR / fname

@torch.no_grad()
def compute_channel_stats(paths, cfg):
    _, ld = make_loader(paths, cfg, shuffle=False)
    sum_c, sumsq_c, total = None, None, 0
    for batch in tqdm(ld, desc="Compute per-channel mean/std"):
        x = batch["patch"].float()   # [B,C,H,W]
        B, C, H, W = x.shape
        x = x.view(B, C, -1)
        if sum_c is None:
            sum_c   = x.sum(dim=(0,2))
            sumsq_c = (x**2).sum(dim=(0,2))
        else:
            sum_c   += x.sum(dim=(0,2))
            sumsq_c += (x**2).sum(dim=(0,2))
        total += B * H * W
    mu  = (sum_c / (total + 1e-8)).cpu().numpy()
    var = (sumsq_c / (total + 1e-8)).cpu().numpy() - mu**2
    std = np.sqrt(np.maximum(var, 1e-12))
    return mu, std

def load_or_compute_stats(paths, cfg, force_recompute=False):
    cache_path = _stats_cache_path(paths, cfg)
    if (not force_recompute) and cache_path.exists():
        z = np.load(cache_path, allow_pickle=True)
        mu, std = z["mu"], z["std"]
        print(f"[STATS] Loaded cached mu/std from: {cache_path}")
        return mu, std
    print("[STATS] Computing mu/std (once) ...")
    mu, std = compute_channel_stats(paths, cfg)
    np.savez_compressed(cache_path, mu=mu, std=std, meta=dict(
        channels_per_view=int(cfg.channels_per_view),
        input_size=int(cfg.input_size),
        n_paths=len(paths)
    ))
    print(f"[STATS] Saved mu/std to: {cache_path}")
    return mu, std

# stats once (now persists across runs)
mu_c, std_c = load_or_compute_stats(all_paths, cfg, force_recompute=False)
mu_t = torch.tensor(mu_c, device=device).view(1, -1, 1, 1)
sd_t = torch.tensor(std_c, device=device).view(1, -1, 1, 1).clamp_min(1e-6)
STEP = CHANNELS_PER_STEP if TARGET_SIZE == 224 else max(4, CHANNELS_PER_STEP // 2)

# =====================================================
# IMAGE EMBEDDING CACHE (mean-CLS per tile)
# =====================================================
def _embed_key(path):
    h = hashlib.sha1(f"{TIMM_ID}|{PATCH_MULTIPLE}|{TARGET_SIZE}|{path}".encode()).hexdigest()
    return IMG_CACHE_DIR / f"{h}.npy"

@torch.no_grad()
def encode_image_mean_cls_batch(patch_bchw):  # vectorized encode for a batch
    x = patch_bchw.to(device=device).float()   # [B,C,H,W] in [0,1]
    B, C, H, W = x.shape
    x = (x - mu_t[:, :C]) / sd_t[:, :C]
    x_flat = x.permute(0, 2, 3, 1).contiguous().view(B * C, 1, H, W)
    x_rgb  = x_flat.repeat(1, 3, 1, 1)

    outs = []
    with autocast_ctx, torch.inference_mode():
        for s in range(0, B * C, STEP):
            e = min(s + STEP, B * C)
            xr = crop_resize_to_target(x_rgb[s:e], target=TARGET_SIZE, patch_multiple=PATCH_MULTIPLE)
            cls = bb(xr).float()  # [N, D_img]
            outs.append(cls)
    cls_all = torch.cat(outs, dim=0)           # [B*C, D_img]
    D = cls_all.shape[1]
    z = cls_all.view(B, C, D).mean(dim=1)      # [B, D_img]
    return F.normalize(z, dim=-1)

def build_img_cache(paths, cfg):
    ds, ld = make_loader(paths, cfg, shuffle=False)
    done = 0
    for batch in tqdm(ld, desc="[CACHE] image embeddings"):
        z = encode_image_mean_cls_batch(batch["patch"])
        for pth, zi in zip(batch["path"], z):  # NPZDataset must return "path"
            out_p = _embed_key(pth)
            if not out_p.exists():
                np.save(out_p, zi.detach().cpu().numpy())
                done += 1
    print(f"[CACHE] Wrote {done} new embeddings.")

# Build cache once (skips existing files)
missing = [p for p in all_paths if not _embed_key(p).exists()]
if missing:
    print(f"[CACHE] {len(missing)} embeddings to compute...")
    build_img_cache(all_paths, cfg)
else:
    print("[CACHE] All embeddings already computed.")

# =====================================================
# Trainable HF text encoder + projection (partial tuning)
# =====================================================
class HFTextEnc(nn.Module):
    def __init__(self, name: str, out_dim: int,
                 train_strategy: str = "last_k",
                 last_k: int = 2,
                 freeze_embed: bool = True):
        super().__init__()
        self.model = AutoModel.from_pretrained(name)
        self.name  = name
        hid = self.model.config.hidden_size
        self.proj = nn.Linear(hid, out_dim)

        # default: freeze everything
        for p in self.model.parameters():
            p.requires_grad_(False)

        # optionally unfreeze embeddings
        if not freeze_embed:
            for p in self.model.get_input_embeddings().parameters():
                p.requires_grad_(True)

        if hasattr(self.model.config, "hidden_dropout_prob"):
            self.model.config.hidden_dropout_prob = 0.0
        if hasattr(self.model.config, "attention_probs_dropout_prob"):
            self.model.config.attention_probs_dropout_prob = 0.0

        # unfreeze per strategy
        if train_strategy == "none":
            pass
        elif train_strategy == "proj":
            for p in self.proj.parameters():
                p.requires_grad_(True)
        elif train_strategy == "last_k":
            enc = getattr(self.model, "encoder", None)
            if enc is None:
                enc = getattr(self.model, "transformer", None)
            layers = None
            for cand in ("layer", "layers", "block", "h"):
                if hasattr(enc, cand):
                    layers = getattr(enc, cand)
                    break
            if layers is None:
                raise RuntimeError("Cannot locate transformer blocks for last_k tuning.")
            k = max(1, min(last_k, len(layers)))
            for layer in layers[-k:]:
                for p in layer.parameters():
                    p.requires_grad_(True)
            for p in self.proj.parameters():
                p.requires_grad_(True)
        else:
            raise ValueError(f"Unknown TEXT_TRAIN_STRATEGY={train_strategy}")

    def forward(self, ids, mask):
        out = self.model(input_ids=ids, attention_mask=mask, return_dict=True)
        x = (out.last_hidden_state * mask.unsqueeze(-1)).sum(1) / (mask.sum(1, keepdim=True) + 1e-6)
        return self.proj(x)

    def param_groups(self, lr_base: float, lr_proj: float):
        base_params, proj_params = [], []
        for n, p in self.named_parameters():
            if not p.requires_grad:
                continue
            if n.startswith("proj."):
                proj_params.append(p)
            else:
                base_params.append(p)
        groups = []
        if base_params:
            groups.append({"params": base_params, "lr": lr_base})
        if proj_params:
            groups.append({"params": proj_params, "lr": lr_proj})
        return groups

# =====================================================
# Fusion + Heads
# =====================================================
class Fusion(nn.Module):
    def __init__(self, d_img, d_txt, d_out):
        super().__init__()
        self.ln_img = nn.LayerNorm(d_img)
        self.ln_txt = nn.LayerNorm(d_txt)
        self.mlp = nn.Sequential(
            nn.Linear(d_img + d_txt, d_out),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(d_out, d_out),
        )
    def forward(self, zi, zt):
        x = torch.cat([self.ln_img(zi), self.ln_txt(zt)], dim=-1)
        return self.mlp(x)

class VQAHeads(nn.Module):
    def __init__(self, embed_dim: int, cls_spaces: dict):
        super().__init__()
        # Heads have ONLY the kept classes (unknown excluded).
        self.cls_spaces = {k: v for k, v in cls_spaces.items()}
        self.heads = nn.ModuleDict({k: nn.Linear(embed_dim, max(1, len(v))) for k, v in self.cls_spaces.items()})
        for lin in self.heads.values():
            nn.init.trunc_normal_(lin.weight, std=0.02); nn.init.zeros_(lin.bias)
    def forward(self, z):  # [B, D_fused]
        return {"cls": {k: h(z) for k, h in self.heads.items()}}

# =====================================================
# Text utils (pre-tokenize once)
# =====================================================
tok = AutoTokenizer.from_pretrained(HF_TEXT_MODEL)
FIXED_QUESTION = "what is the organism?"
_tok_once = tok(FIXED_QUESTION, padding=False, truncation=True, max_length=TEXT_MAX_LEN, return_tensors="pt")
_ids_base  = _tok_once["input_ids"]
_mask_base = _tok_once["attention_mask"]

def encode_text_fixed(batch_size: int, text_enc):
    ids  = _ids_base.to(device).expand(batch_size, -1).contiguous()
    mask = _mask_base.to(device).expand(batch_size, -1).contiguous()
    zt = text_enc(ids, mask)  # [B, d_txt]
    return F.normalize(zt, dim=-1)

# =====================================================
# Dataset over cached embeddings
# =====================================================
class EmbedDataset(torch.utils.data.Dataset):
    def __init__(self, df, cls_spaces, task_cols):
        self.df = df.reset_index(drop=True)
        self.task_cols = task_cols
        self.cls_spaces = {k: list(v) for k, v in cls_spaces.items()}
        # maps without 'unknown'
        self.label_maps = {k: {c:i for i,c in enumerate(v)} for k,v in self.cls_spaces.items()}

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        emb_path = _embed_key(row["sample_path"])
        if not emb_path.exists():
            raise FileNotFoundError(f"Missing cached embedding for {row['sample_path']}, run cache first.")
        z = np.load(emb_path, mmap_mode="r")
        out = {"z_img": torch.from_numpy(np.array(z, dtype=np.float32))}
        # encode labels; set -100 if unknown or filtered out
        for head, col in self.task_cols.items():
            val = str(row.get(col, "unknown"))
            if (val == "unknown") or (val not in self.label_maps[head]):
                y = -100
            else:
                y = self.label_maps[head][val]
            out[f"y_cls_{head}"] = torch.tensor(y, dtype=torch.long)
        out["question"] = FIXED_QUESTION
        return out

embed_ds = EmbedDataset(df, cls_spaces, TASK_COLS)

# Optional simple rebalancing example (organism): downweight unknown heavily
if REBALANCE:
    org_vals = df["organism"].astype(str).tolist()
    weights = np.array([0.05 if v=="unknown" else 1.0 for v in org_vals], dtype=np.float32)
    sampler = WeightedRandomSampler(weights=weights, num_samples=len(weights), replacement=True)
    train_loader = DataLoader(embed_ds, batch_size=BATCH_SIZE, sampler=sampler,
                              num_workers=NUM_WORKERS, pin_memory=False, drop_last=False)
else:
    train_loader = DataLoader(embed_ds, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=NUM_WORKERS, pin_memory=False, drop_last=False)

# =====================================================
# Build trainable modules
# =====================================================
if USE_TEXT:
    text_enc = HFTextEnc(
        HF_TEXT_MODEL, TEXT_OUT_DIM,
        train_strategy=TEXT_TRAIN_STRATEGY,
        last_k=TEXT_TRAIN_LAST_K,
        freeze_embed=TEXT_FREEZE_EMBED
    ).to(device)
    d_txt = TEXT_OUT_DIM
    FUSED_DIM = d_img + d_txt
    fusion = Fusion(d_img, d_txt, FUSED_DIM).to(device)
else:
    text_enc = None
    d_txt = 0
    FUSED_DIM = d_img
    fusion = None

heads = VQAHeads(embed_dim=FUSED_DIM, cls_spaces=cls_spaces).to(device)

# Optimizer with per-group LRs (fused if available)
opt_groups = []
if USE_TEXT:
    opt_groups.extend(text_enc.param_groups(lr_base=LR_TXT_BASE, lr_proj=LR_TXT_PROJ))
    opt_groups.append({"params": fusion.parameters(),   "lr": LR_FUSION})
opt_groups.append({"params": heads.parameters(),          "lr": LR_HEADS})

try:
    opt = AdamW(opt_groups, weight_decay=WD, fused=USE_FUSED_ADAMW)
except TypeError:
    opt = AdamW(opt_groups, weight_decay=WD)

# Cosine scheduler (per-step)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=TOTAL_STEPS)

# =====================================================
# Logging / checkpoints
# =====================================================
config = {
    "timm_id": TIMM_ID,
    "patch_multiple": PATCH_MULTIPLE,
    "target_size": TARGET_SIZE,
    "pretrained": PRETRAINED,
    "hf_text_model": HF_TEXT_MODEL if USE_TEXT else None,
    "use_text": USE_TEXT,
    "text_max_len": TEXT_MAX_LEN,
    "text_out_dim": TEXT_OUT_DIM if USE_TEXT else 0,
    "text_tune": {
        "strategy": TEXT_TRAIN_STRATEGY,
        "last_k": TEXT_TRAIN_LAST_K,
        "freeze_embed": TEXT_FREEZE_EMBED,
        "lr_base": LR_TXT_BASE,
        "lr_proj": LR_TXT_PROJ,
        "fused_adamw": USE_FUSED_ADAMW
    },
    "channels_per_view": CHANNELS_PER_VIEW,
    "channels_per_step": CHANNELS_PER_STEP,
    "batch_size": BATCH_SIZE,
    "input_size": TARGET_H,
    "total_steps": TOTAL_STEPS,
    "lrs": {"fusion": LR_FUSION, "heads": LR_HEADS},
    "weight_decay": WD,
    "ema_beta": EMA_BETA,
    "embed_dim_image": int(d_img),
    "embed_dim_text": int(d_txt),
    "embed_dim_fused": int(FUSED_DIM),
    "cls_spaces": {k: list(v) for k, v in cls_spaces.items()},
    "rebalance_sampler": REBALANCE,
    "min_samples_per_class": MIN_SAMPLES_PER_CLASS,
    "drop_rows_all_unknown": DROP_ROWS_ALL_UNKNOWN,
}
RUN_DIR.mkdir(exist_ok=True, parents=True)
(RUN_DIR / "config.json").write_text(json.dumps(config, indent=2))

log_f = open(RUN_DIR / "train_log.csv", "w", newline="")
log_w = csv.writer(log_f); log_w.writerow(["step", "loss", "ema_loss", "lr_min", "lr_max"]); log_f.flush()

def _cpu_state_dict(m):
    return {k: v.detach().cpu() for k, v in m.state_dict().items()}

def save_ckpt(tag, step):
    ckpt = {
        "step": step,
        "heads": _cpu_state_dict(heads),
        "optimizer": opt.state_dict(),
        "scheduler": sched.state_dict(),
        "cls_spaces": cls_spaces,
        "embed_dim_image": int(d_img),
        "embed_dim_text": int(d_txt),
        "embed_dim_fused": int(FUSED_DIM),
        "backbone": dict(timm_id=TIMM_ID, patch_multiple=PATCH_MULTIPLE,
                         target_size=TARGET_SIZE, pretrained=PRETRAINED),
        "text_model": dict(id=HF_TEXT_MODEL, out_dim=TEXT_OUT_DIM, used=USE_TEXT),
        "timestamp": datetime.now().isoformat(timespec="seconds"),
        "min_samples_per_class": MIN_SAMPLES_PER_CLASS,
        "drop_rows_all_unknown": DROP_ROWS_ALL_UNKNOWN,
    }
    if USE_TEXT:
        ckpt["text_enc"] = _cpu_state_dict(text_enc)
        ckpt["fusion"]   = _cpu_state_dict(fusion)
    path = RUN_DIR / f"{tag}.pt"
    torch.save(ckpt, path)
    return path

# =====================================================
# Training loop (on cached embeddings)
# =====================================================
ema_loss = None
best_ema = float("inf")

if USE_TEXT:
    text_enc.train()
    fusion.train()
heads.train()

pbar = tqdm(range(TOTAL_STEPS), desc=f"VQA(train cached + heads {'+ text' if USE_TEXT else ''}) -> {RUN_DIR.name}")
loader_it = iter(train_loader)

for step in pbar:
    try:
        batch = next(loader_it)
    except StopIteration:
        loader_it = iter(train_loader)
        batch = next(loader_it)

    # image emb (already cached and normalized)
    z_img = batch["z_img"].to(device)  # [B, D_img]

    if USE_TEXT:
        # pre-tokenized fixed question path (fast)
        z_txt = encode_text_fixed(z_img.shape[0], text_enc)  # [B, d_txt]
        z = fusion(z_img, z_txt)
    else:
        z = z_img

    # heads
    out = heads(z)["cls"]

    # per-head masked loss (ignore unknown=-100)
    loss_list = []
    for field in cls_spaces.keys():
        key = f"y_cls_{field}"
        if key not in batch:
            continue
        y = batch[key].to(device)  # [B]
        logits = out[field]        # [B, num_classes]
        mask = (y != -100)
        if mask.any():
            loss_list.append(F.cross_entropy(logits[mask], y[mask], label_smoothing=LABEL_SMOOTH))

    if not loss_list:
        # No usable labels present for this batch; skip optimization/logging but keep the loop moving.
        continue

    loss = torch.stack(loss_list).sum()

    opt.zero_grad(set_to_none=True)
    loss.backward()
    if USE_TEXT:
        torch.nn.utils.clip_grad_norm_(list(chain(text_enc.parameters(), fusion.parameters(), heads.parameters())), 1.0)
    else:
        torch.nn.utils.clip_grad_norm_(list(heads.parameters()), 1.0)
    opt.step()
    sched.step()

    # logs / EMA / ckpts
    cur_loss = float(loss.detach().item())
    ema_loss = cur_loss if ema_loss is None else (EMA_BETA * ema_loss + (1 - EMA_BETA) * cur_loss)
    lrs = [pg["lr"] for pg in opt.param_groups]
    pbar.set_postfix(loss=f"{cur_loss:.4f}", ema=f"{ema_loss:.4f}", lr_min=f"{min(lrs):.2e}", lr_max=f"{max(lrs):.2e}")

    log_w.writerow([step, f"{cur_loss:.6f}", f"{ema_loss:.6f}", f"{min(lrs):.6e}", f"{max(lrs):.6e}"])
    if step % 20 == 0:
        log_f.flush()

    if (step + 1) % SAVE_EVERY == 0:
        save_ckpt("last", step + 1)
    if KEEP_EVERY and (step + 1) % KEEP_EVERY == 0:
        save_ckpt(f"step_{step + 1:06d}", step + 1)
    if ema_loss < best_ema:
        best_ema = ema_loss
        save_ckpt("best", step + 1)

# final save
save_ckpt("last", TOTAL_STEPS)
log_f.close()
print(f"Saved checkpoints & logs to: {RUN_DIR}")

[FILTER] Dropped 0 rows with all tasks == 'unknown'.

[SUMMARY] Class counts after filtering (unknown shown but ignored during loss):
 - organism: total=3938, classes=2 (+unknown)
organism
homo sapiens    2122
mus musculus    1799
unknown           17
 - polarity: total=3938, classes=2 (+unknown)
polarity
negative    2397
positive    1541
 - organ: total=3938, classes=6 (+unknown)
Organism_Part
kidney           1353
unknown          1020
brain             643
kidney cortex     368
liver             225
lung              212
breast            117
 - condition: total=3938, classes=8 (+unknown)
Condition
n/a             1259
unknown          781
biopsy           598
wildtype         587
frozen           194
tumor            164
diseased         141
cancer           109
fresh frozen     105
 - analyzerType: total=3938, classes=4 (+unknown)
analyzerType
orbitrap        1965
timstof flex     627
12t fticr        593
unknown          443
fticr            310
 - ionisationSource: total=3938, c

VQA(train cached + heads + text) -> 20251015_171627: 100%|██████████| 10000/10000 [2:35:29<00:00,  1.07it/s, ema=2.0028, loss=2.2207, lr_max=0.00e+00, lr_min=0.00e+00] 


Saved checkpoints & logs to: vqa\20251015_171627


### 5-Fold Cross Validation

In [44]:
# vqa_cv5.py
# 5-fold cross-validation for cached-embedding VQA (train_vqa_fast style)
# - Uses same normalization + class spaces
# - GroupKFold by dataset_id to avoid leakage
# - Trains heads (+ optional text-tower last_k) on train folds
# - Evaluates on held-out fold, saves CSV/JSON, and seaborn plots
# - tqdm progress per epoch/batch
# - Cleans label variants (N/A, n a, —, etc.) -> 'unknown'
# - Optional rare-class filtering (remap rare labels to 'unknown')

import os, csv, json, math, hashlib, warnings, re
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Optional

import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader, Subset

from sklearn.model_selection import GroupKFold
from sklearn.metrics import accuracy_score, f1_score, classification_report

import seaborn as sns
import matplotlib.pyplot as plt

from transformers import AutoTokenizer, AutoModel

warnings.filterwarnings("ignore", category=UserWarning)

# -------------------------
# Paths & I/O
# -------------------------
IDX_PARQUET = "metaspace_images_dump/msi_fm_samples3.parquet"
MAN_PARQUET = "metaspace_images_dump/manifest_expanded.parquet"

RUN_ROOT = Path("vqa")
IMG_CACHE_DIR = RUN_ROOT / "_img_cache"  # must exist from your previous caching step
STATS_DIR = RUN_ROOT / "_stats_cache"    # not used here but kept for parity

CV_DIR = RUN_ROOT / ("cv5_" + datetime.now().strftime("%Y%m%d_%H%M%S"))
CV_DIR.mkdir(parents=True, exist_ok=True)

# -------------------------
# Config (match train_vqa_fast)
# -------------------------
SEED = 6740
torch.manual_seed(SEED); np.random.seed(SEED)

USE_TEXT   = True      # set False to try image-only baseline
LABEL_SMOOTH = 0.05

# Timmed image encoder was already used to create cached embeddings
TIMM_ID        = "vit_small_patch14_reg4_dinov2.lvd142m"
PATCH_MULTIPLE = 14
TARGET_SIZE    = 518

# HF text tower partial tuning setup (same as fast script)
HF_TEXT_MODEL       = "sentence-transformers/all-MiniLM-L6-v2"
TEXT_MAX_LEN        = 64
TEXT_OUT_DIM        = 384
TEXT_TRAIN_STRATEGY = "last_k"  # ["none","proj","last_k"]
TEXT_TRAIN_LAST_K   = 2
TEXT_FREEZE_EMBED   = True
LR_TXT_BASE         = 5e-5
LR_TXT_PROJ         = 1e-4

# Heads / fusion
LR_FUSION = 5e-4
LR_HEADS  = 8e-4
WD        = 0.05
BATCH_SIZE = 512          # cached embeddings => large batch ok
EPOCHS     = 25           # keep modest for CV
AMP        = False        # cached embeddings: little gain; keep False for reproducibility
NUM_WORKERS = 0           # bump if your I/O can handle it

# Optional rare-class handling: remap labels with corpus count < N to 'unknown'
MIN_SAMPLES_PER_CLASS = {
    "organism": 100,
    "polarity": 100,
    "Organism_Part": 100,
    "Condition": 100,
    "analyzerType": 100,
    "ionisationSource": 100,
}

# -------------------------
# Utilities
# -------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
autocast_ctx = torch.cuda.amp.autocast(enabled=False)

NA_LIKE = {
    "na", "n/a", "n\\a", "n.a.", "n . a .", "n a", "n - a", "not applicable",
    "none", "null", "nan", "-", "—", "unknown", ""
}

def _canon_base(x):
    if x is None:
        return "unknown"
    s = str(x).strip().lower()
    s = re.sub(r"\s+", " ", s)  # collapse whitespace
    s = s.replace("_", " ").replace(".", ".").strip()
    if s in NA_LIKE:
        return "unknown"
    # common OCR-ish or punctuation variants to unknown
    if re.fullmatch(r"n[\s\./\\-]*a", s):
        return "unknown"
    return s if s else "unknown"

def normalize_organism(x):
    s = _canon_base(x)
    if s == "unknown": return s
    if "mus musculus" in s or s in {"mouse", "mouse brain", "mus", "m. musculus"}:
        return "mus musculus"
    if "homo sapiens" in s or s in {"human", "h. sapiens"}:
        return "homo sapiens"
    return "unknown"

def normalize_polarity(x):
    s = _canon_base(x)
    if s in {"pos", "+", "positive", "positive ion mode", "positive mode"}:  return "positive"
    if s in {"neg", "-", "negative", "negative ion mode", "negative mode"}:  return "negative"
    return "unknown"

def normalize_text(x): 
    return _canon_base(x)

def _build_vocab(series):
    vals = sorted(set(series.dropna().astype(str).tolist()))
    if "unknown" not in vals:
        vals.append("unknown")
    return vals

def _embed_key(path: str, timm_id: str, patch_multiple: int, target_size: int) -> Path:
    h = hashlib.sha1(f"{timm_id}|{patch_multiple}|{target_size}|{path}".encode()).hexdigest()
    return IMG_CACHE_DIR / f"{h}.npy"

def _remap_rare_classes(df: pd.DataFrame, col: str, min_count: int) -> pd.Series:
    counts = df[col].value_counts()
    rare = set(counts[counts < min_count].index.tolist())
    rare.discard("unknown")  # always keep unknown
    if not rare:
        return df[col]
    return df[col].apply(lambda v: "unknown" if v in rare else v)

# -------------------------
# Model bits (match your fast script)
# -------------------------
class HFTextEnc(nn.Module):
    def __init__(self, name: str, out_dim: int,
                 train_strategy: str = "last_k",
                 last_k: int = 2,
                 freeze_embed: bool = True):
        super().__init__()
        self.model = AutoModel.from_pretrained(name)
        hid = self.model.config.hidden_size
        self.proj = nn.Linear(hid, out_dim)

        # freeze all
        for p in self.model.parameters():
            p.requires_grad_(False)
        if not freeze_embed:
            for p in self.model.get_input_embeddings().parameters():
                p.requires_grad_(True)

        # dropouts off
        if hasattr(self.model.config, "hidden_dropout_prob"):
            self.model.config.hidden_dropout_prob = 0.0
        if hasattr(self.model.config, "attention_probs_dropout_prob"):
            self.model.config.attention_probs_dropout_prob = 0.0

        if train_strategy == "none":
            pass
        elif train_strategy == "proj":
            for p in self.proj.parameters(): p.requires_grad_(True)
        elif train_strategy == "last_k":
            enc = getattr(self.model, "encoder", None) or getattr(self.model, "transformer", None)
            layers = None
            for cand in ("layer","layers","block","h"):
                if hasattr(enc, cand):
                    layers = getattr(enc, cand); break
            if layers is None:
                raise RuntimeError("Cannot locate transformer blocks for last_k tuning.")
            k = max(1, min(last_k, len(layers)))
            for layer in layers[-k:]:
                for p in layer.parameters():
                    p.requires_grad_(True)
            for p in self.proj.parameters(): p.requires_grad_(True)
        else:
            raise ValueError(f"Unknown train_strategy={train_strategy}")

    def forward(self, ids, mask):
        out = self.model(input_ids=ids, attention_mask=mask, return_dict=True)
        x = (out.last_hidden_state * mask.unsqueeze(-1)).sum(1) / (mask.sum(1, keepdim=True) + 1e-6)
        return self.proj(x)

    def param_groups(self, lr_base: float, lr_proj: float):
        base_params, proj_params = [], []
        for n, p in self.named_parameters():
            if not p.requires_grad: continue
            (proj_params if n.startswith("proj.") else base_params).append(p)
        groups = []
        if base_params: groups.append({"params": base_params, "lr": lr_base})
        if proj_params: groups.append({"params": proj_params, "lr": lr_proj})
        return groups

class Fusion(nn.Module):
    def __init__(self, d_img, d_txt, d_out):
        super().__init__()
        self.ln_img = nn.LayerNorm(d_img)
        self.ln_txt = nn.LayerNorm(d_txt)
        self.mlp = nn.Sequential(
            nn.Linear(d_img + d_txt, d_out),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(d_out, d_out),
        )
    def forward(self, zi, zt):
        x = torch.cat([self.ln_img(zi), self.ln_txt(zt)], dim=-1)
        return self.mlp(x)

class VQAHeads(nn.Module):
    def __init__(self, embed_dim: int, cls_spaces: dict):
        super().__init__()
        self.heads = nn.ModuleDict({k: nn.Linear(embed_dim, len(v)) for k, v in cls_spaces.items()})
        for lin in self.heads.values():
            nn.init.trunc_normal_(lin.weight, std=0.02); nn.init.zeros_(lin.bias)
    def forward(self, z):  # [B, D]
        return {"cls": {k: h(z) for k, h in self.heads.items()}}

# -------------------------
# Dataframe + normalization (same as training) + rare-class remap
# -------------------------
def build_eval_dataframe() -> pd.DataFrame:
    df_idx = pd.read_parquet(IDX_PARQUET)
    df_man = pd.read_parquet(MAN_PARQUET)

    need_cols = ["dataset_id","organism","polarity","Organism_Part","Condition","analyzerType","ionisationSource"]
    man_sub = df_man[[c for c in need_cols if c in df_man.columns]].drop_duplicates("dataset_id")

    df = df_idx.merge(man_sub, on="dataset_id", how="left", suffixes=("", "_man"))
    df = df.loc[:, ~df.columns.duplicated()]
    if "sample_path" not in df.columns:
        raise ValueError("Index parquet must include 'sample_path' pointing to .npz files.")

    # Normalize labels
    norm_map = [
        ("organism", normalize_organism),
        ("polarity", normalize_polarity),
        ("Organism_Part", normalize_text),
        ("Condition", normalize_text),
        ("analyzerType", normalize_text),
        ("ionisationSource", normalize_text),
    ]
    for col, fn in norm_map:
        if col in df.columns:
            df[col] = df[col].map(fn)
        else:
            df[col] = "unknown"

    # Rare-class remap to 'unknown' (to be ignored)
    for col, min_count in MIN_SAMPLES_PER_CLASS.items():
        if col in df.columns and min_count and min_count > 0:
            df[col] = _remap_rare_classes(df, col, min_count)

    return df

# -------------------------
# Dataset over cached embeddings
# -------------------------
class EmbedDataset(torch.utils.data.Dataset):
    def __init__(self, df, cls_spaces):
        self.df = df.reset_index(drop=True)
        self.cls_spaces = cls_spaces
        self.label_maps = {k: {c:i for i,c in enumerate(v)} for k,v in cls_spaces.items()}
        self.paths = self.df["sample_path"].tolist()

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        ep = _embed_key(row["sample_path"], TIMM_ID, PATCH_MULTIPLE, TARGET_SIZE)
        if not ep.exists():
            raise FileNotFoundError(f"Missing cached embedding for {row['sample_path']} -> {ep}")
        z = np.load(ep, mmap_mode="r")
        out = {"z_img": torch.from_numpy(np.array(z, dtype=np.float32))}
        # labels, -100 for unknown so we can ignore in loss/metrics
        for field, col in [
            ("organism","organism"),
            ("polarity","polarity"),
            ("organ","Organism_Part"),
            ("condition","Condition"),
            ("analyzerType","analyzerType"),
            ("ionisationSource","ionisationSource"),
        ]:
            val = str(row.get(col, "unknown"))
            y = self.label_maps[field].get(val, self.label_maps[field]["unknown"])
            if val == "unknown": y = -100
            out[f"y_cls_{field}"] = torch.tensor(y, dtype=torch.long)
        return out

# -------------------------
# Metrics (robust to absent classes)
# -------------------------
def compute_metrics(y_true: np.ndarray, y_pred: np.ndarray, labels_full: List[str]):
    if y_true.size == 0:
        return {"accuracy": None, "macro_f1": None, "per_class_f1": {}}
    acc = float(accuracy_score(y_true, y_pred))
    macro = float(f1_score(y_true, y_pred, average="macro"))
    present = sorted(set(np.unique(y_true)).union(set(np.unique(y_pred))))
    present_names = [labels_full[i] for i in present] if present else []
    report = classification_report(
        y_true, y_pred, labels=present,
        target_names=present_names, output_dict=True, zero_division=0
    ) if present else {}
    per_class = {cls: float(v["f1-score"]) for cls, v in report.items() if cls in present_names}
    return {"accuracy": acc, "macro_f1": macro, "per_class_f1": per_class}

# -------------------------
# Train / Eval loop (per fold)
# -------------------------
FIXED_QUESTION = "what is the organism?"
def build_modules(d_img: int, cls_spaces: Dict[str, List[str]]):
    if USE_TEXT:
        text_enc = HFTextEnc(
            HF_TEXT_MODEL, TEXT_OUT_DIM,
            train_strategy=TEXT_TRAIN_STRATEGY,
            last_k=TEXT_TRAIN_LAST_K,
            freeze_embed=TEXT_FREEZE_EMBED
        ).to(device)
        fusion = Fusion(d_img, TEXT_OUT_DIM, d_img + TEXT_OUT_DIM).to(device)
        heads = VQAHeads(embed_dim=d_img + TEXT_OUT_DIM, cls_spaces=cls_spaces).to(device)
    else:
        text_enc, fusion = None, None
        heads = VQAHeads(embed_dim=d_img, cls_spaces=cls_spaces).to(device)

    # params & opt
    groups = []
    if USE_TEXT:
        groups.extend(text_enc.param_groups(lr_base=LR_TXT_BASE, lr_proj=LR_TXT_PROJ))
        groups.append({"params": fusion.parameters(), "lr": LR_FUSION})
    groups.append({"params": heads.parameters(), "lr": LR_HEADS})

    try:
        opt = AdamW(groups, weight_decay=WD, fused=torch.cuda.is_available())
    except TypeError:
        opt = AdamW(groups, weight_decay=WD)
    return text_enc, fusion, heads, opt

def forward_batch(batch, text_ctx):
    z_img = batch["z_img"].to(device, non_blocking=True)
    if USE_TEXT:
        ids, mask, text_enc, fusion = text_ctx
        ids = ids.expand(z_img.shape[0], -1).contiguous()
        mask = mask.expand(z_img.shape[0], -1).contiguous()
        # normalize text only in eval; train lets gradients through proj/last_k
        z_txt = text_enc(ids, mask)
        if not text_enc.training:
            z_txt = F.normalize(z_txt, dim=-1)
        z = fusion(z_img, z_txt)
    else:
        z = z_img
    return z

def train_one_epoch(loader, heads, opt, text_ctx, fold_idx: int, ep: int):
    heads.train()
    if USE_TEXT:
        text_ctx[2].train()   # text_enc
        text_ctx[3].train()   # fusion

    ema = None
    pbar = tqdm(loader, total=len(loader), desc=f"Fold {fold_idx} | Epoch {ep}", leave=False)
    for batch in pbar:
        z = forward_batch(batch, text_ctx)
        out = heads(z)["cls"]
        # masked multi-head loss
        loss_list = []
        for field in heads.heads.keys():
            y = batch[f"y_cls_{field}"].to(device, non_blocking=True)
            mask = (y != -100)
            if mask.any():
                loss_list.append(F.cross_entropy(out[field][mask], y[mask], label_smoothing=LABEL_SMOOTH))
        if not loss_list:
            continue
        loss = torch.stack(loss_list).sum()

        opt.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(heads.parameters(), 1.0)
        if USE_TEXT:
            torch.nn.utils.clip_grad_norm_(list(text_ctx[2].parameters()) + list(text_ctx[3].parameters()), 1.0)
        opt.step()

        cur = float(loss.detach().item())
        ema = cur if ema is None else 0.98 * ema + 0.02 * cur
        pbar.set_postfix(loss=f"{cur:.4f}", ema=f"{ema:.4f}")
    return ema if ema is not None else 0.0

@torch.no_grad()
def eval_loader(loader, heads, text_ctx, cls_spaces):
    heads.eval()
    if USE_TEXT:
        text_ctx[2].eval()
        text_ctx[3].eval()

    rows = []
    for batch in loader:
        z = forward_batch(batch, text_ctx)
        out = heads(z)["cls"]
        rec = {}
        for field in cls_spaces.keys():
            logits = out[field].detach().cpu().numpy()
            pred = logits.argmax(axis=1)
            true = batch[f"y_cls_{field}"].detach().cpu().numpy()
            rec[f"pred_{field}"] = pred
            rec[f"true_{field}"] = true
        rows.append(rec)

    # stack and compute metrics (skip unknown=-100)
    metrics = {}
    for field, labels in cls_spaces.items():
        pred = np.concatenate([r[f"pred_{field}"] for r in rows], axis=0)
        true = np.concatenate([r[f"true_{field}"] for r in rows], axis=0)
        m = (true != -100)
        y_true = true[m]
        y_pred = pred[m]
        mtr = compute_metrics(y_true, y_pred, labels)
        metrics[field] = {"num_eval": int(y_true.size), "accuracy": mtr["accuracy"], "macro_f1": mtr["macro_f1"]}
    return metrics

# -------------------------
# Main CV
# -------------------------
def main():
    # Build dataframe & class spaces
    df = build_eval_dataframe()

    # Class spaces MUST come from entire corpus (as in training)
    cls_spaces = {
        "organism":         _build_vocab(df["organism"]),
        "polarity":         _build_vocab(df["polarity"]),
        "organ":            _build_vocab(df["Organism_Part"]),
        "condition":        _build_vocab(df["Condition"]),
        "analyzerType":     _build_vocab(df["analyzerType"]),
        "ionisationSource": _build_vocab(df["ionisationSource"]),
    }

    # Check cached embeddings exist for all samples
    missing = [p for p in df["sample_path"].tolist()
               if not _embed_key(p, TIMM_ID, PATCH_MULTIPLE, TARGET_SIZE).exists()]
    if missing:
        raise FileNotFoundError(
            f"{len(missing)} embeddings missing under {IMG_CACHE_DIR}. "
            f"Example:\n  {missing[0]}\nPlease run your caching step first."
        )

    # Probe image embedding dim
    any_ep = _embed_key(df.iloc[0]["sample_path"], TIMM_ID, PATCH_MULTIPLE, TARGET_SIZE)
    d_img = int(np.load(any_ep, mmap_mode="r").shape[0])
    fused_dim = d_img + (TEXT_OUT_DIM if USE_TEXT else 0)

    # Build dataset
    ds_all = EmbedDataset(df, cls_spaces)

    # Tokenizer / fixed question once
    if USE_TEXT:
        tok = AutoTokenizer.from_pretrained(HF_TEXT_MODEL)
        toks_once = tok(FIXED_QUESTION, padding=False, truncation=True, max_length=TEXT_MAX_LEN, return_tensors="pt")
        ids_base  = toks_once["input_ids"].to(device)
        mask_base = toks_once["attention_mask"].to(device)
    else:
        ids_base = mask_base = None

    # GroupKFold by dataset_id (avoid leakage)
    groups = df["dataset_id"].astype(str).values
    gkf = GroupKFold(n_splits=5)

    all_fold_rows = []     # tidy rows for plotting
    fold_summaries = []    # per-fold JSON

    for fold_idx, (tr_idx, va_idx) in enumerate(gkf.split(df, groups=groups, y=None), start=1):
        print(f"\n===== Fold {fold_idx}/5 =====  train={len(tr_idx)}  val={len(va_idx)}")
        ds_tr = Subset(ds_all, tr_idx.tolist())
        ds_va = Subset(ds_all, va_idx.tolist())

        ld_tr = DataLoader(ds_tr, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=False, drop_last=False)
        ld_va = DataLoader(ds_va, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=False, drop_last=False)

        # Build fresh modules per fold
        text_enc, fusion, heads, opt = build_modules(d_img, cls_spaces)

        # Text context tuple for forward
        text_ctx = [ids_base, mask_base, text_enc, fusion] if USE_TEXT else [None, None, None, None]

        # Train
        for ep in range(1, EPOCHS + 1):
            ema = train_one_epoch(ld_tr, heads, opt, text_ctx, fold_idx=fold_idx, ep=ep)
            print(f"  Epoch {ep}/{EPOCHS} :: EMA loss ~ {ema:.4f}")

        # Eval
        metrics = eval_loader(ld_va, heads, text_ctx, cls_spaces)

        # Save fold artifacts
        fold_dir = CV_DIR / f"fold_{fold_idx}"
        fold_dir.mkdir(parents=True, exist_ok=True)
        (fold_dir / "metrics.json").write_text(json.dumps(metrics, indent=2))

        # Flatten tidy rows for plotting
        for task, m in metrics.items():
            all_fold_rows.append({
                "fold": fold_idx,
                "task": task,
                "num_eval": m["num_eval"],
                "accuracy": m["accuracy"],
                "macro_f1": m["macro_f1"],
            })

        # Minimal state save (heads only)
        torch.save({"heads": {k: v.detach().cpu() for k, v in heads.state_dict().items()},
                    "embed_dim_image": d_img,
                    "embed_dim_text": (TEXT_OUT_DIM if USE_TEXT else 0),
                    "fused_dim": fused_dim,
                    "cls_spaces": cls_spaces},
                   fold_dir / "heads.pt")

        fold_summaries.append({"fold": fold_idx, "metrics": metrics})

    # -------------------------
    # Aggregate & plots (mean ± std)
    # -------------------------
    tidy_df = pd.DataFrame(all_fold_rows)
    tidy_csv = CV_DIR / "cv5_tidy_metrics.csv"
    tidy_df.to_csv(tidy_csv, index=False)
    (CV_DIR / "cv5_summary.json").write_text(json.dumps(fold_summaries, indent=2))

    # Compute mean and std per task
    agg_df = tidy_df.groupby("task", as_index=False).agg(
        mean_macro_f1=("macro_f1", "mean"),
        std_macro_f1=("macro_f1", "std"),
        mean_acc=("accuracy", "mean"),
        std_acc=("accuracy", "std")
    )

    # Plot helpers
    def _plot_bar_with_err(df, y_col_mean, y_col_std, title, ylabel, filename, color):
        plt.figure(figsize=(8, 5))
        sns.barplot(
            data=df,
            x="task",
            y=y_col_mean,
            color=color,
            edgecolor="black"
        )
        plt.errorbar(
            x=np.arange(len(df)),
            y=df[y_col_mean],
            yerr=df[y_col_std],
            fmt="none",
            ecolor="black",
            capsize=4,
            lw=1.2,
        )
        plt.title(title)
        plt.ylabel(ylabel)
        plt.xlabel("Task")
        plt.ylim(0, 1)
        plt.tight_layout()
        out = CV_DIR / filename
        plt.savefig(out, dpi=200)
        plt.close()
        return out

    out_f1 = _plot_bar_with_err(
        agg_df, "mean_macro_f1", "std_macro_f1",
        "5-Fold CV — Macro-F1 (mean ± std)", "Macro-F1",
        "cv5_macro_f1_mean_std.png", color="steelblue"
    )
    out_acc = _plot_bar_with_err(
        agg_df, "mean_acc", "std_acc",
        "5-Fold CV — Accuracy (mean ± std)", "Accuracy",
        "cv5_accuracy_mean_std.png", color="lightcoral"
    )

    print(f"\n[OK] Saved:")
    print(f"  Tidy CSV: {tidy_csv}")
    print(f"  Macro-F1 mean ± std: {out_f1}")
    print(f"  Accuracy mean ± std: {out_acc}")
    print(f"  Per-fold JSON/heads under: {CV_DIR}")

if __name__ == "__main__":
    main()

  autocast_ctx = torch.cuda.amp.autocast(enabled=False)



===== Fold 1/5 =====  train=3150  val=788


                                                                                        

  Epoch 1/25 :: EMA loss ~ 9.1557


                                                                                        

  Epoch 2/25 :: EMA loss ~ 6.6921


                                                                                        

  Epoch 3/25 :: EMA loss ~ 6.1200


                                                                                        

  Epoch 4/25 :: EMA loss ~ 5.7258


                                                                                        

  Epoch 5/25 :: EMA loss ~ 5.7220


                                                                                        

  Epoch 6/25 :: EMA loss ~ 5.4128


                                                                                        

  Epoch 7/25 :: EMA loss ~ 5.3807


                                                                                        

  Epoch 8/25 :: EMA loss ~ 5.3318


                                                                                        

  Epoch 9/25 :: EMA loss ~ 5.3659


                                                                                         

  Epoch 10/25 :: EMA loss ~ 5.2217


                                                                                         

  Epoch 11/25 :: EMA loss ~ 5.0434


                                                                                         

  Epoch 12/25 :: EMA loss ~ 5.1704


                                                                                         

  Epoch 13/25 :: EMA loss ~ 5.1135


                                                                                         

  Epoch 14/25 :: EMA loss ~ 4.9523


                                                                                         

  Epoch 15/25 :: EMA loss ~ 5.2142


                                                                                         

  Epoch 16/25 :: EMA loss ~ 5.0696


                                                                                         

  Epoch 17/25 :: EMA loss ~ 4.9953


                                                                                         

  Epoch 18/25 :: EMA loss ~ 4.9397


                                                                                         

  Epoch 19/25 :: EMA loss ~ 4.8209


                                                                                         

  Epoch 20/25 :: EMA loss ~ 4.7671


                                                                                         

  Epoch 21/25 :: EMA loss ~ 4.8943


                                                                                         

  Epoch 22/25 :: EMA loss ~ 4.9604


                                                                                         

  Epoch 23/25 :: EMA loss ~ 4.5940


                                                                                         

  Epoch 24/25 :: EMA loss ~ 4.5197


                                                                                         

  Epoch 25/25 :: EMA loss ~ 4.7664

===== Fold 2/5 =====  train=3150  val=788


                                                                                        

  Epoch 1/25 :: EMA loss ~ 9.0485


                                                                                        

  Epoch 2/25 :: EMA loss ~ 6.4557


                                                                                        

  Epoch 3/25 :: EMA loss ~ 6.1024


                                                                                        

  Epoch 4/25 :: EMA loss ~ 5.8416


                                                                                        

  Epoch 5/25 :: EMA loss ~ 5.6462


                                                                                        

  Epoch 6/25 :: EMA loss ~ 5.4164


                                                                                        

  Epoch 7/25 :: EMA loss ~ 5.5657


                                                                                        

  Epoch 8/25 :: EMA loss ~ 5.4107


                                                                                        

  Epoch 9/25 :: EMA loss ~ 5.3277


                                                                                         

  Epoch 10/25 :: EMA loss ~ 5.2410


                                                                                         

  Epoch 11/25 :: EMA loss ~ 5.1456


                                                                                         

  Epoch 12/25 :: EMA loss ~ 5.0761


                                                                                         

  Epoch 13/25 :: EMA loss ~ 5.1057


                                                                                         

  Epoch 14/25 :: EMA loss ~ 4.9724


                                                                                         

  Epoch 15/25 :: EMA loss ~ 5.0786


                                                                                         

  Epoch 16/25 :: EMA loss ~ 4.9341


                                                                                         

  Epoch 17/25 :: EMA loss ~ 5.0561


                                                                                         

  Epoch 18/25 :: EMA loss ~ 4.9017


                                                                                         

  Epoch 19/25 :: EMA loss ~ 4.8092


                                                                                         

  Epoch 20/25 :: EMA loss ~ 4.6349


                                                                                         

  Epoch 21/25 :: EMA loss ~ 4.6489


                                                                                         

  Epoch 22/25 :: EMA loss ~ 4.8116


                                                                                         

  Epoch 23/25 :: EMA loss ~ 4.6917


                                                                                         

  Epoch 24/25 :: EMA loss ~ 4.5260


                                                                                         

  Epoch 25/25 :: EMA loss ~ 4.7181

===== Fold 3/5 =====  train=3150  val=788


                                                                                        

  Epoch 1/25 :: EMA loss ~ 9.2823


                                                                                        

  Epoch 2/25 :: EMA loss ~ 6.6623


                                                                                        

  Epoch 3/25 :: EMA loss ~ 6.1251


                                                                                        

  Epoch 4/25 :: EMA loss ~ 6.0144


                                                                                        

  Epoch 5/25 :: EMA loss ~ 5.7137


                                                                                        

  Epoch 6/25 :: EMA loss ~ 5.5601


                                                                                        

  Epoch 7/25 :: EMA loss ~ 5.5490


                                                                                        

  Epoch 8/25 :: EMA loss ~ 5.3803


                                                                                        

  Epoch 9/25 :: EMA loss ~ 5.3598


                                                                                         

  Epoch 10/25 :: EMA loss ~ 5.3166


                                                                                         

  Epoch 11/25 :: EMA loss ~ 5.1868


                                                                                         

  Epoch 12/25 :: EMA loss ~ 5.1459


                                                                                         

  Epoch 13/25 :: EMA loss ~ 5.0024


                                                                                         

  Epoch 14/25 :: EMA loss ~ 4.9905


                                                                                         

  Epoch 15/25 :: EMA loss ~ 4.9319


                                                                                         

  Epoch 16/25 :: EMA loss ~ 4.8688


                                                                                         

  Epoch 17/25 :: EMA loss ~ 4.9494


                                                                                         

  Epoch 18/25 :: EMA loss ~ 4.7499


                                                                                         

  Epoch 19/25 :: EMA loss ~ 4.7407


                                                                                         

  Epoch 20/25 :: EMA loss ~ 4.7132


                                                                                         

  Epoch 21/25 :: EMA loss ~ 4.5270


                                                                                         

  Epoch 22/25 :: EMA loss ~ 4.5572


                                                                                         

  Epoch 23/25 :: EMA loss ~ 4.6234


                                                                                         

  Epoch 24/25 :: EMA loss ~ 4.5978


                                                                                         

  Epoch 25/25 :: EMA loss ~ 4.4452

===== Fold 4/5 =====  train=3151  val=787


                                                                                        

  Epoch 1/25 :: EMA loss ~ 9.1428


                                                                                        

  Epoch 2/25 :: EMA loss ~ 6.5644


                                                                                        

  Epoch 3/25 :: EMA loss ~ 6.1990


                                                                                        

  Epoch 4/25 :: EMA loss ~ 5.8578


                                                                                        

  Epoch 5/25 :: EMA loss ~ 5.7518


                                                                                        

  Epoch 6/25 :: EMA loss ~ 5.4681


                                                                                        

  Epoch 7/25 :: EMA loss ~ 5.3692


                                                                                        

  Epoch 8/25 :: EMA loss ~ 5.4518


                                                                                        

  Epoch 9/25 :: EMA loss ~ 5.3986


                                                                                         

  Epoch 10/25 :: EMA loss ~ 5.2169


                                                                                         

  Epoch 11/25 :: EMA loss ~ 5.2009


                                                                                         

  Epoch 12/25 :: EMA loss ~ 5.1468


                                                                                         

  Epoch 13/25 :: EMA loss ~ 5.1643


                                                                                         

  Epoch 14/25 :: EMA loss ~ 5.0149


                                                                                         

  Epoch 15/25 :: EMA loss ~ 5.1749


                                                                                         

  Epoch 16/25 :: EMA loss ~ 4.9642


                                                                                         

  Epoch 17/25 :: EMA loss ~ 5.3027


                                                                                         

  Epoch 18/25 :: EMA loss ~ 5.0913


                                                                                         

  Epoch 19/25 :: EMA loss ~ 4.8566


                                                                                         

  Epoch 20/25 :: EMA loss ~ 4.9365


                                                                                         

  Epoch 21/25 :: EMA loss ~ 4.9731


                                                                                         

  Epoch 22/25 :: EMA loss ~ 4.6993


                                                                                         

  Epoch 23/25 :: EMA loss ~ 4.6314


                                                                                         

  Epoch 24/25 :: EMA loss ~ 4.4896


                                                                                         

  Epoch 25/25 :: EMA loss ~ 4.4129

===== Fold 5/5 =====  train=3151  val=787


                                                                                        

  Epoch 1/25 :: EMA loss ~ 9.0954


                                                                                        

  Epoch 2/25 :: EMA loss ~ 6.4866


                                                                                        

  Epoch 3/25 :: EMA loss ~ 6.1261


                                                                                        

  Epoch 4/25 :: EMA loss ~ 5.8658


                                                                                        

  Epoch 5/25 :: EMA loss ~ 5.5485


                                                                                        

  Epoch 6/25 :: EMA loss ~ 5.4953


                                                                                        

  Epoch 7/25 :: EMA loss ~ 5.6248


                                                                                        

  Epoch 8/25 :: EMA loss ~ 5.2263


                                                                                        

  Epoch 9/25 :: EMA loss ~ 5.3466


                                                                                         

  Epoch 10/25 :: EMA loss ~ 5.2445


                                                                                         

  Epoch 11/25 :: EMA loss ~ 5.1479


                                                                                         

  Epoch 12/25 :: EMA loss ~ 5.1101


                                                                                         

  Epoch 13/25 :: EMA loss ~ 4.9894


                                                                                         

  Epoch 14/25 :: EMA loss ~ 4.9370


                                                                                         

  Epoch 15/25 :: EMA loss ~ 4.8379


                                                                                         

  Epoch 16/25 :: EMA loss ~ 4.8042


                                                                                         

  Epoch 17/25 :: EMA loss ~ 4.9296


                                                                                         

  Epoch 18/25 :: EMA loss ~ 4.7048


                                                                                         

  Epoch 19/25 :: EMA loss ~ 4.8873


                                                                                         

  Epoch 20/25 :: EMA loss ~ 5.0950


                                                                                         

  Epoch 21/25 :: EMA loss ~ 4.8355


                                                                                         

  Epoch 22/25 :: EMA loss ~ 4.7905


                                                                                         

  Epoch 23/25 :: EMA loss ~ 4.5785


                                                                                         

  Epoch 24/25 :: EMA loss ~ 4.4716


                                                                                         

  Epoch 25/25 :: EMA loss ~ 4.5308

[OK] Saved:
  Tidy CSV: vqa\cv5_20251015_195205\cv5_tidy_metrics.csv
  Macro-F1 mean ± std: vqa\cv5_20251015_195205\cv5_macro_f1_mean_std.png
  Accuracy mean ± std: vqa\cv5_20251015_195205\cv5_accuracy_mean_std.png
  Per-fold JSON/heads under: vqa\cv5_20251015_195205


### Gradio app

In [None]:
# vqa_app.py — MSI viewer + single-question CLS Q/A
# pip install gradio==4.* pandas scikit-learn timm transformers torch torchvision torchaudio

# --- Windows event-loop fix (MUST be before importing gradio) ---
import os, asyncio
if os.name == "nt":
    try:
        asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
    except Exception:
        pass

import re, json, math, hashlib, glob
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
import gradio as gr

import torch
import torch.nn as nn
import torch.nn.functional as F

# =========================
# Heads we expose (CLS-only)
# =========================
CLS_HEADS = [
    "organism",
    "polarity",
    "organ",
    "condition",
    "analyzerType",
    "ionisationSource",
]

SEED = 6740
torch.manual_seed(SEED); np.random.seed(SEED)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# =========================
# Fixed question (as trained)
# =========================
FIXED_QUESTION = "what is the organism?"

# ===========================================
# Image helpers (preview + robust normalization)
# ===========================================
def _safe_uint8(img: np.ndarray) -> np.ndarray:
    img = np.nan_to_num(img, nan=0.0, posinf=0.0, neginf=0.0)
    lo, hi = np.percentile(img, [1, 99])
    if hi <= lo: hi = lo + 1e-6
    img = np.clip((img - lo) / (hi - lo), 0, 1)
    return (img * 255).astype(np.uint8)

def _rgb_from_patch_pca(patch: np.ndarray) -> np.ndarray:
    # patch: (C,H,W) in [0,1]
    C, H, W = patch.shape
    X = patch.reshape(C, -1).T  # (H*W, C)
    if C > 0:
        X = X - X.mean(axis=0, keepdims=True)
    comps = min(3, max(1, C))
    if C >= comps:
        pca = PCA(n_components=comps, svd_solver="randomized")
        Y = pca.fit_transform(X)
    else:
        Y = np.zeros((H * W, comps), dtype=np.float32)
    rgb = np.zeros((H * W, 3), dtype=np.float32)
    rgb[:, :comps] = Y
    rgb = rgb.reshape(H, W, 3)
    out = np.zeros_like(rgb, dtype=np.uint8)
    for i in range(3):
        out[..., i] = _safe_uint8(rgb[..., i])
    return out

def _single_channel_gray(patch: np.ndarray, idx: int) -> np.ndarray:
    idx = int(np.clip(int(idx), 0, max(0, patch.shape[0] - 1)))
    ch = patch[idx]
    return np.stack([_safe_uint8(ch)] * 3, axis=-1)

def _load_npz_patch(npz_file) -> Tuple[np.ndarray, np.ndarray, str]:
    path = npz_file.name if hasattr(npz_file, "name") else npz_file
    with np.load(path, mmap_mode="r") as z:
        patch = z["patch"].astype(np.float32)  # (C,H,W)
        mz    = z["mz"].astype(np.float32)
    # normalize if uint16 scale
    if patch.max() > 1.0:
        patch /= 65535.0
    return patch, mz, path

# ===========================================
# Intent detection (CLS focus; legacy fallbacks)
# ===========================================
INTENT = {
    "organism":         re.compile(r"\b(organism|species)\b", re.I),
    "polarity":         re.compile(r"\bpolari(?:ty)?\b", re.I),
    "organ":            re.compile(r"\b(organ(?:ism)?\s*part|organ\b|tissue)\b", re.I),
    "condition":        re.compile(r"\b(condition|status)\b", re.I),
    "analyzerType":     re.compile(r"\banaly[sz]er(?:\s*type)?\b", re.I),
    "ionisationSource": re.compile(r"\bion(i[sz]ation)?\s*source\b|\bion[i|z]iser\b", re.I),
    # legacy examples
    "left_right": re.compile(r"\b(left|right).*(bright|darker|brighter)\b", re.I),
    "count_5pct": re.compile(r"\bhow many\b.*(five|5)\s*percent.*(non[- ]?zero|nonzero)", re.I),
    "mz_yesno":   re.compile(r"\b(ion|peak).*(near|around)\s*m/?z\s*([0-9]+(?:\.[0-9]+)?)", re.I),
}

def detect_intent(question: str) -> Tuple[str, Optional[str]]:
    q = (question or "").strip()
    if not q:
        return ("none", None)
    for head in CLS_HEADS:
        if INTENT[head].search(q):
            return ("cls", head)
    if INTENT["left_right"].search(q):
        return ("legacy_left_right", None)
    if INTENT["count_5pct"].search(q):
        return ("legacy_count_5pct", None)
    m = INTENT["mz_yesno"].search(q)
    if m:
        return ("legacy_mz_yesno", float(m.group(3)))
    return ("cls", "auto")

# ===========================================
# Summarization helpers
# ===========================================
def _best_cls_head(cls_dict: dict):
    best_h, best_v = None, None
    for h, d in (cls_dict or {}).items():
        if not isinstance(d, dict):
            continue
        if best_v is None or float(d.get("confidence", 0.0)) > float(best_v.get("confidence", 0.0)):
            best_h, best_v = h, d
    return best_h, best_v

def _cls_summary(cls_dict: dict, target: Optional[str]):
    if not cls_dict:
        return "I couldn't infer a class from this model."
    if target and target != "auto":
        for k, d in cls_dict.items():
            if k.lower() == target.lower():
                return f"**{k}** → **{d.get('pred','?')}**."
    k, d = _best_cls_head(cls_dict)
    if k is None or d is None:
        return "I couldn't infer a class from this model."
    return f"**{k}** → **{d.get('pred','?')}**."

def summarize_filtered(result_item: dict, intent_kind: str, intent_target: Optional[str]):
    r = result_item.get("result", {})
    cls_dict = r.get("cls")
    if intent_kind == "cls":
        return _cls_summary(cls_dict, intent_target)
    yn = r.get("yesno", None)
    if intent_kind in ("legacy_left_right", "legacy_mz_yesno", "legacy_count_5pct"):
        if isinstance(yn, dict) and "pred" in yn:
            return f"**{str(yn['pred']).upper()}**."
        return _cls_summary(cls_dict, target=None)
    return _cls_summary(cls_dict, target=None)

# ===========================================
# Model components (match training/eval)
# ===========================================
import timm
from transformers import AutoModel, AutoTokenizer

class FrozenBackbone(nn.Module):
    """timm ViT backbone returning [B, D] CLS-like embedding; all params frozen."""
    def __init__(self, timm_name: str, pretrained: bool = True):
        super().__init__()
        self.m = timm.create_model(timm_name, pretrained=pretrained, num_classes=0)
        for p in self.m.parameters():
            p.requires_grad_(False)
        self.m.eval()

    @torch.no_grad()
    def forward(self, x3):  # [N,3,H,W]
        feats = self.m.forward_features(x3)
        if isinstance(feats, dict):
            if 'x_norm_clstoken' in feats:  return feats['x_norm_clstoken']
            if 'cls_token' in feats:        return feats['cls_token']
            if 'avgpool' in feats:          return feats['avgpool']
            for k in ('last_hidden_state', 'tokens', 'x'):
                if k in feats and torch.is_tensor(feats[k]):
                    t = feats[k]
                    return t[:, 0] if t.dim() == 3 else t
        if torch.is_tensor(feats):
            return feats[:, 0] if feats.dim() == 3 else feats
        return feats.mean(dim=-2)

def crop_resize_to_target(x3: torch.Tensor, target=224, patch_multiple=16) -> torch.Tensor:
    _, _, H, W = x3.shape
    Hc = (H // patch_multiple) * patch_multiple
    Wc = (W // patch_multiple) * patch_multiple
    dh = (H - Hc) // 2; dw = (W - Wc) // 2
    if Hc > 0 and Wc > 0:
        x3 = x3[:, :, dh:dh+Hc, dw:dw+Wc]
    if (Hc, Wc) != (target, target):
        x3 = F.interpolate(x3, size=(target, target), mode="bilinear", align_corners=False)
    return x3

class HFTextEnc(nn.Module):
    """Trainable HuggingFace encoder + linear projection (used during training)."""
    def __init__(self, name: str, out_dim: int):
        super().__init__()
        self.model = AutoModel.from_pretrained(name)
        hid = self.model.config.hidden_size
        self.proj = nn.Linear(hid, out_dim)
        self.name = name

    def forward(self, ids, mask):
        out = self.model(input_ids=ids, attention_mask=mask, return_dict=True)
        x = (out.last_hidden_state * mask.unsqueeze(-1)).sum(1) / (mask.sum(1, keepdim=True) + 1e-6)
        return self.proj(x)  # [B, out_dim]

class Fusion(nn.Module):
    """Exactly matches training/eval: LN(img), LN(txt), MLP on concat."""
    def __init__(self, d_img, d_txt, d_out):
        super().__init__()
        self.ln_img = nn.LayerNorm(d_img)
        self.ln_txt = nn.LayerNorm(d_txt)
        self.mlp = nn.Sequential(
            nn.Linear(d_img + d_txt, d_out),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(d_out, d_out),
        )
    def forward(self, zi, zt):
        x = torch.cat([self.ln_img(zi), self.ln_txt(zt)], dim=-1)
        return self.mlp(x)

class VQAHeads(nn.Module):
    """Per-task linear heads over fused embedding z_fused."""
    def __init__(self, embed_dim: int, cls_spaces: Dict[str, List[str]]):
        super().__init__()
        self.heads = nn.ModuleDict({k: nn.Linear(embed_dim, len(v)) for k, v in cls_spaces.items()})
        for lin in self.heads.values():
            nn.init.trunc_normal_(lin.weight, std=0.02); nn.init.zeros_(lin.bias)

    def forward(self, z):  # [B, D_fused]
        return {k: h(z) for k, h in self.heads.items()}

# ===========================================
# Loading & inference utils
# ===========================================
RUN_ROOT = Path("vqa")
IMG_CACHE_DIR = RUN_ROOT / "_img_cache"
STATS_DIR    = RUN_ROOT / "_stats_cache"

def _pick_ckpt(run_dir: str) -> str:
    best = os.path.join(run_dir, "best.pt")
    last = os.path.join(run_dir, "last.pt")
    if os.path.exists(best): return best
    if os.path.exists(last): return last
    cand = [os.path.join(run_dir, f) for f in os.listdir(run_dir) if f.endswith(".pt")]
    if not cand:
        raise FileNotFoundError(f"No checkpoint found in {run_dir}")
    return sorted(cand)[-1]

def _load_config(run_dir: str) -> dict:
    cfg_path = os.path.join(run_dir, "config.json")
    if not os.path.exists(cfg_path):
        raise FileNotFoundError(f"Missing config.json in {run_dir}")
    return json.load(open(cfg_path, "r"))

def _per_image_channel_zscore(x: torch.Tensor) -> torch.Tensor:
    # x: [1,C,H,W], return z-scored per-channel over H*W (fallback of last resort)
    B, C, H, W = x.shape
    xm = x.view(B, C, -1).mean(dim=-1, keepdim=True).view(B, C, 1, 1)
    xs = x.view(B, C, -1).std(dim=-1, keepdim=True).view(B, C, 1, 1).clamp_min(1e-6)
    return (x - xm) / xs

def _embed_key(path, timm_id, patch_multiple, target_size):
    h = hashlib.sha1(f"{timm_id}|{patch_multiple}|{target_size}|{path}".encode()).hexdigest()
    return IMG_CACHE_DIR / f"{h}.npy"

def _load_stats_or_none(channels_per_view: int, input_size: int):
    """
    Try to load a mu/std cache that matches channels_per_view & input_size.
    We pick the most recent matching file from STATS_DIR.
    """
    pattern = str(STATS_DIR / f"mu_std_c{int(channels_per_view)}_in{int(input_size)}_*.npz")
    files = sorted(glob.glob(pattern))
    if not files:
        return None, None
    z = np.load(files[-1], allow_pickle=True)
    mu, std = z["mu"], z["std"]
    return mu.astype(np.float32), np.maximum(std.astype(np.float32), 1e-6)

def _select_k_first(patch: np.ndarray, k: int) -> np.ndarray:
    """
    Match training behavior: keep channels in original order, truncate/pad to k.
    If C >= k: take first k. If C < k: repeat from start to reach k.
    """
    C, H, W = patch.shape
    if C >= k:
        return patch[:k]
    reps = int(np.ceil(k / max(C, 1)))
    tiled = np.tile(patch, (reps, 1, 1))[:k]
    return tiled

@torch.no_grad()
def _encode_image_mean_cls_fallback_with_stats(
    patch_chw: np.ndarray,
    bb: FrozenBackbone,
    *,
    target: int,
    patch_multiple: int,
    channels_per_view: int,
    input_size: int,
    mu: Optional[np.ndarray],
    std: Optional[np.ndarray],
    channels_per_step: int = 16
) -> torch.Tensor:
    """
    Fallback encoder that attempts to match training:
    - select top-variance channels -> exactly channels_per_view
    - normalize with dataset mu/std if available; else per-image z-score
    - crop/resize & mean-CLS over channels
    """
    # 1) channel selection
    x = _select_k_first(patch_chw, channels_per_view)  # [K,H,W]
    x = torch.from_numpy(x).unsqueeze(0)  # [1,K,H,W]

    # 2) normalization
    if mu is not None and std is not None and len(mu) >= channels_per_view and len(std) >= channels_per_view:
        mu_t = torch.from_numpy(mu[:channels_per_view]).view(1, -1, 1, 1)
        sd_t = torch.from_numpy(std[:channels_per_view]).view(1, -1, 1, 1)
        x = (x - mu_t) / sd_t
    else:
        x = _per_image_channel_zscore(x)

    B, K, H, W = x.shape

    # 3) flatten channels to batch, replicate to 3ch
    x_flat = x.permute(0, 2, 3, 1).contiguous().view(B * K, 1, H, W)
    x_rgb  = x_flat.repeat(1, 3, 1, 1)

    # 4) chunk over channels
    cls_chunks = []
    step = channels_per_step if target == 224 else max(4, channels_per_step // 2)
    for s in range(0, B * K, step):
        e = min(s + step, B * K)
        xr = crop_resize_to_target(x_rgb[s:e], target=target, patch_multiple=patch_multiple)
        cls = bb(xr.to(DEVICE))             # [N, D_img]
        cls_chunks.append(cls.float().cpu())
    cls_all = torch.cat(cls_chunks, dim=0)  # [K, D_img]
    z_img = cls_all.mean(dim=0, keepdim=True)  # [1, D_img]
    return F.normalize(z_img, dim=-1)

def _softmax_top(logits: torch.Tensor, vocab: List[str]) -> Tuple[str, float]:
    probs = logits.softmax(dim=-1)[0]  # [K]
    conf, idx = float(probs.max().item()), int(probs.argmax().item())
    pred = vocab[idx] if 0 <= idx < len(vocab) else "unknown"
    return pred, conf

def _encode_fixed_question(tok, text_enc, max_len):
    t = tok(FIXED_QUESTION, padding=False, truncation=True, max_length=max_len, return_tensors="pt")
    ids  = t["input_ids"].to(DEVICE)
    mask = t["attention_mask"].to(DEVICE)
    with torch.no_grad():
        z_txt = F.normalize(text_enc(ids, mask), dim=-1)
    return z_txt

def _normalize_question(q: str) -> str:
    q = (q or "").strip()
    if not q:
        return "What organism is this sample?"
    return q

def _apply_unknown_penalty(logits: torch.Tensor, vocab: List[str], penalty: float) -> torch.Tensor:
    """
    Subtract a constant from the 'unknown' logit (if present) to reduce its dominance.
    logits: [1, K]
    """
    if penalty <= 0:
        return logits
    try:
        unk_idx = vocab.index("unknown")
    except ValueError:
        return logits
    out = logits.clone()
    out[0, unk_idx] = out[0, unk_idx] - float(penalty)
    return out

# ===========================================
# Core: run one question on one .npz
# ===========================================
def vqa_on_npz_single(
    run_dir: str,
    npz_path: str,
    question: str
) -> Dict:
    """
    Loads ckpt + config from run_dir (frozen image backbone setup), runs one question on one .npz.
    Returns:
      { "result": { "cls": { head: {"pred": str, "confidence": float}, ... } },
        "meta": { "used_cached_embedding": bool, "used_dataset_stats": bool } }
    """
    # ----- Load cfg + ckpt
    ckpt_path = _pick_ckpt(run_dir)
    cfg = _load_config(run_dir)
    state = torch.load(ckpt_path, map_location="cpu")

    # Backbone config from training
    timm_id        = cfg.get("timm_id", "vit_small_patch14_reg4_dinov2.lvd142m")
    patch_multiple = int(cfg.get("patch_multiple", 14))
    target_size    = int(cfg.get("target_size", 518))
    text_model_id  = cfg.get("hf_text_model", "sentence-transformers/all-MiniLM-L6-v2")
    text_out_dim   = int(cfg.get("text_out_dim", cfg.get("embed_dim_text", 384)))
    channels_per_view = int(cfg.get("channels_per_view", 64))
    input_size     = int(cfg.get("input_size", 256))

    # Class spaces (from ckpt if available to ensure exact vocab)
    cls_spaces = state.get("cls_spaces", cfg.get("cls_spaces", {h: ["unknown"] for h in CLS_HEADS}))

    # ----- Build modules
    bb = FrozenBackbone(timm_id, pretrained=True).to(DEVICE).eval()
    tok = AutoTokenizer.from_pretrained(text_model_id)
    text_enc = HFTextEnc(text_model_id, out_dim=text_out_dim).to(DEVICE).eval()

    d_img     = int(cfg.get("embed_dim_image", 384))
    fused_dim = int(cfg.get("embed_dim_fused", d_img + text_out_dim))
    fusion    = Fusion(d_img, text_out_dim, fused_dim).to(DEVICE).eval()
    heads     = VQAHeads(embed_dim=fused_dim, cls_spaces=cls_spaces).to(DEVICE).eval()

    # ----- Load weights (strict for fusion/heads, relaxed for text_enc)
    if "text_enc" in state:
        text_enc.load_state_dict(state["text_enc"], strict=False)
    if "fusion" in state:
        fusion.load_state_dict(state["fusion"], strict=True)
    if "heads" in state:
        heads.load_state_dict(state["heads"], strict=True)

    # ----- Load sample
    with np.load(npz_path, mmap_mode="r") as z:
        patch = z["patch"].astype(np.float32)  # (C,H,W)
        if patch.max() > 1.0: patch /= 65535.0

    # ----- Prefer cached image embedding (identical to eval); fallback to train-like path
    embed_path = _embed_key(npz_path, timm_id, patch_multiple, target_size)
    meta_used_cache = False
    meta_used_stats = False

    if embed_path.exists():
        z_img = torch.from_numpy(np.load(embed_path, mmap_mode="r")).unsqueeze(0).to(DEVICE).float()
        z_img = F.normalize(z_img, dim=-1)
        meta_used_cache = True
    else:
        mu, std = _load_stats_or_none(channels_per_view=channels_per_view, input_size=input_size)
        meta_used_stats = mu is not None and std is not None
        z_img = _encode_image_mean_cls_fallback_with_stats(
            patch, bb,
            target=target_size,
            patch_multiple=patch_multiple,
            channels_per_view=channels_per_view,
            input_size=input_size,
            mu=mu,
            std=std,
            channels_per_step=16
        ).to(DEVICE)

    # ----- Encode text as fixed prompt (matches training)
    _ = _normalize_question(question)  # user text ignored to match training; kept for future multi-prompt training
    z_txt = _encode_fixed_question(tok, text_enc, max_len=int(cfg.get("text_max_len", 64)))

    # ----- Fuse & predict
    with torch.no_grad():
        z_fused = fusion(z_img, z_txt)            # [1, D_fused]
        logits_dict = heads(z_fused)              # dict of head_name -> [1, K]

    UNKNOWN_LOGIT_PENALTY = 0.7  # try 0.5..1.0 if 'unknown' is still over-predicted

    # ----- Build result
    res = {"cls": {}, "meta": {"used_cached_embedding": bool(meta_used_cache),
                            "used_dataset_stats": bool(meta_used_stats)}}
    for head, vocab in cls_spaces.items():
        if head not in logits_dict:
            continue
        # apply penalty before softmax
        adj = _apply_unknown_penalty(logits_dict[head], vocab, UNKNOWN_LOGIT_PENALTY)
        pred, conf = _softmax_top(adj, vocab)
        res["cls"][head] = {"pred": pred, "confidence": conf}

    return {"result": res}

# ===========================================
# Preview handler (no models)
# ===========================================
def _preview(npz_file, view_mode, ch_index):
    if npz_file is None:
        return None
    patch, mz, _ = _load_npz_patch(npz_file)
    if view_mode == "PCA RGB":
        return _rgb_from_patch_pca(patch)
    return _single_channel_gray(patch, int(ch_index))

# ===========================================
# Core run handler (loads run, answers one question)
# ===========================================
def run_basic(vqa_run, npz_file, view_mode, ch_index, question, state):
    if npz_file is None:
        return None, "Please upload a sample first.", state

    # preview image
    patch, mz, real_path = _load_npz_patch(npz_file)
    if view_mode == "PCA RGB":
        rgb = _rgb_from_patch_pca(patch)
    else:
        rgb = _single_channel_gray(patch, int(ch_index))

    if not question or not str(question).strip():
        return rgb, "Type a question like **What organism is this sample?**", state

    try:
        out = vqa_on_npz_single(run_dir=str(vqa_run), npz_path=real_path, question=str(question))
        ikind, itarget = detect_intent(question)
        answer_text = summarize_filtered(out, ikind, itarget)

        meta = out.get("result", {}).get("meta", {})
        used_cache = meta.get("used_cached_embedding", False)
        used_stats = meta.get("used_dataset_stats", False)

    except Exception as e:
        answer_text = f"Error while running the model: {e}"

    return rgb, answer_text, state

# ===========================================
# UI (Gradio)
# ===========================================
with gr.Blocks(title="metaboFM VQA") as demo:
    gr.Markdown(
        "## metaboFM\n"
        "Upload an MSI patch (`.npz` with arrays `patch` (C,H,W)), preview it, and ask:\n"
        "- *What organism is this sample?*\n"
        "- *What is the ionization polarity?*\n"
        "- *Which organ is this sample from?*\n"
        "- *What is the sample condition?*\n"
        "- *What analyzer type / ionisation source was used?*\n"
    )

    with gr.Row():
        with gr.Column(scale=2):
            npz_file  = gr.File(label="Upload .npz (must contain 'patch' and 'mz')", file_types=[".npz"])
            view_mode = gr.Radio(choices=["PCA RGB", "Single Channel"], value="PCA RGB", label="View")
            ch_index  = gr.Slider(label="Channel (for Single Channel view)", minimum=0, maximum=255, step=1, value=0)
            img_out   = gr.Image(label="Preview", type="numpy")

        with gr.Column(scale=1):
            gr.Image("metabofm.png", label="", show_label=False, container=False, interactive=False, height=220)

            question  = gr.Textbox(
                label="Ask a question",
                placeholder="e.g., What organism is this sample?",
            )
            with gr.Row():
                btn_org = gr.Button("What organism is this sample?")
                btn_pol = gr.Button("What is the ionization polarity?")
            with gr.Row():
                btn_orgn = gr.Button("Which organ is this sample from?")
                btn_cond = gr.Button("What is the sample condition?")
            with gr.Row():
                btn_an   = gr.Button("What analyzer type was used?")
                btn_ions = gr.Button("What ionisation source was used?")

            ask_btn   = gr.Button("Ask", variant="primary")
            answer    = gr.Markdown("")

            with gr.Accordion("Model run directory", open=False):
                vqa_run   = gr.Textbox(label="VQA run directory (contains best.pt/last.pt + config.json)", value="vqa/20251015_171627")

            state = gr.State(value=None)

    # Preview on change
    for ctrl in [npz_file, view_mode, ch_index]:
        ctrl.change(
            fn=_preview,
            inputs=[npz_file, view_mode, ch_index],
            outputs=img_out
        )

    # Quick-pick question helpers
    def _q_org():   return "What organism is this sample?"
    def _q_pol():   return "What is the ionization polarity?"
    def _q_orgn():  return "Which organ is this sample from?"
    def _q_cond():  return "What is the sample condition?"
    def _q_an():    return "What analyzer type was used?"
    def _q_ions():  return "What ionisation source was used?"

    btn_org.click(fn=_q_org, outputs=question).then(
        fn=run_basic,
        inputs=[vqa_run, npz_file, view_mode, ch_index, question, state],
        outputs=[img_out, answer, state],
    )
    btn_pol.click(fn=_q_pol, outputs=question).then(
        fn=run_basic,
        inputs=[vqa_run, npz_file, view_mode, ch_index, question, state],
        outputs=[img_out, answer, state],
    )
    btn_orgn.click(fn=_q_orgn, outputs=question).then(
        fn=run_basic,
        inputs=[vqa_run, npz_file, view_mode, ch_index, question, state],
        outputs=[img_out, answer, state],
    )
    btn_cond.click(fn=_q_cond, outputs=question).then(
        fn=run_basic,
        inputs=[vqa_run, npz_file, view_mode, ch_index, question, state],
        outputs=[img_out, answer, state],
    )
    btn_an.click(fn=_q_an, outputs=question).then(
        fn=run_basic,
        inputs=[vqa_run, npz_file, view_mode, ch_index, question, state],
        outputs=[img_out, answer, state],
    )
    btn_ions.click(fn=_q_ions, outputs=question).then(
        fn=run_basic,
        inputs=[vqa_run, npz_file, view_mode, ch_index, question, state],
        outputs=[img_out, answer, state],
    )

    # Ask
    ask_btn.click(
        fn=run_basic,
        inputs=[vqa_run, npz_file, view_mode, ch_index, question, state],
        outputs=[img_out, answer, state]
    )
    question.submit(
        fn=run_basic,
        inputs=[vqa_run, npz_file, view_mode, ch_index, question, state],
        outputs=[img_out, answer, state]
    )

if __name__ == "__main__":
    # Gradio queue helps keep a single event loop path for uploads/progress on Windows
    demo.queue(status_update_rate=1)
    demo.launch()

* Running on local URL:  http://127.0.0.1:7871
* To create a public link, set `share=True` in `launch()`.


ERROR:    Exception in ASGI application
Traceback (most recent call last):
  File "c:\Users\eozturk7\AppData\Local\miniconda3\envs\magic\lib\site-packages\uvicorn\protocols\http\h11_impl.py", line 403, in run_asgi
    result = await app(  # type: ignore[func-returns-value]
  File "c:\Users\eozturk7\AppData\Local\miniconda3\envs\magic\lib\site-packages\uvicorn\middleware\proxy_headers.py", line 60, in __call__
    return await self.app(scope, receive, send)
  File "c:\Users\eozturk7\AppData\Local\miniconda3\envs\magic\lib\site-packages\fastapi\applications.py", line 1133, in __call__
    await super().__call__(scope, receive, send)
  File "c:\Users\eozturk7\AppData\Local\miniconda3\envs\magic\lib\site-packages\starlette\applications.py", line 113, in __call__
    await self.middleware_stack(scope, receive, send)
  File "c:\Users\eozturk7\AppData\Local\miniconda3\envs\magic\lib\site-packages\starlette\middleware\errors.py", line 186, in __call__
    raise exc
  File "c:\Users\eozturk7\Ap

In [56]:
from fm_utils import *
from vqa_utils import *

# ---------------- Example ----------------
SAMPLE_PATH = r"metaspace_images_dump\2021-10-11_19h38m30s\metadata_full.json"
df_gt = gt_answers_for_questions_json(
[
"What organism is this sample?",
"What is the ionization polarity?",
"Which organ is this sample from?",
"What is the sample condition?",
"What analyzer type was used?",
"What ionisation source was used?"
],
sample_path=SAMPLE_PATH
)
display(df_gt)

Unnamed: 0,dataset_id,question,head,gt_normalized,gt_original,source
0,2021-10-11_19h38m30s,What organism is this sample?,organism,homo sapiens,Homo sapiens (human),metaspace_images_dump/2021-10-11_19h38m30s/met...
1,2021-10-11_19h38m30s,What is the ionization polarity?,polarity,positive,Positive,metaspace_images_dump/2021-10-11_19h38m30s/met...
2,2021-10-11_19h38m30s,Which organ is this sample from?,organ,lung,Lung,metaspace_images_dump/2021-10-11_19h38m30s/met...
3,2021-10-11_19h38m30s,What is the sample condition?,condition,biopsy,biopsy,metaspace_images_dump/2021-10-11_19h38m30s/met...
4,2021-10-11_19h38m30s,What analyzer type was used?,analyzerType,fticr,FTICR,metaspace_images_dump/2021-10-11_19h38m30s/met...
5,2021-10-11_19h38m30s,What ionisation source was used?,ionisationSource,maldi,MALDI,metaspace_images_dump/2021-10-11_19h38m30s/met...
