In [1]:
import pandas as pd
import json
import torch
import re
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
# ==========================================
# 1. SETUP & MODEL LOADING
# ==========================================
LOCAL_QWEN_PATH = "**/Phase2/models_phase2/Qwen2.5-1.5B-Instruct/Qwen2.5-1.5B-Instruct_downloaded"

# Defining DEVICE (Your function uses 'DEVICE' in caps)
DEVICE = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"üöÄ Device: {DEVICE}")

tokenizer = AutoTokenizer.from_pretrained(LOCAL_QWEN_PATH, local_files_only=True)
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    LOCAL_QWEN_PATH,
    device_map=DEVICE,
    torch_dtype=torch.float32,
    trust_remote_code=True,
    local_files_only=True
)


`torch_dtype` is deprecated! Use `dtype` instead!


üöÄ Device: mps


In [3]:
# ==========================================
# 1. TOP 20 MARKER GENES (EVIDENCE 0)
# ==========================================

TOP_20_MARKERS = {
    1: "IL7R, CCL19, CD2, CCR7, ITK, TRBC2, GZMK, CD3E, FCN1, CD40LG, TRAC, CTLA4, SELL, FOXP3, CD163, GIMAP7, TIGIT, CD4, CXCR4, CCL5",
    2: "ITGB6, KRT17, LCN2, MLPH, CTTNBP2, CTNND2, CXCL1, TACSTD2, DEFB1, AGR2, TSPAN13, KIF26B, ANXA13, RARRES1, TMC5, FGG, UGT1A1, RERG, CXCL3, SNAP25",
    3: "ADH4, SULT2A1, ARG1, SERPINC1, HAMP, AHSG, AGXT, CYP8B1, SERPINA11, AFM, HAO2, SPP2, HAO1, PCK1, ADH1A, ALDOB, APOF, F13B, F2, ITIH1",
    4: "APCDD1, CHST4, TACSTD2, ALKAL2, KIF12, SLC44A3, F5, FGFR1, FGFR2, CXCL8, NUAK2, PKHD1, FAM171A1, GLIS3, TESC, DCDC2, ANXA9, FXYD2, BICC1, ITGB8",
    5: "KIT, NGFR, ARHGAP24, MFAP4, DSEL, WNK2, FGFR3, EVC, CD44, CHST9, F5, SLC12A2, AQP1, ANXA4, ATP1A1, FTH1, ITGB1, TMSB10, HSP90B1, SOX4",
    6: "CFHR5, COCH, CFHR4, CALB2, GDF15, RGS4, F2, EPHB6, FGFR1, MMP15, CFH, CXCL8, FBLN1, F5, SLC12A2, TACSTD2, FGFR4, ANXA9, HOMER2, GJB1",
    7: "KCNJ16, PDGFRA, CTNND2, HNF1B, CX3CL1, APCDD1, UGT1A1, SOX6, CDH2, ANXA9, WNK2, CDH6, C5, PDGFD, NEO1, FXYD2, CCL2, ABCC4, CHST9, S100A10",
    8: "AKR1B10, TMC5, UGT1A1, GDF15, MUC20, GPX2, CXCL8, TSPAN13, CRYAB, AGR2, S100A13, MAP2, FGA, AKR1C1, TALDO1, TACSTD2, FGG, CHST4, SLPI, POSTN",
    9: "RERG, ADH6, ARSE, RGS4, COCH, CYP4A11, SMIM24, GCNT3, HUNK, NOVA1, GCNT4, KCNJ16, UNC5CL, CCL20, MAP2, OGDHL, FAM149A, BAMBI, SNAP25, CX3CL1",
    10: "IER5, CRYAB, NUAK2, PDGFRA, TSPAN1, FXYD2, C4BPA, CYP4A11, VEPH1, ANXA13, SLPI, CYP3A5, CYP3A4, S100A13, CX3CL1, ARSE, ANXA9, PAQR5, ANPEP, CHST9",
    11: "ACE2, FGB, CCL20, MLPH, FGA, LOXL4, TMEM156, FGG, ALKAL2, TFR2, CHST4, GPX2, SMIM24, MUC1, VEPH1, C5, HOMER2, CXCL2, TMC5, SERPINA6",
    12: "C9ORF152, CALB2, ANXA3, FGA, ALKAL2, TMEM156, ARHGAP24, C5, SLC12A2, CHST9, SH3YL1, WNK2, ABCC4, TMC4, HOMER2, DCDC2, VCAN, ALDH1A1, FGG, FGFR3",
    13: "FBLN1, TRIM45, CXCL3, TMC5, CRYAB, VEPH1, HUNK, HKDC1, SLC44A3, SEMA6A, SPNS2, CXCL2, TUT4, CXCL12, CXCL1, SLPI, SERPINA6, EVC, ELF3, SEMA4G",
    14: "ALDOB, SMIM24, TRIM45, RGS4, TFR2, OGDHL, GCNT4, G6PC, GPT, SLC5A9, AFF3, CDH6, MUC13, SMAD5, IFI27, SDCBP2, AKR1B10, RIC8B, GDF15, SLC17A4",
    15: "CALCA, AKR1B10, SPNS2, UGT1A1, COCH, RGS4, NCAM1, CYP4A11, ELOVL7, MYRF, ALDH1A1, SLC5A9, ANXA9, CTTNBP2, GCNT4, CX3CL1, DSEL, FXYD2, PDGFD, HKDC1",
    16: "HMGCS2, NCAM1, VSIG1, IL1RN, CAPN8, TFF1, APOA5, BCL2L15, DUOX2, HPGD, ACSL5, UGT2B10, G6PC, MMP15, RORC, TM4SF5, GJB2, BATF, PPARG, RAP1GAP",
    17: "CALB1, AGR3, MYB, SDR16C5, PLAC8, S100P, GALNT5, ST6GALNAC1, SLC2A1, AGR2, CAPN8, MSLN, TFF2, CDA, SLCO1B3, TM4SF5, GCNT3, C15ORF48, GJB2, IFI27",
    18: "MUC3A, CALCA, DUOX2, HPGD, GCNT3, ANXA13, KRT19, CDH6, ITGB1, BAMBI, VSIG1, IGFBP5, CFH, TSPAN1, MUC1, FTH1, HK2, TMC5, JPT1, TCN1",
    19: "CXCR4, TESC, CXCL8, FTH1, FGG, CDH1, FCGR3A, CD24, TACSTD2, VIM, C15ORF48, FGA, C1QC, CLU, FGFR2, FGB, IGFBP5, SERPINA6, ANPEP, CTSB",
    20: "DUOX2, LCN2, AGXT2, C4BPA, LIPC, APOD, PROM1, CYP4A11, ACE2, FCGBP, KCNJ16, MUC20, DSEL, CHST4, AKR1C3, CYP3A5, UNC5CL, ANPEP, FXYD2, CCL2",
    21: "ERN2, MS4A8, ST6GALNAC1, SLCO1B3, AGR3, S100P, MYB, KCNN4, TFF2, GALNT5, SDR16C5, TFF1, AGR2, GCNT3, DUOX2, RARRES1, FUT2, SLC2A1, TFF3, CAPN8",
    22: "MSLN, KRT17, ITGB6, VSIG1, KIF26B, PTGDS, MUC1, MRS2, CDH2, CFH, MAOB, GPX2, FAM171A1, S100A10, PMEPA1, CLU, CD24, PDGFRA, LAD1, KCNJ16",
    23: "C15ORF48, FAM83E, CANX, ATP1A1, ITGB1, PSAP, TMSB10, CTSB, CLU, VIM, CD63, ANXA4, MYL6, ANPEP, AQP1, TIMP1, DCDC2, LGALS3BP, CDH1, CD24",
}

In [4]:

# ==========================================
# 1. CLEANED CANDIDATE LIST (Top 20, "Normal cell" removed)
# ==========================================
CANDIDATE_DICT = {
    1: ['T cells', 'B cells', 'Astrocytes', 'B cells naive', 'Cancer cell', 'Dendritic cells', 'Endothelial cells', 'Eosinophils', 'Macrophages', 'Monocytes', 'NK cells', 'Nuocytes', 'Plasmacytoid dendritic cells', 'Platelets', 'T helper cells', 'T regulatory cells'],
    2: ['Basal cells', 'Cholangiocytes', 'Hepatocytes', 'Acinar cells', 'Airway goblet cells', 'Astrocytes', 'Chromaffin cells', 'Epithelial cells', 'Epsilon cells', 'Fibroblasts', 'Germ cells', 'Mast cells', 'Neurons', 'Pulmonary alveolar type II cells'],
    3: ['Adipocytes', 'Dendritic cells', 'Acinar cells', 'Airway epithelial cells', 'Beta cells', 'Erythroid-like and erythroid precursor cells', 'Proximal tubule cells'],
    4: ['Beta cells', 'Cholangiocytes', 'Chondrocytes', 'Dendritic cells', 'Distal tubule cells', 'Ductal cells', 'Epithelial cells', 'Luminal epithelial cells', 'Neurons', 'Osteoblasts', 'Proximal tubule cells', 'T regulatory cells'],
    5: ['Airway epithelial cells', 'Astrocytes', 'Basophils', 'Bergmann glia', 'Cancer cell', 'Cardiac stem and precursor cells', 'Chromaffin cells', 'Ductal cells', 'Epithelial cells', 'Hepatic stellate cells', 'Podocytes', 'Retinal ganglion cells', 'Sertoli cells'],
    6: ['Cajal-Retzius cells', 'Acinar cells', 'Alpha cells', 'Cholangiocytes', 'Chromaffin cells', 'Dendritic cells', 'Embryonic stem cells', 'Epithelial cells', 'Fibroblasts', 'Hepatocytes', 'Osteoblasts', 'Stromal cells'],
    7: ['Ductal cells', 'Hematopoietic stem cells', 'Beta cells', 'Cholangiocytes', 'Chondrocytes', 'Distal tubule cells', 'Fibroblasts', 'Hepatocytes', 'Loop of Henle cells', 'Osteoblasts', 'Platelets', 'Pulmonary alveolar type II cells', 'Radial glia cells', 'Sertoli cells'],
    8: ['Hepatocytes', 'Airway goblet cells', 'Acinar cells', 'Adipocytes', 'Cholangiocytes', 'Dendritic cells', 'Ductal cells', 'Endothelial cells', 'Epithelial cells', 'Foveolar cells', 'Paneth cells', 'Pyramidal cells', 'Schwann cells'],
    9: ['Alpha cells', 'Acinar cells', 'Cajal-Retzius cells', 'Chromaffin cells', 'Distal tubule cells', 'Ductal cells', 'Epithelial cells', 'Oligodendrocytes', 'Osteoblasts', 'Pyramidal cells', 'Sebocytes', 'Transient cells'],
    10: ['Epithelial cells', 'Acinar cells', 'Beta cells', 'Cancer cell', 'Ductal cells', 'Endothelial cells', 'Enterocytes', 'Epsilon cells', 'Fibroblasts', 'Oligodendrocytes', 'Osteoblasts', 'Schwann cells', 'Sebocytes', 'Sertoli cells'],
    11: ['Hepatocytes', 'Alpha cells', 'Acinar cells', 'Alveolar macrophages', 'B cells memory', 'Basal cells', 'Cajal-Retzius cells', 'Epithelial cells', 'Erythroblasts', 'Gamma (PP) cells', 'Mast cells', 'Paneth cells'],
    12: ['Cajal-Retzius cells', 'Hepatocytes', 'Alpha cells', 'Astrocytes', 'B cells memory', 'Chromaffin cells', 'Dendritic cells', 'Ductal cells', 'Enterocytes', 'Epsilon cells', 'Germ cells', 'Platelets', 'Podocytes', 'Sertoli cells'],
    13: ['Cholangiocytes', 'Fibroblasts', 'Alveolar macrophages', 'Distal tubule cells', 'Ductal cells', 'Endothelial cells', 'Epiblast cells', 'Hepatocytes', 'Peri-islet Schwann cells', 'Retinal ganglion cells', 'Schwann cells', 'Sertoli cells'],
    14: ['Acinar cells', 'Alpha cells', 'Hepatocytes', 'Cajal-Retzius cells', 'Embryonic stem cells', 'Endothelial cells', 'Enterocytes', 'Epsilon cells', 'Erythroblasts', 'Foveolar cells', 'Loop of Henle cells', 'Olfactory epithelial cells', 'Reticulocytes', 'Transient cells'],
    15: ['Alpha cells', 'Oligodendrocytes', 'Sebocytes', 'Beta cells', 'Endothelial cells', 'Foveolar cells', 'Hepatocytes', 'Natural killer T cells', 'Neurons', 'Osteoblasts', 'Radial glia cells', 'Reticulocytes'],
    16: ['Enterocytes', 'Hepatocytes', 'Acinar cells', 'Adipocyte progenitor cells', 'Cholangiocytes', 'Dendritic cells', 'Epsilon cells', 'Erythroid-like and erythroid precursor cells', 'Fibroblasts', 'Germ cells', 'Macrophages', 'Monocytes', 'Natural killer T cells', 'Stromal cells', 'T helper cells'],
    17: ['Airway goblet cells', 'Dendritic cells', 'Endothelial cells', 'Acinar cells', 'Basophils', 'Distal tubule cells', 'Enteric glia cells', 'Enterocytes', 'Ependymal cells', 'Fibroblasts', 'Trophoblast cells'],
    18: ['Acinar cells', 'Bergmann glia', 'Cholangiocytes', 'Ductal cells', 'Epithelial cells', 'Epsilon cells', 'Erythroid-like and erythroid precursor cells', 'Gamma delta T cells', 'Leydig cells', 'Loop of Henle cells', 'Macrophages', 'Pluripotent stem cells', 'Sebocytes'],
    19: ['Hepatocytes', 'Acinar cells', 'B cells', 'Dendritic cells', 'Airway epithelial cells', 'Alpha cells', 'Astrocytes', 'Cholangiocytes', 'Gamma (PP) cells', 'Kupffer cells', 'Leydig cells', 'Luminal epithelial cells', 'M√ºller cells', 'Neurons'],
    20: ['Acinar cells', 'Airway goblet cells', 'Astrocytes', 'Basal cells', 'Beta cells', 'Cancer cell', 'Crypt cells', 'Distal tubule cells', 'Ductal cells', 'Goblet cells', 'Hepatocytes', 'Proximal tubule cells', 'Schwann cells', 'Sebocytes'],
    21: ['Airway goblet cells', 'Cholangiocytes', 'Acinar cells', 'Basal cells', 'Basophils', 'Delta cells', 'Endothelial cells', 'Enteric glia cells', 'Ependymal cells', 'Erythroid-like and erythroid precursor cells', 'Trophoblast cells'],
    22: ['Basal cells', 'Acinar cells', 'Adipocytes', 'Airway goblet cells', 'B cells', 'Distal tubule cells', 'Ductal cells', 'Endothelial cells', 'Enterocytes', 'Erythroid-like and erythroid precursor cells', 'Fibroblasts', 'Germ cells', 'Hematopoietic stem cells', 'Macrophages', 'M√ºller cells', 'Oligodendrocyte progenitor cells', 'Paneth cells'],
    23: ['Acinar cells', 'Airway epithelial cells', 'Astrocytes', 'Bergmann glia', 'Ductal cells', 'Adipocytes', 'Alpha cells', 'B cells', 'Embryonic stem cells', 'Monocytes', 'M√ºller cells']
}

In [5]:
import torch
import pandas as pd
import re
from tqdm import tqdm

# ==========================================
# 1. FIXED PROMPT (Logic Corrected)
# ==========================================
PROMPT_TEMPLATE = """
You are an expert Cell Biologist Annotator.
Your task is to identify the **specific cell type** of a cluster based on its functional terms and marker genes.

---

---
**EVIDENCE 0: TOP 20 MARKER GENES (Raw Data):**
{top_20_genes}

**EVIDENCE 1: DATA-DRIVEN CANDIDATES (From Database):**
{candidate_list}


---
**ALLOWED TAXONOMY (You MUST output one of these exact strings):**
- Hepatocyte
- Cholangiocyte (Tumor)
- Cholangiocyte (Reactive/EMT-like)
- Fibroblast / Stroma
- Mesenchymal progenitors
- T Cell
- Macrophage / Monocyte


---

# **INSTRUCTIONS:**
# 1. Analyze the **top 20 genes** and cross-reference with the **Candidates**.
# 2. **Map the Candidate** to the **Allowed Taxonomy**:
#    - If Candidate is "Cancer cell" or "Cholangiocyte" AND genes show tumor markers -> Map to **Cholangiocyte (Tumor)**.
#    - If Candidate is "Hepatic Stellate Cell" -> Map to **Fibroblast / Stroma**.
# 3. **Immune Check:** If you see **CD3D, CD3E, CD4, CD8, TRAC**, you MUST label as **T Cell**.
# 4. **Specificity Rule:** If genes support **Hepatocyte** (FGA, FGB, FGG, ALB, CYP), choose that over generic "Normal cell".



---
**OUTPUT FORMAT (Strict JSON):**
{{
    "reasoning": "Explain why you chose this taxonomy label based on the evidence.",
    "label": "EXACT STRING FROM ALLOWED TAXONOMY"
}}
"""

In [6]:
# ==========================================
# 3. YOUR INFERENCE FUNCTION (Integrated)
# ==========================================
def ask_local_qwen(raw_prompt):
    try:
        # 1. Format input using Qwen's chat template
        messages = [{"role": "user", "content": raw_prompt}]
        text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        
        # 2. Prepare inputs
        model_inputs = tokenizer([text], return_tensors="pt").to(DEVICE)

        # 3. Generate Answer
        with torch.no_grad():
            generated_ids = model.generate(
                model_inputs.input_ids,
                max_new_tokens=512,      # Correctly increased length
                do_sample=False,         # Deterministic
                temperature=0.0,         # Greedy decoding
                repetition_penalty=1.1   # Prevents looping
            )

        # 4. Decode ONLY the answer
        input_length = model_inputs.input_ids.shape[1]
        new_tokens = generated_ids[0][input_length:]
        response = tokenizer.decode(new_tokens, skip_special_tokens=True)
        
        return response.strip()

    except Exception as e:
        return f"Error: {str(e)}"

In [7]:
# ==========================================
# 4. ROBUST PARSER (THIS WAS MISSING)
# ==========================================
def extract_fields_with_regex(text):
    """
    Extracts 'label' and 'reasoning' directly from text using Regex.
    Works even if JSON syntax is broken.
    """
    # Clean Markdown
    text = text.replace("```json", "").replace("```", "").strip()
    
    # Extract Label
    label_match = re.search(r'"label"\s*:\s*"([^"]+)"', text, re.IGNORECASE)
    label = label_match.group(1) if label_match else "Error"

    # Extract Reasoning
    reasoning_match = re.search(r'"reasoning"\s*:\s*"([^"]+)"', text, re.IGNORECASE | re.DOTALL)
    reasoning = reasoning_match.group(1) if reasoning_match else "No reasoning found"
    
    # Clean newlines for CSV safety
    reasoning = reasoning.replace("\n", " ").replace('"', "'")
    
    return {"label": label, "reasoning": reasoning}

In [8]:
# ==========================================
# 2. MODIFIED EXECUTION LOOP (EV0 + EV1)
# ==========================================
results = []
print("‚è≥ Starting Inference (Evidence 0 + Evidence 1)...")

# Iterate over TOP_20_MARKERS
for cluster_id, gene_string in tqdm(TOP_20_MARKERS.items()):
    
    # 1. Retrieve Candidates (Evidence 1)
    current_candidates = CANDIDATE_DICT.get(cluster_id, ["Unknown"])
    candidate_str = str(current_candidates)
    
    # 2. Format Prompt (Inject EV0 & EV1)
    # Note: We do NOT include 'input_data' (Evidence 2) here
    final_prompt = PROMPT_TEMPLATE.format(
        top_20_genes=gene_string,
        candidate_list=candidate_str
    )
    
    # 3. Inference
    raw_output = ask_local_qwen(final_prompt)
    
    # 4. Parsing
    parsed = extract_fields_with_regex(raw_output)
    
    results.append({
        "Cluster": cluster_id,
        "Predicted_Label": parsed["label"],
        "Reasoning": parsed["reasoning"],
        "Genes_Used": gene_string,
        "Candidates_Used": candidate_str
    })

# ==========================================
# 3. SAVE RESULTS
# ==========================================
df_results = pd.DataFrame(results)
print("\n‚úÖ Final Results:")
print(df_results[['Cluster', 'Predicted_Label']].head())

# Saving with the specific filename requested
df_results.to_csv("All_cluster_v5_baseline_knoweldge_A1_EV0_EV1.csv", index=False)

‚è≥ Starting Inference (Evidence 0 + Evidence 1)...


  0%|          | 0/23 [00:00<?, ?it/s]

The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 23/23 [10:46<00:00, 28.13s/it]


‚úÖ Final Results:
   Cluster                    Predicted_Label
0        1                             T Cell
1        2                Fibroblast / Stroma
2        3                         Hepatocyte
3        4  Cholangiocyte (Reactive/EMT-like)
4        5                   Epithelial cells





In [9]:
# ==========================================
# 7. FORMAT & SAVE FINAL OUTPUT
# ==========================================

# 1. Select only the columns you need
final_output = df_results[['Cluster', 'Predicted_Label']].copy()

# 2. Rename columns to match your desired format
final_output.columns = ['Cluster', 'Cell Type']

# 3. Clean up the 'Cell Type' column if needed (optional)
# For example, removing "Cholangiocyte (Tumor)" -> just "Tumor" if you wanted, 
# but keeping the full name is usually better for scientific accuracy.

# 4. Display the first few rows to verify
print("\nüìù Final Formatted Table:")
print(final_output.head())

# 5. Save to CSV
output_filename = "All_cluster_v5_baseline_knoweldge_A1_EV0_EV1.csv"
final_output.to_csv(output_filename, index=False)

print(f"\n‚úÖ Successfully saved results to: {output_filename}")


üìù Final Formatted Table:
   Cluster                          Cell Type
0        1                             T Cell
1        2                Fibroblast / Stroma
2        3                         Hepatocyte
3        4  Cholangiocyte (Reactive/EMT-like)
4        5                   Epithelial cells

‚úÖ Successfully saved results to: All_cluster_v5_baseline_knoweldge_A1_EV0_EV1.csv
