In [None]:
from pathlib import Path
import os

DAY = "20251201"
Version = "v1"

PROJECT_ROOT = Path(os.getenv("LLMSC_ROOT", ".")).resolve()

DATA_DIR = Path(os.getenv("LLMSC_DATA_DIR", PROJECT_ROOT / "input")).resolve()
OUT_DIR  = Path(os.getenv("LLMSC_OUT_DIR",  PROJECT_ROOT / "runs" / f"{DAY}.{Version}")).resolve()
OUT_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
RANDOM_SEED = 42
import random, os
import numpy as np
random.seed(RANDOM_SEED)
os.environ['PYTHONHASHSEED'] = str(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

print(f"üîí Random seed set to {RANDOM_SEED} for reproducibility.")

üîí Random seed set to 42 for reproducibility.


In [None]:
import os
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
USE_LLM = bool(GEMINI_API_KEY)

if not USE_LLM:
    print("‚ö†Ô∏è GEMINI_API_KEY not set ‚Üí LLM inference cells will be skipped.")

In [None]:
import google.generativeai as genai
import scanpy as sc
import pandas as pd
import numpy as np
import scipy
import scipy.sparse
from scipy import io
import adjustText
import gc
import gzip
import time
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.backends.backend_pdf import PdfPages
import seaborn as sns
from sklearn.metrics import confusion_matrix


import logging
logging.basicConfig(
    level=logging.INFO,
    format="%(levelname)s:%(name)s:%(message)s"
)

from llm_sc_curator import LLMscCurator
from llm_sc_curator.masking import FeatureDistiller

import warnings
warnings.filterwarnings("ignore")

In [None]:
# ==========================================
# 1. API
# ==========================================

genai.configure(api_key=GEMINI_API_KEY)
print("--- Available Models for your Key ---")
try:
    for m in genai.list_models():
        if 'generateContent' in m.supported_generation_methods:
            print(m.name)
except Exception as e:
    print(f"Error: {e}")

--- Available Models for your Key ---
models/gemini-2.5-flash
models/gemini-2.5-pro
models/gemini-2.0-flash-exp
models/gemini-2.0-flash
models/gemini-2.0-flash-001
models/gemini-2.0-flash-exp-image-generation
models/gemini-2.0-flash-lite-001
models/gemini-2.0-flash-lite
models/gemini-2.0-flash-lite-preview-02-05
models/gemini-2.0-flash-lite-preview
models/gemini-exp-1206
models/gemini-2.5-flash-preview-tts
models/gemini-2.5-pro-preview-tts
models/gemma-3-1b-it
models/gemma-3-4b-it
models/gemma-3-12b-it
models/gemma-3-27b-it
models/gemma-3n-e4b-it
models/gemma-3n-e2b-it
models/gemini-flash-latest
models/gemini-flash-lite-latest
models/gemini-pro-latest
models/gemini-2.5-flash-lite
models/gemini-2.5-flash-image-preview
models/gemini-2.5-flash-image
models/gemini-2.5-flash-preview-09-2025
models/gemini-2.5-flash-lite-preview-09-2025
models/gemini-3-pro-preview
models/gemini-3-pro-image-preview
models/nano-banana-pro-preview
models/gemini-robotics-er-1.5-preview
models/gemini-2.5-computer-

In [None]:
MODEL_NAME = 'models/gemini-2.5-pro'

# MODEL_NAME = 'models/gemini-3-pro-preview'
# MODEL_NAME = 'models/gemini-2.5-flash'
# MODEL_NAME = 'models/gemini-2.0-flash'

print(f"Using Model: {MODEL_NAME}")
model = genai.GenerativeModel(MODEL_NAME)

Using Model: models/gemini-2.5-pro


In [None]:
save_path =  OUT_DIR / "brca_msc_benchmark_data.h5ad"

print(f"\n Loading data...: {save_path}")
adata= sc.read_h5ad(save_path)
print("\nCell counts per cluster (Should be balanced approx ~300 if 3 datasets merged):")
print(adata.obs['meta.cluster'].value_counts().head(10))


 Loading data...: /runs/20251201.v1/brca_msc_benchmark_data.h5ad

Cell counts per cluster (Should be balanced approx ~300 if 3 datasets merged):
meta.cluster
CAFs MSC iCAF-like    300
CAFs myCAF-like       300
Endothelial ACKR1     300
Endothelial CXCL12    300
PVL Differentiated    300
Endothelial RGS5      300
PVL Immature          300
Cycling PVL            50
Name: count, dtype: int64


In [None]:
# ==========================================
# Correct Label Generation Function
# Definition based on cluster IDs from Wu et al., 2021
# ==========================================
def get_msc_ground_truth(cluster_name: str) -> str:
    """
    Ground-truth mapping for the CAF / MSC benchmark.

    We deliberately collapse the space into a small set of biologically
    interpretable states:
        - Fibro_iCAF      (MSC / inflammatory CAFs)
        - Fibro_myCAF     (myofibroblastic / activated CAFs)
        - Fibro_PVL       (perivascular / pericyte / smooth muscle lineage)
        - Fibro_Cycling   (cycling fibroblasts that are not clearly PVL)
        - Endothelial     (vascular endothelial lineage)
    """
    s = str(cluster_name).lower()

    # 1) Endothelial lineages
    if "endothelial" in s or "blood vessel" in s or "blood vessels" in s:
        return "Endothelial"

    # 2) PVL / perivascular / pericyte / smooth muscle clusters
    #    (including "Cycling PVL" ‚Äì PVL is the primary signal here)
    if any(k in s for k in ["pvl", "perivascular", "pericyte", "smooth muscle"]):
        return "Fibro_PVL"

    # 3) iCAF-like MSC clusters
    if any(k in s for k in ["icaf", "msc"]):
        return "Fibro_iCAF"

    # 4) myCAF-like clusters
    if any(k in s for k in ["mycaf", "myofibroblast"]):
        return "Fibro_myCAF"

    # 5) Cycling fibroblasts (that are not clearly PVL / endothelial)
    if any(k in s for k in ["cycling", "proliferating", "mki67", "top2a"]):
        return "Fibro_Cycling"

    # 6) Fallback
    return "Fibro_Other"


print("Applying MSC/CAF Ground Truth Mapping...")
adata.obs['GT_Category'] = adata.obs['meta.cluster'].apply(get_msc_ground_truth)

print("\n--- Value Counts (Check for 'Other') ---")
print(adata.obs['GT_Category'].value_counts())

print("  Generating UMAP...")
sc.pp.neighbors(adata, random_state=42)
sc.tl.umap(adata, random_state=42)
fig = sc.pl.umap(
    adata,
    color=['meta.cluster', 'GT_Category'],
    legend_fontsize=10,
    ncols=1,
    title=['Original clusters', 'Ground truth (Consensus)'],
    frameon=False,
    return_fig=True
)

# Layout adjustment
for ax in fig.axes:
    handles, labels = ax.get_legend_handles_labels()
    if handles:
        ax.legend(handles, labels, loc='center left', bbox_to_anchor=(1.05, 0.5), ncol=1)

fig.set_size_inches(4, 6)
fig.tight_layout(rect=[0, 0, 0.85, 1])
fig.savefig(f"{OUT_DIR}/EDFig1d_MSC_GT_Check.pdf", bbox_inches='tight')


GT_KEYWORDS_MSC = {
    "Endothelial":   ["PECAM1", "VWF", "CLDN5", "ACKR1"],
    "Fibro_iCAF":    ["CXCL12", "CFD", "C7", "IL6"],
    "Fibro_myCAF":   ["ACTA2", "TAGLN", "MMP11", "COL1A1"],
    "Fibro_Cycling": ["MKI67", "TOP2A"],
    "Fibro_PVL":     ["RGS5", "MCAM", "NOTCH3"],
    "Fibro_Other":   [],   # catch-all; marker list is intentionallyÁ©∫
}

# Filter markers
var_names = adata.var_names
filtered_markers = {k: [g for g in v if g in var_names] for k, v in GT_KEYWORDS_MSC.items()}



In [None]:
if "highly_variable" not in adata.var.columns:
    print("[Setup] Computing global HVGs for benchmark...")
    sc.pp.highly_variable_genes(
        adata,
        n_top_genes=2000,
        subset=False,
        flavor="seurat",
    )

curator = LLMscCurator(api_key=GEMINI_API_KEY, model_name=MODEL_NAME)
curator.set_global_context(adata)

N_GENES = 50
benchmark_results = []

# 1) Retrieve cluster list from meta.cluster (sorted for reproducibility)
unique_clusters = sorted(adata.obs["meta.cluster"].unique())

# Clusters with ambiguous GT values are excluded from the main benchmark.
EXCLUDE_GT_LABELS = {"Fibro_Other", "Other", "Unknown"}

cluster_meta = []
dropped_clusters = []

for c in unique_clusters:
    gt_label = get_msc_ground_truth(c)
    if gt_label in EXCLUDE_GT_LABELS:
        dropped_clusters.append((c, gt_label))
        continue
    cluster_meta.append((c, gt_label))

print(f"üöÄ Starting MSC benchmark for {len(cluster_meta)} clusters (Top {N_GENES} genes)...")
if dropped_clusters:
    print("‚ö†Ô∏è Excluded ambiguous clusters from main benchmark:")
    for c, lab in dropped_clusters:
        print(f"   - {c} (GT={lab})")


def ensure_json_result(x):
    """Always normalize LLM responses to dict format. Also complete missing keys."""
    if isinstance(x, dict):
        return {
            "cell_type":  x.get("cell_type", "Unknown"),
            "confidence": x.get("confidence", "Low"),
            "reasoning":  x.get("reasoning", ""),
        }
    elif isinstance(x, str):
        return {
            "cell_type":  x,
            "confidence": "Low",
            "reasoning":  "",
        }
    else:
        return {
            "cell_type":  "Error",
            "confidence": "Low",
            "reasoning":  repr(x),
        }


# 2) Main loop: 1 row = 1 meta.cluster
for i, (cluster_name, gt_label) in enumerate(cluster_meta):
    gt_keywords = GT_KEYWORDS_MSC.get(gt_label, [])

    print(f"\n[{i+1}/{len(cluster_meta)}] Processing: {cluster_name} ‚Üí {gt_label}")

    # -------------------------------------------------
    # A. Standard pipeline (no masking)
    # -------------------------------------------------
    adata.obs["binary_group"] = "Rest"
    adata.obs.loc[adata.obs["meta.cluster"] == cluster_name, "binary_group"] = "Target"

    try:
        sc.tl.rank_genes_groups(
            adata,
            groupby="binary_group",
            groups=["Target"],
            reference="Rest",
            method="wilcoxon",
            use_raw=False,
        )
        df_std = sc.get.rank_genes_groups_df(adata, group="Target")
        genes_std = df_std["names"].head(N_GENES).tolist()
    except Exception as e:
        print(f"[WARN] Standard DE failed for {cluster_name}: {e}")
        genes_std = []
    try:
        if genes_std:
            raw_std = curator.annotate(genes_std, use_auto_context=False)
        else:
            raw_std = {
                "cell_type": "NoGenes",
                "confidence": "Low",
                "reasoning": "Empty DEG list",
            }
        res_std = ensure_json_result(raw_std)
    except Exception as e:
        print(f"[WARN] Standard annotate failed for {cluster_name}: {e}")
        res_std = ensure_json_result(
            {"cell_type": "Error", "confidence": "Low", "reasoning": str(e)}
        )

    time.sleep(2)

    # -------------------------------------------------
    # B. Curated pipeline (LLM-scCurator, masking ON)
    # -------------------------------------------------
    try:
        genes_cur = curator.curate_features(
            adata,
            group_col="meta.cluster",
            target_group=cluster_name,
            n_top=N_GENES,
            use_statistics=True,   # Gini + Regex Masking ON
        )
    except Exception as e:
        print(f"[WARN] curate_features failed for {cluster_name}: {e}")
        genes_cur = []

    try:
        if genes_cur:
            raw_cur = curator.annotate(genes_cur, use_auto_context=True)
        else:
            raw_cur = {
                "cell_type": "NoGenes",
                "confidence": "Low",
                "reasoning": "Curated gene list empty",
            }
        res_cur = ensure_json_result(raw_cur)
    except Exception as e:
        print(f"[WARN] Curated annotate failed for {cluster_name}: {e}")
        res_cur = ensure_json_result(
            {"cell_type": "Error", "confidence": "Low", "reasoning": str(e)}
        )

    time.sleep(2)

    # -------------------------------------------------
    # C. Save Results
    # -------------------------------------------------
    print(f"  üëâ Std Ans: {res_std['cell_type'][:40]} ({res_std['confidence']})")
    print(f"  üëâ Cur Ans: {res_cur['cell_type'][:40]} ({res_cur['confidence']})")

    benchmark_results.append(
        {
            "Cluster_ID": cluster_name,
            # Ground truth
            "Ground_Truth": gt_label,
            "Ground_Truth_Label": gt_label,
            "Ground_Truth_Keywords": ", ".join(gt_keywords),
            # Standard
            "Standard_Genes": ";".join(genes_std),
            "Standard_Answer": res_std["cell_type"],
            "Standard_CellType": res_std["cell_type"],
            "Standard_Confidence": res_std["confidence"],
            "Standard_Reasoning": res_std["reasoning"],
            # Curated
            "Curated_Genes": ";".join(genes_cur),
            "Curated_Answer": res_cur["cell_type"],
            "Curated_CellType": res_cur["cell_type"],
            "Curated_Confidence": res_cur["confidence"],
            "Curated_Reasoning": res_cur["reasoning"],
        }
    )

    if (i + 1) % 5 == 0:
        tmp_path = f"{OUT_DIR}/msc_benchmark_progress.csv"
        pd.DataFrame(benchmark_results).to_csv(tmp_path, index=False)
        print(f"  üíæ Progress saved to {tmp_path}")
    time.sleep(1)

# 3) ÊúÄÁµÇ‰øùÂ≠ò
df_results = pd.DataFrame(benchmark_results)
save_path = f"{OUT_DIR}/msc_benchmark_results.csv"
df_results.to_csv(save_path, index=False)
print(f"\n‚úÖ MSC Benchmark Complete! Saved to {save_path}")

[Setup] Computing global HVGs for benchmark...
üöÄ Starting MSC benchmark for 8 clusters (Top 50 genes)...

[1/8] Processing: CAFs MSC iCAF-like ‚Üí Fibro_iCAF
  üëâ Std Ans: PDGFRA+ Fibroblast (High)
  üëâ Cur Ans: Lipofibroblast (High)

[2/8] Processing: CAFs myCAF-like ‚Üí Fibro_myCAF
  üëâ Std Ans: Activated fibroblast (High)
  üëâ Cur Ans: Cancer-Associated Fibroblast (High)

[3/8] Processing: Cycling PVL ‚Üí Fibro_PVL
  üëâ Std Ans: Myofibroblast (proliferating) (High)
  üëâ Cur Ans: Pericyte (proliferating) (High)

[4/8] Processing: Endothelial ACKR1 ‚Üí Endothelial
  üëâ Std Ans: Antigen-presenting endothelial cell (High)
  üëâ Cur Ans: Post-capillary Venular Endothelial Cell  (High)

[5/8] Processing: Endothelial CXCL12 ‚Üí Endothelial
  üëâ Std Ans: Arterial endothelial cell (High)
  üëâ Cur Ans: Arterial Endothelial Cell (ISG-high) (High)
  üíæ Progress saved to /runs/20251201.v1/msc_benchmark_progress.csv

[6/8] Processing: Endothelial RGS5 ‚Üí Endothelial
  üë