### Gather mouse and human datasets

In [23]:
from metaspace import SMInstance
import time

TARGET_ORGS = {"Homo sapiens", "Mus musculus"}
SLEEP_BETWEEN   = 0.15 

def _from_metadata(ds):
    md = getattr(ds, "metadata", None) or {}
    try:
        for key in ("organism", "Organism", "sample_organism", "Sample organism"):
            if key in md and md[key]:
                return str(md[key])
        for sec in ("Sample_Information", "Sample Information", "sample"):
            if isinstance(md.get(sec), dict):
                for key in ("organism", "Organism"):
                    val = md[sec].get(key)
                    if val:
                        return str(val)
    except Exception:
        pass
    return None

def resolve_organism(ds):
    org = getattr(ds, "organism", None)
    if org: return org
    org2 = _from_metadata(ds)
    return org2 or "-"

def is_mouse_or_human_texty(ds):
    blob = " ".join(map(str, [
        resolve_organism(ds),
        getattr(ds, "name", ""),
        getattr(ds, "description", ""),
        getattr(ds, "projects", ""),
    ])).lower()
    return ("human" in blob) or ("homo sapiens" in blob) or ("mouse" in blob) or ("mus musculus" in blob)

def iter_datasets_mouse_human(sm):
    seen = set()
    for org in ["Mus musculus", "Homo sapiens"]:
        try:
            for ds in sm.datasets(organism=org):
                if ds.id not in seen:
                    seen.add(ds.id); yield ds
        except Exception as e:
            print(f"[WARN] iter(organism={org}) failed: {e}")
        time.sleep(SLEEP_BETWEEN)

    try:
        for ds in sm.datasets():
            if ds.id in seen: continue
            if resolve_organism(ds) in TARGET_ORGS or is_mouse_or_human_texty(ds):
                seen.add(ds.id); yield ds
    except Exception as e:
        print(f"[WARN] iter(all) failed: {e}")

sm = SMInstance()

def count_mouse_human(sm):
    seen = set()
    count = 0
    for ds in iter_datasets_mouse_human(sm):
        if ds.id not in seen:
            seen.add(ds.id)
            count += 1
    return count

n_datasets = count_mouse_human(sm)
print(f"Total mouse/human datasets available: {n_datasets}")

Total mouse/human datasets available: 8854


In [None]:
# metaspace_ion_images_dump_fm.py
# ------------------------------------------------------------
# Fetch mouse/human datasets from METASPACE and save ion (isotope) images
# as arrays (uint16 .npy / .npz) plus an FM-ready manifest (per tile).
#
# Upgrades vs. original:
#  - Deterministic tiling with blank-tile filtering
#  - Optional float32 tile saves alongside uint16
#  - Per-tile manifest rows with QC stats and a stable train/val/test split
#  - Keeps per-annotation CSV + dataset metadata
#
# Notes:
#  - For a true FM, also build a parallel imzML->cube pipeline. This script
#    is ideal for weakly-labeled 2D supervision at scale.

from metaspace import SMInstance
from pathlib import Path
import os, json, time, traceback, hashlib
import numpy as np
import pandas as pd
from tqdm import tqdm

# ---------------- CONFIG ----------------
OUT_ROOT        = Path("metaspace_images_dump")
DB              = ("HMDB", "v4")
FDR_MAX         = 0.10
TOP_K_ANN       = 48          # cap per-dataset so a few huge sets don’t dominate
MAX_ISOTOPES    = 3           # M, M+1, M+2 are usually enough
MAX_DATASETS    = 5000
SAVE_FORMAT     = "npz"       # compressed; cuts storage 30–70% vs .npy
SLEEP_BETWEEN   = 0.15        # be polite; lower if you parallelize clients

# --- Tiling & QC ---
DO_TILING       = True
TILE_SIZE       = 256         # great balance of context vs count
TILE_STRIDE     = 256         # start with no-overlap to keep volume sane
MIN_NNZ_PCT     = 3.0         # % of pixels > 0 required in a tile (see note below)
SAVE_FLOAT_TILES= False       # save only uint16; re-materialize float on load
USE_WHOLE_ROWS  = False       # train on tiles; whole images for QA only

# --- Splits (grouped by dataset to avoid leakage) ---
TRAIN_FRAC      = 0.85
VAL_FRAC        = 0.10        # test ends up ~0.05

# ----------------------------------------

OUT_ROOT.mkdir(parents=True, exist_ok=True)

# ------------- AUTH / LOGIN (optional) -------------
sm = SMInstance()
API_KEY = os.getenv("METASPACE_API_KEY", "").strip()
if API_KEY and hasattr(sm, "login"):
    try:
        sm.login(API_KEY)
        print("[AUTH] Logged in via API key")
    except Exception as e:
        print(f"[WARN] sm.login failed: {e}  (you can run sm.save_login() once)")
else:
    print("[AUTH] Proceeding with current access (public or saved login).")

# --------- DATASET DISCOVERY ----------
TARGET_ORGS = {"Homo sapiens", "Mus musculus"}

def _from_metadata(ds):
    md = getattr(ds, "metadata", None) or {}
    try:
        for key in ("organism", "Organism", "sample_organism", "Sample organism"):
            if key in md and md[key]:
                return str(md[key])
        for sec in ("Sample_Information", "Sample Information", "sample"):
            if isinstance(md.get(sec), dict):
                for key in ("organism", "Organism"):
                    val = md[sec].get(key)
                    if val:
                        return str(val)
    except Exception:
        pass
    return None

def resolve_organism(ds):
    org = getattr(ds, "organism", None)
    if org: return org
    org2 = _from_metadata(ds)
    return org2 or "-"

def is_mouse_or_human_texty(ds):
    blob = " ".join(map(str, [
        resolve_organism(ds),
        getattr(ds, "name", ""),
        getattr(ds, "description", ""),
        getattr(ds, "projects", ""),
    ])).lower()
    return ("human" in blob) or ("homo sapiens" in blob) or ("mouse" in blob) or ("mus musculus" in blob)

def iter_datasets_mouse_human(sm):
    seen = set()
    for org in ["Mus musculus", "Homo sapiens"]:
        try:
            for ds in sm.datasets(organism=org):
                if ds.id not in seen:
                    seen.add(ds.id); yield ds
        except Exception as e:
            print(f"[WARN] iter(organism={org}) failed: {e}")
        time.sleep(SLEEP_BETWEEN)

    try:
        for ds in sm.datasets():
            if ds.id in seen: continue
            if resolve_organism(ds) in TARGET_ORGS or is_mouse_or_human_texty(ds):
                seen.add(ds.id); yield ds
    except Exception as e:
        print(f"[WARN] iter(all) failed: {e}")

# --------- HELPERS ----------
def safe_results(ds, db):
    """Return results DataFrame with MultiIndex (formula, adduct) if possible."""
    try:
        res = ds.results(database=db)
        if res is None or len(res) == 0: return None
        # ensure index is (formula, adduct) for easy use
        if not isinstance(res.index, pd.MultiIndex):
            cols = set(res.columns)
            if {"formula", "adduct"}.issubset(cols):
                res.index = pd.MultiIndex.from_frame(res[["formula","adduct"]])
        return res
    except Exception as e:
        print(f"[WARN] results() failed for {getattr(ds, 'id', '?')}: {e}")
        return None

def to_uint16_robust(img: np.ndarray):
    """Percentile-based normalization -> uint16 for compact storage.
       Returns (uint16_img, (lo, hi)) where lo/hi are float percentiles (1,99)."""
    img = np.asarray(img, dtype=np.float32)
    if img.size == 0:
        return np.zeros_like(img, dtype=np.uint16), (0.0, 1.0)
    lo, hi = np.percentile(img, [1, 99])
    if not np.isfinite(lo) or not np.isfinite(hi) or hi <= lo:
        lo = float(np.min(img))
        hi = float(np.max(img))
        if hi <= lo:
            hi = lo + 1e-6
    img_n = np.clip((img - lo) / (hi - lo), 0, 1)
    return (img_n * 65535).astype(np.uint16), (float(lo), float(hi))

def save_array(arr_uint16: np.ndarray, out_path: Path, meta: dict, fmt: str = "npy"):
    out_path.parent.mkdir(parents=True, exist_ok=True)
    if fmt == "npy":
        np.save(out_path, arr_uint16)
        (out_path.with_suffix(".json")).write_text(json.dumps(meta, indent=2))
    elif fmt == "npz":
        # store meta inside the archive as JSON bytes
        np.savez_compressed(out_path, image=arr_uint16, meta=json.dumps(meta).encode("utf-8"))
    else:
        raise ValueError("SAVE_FORMAT must be 'npy' or 'npz'")

def stable_split(dataset_id: str, p_train=TRAIN_FRAC, p_val=VAL_FRAC):
    """Deterministic split by hashing dataset_id."""
    h = int(hashlib.md5(dataset_id.encode("utf-8")).hexdigest(), 16) % 10_000
    x = h / 10_000.0
    return "train" if x < p_train else ("val" if x < p_train + p_val else "test")

def tile_and_save(img_uint16: np.ndarray,
                  img_float32: np.ndarray,
                  base_out: Path,
                  tile=256,
                  stride=256,
                  min_nnz_pct=0.5,
                  save_float=True):
    """Tile uint16 and (optionally) float32 images; filter blank tiles; return tile records."""
    H, W = img_uint16.shape[:2]
    tiles = []
    base_out.parent.mkdir(parents=True, exist_ok=True)

    for r in range(0, max(1, H - tile + 1), stride):
        for c in range(0, max(1, W - tile + 1), stride):
            patch_u16 = img_uint16[r:r+tile, c:c+tile]
            if patch_u16.shape[:2] != (tile, tile):
                continue

            nnz = np.count_nonzero(patch_u16)
            nnz_pct = 100.0 * (nnz / float(tile * tile))
            if nnz_pct < min_nnz_pct:
                continue

            p_stem = f"{base_out.stem}_r{r:04d}_c{c:04d}"
            p_u16 = base_out.parent / f"{p_stem}.npy"
            np.save(p_u16, patch_u16)

            p_f32 = None
            if save_float and img_float32 is not None:
                patch_f32 = img_float32[r:r+tile, c:c+tile].astype(np.float32, copy=False)
                p_f32 = base_out.parent / f"{p_stem}.float32.npy"
                np.save(p_f32, patch_f32)

            tiles.append({
                "tile_r": r, "tile_c": c, "tile_h": tile, "tile_w": tile,
                "nnz_pct": nnz_pct, "path_u16": str(p_u16),
                "path_f32": (str(p_f32) if p_f32 else None),
            })
    return tiles

import json, math
import numpy as np

def _jsonify(obj):
    # Make numpy / sets / NaNs JSON-friendly
    if isinstance(obj, (np.integer,)):  return int(obj)
    if isinstance(obj, (np.floating,)): return float(obj)
    if isinstance(obj, (np.ndarray,)):  return obj.tolist()
    if isinstance(obj, set):            return list(obj)
    if isinstance(obj, dict):           return {str(k): _jsonify(v) for k,v in obj.items()}
    if isinstance(obj, (list, tuple)):  return [_jsonify(v) for v in obj]
    if isinstance(obj, float) and (math.isnan(obj) or math.isinf(obj)): return None
    return obj

def collect_full_metadata(ds, db=("HMDB","v4")):
    """Best-effort snapshot of everything we can reasonably serialize."""
    base = {
        # common top-level attributes (safe getattr)
        "dataset_id":      getattr(ds, "id", None),
        "name":            getattr(ds, "name", None),
        "description":     getattr(ds, "description", None),
        "organism":        getattr(ds, "organism", None),
        "polarity":        getattr(ds, "polarity", None),
        "analyzerType":    getattr(ds, "analyzerType", None),
        "ionisationSource":getattr(ds, "ionisationSource", None),
        "maldiMatrix":     getattr(ds, "maldiMatrix", None),
        "submitter":       getattr(ds, "submitter", None),   # may be dict-like or str depending on client
        "principalInvestigator": getattr(ds, "principalInvestigator", None),
        "group":           getattr(ds, "group", None),
        "projects":        getattr(ds, "projects", None),
        "license":         getattr(ds, "license", None),
        "status":          getattr(ds, "status", None),
        "uploaded_dt":     getattr(ds, "uploadDT", None) or getattr(ds, "uploadedDT", None),
        "db_used":         db,
    }

    # user-provided metadata JSON attached to the dataset (often rich)
    try:
        base["metadata_uploaded"] = _jsonify(getattr(ds, "metadata", {}) or {})
    except Exception:
        base["metadata_uploaded"] = {}

    # raw GraphQL record (contains fields not exposed as attrs). This is a private field
    # in many client versions, so wrap carefully.
    try:
        gql = getattr(ds, "_ds", None)   # WARNING: private; may change across versions
        if isinstance(gql, dict):
            base["graphql_record"] = _jsonify(gql)
    except Exception:
        pass

    # diagnostics (header info, quick stats), best-effort
    diags = {}
    for key in ["IMZML_METADATA", "DATASET_SUMMARY", "ANNOTATION_COUNTS", "OFF_SAMPLE_MASK"]:
        try:
            d = ds.diagnostic(key)
            if d is not None:
                diags[key] = _jsonify(d)
        except Exception:
            continue
    if diags:
        base["diagnostics"] = diags

    # results summary: how many annotations pass FDR thresholds, etc.
    try:
        res = ds.results(database=db)
        if res is not None and "fdr" in res.columns:
            counts = {}
            for thr in [0.01, 0.05, 0.1, 0.2, 0.5]:
                counts[f"n_fdr_le_{thr}"] = int((res["fdr"] <= thr).sum())
            base["results_summary"] = counts
    except Exception:
        pass

    # spatial size hint (from one isotope image)
    try:
        imgs = ds.isotope_images("C6H12O6", "+Na")  # harmless probe; likely missing
        if imgs and len(imgs) > 0:
            sh = np.asarray(imgs[0]).shape
            base["spatial_shape_probe"] = list(sh)
    except Exception:
        pass

    return base

def save_full_metadata(ds, out_dir: Path, db=("HMDB","v4")):
    info = collect_full_metadata(ds, db=db)
    (out_dir / "metadata_full.json").write_text(json.dumps(_jsonify(info), indent=2))

# --------- CORE: per-dataset processing ----------
def process_one_dataset(ds) -> list:
    """Returns a list of manifest rows (per tile or per whole image) for this dataset."""
    rows = []
    ds_id = getattr(ds, "id", None)
    if not ds_id: return rows

    ds_dir = OUT_ROOT / ds_id

    if ds_dir.exists():
        print(f"[SKIP] {ds_id}: folder exists")
        return []
    
    ds_dir.mkdir(parents=True, exist_ok=True)

    save_full_metadata(ds, ds_dir, db=DB)

    org = resolve_organism(ds)
    nm  = getattr(ds, "name", None)
    split = stable_split(ds_id)

    print(f"[INFO] Processing {ds_id} | org={org} | name={nm or '-'} | split={split}")

    # Save basic metadata (dataset-level)
    meta = {
        "dataset_id": ds_id, "name": nm, "organism": org,
        "polarity": getattr(ds, "polarity", None),
        "analyzerType": getattr(ds, "analyzerType", None),
        "ionisationSource": getattr(ds, "ionisationSource", None),
        "db": DB, "fdr_max": FDR_MAX, "top_k_ann": TOP_K_ANN,
        "max_isotopes": MAX_ISOTOPES,
        "save_format": SAVE_FORMAT, "tiling": DO_TILING,
        "tile_size": TILE_SIZE, "tile_stride": TILE_STRIDE,
    }
    (ds_dir / "metadata.json").write_text(json.dumps(meta, indent=2))

    # Get annotations
    res = safe_results(ds, DB)
    if res is None or "fdr" not in res.columns:
        print(f"[SKIP] {ds_id}: no usable results/FDR")
        return rows

    keep = res[res["fdr"] <= FDR_MAX].copy()
    if len(keep) == 0:
        print(f"[SKIP] {ds_id}: 0 annotations pass FDR ≤ {FDR_MAX}")
        return rows

    # Sort and cap
    if TOP_K_ANN is not None and len(keep) > TOP_K_ANN:
        if "msm" in keep.columns:
            keep = keep.sort_values("msm", ascending=False).head(TOP_K_ANN)
        else:
            keep = keep.sort_values("fdr", ascending=True).head(TOP_K_ANN)

    # Save a copy of kept annotations (useful later)
    keep.to_csv(ds_dir / "annotations_kept.csv")

    # Iterate annotations
    for (sf, adduct) in tqdm(list(keep.index), desc=f"{ds_id}: pulling images", leave=False):
        ann_fdr = float(keep.loc[(sf, adduct)].get("fdr", np.nan))
        ann_msm = float(keep.loc[(sf, adduct)].get("msm", np.nan)) if "msm" in keep.columns else np.nan

        try:
            images = ds.isotope_images(sf, adduct)  # list-like of 2D arrays
        except Exception as e:
            print(f"[WARN] {ds_id} {sf} {adduct}: isotope_images failed -> {e}")
            continue

        n = min(len(images), MAX_ISOTOPES)
        ann_dir = ds_dir / "images" / f"{sf}_{adduct}".replace("/", "_")

        for j in range(n):
            try:
                im = np.asarray(images[j], dtype=np.float32)
                arr16, (lo, hi) = to_uint16_robust(im)
                try:
                    peak_mz = float(images.peak(index=j))
                except Exception:
                    peak_mz = float("nan")

                # Whole-image save (uint16)
                base = ann_dir / f"peak{j}"
                out_path = base.with_suffix(".npy") if SAVE_FORMAT == "npy" else base.with_suffix(".npz")
                img_meta = {
                    "dataset_id": ds_id,
                    "name": nm,
                    "organism": org,
                    "split": split,
                    "db": DB,
                    "fdr": ann_fdr,
                    "msm": ann_msm,
                    "sum_formula": sf,
                    "adduct": adduct,
                    "isotope_index": j,
                    "peak_mz": peak_mz,
                    "shape": list(arr16.shape),
                    "dtype": "uint16",
                    "norm_lo": lo,
                    "norm_hi": hi,
                    "polarity": getattr(ds, "polarity", None),
                    "analyzerType": getattr(ds, "analyzerType", None),
                    "ionisationSource": getattr(ds, "ionisationSource", None),
                }

                save_array(arr16, out_path, img_meta, fmt=SAVE_FORMAT)

                # QC for whole image (optional fields)
                whole_mean = float(np.mean(arr16)) if arr16.size else 0.0
                whole_std  = float(np.std(arr16)) if arr16.size else 0.0
                nnz_pct_whole = 100.0 * (np.count_nonzero(arr16) / float(arr16.size)) if arr16.size else 0.0

                # Tiling for FM
                tile_count = 0
                if DO_TILING:
                    tiles = tile_and_save(
                        arr16, im, out_path,
                        tile=TILE_SIZE,
                        stride=TILE_STRIDE,
                        min_nnz_pct=MIN_NNZ_PCT,
                        save_float=SAVE_FLOAT_TILES
                    )
                    for t in tiles:
                        rows.append({
                            **img_meta,
                            "path": t["path_u16"],
                            "path_float32": t["path_f32"],
                            "tile_r": t["tile_r"], "tile_c": t["tile_c"],
                            "tile_h": t["tile_h"], "tile_w": t["tile_w"],
                            "nnz_pct": t["nnz_pct"],
                            "whole_mean_u16": whole_mean,
                            "whole_std_u16": whole_std,
                            "whole_nnz_pct": nnz_pct_whole,
                            "n_tiles": 1
                        })
                    tile_count = len(tiles)

                # Optionally also include a row for the whole image
                if USE_WHOLE_ROWS or (not DO_TILING or tile_count == 0):
                    rows.append({
                        **img_meta,
                        "path": str(out_path),
                        "path_float32": None,
                        "tile_r": None, "tile_c": None,
                        "tile_h": arr16.shape[0], "tile_w": arr16.shape[1],
                        "nnz_pct": nnz_pct_whole,
                        "whole_mean_u16": whole_mean,
                        "whole_std_u16": whole_std,
                        "whole_nnz_pct": nnz_pct_whole,
                        "n_tiles": 0
                    })

            except Exception as e:
                print(f"[WARN] {ds_id} {sf} {adduct} peak{j}: save failed -> {e}")
                continue

    print(f"[DONE] {ds_id}: wrote {sum(1 for _ in rows if _['dataset_id']==ds_id)} manifest rows")
    return rows

# ---------------- MAIN -----------------
def main():
    manifest_rows = []
    seen, count = set(), 0

    for ds in iter_datasets_mouse_human(sm):
        dsid = getattr(ds, "id", None)
        if not dsid or dsid in seen:
            continue
        seen.add(dsid)

        count += 1
        if MAX_DATASETS is not None and count > MAX_DATASETS:
            print(f"[STOP] Reached MAX_DATASETS={MAX_DATASETS}")
            break

        try:
            rows = process_one_dataset(ds)
            manifest_rows.extend(rows)
        except Exception as e:
            print(f"[WARN] {dsid}: processing crashed -> {e}")
            traceback.print_exc()

    if manifest_rows:
        man_df = pd.DataFrame(manifest_rows)

        # Reorder columns for readability
        preferred = [
            "dataset_id","name","organism","split",
            "db","fdr","msm","sum_formula","adduct","isotope_index","peak_mz",
            "path","path_float32",
            "tile_r","tile_c","tile_h","tile_w","nnz_pct",
            "whole_mean_u16","whole_std_u16","whole_nnz_pct",
            "polarity","analyzerType","ionisationSource",
            "dtype","shape","norm_lo","norm_hi","n_tiles"
        ]
        cols = [c for c in preferred if c in man_df.columns] + [c for c in man_df.columns if c not in preferred]
        man_df = man_df[cols]

        OUT_ROOT.mkdir(parents=True, exist_ok=True)
        man_df.to_parquet(OUT_ROOT / "manifest.parquet", index=False)
        man_df.to_csv(OUT_ROOT / "manifest.csv", index=False)
        n_ds = man_df['dataset_id'].nunique()
        n_tiles = (man_df['n_tiles'] == 1).sum()
        print(f"\n[SUMMARY] Wrote {len(man_df)} manifest rows "
              f"({n_tiles} tiles) across {n_ds} datasets")
    else:
        print("\n[SUMMARY] No images saved")

if __name__ == "__main__":
    main()

### Expand manifest with metadata

In [None]:
# --- Enrich manifest using metadata_full.json files (adds/updates analyzerType & ionisationSource) ---
import os, re, json
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm

# ---------------- CONFIG ----------------
MANIFEST_PATH = Path(r"Y:\coskun-lab\Efe\MSI Foundation Model\metaspace_images_dump\manifest.parquet")
SIDECAR_ROOT = MANIFEST_PATH.parent   # base folder that contains <dataset_id>/metadata_full.json

# For Organism_Part / Condition we read from Sample_Information;
# For analyzerType / ionisationSource we read from MS_Analysis (robust to key variants).
FIELD_MAP = {
    "Organism_Part":    ("metadata_uploaded", "Sample_Information", "Organism_Part"),
    "Condition":        ("metadata_uploaded", "Sample_Information", "Condition"),
    "analyzerType":     ("metadata_uploaded", "MS_Analysis", ("Analyzer","Analyser","Analyzer Type","Analyzer_Type","analyzerType")),
    "ionisationSource": ("metadata_uploaded", "MS_Analysis", ("Ionisation_Source","Ionization_Source","Ion Source","Ion_Source","ionisationSource","ionization source")),
}
# ----------------------------------------

# --- Robust, case/format-insensitive nested getter ---
def _norm_key(s: str) -> str:
    s = s.lower()
    s = s.replace("ionization", "ionisation")  # US→UK normalization
    s = s.replace("analyser", "analyzer")      # UK→US normalization
    return re.sub(r"[^a-z0-9]+", "", s)        # strip spaces/_/-

def _ci_get(d: dict, candidates, default=np.nan):
    """Case/format-insensitive key fetch from dict using list/tuple of candidate keys."""
    if not isinstance(d, dict) or not d:
        return default
    lut = {_norm_key(k): v for k, v in d.items()}
    for c in (candidates if isinstance(candidates, (list, tuple)) else [candidates]):
        v = lut.get(_norm_key(str(c)))
        if v is not None:
            return v
    return default

def _get_nested_ci(d, keys, default=np.nan):
    """
    Keys can be strings or tuples/lists of variants.
    Also tolerant to section-name variants:
      - 'metadata_uploaded' / 'metadata uploaded'
      - 'Sample_Information' / 'Sample Information' / 'Sample'
      - 'MS_Analysis' / 'MS Analysis' / 'MS'
    """
    cur = d
    section_variants = {
        "metadata_uploaded": ("metadata_uploaded", "metadata uploaded"),
        "Sample_Information": ("Sample_Information", "Sample Information", "Sample"),
        "MS_Analysis": ("MS_Analysis", "MS Analysis", "MS"),
    }
    for k in keys:
        if not isinstance(cur, dict):
            return default
        if isinstance(k, (list, tuple)):
            cur = _ci_get(cur, list(k), default)
        else:
            cand = section_variants.get(k, (k,))
            cur = _ci_get(cur, cand, default)
        if cur is default:
            return default
    return cur

def load_sidecar_metadata(dataset_id: str) -> dict:
    """Load metadata_full.json for a given dataset_id, return dict or {}."""
    json_path = SIDECAR_ROOT / dataset_id / "metadata_full.json"
    if json_path.exists():
        try:
            with open(json_path, "r", encoding="utf-8") as f:
                return json.load(f)
        except Exception as e:
            print(f"[warn] Failed to read {json_path}: {e}")
    return {}

# ---- Load manifest ----
if MANIFEST_PATH.suffix == ".parquet":
    df = pd.read_parquet(MANIFEST_PATH, engine="pyarrow")
elif MANIFEST_PATH.suffix == ".csv":
    df = pd.read_csv(MANIFEST_PATH)
else:
    raise ValueError(f"Unsupported manifest format: {MANIFEST_PATH.suffix}")

print(f"[info] manifest loaded with {len(df)} rows, {df['dataset_id'].nunique()} datasets")

# Track which columns already existed to avoid suffix collisions
pre_cols = set(df.columns)

# ---- Enrich per-dataset ----
meta_records = {}
for dsid in tqdm(df["dataset_id"].unique(), desc="Loading metadata"):
    meta = load_sidecar_metadata(str(dsid))
    if not meta:
        continue
    rec = {}
    for out_col, path in FIELD_MAP.items():
        rec[out_col] = _get_nested_ci(meta, path, default=np.nan)
    meta_records[dsid] = rec

# Nothing to merge?
if not meta_records:
    print("[warn] No sidecar metadata found; nothing to enrich.")
else:
    meta_df = pd.DataFrame.from_dict(meta_records, orient="index").reset_index().rename(columns={"index":"dataset_id"})

    # Merge with explicit suffix so we can safely fill existing cols (e.g., analyzerType, ionisationSource)
    df = df.merge(meta_df, on="dataset_id", how="left", suffixes=("", "_sidecar"))

    # For any field that already exists in manifest, fill missing values from sidecar and drop helper column
    for col in FIELD_MAP.keys():
        sidecar_col = f"{col}_sidecar"
        if sidecar_col in df.columns:
            if col in pre_cols:
                df[col] = df[col].where(~df[col].isna(), df[sidecar_col])
                df.drop(columns=[sidecar_col], inplace=True)
            else:
                # Column did not exist originally; rename sidecar → main
                df.rename(columns={sidecar_col: col}, inplace=True)

# ---- Save enriched manifest ----
OUT_PATH = MANIFEST_PATH.with_name(MANIFEST_PATH.stem + "_expanded").with_suffix(".parquet")
df.to_parquet(OUT_PATH, index=False, engine="pyarrow")

# Report what changed/added
added_now = [c for c in FIELD_MAP.keys() if c in df.columns and c not in pre_cols]
print(f"[OK] Enriched manifest saved → {OUT_PATH}")
print("Columns added (new):", added_now)
print("Rows:", len(df), "| Datasets:", df["dataset_id"].nunique())

# Quick peek
cols_to_show = ["dataset_id"] + [c for c in FIELD_MAP.keys() if c in df.columns]
display_cols = [c for c in cols_to_show if c in df.columns]
df.head(3)[display_cols]

[info] manifest loaded with 384403 rows, 3774 datasets


Loading metadata: 100%|██████████| 3774/3774 [01:58<00:00, 31.94it/s]


[OK] Enriched manifest saved → Y:\coskun-lab\Efe\MSI Foundation Model\metaspace_images_dump\manifest_expanded.parquet
Columns added (new): ['Organism_Part', 'Condition']
Rows: 384403 | Datasets: 3774


Unnamed: 0,dataset_id,Organism_Part,Condition,analyzerType,ionisationSource
0,2025-08-27_01h37m31s,Kidney,biopsy,timsTOF fleX,MALDI
1,2025-08-27_01h37m31s,Kidney,biopsy,timsTOF fleX,MALDI
2,2025-08-27_01h37m31s,Kidney,biopsy,timsTOF fleX,MALDI


In [None]:
# assemble_kronos_samples_npz_aware.py
# Build (patch, mz) .npz samples for KRONOS/MSI-FM pretrain from your tiling manifest.
# - Single process, no pin memory / multiprocessing
# - Supports .npy and .npz: picks first 2D array or squeezes (1,H,W)->(H,W)
# - Ignores `n_tiles`; uses tile coords if present; optional dataset-level fallback
# - Vectorized ranking (fdr asc, msm desc), mmap reads, pre-allocation

from pathlib import Path
import pandas as pd
import numpy as np
from tqdm import tqdm
from collections import Counter
import os

# ====================== CONFIG ======================
DUMP_ROOT   = Path("metaspace_images_dump")
SRC_MANIFEST= DUMP_ROOT / "manifest_expanded.parquet"    # or .csv
OUT_SAMPLES = DUMP_ROOT / "msi_fm_samples"
OUT_LIST    = DUMP_ROOT / "msi_fm_samples.parquet"

C_TARGET = 64                     # max channels per sample
MIN_CHANNELS_PER_SAMPLE = 8       # skip groups with fewer channels
N_SAMPLES_PER_DATASET_CAP = None  # e.g. 200; None disables

USE_UINT16_INPUT = True           # cast channels to uint16 for compact storage
USE_COMPRESSION  = False          # np.savez (False) is faster; use True for smaller files
WRITE_CSV_LIST   = True           # also write CSV index

# Rows without tile coords:
PROCESS_UNTILED = True            # try dataset-level fallback groups
UNTILED_PATH_PATTERN = None       # e.g. "peak" to include only certain files in fallback
UNTILED_ACCEPT_3D_SQUEEZE = True  # accept (1,H,W)->(H,W) by squeeze
# ====================================================

OUT_SAMPLES.mkdir(parents=True, exist_ok=True)

# ---------- Load manifest ----------
if SRC_MANIFEST.suffix == ".parquet":
    df = pd.read_parquet(SRC_MANIFEST, engine="pyarrow")
else:
    df = pd.read_csv(SRC_MANIFEST)

# ---------- Required columns ----------
if "dataset_id" not in df.columns or "path" not in df.columns:
    raise ValueError("Manifest must contain at least ['dataset_id','path'].")

# ---------- Vectorized ranking: lower fdr better; higher msm better ----------
if "fdr" in df.columns:
    fdr = df["fdr"].to_numpy()
    fdr = np.where(np.isfinite(fdr), fdr, 1.0)
else:
    fdr = np.ones(len(df), dtype=np.float32)

if "msm" in df.columns:
    msm = df["msm"].to_numpy()
    msm = np.where(np.isfinite(msm), -msm, 0.0)  # negative so asc == original desc
else:
    msm = np.zeros(len(df), dtype=np.float32)

order_key = np.lexsort((msm, fdr))  # stable global order
df = df.iloc[order_key].reset_index(drop=True)

has_mz = "peak_mz" in df.columns

# ---------- Tiled vs untiled split ----------
tile_cols = ["tile_r", "tile_c", "tile_h", "tile_w"]
tile_cols_exist = all(c in df.columns for c in tile_cols)
mask_complete = df[tile_cols].notna().all(axis=1) if tile_cols_exist else pd.Series(False, index=df.index)

tiled_df   = df[mask_complete].copy()
untiled_df = df[~mask_complete].copy()

# Optional filter for untiled paths
if PROCESS_UNTILED and UNTILED_PATH_PATTERN:
    untiled_df = untiled_df[untiled_df["path"].astype(str).str.contains(UNTILED_PATH_PATTERN, case=False, na=False)]

# ---------- NPZ/Numpy loader helpers ----------
def _pick_first_2d_from_npz(npz_file: np.lib.npyio.NpzFile):
    # Prefer common names, else first 2D
    preferred = ("tile", "image", "img", "data", "arr", "arr_0")
    keys = list(npz_file.keys())
    for k in preferred:
        if k in keys:
            a = npz_file[k]
            if getattr(a, "ndim", 0) == 2:
                return a
    for k in keys:
        a = npz_file[k]
        if getattr(a, "ndim", 0) == 2:
            return a
    return None

def load_2d_array(path: str):
    """
    Return (H,W) np.ndarray or None.
    - .npy: accept 2D; or squeeze (1,H,W)->(H,W) if enabled
    - .npz: pick first 2D array; or squeeze a (1,H,W) if enabled
    """
    try:
        obj = np.load(path, mmap_mode='r', allow_pickle=False)
    except Exception:
        return None

    if isinstance(obj, np.lib.npyio.NpzFile):
        a2 = _pick_first_2d_from_npz(obj)
        if a2 is not None:
            return a2
        if UNTILED_ACCEPT_3D_SQUEEZE:
            for k in obj.files:
                a = obj[k]
                if getattr(a, "ndim", 0) == 3 and 1 in a.shape:
                    a = np.squeeze(a)
                    if a.ndim == 2:
                        return a
        return None

    # Regular .npy
    a = obj
    if getattr(a, "ndim", 0) == 2:
        return a
    if UNTILED_ACCEPT_3D_SQUEEZE and getattr(a, "ndim", 0) == 3 and 1 in a.shape:
        a = np.squeeze(a)
        if a.ndim == 2:
            return a
    return None

# ---------- Processing ----------
skip_reasons = Counter()
records = []
by_ds_counts = {}

def process_group(dsid, r, c, th, tw, grp):
    # Cap per dataset
    if N_SAMPLES_PER_DATASET_CAP is not None and by_ds_counts.get(dsid, 0) >= N_SAMPLES_PER_DATASET_CAP:
        skip_reasons["per_dataset_cap"] += 1
        return

    # Take top C_TARGET rows from already globally ranked group
    if len(grp) > C_TARGET:
        grp = grp.iloc[:C_TARGET]

    # Keep only real files with acceptable extensions
    all_paths = grp["path"].astype(str).to_numpy()
    paths = [p for p in all_paths if os.path.exists(p) and (p.endswith(".npy") or p.endswith(".npz"))]
    if not paths:
        skip_reasons["no_valid_paths"] += 1
        return

    # m/z vector aligned to kept paths
    if has_mz:
        mzes_full = grp["peak_mz"].to_numpy()
    else:
        mzes_full = np.full(len(grp), np.nan, dtype=np.float32)

    # Load first usable channel to determine shape/dtype
    arr0 = None
    first_idx = -1
    for i, p in enumerate(paths):
        arr0 = load_2d_array(p)
        if arr0 is not None:
            first_idx = i
            break
    if arr0 is None:
        skip_reasons["load_error_or_not_2d"] += 1
        return

    H, W = int(arr0.shape[0]), int(arr0.shape[1])
    dtype = np.uint16 if USE_UINT16_INPUT else arr0.dtype

    # Collect up to C_TARGET valid channels (shape-consistent)
    valid_arrays = []
    valid_mz = []
    for i, p in enumerate(paths):
        a = load_2d_array(p)
        if a is None:
            continue
        if a.shape != (H, W):
            skip_reasons["shape_mismatch"] += 1
            continue
        valid_arrays.append(a)
        # align mz
        mz_val = mzes_full[i] if i < len(mzes_full) else np.nan
        valid_mz.append(mz_val)
        if len(valid_arrays) == C_TARGET:
            break

    C = len(valid_arrays)
    if C < MIN_CHANNELS_PER_SAMPLE:
        skip_reasons["too_few_channels"] += 1
        return

    # Stack
    patch = np.empty((C, H, W), dtype=dtype)
    for i, a in enumerate(valid_arrays):
        patch[i] = a.astype(dtype, copy=False) if a.dtype != dtype else np.asarray(a)

    mzes = np.asarray(valid_mz, dtype=np.float32)

    # Save
    stem = f"{dsid}_r{int(r)}_c{int(c)}_C{C}"
    out_path = OUT_SAMPLES / f"{stem}.npz"
    if USE_COMPRESSION:
        np.savez_compressed(out_path, patch=patch, mz=mzes)
    else:
        np.savez(out_path, patch=patch, mz=mzes)

    records.append({
        "sample_path": str(out_path),
        "dataset_id": dsid,
        "tile_r": int(r), "tile_c": int(c),
        "tile_h": int(th), "tile_w": int(tw),
        "channels": int(C),
    })
    by_ds_counts[dsid] = by_ds_counts.get(dsid, 0) + 1

# ---------- Grouping and main loop ----------
if tile_cols_exist and not tiled_df.empty:
    groups_tiled = tiled_df.groupby(["dataset_id","tile_r","tile_c","tile_h","tile_w"], sort=False)
    n_tiled = groups_tiled.ngroups
else:
    groups_tiled = None
    n_tiled = 0

if PROCESS_UNTILED and not untiled_df.empty:
    groups_untiled = untiled_df.groupby(["dataset_id"], sort=False)
    n_untiled = groups_untiled.ngroups
else:
    groups_untiled = None
    n_untiled = 0

total_groups = n_tiled + n_untiled

with tqdm(total=total_groups, desc="Assembling samples") as pbar:
    if groups_tiled is not None:
        for (dsid, r, c, th, tw), grp in groups_tiled:
            process_group(dsid, r, c, th, tw, grp)
            pbar.update(1)
    if groups_untiled is not None:
        for (dsid,), grp in groups_untiled:
            # nominal coords for naming (best-effort)
            th = int(grp["tile_h"].iloc[0]) if "tile_h" in grp.columns and pd.notna(grp["tile_h"].iloc[0]) else 0
            tw = int(grp["tile_w"].iloc[0]) if "tile_w" in grp.columns and pd.notna(grp["tile_w"].iloc[0]) else 0
            process_group(dsid, r=0, c=0, th=th, tw=tw, grp=grp)
            pbar.update(1)

# ---------- Write index ----------
if records:
    out_df = pd.DataFrame(records)
    out_df.to_parquet(OUT_LIST, index=False, engine="pyarrow")
    if WRITE_CSV_LIST:
        out_df.to_csv(OUT_LIST.with_suffix(".csv"), index=False)
    print(f"[OK] Wrote {len(out_df)} samples to {OUT_SAMPLES} and list to {OUT_LIST}")
else:
    print("[WARN] No samples assembled. Check manifest/grouping.")

# ---------- Diagnostics ----------
print("---- Diagnostics ----")
print("Rows in manifest:", len(df))
print("Datasets:", df["dataset_id"].nunique())
if tile_cols_exist:
    print("Rows with complete tile coords:", int(mask_complete.sum()))
    print("Tiled groups:", n_tiled)
else:
    print("Tile columns missing; all rows treated as untiled.")
print("Untiled (dataset) groups:", n_untiled)
print("Skip reasons:", dict(skip_reasons))

Assembling samples: 100%|██████████| 4471/4471 [29:32<00:00,  2.52it/s] 

[OK] Wrote 3938 samples to metaspace_images_dump\msi_fm_samples3 and list to metaspace_images_dump\msi_fm_samples3.parquet
---- Diagnostics ----
Rows in manifest: 384403
Datasets: 3774
Rows with complete tile coords: 31348
Tiled groups: 763
Untiled (dataset) groups: 3708
Skip reasons: {'too_few_channels': 533}



