In [None]:
from pathlib import Path
import os

DAY = "20251201"
Version = "v1"

PROJECT_ROOT = Path(os.getenv("LLMSC_ROOT", ".")).resolve()

DATA_DIR = Path(os.getenv("LLMSC_DATA_DIR", PROJECT_ROOT / "input")).resolve()
OUT_DIR  = Path(os.getenv("LLMSC_OUT_DIR",  PROJECT_ROOT / "runs" / f"{DAY}.{Version}")).resolve()
OUT_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
RANDOM_SEED = 42
import random, os
import numpy as np
random.seed(RANDOM_SEED)
os.environ['PYTHONHASHSEED'] = str(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

print(f"üîí Random seed set to {RANDOM_SEED} for reproducibility.")


üîí Random seed set to 42 for reproducibility.


In [None]:
import os
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
USE_LLM = bool(GEMINI_API_KEY)

if not USE_LLM:
    print("‚ö†Ô∏è GEMINI_API_KEY not set ‚Üí LLM inference cells will be skipped.")

In [None]:
import google.generativeai as genai
import scanpy as sc
import pandas as pd
import numpy as np
import scipy
import scipy.sparse
from scipy import sparse
from scipy import io
import adjustText
from adjustText import adjust_text
import gc
import re
import time
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.backends.backend_pdf import PdfPages
import seaborn as sns
from sklearn.metrics import confusion_matrix
from google.colab import userdata

import logging
logging.basicConfig(
    level=logging.INFO,
    format="%(levelname)s:%(name)s:%(message)s"
)

import llm_sc_curator
from llm_sc_curator import LLMscCurator
from llm_sc_curator.masking import FeatureDistiller

import warnings
warnings.filterwarnings("ignore")

In [None]:
# ==========================================
# 1. API
# ==========================================

genai.configure(api_key=GEMINI_API_KEY)
print("--- Available Models for your Key ---")
try:
    for m in genai.list_models():
        if 'generateContent' in m.supported_generation_methods:
            print(m.name)
except Exception as e:
    print(f"Error: {e}")

--- Available Models for your Key ---
models/gemini-2.5-flash
models/gemini-2.5-pro
models/gemini-2.0-flash-exp
models/gemini-2.0-flash
models/gemini-2.0-flash-001
models/gemini-2.0-flash-exp-image-generation
models/gemini-2.0-flash-lite-001
models/gemini-2.0-flash-lite
models/gemini-2.0-flash-lite-preview-02-05
models/gemini-2.0-flash-lite-preview
models/gemini-exp-1206
models/gemini-2.5-flash-preview-tts
models/gemini-2.5-pro-preview-tts
models/gemma-3-1b-it
models/gemma-3-4b-it
models/gemma-3-12b-it
models/gemma-3-27b-it
models/gemma-3n-e4b-it
models/gemma-3n-e2b-it
models/gemini-flash-latest
models/gemini-flash-lite-latest
models/gemini-pro-latest
models/gemini-2.5-flash-lite
models/gemini-2.5-flash-image-preview
models/gemini-2.5-flash-image
models/gemini-2.5-flash-preview-09-2025
models/gemini-2.5-flash-lite-preview-09-2025
models/gemini-3-pro-preview
models/gemini-3-pro-image-preview
models/nano-banana-pro-preview
models/gemini-robotics-er-1.5-preview
models/gemini-2.5-computer-

In [None]:
MODEL_NAME = 'models/gemini-2.5-pro'

# MODEL_NAME = 'models/gemini-3-pro-preview'
# MODEL_NAME = 'models/gemini-2.5-flash'
# MODEL_NAME = 'models/gemini-2.0-flash'

print(f"Using Model: {MODEL_NAME}")
model = genai.GenerativeModel(MODEL_NAME)

Using Model: models/gemini-2.5-pro


In [None]:
save_path = OUT_DIR / "cd8_benchmark_data.h5ad"

print(f"\n Loading data...: {save_path}")
adata= sc.read_h5ad(save_path)
print("\nCell counts per cluster (Should be balanced approx ~300 if 3 datasets merged):")
print(adata.obs['meta.cluster'].value_counts().head(10))


 Loading data...: //runs/20251201.v1/cd8_benchmark_data.h5ad

Cell counts per cluster (Should be balanced approx ~300 if 3 datasets merged):
meta.cluster
CD8.c01.Tn.MAL          300
CD8.c02.Tm.IL7R         300
CD8.c03.Tm.RPS12        300
CD8.c04.Tm.CD52         300
CD8.c05.Tem.CXCR5       300
CD8.c06.Tem.GZMK        300
CD8.c07.Temra.CX3CR1    300
CD8.c11.Tex.PDCD1       300
CD8.c10.Trm.ZNF683      300
CD8.c12.Tex.CXCL13      300
Name: count, dtype: int64


In [None]:
adata.obs

Unnamed: 0,cancerType,patient,libraryID,loc,meta.cluster,platform,Cancer_Type,Sample_ID
TTGAACGCACGGATAG.13,ESCA,ESCA.P20181123,ESCA-P20181123-N,N,CD8.c07.Temra.CX3CR1,10X,ESCA,TTGAACGCACGGATAG.13
CCTAGCTGTTTCCACC.7,ESCA,ESCA.P20190410,ESCA-P20190410-N,N,CD8.c07.Temra.CX3CR1,10X,ESCA,CCTAGCTGTTTCCACC.7
TGATTTCCACCCTATC.5,ESCA,ESCA.P20190404,ESCA-P20190404-N,N,CD8.c07.Temra.CX3CR1,10X,ESCA,TGATTTCCACCCTATC.5
GGAGCAACAATCTACG.7,ESCA,ESCA.P20190410,ESCA-P20190410-N,N,CD8.c07.Temra.CX3CR1,10X,ESCA,GGAGCAACAATCTACG.7
GCACATAAGGAACTGC.10,ESCA,ESCA.P20190411,ESCA-P20190411-T,T,CD8.c07.Temra.CX3CR1,10X,ESCA,GCACATAAGGAACTGC.10
...,...,...,...,...,...,...,...,...
GTCACGGGTGGTAACG.51,RC,RC.P20190923,RC-P20190923-T,T,CD8.c15.ISG.IFIT1,10X,RC,GTCACGGGTGGTAACG.51
TCAGCTCTCTATCCCG.51,RC,RC.P20190923,RC-P20190923-T,T,CD8.c15.ISG.IFIT1,10X,RC,TCAGCTCTCTATCCCG.51
TGATTTCAGTATTGGA.51,RC,RC.P20190923,RC-P20190923-T,T,CD8.c15.ISG.IFIT1,10X,RC,TGATTTCAGTATTGGA.51
TTGCCGTCAGCCTATA.51,RC,RC.P20190923,RC-P20190923-T,T,CD8.c15.ISG.IFIT1,10X,RC,TTGCCGTCAGCCTATA.51


In [None]:
# ==========================================
# Correct Label Generation Function
# Definition based on cluster IDs from Zheng et al. 2021
# ==========================================

def get_cd8_ground_truth(cluster_name: str) -> str:
    """
    Map Zheng et al. CD8 meta.cluster names to GT categories.

    The GT labels are intentionally slightly finer (Naive / Effector / EffectorMemory /
    Exhausted / ISG / MAIT / NK_killer / Cycling), and are later collapsed into
    coarse (major, state) pairs by CD8_HIER_CFG.gt_rules.
    """
    s = str(cluster_name).lower()

    # 1) Distinct functional states
    if "mait" in s:
        return "CD8_MAIT"

    if any(k in s for k in ["isg", "interferon", "ifit1"]):
        return "CD8_ISG"

    if any(k in s for k in ["proliferating", "cycle", "mki67", "top2a"]):
        return "CD8_Cycling"

    # 2) NK-like killer pool (exclude explicit T cell labels)
    if "nk" in s and "t cell" not in s:
        return "CD8_NK_Killer"

    # 3) Exhausted pool
    if any(k in s for k in ["tex", "exhausted", "pdcd1"]):
        return "CD8_Exhausted"

    # 4) TRM / resident memory ‚Üí treated as EffectorMemory in GT
    if any(k in s for k in ["trm", "resident", "znf683", "itgae", "cd69"]):
        return "CD8_EffectorMemory"

    # 5) Naive pool (true naive; use 'tn.' to avoid Tn/Tm confusion)
    if "tn." in s or "naive" in s:
        return "CD8_Naive"

    # 6) Temra / CX3CR1-high killers
    if any(k in s for k in ["temra", "cx3cr1", "klrg1"]):
        return "CD8_Effector"

    # 7) Tem / Tm / GZMK+ effector-memory clusters
    if any(k in s for k in ["tem.", "tm.", "memory", "gzmk", "aqp3", "ltb"]):
        return "CD8_EffectorMemory"

    # 8) Tk / killer T clusters (Zheng's Tk)
    if "tk" in s or "killer" in s:
        return "CD8_Effector"

    # 9) Fallback
    return "CD8_Other"



print("Applying Ground Truth Mapping...")
adata.obs['GT_Category'] = adata.obs['meta.cluster'].apply(get_cd8_ground_truth)
print(adata.obs['GT_Category'].value_counts())


Applying Ground Truth Mapping...
GT_Category
CD8_EffectorMemory    1983
CD8_Exhausted          928
CD8_Effector           753
CD8_Naive              300
CD8_MAIT               300
CD8_ISG                202
Name: count, dtype: int64


In [None]:
GT_KEYWORDS_CD8 = {
    "CD8_Exhausted":      ["PDCD1", "HAVCR2", "LAG3", "TOX"],
    "CD8_ISG":            ["ISG15", "IFIT1", "MX1", "STAT1"],
    "CD8_EffectorMemory": ["GZMK", "LTB", "AQP3"],
    "CD8_Effector":       ["GZMB", "PRF1", "GNLY", "CX3CR1", "KLRG1"],
    "CD8_MAIT":           ["SLC4A10", "KLRB1"],
    "CD8_Naive":          ["TCF7", "LEF1", "CCR7", "SELL"],
    "CD8_Cycling":        ["MKI67", "TOP2A"],
}


var_names = adata.var_names
filtered_markers = {k: [g for g in v if g in var_names] for k, v in GT_KEYWORDS_CD8.items()}


In [None]:
if "highly_variable" not in adata.var.columns:
    print("[Setup] Computing global HVGs for CD8 benchmark...")

    hvg_kwargs = dict(
        n_top_genes=2000,
        subset=False,
        batch_key="Cancer_Type",
    )

    if "counts" in adata.layers:
        hvg_kwargs.update(
            flavor="seurat_v3",
            layer="counts",
        )
        print("[HVG] Using flavor='seurat_v3' on layers['counts'] with batch_key='Cancer_Type'.")
    else:
        hvg_kwargs.update(
            flavor="seurat",
        )
        print("[HVG] `layers['counts']` not found ‚Üí using flavor='seurat' on log1p .X.")

    print(hvg_kwargs)

    sc.pp.highly_variable_genes(adata, **hvg_kwargs)

# Curator Initialize
curator = LLMscCurator(api_key=GEMINI_API_KEY, model_name=MODEL_NAME)
curator.set_global_context(adata)

[Setup] Computing global HVGs for CD8 benchmark...
[HVG] Using flavor='seurat_v3' on layers['counts'] with batch_key='Cancer_Type'.
{'n_top_genes': 2000, 'subset': False, 'batch_key': 'Cancer_Type', 'flavor': 'seurat_v3', 'layer': 'counts'}


In [None]:
N_GENES = 50
benchmark_results = []

# 1) Retrieve cluster list from meta.cluster (sorted for reproducibility)
unique_clusters = sorted(adata.obs["meta.cluster"].unique())

# Clusters with ambiguous GT values are excluded from the main benchmark.
EXCLUDE_GT_LABELS = {"CD8_Other", "Other", "Unknown"}

cluster_meta = []
dropped_clusters = []

for c in unique_clusters:
    gt_label = get_cd8_ground_truth(c)
    if gt_label in EXCLUDE_GT_LABELS:
        dropped_clusters.append((c, gt_label))
        continue
    cluster_meta.append((c, gt_label))

print(f"üöÄ Starting CD8 benchmark for {len(cluster_meta)} clusters (Top {N_GENES} genes)...")
if dropped_clusters:
    print("‚ö†Ô∏è Excluded ambiguous clusters from main benchmark:")
    for c, lab in dropped_clusters:
        print(f"   - {c} (GT={lab})")


def ensure_json_result(x):
    """LLMÂøúÁ≠î„ÇíÂ∏∏„Å´ dict ÂΩ¢Âºè„Å´Ê≠£Ë¶èÂåñ„ÄÇÊ¨†Êêç„Ç≠„Éº„ÇÇË£úÂÆå„Åô„Çã„ÄÇ"""
    if isinstance(x, dict):
        return {
            "cell_type":  x.get("cell_type", "Unknown"),
            "confidence": x.get("confidence", "Low"),
            "reasoning":  x.get("reasoning", ""),
        }
    elif isinstance(x, str):
        return {
            "cell_type":  x,
            "confidence": "Low",
            "reasoning":  "",
        }
    else:
        return {
            "cell_type":  "Error",
            "confidence": "Low",
            "reasoning":  repr(x),
        }


# 2) Main loop: 1 row = 1 meta.cluster
for i, (cluster_name, gt_label) in enumerate(cluster_meta):
    gt_keywords = GT_KEYWORDS_CD8.get(gt_label, [])

    print(f"\n[{i+1}/{len(cluster_meta)}] Processing: {cluster_name} ‚Üí {gt_label}")

    # -------------------------------------------------
    # A. Standard pipeline (no masking)
    # -------------------------------------------------
    adata.obs["binary_group"] = "Rest"
    adata.obs.loc[adata.obs["meta.cluster"] == cluster_name, "binary_group"] = "Target"

    try:
        sc.tl.rank_genes_groups(
            adata,
            groupby="binary_group",
            groups=["Target"],
            reference="Rest",
            method="wilcoxon",
            use_raw=False,
        )
        df_std = sc.get.rank_genes_groups_df(adata, group="Target")
        genes_std = df_std["names"].head(N_GENES).tolist()
    except Exception as e:
        print(f"[WARN] Standard DE failed for {cluster_name}: {e}")
        genes_std = []

    try:
        if genes_std:
            raw_std = curator.annotate(genes_std, use_auto_context=False)
        else:
            raw_std = {
                "cell_type": "NoGenes",
                "confidence": "Low",
                "reasoning": "Empty DEG list",
            }
        res_std = ensure_json_result(raw_std)
    except Exception as e:
        print(f"[WARN] Standard annotate failed for {cluster_name}: {e}")
        res_std = ensure_json_result(
            {"cell_type": "Error", "confidence": "Low", "reasoning": str(e)}
        )

    time.sleep(3)

    # -------------------------------------------------
    # B. Curated pipeline (LLM-scCurator, masking ON)
    # -------------------------------------------------
    try:
        genes_cur = curator.curate_features(
            adata,
            group_col="meta.cluster",
            target_group=cluster_name,
            n_top=N_GENES,
            use_statistics=True,   # Gini + Regex Masking ON
        )
    except Exception as e:
        print(f"[WARN] curate_features failed for {cluster_name}: {e}")
        genes_cur = []

    try:
        if genes_cur:
            raw_cur = curator.annotate(genes_cur, use_auto_context=True)
        else:
            raw_cur = {
                "cell_type": "NoGenes",
                "confidence": "Low",
                "reasoning": "Curated gene list empty",
            }
        res_cur = ensure_json_result(raw_cur)
    except Exception as e:
        print(f"[WARN] Curated annotate failed for {cluster_name}: {e}")
        res_cur = ensure_json_result(
            {"cell_type": "Error", "confidence": "Low", "reasoning": str(e)}
        )

    time.sleep(3)

    # -------------------------------------------------
    # C. Save Results
    # -------------------------------------------------
    print(f"  üëâ Std Ans: {res_std['cell_type'][:40]} ({res_std['confidence']})")
    print(f"  üëâ Cur Ans: {res_cur['cell_type'][:40]} ({res_cur['confidence']})")

    benchmark_results.append(
        {
            "Cluster_ID": cluster_name,
            # Ground truth
            "Ground_Truth": gt_label,
            "Ground_Truth_Label": gt_label,
            "Ground_Truth_Keywords": ", ".join(gt_keywords),
            # Standard
            "Standard_Genes": ";".join(genes_std),
            "Standard_Answer": res_std["cell_type"],
            "Standard_CellType": res_std["cell_type"],
            "Standard_Confidence": res_std["confidence"],
            "Standard_Reasoning": res_std["reasoning"],
            # Curated
            "Curated_Genes": ";".join(genes_cur),
            "Curated_Answer": res_cur["cell_type"],
            "Curated_CellType": res_cur["cell_type"],
            "Curated_Confidence": res_cur["confidence"],
            "Curated_Reasoning": res_cur["reasoning"],
        }
    )

    if (i + 1) % 5 == 0:
        tmp_path = save_path = OUT_DIR / "cd8_benchmark_progress.csv"
        pd.DataFrame(benchmark_results).to_csv(tmp_path, index=False)
        print(f"  üíæ Progress saved to {tmp_path}")
    time.sleep(3)

df_results = pd.DataFrame(benchmark_results)
save_path = OUT_DIR / "cd8_benchmark_results.csv"
df_results.to_csv(save_path, index=False)
print(f"\n‚úÖ CD8 Benchmark Complete! Saved to {save_path}")

üöÄ Starting CD8 benchmark for 17 clusters (Top 50 genes)...

[1/17] Processing: CD8.c01.Tn.MAL ‚Üí CD8_Naive
  üëâ Std Ans: Naive T cell (High)
  üëâ Cur Ans: Naive T cell (High)

[2/17] Processing: CD8.c02.Tm.IL7R ‚Üí CD8_EffectorMemory
  üëâ Std Ans: Early activated T cell (High)
  üëâ Cur Ans: CD8+ Tissue-Resident Memory T cell (Acti (High)

[3/17] Processing: CD8.c03.Tm.RPS12 ‚Üí CD8_EffectorMemory
  üëâ Std Ans: GZMK+ CD8+ T cell (High)
  üëâ Cur Ans: GZMK+ CD8+ T cell (High)

[4/17] Processing: CD8.c04.Tm.CD52 ‚Üí CD8_EffectorMemory
  üëâ Std Ans: CD8+ Effector T cell (High)
  üëâ Cur Ans: CD8+ Tissue-Resident Memory T cell (High)

[5/17] Processing: CD8.c05.Tem.CXCR5 ‚Üí CD8_EffectorMemory
  üëâ Std Ans: GZMK+ Effector Memory T cell (High)
  üëâ Cur Ans: GZMK+ CD8+ Effector T cell (High)
  üíæ Progress saved to //runs/20251201.v1/cd8_benchmark_progress.csv

[6/17] Processing: CD8.c06.Tem.GZMK ‚Üí CD8_EffectorMemory
  üëâ Std Ans: CD8+ Exhausted T cell (High)
  üëâ