In [2]:
# =====================================================
# OFF-THE-SHELF FEATURES (MEAN CLS ONLY):
# DINOv2 / DeiT / MAE on MSI via per-channel replication
# - Each MSI channel -> replicate to RGB -> frozen ViT/CNN
# - Extract CLS per channel, then MEAN across channels (no PCA)
# - Caches per-channel mean/std to disk (no re-compute every run)
# - Skips models that already have image_feats.npy
# - Also outputs PCA(256) pixel baseline (optional, cached)
# =====================================================

import os, json, hashlib
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import timm
from sklearn.decomposition import IncrementalPCA

from fm_utils import (
    CFG, NPZDataset, collate_simple, nullcontext
)

# --------------- SPEED / BEHAVIOR KNOBS ---------------
AMP = True                        # automatic mixed precision on GPU
CHANNELS_PER_VIEW = 64            # load up to this many MSI channels per tile
CHANNELS_PER_STEP = 16            # process this many channels per chunk (tune for memory)
BATCH_SIZE = 8                    # MSI tiles per batch; tune with CHANNELS_PER_STEP
NUM_WORKERS = 0                   # dataloader workers
TORCH_BENCHMARK = True            # cudnn autotune for fixed shapes
SEED = 6740

# -------------- PATHS / DATA --------------
IDX_PARQUET = "metaspace_images_dump/msi_fm_samples3.parquet"
OUT_DIR     = os.path.join("fm_ssl_run", "pretrained_feats2")
STATS_DIR   = os.path.join(OUT_DIR, "_stats_cache")
os.makedirs(OUT_DIR, exist_ok=True)
os.makedirs(STATS_DIR, exist_ok=True)

df = pd.read_parquet(IDX_PARQUET)
all_paths = df["sample_path"].tolist()

# -------------- BASE CONFIG --------------
BASE_CFG = CFG(
    channels_per_view=CHANNELS_PER_VIEW,
    input_size=256,   # dataset crop size, independent from backbone target
    crop_size=256,
    patch_size=16,    # overridden per model to 14 for ViT/14
    batch_size=BATCH_SIZE,
    seed=SEED
)

# -------------- TORCH SETUP --------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available() and TORCH_BENCHMARK:
    torch.backends.cudnn.benchmark = True
torch.manual_seed(SEED)
np.random.seed(SEED)

# -------------- DATALOADER --------------
def make_loader(paths, cfg, shuffle=False):
    ds = NPZDataset(
        paths,
        target_h=int(cfg.input_size),
        target_w=int(cfg.input_size),
        k_target=int(cfg.channels_per_view),
        scale_u16=True,             # returns float in [0,1]
        pad_mode="repeat",
        sort_by_mz=False
    )
    ld = DataLoader(
        ds,
        batch_size=int(cfg.batch_size),
        shuffle=shuffle,
        num_workers=NUM_WORKERS,
        pin_memory=torch.cuda.is_available(),
        persistent_workers=(NUM_WORKERS > 0),
        collate_fn=collate_simple,
    )
    return ds, ld

# -------------- BACKBONE (FROZEN) --------------
class FrozenBackbone(nn.Module):
    """
    timm backbone with num_classes=0; returns a single [B, D] embedding.
    Robust to different timm return types (dict, [B,N,D], [B,D]).
    """
    def __init__(self, timm_name: str, pretrained: bool = True):
        super().__init__()
        self.m = timm.create_model(timm_name, pretrained=pretrained, num_classes=0)
        for p in self.m.parameters():
            p.requires_grad_(False)
        self.m.eval()

    @torch.no_grad()
    def forward(self, x3):  # [B,3,H,W]
        feats = self.m.forward_features(x3)

        # Common DINOv2 path (dict with normalized CLS)
        if isinstance(feats, dict):
            if 'x_norm_clstoken' in feats:         # DINOv2 ViT-*/14
                return feats['x_norm_clstoken']    # [B, D]
            if 'cls_token' in feats:               # some ViT variants
                return feats['cls_token']          # [B, D]
            if 'avgpool' in feats:                 # DeiT/MAE heads-off fallback
                return feats['avgpool']            # [B, D]
            for k in ('last_hidden_state', 'tokens', 'x'):
                if k in feats and torch.is_tensor(feats[k]):
                    t = feats[k]
                    if t.dim() == 3:               # [B, N, D]
                        return t[:, 0]             # CLS at index 0
                    if t.dim() == 2:               # [B, D]
                        return t

        # Tensor outputs
        if torch.is_tensor(feats):
            if feats.dim() == 3:                   # [B, N, D]
                return feats[:, 0]                 # CLS
            if feats.dim() == 2:                   # [B, D]
                return feats

        # Fallback: global average over spatial/token dim if ambiguous
        if torch.is_tensor(feats) and feats.dim() >= 3:
            return feats.mean(dim=-2)              # [B, D] assuming [-2] is token dim

        raise RuntimeError("Unsupported backbone forward_features() return type")

# -------------- CENTER-CROP TO PATCH MULTIPLE + RESIZE --------------
def crop_resize_to_target(x3, target=224, patch_multiple=16):
    """
    x3: [N,3,H,W] float; returns [N,3,target,target]
    Center-crop H,W to be divisible by patch_multiple, then bilinear resize to target.
    """
    _, _, H, W = x3.shape
    Hc = (H // patch_multiple) * patch_multiple
    Wc = (W // patch_multiple) * patch_multiple
    dh = (H - Hc) // 2
    dw = (W - Wc) // 2
    if Hc > 0 and Wc > 0:
        x3 = x3[:, :, dh:dh+Hc, dw:dw+Wc]
    if (Hc, Wc) != (target, target):
        x3 = F.interpolate(x3, size=(target, target), mode="bilinear", align_corners=False)
    return x3

# -------------- PER-CHANNEL STATS (z-score) --------------
@torch.no_grad()
def compute_channel_stats(paths, cfg):
    """
    Compute per-channel mean/std over provided paths.
    In your real pipeline, compute on TRAIN split only.
    """
    _, ld = make_loader(paths, cfg, shuffle=False)
    sum_c, sumsq_c, total = None, None, 0
    for batch in tqdm(ld, desc="Compute per-channel mean/std"):
        x = batch["patch"].float()   # [B,C,H,W] in [0,1]
        B, C, H, W = x.shape
        x = x.view(B, C, -1)
        if sum_c is None:
            sum_c   = x.sum(dim=(0,2))            # [C]
            sumsq_c = (x**2).sum(dim=(0,2))       # [C]
        else:
            sum_c   += x.sum(dim=(0,2))
            sumsq_c += (x**2).sum(dim=(0,2))
        total += B * H * W
    mu  = (sum_c / (total + 1e-8)).cpu().numpy()
    var = (sumsq_c / (total + 1e-8)).cpu().numpy() - mu**2
    std = np.sqrt(np.maximum(var, 1e-12))
    return mu, std

def _paths_signature(paths):
    """
    Stable, short signature for the given path list to key the stats cache.
    Uses count + SHA1 over first/last 100 paths to avoid huge strings.
    """
    n = len(paths)
    head = paths[:100]
    tail = paths[-100:]
    sig_src = json.dumps([n, head, tail], separators=(',', ':')).encode('utf-8')
    return hashlib.sha1(sig_src).hexdigest()

def _stats_cache_path(paths, cfg):
    sig = _paths_signature(paths)
    fname = f"mu_std_c{int(cfg.channels_per_view)}_in{int(cfg.input_size)}_{sig}.npz"
    return os.path.join(STATS_DIR, fname)

def load_or_compute_stats(paths, cfg, force_recompute=False):
    """
    Load cached mu/std if present; otherwise compute once and cache.
    """
    cache_path = _stats_cache_path(paths, cfg)
    if (not force_recompute) and os.path.exists(cache_path):
        z = np.load(cache_path)
        mu, std = z["mu"], z["std"]
        print(f"[STATS] Loaded cached mu/std from: {cache_path}")
        return mu, std, cache_path

    print("[STATS] Computing mu/std (once) ...")
    mu, std = compute_channel_stats(paths, cfg)
    np.savez_compressed(cache_path, mu=mu, std=std, meta=dict(
        channels_per_view=int(cfg.channels_per_view),
        input_size=int(cfg.input_size),
        crop_size=int(cfg.crop_size),
        n_paths=len(paths)
    ))
    print(f"[STATS] Saved mu/std to: {cache_path}")
    return mu, std, cache_path

# -------------- EXTRACTION: MEAN CLS ONLY (FAST) --------------
@torch.no_grad()
def extract_offtheshelf_mean(paths, cfg, timm_id, patch_multiple, pretrained=True,
                             mu=None, std=None, target_size=224):
    """
    Mean-CLS aggregation:
      - [B,C,H,W] -> z-score per channel (using cached mu/std)
      - reshape to [B*C,1,H,W], replicate to RGB -> [B*C,3,H,W]
      - center-crop to patch multiple & resize to target_size
      - run frozen backbone in channel chunks
      - get CLS: [B*C,D] -> reshape [B,C,D] -> mean over C -> [B,D]
    """
    # Loader
    ds, ld = make_loader(paths, cfg, shuffle=False)

    # Enforce mu/std provided
    assert mu is not None and std is not None, "mu/std must be provided (use load_or_compute_stats)."
    mu_t = torch.tensor(mu, device=device).view(1, -1, 1, 1)
    sd_t = torch.tensor(std, device=device).view(1, -1, 1, 1).clamp_min(1e-6)

    # Backbone + AMP context
    bb = FrozenBackbone(timm_id, pretrained=pretrained).to(device)
    autocast_ctx = (torch.autocast(device_type="cuda", dtype=torch.float16)
                    if (AMP and torch.cuda.is_available()) else nullcontext())

    # Adjust channel stepping for large targets to save VRAM
    step = CHANNELS_PER_STEP if target_size == 224 else max(4, CHANNELS_PER_STEP // 2)

    all_embs, all_ids = [], []

    for batch in tqdm(ld, desc=f"Extract ({timm_id})"):
        x = batch["patch"].to(device=device).float()  # [B,C,H,W] in [0,1]
        ids = batch["path"]
        B, C, H, W = x.shape

        # Z-score per channel
        x = (x - mu_t[:, :C]) / sd_t[:, :C]

        # Flatten channels into batch dimension: [B*C,1,H,W]
        x_flat = x.permute(0, 2, 3, 1).contiguous().view(B * C, 1, H, W)
        # Replicate to RGB
        x_rgb = x_flat.repeat(1, 3, 1, 1)  # [B*C,3,H,W]

        # Chunk over channels for memory control
        cls_chunks = []
        with autocast_ctx:
            for s in range(0, B * C, step):
                e = min(s + step, B * C)
                xr = crop_resize_to_target(x_rgb[s:e], target=target_size, patch_multiple=patch_multiple)
                cls = bb(xr)                             # [N,D]
                cls_chunks.append(cls.float())

        cls_all = torch.cat(cls_chunks, dim=0)          # [B*C,D]
        D = cls_all.shape[1]
        cls_all = cls_all.view(B, C, D)                 # [B,C,D]
        mean_emb = cls_all.mean(dim=1)                  # [B,D]
        mean_emb = F.normalize(mean_emb, dim=-1)        # L2 normalize
        all_embs.append(mean_emb.cpu().numpy())
        all_ids.extend(ids)

    feats = np.concatenate(all_embs, axis=0)
    return feats, all_ids

# -------------- MODEL MENU --------------
# key: output folder; timm: timm model id; patch: 14 or 16; pretrained: True/False; target: input resolution
MODELS = [
    # ImageNet-supervised
    dict(key="imagenet_deit_s16", timm="deit_small_distilled_patch16_224",         patch=16, pretrained=True,  target=224),
    dict(key="imagenet_deit_b16", timm="deit_base_distilled_patch16_224",         patch=16, pretrained=True,  target=224),

    # DINOv2 (reg4 expects 518)
    dict(key="dinov2_vits14",     timm="vit_small_patch14_reg4_dinov2.lvd142m",    patch=14, pretrained=True,  target=518),
    dict(key="dinov2_vitb14",     timm="vit_base_patch14_reg4_dinov2.lvd142m",     patch=14, pretrained=True,  target=518),

    # MAE
    dict(key="mae_vitb16",        timm="vit_base_patch16_224.mae",                 patch=16, pretrained=True,  target=224),

    # Random-init controls
    dict(key="deit_s16_random",   timm="deit_small_distilled_patch16_224",         patch=16, pretrained=False, target=224),
    dict(key="imagenet_deit_b16_random",   timm="deit_base_distilled_patch16_224",         patch=16, pretrained=False, target=224),
    dict(key="dinov2_vits14_random",timm="vit_small_patch14_reg4_dinov2",            patch=14, pretrained=False, target=518),
    dict(key="dinov2_vitb14_random",timm="vit_base_patch14_reg4_dinov2",             patch=14, pretrained=False, target=518),
    dict(key="mae_vitb16_random",        timm="vit_base_patch16_224.mae",                 patch=16, pretrained=False,  target=224),
]

# Uncomment to debug a subset
# MODELS = MODELS[:2]

# -------------- PRECOMPUTE / LOAD STATS ONCE --------------
mu_cached, std_cached, stats_path = load_or_compute_stats(all_paths, BASE_CFG, force_recompute=False)

# -------------- RUN EXTRACTION (MEAN ONLY) --------------
for m in MODELS:
    key, timm_id, pm, use_pt, tgt = m["key"], m["timm"], m["patch"], m["pretrained"], m["target"]
    out_dir = os.path.join(OUT_DIR, f"{key}")
    feats_path = os.path.join(out_dir, "image_feats.npy")

    # Skip if already extracted
    if os.path.exists(feats_path):
        print(f"[SKIP] Found existing feats for {key}: {feats_path}")
        continue

    print(f"[INFO] Extracting MEAN-CLS for {key} :: {timm_id} (patch-multiple={pm}, pretrained={use_pt}, target={tgt})")
    os.makedirs(out_dir, exist_ok=True)

    cfg_m = CFG(**BASE_CFG.__dict__)
    cfg_m.patch_size = pm

    try:
        feats, ids = extract_offtheshelf_mean(
            all_paths, cfg_m, timm_id,
            patch_multiple=pm, pretrained=use_pt,
            mu=mu_cached, std=std_cached,
            target_size=tgt
        )
    except Exception as e:
        print(f"[ERROR] Extraction failed for {key}: {e}")
        continue

    np.save(feats_path, feats)
    pd.DataFrame({"sample_path": ids}).to_csv(os.path.join(out_dir, "index.csv"), index=False)
    meta = dict(
        timm=timm_id, patch_multiple=int(pm), agg="mean",
        embed_dim=int(feats.shape[1]), pretrained=bool(use_pt),
        adapter="per-channel-replicate", input_size=int(tgt),
        channels_per_view=int(cfg_m.channels_per_view),
        batch_size=int(cfg_m.batch_size), channels_per_step=int(CHANNELS_PER_STEP),
        stats_path=os.path.relpath(stats_path, out_dir)
    )
    json.dump(meta, open(os.path.join(out_dir, "meta.json"), "w"), indent=2)
    print(f"[OK] Saved MEAN-CLS feats to: {out_dir}")

  from .autonotebook import tqdm as notebook_tqdm


[STATS] Loaded cached mu/std from: fm_ssl_run\pretrained_feats2\_stats_cache\mu_std_c64_in256_8f39161ffd68646249a4caf6a1d88d7330ab27e1.npz
[SKIP] Found existing feats for imagenet_deit_s16: fm_ssl_run\pretrained_feats2\imagenet_deit_s16\image_feats.npy
[SKIP] Found existing feats for imagenet_deit_b16: fm_ssl_run\pretrained_feats2\imagenet_deit_b16\image_feats.npy
[SKIP] Found existing feats for dinov2_vits14: fm_ssl_run\pretrained_feats2\dinov2_vits14\image_feats.npy
[SKIP] Found existing feats for dinov2_vitb14: fm_ssl_run\pretrained_feats2\dinov2_vitb14\image_feats.npy
[SKIP] Found existing feats for mae_vitb16: fm_ssl_run\pretrained_feats2\mae_vitb16\image_feats.npy
[SKIP] Found existing feats for deit_s16_random: fm_ssl_run\pretrained_feats2\deit_s16_random\image_feats.npy
[INFO] Extracting MEAN-CLS for imagenet_deit_b16_random :: deit_base_distilled_patch16_224 (patch-multiple=16, pretrained=False, target=224)


Extract (deit_base_distilled_patch16_224): 100%|██████████| 493/493 [06:51<00:00,  1.20it/s]


[OK] Saved MEAN-CLS feats to: fm_ssl_run\pretrained_feats2\imagenet_deit_b16_random
[SKIP] Found existing feats for dinov2_vits14_random: fm_ssl_run\pretrained_feats2\dinov2_vits14_random\image_feats.npy
[SKIP] Found existing feats for dinov2_vitb14_random: fm_ssl_run\pretrained_feats2\dinov2_vitb14_random\image_feats.npy
[INFO] Extracting MEAN-CLS for mae_vitb16_random :: vit_base_patch16_224.mae (patch-multiple=16, pretrained=False, target=224)


Extract (vit_base_patch16_224.mae): 100%|██████████| 493/493 [05:59<00:00,  1.37it/s]

[OK] Saved MEAN-CLS feats to: fm_ssl_run\pretrained_feats2\mae_vitb16_random





In [2]:
# ===================== FAST PIXEL BASELINES =====================
# Option A (default): Sparse Random Projection (RP) to 256 dims
#   - No training, super fast, JL guarantee
# Option B: Accelerated IncrementalPCA with spatial downsampling
#   - If you still want PCA, set USE_IPCA=True

from sklearn.random_projection import SparseRandomProjection

USE_IPCA = False          # << flip to True if you really want IPCA
RP_N_COMPONENTS = 256
IPCA_N_COMPONENTS = 256
DOWNSAMPLE_HW = 64        # << downsample H,W to this before PCA (was 256)

def _downsample_hw(x, hw=DOWNSAMPLE_HW):
    # x: [B,C,H,W] -> [B,C,hw,hw]
    if x.shape[-1] == hw and x.shape[-2] == hw:
        return x
    return F.interpolate(x, size=(hw, hw), mode="area")

# ---------- Option A: Random Projection (fast) ----------
@torch.no_grad()
def extract_random_projection_feats(paths, cfg, n_components=RP_N_COMPONENTS, pool_hw=DOWNSAMPLE_HW):
    """
    Fast, training-free JL projection baseline.
    Steps:
      - Load [B,C,H,W] in [0,1]
      - Downsample to [B,C,hw,hw] (area pooling)
      - Flatten to [B, C*hw*hw]
      - SparseRandomProjection to n_components, per-batch
    """
    ds, ld = make_loader(paths, cfg, shuffle=False)
    feats, ids = [], []
    rp = None  # instantiated after seeing input dim

    for batch in tqdm(ld, desc=f"RP({n_components}) transform"):
        x = batch["patch"].float()                  # [B,C,H,W]
        x = _downsample_hw(x, hw=pool_hw)           # [B,C,hw,hw]
        B, C, H, W = x.shape
        arr = x.permute(0,2,3,1).contiguous().view(B, -1).cpu().numpy()  # [B, C*H*W]

        if rp is None:
            rp = SparseRandomProjection(n_components=n_components)  # defines components from input dim
            z = rp.fit_transform(arr)  # first batch fits the random matrix shape
        else:
            z = rp.transform(arr)
        # L2 norm
        z = z / (np.linalg.norm(z, axis=1, keepdims=True) + 1e-8)
        feats.append(z.astype(np.float32))
        ids.extend(batch["path"])
    feats = np.concatenate(feats, 0)
    meta = dict(
        method="sparse_random_projection",
        n_components=int(n_components),
        downsample_hw=int(pool_hw),
        D_input=int(C*H*W),
        note="JL projection; no training"
    )
    return feats, ids, meta

# ---------- Option B: Accelerated IPCA (still reasonably fast) ----------
def extract_pca_feats_fast(paths, cfg, n_components=IPCA_N_COMPONENTS, batch_bytes_target=768*1024*1024):
    """
    Faster IPCA by (1) aggressive downsampling to DOWNSAMPLE_HW,
    (2) ensuring first partial_fit sees >= n_components samples,
    (3) larger warmup batches, and (4) keeping data on CPU float32.
    """
    ds, ld = make_loader(paths, cfg, shuffle=False)
    H, W, C = DOWNSAMPLE_HW, DOWNSAMPLE_HW, int(cfg.channels_per_view)
    D = H * W * C

    # Memory-based batch size (in samples)
    bytes_per_sample = D * 4  # float32
    mem_based = max(1, int(batch_bytes_target // bytes_per_sample))

    total_samples = len(paths)
    n_comp_eff = min(n_components, total_samples)
    if n_comp_eff < n_components:
        print(f"[IPCA] Reducing n_components {n_components} -> {n_comp_eff} (only {total_samples} samples).")
    batch_n = max(mem_based, n_comp_eff)

    ipca = IncrementalPCA(n_components=n_comp_eff, batch_size=batch_n)

    # -------- Pass 1: fit --------
    buf = []
    buf_rows = 0
    for batch in tqdm(ld, desc="IPCA Pass1 (fit, downsampled)"):
        x = batch["patch"].float()                  # [B,C,256,256]
        x = _downsample_hw(x, hw=DOWNSAMPLE_HW)     # [B,C,H,W] smaller
        B, C2, H2, W2 = x.shape
        arr = x.permute(0,2,3,1).contiguous().view(B, -1).cpu().numpy()  # [B, D]
        buf.append(arr); buf_rows += arr.shape[0]

        if buf_rows >= batch_n:
            ipca.partial_fit(np.vstack(buf)); buf, buf_rows = [], 0
    if buf_rows > 0:
        ipca.partial_fit(np.vstack(buf)); buf, buf_rows = [], 0

    # -------- Pass 2: transform --------
    feats, ids = [], []
    for batch in tqdm(ld, desc="IPCA Pass2 (transform)"):
        x = batch["patch"].float()
        x = _downsample_hw(x, hw=DOWNSAMPLE_HW)
        B, C2, H2, W2 = x.shape
        arr = x.permute(0,2,3,1).contiguous().view(B, -1).cpu().numpy()
        z = ipca.transform(arr)
        z = z / (np.linalg.norm(z, axis=1, keepdims=True) + 1e-8)
        feats.append(z.astype(np.float32))
        ids.extend(batch["path"])

    feats = np.concatenate(feats, 0)
    meta = dict(
        method="incremental_pca",
        n_components=int(ipca.n_components_),
        downsample_hw=int(DOWNSAMPLE_HW),
        D_input=int(D),
        batch_size_used=int(batch_n)
    )
    return feats, ids, meta

# ---------- Call: choose RP (fast) or IPCA (fallback) ----------
try:
    out_dir = os.path.join(OUT_DIR, "pixels_fast256")
    feats_path = os.path.join(out_dir, "image_feats.npy")
    if os.path.exists(feats_path):
        print(f"[SKIP] Fast pixel baseline already exists: {feats_path}")
    else:
        os.makedirs(out_dir, exist_ok=True)
        if USE_IPCA:
            print("[INFO] Extracting FAST IPCA pixel baseline ...")
            p_feats, p_ids, p_meta = extract_pca_feats_fast(all_paths, BASE_CFG, n_components=IPCA_N_COMPONENTS)
        else:
            print("[INFO] Extracting RANDOM PROJECTION pixel baseline ...")
            p_feats, p_ids, p_meta = extract_random_projection_feats(all_paths, BASE_CFG, n_components=RP_N_COMPONENTS)

        np.save(feats_path, p_feats)
        pd.DataFrame({"sample_path": p_ids}).to_csv(os.path.join(out_dir, "index.csv"), index=False)
        json.dump(p_meta, open(os.path.join(out_dir, "meta.json"), "w"), indent=2)
        print(f"[OK] Saved fast pixel baseline to: {out_dir} ({p_meta['method']}, n_components={p_meta['n_components']})")
except Exception as e:
    print(f"[WARN] Fast pixel baseline failed: {e}")

[INFO] Extracting RANDOM PROJECTION pixel baseline ...


RP(256) transform:   3%|▎         | 13/493 [00:21<13:23,  1.67s/it]


KeyboardInterrupt: 

### DOWNSTREAM EVALUATION ON HELD-OUT SPLITS with Baseline Models

In [None]:
# =====================================================
# DOWNSTREAM EVALUATION ON HELD-OUT SPLITS (multi-baseline)
# - Canonicalize/merge label variants
# - Evaluate many embedding sets with linear probe & k-NN
# - Extras: Few-shot curves & k-NN sensitivity
# - Save per-model results + combined CSV + plots
# - NEW: Skip guards to avoid recomputing existing baselines
# =====================================================

import os, copy, re, glob, json
import numpy as np
import pandas as pd

from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, f1_score, classification_report
from sklearn.model_selection import ParameterGrid

import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

SEED = 6740
np.random.seed(SEED)

# -------------------------
# Speed / behavior knobs
# -------------------------
FAST_MODE = True                       # flip False for full evaluation
FEW_SHOT_GRID = [1, 5, 10, 25, None]   # None means "All train"
K_GRID = [1, 5, 10, 20, 50]
K_FOR_KNN = 10 if FAST_MODE else 20
C_GRID = [0.1, 1.0] if FAST_MODE else [0.01, 0.1, 0.5, 1.0, 2.0, 5.0]
MAX_ITER_LR = 800 if FAST_MODE else 2000
MIN_CLASS_COUNT = 100 if FAST_MODE else 50
PCA_DIM = None                         # e.g., 256 to speed up; None disables
METABOFM_VERSION = "last"

# -------------------------
# Skips & guards
# -------------------------
ONLY_EVAL_EXISTING = True     # only run models that already have embeddings saved
SKIP_IF_RESULTS_EXIST = True  # don't recompute per-model results if CSVs already exist
FORCE_REEVAL = False          # set True to ignore SKIP_IF_RESULTS_EXIST

FEATS_PRIMARY = "image_feats.npy"
FEATS_ALT_LAST = "image_feats_last.npy"
INDEX_FILE = "index.csv"

# keep a simple constant for the index
REQUIRED_INDEX = INDEX_FILE

def get_feats_path(model_tag: str, emb_dir: str) -> str | None:
    """Return the feature path for the given model."""
    if model_tag == "metabofm":
        fname = FEATS_PRIMARY if METABOFM_VERSION == "best" else FEATS_ALT_LAST
        path = os.path.join(emb_dir, fname)
        if os.path.exists(path):
            print(f"[INFO] Using MetaboFM-{METABOFM_VERSION.upper()} embeddings: {path}")
            return path
        else:
            print(f"[WARN] MetaboFM-{METABOFM_VERSION.upper()} embeddings not found at {path}")
            return None
    else:
        path = os.path.join(emb_dir, FEATS_PRIMARY)
        return path if os.path.exists(path) else None

def embeddings_ready(model_tag: str, emb_dir: str) -> bool:
    feats_path = get_feats_path(model_tag, emb_dir)
    return feats_path is not None and os.path.exists(os.path.join(emb_dir, REQUIRED_INDEX))

def resolve_model_outdir(model_tag: str, emb_dir: str) -> str:
    if model_tag == "metabofm":
        return TRAIN_OUT
    return emb_dir

def results_already_done(model_tag: str, emb_dir: str) -> bool:
    primary_out_dir = resolve_model_outdir(model_tag, emb_dir)
    res_csv = os.path.join(primary_out_dir, "downstream_results.csv")
    return os.path.exists(res_csv)

def inputs_newer_than_results(model_tag: str, emb_dir: str) -> bool:
    """If inputs are newer than results, you may want to recompute."""
    primary_out_dir = resolve_model_outdir(model_tag, emb_dir)
    res_csv = os.path.join(primary_out_dir, "downstream_results.csv")
    if not os.path.exists(res_csv):
        return True
    res_mtime = os.path.getmtime(res_csv)
    latest_input_mtime = max(
        os.path.getmtime(os.path.join(emb_dir, f))
        for f in REQUIRED_FILES
        if os.path.exists(os.path.join(emb_dir, f))
    )
    return latest_input_mtime > res_mtime

def inputs_newer_than_results(model_tag: str, emb_dir: str) -> bool:
    primary_out_dir = resolve_model_outdir(model_tag, emb_dir)
    res_csv = os.path.join(primary_out_dir, "downstream_results.csv")
    if not os.path.exists(res_csv):
        return True
    res_mtime = os.path.getmtime(res_csv)

    feats_path = get_feats_path(model_tag, emb_dir)
    idx_path = os.path.join(emb_dir, REQUIRED_INDEX)
    inputs = [p for p in [feats_path, idx_path] if p and os.path.exists(p)]
    if not inputs:
        return True
    latest_input_mtime = max(os.path.getmtime(p) for p in inputs)
    return latest_input_mtime > res_mtime

# -------------------------
# Paths / run root
# -------------------------
TRAIN_OUT = os.path.join("fm_ssl_run", "20251007_235600")
print("[INFO] Using TRAIN_OUT:", TRAIN_OUT)

# Metadata + splits
SPLIT_CSV    = os.path.join(TRAIN_OUT, "splits_by_dataset_id.csv")
IDX_PARQUET  = "metaspace_images_dump/msi_fm_samples3.parquet"
MAN_PARQUET  = "metaspace_images_dump/manifest_expanded.parquet"

# Centralized benchmark outputs
BENCH_OUT = os.path.join(TRAIN_OUT, "baseline_eval")
PLOTS_OUT = os.path.join(BENCH_OUT, "plots")
os.makedirs(BENCH_OUT, exist_ok=True)
os.makedirs(PLOTS_OUT, exist_ok=True)

# -------------------------
# Which embedding sets to evaluate
# (ensure these folders contain image_feats.npy + index.csv)
# -------------------------
EMB_SETS = {
    # Off-the-shelf baselines
    "imagenet_deit_s16":      os.path.join("fm_ssl_run", "pretrained_feats2", "imagenet_deit_s16"),
    "imagenet_deit_b16":      os.path.join("fm_ssl_run", "pretrained_feats2", "imagenet_deit_b16"),
    "dinov2_vitb14":          os.path.join("fm_ssl_run", "pretrained_feats2", "dinov2_vitb14"),
    "mae_vitb16":             os.path.join("fm_ssl_run", "pretrained_feats2", "mae_vitb16"),

    # New recommended baselines
    "dinov2_vits14":          os.path.join("fm_ssl_run", "pretrained_feats2", "dinov2_vits14"),
    "deit_s16_random":        os.path.join("fm_ssl_run", "pretrained_feats2", "deit_s16_random"),
    "imagenet_deit_b16_random":        os.path.join("fm_ssl_run", "pretrained_feats2", "imagenet_deit_b16_random"),
    "dinov2_vits14_random":        os.path.join("fm_ssl_run", "pretrained_feats2", "dinov2_vits14_random"),
    "dinov2_vitb14_random":        os.path.join("fm_ssl_run", "pretrained_feats2", "dinov2_vitb14_random"),
    "mae_vitb16_random":             os.path.join("fm_ssl_run", "pretrained_feats2", "mae_vitb16_random"),
    #"pca256_pixels":          os.path.join("fm_ssl_run", "pretrained_feats2", "pixels_fast256"),
}

# -------------------------
# Canonicalization / merging rules
# -------------------------
def _clean(s):
    if pd.isna(s): return None
    s = str(s).strip()
    s = re.sub(r"\s+", " ", s)
    return s

def canonicalize_labels(df):
    df = df.copy()

    # 1) Polarity
    pol_map = {"pos":"Positive","positive":"Positive","+":"Positive",
               "neg":"Negative","negative":"Negative","-":"Negative"}
    def canon_polarity(s):
        if s is None: return None
        t = _clean(s).lower()
        t2 = pol_map.get(t, t)
        if t2 in ("positive","negative"):
            return t2.capitalize()
        if "pos" in t: return "Positive"
        if "neg" in t: return "Negative"
        return _clean(s)
    if "polarity" in df.columns:
        df["polarity"] = df["polarity"].map(canon_polarity)

    # 2) Ionisation Source (map DESI-MSI -> DESI, etc.)
    def canon_ion_src(s):
        if s is None: return None
        t_raw = _clean(s)
        t = t_raw.upper().replace("-", "").replace("_","")
        if "APSMALDI" in t: return "AP-SMALDI"
        if "IRMALDESI" in t or "IRMALDI" in t: return "IR-MALDESI"
        if "APMALDI" in t: return "AP-MALDI"
        if "DESIMSI" in t: return "DESI"
        if "DESI" in t: return "DESI"
        if "MALDI" in t: return "MALDI"
        return t_raw
    if "ionisationSource" in df.columns:
        df["ionisationSource"] = df["ionisationSource"].map(canon_ion_src)

    # 3) Analyzer Type
    def canon_analyzer(s):
        if s is None: return None
        t = _clean(s); tl = t.lower()
        if "timstof" in tl and "flex" in tl: return "timsTOF Flex"
        if "fticr" in tl:
            if "12t" in tl: return "12T FTICR"
            if "7t" in tl and "scimax" in tl: return "FTICR scimaX 7T"
            return "FTICR"
        if "orbitrap" in tl or "q-exactive" in tl: return "Orbitrap"
        if "tof" in tl and "reflector" in tl: return "TOF reflector"
        if tl.strip() == "qtof": return "qTOF"
        return t
    if "analyzerType" in df.columns:
        df["analyzerType"] = df["analyzerType"].map(canon_analyzer)

    # 4) Organism
    def canon_organism(s):
        if s is None: return None
        t = _clean(s); tl = t.lower()
        if "|" in t or "," in t:
            if ("human" in tl or "homo sapiens" in tl) and ("mouse" in tl or "mus musculus" in tl):
                return "Mixed"
        if "homo sapiens" in tl or tl.strip() in {"human","h. sapiens","homo"}:
            return "Homo sapiens"
        if "mus musculus" in tl or tl.strip() in {"mouse","m. musculus"}:
            return "Mus musculus"
        return t
    if "organism" in df.columns:
        df["organism"] = df["organism"].map(canon_organism)

    # 5) Organism_Part
    def canon_part(s):
        if s is None: return None
        t = _clean(s); tl = t.lower()
        if "kidney" in tl: return "Kidney"
        if "brain"  in tl: return "Brain"
        if "liver"  in tl: return "Liver"
        if "lung"   in tl: return "Lung"
        if "breast" in tl: return "Breast"
        if "skin"   in tl: return "Skin"
        if "heart"  in tl or "cardiac" in tl: return "Heart"
        return t
    if "Organism_Part" in df.columns:
        df["Organism_Part"] = df["Organism_Part"].map(canon_part)

    # 6) Condition (keep "NA" but we'll exclude it later)
    def canon_condition(s):
        if s is None: return None
        t = _clean(s); tl = t.lower()
        if tl in {"n/a","na","none","not available",""}: return "NA"
        if tl in {"biopsy","biopsies"}: return "Biopsy"
        if "fresh frozen" in tl or "frozen" in tl: return "Frozen"
        if "tumor" in tl or "tumour" in tl: return "Tumor"
        if "cancer" in tl: return "Cancer"
        if "wildtype" in tl or tl == "wt": return "Wildtype"
        if "healthy" in tl or "control" in tl: return "Healthy"
        if "diseased" in tl or "disease" in tl: return "Diseased"
        return t
    if "Condition" in df.columns:
        df["Condition"] = df["Condition"].map(canon_condition)

    return df

# -------------------------
# Load metadata + splits (once)
# -------------------------
idx = pd.read_parquet(IDX_PARQUET)
man = pd.read_parquet(MAN_PARQUET)

need_cols = [
    "dataset_id", "organism", "polarity", "Organism_Part", "Condition",
    "analyzerType", "ionisationSource"
]
man_sub = man[[c for c in need_cols if c in man.columns]].drop_duplicates("dataset_id")

df_meta = idx.merge(man_sub, on="dataset_id", how="left", suffixes=("", "_man"))
df_meta = df_meta.loc[:, ~df_meta.columns.duplicated()].copy().reset_index(drop=True)
splits = pd.read_csv(SPLIT_CSV)
df_meta = df_meta.merge(splits, on="dataset_id", how="left")

# Dedup metadata by sample_path and canonicalize labels
if df_meta.duplicated("sample_path").sum():
    print("[WARN] duplicate sample_path rows in metadata; keeping first.")
    df_meta = df_meta.drop_duplicates("sample_path", keep="first").reset_index(drop=True)
df_meta = canonicalize_labels(df_meta)

# -------------------------
# Common helpers
# -------------------------
TASKS = ["organism", "polarity", "Organism_Part", "Condition", "analyzerType", "ionisationSource"]
EXCLUDE_LABELS = {"Condition": {"NA"}}

def filter_valid(df_task, yname, min_count=5):
    x = df_task.dropna(subset=[yname]).copy()
    if yname in EXCLUDE_LABELS:
        x = x[~x[yname].isin(EXCLUDE_LABELS[yname])]
    x = x[x[yname].astype(str).str.len() > 0]
    vc = x[yname].value_counts()
    keep = vc[vc >= min_count].index
    x = x[x[yname].isin(keep)].copy()
    return x

def few_shot_subset(df_task, yname, shots_per_class=None, seed=SEED):
    if not shots_per_class or shots_per_class <= 0:
        return (df_task["split"] == "train").values
    rng = np.random.RandomState(seed)
    m_train = (df_task["split"] == "train").values
    keep = np.zeros(len(df_task), dtype=bool)
    labels = df_task.loc[m_train, yname].astype(str).values
    idxs   = np.where(m_train)[0]
    from collections import defaultdict
    per_class = defaultdict(list)
    for i, lbl in zip(idxs, labels):
        per_class[lbl].append(i)
    for lbl, arr in per_class.items():
        arr = np.array(arr)
        rng.shuffle(arr)
        keep[arr[:min(shots_per_class, len(arr))]] = True
    return keep

def get_mask(df_all, split_name):
    return (df_all["split"] == split_name).values

def run_linear_probe(X_tr, y_tr, X_va, y_va, X_te, y_te):
    pipe = Pipeline([
        ("scaler", StandardScaler(with_mean=True, with_std=True)),
        ("clf", LogisticRegression(
            solver="saga",
            max_iter=MAX_ITER_LR,
            class_weight="balanced",
            random_state=SEED,
            n_jobs=-1
        ))
    ])
    grid = {"clf__C": C_GRID}
    best = None; best_va = -np.inf
    for p in tqdm(list(ParameterGrid(grid)), desc="LinearProbe grid", leave=False):
        pipe.set_params(**p)
        pipe.fit(X_tr, y_tr)
        pred_va = pipe.predict(X_va)
        macro_f1 = f1_score(y_va, pred_va, average="macro")
        if macro_f1 > best_va:
            best_va = macro_f1
            best = copy.deepcopy(pipe)
    pred_te = best.predict(X_te)
    acc = accuracy_score(y_te, pred_te)
    f1m = f1_score(y_te, pred_te, average="macro")
    return acc, f1m, pred_te, best

def run_knn(X_tr, y_tr, X_te, y_te, k=20):
    n_fit = int(X_tr.shape[0])
    # need at least 1 sample and 2 classes to classify
    if n_fit < 1 or len(np.unique(y_tr)) < 2:
        return np.nan, np.nan, np.array([], dtype=object), None, 0

    k_eff = max(1, min(k, n_fit))
    knn = KNeighborsClassifier(n_neighbors=k_eff, metric="cosine", n_jobs=-1)
    knn.fit(X_tr, y_tr)

    if X_te.shape[0] == 0:
        return np.nan, np.nan, np.array([], dtype=object), knn, k_eff

    pred_te = knn.predict(X_te)
    acc = accuracy_score(y_te, pred_te)
    f1m = f1_score(y_te, pred_te, average="macro")
    return acc, f1m, pred_te, knn, k_eff

def per_class_f1(y_true, y_pred):
    rep = classification_report(y_true, y_pred, output_dict=True, zero_division=0)
    out = {k: v["f1-score"] for k, v in rep.items() if k not in ("accuracy","macro avg","weighted avg")}
    return out

# -------------------------
# Evaluate ONE embedding set dir
# -------------------------
def evaluate_embeddings(emb_dir: str, model_tag: str):
    feats_path = get_feats_path(model_tag, emb_dir)
    index_path = os.path.join(emb_dir, INDEX_FILE)
    if not (feats_path and os.path.exists(feats_path) and os.path.exists(index_path)):
        print(f"[WARN] Missing feats or index for {model_tag} at {emb_dir}; "
              f"looked for {FEATS_PRIMARY}"
              f"{' or ' + FEATS_ALT_LAST if model_tag=='metabofm' else ''} and {INDEX_FILE}. Skipping.")
        return pd.DataFrame(), pd.DataFrame(), pd.DataFrame()

    # Optional: let the logs show which file was used
    if os.path.basename(feats_path) == FEATS_ALT_LAST:
        print(f"[INFO] ({model_tag}) Using LAST embeddings: {feats_path}")
    else:
        print(f"[INFO] ({model_tag}) Using BEST embeddings: {feats_path}")

    emb = np.load(feats_path)
    index = pd.read_csv(index_path)      # must include column 'sample_path'
    index = index.reset_index().rename(columns={"index": "row_id"})

    # Join metadata to index
    df_base = df_meta.copy()
    df_emb = df_base.merge(index, on="sample_path", how="inner")
    df_emb = df_emb.sort_values("row_id").reset_index(drop=True)

    if emb.shape[0] != len(df_emb):
        print(f"[INFO] ({model_tag}) Embedding count != joined rows; aligning to matched rows only.")
        emb = emb[df_emb["row_id"].values, :]
    if emb.shape[0] != len(df_emb):
        raise RuntimeError(f"({model_tag}) Embedding count and joined rows mismatch after alignment.")

    # Optional PCA
    if PCA_DIM is not None and PCA_DIM > 0 and PCA_DIM < emb.shape[1]:
        print(f"[INFO] ({model_tag}) Reducing dim to {PCA_DIM} via PCA.")
        from sklearn.decomposition import PCA
        pca = PCA(n_components=PCA_DIM, random_state=SEED)
        emb = pca.fit_transform(emb)

    df_emb["row_pos"] = np.arange(len(df_emb), dtype=int)

    # ---- Main evaluation (All-train or FEW_SHOT_TRAIN) ----
    results = []
    pcs_rows = []   # per-class F1 (linear probe) rows
    kgrid_rows = [] # k-sweep rows
    for yname in tqdm(TASKS, desc=f"Tasks[{model_tag}]"):
        if yname not in df_emb.columns:
            print(f"[WARN] ({model_tag}) Missing column {yname}; skipping.")
            continue

        print(f"\n==== [{model_tag}] Task: {yname} ====")
        df_task = filter_valid(df_emb, yname, min_count=MIN_CLASS_COUNT)
        if df_task.empty:
            print(f"[WARN] ({model_tag}) Skipping {yname}: no data after filtering.")
            continue
        if df_task["split"].isna().any():
            df_task = df_task[~df_task["split"].isna()].copy()

        row_pos = df_task["row_pos"].values
        X = emb[row_pos]
        y = df_task[yname].astype(str).values
        m_tr_all = get_mask(df_task, "train")
        m_va     = get_mask(df_task, "val")
        m_te     = get_mask(df_task, "test")

        def _ok(mask):
            return np.sum(mask) > 0 and (len(np.unique(y[mask])) > 1)

        # ---- Few-shot sweep (uses val for C tuning; test for final) ----
        for shots in FEW_SHOT_GRID:
            m_tr = few_shot_subset(df_task, yname, shots_per_class=shots)
            if not (_ok(m_tr) and _ok(m_va) and _ok(m_te)):
                print(f"[WARN] ({model_tag}) {yname} few-shot={shots}: insufficient classes.")
                continue

            # Linear probe (C grid on val)
            acc_lp, f1_lp, pred_lp, best_lp = run_linear_probe(
                X[m_tr], y[m_tr], X[m_va], y[m_va], X[m_te], y[m_te]
            )
            # k-NN with default K_FOR_KNN
            acc_knn, f1_knn, pred_knn, _, k_eff = run_knn(
                X[m_tr], y[m_tr], X[m_te], y[m_te], k=K_FOR_KNN
            )

            results.append({
                "model": model_tag,
                "task": yname,
                "shots_per_class": shots if shots else 0,
                "test_acc_linear": acc_lp,
                "test_macroF1_linear": f1_lp,
                "test_acc_knn": acc_knn,
                "test_macroF1_knn": f1_knn,
                "k_for_knn": int(k_eff),
                "pca_dim": PCA_DIM if PCA_DIM else X.shape[1],
                "n_train": int(m_tr.sum()),
                "n_val": int(m_va.sum()),
                "n_test": int(m_te.sum()),
                "n_classes": int(len(np.unique(y)))
            })

            # Per-class F1 (only for All-train to keep size modest)
            if shots is None:
                pcs = per_class_f1(y[m_te], pred_lp)
                for cls, f1v in pcs.items():
                    pcs_rows.append({
                        "model": model_tag,
                        "task": yname,
                        "label": cls,
                        "f1": f1v,
                        "shots_per_class": 0  # 0 = All-train
                    })

            # k-sweep on All-train only
            if shots is None:
                for k in K_GRID:
                    acc_k, f1_k, _, _, k_eff = run_knn(X[m_tr], y[m_tr], X[m_te], y[m_te], k=k)
                    if np.isnan(f1_k):
                        continue
                    kgrid_rows.append({
                        "model": model_tag,
                        "task": yname,
                        "k": int(k_eff),
                        "test_macroF1_knn": f1_k
                    })

    res_df = pd.DataFrame(results).sort_values(["task", "model", "shots_per_class"])
    pcs_df = pd.DataFrame(pcs_rows).sort_values(["task", "model", "label"])
    kgrid_df = pd.DataFrame(kgrid_rows).sort_values(["task", "model", "k"])

    # --------- SAVE: per-model to its corresponding location ---------
    primary_out_dir = resolve_model_outdir(model_tag, emb_dir)
    os.makedirs(primary_out_dir, exist_ok=True)

    res_path = os.path.join(primary_out_dir, "downstream_results.csv")
    res_df.to_csv(res_path, index=False)
    if len(pcs_df):
        pcs_path = os.path.join(primary_out_dir, "per_class_f1_linear.csv")
        pcs_df.to_csv(pcs_path, index=False)
    if len(kgrid_df):
        ks_path = os.path.join(primary_out_dir, "knn_k_sweep.csv")
        kgrid_df.to_csv(ks_path, index=False)

    print(f"[OK] ({model_tag}) Saved per-model results to: {primary_out_dir}")

    # Also mirror into BENCH_OUT/model_tag for centralized copy
    mirror_dir = os.path.join(BENCH_OUT, model_tag)
    os.makedirs(mirror_dir, exist_ok=True)
    res_df.to_csv(os.path.join(mirror_dir, "downstream_results.csv"), index=False)
    if len(pcs_df):
        pcs_df.to_csv(os.path.join(mirror_dir, "per_class_f1_linear.csv"), index=False)
    if len(kgrid_df):
        kgrid_df.to_csv(os.path.join(mirror_dir, "knn_k_sweep.csv"), index=False)

    return res_df, pcs_df, kgrid_df

# -------------------------
# Run all baselines + FM and aggregate (with skip guards)
# -------------------------
all_main, all_pcs, all_ks = [], [], []
for tag, emb_dir in EMB_SETS.items():
    # 1) Only run if embeddings are present (for baselines)
    if ONLY_EVAL_EXISTING and not embeddings_ready(tag, emb_dir):
        print(f"[SKIP] '{tag}': embeddings missing at {emb_dir} "
            f"(need {FEATS_PRIMARY} or {FEATS_ALT_LAST} for metabofm, plus {INDEX_FILE}); skipping.")
        continue


    # 2) Skip if results already exist (unless inputs are newer or you force)
    if SKIP_IF_RESULTS_EXIST and results_already_done(tag, emb_dir) and not FORCE_REEVAL:
        if inputs_newer_than_results(tag, emb_dir):
            print(f"[RERUN] '{tag}': inputs newer than results; re-evaluating.")
        else:
            print(f"[SKIP] '{tag}': results already exist and are up-to-date.")
            # Load existing and append to aggregates (and ensure mirror exists)
            try:
                primary_out_dir = resolve_model_outdir(tag, emb_dir)
                res_df = pd.read_csv(os.path.join(primary_out_dir, "downstream_results.csv"))
                all_main.append(res_df)

                pcs_path = os.path.join(primary_out_dir, "per_class_f1_linear.csv")
                if os.path.exists(pcs_path):
                    all_pcs.append(pd.read_csv(pcs_path))
                ks_path = os.path.join(primary_out_dir, "knn_k_sweep.csv")
                if os.path.exists(ks_path):
                    all_ks.append(pd.read_csv(ks_path))

                # Mirror to BENCH_OUT/tag if not already mirrored
                mirror_dir = os.path.join(BENCH_OUT, tag)
                os.makedirs(mirror_dir, exist_ok=True)
                res_df.to_csv(os.path.join(mirror_dir, "downstream_results.csv"), index=False)
                if os.path.exists(pcs_path):
                    pd.read_csv(pcs_path).to_csv(os.path.join(mirror_dir, "per_class_f1_linear.csv"), index=False)
                if os.path.exists(ks_path):
                    pd.read_csv(ks_path).to_csv(os.path.join(mirror_dir, "knn_k_sweep.csv"), index=False)
            except Exception as e:
                print(f"[WARN] Failed to load existing results for '{tag}': {e}")
            continue

    print(f"\n[RUN] Evaluating model '{tag}' from {emb_dir}")
    res_df, pcs_df, kgrid_df = evaluate_embeddings(emb_dir, tag)
    if len(res_df):  all_main.append(res_df)
    if len(pcs_df):  all_pcs.append(pcs_df)
    if len(kgrid_df): all_ks.append(kgrid_df)

[INFO] Using TRAIN_OUT: fm_ssl_run\20251007_235600
[WARN] duplicate sample_path rows in metadata; keeping first.
[SKIP] 'imagenet_deit_s16': results already exist and are up-to-date.
[SKIP] 'imagenet_deit_b16': results already exist and are up-to-date.
[SKIP] 'dinov2_vitb14': results already exist and are up-to-date.
[SKIP] 'mae_vitb16': results already exist and are up-to-date.
[SKIP] 'dinov2_vits14': results already exist and are up-to-date.
[SKIP] 'deit_s16_random': results already exist and are up-to-date.
[SKIP] 'imagenet_deit_b16_random': results already exist and are up-to-date.
[SKIP] 'dinov2_vits14_random': results already exist and are up-to-date.
[SKIP] 'dinov2_vitb14_random': results already exist and are up-to-date.

[RUN] Evaluating model 'mae_vitb16_random' from fm_ssl_run\pretrained_feats2\mae_vitb16_random
[INFO] (mae_vitb16_random) Using BEST embeddings: fm_ssl_run\pretrained_feats2\mae_vitb16_random\image_feats.npy


Tasks[mae_vitb16_random]:   0%|          | 0/6 [00:00<?, ?it/s]


==== [mae_vitb16_random] Task: organism ====


Tasks[mae_vitb16_random]:  17%|█▋        | 1/6 [00:25<02:08, 25.74s/it]


==== [mae_vitb16_random] Task: polarity ====


Tasks[mae_vitb16_random]:  33%|███▎      | 2/6 [00:47<01:34, 23.66s/it]


==== [mae_vitb16_random] Task: Organism_Part ====


Tasks[mae_vitb16_random]:  50%|█████     | 3/6 [01:43<01:54, 38.24s/it]


==== [mae_vitb16_random] Task: Condition ====


Tasks[mae_vitb16_random]:  67%|██████▋   | 4/6 [02:31<01:24, 42.14s/it]


==== [mae_vitb16_random] Task: analyzerType ====


Tasks[mae_vitb16_random]:  83%|████████▎ | 5/6 [03:18<00:43, 43.82s/it]


==== [mae_vitb16_random] Task: ionisationSource ====


Tasks[mae_vitb16_random]: 100%|██████████| 6/6 [04:05<00:00, 40.96s/it]

[OK] (mae_vitb16_random) Saved per-model results to: fm_ssl_run\pretrained_feats2\mae_vitb16_random





### Few-shot Evaluation

In [None]:
# =====================================================
# DOWNSTREAM EVALUATION ON HELD-OUT SPLITS (multi-baseline)
# - Canonicalize/merge label variants
# - Evaluate many embedding sets with linear probe & k-NN
# - Extras: Few-shot curves, k-NN sensitivity, label-efficiency AUC, win-rate, k-robustness
# - Save per-model results + centralized mirrors + combined plots
# - Skip guards to avoid recomputing existing baselines
# - NEW: Save CSV backing each plot
# =====================================================

import os, copy, re, glob, json, hashlib
import numpy as np
import pandas as pd

from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, f1_score, classification_report
from sklearn.model_selection import ParameterGrid

import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

SEED = 6740
np.random.seed(SEED)

# -------------------------
# Speed / behavior knobs
# -------------------------
FAST_MODE = True                       # flip False for full evaluation
FEW_SHOT_GRID = [1, 5, 10, 25, None]   # None means "All train"
K_GRID = [1, 5, 10, 20, 50]
K_FOR_KNN = 10 if FAST_MODE else 20
C_GRID = [0.1, 1.0] if FAST_MODE else [0.01, 0.1, 0.5, 1.0, 2.0, 5.0]
MAX_ITER_LR = 800 if FAST_MODE else 2000
MIN_CLASS_COUNT = 100 if FAST_MODE else 50
PCA_DIM = None                         # e.g., 256 to speed up; None disables
METABOFM_VERSION = "last"              # "best" or "last"

# -------------------------
# Skips & guards
# -------------------------
ONLY_EVAL_EXISTING = True     # only run models that already have embeddings saved
SKIP_IF_RESULTS_EXIST = True  # don't recompute per-model results if CSVs already exist
FORCE_REEVAL = False          # set True to ignore SKIP_IF_RESULTS_EXIST

FEATS_PRIMARY = "image_feats.npy"
FEATS_ALT_LAST = "image_feats_last.npy"
INDEX_FILE = "index.csv"
REQUIRED_INDEX = INDEX_FILE

# -------------------------

COMBINED_BENCH_OUT = os.path.join("fm_ssl_run", "baseline_eval_combined2")
COMBINED_PLOTS_OUT = os.path.join(COMBINED_BENCH_OUT, "plots")
os.makedirs(COMBINED_PLOTS_OUT, exist_ok=True)

# Keep same task list
TASKS = ["organism", "polarity", "Organism_Part", "Condition", "analyzerType", "ionisationSource"]

# Off-the-shelf baselines to include in combined plots
BASELINE_EMB_SETS = {
    "PCA":                         os.path.join("fm_ssl_run", "pretrained_feats2", "pixels_fast256"),
    "DeiT distilled (Random)":    os.path.join("fm_ssl_run", "pretrained_feats2", "imagenet_deit_b16_random"),
    "Dinov2-VIT14 (Random)":      os.path.join("fm_ssl_run", "pretrained_feats2", "dinov2_vitb14_random"),
    "MAE-VIT16 (Random)":         os.path.join("fm_ssl_run", "pretrained_feats2", "mae_vitb16_random"),
    "MAE-VIT16 (ImageNet)":       os.path.join("fm_ssl_run", "pretrained_feats2", "mae_vitb16"),
    "DeiT distilled (ImageNet)":  os.path.join("fm_ssl_run", "pretrained_feats2", "imagenet_deit_b16"),
    "Dinov2-VIT14 (LVD-142M)":    os.path.join("fm_ssl_run", "pretrained_feats2", "dinov2_vitb14"),
}

# -------------------------
# Helpers
# -------------------------
def save_df_csv(df: pd.DataFrame, path: str):
    """Safely save dataframe to CSV (and ensure dir)."""
    os.makedirs(os.path.dirname(path), exist_ok=True)
    df.to_csv(path, index=False)

def load_first_existing_csv(paths):
    for p in paths:
        if p and os.path.exists(p):
            return pd.read_csv(p), p
    return None, None

def read_vit_name_from_run(run_dir: str) -> str:
    rp = os.path.join(run_dir, "run_params.json")
    if os.path.exists(rp):
        try:
            with open(rp, "r", encoding="utf-8") as f:
                data = json.load(f)
            if "cfg" in data and isinstance(data["cfg"], dict):
                vit = data["cfg"].get("vit_name", None)
                if vit:
                    return str(vit)
            vit = data.get("vit_name", None)
            if vit:
                return str(vit)
        except Exception as e:
            print(f"[WARN] Failed reading vit_name from {rp}: {e}")
    return os.path.basename(os.path.normpath(run_dir))

def make_unique_name(base_name: str, existing: set, disambig_hint: str) -> str:
    name = base_name
    if name not in existing:
        return name
    short = disambig_hint.replace("\\", "/").strip("/").split("/")[-1]
    suffix = f" ({short})"
    if name + suffix not in existing:
        return name + suffix
    tiny = hashlib.md5(disambig_hint.encode("utf-8")).hexdigest()[:6]
    alt = f"{name} [{tiny}]"
    if alt not in existing:
        return alt
    i = 2
    while True:
        cand = f"{alt}-{i}"
        if cand not in existing:
            return cand
        i += 1

def primary_results_dir_for_fm_run(run_dir: str) -> str:
    return run_dir

def primary_results_dir_for_baseline(baseline_emb_dir: str) -> str:
    return baseline_emb_dir

# -------------------------
# Build the model list (baselines)
# -------------------------
model_entries = []
seen_names = set()

# Add baselines
for tag, emb_dir in BASELINE_EMB_SETS.items():
    disp = make_unique_name(tag, seen_names, disambig_hint=emb_dir)
    seen_names.add(disp)
    model_entries.append({
        "name": disp,
        "kind": "baseline",
        "primary_dir": primary_results_dir_for_baseline(emb_dir),
        "mirror_dir": os.path.join(COMBINED_BENCH_OUT, disp.replace(os.sep, "_")),
    })

# -------------------------
# Load per-model CSVs
# -------------------------
all_main, all_pcs, all_ks = [], [], []
loaded = []

for m in model_entries:
    name = m["name"]
    primary_dir = m["primary_dir"]
    mirror_dir  = m["mirror_dir"]

    main_paths = [
        os.path.join(primary_dir, "downstream_results.csv"),
        os.path.join(mirror_dir,  "downstream_results.csv"),
    ]
    pcs_paths = [
        os.path.join(primary_dir, "per_class_f1_linear.csv"),
        os.path.join(mirror_dir,  "per_class_f1_linear.csv"),
    ]
    ks_paths = [
        os.path.join(primary_dir, "knn_k_sweep.csv"),
        os.path.join(mirror_dir,  "knn_k_sweep.csv"),
    ]

    df_main, mp = load_first_existing_csv(main_paths)
    if df_main is None:
        print(f"[WARN] No downstream_results.csv for '{name}'. Skipping this model.")
        continue

    df_pcs, pp = load_first_existing_csv(pcs_paths)
    df_ks,  kp = load_first_existing_csv(ks_paths)

    required_cols = {"model","task","shots_per_class","test_macroF1_linear","test_macroF1_knn"}
    missing = required_cols.difference(set(df_main.columns) | {"model"})  # model will be added below
    if missing:
        raise ValueError(f"[ERR] '{name}' main CSV missing columns: {missing} at {mp}")

    df_main = df_main.copy()
    df_main["model"] = name

    all_main.append(df_main)
    if df_pcs is not None:
        df_pcs = df_pcs.copy()
        df_pcs["model"] = name
        all_pcs.append(df_pcs)
    if df_ks is not None:
        if "k" not in df_ks.columns or "test_macroF1_knn" not in df_ks.columns:
            print(f"[WARN] k-sweep CSV for '{name}' missing expected columns; skipping.")
        else:
            df_ks = df_ks.copy()
            df_ks["model"] = name
            all_ks.append(df_ks)

    loaded.append((name, mp, pp, kp))

if not all_main:
    raise SystemExit("[ERR] No per-model CSVs could be loaded.")

bench   = pd.concat(all_main, ignore_index=True)
pcs_all = pd.concat(all_pcs, ignore_index=True) if all_pcs else pd.DataFrame()
ks_all  = pd.concat(all_ks,  ignore_index=True) if all_ks  else pd.DataFrame()

# Save the unified CSVs too
save_df_csv(bench,   os.path.join(COMBINED_PLOTS_OUT, "ALL_downstream_results_merged.csv"))
if len(pcs_all):
    save_df_csv(pcs_all, os.path.join(COMBINED_PLOTS_OUT, "ALL_per_class_f1_linear_merged.csv"))
if len(ks_all):
    save_df_csv(ks_all,  os.path.join(COMBINED_PLOTS_OUT, "ALL_knn_k_sweep_merged.csv"))

print("[OK] Loaded models:")
for name, mp, pp, kp in loaded:
    print(f" - {name}\n    main: {mp}\n    pcs : {pp or '—'}\n    k   : {kp or '—'}")

# -------------------------
# PLOTS (standard + narrative-pivot extras)
# -------------------------
sns.set(style="whitegrid")

# 1) Macro-F1 by task & model (Linear Probe) — barplot (All-train only)
df_bar = bench[bench["shots_per_class"] == 0].copy()
if len(df_bar):
    # Save raw plot data
    save_df_csv(df_bar, os.path.join(COMBINED_PLOTS_OUT, "PLOT_DATA_bar_macroF1_linear_alltrain_raw.csv"))

    plt.figure(figsize=(14, 6))
    order_tasks = [t for t in TASKS if t in df_bar["task"].unique()]
    sns.barplot(
        data=df_bar, x="task", y="test_macroF1_linear", hue="model",
        order=order_tasks, errorbar="ci", dodge=True
    )
    plt.title("Macro-F1 (Linear Probe, TEST) by Task & Model (All-train)")
    plt.ylabel("Macro-F1")
    plt.xlabel("Task")
    plt.xticks(rotation=30, ha="right")
    plt.legend(bbox_to_anchor=(1.02, 1), loc="upper left", title="Model")
    plt.tight_layout()
    plt.savefig(os.path.join(COMBINED_PLOTS_OUT, "bar_macroF1_linear_alltrain.png"), dpi=200)
    plt.close()

# 2) Macro-F1 by task & model (k-NN default K) — barplot (All-train only)
if len(df_bar):
    # Save raw plot data
    save_df_csv(df_bar, os.path.join(COMBINED_PLOTS_OUT, f"PLOT_DATA_bar_macroF1_knn_k{K_FOR_KNN}_alltrain_raw.csv"))

    plt.figure(figsize=(14, 6))
    order_tasks = [t for t in TASKS if t in df_bar["task"].unique()]
    sns.barplot(
        data=df_bar, x="task", y="test_macroF1_knn", hue="model",
        order=order_tasks, errorbar="ci", dodge=True
    )
    plt.title(f"Macro-F1 (k-NN@{K_FOR_KNN}, TEST) by Task & Model (All-train)")
    plt.ylabel("Macro-F1")
    plt.xlabel("Task")
    plt.xticks(rotation=30, ha="right")
    plt.legend(bbox_to_anchor=(1.02, 1), loc="upper left", title="Model")
    plt.tight_layout()
    plt.savefig(os.path.join(COMBINED_PLOTS_OUT, f"bar_macroF1_knn_k{K_FOR_KNN}_alltrain.png"), dpi=200)
    plt.close()

# ============================================
# Few-shot BARPLOTS per task (shots on x-axis)
# ============================================

# Metrics to plot: linear probe & k-NN (macro-F1 on TEST)
METRICS = [
    ("test_macroF1_linear", "Linear Probe", "linear"),
    ("test_macroF1_knn",    "k-NN",         "knn"),
]

df_fs = bench.copy()
if len(df_fs):
    # Map shots_per_class (0 means All-train)
    df_fs["shots_lbl"] = df_fs["shots_per_class"].replace({0: "All"}).astype(str)

    # Desired order of bars on the x-axis
    desired_order = ["1", "5", "10", "25", "All"]

    # Save the global few-shot merged dataframe used for all plots
    save_df_csv(df_fs, os.path.join(COMBINED_PLOTS_OUT, "PLOT_DATA_bar_fewshot_ALL_metrics_RAW.csv"))

    for metric_col, metric_title, metric_key in METRICS:
        if metric_col not in df_fs.columns:
            print(f"[WARN] Metric '{metric_col}' missing in bench; skipping all tasks for this metric.")
            continue

        for tsk in TASKS:
            cur = df_fs[df_fs["task"] == tsk].copy()
            if cur.empty:
                continue

            # Keep only present shot labels but preserve canonical order
            present = [lab for lab in desired_order if lab in set(cur["shots_lbl"])]
            if not present:
                print(f"[WARN] No few-shot settings present for task '{tsk}'. Skipping.")
                continue

            cur["shots_lbl"] = pd.Categorical(cur["shots_lbl"], categories=present, ordered=True)

            # Save per-task/metric plot data
            out_csv = os.path.join(COMBINED_PLOTS_OUT, f"PLOT_DATA_bar_fewshot_{metric_key}_{tsk}.csv")
            save_df_csv(cur[["model","task","shots_per_class","shots_lbl",metric_col]], out_csv)

            plt.figure(figsize=(12, 6))
            sns.barplot(
                data=cur,
                x="shots_lbl", y=metric_col, hue="model",
                order=present, dodge=True, errorbar="ci"
            )
            plt.title(f"Few-shot Macro-F1 ({metric_title}) — Task: {tsk}")
            plt.ylabel("Macro-F1 (TEST)")
            plt.xlabel("Shots per class")
            plt.xticks(rotation=0)
            plt.legend(bbox_to_anchor=(1.02, 1), loc="upper left", title="Model")
            plt.tight_layout()

            out_name = f"bar_fewshot_{metric_key}_{tsk}.png"
            plt.savefig(os.path.join(COMBINED_PLOTS_OUT, out_name), dpi=200)
            plt.close()

print(f"[OK] Few-shot per-task barplots (Linear & k-NN) saved to: {COMBINED_PLOTS_OUT}")

# --- Label-efficiency AUC (Linear) across shots per task ---
shot_x_map = {1:1, 5:5, 10:10, 25:25, 0:50}  # 0 ("All") -> large x
df_auc = []
for (model, task), g in df_fs.groupby(["model", "task"]):
    gg = g.copy()
    gg["x"] = gg["shots_per_class"].map(shot_x_map)
    gg = gg.dropna(subset=["x", "test_macroF1_linear"]).sort_values("x")
    if gg["x"].nunique() >= 2:
        x = gg["x"].values.astype(float)
        y = gg["test_macroF1_linear"].values.astype(float)
        x_norm = (x - x.min()) / (x.max() - x.min())
        # Guard against zero division (shouldn't happen because nunique>=2)
        denom = (x_norm.max() - x_norm.min()) if (x_norm.max() - x_norm.min()) > 0 else 1.0
        auc = np.trapz(y, x_norm) / denom
        df_auc.append({"model": model, "task": task, "auc_linear": auc})
df_auc = pd.DataFrame(df_auc)

if len(df_auc):
    # Save per-task AUC values
    save_df_csv(df_auc, os.path.join(COMBINED_PLOTS_OUT, "PLOT_DATA_label_efficiency_auc_linear_by_task.csv"))

    plt.figure(figsize=(12, 6))
    auc_bar = df_auc.groupby("model", as_index=False)["auc_linear"].mean()
    # Save aggregated AUC means per model
    save_df_csv(auc_bar, os.path.join(COMBINED_PLOTS_OUT, "PLOT_DATA_label_efficiency_auc_linear_mean_over_tasks.csv"))

    sns.barplot(data=auc_bar, x="model", y="auc_linear")
    plt.title("Label-Efficiency AUC (Linear Probe) — Mean over tasks")
    plt.ylabel("AUC (0–1)"); plt.xlabel("")
    plt.xticks(rotation=20, ha="right")
    plt.tight_layout()
    plt.savefig(os.path.join(COMBINED_PLOTS_OUT, "bar_label_efficiency_auc_linear.png"), dpi=200)
    plt.close()

# --- Best-of (Linear vs k-NN) at All-train ---
df_alltrain = bench[bench["shots_per_class"] == 0].copy()
if len(df_alltrain):
    df_alltrain["best_macroF1"] = df_alltrain[["test_macroF1_linear", "test_macroF1_knn"]].max(axis=1)
    # Save raw + best-of
    save_df_csv(df_alltrain, os.path.join(COMBINED_PLOTS_OUT, "PLOT_DATA_bestof_alltrain_raw.csv"))
    save_df_csv(df_alltrain[["model","task","best_macroF1"]], os.path.join(COMBINED_PLOTS_OUT, "PLOT_DATA_bestof_alltrain_compact.csv"))

    plt.figure(figsize=(14, 6))
    order_tasks = [t for t in TASKS if t in df_alltrain["task"].unique()]
    sns.barplot(
        data=df_alltrain, x="task", y="best_macroF1", hue="model",
        order=order_tasks, errorbar="ci", dodge=True
    )
    plt.title("Best-of (Linear or k-NN) Macro-F1 — TEST, All-train")
    plt.ylabel("Macro-F1"); plt.xlabel("Task")
    plt.xticks(rotation=30, ha="right")
    plt.legend(bbox_to_anchor=(1.02, 1), loc="upper left", title="Model")
    plt.tight_layout()
    plt.savefig(os.path.join(COMBINED_PLOTS_OUT, "bar_bestof_alltrain.png"), dpi=200)
    plt.close()

# --- Win-rate across tasks (All-train, Linear) ---
if len(df_alltrain):
    wr_rows = []
    for tsk, g in df_alltrain.groupby("task"):
        g2 = g.sort_values("test_macroF1_linear", ascending=False)
        if not g2.empty:
            top_val = g2["test_macroF1_linear"].iloc[0]
            winners = g2[g2["test_macroF1_linear"] >= top_val - 1e-9]["model"].unique()
            for w in winners:
                wr_rows.append({"model": w, "win_task": tsk})
    df_wr = pd.DataFrame(wr_rows)

    if len(df_wr):
        # Save raw win tasks + win counts
        save_df_csv(df_wr, os.path.join(COMBINED_PLOTS_OUT, "PLOT_DATA_winrate_tasks_linear_alltrain.csv"))
        winrate = df_wr.groupby("model")["win_task"].nunique().reset_index()
        winrate = winrate.rename(columns={"win_task": "win_count"})
        save_df_csv(winrate, os.path.join(COMBINED_PLOTS_OUT, "PLOT_DATA_winrate_counts_linear_alltrain.csv"))

        plt.figure(figsize=(10, 5))
        sns.barplot(data=winrate, x="model", y="win_count")
        plt.title("Win-rate (#tasks won) — Linear, All-train")
        plt.ylabel("#Tasks won"); plt.xlabel("")
        plt.xticks(rotation=20, ha="right")
        plt.tight_layout()
        plt.savefig(os.path.join(COMBINED_PLOTS_OUT, "bar_winrate_alltrain.png"), dpi=200)
        plt.close()

# --- k-robustness: lower variance across k is better ---
if len(ks_all):
    var_df = ks_all.groupby(["model", "task"])["test_macroF1_knn"].std(ddof=0).reset_index()
    var_df = var_df.rename(columns={"test_macroF1_knn": "std_over_k"})
    # Save robustness data
    save_df_csv(var_df, os.path.join(COMBINED_PLOTS_OUT, "PLOT_DATA_knn_robustness_std_over_k.csv"))

    plt.figure(figsize=(12, 6))
    sns.barplot(data=var_df, x="task", y="std_over_k", hue="model")
    plt.title("k-NN Robustness — Std of Macro-F1 over k (lower is better)")
    plt.ylabel("Std(F1) over k"); plt.xlabel("Task")
    plt.xticks(rotation=30, ha="right")
    plt.legend(bbox_to_anchor=(1.02, 1), loc="upper left", title="Model")
    plt.tight_layout()
    plt.savefig(os.path.join(COMBINED_PLOTS_OUT, "bar_knn_robustness_std_over_k.png"), dpi=200)
    plt.close()

# 5) (Optional) Table preview in stdout: Macro-F1 (Linear, All-train)
if len(df_bar):
    print("\n=== Macro-F1 (TEST, Linear, All-train) by task & model ===")
    pivot = df_bar.pivot_table(index="task", columns="model", values="test_macroF1_linear", aggfunc="mean")
    print(pivot.round(3).to_string())
    # Save the pivot too
    save_df_csv(pivot.reset_index(), os.path.join(COMBINED_PLOTS_OUT, "TABLE_macroF1_linear_alltrain_pivot.csv"))

print(f"\n[OK] Combined plots + CSVs saved under: {COMBINED_PLOTS_OUT}")

[OK] Loaded models:
 - PCA
    main: fm_ssl_run\pretrained_feats2\pixels_fast256\downstream_results.csv
    pcs : fm_ssl_run\pretrained_feats2\pixels_fast256\per_class_f1_linear.csv
    k   : fm_ssl_run\pretrained_feats2\pixels_fast256\knn_k_sweep.csv
 - DeiT distilled (Random)
    main: fm_ssl_run\pretrained_feats2\imagenet_deit_b16_random\downstream_results.csv
    pcs : fm_ssl_run\pretrained_feats2\imagenet_deit_b16_random\per_class_f1_linear.csv
    k   : fm_ssl_run\pretrained_feats2\imagenet_deit_b16_random\knn_k_sweep.csv
 - Dinov2-VIT14 (Random)
    main: fm_ssl_run\pretrained_feats2\dinov2_vitb14_random\downstream_results.csv
    pcs : fm_ssl_run\pretrained_feats2\dinov2_vitb14_random\per_class_f1_linear.csv
    k   : fm_ssl_run\pretrained_feats2\dinov2_vitb14_random\knn_k_sweep.csv
 - MAE-VIT16 (Random)
    main: fm_ssl_run\pretrained_feats2\mae_vitb16_random\downstream_results.csv
    pcs : fm_ssl_run\pretrained_feats2\mae_vitb16_random\per_class_f1_linear.csv
    k   : fm_ss

  auc = np.trapz(y, x_norm) / denom
  auc = np.trapz(y, x_norm) / denom
  auc = np.trapz(y, x_norm) / denom
  auc = np.trapz(y, x_norm) / denom
  auc = np.trapz(y, x_norm) / denom
  auc = np.trapz(y, x_norm) / denom
  auc = np.trapz(y, x_norm) / denom
  auc = np.trapz(y, x_norm) / denom
  auc = np.trapz(y, x_norm) / denom
  auc = np.trapz(y, x_norm) / denom
  auc = np.trapz(y, x_norm) / denom
  auc = np.trapz(y, x_norm) / denom
  auc = np.trapz(y, x_norm) / denom
  auc = np.trapz(y, x_norm) / denom
  auc = np.trapz(y, x_norm) / denom
  auc = np.trapz(y, x_norm) / denom
  auc = np.trapz(y, x_norm) / denom
  auc = np.trapz(y, x_norm) / denom
  auc = np.trapz(y, x_norm) / denom
  auc = np.trapz(y, x_norm) / denom
  auc = np.trapz(y, x_norm) / denom
  auc = np.trapz(y, x_norm) / denom
  auc = np.trapz(y, x_norm) / denom
  auc = np.trapz(y, x_norm) / denom
  auc = np.trapz(y, x_norm) / denom
  auc = np.trapz(y, x_norm) / denom
  auc = np.trapz(y, x_norm) / denom
  auc = np.trapz(y, x_norm) 


=== Macro-F1 (TEST, Linear, All-train) by task & model ===
model             DeiT distilled (ImageNet)  DeiT distilled (Random)  Dinov2-VIT14 (LVD-142M)  Dinov2-VIT14 (Random)  MAE-VIT16 (ImageNet)  MAE-VIT16 (Random)    PCA
task                                                                                                                                                                 
Condition                             0.579                    0.406                    0.637                  0.186                 0.546               0.433  0.296
Organism_Part                         0.692                    0.469                    0.684                  0.256                 0.653               0.436  0.334
analyzerType                          0.774                    0.593                    0.799                  0.480                 0.779               0.593  0.467
ionisationSource                      0.783                    0.562                    0.725                 


### UMAP

In [None]:
# =====================================================
# UMAP visualization + Global Pie Charts (colors match)
# - One UMAP per model, recolored by multiple tasks
# - ONE pie chart per task (not per model), with counts on slices
# - Colors are consistent between UMAPs and pies
# - NEW: Robust outlier removal in UMAP PLOTS (MAD/IQR/percentile)
#   * Outliers are filtered ONLY for plotting; pies & colors unaffected
#   * Saves both raw and filtered UMAP coords for transparency
# =====================================================
import os
import re
import json
import hashlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import umap

# -----------------------------
# Config
# -----------------------------
FAST_MODE = True
MIN_CLASS_COUNT = 100 if FAST_MODE else 50
RANDOM_SEED = 6740

# Choose which FM embeddings to visualize:
#   "best" -> embeddings_all/image_feats.npy
#   "last" -> embeddings_all/image_feats_last.npy
FM_EMB_VERSION = "last"   # options: "best", "last"
# =========================================================================================================

# Off-the-shelf baselines (embedding folders must contain image_feats.npy + index.csv)
BASELINE_EMB_SETS = {
    # "PCA":                             os.path.join("fm_ssl_run", "pretrained_feats2", "pixels_fast256"),
    # "DeiT distilled (Random)":         os.path.join("fm_ssl_run", "pretrained_feats2", "imagenet_deit_b16_random"),
    "Dinov2-VIT14 (Random)":           os.path.join("fm_ssl_run", "pretrained_feats2", "dinov2_vitb14_random"),
    # "MAE-VIT16 (Random)":              os.path.join("fm_ssl_run", "pretrained_feats2", "mae_vitb16_random"),
    # "MAE-VIT16 (ImageNet)":            os.path.join("fm_ssl_run", "pretrained_feats2", "mae_vitb16"),
    # "DeiT distilled (ImageNet)":       os.path.join("fm_ssl_run", "pretrained_feats2", "imagenet_deit_b16"),
    "Dinov2-VIT14 (LVD-142M)":           os.path.join("fm_ssl_run", "pretrained_feats2", "dinov2_vitb14"),
}

# Metadata sources
IDX_PARQUET = "metaspace_images_dump/msi_fm_samples3.parquet"
MAN_PARQUET = "metaspace_images_dump/manifest_expanded.parquet"

# Output (single combined folder for all models)
OUT_DIR = os.path.join("fm_ssl_run", "baseline_eval_combined2", "umaps")
os.makedirs(OUT_DIR, exist_ok=True)

# UMAP params (per model)
N_NEIGHBORS = 15
MIN_DIST = 0.15
METRIC = "cosine"

# Tasks to color by (must exist after canonicalization)
TASKS = ["organism", "polarity", "Organism_Part", "Condition", "analyzerType", "ionisationSource"]

# ------------- Outlier Removal (for plotting only) -------------
OUTLIER_REMOVE = True
OUTLIER_METHOD = "mad"   # options: "mad", "iqr", "percentile"
MAD_K = 3             # higher = fewer removals (robust; good default 6–10)
IQR_K = 2.5              # multiplier for Q3 + k*IQR on radial distances
DIST_PERCENTILE = 99.7   # keep points within this distance percentile (if method="percentile")
# Safety rails: if we accidentally drop too much, we auto-relax once
MIN_KEEP_FRACTION = 0.85

sns.set_style("white")

# -----------------------------
# Helpers for FM-run naming & paths
# -----------------------------
def read_vit_name_from_run(run_dir: str) -> str:
    """Read cfg.vit_name from run_params.json; fallback to folder name."""
    rp = os.path.join(run_dir, "run_params.json")
    if os.path.exists(rp):
        try:
            with open(rp, "r", encoding="utf-8") as f:
                data = json.load(f)
            if "cfg" in data and isinstance(data["cfg"], dict):
                vit = data["cfg"].get("vit_name", None)
                if vit:
                    return str(vit)
            vit = data.get("vit_name", None)
            if vit:
                return str(vit)
        except Exception as e:
            print(f"[WARN] Failed reading vit_name from {rp}: {e}")
    return os.path.basename(os.path.normpath(run_dir))

def make_unique_name(base_name: str, existing: set, hint: str) -> str:
    """Ensure display name uniqueness across runs/baselines."""
    if base_name not in existing:
        return base_name
    short = hint.replace("\\", "/").strip("/").split("/")[-1]
    cand = f"{base_name} ({short})"
    if cand not in existing:
        return cand
    tiny = hashlib.md5(hint.encode("utf-8")).hexdigest()[:6]
    return f"{base_name} [{tiny}]"

def fm_embeddings_dir(run_dir: str) -> str:
    """Default FM embeddings location within each run."""
    return os.path.join(run_dir, "embeddings_all")

def fm_feats_filename() -> str:
    return "image_feats.npy" if FM_EMB_VERSION.lower() == "best" else "image_feats_last.npy"

def fm_embeddings_ready(emb_dir: str) -> bool:
    feats_file = fm_feats_filename()
    return (
        os.path.exists(os.path.join(emb_dir, feats_file)) and
        os.path.exists(os.path.join(emb_dir, "index.csv"))
    )

def baseline_embeddings_ready(emb_dir: str) -> bool:
    return (
        os.path.exists(os.path.join(emb_dir, "image_feats.npy")) and
        os.path.exists(os.path.join(emb_dir, "index.csv"))
    )

# -----------------------------
# Canonicalization (matches eval script)
# -----------------------------
def _clean(s):
    if pd.isna(s): return None
    s = str(s).strip()
    s = re.sub(r"\s+", " ", s)
    return s

def canonicalize_labels(df):
    df = df.copy()

    # 1) Polarity
    pol_map = {"pos":"Positive","positive":"Positive","+":"Positive",
               "neg":"Negative","negative":"Negative","-":"Negative"}
    def canon_polarity(s):
        if s is None: return None
        t = _clean(s).lower()
        t2 = pol_map.get(t, t)
        if t2 in ("positive","negative"):
            return t2.capitalize()
        if "pos" in t: return "Positive"
        if "neg" in t: return "Negative"
        return _clean(s)
    if "polarity" in df.columns:
        df["polarity"] = df["polarity"].map(canon_polarity)

    # 2) Ionisation Source
    def canon_ion_src(s):
        if s is None: return None
        t_raw = _clean(s)
        t = t_raw.upper().replace("-", "").replace("_","")
        if "APSMALDI" in t: return "AP-SMALDI"
        if "IRMALDESI" in t or "IRMALDI" in t: return "IR-MALDESI"
        if "APMALDI" in t: return "AP-MALDI"
        if "DESIMSI" in t: return "DESI"
        if "DESI" in t: return "DESI"
        if "MALDI" in t: return "MALDI"
        return t_raw
    if "ionisationSource" in df.columns:
        df["ionisationSource"] = df["ionisationSource"].map(canon_ion_src)

    # 3) Analyzer Type
    def canon_analyzer(s):
        if s is None: return None
        t = _clean(s); tl = t.lower()
        if "timstof" in tl and "flex" in tl: return "timsTOF Flex"
        if "fticr" in tl:
            if "12t" in tl: return "12T FTICR"
            if "7t" in tl and "scimax" in tl: return "FTICR scimaX 7T"
            return "FTICR"
        if "orbitrap" in tl or "q-exactive" in tl: return "Orbitrap"
        if "tof" in tl and "reflector" in tl: return "TOF reflector"
        if tl.strip() == "qtof": return "qTOF"
        return t
    if "analyzerType" in df.columns:
        df["analyzerType"] = df["analyzerType"].map(canon_analyzer)

    # 4) Organism
    def canon_organism(s):
        if s is None: return None
        t = _clean(s); tl = t.lower()
        if "|" in t or "," in t:
            if ("human" in tl or "homo sapiens" in tl) and ("mouse" in tl or "mus musculus" in tl):
                return "Mixed"
        if "homo sapiens" in tl or tl.strip() in {"human","h. sapiens","homo"}:
            return "Homo sapiens"
        if "mus musculus" in tl or tl.strip() in {"mouse","m. musculus"}:
            return "Mus musculus"
        return t
    if "organism" in df.columns:
        df["organism"] = df["organism"].map(canon_organism)

    # 5) Organism_Part
    def canon_part(s):
        if s is None: return None
        t = _clean(s); tl = t.lower()
        if "kidney" in tl: return "Kidney"
        if "brain"  in tl: return "Brain"
        if "liver"  in tl: return "Liver"
        if "lung"   in tl: return "Lung"
        if "breast" in tl: return "Breast"
        if "skin"   in tl: return "Skin"
        if "heart"  in tl or "cardiac" in tl: return "Heart"
        return t
    if "Organism_Part" in df.columns:
        df["Organism_Part"] = df["Organism_Part"].map(canon_part)

    # 6) Condition
    def canon_condition(s):
        if s is None: return None
        t = _clean(s); tl = t.lower()
        if tl in {"n/a","na","none","not available",""}: return "NA"
        if tl in {"biopsy","biopsies"}: return "Biopsy"
        if "fresh frozen" in tl or "frozen" in tl: return "Frozen"
        if "tumor" in tl or "tumour" in tl: return "Tumor"
        if "cancer" in tl: return "Cancer"
        if "wildtype" in tl or tl == "wt": return "Wildtype"
        if "healthy" in tl or "control" in tl: return "Healthy"
        if "diseased" in tl or "disease" in tl: return "Diseased"
        return t
    if "Condition" in df.columns:
        df["Condition"] = df["Condition"].map(canon_condition)

    return df

# -----------------------------
# Load metadata (+ canonicalize)
# -----------------------------
idx = pd.read_parquet(IDX_PARQUET)
man = pd.read_parquet(MAN_PARQUET)
meta = idx.merge(man, on="dataset_id", how="left", suffixes=("", "_man"))
meta = meta.drop_duplicates("sample_path", keep="first").reset_index(drop=True)
meta = canonicalize_labels(meta)

# -----------------------------
# Build model list (FM runs + baselines)
# -----------------------------
MODELS = []  # {"tag": display_name, "emb_dir": path, "feats_file": fname}
seen_names = set()

# Add baselines
for tag, emb_dir in BASELINE_EMB_SETS.items():
    name_disp = make_unique_name(tag, seen_names, hint=emb_dir)
    seen_names.add(name_disp)
    if not baseline_embeddings_ready(emb_dir):
        print(f"[WARN] Missing embeddings for baseline '{name_disp}' at {emb_dir} "
              f"(need image_feats.npy + index.csv); skipping.")
        continue
    MODELS.append({"tag": name_disp, "emb_dir": emb_dir, "feats_file": "image_feats.npy"})

if not MODELS:
    raise SystemExit("[ERR] No models to visualize.")

print("[INFO] Models to visualize:")
for m in MODELS:
    print(f" - {m['tag']}: {m['emb_dir']} [{m['feats_file']}]")

# -----------------------------
# Outlier detection helpers
# -----------------------------
def _mad(x):
    med = np.median(x)
    return np.median(np.abs(x - med))

def outlier_mask_umap(umap_xy: np.ndarray,
                      method: str = OUTLIER_METHOD,
                      mad_k: float = MAD_K,
                      iqr_k: float = IQR_K,
                      dist_pct: float = DIST_PERCENTILE) -> np.ndarray:
    """
    Return boolean mask of points to KEEP based on robust radial distance
    from the median center of the UMAP cloud.
    """
    if umap_xy.ndim != 2 or umap_xy.shape[1] != 2:
        return np.ones(len(umap_xy), dtype=bool)

    # robust center
    cx, cy = np.median(umap_xy[:, 0]), np.median(umap_xy[:, 1])
    r = np.sqrt((umap_xy[:, 0] - cx) ** 2 + (umap_xy[:, 1] - cy) ** 2)

    if method == "mad":
        mad = _mad(r)
        if mad == 0:
            # fallback to percentile if degenerate
            thr = np.percentile(r, dist_pct)
            keep = r <= thr
        else:
            thr = np.median(r) + mad_k * mad
            keep = r <= thr

    elif method == "iqr":
        q1, q3 = np.percentile(r, [25, 75])
        iqr = q3 - q1
        thr = q3 + iqr_k * iqr
        keep = r <= thr

    elif method == "percentile":
        thr = np.percentile(r, dist_pct)
        keep = r <= thr

    else:
        # unknown method -> keep all
        return np.ones(len(umap_xy), dtype=bool)

    # Safety rail: if we dropped too many, relax threshold once
    frac = keep.mean()
    if frac < MIN_KEEP_FRACTION:
        print(f"[WARN] Outlier removal kept only {frac:.1%}; relaxing...")
        # Relax by blending toward 100th percentile
        thr_relaxed = np.percentile(r, min(99.95, max(dist_pct, 99.0)))
        keep = r <= thr_relaxed
    return keep

# -----------------------------
# UMAP helper (fit/transform)
# -----------------------------
def run_umap_once(X, n_neighbors=30, min_dist=0.15, seed=6740, metric="cosine"):
    reducer = umap.UMAP(
        n_neighbors=n_neighbors,
        min_dist=min_dist,
        random_state=seed,
        metric=metric
    )
    return reducer.fit_transform(X)

# -----------------------------
# 1) Build per-task kept-classes + color maps (global, used by UMAPs & pies)
# -----------------------------
def build_task_color_maps(meta_df, tasks, min_class_count=50):
    kept_by_task = {}
    cmap_by_task = {}
    for task in tasks:
        if task not in meta_df.columns:
            continue
        labels_full = meta_df[task].astype("string").fillna("NA")
        vc = labels_full.value_counts()
        kept = sorted(vc[vc >= min_class_count].index.tolist())
        palette = (
            sns.color_palette("tab10", n_colors=max(len(kept), 1))
            if len(kept) <= 10
            else sns.color_palette("husl", max(len(kept), 1))
        )
        color_map = {lab: palette[i] for i, lab in enumerate(kept)}
        kept_by_task[task] = kept
        cmap_by_task[task] = color_map
    return kept_by_task, cmap_by_task

# Build once (so all models/plots use identical task colors)
KEPT_BY_TASK, CMAP_BY_TASK = build_task_color_maps(meta, TASKS, MIN_CLASS_COUNT)

# -----------------------------
# 2) Global pies (one per task), using same colors as UMAPs
# -----------------------------
def save_global_pie_for_task(meta_df, task, kept_by_task, cmap_by_task, out_dir):
    if task not in meta_df.columns:
        print(f"[WARN] Global pie: task '{task}' not in metadata; skipping.")
        return

    labels = meta_df[task].astype("string").fillna("NA")
    vc_all = labels.value_counts()
    kept = kept_by_task.get(task, [])
    c_map = cmap_by_task.get(task, {})

    # Only plot kept classes
    sizes, colors, names = [], [], []
    total = 0
    for lab in kept:
        cnt = int(vc_all.get(lab, 0))
        if cnt > 0:
            sizes.append(cnt)
            colors.append(c_map[lab])
            names.append(str(lab))
            total += cnt

    if total == 0:
        print(f"[WARN] Global pie: '{task}' has no kept samples; skipping.")
        return

    def _count_autopct(pct):
        absolute = int(round(pct * total / 100.0))
        return f"{absolute}"

    fig, ax = plt.subplots(figsize=(6.4, 6.4))
    wedges, texts, autotexts = ax.pie(
        sizes,
        colors=colors,
        startangle=90,
        autopct=_count_autopct,
        pctdistance=0.72,
        textprops={"fontsize": 9},
        wedgeprops=dict(linewidth=0.5, edgecolor="white"),
    )
    ax.axis("equal")

    ax.legend(
        wedges, names,
        title="Classes",
        loc="center left",
        bbox_to_anchor=(1.02, 0.5),
        frameon=False,
        fontsize=8
    )

    plt.tight_layout()
    out_path = os.path.join(out_dir, f"pie_{task}.png")
    plt.savefig(out_path, dpi=250, bbox_inches="tight", pad_inches=0.15)
    plt.close()
    print(f"[OK] saved: {out_path}")

# -----------------------------
# 3) UMAP plotting using precomputed task color maps
#     (applies outlier removal to the scatter only)
# -----------------------------
def plot_same_coords_color_by_tasks(umap_2d, df_join, model_tag, out_dir,
                                    tasks, min_class_count=50,
                                    outlier_remove=True):
    assert len(df_join) == len(umap_2d), "UMAP coords and df_join length mismatch"

    # Save RAW coords for reuse
    raw_coords_out = os.path.join(out_dir, f"umap_coords_{model_tag}_raw.csv")
    pd.DataFrame({
        "sample_path": df_join["sample_path"].values,
        "umap_x": umap_2d[:,0],
        "umap_y": umap_2d[:,1],
    }).to_csv(raw_coords_out, index=False)

    # Compute outlier mask (for plotting)
    if outlier_remove:
        keep_mask = outlier_mask_umap(umap_2d, method=OUTLIER_METHOD,
                                      mad_k=MAD_K, iqr_k=IQR_K, dist_pct=DIST_PERCENTILE)
    else:
        keep_mask = np.ones(len(umap_2d), dtype=bool)

    kept_frac = keep_mask.mean()
    removed = (~keep_mask).sum()
    if removed > 0:
        print(f"[INFO] {model_tag}: removing {removed} outliers ({(1-kept_frac):.1%}) from plot")

    # Save FILTERED coords (for transparency)
    filt_coords_out = os.path.join(out_dir, f"umap_coords_{model_tag}_filtered.csv")
    pd.DataFrame({
        "sample_path": df_join.loc[keep_mask, "sample_path"].values,
        "umap_x": umap_2d[keep_mask, 0],
        "umap_y": umap_2d[keep_mask, 1],
    }).to_csv(filt_coords_out, index=False)

    # Use filtered arrays for plotting only
    U = umap_2d[keep_mask]
    DF = df_join.loc[keep_mask].reset_index(drop=True)

    for task in tasks:
        if task not in DF.columns:
            print(f"[WARN] {model_tag}: task '{task}' not in metadata; skipping.")
            continue

        labels = DF[task].astype("string").fillna("NA")
        kept = KEPT_BY_TASK.get(task, [])
        color_map = CMAP_BY_TASK.get(task, {})

        is_keep = labels.isin(kept).values

        fig, ax = plt.subplots(figsize=(7.8, 6.6))

        # gray background for non-kept classes (still filtered by outliers)
        m_bg = ~is_keep
        if m_bg.any():
            ax.scatter(
                U[m_bg, 0], U[m_bg, 1],
                s=6, alpha=0.25, linewidths=0, c="#cfcfcf"
            )

        # kept classes with fixed colors
        for lab in kept:
            sel = (labels.values == lab)
            if not sel.any():
                continue
            ax.scatter(
                U[sel, 0], U[sel, 1],
                s=6, alpha=0.8, linewidths=0,
                color=color_map[lab], label=str(lab)
            )

        ax.set_title(f"UMAP — {model_tag} [{task}] (outliers removed)")
        ax.set_xlabel("UMAP-1"); ax.set_ylabel("UMAP-2")

        if 0 < len(kept) <= 25:
            ax.legend(markerscale=3, bbox_to_anchor=(1.02, 1),
                      loc="upper left", fontsize=8, frameon=False, borderaxespad=0.0)

        plt.tight_layout()
        out_path = os.path.join(out_dir, f"umap_{model_tag}_{task}.png")
        plt.savefig(out_path, dpi=250, bbox_inches="tight", pad_inches=0.15)
        plt.close()
        print(f"[OK] saved: {out_path}")

# -----------------------------
# Main: loop models (UMAP once per model) + global pies once per task
# -----------------------------
for m in MODELS:
    tag = m["tag"]
    emb_dir = m["emb_dir"]
    feats_file = m["feats_file"]

    feats_path = os.path.join(emb_dir, feats_file)
    index_path = os.path.join(emb_dir, "index.csv")
    if not (os.path.exists(feats_path) and os.path.exists(index_path)):
        print(f"[WARN] Missing artifacts for {tag} at {emb_dir} "
              f"(need {feats_file} + index.csv)")
        continue

    emb = np.load(feats_path)
    index = pd.read_csv(index_path)  # must have 'sample_path'

    # Join & align
    df_join = meta.merge(index, on="sample_path", how="inner").reset_index(drop=True)
    if "index" in df_join.columns:
        df_join = df_join.drop(columns=["index"])

    # Align lengths robustly (index.csv order is assumed to match emb row order)
    if emb.shape[0] != len(df_join):
        if emb.shape[0] > len(df_join):
            print(f"[WARN] {tag}: more embeddings ({emb.shape[0]}) than metadata rows ({len(df_join)}); trimming embeddings.")
            emb = emb[:len(df_join)]
        else:
            print(f"[WARN] {tag}: fewer embeddings ({emb.shape[0]}) than metadata rows ({len(df_join)}); truncating metadata.")
            df_join = df_join.iloc[:emb.shape[0]].reset_index(drop=True)

    # Optional FAST subsampling for UMAP fit (fit on subset, transform all)
    SUB_FIT = None
    if FAST_MODE and emb.shape[0] > 120_000:
        rng = np.random.default_rng(RANDOM_SEED)
        sub_idx = rng.choice(emb.shape[0], size=120_000, replace=False)
        SUB_FIT = sub_idx

    reducer = umap.UMAP(
        n_neighbors=N_NEIGHBORS, min_dist=MIN_DIST,
        random_state=RANDOM_SEED, metric=METRIC
    )
    if SUB_FIT is None:
        umap_2d = reducer.fit_transform(emb)
    else:
        reducer.fit(emb[SUB_FIT])
        umap_2d = reducer.transform(emb)

    print(f"[INFO] UMAP {tag}: n={emb.shape[0]} → coords computed once")

    # Plot SAME coords, recolor by each task (colors fixed globally)
    plot_same_coords_color_by_tasks(
        umap_2d=umap_2d,
        df_join=df_join,
        model_tag=tag,
        out_dir=OUT_DIR,
        tasks=TASKS,
        min_class_count=MIN_CLASS_COUNT,
        outlier_remove=OUTLIER_REMOVE
    )

# -------- After UMAPs: ONE pie per task (colors match UMAPs) --------
for task in TASKS:
    save_global_pie_for_task(meta, task, KEPT_BY_TASK, CMAP_BY_TASK, OUT_DIR)

[INFO] Models to visualize:
 - Dinov2-VIT14 (Random): fm_ssl_run\pretrained_feats2\dinov2_vitb14_random [image_feats.npy]
 - Dinov2-VIT14 (LVD-142M): fm_ssl_run\pretrained_feats2\dinov2_vitb14 [image_feats.npy]


  warn(


[INFO] UMAP Dinov2-VIT14 (Random): n=3938 → coords computed once
[INFO] Dinov2-VIT14 (Random): removing 303 outliers (7.7%) from plot
[OK] saved: fm_ssl_run\baseline_eval_combined2\umaps\umap_Dinov2-VIT14 (Random)_organism.png
[OK] saved: fm_ssl_run\baseline_eval_combined2\umaps\umap_Dinov2-VIT14 (Random)_polarity.png
[OK] saved: fm_ssl_run\baseline_eval_combined2\umaps\umap_Dinov2-VIT14 (Random)_Organism_Part.png
[OK] saved: fm_ssl_run\baseline_eval_combined2\umaps\umap_Dinov2-VIT14 (Random)_Condition.png
[OK] saved: fm_ssl_run\baseline_eval_combined2\umaps\umap_Dinov2-VIT14 (Random)_analyzerType.png
[OK] saved: fm_ssl_run\baseline_eval_combined2\umaps\umap_Dinov2-VIT14 (Random)_ionisationSource.png


  warn(


[INFO] UMAP Dinov2-VIT14 (LVD-142M): n=3938 → coords computed once
[INFO] Dinov2-VIT14 (LVD-142M): removing 287 outliers (7.3%) from plot
[OK] saved: fm_ssl_run\baseline_eval_combined2\umaps\umap_Dinov2-VIT14 (LVD-142M)_organism.png
[OK] saved: fm_ssl_run\baseline_eval_combined2\umaps\umap_Dinov2-VIT14 (LVD-142M)_polarity.png
[OK] saved: fm_ssl_run\baseline_eval_combined2\umaps\umap_Dinov2-VIT14 (LVD-142M)_Organism_Part.png
[OK] saved: fm_ssl_run\baseline_eval_combined2\umaps\umap_Dinov2-VIT14 (LVD-142M)_Condition.png
[OK] saved: fm_ssl_run\baseline_eval_combined2\umaps\umap_Dinov2-VIT14 (LVD-142M)_analyzerType.png
[OK] saved: fm_ssl_run\baseline_eval_combined2\umaps\umap_Dinov2-VIT14 (LVD-142M)_ionisationSource.png
[OK] saved: fm_ssl_run\baseline_eval_combined2\umaps\pie_organism.png
[OK] saved: fm_ssl_run\baseline_eval_combined2\umaps\pie_polarity.png
[OK] saved: fm_ssl_run\baseline_eval_combined2\umaps\pie_Organism_Part.png
[OK] saved: fm_ssl_run\baseline_eval_combined2\umaps\pie_Co