In [2]:
# ==========================
# Align from cache -> A, labels, groups
# ==========================
import os, json, numpy as np
import pandas as pd
import scanpy as sc

# ---- paths / config
CACHE_DIR = "/maiziezhou_lab2/yunfei/Projects/interpTFM/activations_cosmx_lung_cancer/cached/layer4"
A_CACHE   = os.path.join(CACHE_DIR, "A_cls.npy")     # scGPT <CLS> activations
J_CACHE   = os.path.join(CACHE_DIR, "J_embed.npy")
IDS_CACHE = os.path.join(CACHE_DIR, "cell_order.txt")
H5AD_PATH = "/maiziezhou_lab2/yunfei/Projects/FM_temp/InterPLM/interplm/ge_shards/cosmx_human_lung_sec8.h5ad"
GROUP_COL = "patient"   # set to None if you don't want within-patient shuffles

# ---- load cache + AnnData
A_cls = np.load(A_CACHE, mmap_mode="r")  # shape [N_all, U]
with open(IDS_CACHE, "r") as f:
    cell_order_cached = [ln.strip() for ln in f]

adata = sc.read_h5ad(H5AD_PATH)

# ---- build intersection in A's order
ids_A = pd.Index(cell_order_cached)
ids_J = pd.Index(adata.obs_names.astype(str))
idx_J_for_A = ids_J.get_indexer(ids_A)  # -1 if not present in adata
keep_mask   = idx_J_for_A >= 0

idx_A_keep = np.nonzero(keep_mask)[0]    # rows in A order to keep
idx_J_keep = idx_J_for_A[keep_mask]      # corresponding rows in adata

# A_cls   = np.load(A_CACHE, mmap_mode="r")   # shape [N, U]
J_embed = np.load(J_CACHE, mmap_mode="r")   # shape [N, F]

# ---- slice and name the arrays EXACTLY as used downstream
A_for_mis       = np.asarray(A_cls[idx_A_keep], dtype=np.float32)      # [M, U]
labels_aligned  = adata.obs["author_cell_type"].to_numpy()[idx_J_keep]        # [M]

# expose the names expected by the justification script:
A      = A_for_mis
labels = labels_aligned

# optional: groups for within-group permutations/splits (patient/slide/etc.)
if GROUP_COL and (GROUP_COL in adata.obs.columns):
    groups = adata.obs[GROUP_COL].to_numpy()[idx_J_keep]
else:
    groups = None

# ---- quick sanity
print(f"[sanity] A shape: {A.shape}  | labels: {labels.shape}  | groups: {None if groups is None else groups.shape}")

[sanity] A shape: (81236, 512)  | labels: (81236,)  | groups: None


In [3]:
from scipy.stats import binomtest, hypergeom
from statsmodels.stats.multitest import multipletests

def hypergeo_enrichment_for_unit(u, A, labels, min_in_high=50):
    if labels is None: 
        return None
    a = A[:, u]
    high = a >= np.median(a)
    idx_high = np.where(high)[0]
    if len(idx_high) < min_in_high:
        return None

    lab = pd.Series(labels)
    lab_high = pd.Series(labels[idx_high])

    N = len(lab)
    n = len(idx_high)

    rows = []
    for label, K in lab.value_counts().items():
        k = int((lab_high == label).sum())
        # overlap k or more among n draws from population N with K "successes"
        p = hypergeom.sf(k-1, N, K, n)
        rows.append((label, k, K, n, N, p))
    df = pd.DataFrame(rows, columns=["label","k_in_high","K_in_pop","n_high","N_total","p"])
    df["p_fdr"] = multipletests(df["p"], method="fdr_bh")[1]
    df["enrichment_ratio"] = (df["k_in_high"]/df["n_high"]) / (df["K_in_pop"]/df["N_total"])
    return df.sort_values(["p_fdr","enrichment_ratio"], ascending=[True,False])

# # --- Make sure your variables are correctly defined ---
A = np.asarray(A_cls)
labels = np.asarray(labels_aligned)
# --- Base directory for the new run ---
out_dir = "./out"

# --- Load MIS results ---
MIS = np.load(f"{out_dir}/MIS_expr_zcos.npy")

with open(f"{out_dir}/summary_expr_zcos.json") as f:
    summary = json.load(f)

# Example: top 10 units
for u in [r["unit"] for r in summary["top_units"][:10]]:
    print("Unit", u, "MIS", MIS[u])
    display(hypergeo_enrichment_for_unit(u, A, labels).head(8))


Unit 246 MIS 0.845


Unnamed: 0,label,k_in_high,K_in_pop,n_high,N_total,p,p_fdr,enrichment_ratio
2,fibroblast,11921,13609,40618,81236,0.0,0.0,1.751929
3,T CD4 memory,7411,9344,40618,81236,0.0,0.0,1.586259
7,plasmablast,1513,1638,40618,81236,1.267523e-308,9.295167e-308,1.847375
8,mast,1391,1487,40618,81236,1.0474280000000001e-299,5.760856e-299,1.870881
6,Treg,1886,2244,40618,81236,1.8145529999999998e-256,7.984032e-256,1.680927
5,pDC,2135,2702,40618,81236,1.248822e-219,4.5790140000000005e-219,1.580311
10,neutrophil,352,412,40618,81236,6.895189e-52,2.1670589999999998e-51,1.708738
12,T CD8 naive,175,184,40618,81236,1.981493e-41,5.449105e-41,1.902174


Unit 437 MIS 0.825


Unnamed: 0,label,k_in_high,K_in_pop,n_high,N_total,p,p_fdr,enrichment_ratio
0,tumor 13,20770,26030,40618,81236,0.0,0.0,1.595851
9,endothelial,791,1449,40618,81236,0.000232,0.002547,1.091787
1,macrophage,9147,17989,40618,81236,0.005102,0.037418,1.016955
4,T CD8 memory,1162,3067,40618,81236,1.0,1.0,0.757744
3,T CD4 memory,3172,9344,40618,81236,1.0,1.0,0.678938
11,mDC,106,338,40618,81236,1.0,1.0,0.627219
2,fibroblast,4085,13609,40618,81236,1.0,1.0,0.600338
5,pDC,664,2702,40618,81236,1.0,1.0,0.491488


Unit 467 MIS 0.805


Unnamed: 0,label,k_in_high,K_in_pop,n_high,N_total,p,p_fdr,enrichment_ratio
2,fibroblast,12058,13609,40618,81236,0.0,0.0,1.772063
3,T CD4 memory,8212,9344,40618,81236,0.0,0.0,1.757705
8,mast,1385,1487,40618,81236,9.444565999999999e-293,6.926015000000001e-292,1.862811
6,Treg,1729,2244,40618,81236,7.236646e-157,3.9801560000000003e-156,1.540998
5,pDC,1833,2702,40618,81236,3.699197e-81,1.627647e-80,1.356773
7,plasmablast,1190,1638,40618,81236,2.149129e-79,7.880141e-79,1.452991
10,neutrophil,332,412,40618,81236,4.924135999999999e-38,1.547586e-37,1.61165
12,T CD8 naive,172,184,40618,81236,8.158654999999999e-38,2.24363e-37,1.869565


Unit 55 MIS 0.78


Unnamed: 0,label,k_in_high,K_in_pop,n_high,N_total,p,p_fdr,enrichment_ratio
0,tumor 13,25346,26030,40618,81236,0.0,0.0,1.947445
6,Treg,1624,2244,40618,81236,4.178101e-106,4.595912e-105,1.447415
4,T CD8 memory,2003,3067,40618,81236,4.118266e-68,3.020061e-67,1.306162
7,plasmablast,1109,1638,40618,81236,1.279085e-48,7.034968999999999e-48,1.35409
10,neutrophil,333,412,40618,81236,1.1699579999999999e-38,5.147813999999999e-38,1.616505
13,tumor 9,148,170,40618,81236,1.793582e-24,6.5764669999999995e-24,1.741176
14,epithelial,110,132,40618,81236,1.364076e-15,4.287095e-15,1.666667
12,T CD8 naive,130,184,40618,81236,9.875154e-09,2.715667e-08,1.413043


Unit 213 MIS 0.7775


Unnamed: 0,label,k_in_high,K_in_pop,n_high,N_total,p,p_fdr,enrichment_ratio
0,tumor 13,22213,26030,40618,81236,0.0,0.0,1.706723
1,macrophage,9958,17989,40618,81236,6.117135000000001e-60,6.728849e-59,1.107121
4,T CD8 memory,1869,3067,40618,81236,1.658207e-35,1.216019e-34,1.218781
11,mDC,191,338,40618,81236,0.009478861,0.05213373,1.130178
8,mast,560,1487,40618,81236,1.0,1.0,0.753194
6,Treg,843,2244,40618,81236,1.0,1.0,0.751337
17,B-cell,24,70,40618,81236,0.9972188,1.0,0.685714
3,T CD4 memory,2919,9344,40618,81236,1.0,1.0,0.624786


Unit 196 MIS 0.775


Unnamed: 0,label,k_in_high,K_in_pop,n_high,N_total,p,p_fdr,enrichment_ratio
0,tumor 13,24391,26030,40618,81236,0.0,0.0,1.874068
4,T CD8 memory,2320,3067,40618,81236,5.1202839999999995e-193,5.632313e-192,1.512879
6,Treg,1467,2244,40618,81236,2.7348919999999995e-50,2.005587e-49,1.307487
12,T CD8 naive,153,184,40618,81236,6.279403e-21,3.4536709999999996e-20,1.663043
15,NK,85,105,40618,81236,4.912775e-11,2.161621e-10,1.619048
13,tumor 9,110,170,40618,81236,7.624572e-05,0.0002795676,1.294118
14,epithelial,87,132,40618,81236,0.0001597967,0.0005022182,1.318182
19,T CD4 naive,41,62,40618,81236,0.007547907,0.02075675,1.322581


Unit 292 MIS 0.7725


Unnamed: 0,label,k_in_high,K_in_pop,n_high,N_total,p,p_fdr,enrichment_ratio
2,fibroblast,11095,13609,40618,81236,0.0,0.0,1.630539
1,macrophage,14578,17989,40618,81236,0.0,0.0,1.620768
3,T CD4 memory,6539,9344,40618,81236,0.0,0.0,1.399615
5,pDC,1953,2702,40618,81236,5.458871e-127,3.002379e-126,1.445596
8,mast,1002,1487,40618,81236,1.099372e-42,4.837237e-42,1.34768
11,mDC,286,338,40618,81236,9.795671e-41,3.591746e-40,1.692308
16,monocyte,62,93,40618,81236,0.0008512247,0.002675278,1.333333
7,plasmablast,848,1638,40618,81236,0.07738806,0.2128172,1.035409


Unit 346 MIS 0.7725


Unnamed: 0,label,k_in_high,K_in_pop,n_high,N_total,p,p_fdr,enrichment_ratio
0,tumor 13,22185,26030,40618,81236,0.0,0.0,1.704572
13,tumor 9,72,170,40618,81236,0.981069,1.0,0.847059
1,macrophage,6732,17989,40618,81236,1.0,1.0,0.748457
2,fibroblast,4612,13609,40618,81236,1.0,1.0,0.677787
3,T CD4 memory,3162,9344,40618,81236,1.0,1.0,0.676798
8,mast,478,1487,40618,81236,1.0,1.0,0.642905
4,T CD8 memory,974,3067,40618,81236,1.0,1.0,0.635148
5,pDC,850,2702,40618,81236,1.0,1.0,0.629164


Unit 382 MIS 0.77


Unnamed: 0,label,k_in_high,K_in_pop,n_high,N_total,p,p_fdr,enrichment_ratio
7,plasmablast,1527,1638,40618,81236,0.0,0.0,1.864469
5,pDC,2347,2702,40618,81236,0.0,0.0,1.737232
3,T CD4 memory,5821,9344,40618,81236,2.386328e-142,1.749974e-141,1.245933
6,Treg,1668,2244,40618,81236,5.597477e-126,3.078612e-125,1.486631
2,fibroblast,8036,13609,40618,81236,3.286175e-119,1.445917e-118,1.180983
1,macrophage,10228,17989,40618,81236,5.744488e-97,2.106312e-96,1.137139
11,mDC,323,338,40618,81236,4.998691e-77,1.571017e-76,1.911243
8,mast,1079,1487,40618,81236,1.444315e-71,3.971865e-71,1.451244


Unit 187 MIS 0.77


Unnamed: 0,label,k_in_high,K_in_pop,n_high,N_total,p,p_fdr,enrichment_ratio
0,tumor 13,21833,26030,40618,81236,0.0,0.0,1.677526
1,macrophage,10232,17989,40618,81236,1.386802e-97,1.5254830000000001e-96,1.137584
4,T CD8 memory,2033,3067,40618,81236,5.623388e-77,4.123818e-76,1.325725
9,endothelial,762,1449,40618,81236,0.02488772,0.1368824,1.05176
6,Treg,842,2244,40618,81236,1.0,1.0,0.750446
5,pDC,802,2702,40618,81236,1.0,1.0,0.593634
15,NK,31,105,40618,81236,0.9999935,1.0,0.590476
21,tumor 5,6,23,40618,81236,0.9946943,1.0,0.521739


In [6]:
# ============================================
# Justify 3 archetypes from MIS enrichment
# Tumor / TME (immune+stromal) / Interface
# ============================================
import numpy as np, pandas as pd
from math import log10
from scipy.stats import ks_2samp
from sklearn.metrics import r2_score
from scipy.stats import pearsonr
from typing import Dict 

# ---- Assumed available in your session ----
# - hypergeo_enrichment_for_unit(u, A, labels, min_in_high=50)
# - MIS (np.ndarray length = #units in A)
# - summary (dict from summary_expr_*.json) with "top_units" list[{unit,MIS},...]
# - A (cells x units), labels (array[str] for each cell)
# - adata (optional, only if you want patient/slide-aware permutations/splits)

# ============================================
# 0) CONFIG
# ============================================
TOP_UNITS = 50           # how many MIS-top units to consider
ALPHA_FDR = 0.05         # significance
USE_ER_POWER = 1.0       # weight ER^p inside score (1.0 = linear)
EPS_FDR = 1e-300         # floor to avoid -log10(0)
MIN_HIGH = 200           # min # high-activation cells to run enrichment (guard)
GROUP_COL = "patient"    # for within-group permutation/split; set None if unavailable

# label binning rules (coarse-grain masks)
def is_tumor(lbl: str) -> bool:
    s = str(lbl).lower()
    return s.startswith("tumor") or ("malig" in s) or ("carcin" in s)

def is_tme(lbl: str) -> bool:
    # everything non-tumor is considered TME (immune+stromal)
    return not is_tumor(lbl)

# ============================================
# 1) Collect per-unit enrichment tables
# ============================================
units_sorted = [rec["unit"] for rec in summary["top_units"][:TOP_UNITS]]

enrich_by_unit: Dict[int, pd.DataFrame] = {}
for u in units_sorted:
    df = hypergeo_enrichment_for_unit(u, A, labels, min_in_high=MIN_HIGH)
    if df is None or df.empty:
        continue
    # keep only columns we need; ensure numeric dtypes
    keep = ["label","k_in_high","K_in_pop","n_high","N_total","p","p_fdr","enrichment_ratio"]
    df = df[keep].copy()
    df["p_fdr"] = df["p_fdr"].astype(float).clip(lower=EPS_FDR)
    df["enrichment_ratio"] = df["enrichment_ratio"].astype(float)
    enrich_by_unit[u] = df

if not enrich_by_unit:
    raise RuntimeError("No enrichment tables computed; relax MIN_HIGH or TOP_UNITS?")

# ============================================
# 2) Build archetype scores per unit
#   Tumor score T_u = sum_{tumor labels with FDR<=α} [-log10(FDR) * ER^p]
#   TME   score E_u = sum_{tme   labels with FDR<=α} [-log10(FDR) * ER^p]
#   Interface I_u   = min(T_u, E_u)  (only if both sides have ≥1 sig label; else 0)
# ============================================
rows = []
for u, df in enrich_by_unit.items():
    df = df.copy()
    df["is_tumor"] = df["label"].map(is_tumor)
    df["is_tme"]   = ~df["is_tumor"]

    df_sig = df[df["p_fdr"] <= ALPHA_FDR].copy()
    # components (allow empty -> 0)
    T_comp = df_sig[df_sig["is_tumor"]]
    E_comp = df_sig[df_sig["is_tme"]]

    def score(part):
        if part.empty: return 0.0
        # weight by ER^p; clamp ER>=0 to be safe
        er = np.clip(part["enrichment_ratio"].to_numpy(float), 0, None) ** USE_ER_POWER
        return float(np.sum((-np.log10(part["p_fdr"].to_numpy(float))) * er))

    T_u = score(T_comp)
    E_u = score(E_comp)

    # interface only when both sides show significant labels
    if (not T_comp.empty) and (not E_comp.empty):
        I_u = min(T_u, E_u)  # conservative; alternatives: np.sqrt(T_u*E_u)
    else:
        I_u = 0.0

    rows.append({
        "unit": u,
        "MIS": float(MIS[u]) if (0 <= u < len(MIS)) else np.nan,
        "T_score": T_u,
        "E_score": E_u,
        "I_score": I_u,
        "n_sig_tumor": int(len(T_comp)),
        "n_sig_tme":   int(len(E_comp)),
    })

unit_scores = pd.DataFrame(rows).set_index("unit").sort_values(["I_score","T_score","E_score"], ascending=False)
display(unit_scores.head(10))

# ============================================
# 3) Call archetype per unit (crisp, conservative)
# ============================================
# Define thresholds relative to distribution to avoid hand-tuning:
T_thr = unit_scores["T_score"].quantile(0.75)
E_thr = unit_scores["E_score"].quantile(0.75)
I_thr = unit_scores["I_score"].quantile(0.75)

def call_archetype(row):
    T, E, I = row["T_score"], row["E_score"], row["I_score"]
    has_T = row["n_sig_tumor"] > 0
    has_E = row["n_sig_tme"]   > 0

    if has_T and (T >= T_thr) and not (has_E and E >= E_thr):
        return "Tumor"
    if has_E and (E >= E_thr) and not (has_T and T >= T_thr):
        return "TME"
    if has_T and has_E and (I >= I_thr):
        return "Interface"
    return "Unclear"

unit_scores["call"] = unit_scores.apply(call_archetype, axis=1)
display(unit_scores[["MIS","T_score","E_score","I_score","call"]].head(20))

Unnamed: 0_level_0,MIS,T_score,E_score,I_score,n_sig_tumor,n_sig_tme
unit,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
285,0.7325,560.031625,564.973682,560.031625,4,5
348,0.74,441.044948,546.594147,441.044948,1,5
55,0.78,642.177196,418.220004,418.220004,5,9
176,0.73,460.24587,414.343449,414.343449,1,1
196,0.775,566.819179,407.597398,407.597398,2,6
149,0.7525,308.64969,502.925608,308.64969,1,2
187,0.77,503.257779,208.938848,208.938848,1,2
478,0.745,413.684211,194.808768,194.808768,1,2
499,0.745,536.49795,154.894758,154.894758,4,6
213,0.7775,512.016904,105.738525,105.738525,1,2


Unnamed: 0_level_0,MIS,T_score,E_score,I_score,call
unit,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
285,0.7325,560.031625,564.973682,560.031625,Tumor
348,0.74,441.044948,546.594147,441.044948,Interface
55,0.78,642.177196,418.220004,418.220004,Tumor
176,0.73,460.24587,414.343449,414.343449,Interface
196,0.775,566.819179,407.597398,407.597398,Tumor
149,0.7525,308.64969,502.925608,308.64969,Interface
187,0.77,503.257779,208.938848,208.938848,Interface
478,0.745,413.684211,194.808768,194.808768,Interface
499,0.745,536.49795,154.894758,154.894758,Tumor
213,0.7775,512.016904,105.738525,105.738525,Interface


In [7]:
import numpy as np
import pandas as pd
import scanpy as sc

# ------------------------
# Inputs you already have:
# ------------------------
# - unit_scores: DataFrame indexed by unit with column "call" in {"Tumor","TME","Interface",...}
# - A_for_mis:  np.ndarray [N_cells_used, U]  (scGPT <CLS> activations for MIS units)
# - Optionally: ids_common: list/array of cell IDs for rows of A_for_mis (same length as A_for_mis.shape[0])

# --- Load target AnnData with interpretable features ---
adata_interp = sc.read_h5ad(
    "/maiziezhou_lab2/yunfei/Projects/interpTFM/evaluation/ccc/data/adata_interpretable_concepts_sec8.h5ad"
)

# (Optional) source AnnData to pull spatial coords from, if adata_interp lacks them
adata_src = sc.read_h5ad(
    "/maiziezhou_lab2/yunfei/Projects/FM_temp/InterPLM/interplm/ge_shards/cosmx_human_lung_sec8.h5ad"
)

# ------------------------
# 1) Align rows (A_for_mis -> adata_interp)
# ------------------------
# If you already have ids_common in memory, we’ll use it.
# Otherwise set IDS_FOR_A_PATH to a text file with one cell_id per line (same row order as A_for_mis),
# or leave both None to assume A rows are in the same order as adata_interp.obs_names.
IDS_FOR_A_PATH = None  # e.g., "/.../cached/layer4/ids_common.txt"

def _load_ids_for_A():
    if "ids_common" in globals() and ids_common is not None:
        return list(map(str, ids_common))
    if IDS_FOR_A_PATH:
        with open(IDS_FOR_A_PATH, "r") as f:
            return [ln.strip() for ln in f]
    # fallback: assume order already matches adata_interp
    return list(map(str, adata_interp.obs_names))

ids_A = _load_ids_for_A()
idx_A = pd.Index(ids_A)
idx_I = pd.Index(adata_interp.obs_names.astype(str))

take_A_for_I = idx_A.get_indexer(idx_I)   # for each adata_interp row, which row in A? -1 if missing
mask_keep = take_A_for_I >= 0
n_common = mask_keep.sum()
print(f"[align] matched {n_common} / {adata_interp.n_obs} cells from A_for_mis to adata_interp")

# ------------------------
# 2) Build archetype unit lists
# ------------------------
tumor_units = unit_scores.index[unit_scores["call"]=="Tumor"].to_numpy(int)
tme_units   = unit_scores.index[unit_scores["call"]=="TME"].to_numpy(int)
iface_units = unit_scores.index[unit_scores["call"]=="Interface"].to_numpy(int)

print(f"[units] Tumor:{len(tumor_units)}  TME:{len(tme_units)}  Interface:{len(iface_units)}")

# ------------------------
# 3) Compute per-cell archetype scores from A_for_mis (z-scored per unit)
# ------------------------
A = np.asarray(A_for_mis, dtype=np.float32)

# z-score per unit across *rows of A_for_mis*:
mu = A.mean(axis=0, keepdims=True)
sd = A.std(axis=0, keepdims=True); sd[sd < 1e-8] = 1.0
Az = (A - mu) / sd   # shape [N_A_rows, U]

def avg_score(Az, cols):
    if cols is None or len(cols) == 0:
        return np.zeros(Az.shape[0], dtype=np.float32)
    return Az[:, cols].mean(axis=1)

S_T_all = avg_score(Az, tumor_units)   # length = N_A_rows
S_E_all = avg_score(Az, tme_units)
S_I_all = avg_score(Az, iface_units)

# Reindex these scores onto adata_interp rows (pad with NaN where we lack A)
S_T = np.full(adata_interp.n_obs, np.nan, dtype=np.float32)
S_E = np.full(adata_interp.n_obs, np.nan, dtype=np.float32)
S_I = np.full(adata_interp.n_obs, np.nan, dtype=np.float32)

S_T[mask_keep] = S_T_all[take_A_for_I[mask_keep]]
S_E[mask_keep] = S_E_all[take_A_for_I[mask_keep]]
S_I[mask_keep] = S_I_all[take_A_for_I[mask_keep]]

# ------------------------
# 4) Softmax normalization to probabilities
# ------------------------
def softmax(X, tau=1.0):
    X_ = X / tau
    X_ -= np.nanmax(X_, axis=1, keepdims=True)  # stable softmax with NaNs handled below
    e = np.exp(X_)
    e[np.isnan(e)] = 0.0
    s = e.sum(axis=1, keepdims=True)
    s[s <= 0] = 1.0
    return e / s

scores = np.stack([S_T, S_E, S_I], axis=1)  # [cells, 3]
P = softmax(scores, tau=1.0)

# ------------------------
# 5) Assign compartments + write to obs/obsm
# ------------------------
archetypes = np.array(["Tumor", "TME", "Interface"])
hard_call = archetypes[np.nanargmax(P, axis=1)]  # rows with all-NaN become 0; guard next:
all_nan = ~np.isfinite(scores).any(axis=1)
hard_call[all_nan] = "Unassigned"

adata_interp.obs["compartment"] = pd.Categorical(
    hard_call,
    categories=["Tumor","TME","Interface","Unassigned"]
)
adata_interp.obs["compartment_conf"] = np.nanmax(P, axis=1)

# store probabilities
adata_interp.obs["prob_Tumor"]     = P[:, 0]
adata_interp.obs["prob_TME"]       = P[:, 1]
adata_interp.obs["prob_Interface"] = P[:, 2]

# store raw scores too
adata_interp.obs["score_Tumor"]     = S_T
adata_interp.obs["score_TME"]       = S_E
adata_interp.obs["score_Interface"] = S_I

# also put compact matrices in obsm (useful for plotting)
adata_interp.obsm["X_compartment_scores"] = scores.astype(np.float32)
adata_interp.obsm["X_compartment_probs"]  = P.astype(np.float32)

print(adata_interp.obs["compartment"].value_counts(dropna=False))

# ------------------------
# 6) Copy spatial (if adata_interp lacks it but adata_src has it)
# ------------------------
if ("spatial" not in adata_interp.obsm) and ("spatial" in adata_src.obsm):
    idx_src = pd.Index(adata_src.obs_names.astype(str))
    take_src_for_I = idx_src.get_indexer(idx_I)  # reuse idx_I (same as adata_interp order)
    mask_src = take_src_for_I >= 0

    S = np.full((adata_interp.n_obs, adata_src.obsm["spatial"].shape[1]), np.nan, dtype=np.float32)
    S[mask_src] = adata_src.obsm["spatial"][take_src_for_I[mask_src]]
    adata_interp.obsm["spatial"] = S

    # convenience columns
    adata_interp.obs["spatial_x"] = S[:, 0]
    adata_interp.obs["spatial_y"] = S[:, 1]
    if S.shape[1] >= 3:
        adata_interp.obs["spatial_z"] = S[:, 2]

# ------------------------
# 7) Save updated AnnData
# ------------------------
# out_path = "/maiziezhou_lab2/yunfei/Projects/interpTFM/evaluation/ccc/data/adata_interpretable_sec8_final.h5ad"
# adata_interp.write_h5ad(out_path, compression="lzf")
# print(f"[saved] {out_path}")

[align] matched 81236 / 81236 cells from A_for_mis to adata_interp
[units] Tumor:13  TME:13  Interface:8
compartment
TME           40827
Tumor         25315
Interface     15094
Unassigned        0
Name: count, dtype: int64
