In [1]:
# ==========================
# Align from cache -> A, labels, groups
# ==========================
import os, json, numpy as np, pandas as pd, scanpy as sc

# ---- paths / config
CACHE_DIR = "/maiziezhou_lab2/yunfei/Projects/interpTFM/activations_cosmx_lung_cancer/cached/layer4"
A_CACHE   = os.path.join(CACHE_DIR, "A_cls.npy")     # scGPT <CLS> activations
J_CACHE   = os.path.join(CACHE_DIR, "J_embed.npy")
IDS_CACHE = os.path.join(CACHE_DIR, "cell_order.txt")
H5AD_PATH = "/maiziezhou_lab2/yunfei/Projects/FM_temp/InterPLM/interplm/ge_shards/cosmx_human_lung_sec8.h5ad"
GROUP_COL = "patient"   # set to None if you don't want within-patient shuffles

# ---- load cache + AnnData
A_cls = np.load(A_CACHE, mmap_mode="r")  # shape [N_all, U]
with open(IDS_CACHE, "r") as f:
    cell_order_cached = [ln.strip() for ln in f]

adata = sc.read_h5ad(H5AD_PATH)

# ---- build intersection in A's order
ids_A = pd.Index(cell_order_cached)
ids_J = pd.Index(adata.obs_names.astype(str))
idx_J_for_A = ids_J.get_indexer(ids_A)  # -1 if not present in adata
keep_mask   = idx_J_for_A >= 0

idx_A_keep = np.nonzero(keep_mask)[0]    # rows in A order to keep
idx_J_keep = idx_J_for_A[keep_mask]      # corresponding rows in adata

# A_cls   = np.load(A_CACHE, mmap_mode="r")   # shape [N, U]
J_embed = np.load(J_CACHE, mmap_mode="r")   # shape [N, F]

# ---- slice and name the arrays EXACTLY as used downstream
A_for_mis       = np.asarray(A_cls[idx_A_keep], dtype=np.float32)      # [M, U]
labels_aligned  = adata.obs["author_cell_type"].to_numpy()[idx_J_keep]        # [M]

# expose the names expected by the justification script:
A      = A_for_mis
labels = labels_aligned

# optional: groups for within-group permutations/splits (patient/slide/etc.)
if GROUP_COL and (GROUP_COL in adata.obs.columns):
    groups = adata.obs[GROUP_COL].to_numpy()[idx_J_keep]
else:
    groups = None

# ---- quick sanity
print(f"[sanity] A shape: {A.shape}  | labels: {labels.shape}  | groups: {None if groups is None else groups.shape}")

[sanity] A shape: (81236, 512)  | labels: (81236,)  | groups: None


In [2]:
from scipy.stats import binomtest, hypergeom
from statsmodels.stats.multitest import multipletests

def hypergeo_enrichment_for_unit(u, A, labels, min_in_high=50):
    if labels is None: 
        return None
    a = A[:, u]
    high = a >= np.median(a)
    idx_high = np.where(high)[0]
    if len(idx_high) < min_in_high:
        return None

    lab = pd.Series(labels)
    lab_high = pd.Series(labels[idx_high])

    N = len(lab)
    n = len(idx_high)

    rows = []
    for label, K in lab.value_counts().items():
        k = int((lab_high == label).sum())
        # overlap k or more among n draws from population N with K "successes"
        p = hypergeom.sf(k-1, N, K, n)
        rows.append((label, k, K, n, N, p))
    df = pd.DataFrame(rows, columns=["label","k_in_high","K_in_pop","n_high","N_total","p"])
    df["p_fdr"] = multipletests(df["p"], method="fdr_bh")[1]
    df["enrichment_ratio"] = (df["k_in_high"]/df["n_high"]) / (df["K_in_pop"]/df["N_total"])
    return df.sort_values(["p_fdr","enrichment_ratio"], ascending=[True,False])

# # --- Make sure your variables are correctly defined ---
A = np.asarray(A_cls)
labels = np.asarray(labels_aligned)
# --- Base directory for the new run ---
out_dir = "./out"

# --- Load MIS results ---
MIS = np.load(f"{out_dir}/MIS_expr_zcos.npy")

with open(f"{out_dir}/summary_expr_zcos.json") as f:
    summary = json.load(f)

# Example: top 10 units
for u in [r["unit"] for r in summary["top_units"][:10]]:
    print("Unit", u, "MIS", MIS[u])
    display(hypergeo_enrichment_for_unit(u, A, labels).head(8))


Unit 246 MIS 0.845


Unnamed: 0,label,k_in_high,K_in_pop,n_high,N_total,p,p_fdr,enrichment_ratio
2,fibroblast,11921,13609,40618,81236,0.0,0.0,1.751929
3,T CD4 memory,7411,9344,40618,81236,0.0,0.0,1.586259
7,plasmablast,1513,1638,40618,81236,1.267523e-308,9.295167e-308,1.847375
8,mast,1391,1487,40618,81236,1.0474280000000001e-299,5.760856e-299,1.870881
6,Treg,1886,2244,40618,81236,1.8145529999999998e-256,7.984032e-256,1.680927
5,pDC,2135,2702,40618,81236,1.248822e-219,4.5790140000000005e-219,1.580311
10,neutrophil,352,412,40618,81236,6.895189e-52,2.1670589999999998e-51,1.708738
12,T CD8 naive,175,184,40618,81236,1.981493e-41,5.449105e-41,1.902174


Unit 437 MIS 0.825


Unnamed: 0,label,k_in_high,K_in_pop,n_high,N_total,p,p_fdr,enrichment_ratio
0,tumor 13,20770,26030,40618,81236,0.0,0.0,1.595851
9,endothelial,791,1449,40618,81236,0.000232,0.002547,1.091787
1,macrophage,9147,17989,40618,81236,0.005102,0.037418,1.016955
4,T CD8 memory,1162,3067,40618,81236,1.0,1.0,0.757744
3,T CD4 memory,3172,9344,40618,81236,1.0,1.0,0.678938
11,mDC,106,338,40618,81236,1.0,1.0,0.627219
2,fibroblast,4085,13609,40618,81236,1.0,1.0,0.600338
5,pDC,664,2702,40618,81236,1.0,1.0,0.491488


Unit 467 MIS 0.805


Unnamed: 0,label,k_in_high,K_in_pop,n_high,N_total,p,p_fdr,enrichment_ratio
2,fibroblast,12058,13609,40618,81236,0.0,0.0,1.772063
3,T CD4 memory,8212,9344,40618,81236,0.0,0.0,1.757705
8,mast,1385,1487,40618,81236,9.444565999999999e-293,6.926015000000001e-292,1.862811
6,Treg,1729,2244,40618,81236,7.236646e-157,3.9801560000000003e-156,1.540998
5,pDC,1833,2702,40618,81236,3.699197e-81,1.627647e-80,1.356773
7,plasmablast,1190,1638,40618,81236,2.149129e-79,7.880141e-79,1.452991
10,neutrophil,332,412,40618,81236,4.924135999999999e-38,1.547586e-37,1.61165
12,T CD8 naive,172,184,40618,81236,8.158654999999999e-38,2.24363e-37,1.869565


Unit 55 MIS 0.78


Unnamed: 0,label,k_in_high,K_in_pop,n_high,N_total,p,p_fdr,enrichment_ratio
0,tumor 13,25346,26030,40618,81236,0.0,0.0,1.947445
6,Treg,1624,2244,40618,81236,4.178101e-106,4.595912e-105,1.447415
4,T CD8 memory,2003,3067,40618,81236,4.118266e-68,3.020061e-67,1.306162
7,plasmablast,1109,1638,40618,81236,1.279085e-48,7.034968999999999e-48,1.35409
10,neutrophil,333,412,40618,81236,1.1699579999999999e-38,5.147813999999999e-38,1.616505
13,tumor 9,148,170,40618,81236,1.793582e-24,6.5764669999999995e-24,1.741176
14,epithelial,110,132,40618,81236,1.364076e-15,4.287095e-15,1.666667
12,T CD8 naive,130,184,40618,81236,9.875154e-09,2.715667e-08,1.413043


Unit 213 MIS 0.7775


Unnamed: 0,label,k_in_high,K_in_pop,n_high,N_total,p,p_fdr,enrichment_ratio
0,tumor 13,22213,26030,40618,81236,0.0,0.0,1.706723
1,macrophage,9958,17989,40618,81236,6.117135000000001e-60,6.728849e-59,1.107121
4,T CD8 memory,1869,3067,40618,81236,1.658207e-35,1.216019e-34,1.218781
11,mDC,191,338,40618,81236,0.009478861,0.05213373,1.130178
8,mast,560,1487,40618,81236,1.0,1.0,0.753194
6,Treg,843,2244,40618,81236,1.0,1.0,0.751337
17,B-cell,24,70,40618,81236,0.9972188,1.0,0.685714
3,T CD4 memory,2919,9344,40618,81236,1.0,1.0,0.624786


Unit 196 MIS 0.775


Unnamed: 0,label,k_in_high,K_in_pop,n_high,N_total,p,p_fdr,enrichment_ratio
0,tumor 13,24391,26030,40618,81236,0.0,0.0,1.874068
4,T CD8 memory,2320,3067,40618,81236,5.1202839999999995e-193,5.632313e-192,1.512879
6,Treg,1467,2244,40618,81236,2.7348919999999995e-50,2.005587e-49,1.307487
12,T CD8 naive,153,184,40618,81236,6.279403e-21,3.4536709999999996e-20,1.663043
15,NK,85,105,40618,81236,4.912775e-11,2.161621e-10,1.619048
13,tumor 9,110,170,40618,81236,7.624572e-05,0.0002795676,1.294118
14,epithelial,87,132,40618,81236,0.0001597967,0.0005022182,1.318182
19,T CD4 naive,41,62,40618,81236,0.007547907,0.02075675,1.322581


Unit 292 MIS 0.7725


Unnamed: 0,label,k_in_high,K_in_pop,n_high,N_total,p,p_fdr,enrichment_ratio
2,fibroblast,11095,13609,40618,81236,0.0,0.0,1.630539
1,macrophage,14578,17989,40618,81236,0.0,0.0,1.620768
3,T CD4 memory,6539,9344,40618,81236,0.0,0.0,1.399615
5,pDC,1953,2702,40618,81236,5.458871e-127,3.002379e-126,1.445596
8,mast,1002,1487,40618,81236,1.099372e-42,4.837237e-42,1.34768
11,mDC,286,338,40618,81236,9.795671e-41,3.591746e-40,1.692308
16,monocyte,62,93,40618,81236,0.0008512247,0.002675278,1.333333
7,plasmablast,848,1638,40618,81236,0.07738806,0.2128172,1.035409


Unit 346 MIS 0.7725


Unnamed: 0,label,k_in_high,K_in_pop,n_high,N_total,p,p_fdr,enrichment_ratio
0,tumor 13,22185,26030,40618,81236,0.0,0.0,1.704572
13,tumor 9,72,170,40618,81236,0.981069,1.0,0.847059
1,macrophage,6732,17989,40618,81236,1.0,1.0,0.748457
2,fibroblast,4612,13609,40618,81236,1.0,1.0,0.677787
3,T CD4 memory,3162,9344,40618,81236,1.0,1.0,0.676798
8,mast,478,1487,40618,81236,1.0,1.0,0.642905
4,T CD8 memory,974,3067,40618,81236,1.0,1.0,0.635148
5,pDC,850,2702,40618,81236,1.0,1.0,0.629164


Unit 382 MIS 0.77


Unnamed: 0,label,k_in_high,K_in_pop,n_high,N_total,p,p_fdr,enrichment_ratio
7,plasmablast,1527,1638,40618,81236,0.0,0.0,1.864469
5,pDC,2347,2702,40618,81236,0.0,0.0,1.737232
3,T CD4 memory,5821,9344,40618,81236,2.386328e-142,1.749974e-141,1.245933
6,Treg,1668,2244,40618,81236,5.597477e-126,3.078612e-125,1.486631
2,fibroblast,8036,13609,40618,81236,3.286175e-119,1.445917e-118,1.180983
1,macrophage,10228,17989,40618,81236,5.744488e-97,2.106312e-96,1.137139
11,mDC,323,338,40618,81236,4.998691e-77,1.571017e-76,1.911243
8,mast,1079,1487,40618,81236,1.444315e-71,3.971865e-71,1.451244


Unit 187 MIS 0.77


Unnamed: 0,label,k_in_high,K_in_pop,n_high,N_total,p,p_fdr,enrichment_ratio
0,tumor 13,21833,26030,40618,81236,0.0,0.0,1.677526
1,macrophage,10232,17989,40618,81236,1.386802e-97,1.5254830000000001e-96,1.137584
4,T CD8 memory,2033,3067,40618,81236,5.623388e-77,4.123818e-76,1.325725
9,endothelial,762,1449,40618,81236,0.02488772,0.1368824,1.05176
6,Treg,842,2244,40618,81236,1.0,1.0,0.750446
5,pDC,802,2702,40618,81236,1.0,1.0,0.593634
15,NK,31,105,40618,81236,0.9999935,1.0,0.590476
21,tumor 5,6,23,40618,81236,0.9946943,1.0,0.521739


In [3]:
# -----------------------
# Illustrator-friendly style
# -----------------------
import matplotlib.pyplot as plt
import seaborn as sns

plt.rcParams.update({
    "figure.dpi": 120,
    "savefig.dpi": 300,
    "font.size": 10,
    "axes.labelsize": 10,
    "axes.titlesize": 11,
    "xtick.labelsize": 9,
    "ytick.labelsize": 9,
    "legend.fontsize": 9,
    "pdf.fonttype": 42,
    "ps.fonttype": 42,
})
sns.set_context("paper"); sns.set_style("white")

In [None]:
# ------------ Illustrator-friendly grayscale bar plots ------------
import os
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt

mpl.rcParams.update({
    "figure.dpi": 120,
    "savefig.dpi": 300,
    "font.size": 10,
    "axes.labelsize": 10,
    "axes.titlesize": 11,
    "xtick.labelsize": 9,
    "ytick.labelsize": 9,
    "legend.fontsize": 9,
    "pdf.fonttype": 42,
    "ps.fonttype": 42,
    "axes.linewidth": 0.8,
    "axes.edgecolor": "black",
    "xtick.major.size": 3,
    "ytick.major.size": 3,
    "xtick.major.width": 0.8,
    "ytick.major.width": 0.8,
    "xtick.direction": "out",
    "ytick.direction": "out",
    "axes.grid": False,
    "savefig.transparent": False,
})

def plot_unit_enrichment_grayscale(
    df_unit,
    unit,
    out_dir,
    score_col="enrichment_ratio",
    fdr_cutoff=None,        # keep None if you don't want filtering here
    min_count_in_high=0,
    top_n=20,
    width_in=7.2,           # << wider (≈ double-column); was 3.35
    height_per_bar=0.32,    # scale height by #bars
    annotate=False          # << turn off k_in_high/n_high labels
):
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    from pathlib import Path

    if df_unit is None or len(df_unit) == 0:
        return None

    d = df_unit.copy()
    if min_count_in_high > 0 and "k_in_high" in d.columns:
        d = d[d["k_in_high"] >= min_count_in_high]
    if (fdr_cutoff is not None) and ("p_fdr" in d.columns):
        d = d[d["p_fdr"] <= fdr_cutoff]

    if score_col not in d.columns:
        raise ValueError(f"'{score_col}' not found in the DataFrame.")
    d = d.replace([np.inf, -np.inf], np.nan).dropna(subset=[score_col])
    d["label"] = d["label"].astype(str)
    d = d.sort_values(score_col, ascending=False).head(top_n)
    if len(d) == 0:
        return None

    height_in = max(2.4, height_per_bar * len(d))
    fig, ax = plt.subplots(figsize=(width_in, height_in))

    ax.bar(d["label"], d[score_col], color="0.2", edgecolor="0.1", linewidth=0.6)
    ax.set_xlabel("cell type")
    ax.set_ylabel(score_col.replace("_", " "))
    ax.set_title(f"Unit {unit} enrichment")

    # Rotate ticks and align right for readability
    ax.tick_params(axis="x", labelrotation=60)
    plt.setp(ax.get_xticklabels(), ha="right")

    # REMOVE the k/n annotations by default (annotate=False)
    if annotate and {"k_in_high","n_high"}.issubset(d.columns):
        for x, y, k, n in zip(d["label"], d[score_col], d["k_in_high"], d["n_high"]):
            if np.isfinite(y):
                ax.text(x, y, f"{int(k)}/{int(n)}", va="bottom", ha="center", fontsize=8)

    fig.tight_layout()
    # Give a little extra bottom margin for long labels
    fig.subplots_adjust(bottom=0.25)

    out_dir = Path(out_dir); out_dir.mkdir(parents=True, exist_ok=True)
    png = out_dir / f"unit_{unit}_enrichment_{score_col}.png"
    pdf = out_dir / f"unit_{unit}_enrichment_{score_col}.pdf"
    # fig.savefig(png, dpi=300, bbox_inches="tight")
    fig.savefig(pdf, bbox_inches="tight", metadata={"Title": f"Unit {unit} enrichment"})
    plt.close(fig)
    return str(pdf)

# ---------------- Example loop over your top units ----------------
# Assumes you already defined:
# - hypergeo_enrichment_for_unit(...)
# - A = np.asarray(A_cls)
# - labels = np.asarray(labels_aligned)
# - summary (with summary["top_units"])
# - MIS (optional, just for printing)

OUT_DIR = "/maiziezhou_lab2/yunfei/Projects/interpTFM/evaluation/TIS/finalized/figures_enrichment"
TOP_K = 15  # or 50

top_units = [r["unit"] for r in summary["top_units"][:TOP_K]]
for u in top_units:
    df_u = hypergeo_enrichment_for_unit(u, A, labels, min_in_high=50)
    # (Optional) quick check
    # print("Unit", u, "MIS", MIS[u], "rows:", 0 if df_u is None else len(df_u))
    path = plot_unit_enrichment_grayscale(
        df_unit=df_u,
        unit=u,
        out_dir=OUT_DIR,
        score_col="enrichment_ratio",
        fdr_cutoff=None,          # set None to disable filtering
        min_count_in_high=0,      # set to e.g., 5 if you want minimum support
        top_n=20,                 # increase if you want more cell types
        width_in=12             # keep consistent across panels
    )
    if path is None:
        print(f"[skip] unit {u}: nothing to plot after filters.")
    else:
        print(f"[ok] saved: {path}")

In [None]:
# Build a unit x rank table of top enriched cell types
TOP_K = 15
units = [r["unit"] for r in summary["top_units"][:20]]  # same selection you used

rows_labels = {}
rows_formatted = {}

for u in units:
    df = hypergeo_enrichment_for_unit(u, A, labels)  # already sorted by p_fdr then enrichment_ratio
    if df is None or df.empty:
        continue
    top = df.head(TOP_K).reset_index(drop=True)

    # just the labels
    rows_labels[u] = top["label"].tolist()

    # label + scores (enrichment ratio and FDR)
    rows_formatted[u] = [
        f"{row.label} (ER={row.enrichment_ratio:.2f}, FDR={row.p_fdr:.1e})"
        for row in top.itertuples(index=False)
    ]

# pad to the same number of ranks
def pad_rows(d):
    if not d: 
        return d
    maxk = max(len(v) for v in d.values())
    for k, v in d.items():
        if len(v) < maxk:
            d[k] = v + [""] * (maxk - len(v))
    return d

rows_labels   = pad_rows(rows_labels)
rows_formatted = pad_rows(rows_formatted)

rank_cols = [f"rank{i+1}" for i in range(len(next(iter(rows_labels.values()), [])))]

# (A) table with just cell-type names
enrich_table = pd.DataFrame.from_dict(rows_labels, orient="index", columns=rank_cols).sort_index()
enrich_table.index.name = "unit"

# (B) table with cell-type + (ER, FDR)
enrich_table_with_scores = pd.DataFrame.from_dict(rows_formatted, orient="index", columns=rank_cols).sort_index()
enrich_table_with_scores.index.name = "unit"

display(enrich_table)
display(enrich_table_with_scores)

In [None]:
# ============================================
# Justify 3 archetypes from MIS enrichment
# Tumor / TME (immune+stromal) / Interface
# ============================================
import numpy as np, pandas as pd
from math import log10
from scipy.stats import ks_2samp
from sklearn.metrics import r2_score
from scipy.stats import pearsonr
from typing import Dict 

# ---- Assumed available in your session ----
# - hypergeo_enrichment_for_unit(u, A, labels, min_in_high=50)
# - MIS (np.ndarray length = #units in A)
# - summary (dict from summary_expr_*.json) with "top_units" list[{unit,MIS},...]
# - A (cells x units), labels (array[str] for each cell)
# - adata (optional, only if you want patient/slide-aware permutations/splits)

# ============================================
# 0) CONFIG
# ============================================
TOP_UNITS = 50           # how many MIS-top units to consider
ALPHA_FDR = 0.05         # significance
USE_ER_POWER = 1.0       # weight ER^p inside score (1.0 = linear)
EPS_FDR = 1e-300         # floor to avoid -log10(0)
MIN_HIGH = 200           # min # high-activation cells to run enrichment (guard)
GROUP_COL = "patient"    # for within-group permutation/split; set None if unavailable

# label binning rules (coarse-grain masks)
def is_tumor(lbl: str) -> bool:
    s = str(lbl).lower()
    return s.startswith("tumor") or ("malig" in s) or ("carcin" in s)

def is_tme(lbl: str) -> bool:
    # everything non-tumor is considered TME (immune+stromal)
    return not is_tumor(lbl)

# ============================================
# 1) Collect per-unit enrichment tables
# ============================================
units_sorted = [rec["unit"] for rec in summary["top_units"][:TOP_UNITS]]

enrich_by_unit: Dict[int, pd.DataFrame] = {}
for u in units_sorted:
    df = hypergeo_enrichment_for_unit(u, A, labels, min_in_high=MIN_HIGH)
    if df is None or df.empty:
        continue
    # keep only columns we need; ensure numeric dtypes
    keep = ["label","k_in_high","K_in_pop","n_high","N_total","p","p_fdr","enrichment_ratio"]
    df = df[keep].copy()
    df["p_fdr"] = df["p_fdr"].astype(float).clip(lower=EPS_FDR)
    df["enrichment_ratio"] = df["enrichment_ratio"].astype(float)
    enrich_by_unit[u] = df

if not enrich_by_unit:
    raise RuntimeError("No enrichment tables computed; relax MIN_HIGH or TOP_UNITS?")

# ============================================
# 2) Build archetype scores per unit
#   Tumor score T_u = sum_{tumor labels with FDR<=α} [-log10(FDR) * ER^p]
#   TME   score E_u = sum_{tme   labels with FDR<=α} [-log10(FDR) * ER^p]
#   Interface I_u   = min(T_u, E_u)  (only if both sides have ≥1 sig label; else 0)
# ============================================
rows = []
for u, df in enrich_by_unit.items():
    df = df.copy()
    df["is_tumor"] = df["label"].map(is_tumor)
    df["is_tme"]   = ~df["is_tumor"]

    df_sig = df[df["p_fdr"] <= ALPHA_FDR].copy()
    # components (allow empty -> 0)
    T_comp = df_sig[df_sig["is_tumor"]]
    E_comp = df_sig[df_sig["is_tme"]]

    def score(part):
        if part.empty: return 0.0
        # weight by ER^p; clamp ER>=0 to be safe
        er = np.clip(part["enrichment_ratio"].to_numpy(float), 0, None) ** USE_ER_POWER
        return float(np.sum((-np.log10(part["p_fdr"].to_numpy(float))) * er))

    T_u = score(T_comp)
    E_u = score(E_comp)

    # interface only when both sides show significant labels
    if (not T_comp.empty) and (not E_comp.empty):
        I_u = min(T_u, E_u)  # conservative; alternatives: np.sqrt(T_u*E_u)
    else:
        I_u = 0.0

    rows.append({
        "unit": u,
        "MIS": float(MIS[u]) if (0 <= u < len(MIS)) else np.nan,
        "T_score": T_u,
        "E_score": E_u,
        "I_score": I_u,
        "n_sig_tumor": int(len(T_comp)),
        "n_sig_tme":   int(len(E_comp)),
    })

unit_scores = pd.DataFrame(rows).set_index("unit").sort_values(["I_score","T_score","E_score"], ascending=False)
display(unit_scores.head(10))

# ============================================
# 3) Call archetype per unit (crisp, conservative)
# ============================================
# Define thresholds relative to distribution to avoid hand-tuning:
T_thr = unit_scores["T_score"].quantile(0.75)
E_thr = unit_scores["E_score"].quantile(0.75)
I_thr = unit_scores["I_score"].quantile(0.75)

def call_archetype(row):
    T, E, I = row["T_score"], row["E_score"], row["I_score"]
    has_T = row["n_sig_tumor"] > 0
    has_E = row["n_sig_tme"]   > 0

    if has_T and (T >= T_thr) and not (has_E and E >= E_thr):
        return "Tumor"
    if has_E and (E >= E_thr) and not (has_T and T >= T_thr):
        return "TME"
    if has_T and has_E and (I >= I_thr):
        return "Interface"
    return "Unclear"

unit_scores["call"] = unit_scores.apply(call_archetype, axis=1)
display(unit_scores[["MIS","T_score","E_score","I_score","call"]].head(20))