### m/z importance 

In [None]:
# =====================================================
# m/z IMPORTANCE from model embeddings (surrogate via spectra)
# - Loads precomputed embeddings (image_feats.npy + index.csv)
# - Builds binned spectra X_mz from .npz (mz + patch)
# - Learns Z ≈ X_mz @ B (ridge), and class ← Z (L1-OvR)
# - Attributes class→m/z by: score_mz = B @ w_class
# - Outputs: fm_ssl_run/baseline_eval_combined2/importance_mz_from_embeddings/<model_tag>
# =====================================================

import os, re, math, glob, warnings
from typing import Dict, List, Tuple, Optional
import numpy as np
import pandas as pd
from tqdm import tqdm

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression, Ridge
from sklearn.pipeline import Pipeline

warnings.filterwarnings("ignore", category=UserWarning)

# -------------------------
# Config
# -------------------------
SEED = 6740
np.random.seed(SEED)

# --- Data sources ---
SPLIT_CSV   = os.path.join("splits_by_dataset_id.csv")
IDX_PARQUET = r"metaspace_images_dump/msi_fm_samples3.parquet"
MAN_PARQUET = r"metaspace_images_dump/manifest_expanded.parquet"

# --- Embedding model directory (precomputed) ---
MODEL_TAG   = "dinov2_vitb14"
MODEL_DIR   = os.path.join("fm_ssl_run", "pretrained_feats2", MODEL_TAG)  # << your path
FEATS_NPY   = os.path.join(MODEL_DIR, "image_feats.npy")
INDEX_CSV   = os.path.join(MODEL_DIR, "index.csv")  # must include 'sample_path'

# --- Where the raw .npz live (roots for resolver) ---
DATA_ROOTS = [
    r"Y:\coskun-lab\Efe\MSI Foundation Model",
    r"Y:\coskun-lab\Efe\MSI Foundation Model\metaspace_images_dump\msi_fm_samples3",
    os.getcwd(),
]

# --- Output dirs ---
OUT_ROOT = os.path.join("fm_ssl_run", "baseline_eval_combined2", "importance_mz_from_embeddings", MODEL_TAG)
CACHE_DIR = os.path.join(OUT_ROOT, "_cache_mz")
os.makedirs(OUT_ROOT, exist_ok=True)
os.makedirs(CACHE_DIR, exist_ok=True)

# --- Tasks & filtering ---
TASKS = ["organism", "polarity", "Organism_Part", "Condition", "analyzerType", "ionisationSource"]
EXCLUDE_LABELS = {"Condition": {"NA"}}
MIN_CLASS_COUNT = 100

# --- Binning & preprocessing ---
BIN_DA  = 0.05   # Δ m/z (Da)
BIN_PPM = None   # (not used)
TIC_NORM    = False
APPLY_LOG1P = False

# --- Surrogate / attribution settings ---
RIDGE_ALPHA = 5.0       # for Z ≈ X_mz @ B
LR_C        = 0.5       # for class ← Z (L1 OvR)
LR_MAX_ITER = 2000

# --- Visualization ---
TOP_MZ_PER_CLASS = 10
MAX_MZ_IN_PANEL  = 120

# =====================================================
# Helpers
# =====================================================
def safe_name(s: str) -> str:
    return re.sub(r"[^A-Za-z0-9._-]+", "_", str(s)).strip("_")

def _clean(s):
    if pd.isna(s): return None
    s = str(s).strip()
    s = re.sub(r"\s+", " ", s)
    return s

def canonicalize_labels(df):
    df = df.copy()
    pol_map = {"pos":"Positive","positive":"Positive","+":"Positive",
               "neg":"Negative","negative":"Negative","-":"Negative"}
    def canon_polarity(s):
        if s is None: return None
        t = _clean(s).lower()
        t2 = pol_map.get(t, t)
        if t2 in ("positive","negative"):
            return t2.capitalize()
        if "pos" in t: return "Positive"
        if "neg" in t: return "Negative"
        return _clean(s)
    if "polarity" in df.columns:
        df["polarity"] = df["polarity"].map(canon_polarity)

    def canon_ion_src(s):
        if s is None: return None
        t_raw = _clean(s)
        t = t_raw.upper().replace("-", "").replace("_","")
        if "APSMALDI" in t: return "AP-SMALDI"
        if "IRMALDESI" in t or "IRMALDI" in t: return "IR-MALDESI"
        if "APMALDI" in t: return "AP-MALDI"
        if "DESIMSI" in t: return "DESI"
        if "DESI" in t: return "DESI"
        if "MALDI" in t: return "MALDI"
        return t_raw
    if "ionisationSource" in df.columns:
        df["ionisationSource"] = df["ionisationSource"].map(canon_ion_src)

    def canon_analyzer(s):
        if s is None: return None
        t = _clean(s); tl = t.lower()
        if "timstof" in tl and "flex" in tl: return "timsTOF Flex"
        if "fticr" in tl:
            if "12t" in tl: return "12T FTICR"
            if "7t" in tl and "scimax" in tl: return "FTICR scimaX 7T"
            return "FTICR"
        if "orbitrap" in tl or "q-exactive" in tl: return "Orbitrap"
        if "tof" in tl and "reflector" in tl: return "TOF reflector"
        if tl.strip() == "qtof": return "qTOF"
        return t
    if "analyzerType" in df.columns:
        df["analyzerType"] = df["analyzerType"].map(canon_analyzer)

    def canon_organism(s):
        if s is None: return None
        t = _clean(s); tl = t.lower()
        if "|" in t or "," in t:
            if ("human" in tl or "homo sapiens" in tl) and ("mouse" in tl or "mus musculus" in tl):
                return "Mixed"
        if "homo sapiens" in tl or tl.strip() in {"human","h. sapiens","homo"}:
            return "Homo sapiens"
        if "mus musculus" in tl or tl.strip() in {"mouse","m. musculus"}:
            return "Mus musculus"
        return t
    if "organism" in df.columns:
        df["organism"] = df["organism"].map(canon_organism)

    def canon_part(s):
        if s is None: return None
        t = _clean(s); tl = t.lower()
        if "kidney" in tl: return "Kidney"
        if "brain"  in tl: return "Brain"
        if "liver"  in tl: return "Liver"
        if "lung"   in tl: return "Lung"
        if "breast" in tl: return "Breast"
        if "skin"   in tl: return "Skin"
        if "heart"  in tl or "cardiac" in tl: return "Heart"
        return t
    if "Organism_Part" in df.columns:
        df["Organism_Part"] = df["Organism_Part"].map(canon_part)

    def canon_condition(s):
        if s is None: return None
        t = _clean(s); tl = t.lower()
        if tl in {"n/a","na","none","not available",""}: return "NA"
        if tl in {"biopsy","biopsies"}: return "Biopsy"
        if "fresh frozen" in tl or "frozen" in tl: return "Frozen"
        if "tumor" in tl or "tumour" in tl: return "Tumor"
        if "cancer" in tl: return "Cancer"
        if "wildtype" in tl or tl == "wt": return "Wildtype"
        if "healthy" in tl or "control" in tl: return "Healthy"
        if "diseased" in tl or "disease" in tl: return "Diseased"
        return t
    if "Condition" in df.columns:
        df["Condition"] = df["Condition"].map(canon_condition)
    return df

def filter_valid(df_task, yname, min_count=5):
    x = df_task.dropna(subset=[yname]).copy()
    if yname in EXCLUDE_LABELS:
        x = x[~x[yname].isin(EXCLUDE_LABELS[yname])]
    x = x[x[yname].astype(str).str.len() > 0]
    vc = x[yname].value_counts()
    keep = vc[vc >= min_count].index
    x = x[x[yname].isin(keep)].copy()
    return x

def fit_l1_ovr_on_embeddings(Z: np.ndarray, y: np.ndarray, classes: List[str]) -> Dict[str, np.ndarray]:
    out = {}
    for cls in classes:
        y_bin = (y == cls).astype(int)
        pipe = Pipeline([
            ("scaler", StandardScaler(with_mean=True, with_std=True)),
            ("clf", LogisticRegression(
                penalty="l1", solver="saga", C=LR_C, max_iter=LR_MAX_ITER,
                class_weight="balanced", n_jobs=-1, random_state=SEED
            ))
        ])
        pipe.fit(Z, y_bin)
        scaler: StandardScaler = pipe.named_steps["scaler"]
        w_std = pipe.named_steps["clf"].coef_.ravel()
        out[cls] = w_std / (scaler.scale_ + 1e-12)  # back to embedding scale
    return out  # dict[class] -> (D,)

def ridge_embeddings_from_mz(X_mz: np.ndarray, Z: np.ndarray, alpha: float) -> np.ndarray:
    """Fit Z ≈ X_mz @ B with Ridge. Returns B of shape (M_mz, D_emb)."""
    model = Ridge(alpha=alpha, fit_intercept=True, random_state=SEED)
    model.fit(X_mz, Z)                    # multioutput
    B = model.coef_.T                     # shape (n_targets, n_features) -> (M, D).T => (M, D)
    return B

def select_union_top(feats_by_class: Dict[str, np.ndarray], top_k: int, max_total: int) -> np.ndarray:
    idx_sets = []
    for _, vec in feats_by_class.items():
        order = np.argsort(-np.abs(vec))
        idx_sets.append(set(order[:top_k].tolist()))
    union = list(set().union(*idx_sets))
    if len(union) > max_total:
        mat = np.stack([feats_by_class[c] for c in feats_by_class], axis=0)
        agg = np.mean(np.abs(mat), axis=0)
        order = np.argsort(-agg[union])
        union = [union[i] for i in order[:max_total]]
    return np.array(sorted(union))

def format_mz(mz_vals: np.ndarray) -> List[str]:
    return [f"{m:.2f}" for m in mz_vals]

def plot_heatmap(df_vals: pd.DataFrame, title: str, out_png: str):
    plt.figure(figsize=(min(28, 2 + 0.45*len(df_vals.columns)), 0.6 + 0.55*len(df_vals)))
    ax = sns.heatmap(df_vals, cmap="vlag", center=0.0, linewidths=0.3, linecolor='white')
    ax.set_title(title, fontsize=14)
    ax.set_xlabel("m/z")
    ax.set_ylabel("Class")
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0)  # horizontal Y labels
    # ax.set_xticklabels(ax.get_xticklabels(), rotation=90)  # optional if many columns
    plt.tight_layout()
    plt.savefig(out_png, dpi=200)
    plt.close()

# ---------- robust resolver + loader for npz ----------
LIKELY_SUBFOLDERS = [os.path.join("metaspace_images_dump", "msi_fm_samples3"), ""]
def resolve_npz_path(sp: str) -> Optional[str]:
    sp_norm = os.path.normpath(sp)
    base = os.path.basename(sp_norm)
    base_noext, ext = os.path.splitext(base)
    if os.path.isabs(sp_norm) and os.path.exists(sp_norm):
        return sp_norm
    for root in DATA_ROOTS:
        cand = os.path.normpath(os.path.join(root, sp_norm))
        if os.path.exists(cand): return cand
    for root in DATA_ROOTS:
        for sub in LIKELY_SUBFOLDERS:
            cand2 = os.path.normpath(os.path.join(root, sub, base))
            if os.path.exists(cand2): return cand2
            if ext == "" and os.path.exists(cand2 + ".npz"): return cand2 + ".npz"
    for root in DATA_ROOTS:
        pat = os.path.join(os.path.normpath(root), "**", base)
        hits = glob.glob(pat, recursive=True)
        if not hits and ext == "": hits = glob.glob(pat + ".npz", recursive=True)
        if hits: return os.path.normpath(hits[0])
    return None

# include 'patch' for your tiles + other common keys
MZ_KEYS = ["mz", "mzs", "mz_axis", "mass", "mass_axis", "MZ", "MZ_axis"]
POSSIBLE_IMAGE_KEYS = ["patch","tile","tiles","img","image","images","arr","X","cube","data",
                       "spec","spectra","intensity","intensities"]

def load_npz_spectrum(npz_path: str) -> Tuple[np.ndarray, np.ndarray]:
    with np.load(npz_path, allow_pickle=True) as npz:
        mz_key = next((k for k in MZ_KEYS if k in npz.files), None)
        if mz_key is None:
            raise KeyError(f"No m/z key in {npz_path}; keys={list(npz.files)}")
        mz = np.array(npz[mz_key], dtype=float).ravel()
        img_key = next((k for k in POSSIBLE_IMAGE_KEYS if k in npz.files), None)
        if img_key is None:
            raise KeyError(f"No intensity key in {npz_path}; keys={list(npz.files)}")
        arr = npz[img_key]
        if arr.ndim == 1:
            if arr.size != mz.size: raise ValueError(f"1D {arr.size} != len(mz) {mz.size}")
            spec = arr.astype(np.float64)
        elif arr.ndim == 2:
            if arr.shape[1] == mz.size: spec = arr.mean(axis=0)
            elif arr.shape[0] == mz.size: spec = arr.mean(axis=1)
            else: raise ValueError(f"2D {arr.shape} vs mz {mz.size}")
        elif arr.ndim == 3:
            if arr.shape[-1] == mz.size: spec = arr.reshape(-1, arr.shape[-1]).mean(axis=0)
            elif arr.shape[0] == mz.size: spec = arr.reshape(arr.shape[0], -1).mean(axis=1)
            else: raise ValueError(f"3D {arr.shape} vs mz {mz.size}")
        elif arr.ndim == 4:
            if arr.shape[-1] == mz.size: spec = arr.reshape(-1, arr.shape[-1]).mean(axis=0)
            elif arr.shape[1] == mz.size: spec = arr.reshape(arr.shape[0], arr.shape[1], -1).mean(axis=(0,2))
            else: raise ValueError(f"4D {arr.shape} vs mz {mz.size}")
        else:
            raise ValueError(f"Unsupported ndim={arr.ndim}")
        return mz, spec

def build_global_bins(min_mz: float, max_mz: float, bin_da: Optional[float], ppm: Optional[float]) -> np.ndarray:
    if ppm is not None: raise NotImplementedError("PPM binning not implemented; use BIN_DA.")
    n_bins = int(math.ceil((max_mz - min_mz) / bin_da)) + 1
    edges = min_mz + np.arange(n_bins + 1) * bin_da
    centers = (edges[:-1] + edges[1:]) / 2.0
    return centers

def bin_spectrum(mz: np.ndarray, spec: np.ndarray, centers: np.ndarray, bin_da: float) -> np.ndarray:
    min_edge = centers[0] - bin_da/2
    idx = np.floor((mz - min_edge) / bin_da).astype(int)
    valid = (idx >= 0) & (idx < centers.size)
    b = np.zeros(centers.size, dtype=np.float64)
    np.add.at(b, idx[valid], spec[valid])
    return b

# =====================================================
# Load metadata + splits
# =====================================================
idx_df = pd.read_parquet(IDX_PARQUET)          # contains sample_path + dataset_id
man_df = pd.read_parquet(MAN_PARQUET)          # dataset-level metadata
need_cols = ["dataset_id","sample_path","organism","polarity","Organism_Part","Condition","analyzerType","ionisationSource"]
man_sub = man_df[[c for c in need_cols if c in man_df.columns]].drop_duplicates("dataset_id")
df_meta = idx_df.merge(man_sub, on="dataset_id", how="left", suffixes=("", "_man"))
df_meta = df_meta.loc[:, ~df_meta.columns.duplicated()].copy().reset_index(drop=True)
splits = pd.read_csv(SPLIT_CSV)
df_meta = df_meta.merge(splits, on="dataset_id", how="left")
if df_meta.duplicated("sample_path").sum():
    print("[WARN] duplicate sample_path rows; keeping first.")
    df_meta = df_meta.drop_duplicates("sample_path", keep="first").reset_index(drop=True)
df_meta = canonicalize_labels(df_meta)

# =====================================================
# Build/load m/z-binned spectra cache
# =====================================================
CACHE_AXIS = os.path.join(CACHE_DIR, f"mz_axis_da_{str(BIN_DA).replace('.','p')}.npy")
CACHE_X    = os.path.join(CACHE_DIR, f"X_mz_da_{str(BIN_DA).replace('.','p')}.npy")
CACHE_IDX  = os.path.join(CACHE_DIR, f"index_used_da_{str(BIN_DA).replace('.','p')}.csv")
need_rebuild = not (os.path.exists(CACHE_AXIS) and os.path.exists(CACHE_X) and os.path.exists(CACHE_IDX))

if need_rebuild:
    all_min, all_max = np.inf, -np.inf
    ok_paths, used_rows = [], []
    readable = 0
    resolved_ct = exists_ct = open_ok_ct = 0

    for sp in tqdm(df_meta["sample_path"].tolist(), desc="Scan mz ranges"):
        npz_path = resolve_npz_path(sp)
        if npz_path is None: continue
        resolved_ct += 1
        if os.path.exists(npz_path): exists_ct += 1
        try:
            mz, _ = load_npz_spectrum(npz_path)
            open_ok_ct += 1
            if mz.size == 0: continue
            all_min = min(all_min, float(np.min(mz)))
            all_max = max(all_max, float(np.max(mz)))
            ok_paths.append(npz_path)
            readable += 1
        except Exception:
            continue

    print(f"[INFO] mz-scan: resolved={resolved_ct}, exists={exists_ct}, opened_ok={open_ok_ct}, readable={readable}")
    if readable == 0:
        raise RuntimeError("No readable .npz for spectra binning.")

    centers = build_global_bins(all_min, all_max, BIN_DA, BIN_PPM)

    X = np.zeros((len(ok_paths), centers.size), dtype=np.float64)
    for i, npz_path in enumerate(tqdm(ok_paths, desc="Bin spectra")):
        try:
            mz, spec = load_npz_spectrum(npz_path)
            if TIC_NORM: spec = spec / (spec.sum() + 1e-12)
            if APPLY_LOG1P: spec = np.log1p(spec)
            X[i] = bin_spectrum(mz, spec, centers, BIN_DA)
            used_rows.append(npz_path)
        except Exception:
            X[i] = np.nan

    mask = ~np.isnan(X).any(axis=1)
    X = X[mask]
    used_paths = [os.path.normpath(p) for p, m in zip(ok_paths, mask) if m]

    np.save(CACHE_AXIS, centers); np.save(CACHE_X, X)
    pd.DataFrame({"sample_path": used_paths}).to_csv(CACHE_IDX, index=False)
    print(f"[OK] Cached X_mz: {X.shape}, axis: {centers.shape}")
else:
    centers = np.load(CACHE_AXIS)
    X = np.load(CACHE_X)
    used_paths = [os.path.normpath(p) for p in pd.read_csv(CACHE_IDX)["sample_path"].tolist()]
    print(f"[OK] Loaded cached X_mz: {X.shape}, axis: {centers.shape}")

# =====================================================
# Load embeddings + index
# =====================================================
if not (os.path.exists(FEATS_NPY) and os.path.exists(INDEX_CSV)):
    raise FileNotFoundError(f"Expected {FEATS_NPY} and {INDEX_CSV}")

Z = np.load(FEATS_NPY)               # (N_emb, D)
idx_emb = pd.read_csv(INDEX_CSV)     # must include 'sample_path'
if "sample_path" not in idx_emb.columns:
    raise KeyError(f"'sample_path' missing in {INDEX_CSV}")

# =====================================================
# Strict 3-way alignment: X_mz ↔ Z ↔ metadata
# =====================================================
# 1) Resolve paths everywhere
df_meta["resolved_path"] = df_meta["sample_path"].apply(resolve_npz_path)
idx_emb["resolved_path"] = idx_emb["sample_path"].apply(resolve_npz_path)

# 2) Build row maps for each source
x_row   = {p: i for i, p in enumerate(used_paths)}                         # path -> row in X
emb_row = {p: i for i, p in enumerate(idx_emb["resolved_path"].tolist())}  # path -> row in Z

# 3) Compute common set across all three (canonical order = X order)
meta_paths = set(df_meta["resolved_path"].dropna().tolist())
emb_paths  = set(k for k in emb_row.keys() if k is not None)
x_paths    = set(x_row.keys())
common = [p for p in used_paths if p in emb_paths and p in meta_paths]
if len(common) == 0:
    raise RuntimeError("No overlap between X_mz (npz), embeddings, and metadata.")

print(f"[ALIGN] |X paths|={len(x_paths)} |emb paths|={len(emb_paths)} |meta paths|={len(meta_paths)} |common|={len(common)}")

# 4) Subset & reorder all to common canonical order
X = X[[x_row[p] for p in common]]
Z = Z[[emb_row[p] for p in common]]

df_used = df_meta[df_meta["resolved_path"].isin(common)].copy()
order_map = {p:i for i,p in enumerate(common)}
df_used["__ord"] = df_used["resolved_path"].map(order_map)
df_used = df_used.sort_values("__ord").drop(columns="__ord").reset_index(drop=True)
df_used["x_row"] = np.arange(len(df_used), dtype=int)  # stable row index for slicing

print(f"[ALIGN] Final shapes — X_mz: {X.shape}, Z: {Z.shape}, meta rows: {len(df_used)}")

# =====================================================
# Learn surrogate & attribute class→m/z
# =====================================================
PLOTS_DIR = os.path.join(OUT_ROOT, "plots"); os.makedirs(PLOTS_DIR, exist_ok=True)
CSV_DIR   = os.path.join(OUT_ROOT, "csv");   os.makedirs(CSV_DIR, exist_ok=True)

# Learn Z ≈ X_mz @ B (use all rows that survived intersection)
B = ridge_embeddings_from_mz(X, Z, alpha=RIDGE_ALPHA)  # (M_mz, D_emb)
np.save(os.path.join(CSV_DIR, "B_mz_to_embedding.npy"), B)

for task in TASKS:
    if task not in df_used.columns:
        print(f"[SKIP] {task}: missing in metadata.")
        continue

    df_task = filter_valid(df_used, task, min_count=MIN_CLASS_COUNT)
    if df_task.empty:
        print(f"[SKIP] {task}: empty after filtering.")
        continue

    # Use TRAIN+VAL for classifier on embeddings
    m_tr = (df_task["split"] == "train").values
    m_va = (df_task["split"] == "val").values
    fit_mask = m_tr | m_va
    if fit_mask.sum() < 2 or len(np.unique(df_task.loc[fit_mask, task])) < 2:
        print(f"[SKIP] {task}: insufficient TRAIN+VAL.")
        continue

    rows = df_task["x_row"].values
    Z_task = Z[rows]
    y_task = df_task[task].astype(str).values
    classes = sorted(np.unique(y_task).tolist())

    # Fit class ← Z (L1 OvR) on TRAIN+VAL subset
    Z_fit = Z_task[fit_mask]; y_fit = y_task[fit_mask]
    w_by_class = fit_l1_ovr_on_embeddings(Z_fit, y_fit, classes)  # dict[class] -> (D,)

    # Attribute to m/z: score_mz[class] = B @ w_class
    scores_mz = {cls: B @ w_by_class[cls] for cls in classes}  # dict[class] -> (M,)

    # Select columns by top |score| across classes
    sel_idx = select_union_top(scores_mz, TOP_MZ_PER_CLASS, MAX_MZ_IN_PANEL)
    sel_labels = format_mz(centers[sel_idx])

    # Panel (signed) heatmap dataframe
    df_panel = pd.DataFrame(
        np.vstack([scores_mz[c][sel_idx] for c in classes]),
        index=classes, columns=sel_labels
    )

    # Save CSVs
    df_panel.to_csv(os.path.join(CSV_DIR, f"{safe_name(task)}__class_to_mz_panel.csv"))
    full_labels = format_mz(centers)
    df_full = pd.DataFrame({c: scores_mz[c] for c in classes}, index=full_labels)
    df_full.to_csv(os.path.join(CSV_DIR, f"{safe_name(task)}__class_to_mz_full.csv"))

    # Selected m/z map (exact centers)
    pd.DataFrame({"col": list(df_panel.columns), "mz": centers[sel_idx]}).to_csv(
        os.path.join(CSV_DIR, f"{safe_name(task)}__selected_mz_columns.csv"), index=False
    )

    # Plot heatmap (horizontal class labels)
    plot_heatmap(
        df_panel,
        title=f"{task} — m/z attribution via {MODEL_TAG} (B @ w_class)",
        out_png=os.path.join(PLOTS_DIR, f"{safe_name(task)}__mz_attr_heatmap.png"),
    )

print(f"\n[DONE] Outputs in: {OUT_ROOT}")


[WARN] duplicate sample_path rows; keeping first.
[OK] Loaded cached X_mz: (3928, 50026), axis: (50026,)
[ALIGN] |X paths|=3928 |emb paths|=3928 |meta paths|=3928 |common|=3928
[ALIGN] Final shapes — X_mz: (3928, 50026), Z: (3928, 768), meta rows: 3928

[DONE] Outputs in: fm_ssl_run\baseline_eval_combined2\importance_mz_from_embeddings\dinov2_vitb14
