In [None]:
"""
Disease topic classification with LLMs.

- Loads few-shot examples and test notes from 'data-note-with-diseases-label.xlsx'
- Runs GPT-4.1-mini (A), GPT-4.1 (B), and LLaMA 3 8B (C)
- For each model: four prompt variants (v0-v3)
- Writes predictions to 'tests_with_predictions_all_models.csv'
"""

In [None]:
#%pip install openai scikit-learn huggingface_hub pandas numpy

In [None]:
import os
import time


import pandas as pd
import numpy as np
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

from openai import OpenAI
from huggingface_hub import InferenceClient


# config / setup 


openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

llama_client = InferenceClient(
    api_key=os.getenv("HF_TOKEN"),
)

LLAMA_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct:novita"

file_path = "data-note-with-diseases-label.xlsx"
few = pd.read_excel(file_path, sheet_name="few-shot examples")
tests = pd.read_excel(file_path, sheet_name="tests-with-the-ground-truth")

label_set = sorted(
    lab.strip()
    for s in few["diseases_label"]
    if isinstance(s, str)
    for lab in s.split(",")
)
print("Labels:", label_set)



In [None]:
# helpers 

def call_openai(prompt, model="gpt-4.1-mini", max_retries=3, sleep_sec=2):
    for attempt in range(max_retries):
        try:
            resp = openai_client.responses.create(
                model=model,
                input=prompt,
                max_output_tokens=128,
            )
            parts = resp.output[0].content
            text = "".join(p.text for p in parts if hasattr(p, "text"))
            return text.strip()
        except Exception as e:
            print(f"[{model}] error {attempt+1}/{max_retries}: {e}")
            time.sleep(sleep_sec)
    return ""


def call_llama(prompt, max_tokens=128):
    try:
        completion = llama_client.chat.completions.create(
            model=LLAMA_MODEL,
            messages=[{"role": "user", "content": prompt}],
            max_tokens=max_tokens,
            temperature=0.0,
        )
        # HF OpenAI-compatible API: message is dict-like
        msg = completion.choices[0].message
        if isinstance(msg, dict):
            return msg.get("content", "").strip()
        # fallback if it's an object
        return getattr(msg, "content", "").strip()
    except Exception as e:
        print(f"[llama] error: {e}")
        return ""


def parse_labels(raw_text, label_list):
    if not raw_text:
        return []

    text = raw_text.strip()
    if text.lower().startswith("labels:"):
        text = text[len("labels:"):].strip()

    parts = [p.strip() for p in text.split(",") if p.strip()]
    allowed = {lab.lower(): lab for lab in label_list}

    out = []
    for p in parts:
        lab = allowed.get(p.lower())
        if lab is not None:
            out.append(lab)

    # deduplicate, keep order
    return list(dict.fromkeys(out))


def parse_labels_v3(raw_text, label_list):
    if not raw_text:
        return []

    lines = [ln.strip() for ln in raw_text.splitlines() if ln.strip()]
    label_line = None
    for ln in reversed(lines):
        if ln.lower().startswith("labels:"):
            label_line = ln
            break

    if label_line is None:
        return parse_labels(raw_text, label_list)

    after = label_line[len("labels:"):].strip()
    return parse_labels(after, label_list)


label_definitions = {
    "periodontitis": "inflammatory disease affecting the supporting tissues of the teeth",
    "cerebral aneurysms": "focal dilatations of intracranial arteries in the brain",
    "posterior parietal neoplasm": "tumour located in the posterior parietal region of the brain",
    "glioblastoma multiforme": "aggressive high-grade primary brain tumour",
    "status epilepticus": "seizure activity lasting more than 5 minutes or repeated seizures without recovery",
    "hemiparesis": "weakness of one side of the body (left or right)",
    "pneumonia": "infection of the lung parenchyma with cough, fever or respiratory findings",
    "sepsis": "life-threatening organ dysfunction caused by a dysregulated host response to infection",
    "shortness of breath": "dyspnoea or difficulty breathing",
    "cough": "expulsive reflex to clear airways, often associated with respiratory disease",
}



In [None]:
# prompt builders

def build_prompt_v0(note_text, label_list):
    return f"""
You are a medical expert.

Task:
Read the following patient note and assign one or more disease labels from the allowed list.

Allowed disease labels (use only these, separated by commas):
{", ".join(label_list)}

Rules:
- Use only labels from the allowed list.
- Output must be a comma-separated list of labels, with no extra text.
- If more than one disease is present, include all relevant labels.
- If none clearly apply, choose the single most likely label instead of writing "none".

Patient note:
{note_text.strip()}

Labels:
""".strip()


def build_fewshot_prompt_v1(few_df, note_text, label_list):
    examples = []
    for i, row in few_df.iterrows():
        examples.append(
            f"Example {i+1}\n"
            f"Note: {row['description'].strip()}\n"
            f"Labels: {row['diseases_label']}\n"
        )
    examples_block = "\n\n".join(examples)

    return f"""
You are a medical expert.

Task:
Read a patient note and assign one or more disease labels from the allowed list.

Allowed disease labels (use only these, separated by commas):
{", ".join(label_list)}

Rules:
- Use only labels from the allowed list.
- Output must be a comma-separated list of labels, with no extra text.
- If more than one disease is present, include all relevant labels.
- If none clearly apply, choose the single most likely label.

Here are labeled examples:

{examples_block}

Now classify this new patient note.

Note: {note_text.strip()}
Labels:
""".strip()


def build_fewshot_prompt_v2(few_df, note_text, label_list, label_defs):
    defs_block = "\n".join(
        f"- {lab}: {label_defs[lab]}" for lab in label_list
    )

    examples = []
    for i, row in few_df.iterrows():
        examples.append(
            f"Example {i+1}\n"
            f"Note: {row['description'].strip()}\n"
            f"Labels: {row['diseases_label']}\n"
        )
    examples_block = "\n\n".join(examples)

    return f"""
You are a medical expert.

Task:
Read a patient note and assign one or more disease labels from the allowed list.

Allowed disease labels and brief definitions:
{defs_block}

Rules:
- Use only labels from the allowed list.
- ALWAYS output at least one label from the allowed list.
- Output must be a comma-separated list of labels, with no extra text.
- Do NOT output words like "none" or "no disease".
- If more than one disease is present, include all relevant labels.
- If uncertain, choose the single most likely label.

Here are labeled examples:

{examples_block}

Now classify this new patient note.

Note: {note_text.strip()}
Labels:
""".strip()


def build_fewshot_prompt_v3(few_df, note_text, label_list, label_defs):
    defs_block = "\n".join(
        f"- {lab}: {label_defs[lab]}" for lab in label_list
    )

    examples = []
    for i, row in few_df.iterrows():
        examples.append(
            f"Example {i+1}\n"
            f"Note: {row['description'].strip()}\n"
            f"Reasoning: Briefly identify the key clinical features and which disease labels apply.\n"
            f"Labels: {row['diseases_label']}\n"
        )
    examples_block = "\n\n".join(examples)

    return f"""
You are a medical expert.

Task:
Read a patient note and assign one or more disease labels from the allowed list.

Allowed disease labels and brief definitions:
{defs_block}

Instructions:
- Think step by step about the key clinical findings in the note and which disease labels apply.
- Use only labels from the allowed list.
- ALWAYS output at least one label from the allowed list.
- Do NOT output words like "none", "no disease", or free text diagnoses outside the label list.
- In the final answer, strictly follow this format:

Reasoning: <1-3 sentences explaining the key findings and the chosen labels>
Labels: <comma-separated list of labels, with no extra text>

Here are labeled examples:

{examples_block}

Now classify this new patient note.

Note: {note_text.strip()}

First provide your reasoning, then the final labels on a separate line as specified.
""".strip()



In [None]:
# run models 

def run_openai_model(prompt_fn, parse_fn, base_model, col_name):
    preds = []
    arg_names = prompt_fn.__code__.co_varnames

    for _, row in tests.iterrows():
        note = row["description"]

        args = []
        # if prompt_fn expects few_df -> pass it as first argument
        if "few_df" in arg_names:
            args.append(few)
        # then always note_text, label_list
        args.append(note)
        args.append(label_set)
        # if prompt_fn expects label_defs -> pass it as last argument
        if "label_defs" in arg_names:
            args.append(label_definitions)

        prompt = prompt_fn(*args)
        raw = call_openai(prompt, model=base_model)
        preds.append(parse_fn(raw, label_set))

    tests[col_name] = [", ".join(p) for p in preds]


def run_llama_model(prompt_fn, parse_fn, col_name):
    preds = []
    arg_names = prompt_fn.__code__.co_varnames

    for _, row in tests.iterrows():
        note = row["description"]

        args = []
        if "few_df" in arg_names:
            args.append(few)
        args.append(note)
        args.append(label_set)
        if "label_defs" in arg_names:
            args.append(label_definitions)

        prompt = prompt_fn(*args)
        raw = call_llama(prompt)
        preds.append(parse_fn(raw, label_set))

    tests[col_name] = [", ".join(p) for p in preds]


# A0–A3: gpt-4.1-mini
run_openai_model(build_prompt_v0,         parse_labels,    "gpt-4.1-mini", "pred_A0_mini")
run_openai_model(build_fewshot_prompt_v1, parse_labels,    "gpt-4.1-mini", "pred_A1_mini")
run_openai_model(build_fewshot_prompt_v2, parse_labels,    "gpt-4.1-mini", "pred_A2_mini")
run_openai_model(build_fewshot_prompt_v3, parse_labels_v3, "gpt-4.1-mini", "pred_A3_mini")

# B0–B3: gpt-4.1
run_openai_model(build_prompt_v0,         parse_labels,    "gpt-4.1",      "pred_B0_gpt41")
run_openai_model(build_fewshot_prompt_v1, parse_labels,    "gpt-4.1",      "pred_B1_gpt41")
run_openai_model(build_fewshot_prompt_v2, parse_labels,    "gpt-4.1",      "pred_B2_gpt41")
run_openai_model(build_fewshot_prompt_v3, parse_labels_v3, "gpt-4.1",      "pred_B3_gpt41")

# C0–C3: LLaMA
run_llama_model(build_prompt_v0,         parse_labels,    "pred_C0_llama")
run_llama_model(build_fewshot_prompt_v1, parse_labels,    "pred_C1_llama")
run_llama_model(build_fewshot_prompt_v2, parse_labels,    "pred_C2_llama")
run_llama_model(build_fewshot_prompt_v3, parse_labels_v3, "pred_C3_llama")


# save / quick check 

print(
    tests[[
        "_id",
        "pred_A0_mini",
        "pred_A1_mini",
        "pred_A2_mini",
        "pred_A3_mini",
        "pred_B0_gpt41",
        "pred_B1_gpt41",
        "pred_B2_gpt41",
        "pred_B3_gpt41",
        "pred_C0_llama",
        "pred_C1_llama",
        "pred_C2_llama",
        "pred_C3_llama",
    ]].head(10).to_string(index=False)
)

tests.to_csv("tests_with_predictions_all_models.csv", index=False)




In [None]:
# Evaluation: compute metrics for all models
#Load data and merge predictions with ground truth

pred = pd.read_csv("tests_with_predictions_all_models.csv")
gt = pd.read_csv("data-note-with-diseases-label - tests-with-the-ground-truth.csv")

# use the diseases_label column from the ground-truth file
df = (
    pred.drop(columns=["diseases_label"], errors="ignore")
        .merge(gt[["_id", "diseases_label"]], on="_id", how="left")
)

print("Merged shape:", df.shape)
print(df[["_id", "diseases_label"]].head())

#Build the label set (10 topics from few-shot sheet)

few = pd.read_excel("data-note-with-diseases-label.xlsx", sheet_name="few-shot examples")

label_set = sorted({
    lab.strip()
    for s in few["diseases_label"]
    if isinstance(s, str)
    for lab in s.split(",")
})

print("Evaluation labels:", label_set)

label_index = {lab: i for i, lab in enumerate(label_set)}
num_labels = len(label_set)

#Helper functions

def split_labels_cell(cell):
    """Turn 'sepsis, pneumonia' into ['sepsis', 'pneumonia']."""
    if not isinstance(cell, str):
        return []
    return [x.strip() for x in cell.split(",") if x.strip()]


def to_multihot(labels):
    """Convert a list of labels into a multi-hot vector."""
    v = np.zeros(num_labels, dtype=int)
    for lab in labels:
        idx = label_index.get(lab)
        if idx is not None:
            v[idx] = 1
    return v


#ground-truth matrix
true_lists = df["diseases_label"].apply(split_labels_cell).tolist()
Y_true = np.vstack([to_multihot(lbls) for lbls in true_lists])

print("Y_true shape:", Y_true.shape)
print("Positive count per label:", dict(zip(label_set, Y_true.sum(axis=0))))

# mask: notes that have at least one of the 10 target labels
mask_positive_notes = (Y_true.sum(axis=1) > 0)
print("Number of notes with ≥1 of the 10 labels:", int(mask_positive_notes.sum()))

#Evaluation helpers

def evaluate_column(col_name, model_name, mask=None, show_per_label=False):
    """
    Print micro/macro precision/recall/F1 and subset accuracy
    for a given prediction column.
    mask: optional boolean mask to evaluate on a subset of rows.
    """
    if col_name not in df.columns:
        print(f"\n[{model_name}] Column '{col_name}' not found, skipping.")
        return

    pred_lists = df[col_name].apply(split_labels_cell).tolist()
    Y_pred_full = np.vstack([to_multihot(lbls) for lbls in pred_lists])

    if mask is not None:
        Y_t = Y_true[mask]
        Y_p = Y_pred_full[mask]
    else:
        Y_t = Y_true
        Y_p = Y_pred_full

    if Y_t.shape[0] == 0:
        print(f"\n=== {model_name} ({col_name}) ===")
        print("No samples in this scope.")
        return

    prec_micro, rec_micro, f1_micro, _ = precision_recall_fscore_support(
        Y_t, Y_p, average="micro", zero_division=0
    )
    prec_macro, rec_macro, f1_macro, _ = precision_recall_fscore_support(
        Y_t, Y_p, average="macro", zero_division=0
    )
    subset_acc = accuracy_score(Y_t, Y_p)

    print(f"\n=== {model_name} ({col_name}) ===")
    if mask is None:
        print("Scope: all test notes")
    else:
        print(f"Scope: notes with ≥1 of the 10 labels (n={int(mask.sum())})")

    print("Micro  P/R/F1:", round(prec_micro, 3), round(rec_micro, 3), round(f1_micro, 3))
    print("Macro  P/R/F1:", round(prec_macro, 3), round(rec_macro, 3), round(f1_macro, 3))
    print("Subset acc   :", round(subset_acc, 3))

    if show_per_label:
        prec_lbl, rec_lbl, f1_lbl, sup_lbl = precision_recall_fscore_support(
            Y_t, Y_p, average=None, zero_division=0
        )
        print("\nPer-label metrics:")
        for lab, p, r, f1, sup in zip(label_set, prec_lbl, rec_lbl, f1_lbl, sup_lbl):
            print(f"{lab:30s}  P={p:5.3f}  R={r:5.3f}  F1={f1:5.3f}  support={sup}")


#this version collects results into a summary table
summary_rows = []

def evaluate_column_to_summary(col_name, model_name, mask=None):
    if col_name not in df.columns:
        return

    pred_lists = df[col_name].apply(split_labels_cell).tolist()
    Y_pred_full = np.vstack([to_multihot(lbls) for lbls in pred_lists])

    if mask is not None:
        Y_t = Y_true[mask]
        Y_p = Y_pred_full[mask]
        scope = "positive_only"
    else:
        Y_t = Y_true
        Y_p = Y_pred_full
        scope = "all_notes"

    if Y_t.shape[0] == 0:
        return

    prec_micro, rec_micro, f1_micro, _ = precision_recall_fscore_support(
        Y_t, Y_p, average="micro", zero_division=0
    )
    prec_macro, rec_macro, f1_macro, _ = precision_recall_fscore_support(
        Y_t, Y_p, average="macro", zero_division=0
    )
    subset_acc = accuracy_score(Y_t, Y_p)

    summary_rows.append({
        "model": model_name,
        "column": col_name,
        "scope": scope,
        "micro_precision": prec_micro,
        "micro_recall": rec_micro,
        "micro_f1": f1_micro,
        "macro_precision": prec_macro,
        "macro_recall": rec_macro,
        "macro_f1": f1_macro,
        "subset_accuracy": subset_acc,
    })


#Define models / columns to evaluate

models = [
    ("pred_A0_mini",   "A0  gpt-4.1-mini v0"),
    ("pred_A1_mini",   "A1  gpt-4.1-mini v1"),
    ("pred_A2_mini",   "A2  gpt-4.1-mini v2"),
    ("pred_A3_mini",   "A3  gpt-4.1-mini v3"),
    ("pred_B0_gpt41",  "B0  gpt-4.1 v0"),
    ("pred_B1_gpt41",  "B1  gpt-4.1 v1"),
    ("pred_B2_gpt41",  "B2  gpt-4.1 v2"),
    ("pred_B3_gpt41",  "B3  gpt-4.1 v3"),
    ("pred_C0_llama",  "C0  LLaMA v0"),
    ("pred_C1_llama",  "C1  LLaMA v1"),
    ("pred_C2_llama",  "C2  LLaMA v2"),
    ("pred_C3_llama",  "C3  LLaMA v3"),
]

#Print evaluation for all models


print("Evaluation on all notes")

for col, name in models:
    evaluate_column(col, name, mask=None, show_per_label=False)


print("Evaluation only on notes with ≥1 target label")

for col, name in models:
    evaluate_column(col, name, mask=mask_positive_notes, show_per_label=False)

# Example detailed view for one model 
evaluate_column("pred_B1_gpt41", "B1  gpt-4.1 v1", mask=None, show_per_label=True)

#Build /save summary table

for col, name in models:
    evaluate_column_to_summary(col, name, mask=None)
    evaluate_column_to_summary(col, name, mask=mask_positive_notes)

summary_df = pd.DataFrame(summary_rows)
print("\nSummary table:")
print(summary_df)

summary_df.to_csv("model_evaluation_summary.csv", index=False)
print("\nSaved summary to 'model_evaluation_summary.csv'")
