In [None]:
from pathlib import Path
import os

DAY = "20251201"
Version = "v1"

PROJECT_ROOT = Path(os.getenv("LLMSC_ROOT", ".")).resolve()

DATA_DIR = Path(os.getenv("LLMSC_DATA_DIR", PROJECT_ROOT / "input")).resolve()
OUT_DIR  = Path(os.getenv("LLMSC_OUT_DIR",  PROJECT_ROOT / "runs" / f"{DAY}.{Version}")).resolve()
OUT_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
RANDOM_SEED = 42
import random, os
import numpy as np
random.seed(RANDOM_SEED)
os.environ['PYTHONHASHSEED'] = str(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

print(f"üîí Random seed set to {RANDOM_SEED} for reproducibility.")

üîí Random seed set to 42 for reproducibility.


In [None]:
import os
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
USE_LLM = bool(GEMINI_API_KEY)

if not USE_LLM:
    print("‚ö†Ô∏è GEMINI_API_KEY not set ‚Üí LLM inference cells will be skipped.")

In [None]:
import google.generativeai as genai
import scanpy as sc
import pandas as pd
import numpy as np
import scipy
import scipy.sparse
from scipy import io
import adjustText
import gc
import time
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.backends.backend_pdf import PdfPages
import seaborn as sns
from sklearn.metrics import confusion_matrix


import logging
logging.basicConfig(
    level=logging.INFO,
    format="%(levelname)s:%(name)s:%(message)s"
)

from llm_sc_curator import LLMscCurator
from llm_sc_curator.masking import FeatureDistiller

import warnings
warnings.filterwarnings("ignore")

In [None]:
# ==========================================
# 1. API
# ==========================================

genai.configure(api_key=GEMINI_API_KEY)
print("--- Available Models for your Key ---")
try:
    for m in genai.list_models():
        if 'generateContent' in m.supported_generation_methods:
            print(m.name)
except Exception as e:
    print(f"Error: {e}")

--- Available Models for your Key ---
models/gemini-2.5-flash
models/gemini-2.5-pro
models/gemini-2.0-flash-exp
models/gemini-2.0-flash
models/gemini-2.0-flash-001
models/gemini-2.0-flash-exp-image-generation
models/gemini-2.0-flash-lite-001
models/gemini-2.0-flash-lite
models/gemini-2.0-flash-lite-preview-02-05
models/gemini-2.0-flash-lite-preview
models/gemini-exp-1206
models/gemini-2.5-flash-preview-tts
models/gemini-2.5-pro-preview-tts
models/gemma-3-1b-it
models/gemma-3-4b-it
models/gemma-3-12b-it
models/gemma-3-27b-it
models/gemma-3n-e4b-it
models/gemma-3n-e2b-it
models/gemini-flash-latest
models/gemini-flash-lite-latest
models/gemini-pro-latest
models/gemini-2.5-flash-lite
models/gemini-2.5-flash-image-preview
models/gemini-2.5-flash-image
models/gemini-2.5-flash-preview-09-2025
models/gemini-2.5-flash-lite-preview-09-2025
models/gemini-3-pro-preview
models/gemini-3-pro-image-preview
models/nano-banana-pro-preview
models/gemini-robotics-er-1.5-preview
models/gemini-2.5-computer-

In [None]:
MODEL_NAME = 'models/gemini-2.5-pro'

# MODEL_NAME = 'models/gemini-3-pro-preview'
# MODEL_NAME = 'models/gemini-2.5-flash'
# MODEL_NAME = 'models/gemini-2.0-flash'

print(f"Using Model: {MODEL_NAME}")
model = genai.GenerativeModel(MODEL_NAME)

Using Model: models/gemini-2.5-pro


In [None]:
file_path = DATA_DIR / "tabula-muris-senis-droplet-processed-official-annotations.h5ad"
adata = sc.read_h5ad(file_path)

In [None]:
adata.obs

Unnamed: 0_level_0,age,cell,cell_ontology_class,cell_ontology_id,free_annotation,method,mouse.id,n_genes,sex,subtissue,tissue,tissue_free_annotation,n_counts,louvain,leiden
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
AAACCTGCAGGGTACA-1-0-0-0,24m,MACA_24m_M_TONGUE_60_AAACCTGCAGGGTACA,keratinocyte,,filiform,droplet,24-M-60,2107.0,male,,Tongue,Tongue,5482.0,5,8
AAACCTGCAGTAAGCG-1-0-0-0,24m,MACA_24m_M_TONGUE_60_AAACCTGCAGTAAGCG,keratinocyte,,suprabasal,droplet,24-M-60,3481.0,male,,Tongue,Tongue,21855.0,19,15
AAACCTGTCATTATCC-1-0-0-0,24m,MACA_24m_M_TONGUE_60_AAACCTGTCATTATCC,keratinocyte,,suprabasal,droplet,24-M-60,2599.0,male,,Tongue,Tongue,10943.0,19,15
AAACGGGGTACAGTGG-1-0-0-0,24m,MACA_24m_M_TONGUE_60_AAACGGGGTACAGTGG,keratinocyte,,suprabasal differentiating,droplet,24-M-60,3468.0,male,,Tongue,Tongue,20665.0,12,11
AAACGGGGTCTTCTCG-1-0-0-0,24m,MACA_24m_M_TONGUE_60_AAACGGGGTCTTCTCG,keratinocyte,,suprabasal differentiating,droplet,24-M-60,3189.0,male,,Tongue,Tongue,12925.0,5,28
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10X_P8_15_TTTGTCAGTACATGTC-1,3m,10X_P8_15_TTTGTCAGTACATGTC,basal epithelial cell of tracheobronchial tree,CL:0000066,,droplet,3-M-7/8,,male,,Trachea,Trachea,5000.0,51,59
10X_P8_15_TTTGTCAGTGCGCTTG-1,3m,10X_P8_15_TTTGTCAGTGCGCTTG,mesenchymal progenitor cell,CL:0008019,,droplet,3-M-7/8,,male,,Trachea,Trachea,5984.0,11,33
10X_P8_15_TTTGTCAGTTGTCGCG-1,3m,10X_P8_15_TTTGTCAGTTGTCGCG,endothelial cell,CL:0000115,,droplet,3-M-7/8,,male,,Trachea,Trachea,6507.0,40,32
10X_P8_15_TTTGTCATCGGCTTGG-1,3m,10X_P8_15_TTTGTCATCGGCTTGG,endothelial cell,CL:0000115,,droplet,3-M-7/8,,male,,Trachea,Trachea,2589.0,40,32


In [None]:
set(adata.obs["cell_ontology_class"])

{'B cell',
 'CD4-positive, alpha-beta T cell',
 'CD8-positive, alpha-beta T cell',
 'DN3 thymocyte',
 'DN4 thymocyte',
 'Kupffer cell',
 'Langerhans cell',
 'NK cell',
 'Schwann cell',
 'T cell',
 'adventitial cell',
 'alveolar macrophage',
 'basal cell',
 'basal cell of epidermis',
 'basal epithelial cell of tracheobronchial tree',
 'basophil',
 'bladder cell',
 'bladder urothelial cell',
 'blood cell',
 'bronchial smooth muscle cell',
 'brush cell',
 'cardiac neuron',
 'cardiomyocyte',
 'chondrocyte',
 'ciliated columnar cell of tracheobronchial tree',
 'classical monocyte',
 'club cell of bronchiole',
 'dendritic cell',
 'double negative T cell',
 'duct epithelial cell',
 'endocardial cell',
 'endothelial cell',
 'endothelial cell of coronary artery',
 'endothelial cell of hepatic sinusoid',
 'endothelial cell of lymphatic vessel',
 'enterocyte of epithelium of large intestine',
 'epidermal cell',
 'epithelial cell',
 'epithelial cell of large intestine',
 'epithelial cell of proxim

In [None]:
mask_b = (
    adata.obs["cell_ontology_class"].str.contains(r"\bB cell\b", regex=True, na=False)
    & ~adata.obs["cell_ontology_class"].str.contains(r"pancreatic B cell", na=False)
)
adata_b = adata[mask_b].copy()
set(adata_b.obs["cell_ontology_class"])

{'B cell',
 'immature B cell',
 'late pro-B cell',
 'naive B cell',
 'precursor B cell'}

In [None]:
print(f"\n Loading data...: {save_path}")
adata_sub = sc.read_h5ad(save_path)
print("\nCell counts per cluster (Should be balanced approx ~300 if 3 datasets merged):")
print(adata_sub.obs['meta.cluster'].value_counts().head(10))


 Loading data...: /runs/20251201.v1/mouse_b_benchmark_data.h5ad

Cell counts per cluster (Should be balanced approx ~300 if 3 datasets merged):
meta.cluster
B cell              300
immature B cell     300
late pro-B cell     300
naive B cell        300
precursor B cell    300
Name: count, dtype: int64


In [None]:
# ==========================================
# Correct Label Generation Function
# ==========================================

def get_bcell_ground_truth(cluster_name: str) -> str:
    """
    Map Tabula Muris Senis B-lineage meta.cluster to consensus GT categories.

    - immature B cell  : Erythrocyte-like contamination
    - naive B cell     : Mast cell-like contamination
    - precursor B cell : pDC / myeloid-like contamination
    - B cell           : bona fide mature B cell
    - late pro-B cell  : mixed / ambiguous (Real B + pDC + Plasma) ‚Üí B_Other
    """

    s = str(cluster_name).lower().strip()

    if "immature b cell" in s:
        return "Erythrocyte_like"   # original: immature B cell

    if "naive b cell" in s:
        return "Mast_like"      # original: naive B cell

    if "precursor b cell" in s:
        return "pDC_Myeloid_like"    # original: precursor B cell

    if s.startswith("b cell"):
        return "Mature_B"

    if "late pro-b cell" in s:
        return "B_Other"            # mixed / ambiguous ‚Üí Ë©ï‰æ°„Åã„ÇâÈô§Â§ñ

    return "B_Other"


print("Applying B-cell consensus Ground Truth mapping...")
adata_sub.obs["GT_Category"] = adata_sub.obs["meta.cluster"].apply(get_bcell_ground_truth)

print("\n--- Value counts for GT_Category ---")
print(adata_sub.obs["GT_Category"].value_counts())


Applying B-cell consensus Ground Truth mapping...

--- Value counts for GT_Category ---
GT_Category
Mature_B            300
Erythrocyte_like    300
B_Other             300
Mast_like           300
pDC_Myeloid_like    300
Name: count, dtype: int64


In [None]:
GT_KEYWORDS_MOUSE_B = {
    "Erythrocyte_like": ["Hbb-bs", "Hbb-bt", "Hba-a1", "Gypa"],
    "Mast_like":    ["Cpa3", "Mcpt8", "Gata2", "Fcer1a"],
    "pDC_Myeloid_like":  ["Siglech", "Bst2", "Irf8"],
    "Mature_B":         ["Cd79a", "Ms4a1", "Ebf1"],
    "Precursor_B":      [],
    "B_Other":          [],
}
var_names = adata_sub.var_names
markers_b = {
    k: [g for g in genes if g in var_names]
    for k, genes in GT_KEYWORDS_MOUSE_B.items()
}
markers_b = {k: v for k, v in markers_b.items() if len(v) > 0}

In [None]:
if "highly_variable" not in adata_sub.var.columns:
    print("[Setup] Computing global HVGs for benchmark...")
    sc.pp.highly_variable_genes(
        adata_sub,
        n_top_genes=2000,
        subset=False,
        flavor="seurat",
    )

curator = LLMscCurator(api_key=GEMINI_API_KEY, model_name=MODEL_NAME)
curator.set_global_context(adata_sub)

N_GENES = 50
benchmark_results = []

# 1) Retrieve cluster list from meta.cluster (sorted for reproducibility)
unique_clusters = sorted(adata_sub.obs["meta.cluster"].unique())

# Clusters with ambiguous GT values are excluded from the main benchmark.
EXCLUDE_GT_LABELS = {"Other", "Unknown"}

cluster_meta = []
dropped_clusters = []

for c in unique_clusters:
    gt_label = get_bcell_ground_truth(c)
    if gt_label in EXCLUDE_GT_LABELS:
        dropped_clusters.append((c, gt_label))
        continue
    cluster_meta.append((c, gt_label))

print(f"üöÄ Starting B cell benchmark for {len(cluster_meta)} clusters (Top {N_GENES} genes)...")
if dropped_clusters:
    print("‚ö†Ô∏è Excluded ambiguous clusters from main benchmark:")
    for c, lab in dropped_clusters:
        print(f"   - {c} (GT={lab})")


def ensure_json_result(x):
    """Always normalize LLM responses to dict format. Also complete missing keys."""
    if isinstance(x, dict):
        return {
            "cell_type":  x.get("cell_type", "Unknown"),
            "confidence": x.get("confidence", "Low"),
            "reasoning":  x.get("reasoning", ""),
        }
    elif isinstance(x, str):
        return {
            "cell_type":  x,
            "confidence": "Low",
            "reasoning":  "",
        }
    else:
        return {
            "cell_type":  "Error",
            "confidence": "Low",
            "reasoning":  repr(x),
        }


# 2) Main loop: 1 row = 1 meta.cluster
for i, (cluster_name, gt_label) in enumerate(cluster_meta):
    gt_keywords = GT_KEYWORDS_MOUSE_B.get(gt_label, [])

    print(f"\n[{i+1}/{len(cluster_meta)}] Processing: {cluster_name} ‚Üí {gt_label}")

    # -------------------------------------------------
    # A. Standard pipeline (no masking)
    # -------------------------------------------------
    adata_sub.obs["binary_group"] = "Rest"
    adata_sub.obs.loc[adata_sub.obs["meta.cluster"] == cluster_name, "binary_group"] = "Target"

    try:
        sc.tl.rank_genes_groups(
            adata_sub,
            groupby="binary_group",
            groups=["Target"],
            reference="Rest",
            method="wilcoxon",
            use_raw=False,
        )
        df_std = sc.get.rank_genes_groups_df(adata_sub, group="Target")
        genes_std = df_std["names"].head(N_GENES).tolist()
    except Exception as e:
        print(f"[WARN] Standard DE failed for {cluster_name}: {e}")
        genes_std = []
    try:
        if genes_std:
            raw_std = curator.annotate(genes_std, cell_type="Mouse B cell")
        else:
            raw_std = {
                "cell_type": "NoGenes",
                "confidence": "Low",
                "reasoning": "Empty DEG list",
            }
        res_std = ensure_json_result(raw_std)
    except Exception as e:
        print(f"[WARN] Standard annotate failed for {cluster_name}: {e}")
        res_std = ensure_json_result(
            {"cell_type": "Error", "confidence": "Low", "reasoning": str(e)}
        )

    time.sleep(2)

    # -------------------------------------------------
    # B. Curated pipeline (LLM-scCurator, masking ON)
    # -------------------------------------------------
    try:
        genes_cur = curator.curate_features(
            adata_sub,
            group_col="meta.cluster",
            target_group=cluster_name,
            n_top=N_GENES,
            use_statistics=True,   # Gini + Regex Masking ON
        )
    except Exception as e:
        print(f"[WARN] curate_features failed for {cluster_name}: {e}")
        genes_cur = []

    try:
        if genes_cur:
            raw_cur = curator.annotate(genes_cur, cell_type="Mouse B cell")
        else:
            raw_cur = {
                "cell_type": "NoGenes",
                "confidence": "Low",
                "reasoning": "Curated gene list empty",
            }
        res_cur = ensure_json_result(raw_cur)
    except Exception as e:
        print(f"[WARN] Curated annotate failed for {cluster_name}: {e}")
        res_cur = ensure_json_result(
            {"cell_type": "Error", "confidence": "Low", "reasoning": str(e)}
        )

    time.sleep(2)

    # -------------------------------------------------
    # C. Save Results
    # -------------------------------------------------
    print(f"  üëâ Std Ans: {res_std['cell_type'][:40]} ({res_std['confidence']})")
    print(f"  üëâ Cur Ans: {res_cur['cell_type'][:40]} ({res_cur['confidence']})")

    benchmark_results.append(
        {
            "Cluster_ID": cluster_name,
            # Ground truth
            "Ground_Truth": gt_label,
            "Ground_Truth_Label": gt_label,
            "Ground_Truth_Keywords": ", ".join(gt_keywords),
            # Standard
            "Standard_Genes": ";".join(genes_std),
            "Standard_Answer": res_std["cell_type"],
            "Standard_CellType": res_std["cell_type"],
            "Standard_Confidence": res_std["confidence"],
            "Standard_Reasoning": res_std["reasoning"],
            # Curated
            "Curated_Genes": ";".join(genes_cur),
            "Curated_Answer": res_cur["cell_type"],
            "Curated_CellType": res_cur["cell_type"],
            "Curated_Confidence": res_cur["confidence"],
            "Curated_Reasoning": res_cur["reasoning"],
        }
    )

    if (i + 1) % 5 == 0:
        tmp_path = f"{OUT_DIR}/mouse_b_benchmark_progress.csv"
        pd.DataFrame(benchmark_results).to_csv(tmp_path, index=False)
        print(f"  üíæ Progress saved to {tmp_path}")
    time.sleep(1)

df_results = pd.DataFrame(benchmark_results)
save_path = f"{OUT_DIR}/mouse_b_benchmark_results.csv"
df_results.to_csv(save_path, index=False)
print(f"\n‚úÖ B cell Benchmark Complete! Saved to {save_path}")

üöÄ Starting B cell benchmark for 5 clusters (Top 50 genes)...

[1/5] Processing: B cell ‚Üí Mature_B
  üëâ Std Ans: B cell (High)
  üëâ Cur Ans: Follicular B cell (High)

[2/5] Processing: immature B cell ‚Üí Erythrocyte_like
  üëâ Std Ans: Neutrophil-Erythroid Doublet (ISG-high) (High)
  üëâ Cur Ans: Neutrophil-Erythroid doublet (ISG-high) (High)

[3/5] Processing: late pro-B cell ‚Üí B_Other
  üëâ Std Ans: Plasma cell (High)
  üëâ Cur Ans: Plasma cell (High)

[4/5] Processing: naive B cell ‚Üí Mast_like
  üëâ Std Ans: Mast cell (High)
  üëâ Cur Ans: Mast cell (High)

[5/5] Processing: precursor B cell ‚Üí pDC_Myeloid_like
  üëâ Std Ans: Plasmacytoid dendritic cell (pDC) (High)
  üëâ Cur Ans: Plasmacytoid dendritic cell (High)
  üíæ Progress saved to /runs/20251201.v1/mouse_b_benchmark_progress.csv

‚úÖ B cell Benchmark Complete! Saved to /runs/20251201.v1/mouse_b_benchmark_results.csv
