In [5]:
import sys
import os, json, re
import pandas as pd
from sqlalchemy import create_engine
from datetime import datetime
import ollama
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score

# Connect to database
PG_URL = "postgresql+psycopg2://postgres:4030@localhost:5432/omop_sandbox"
engine = create_engine(PG_URL)

# ============================================================================
# STEP 1: Sample patients (ICD+ and ICD- balanced)
# ============================================================================

patients = pd.read_sql("""
    WITH all_patients_with_notes AS (
    SELECT DISTINCT subject_id
    FROM mimic_omop.notes_norm
    WHERE text IS NOT NULL AND LENGTH(text) > 50
),
icd_positive AS (
    SELECT DISTINCT subject_id
    FROM mimic_omop.insomnia_cohort
    WHERE insomnia_flag = 1
)
SELECT 
    a.subject_id,
    CASE 
        WHEN p.subject_id IS NOT NULL THEN 1 
        ELSE 0 
    END AS icd_insomnia
FROM all_patients_with_notes a
LEFT JOIN icd_positive p
    ON a.subject_id = p.subject_id
ORDER BY a.subject_id;

""", engine)

In [6]:

# Sample 25 ICD+ and 25 ICD- patients for balanced evaluation
icd_positive = patients[patients["icd_insomnia"] == 1].sample(40, random_state=42)
icd_negative = patients[patients["icd_insomnia"] == 0].sample(40, random_state=42)

sample_patients = pd.concat([icd_positive, icd_negative])
sample_patient_ids = sample_patients["subject_id"].tolist()

print(f"Sampled {len(sample_patients)} patients:")
print(f"   - ICD+ (insomnia): {len(icd_positive)}")
print(f"   - ICD- (no insomnia): {len(icd_negative)}")

# ============================================================================
# STEP 2: Load notes for sampled patients
# ============================================================================

notes = pd.read_sql(f"""
    SELECT subject_id, hadm_id, text AS note_text
    FROM mimic_omop.notes_norm
    WHERE subject_id IN ({",".join(map(str, sample_patient_ids))})
      AND text IS NOT NULL AND LENGTH(text) > 50;
""", engine)

notes = notes.reset_index().rename(columns={"index": "note_rowid"})
print(f"Loaded {len(notes)} notes from {len(sample_patients)} patients")

# ============================================================================
# STEP 3: Define vocabulary and detection functions
# ============================================================================

SLEEP_TERMS = [
    "insomnia", "sleep onset", "sleep maintenance", "early awakening",
    "trouble sleeping", "difficulty sleeping", "can't sleep", "cant sleep",
    "sleep latency", "sleeplessness", "not sleeping", "poor sleep",
    "restless sleep", "hard to fall asleep", "sleep problem"
]

IMPAIR_TERMS = [
    "fatigue", "tired", "daytime sleepiness", "somnolence", "malaise",
    "irritable", "irritability", "poor concentration", "attention",
    "memory", "impaired performance", "decreased motivation",
    "errors", "accidents", "dissatisfaction with sleep",
    "low energy", "hard to concentrate", "sleepy", "tiredness", "can't concentrate"
]

PRIMARY_MED_TERMS = [
    "zolpidem", "zaleplon", "eszopiclone", "temazepam",
    "triazolam", "ramelteon", "suvorexant", "lemborexant"
]

SECONDARY_MED_TERMS = [
    "trazodone", "mirtazapine", "melatonin",
    "hydroxyzine", "doxepin"
]

def split_sentences(t):
    """Split text into sentences"""
    sents = re.split(r'(?<=[.!?])\s+', t.strip())
    return [s.strip()[:1000] for s in sents if 5 < len(s) < 1000]

def is_candidate(sent):
    s = sent.lower()
    
    #For a more inclusive coverage, we add extra terms 
    # Core insomnia-specific terms
    insomnia_specific = [
        "sleep", "awake", "awakening", "insomnia",
        "sleepless", "sleeping"
    ]
    
    # More targeted rest terms
    rest_specific = [
        "restless", "resting difficulty", "unable to rest",
        "difficulty resting", "can't rest"
    ]
    
    # Night-specific 
    night_sleep = [
        "at night", "during the night", "night time",
        "nighttime", "nightly"
    ]
    
    return (
        any(term in s for term in insomnia_specific) or
        any(term in s for term in rest_specific) or
        any(term in s for term in night_sleep) or
        any(w in s for w in SLEEP_TERMS) or
        any(w in s for w in IMPAIR_TERMS) or
        any(w in s for w in PRIMARY_MED_TERMS) or
        any(w in s for w in SECONDARY_MED_TERMS)
    )

# ============================================================================
# STEP 4: Extract candidate sentences
# ============================================================================

rows = []
for _, r in notes.iterrows():
    sents = split_sentences(r["note_text"])
    for i, s in enumerate(sents):
        if is_candidate(s):
            rows.append({
                "subject_id": r["subject_id"],
                "hadm_id": r["hadm_id"],
                "note_rowid": r["note_rowid"],
                "sent_id": i,
                "text_span": s
            })

cands = pd.DataFrame(rows)
print(f"Extracted {len(cands)} candidate sentences")



# ============================================================================
# STEP 5: LLM classification
# ============================================================================

SYSTEM_PROMPT = """
You are a clinical NLP assistant identifying evidence of insomnia in clinical notes. 
Your task is to classify a SINGLE SENTENCE. Do NOT infer anything not explicitly stated.

---------------------------------------
INSOMNIA CRITERIA DEFINITIONS
---------------------------------------

1. Sleep Difficulty (asserts_sleep_difficulty = true):
   Indicates difficulty initiating, maintaining, or restoring sleep, OR dissatisfaction with sleep.
   Examples include:
     - "trouble sleeping", "difficulty falling asleep", "early awakening"
     - "poor sleep", "restless sleep", "can't sleep", "insomnia"

2. Daytime Impairment (asserts_daytime_impairment = true):
   Evidence that poor sleep causes daytime consequences:
     - fatigue, tiredness, sleepiness, impaired concentration, irritability, low energy
   Only mark true if the impairment is PRESENT in the sentence.

3. Primary Insomnia Medications (asserts_primary_med = true):
   Medications prescribed PRIMARILY for insomnia:
     zolpidem (Ambien), zaleplon (Sonata), eszopiclone (Lunesta),
     temazepam (Restoril), triazolam (Halcion), suvorexant (Belsomra),
     lemborexant (Dayvigo), ramelteon (Rozerem).

4. Secondary Medications (asserts_secondary_med = true):
   Medications SOMETIMES used for insomnia, **even if prescribed for another condition**:
     trazodone, mirtazapine, melatonin, hydroxyzine,
     doxepin, gabapentin, quetiapine (Seroquel),
     olanzapine (Zyprexa), clonazepam (Klonopin),
     lorazepam (Ativan), diazepam (Valium).

---------------------------------------
NEGATION HANDLING
---------------------------------------
negated = true if the sentence explicitly states the absence of sleep difficulty,
impairment, or medication usage.
Examples:
  - "denies insomnia", "no difficulty sleeping", "not taking Ambien"
Otherwise negated = false.

---------------------------------------
TEMPORALITY
---------------------------------------
Choose ONE:
  - "current": symptoms/medication clearly present now
  - "historical": symptoms/medication were in the past
  - "uncertain": unclear timing or general statements

Be inclusive: when uncertain, lean toward "current".

---------------------------------------
STRICT JSON RESPONSE
---------------------------------------
Respond ONLY with valid JSON in this exact format:

{
  "asserts_sleep_difficulty": bool,
  "asserts_daytime_impairment": bool,
  "asserts_primary_med": bool,
  "asserts_secondary_med": bool,
  "negated": bool,
  "temporality": "current" | "historical" | "uncertain"
}

---------------------------------------
EXAMPLE
---------------------------------------
Sentence: "Patient reports difficulty falling asleep and feels tired during the day."

JSON:
{
  "asserts_sleep_difficulty": true,
  "asserts_daytime_impairment": true,
  "asserts_primary_med": false,
  "asserts_secondary_med": false,
  "negated": false,
  "temporality": "current"
}

"""

def extract_json(text):
    m = re.search(r"\{.*\}", text, re.DOTALL)
    if not m:
        return {"error": "no JSON"}
    try:
        return json.loads(m.group(0))
    except:
        return {"error": "bad JSON", "raw": m.group(0)}

def classify_sentence_ollama(text: str):
    resp = ollama.chat(
        model="llama3:8b",
        messages=[{"role": "user", "content": SYSTEM_PROMPT + f'\nSentence: "{text}"'}]
    )
    parsed = extract_json(resp["message"]["content"])

    if "temporality" not in parsed:
        parsed["temporality"] = "uncertain"

    return parsed

# Run LLM classification
out = []
for _, r in cands.iterrows():
    y = classify_sentence_ollama(r["text_span"])
    out.append({**r, **y})

ev = pd.DataFrame(out)
print(f"Classified {len(ev)} sentences")

# ============================================================================
# STEP 6: Patient-level aggregation
# ============================================================================

# ============================================================================
# STEP 6: Patient-level aggregation (Updated according to full definitions)
# ============================================================================

ev["is_sleep"] = ev["asserts_sleep_difficulty"] & ~ev["negated"] & (ev["temporality"] == "current")
ev["is_impair"] = ev["asserts_daytime_impairment"] & ~ev["negated"] & (ev["temporality"] == "current")
ev["is_primary"] = ev["asserts_primary_med"] & ~ev["negated"] & (ev["temporality"] == "current")
ev["is_secondary"] = ev["asserts_secondary_med"] & ~ev["negated"] & (ev["temporality"] == "current")

agg = ev.groupby("subject_id").agg({
    "is_sleep": "max",
    "is_impair": "max",
    "is_primary": "max",
    "is_secondary": "max"
}).reset_index()

# Rule A: Difficulty sleeping + daytime impairment
agg["rule_a_text"] = agg["is_sleep"] & agg["is_impair"]

# Rule B: Any primary insomnia medication
agg["rule_b_text"] = agg["is_primary"]

# Rule C: Secondary med + (sleep difficulty OR daytime impairment)
agg["rule_c_text"] = agg["is_secondary"] & (agg["is_sleep"] | agg["is_impair"])

# Any positive rule
agg["any_text"] = (agg["rule_a_text"] | agg["rule_b_text"] | agg["rule_c_text"]).astype(int)









Sampled 80 patients:
   - ICD+ (insomnia): 40
   - ICD- (no insomnia): 40
Loaded 265 notes from 80 patients
Extracted 472 candidate sentences
Classified 472 sentences


TypeError: bad operand type for unary ~: 'float'

In [27]:
# ======================================================================
# STEP 6 — SAFE PATIENT-LEVEL AGGREGATION (NO TEMPORALITY FILTER)
# ======================================================================

# Ensure correct boolean types
bool_cols = [
    "asserts_sleep_difficulty",
    "asserts_daytime_impairment",
    "asserts_primary_med",
    "asserts_secondary_med",
    "negated"
]

for col in bool_cols:
    ev[col] = ev[col].fillna(False).astype(bool)

# -----------------------------
# FEATURE DEFINITIONS (NO TEMPORALITY)
# -----------------------------
ev["is_sleep"] = (
    ev["asserts_sleep_difficulty"] &
    ~ev["negated"]
)

ev["is_impair"] = (
    ev["asserts_daytime_impairment"] &
    ~ev["negated"]
)

ev["is_primary"] = (
    ev["asserts_primary_med"] &
    ~ev["negated"]
)

ev["is_secondary"] = (
    ev["asserts_secondary_med"] &
    ~ev["negated"]
)

# -----------------------------
# RAW (LLM output only)
# -----------------------------
ev["sleep_raw"] = ev["asserts_sleep_difficulty"].astype(bool)
ev["impair_raw"] = ev["asserts_daytime_impairment"].astype(bool)

# -----------------------------
# PATIENT-LEVEL AGGREGATION
# -----------------------------
agg = ev.groupby("subject_id").agg({
    "is_sleep": "max",
    "is_impair": "max",
    "is_primary": "max",
    "is_secondary": "max"
}).reset_index()

agg_raw = ev.groupby("subject_id").agg({
    "sleep_raw": "max",
    "impair_raw": "max"
}).reset_index()

agg = agg.merge(agg_raw, on="subject_id", how="left")

# -----------------------------
# RULE DEFINITIONS (MATCH ICD LOGIC)
# -----------------------------

# Rule A: symptoms + impairment (raw signals, no temporality restriction)
agg["rule_a_text"] = (agg["sleep_raw"] & agg["impair_raw"]).astype(int)

# Rule B: insomnia-specific medications
agg["rule_b_text"] = agg["is_primary"].astype(int)

# Rule C: secondary meds + symptoms/impairment (raw recommended)
agg["rule_c_text"] = (
    agg["is_secondary"] &
    (agg["sleep_raw"] | agg["impair_raw"])
).astype(int)

# Any rule
agg["any_text"] = (
    agg["rule_a_text"] |
    agg["rule_b_text"] |
    agg["rule_c_text"]
).astype(int)


In [None]:
# ============================================================================
# STEP 7: Load gold standard and merge
# ============================================================================

gold = pd.read_sql("""
    SELECT subject_id, rule_a, rule_b, rule_c, insomnia_flag
    FROM mimic_omop.insomnia_cohort
""", engine)

# Rename rule columns
gold = gold.rename(columns={
    "rule_a": "rule_a_gold",
    "rule_b": "rule_b_gold",
    "rule_c": "rule_c_gold",
    "insomnia_flag": "any_gold"
})

# Convert to int (now all columns exist)
for col in ["rule_a_gold", "rule_b_gold", "rule_c_gold", "any_gold"]:
    gold[col] = gold[col].astype(int)

# Merge with the sampled patients
gold_with_icd = sample_patients.merge(gold, on="subject_id", how="left")

# Missing entries = no insomnia evidence → set to 0
gold_with_icd[["rule_a_gold", "rule_b_gold", "rule_c_gold", "any_gold"]] = \
    gold_with_icd[["rule_a_gold", "rule_b_gold", "rule_c_gold", "any_gold"]].fillna(0).astype(int)

# Merge LLM predictions
df = gold_with_icd.merge(agg, on="subject_id", how="left").fillna(0)

# Convert LLM predictions to int
for col in ["rule_a_text", "rule_b_text", "rule_c_text", "any_text"]:
    df[col] = df[col].astype(int)

print(f"Final merged dataframe shape: {df.shape}")


Final merged dataframe shape: (80, 16)


In [29]:
print("\n=== DIAGNOSTIC 1: Raw LLM Mentions BEFORE TEMPORALITY ===")
print("Raw sleep detections:", ev["sleep_raw"].sum())
print("Raw impairment detections:", ev["impair_raw"].sum())

print("\n=== DIAGNOSTIC 2: After temporality filtering ===")
print("is_sleep:", ev["is_sleep"].sum())
print("is_impair:", ev["is_impair"].sum())

print("\n=== DIAGNOSTIC 3: Patient-level RAW evidence ===")
print("sleep_raw (patient-level):")
print(agg["sleep_raw"].value_counts())
print("\nimpair_raw (patient-level):")
print(agg["impair_raw"].value_counts())

print("\n=== DIAGNOSTIC 4: Rule A comparison ===")
print("Gold positives:", df["rule_a_gold"].sum())
print("Text positives (new Rule A):", agg["rule_a_text"].sum())



=== DIAGNOSTIC 1: Raw LLM Mentions BEFORE TEMPORALITY ===
Raw sleep detections: 245
Raw impairment detections: 82

=== DIAGNOSTIC 2: After temporality filtering ===
is_sleep: 243
is_impair: 81

=== DIAGNOSTIC 3: Patient-level RAW evidence ===
sleep_raw (patient-level):
sleep_raw
True     45
False    13
Name: count, dtype: int64

impair_raw (patient-level):
impair_raw
False    34
True     24
Name: count, dtype: int64

=== DIAGNOSTIC 4: Rule A comparison ===
Gold positives: 2
Text positives (new Rule A): 20


In [31]:
print("Rule A gold positive count:", df["rule_a_gold"].sum())
print("Rule A gold negative count:", (df["rule_a_gold"] == 0).sum())


Rule A gold positive count: 2
Rule A gold negative count: 78


In [30]:
# ============================================================================
# STEP 8: Evaluation
# ============================================================================

print("\n" + "="*70)
print("EVALUATION RESULTS - SENTENCE-BASED APPROACH")
print("="*70)

def evaluate(true, pred, label):
    print(f"\n=== {label} ===")
    cm = confusion_matrix(true, pred)
    print("Confusion Matrix:")
    print(cm)
    
    prec = precision_score(true, pred, zero_division=0)
    rec = recall_score(true, pred, zero_division=0)
    f1 = f1_score(true, pred, zero_division=0)
    
    print(f"Precision: {prec:.3f}")
    print(f"Recall:    {rec:.3f}")
    print(f"F1 Score:  {f1:.3f}")
    
    return {"precision": prec, "recall": rec, "f1": f1, "cm": cm}

results = {}
results["Rule A"] = evaluate(df["rule_a_gold"], df["rule_a_text"], "Rule A (Symptoms)")
results["Rule B"] = evaluate(df["rule_b_gold"], df["rule_b_text"], "Rule B (Primary Meds)")
results["Rule C"] = evaluate(df["rule_c_gold"], df["rule_c_text"], "Rule C (Secondary Meds)")
results["Any Rule"] = evaluate(df["any_gold"], df["any_text"], "Any Rule (Insomnia)")


EVALUATION RESULTS - SENTENCE-BASED APPROACH

=== Rule A (Symptoms) ===
Confusion Matrix:
[[58 20]
 [ 2  0]]
Precision: 0.000
Recall:    0.000
F1 Score:  0.000

=== Rule B (Primary Meds) ===
Confusion Matrix:
[[45  8]
 [11 16]]
Precision: 0.667
Recall:    0.593
F1 Score:  0.627

=== Rule C (Secondary Meds) ===
Confusion Matrix:
[[56 12]
 [ 5  7]]
Precision: 0.368
Recall:    0.583
F1 Score:  0.452

=== Any Rule (Insomnia) ===
Confusion Matrix:
[[31  9]
 [14 26]]
Precision: 0.743
Recall:    0.650
F1 Score:  0.693


In [26]:
print(ev[ev["is_sleep"] == True][["subject_id","text_span"]].sample(20))


     subject_id                                          text_span
441    18570152  # HL: continued fenofibrate\n \n# DM: diabetic...
411    19342580  The modifiable risk factors identified were as...
207    13561687  He received a dose of Ativan with improvement ...
119    13512753  CHRONIC ISSUES:  \n# Insomnia: Continued trazo...
386    17230631  TraZODone 25 mg PO QHS:PRN insomnia \nRX *traz...
58     10569306  She is fatigued and tired, hoping to \nget som...
293    14744455                #Insomnia: Continued home zolpidem.
106    13477790  Reported sleep has been \n"okay" stating she g...
82     11464800  Zolpidem Tartrate 2.5-5 mg PO HS:PRN insomnia ...
113    13477790  Patient reported improved sleep and fewer nigh...
259    14010911  Additional information acquired through \ntran...
335    16296993  She reports increasing fatigue over the past w...
448    18728882  Endorses poor sleep, problems with initiation,...
351    16852131               She was \nawake, alert, oriented

In [17]:
# ============================================================================
# STEP 8: Evaluation
# ============================================================================

print("\n" + "="*70)
print("EVALUATION RESULTS - SENTENCE-BASED APPROACH")
print("="*70)

def evaluate(true, pred, label):
    print(f"\n=== {label} ===")
    cm = confusion_matrix(true, pred)
    print("Confusion Matrix:")
    print(cm)

    prec = precision_score(true, pred, zero_division=0)
    rec = recall_score(true, pred, zero_division=0)
    f1 = f1_score(true, pred, zero_division=0)

    print(f"Precision: {prec:.3f}")
    print(f"Recall:    {rec:.3f}")
    print(f"F1 Score:  {f1:.3f}")

    return {"precision": prec, "recall": rec, "f1": f1, "cm": cm}


# Ensure correct dtype for gold & predictions
for col in ["rule_a_gold", "rule_b_gold", "rule_c_gold", "any_gold",
            "rule_a_text", "rule_b_text", "rule_c_text", "any_text"]:
    df[col] = df[col].astype(int)

results = {}
results["Rule A"] = evaluate(df["rule_a_gold"], df["rule_a_text"], "Rule A (Symptoms)")
results["Rule B"] = evaluate(df["rule_b_gold"], df["rule_b_text"], "Rule B (Primary Meds)")
results["Rule C"] = evaluate(df["rule_c_gold"], df["rule_c_text"], "Rule C (Secondary Meds)")
results["Any Rule"] = evaluate(df["any_gold"], df["any_text"], "Any Rule (Insomnia)")


# ============================================================================
# STEP 9: Comprehensive Classification Analysis
# ============================================================================

print("\n" + "="*70)
print("COMPREHENSIVE CLASSIFICATION ANALYSIS")
print("="*70)

def analyze_classification_outcomes(gold_col, pred_col, rule_name):

    print(f"\n{'='*70}")
    print(f"{rule_name.upper()}")
    print(f"{'='*70}")

    true_pos  = df[(df[gold_col] == 1) & (df[pred_col] == 1)]
    true_neg  = df[(df[gold_col] == 0) & (df[pred_col] == 0)]
    false_pos = df[(df[gold_col] == 0) & (df[pred_col] == 1)]
    false_neg = df[(df[gold_col] == 1) & (df[pred_col] == 0)]

    total = len(df)

    print(f"\nOVERALL DISTRIBUTION:")
    print(f"  Total patients: {total}")
    print(f"  Gold positive:  {df[gold_col].sum()}")
    print(f"  Gold negative:  {(df[gold_col] == 0).sum()}")
    print(f"  LLM predicted positive: {df[pred_col].sum()}")
    print(f"  LLM predicted negative: {(df[pred_col] == 0).sum()}")

    print(f"\nCLASSIFICATION OUTCOMES:")
    print(f"  True Positives:  {len(true_pos):2d}")
    print(f"  True Negatives:  {len(true_neg):2d}")
    print(f"  False Positives: {len(false_pos):2d}")
    print(f"  False Negatives: {len(false_neg):2d}")

    # ------------------------
    # TRUE POSITIVES
    # ------------------------
    if len(true_pos) > 0:
        print(f"\n--- TRUE POSITIVES ({len(true_pos)}) ---")
        print("Patient IDs:", true_pos["subject_id"].tolist())

        for idx, (_, row) in enumerate(true_pos.iterrows(), 1):
            pid = row["subject_id"]
            patient_ev = ev[ev["subject_id"] == pid]

            print(f"\n  [{idx}] Patient {pid}")
            print(f"      Evidence sentences: {len(patient_ev)}")

            if "rule_a" in rule_name.lower():
                print(f"      Sleep mentions: {patient_ev['is_sleep'].sum()}")
                print(f"      Daytime impairment mentions: {patient_ev['is_impair'].sum()}")

            if "rule_b" in rule_name.lower():
                print(f"      Primary med mentions: {patient_ev['is_primary'].sum()}")

            if "rule_c" in rule_name.lower():
                print(f"      Secondary med mentions: {patient_ev['is_secondary'].sum()}")


    # ------------------------
    # TRUE NEGATIVES
    # ------------------------
    if len(true_neg) > 0:
        print(f"\n--- TRUE NEGATIVES ({len(true_neg)}) ---")
        print("Patient IDs:", true_neg["subject_id"].tolist())
        print("No evidence detected (correct).")


    # ------------------------
    # FALSE POSITIVES
    # ------------------------
    if len(false_pos) > 0:
        print(f"\n--- FALSE POSITIVES ({len(false_pos)}) ---")
        print("Patient IDs:", false_pos["subject_id"].tolist())

        for idx, (_, row) in enumerate(false_pos.iterrows(), 1):
            pid = row["subject_id"]
            patient_ev = ev[ev["subject_id"] == pid]

            print(f"\n  [{idx}] Patient {pid}")
            print("      Incorrectly flagged by LLM")
            print(f"      Sleep: {patient_ev['is_sleep'].sum()}  Impair: {patient_ev['is_impair'].sum()}")
            print(f"      Primary: {patient_ev['is_primary'].sum()}  Secondary: {patient_ev['is_secondary'].sum()}")


    # ------------------------
    # FALSE NEGATIVES
    # ------------------------
    if len(false_neg) > 0:
        print(f"\n--- FALSE NEGATIVES ({len(false_neg)}) ---")
        print("Patient IDs:", false_neg["subject_id"].tolist())

        for idx, (_, row) in enumerate(false_neg.iterrows(), 1):
            pid = row["subject_id"]
            patient_ev = ev[ev["subject_id"] == pid]

            print(f"\n  [{idx}] Patient {pid}")
            print("      LLM missed a real case")
            print(f"      Evidence sentences found: {len(patient_ev)}")


# Run for every rule
analyze_classification_outcomes("rule_a_gold", "rule_a_text", "Rule A")
analyze_classification_outcomes("rule_b_gold", "rule_b_text", "Rule B")
analyze_classification_outcomes("rule_c_gold", "rule_c_text", "Rule C")
analyze_classification_outcomes("any_gold", "any_text", "Any Rule")


# ============================================================================
# STEP 10: Summary Table
# ============================================================================

print("\n" + "="*70)
print("SUMMARY COMPARISON TABLE")
print("="*70 + "\n")

comparison = pd.DataFrame({
    'Rule': ['A: Symptoms', 'B: Primary Meds', 'C: Secondary Meds', 'Any Rule'],
    'Gold +': [
        df['rule_a_gold'].sum(),
        df['rule_b_gold'].sum(),
        df['rule_c_gold'].sum(),
        df['any_gold'].sum()
    ],
    'LLM +': [
        df['rule_a_text'].sum(),
        df['rule_b_text'].sum(),
        df['rule_c_text'].sum(),
        df['any_text'].sum()
    ]
})

# Add confusion components
comparison["TP"] = [
    ((df['rule_a_gold']==1) & (df['rule_a_text']==1)).sum(),
    ((df['rule_b_gold']==1) & (df['rule_b_text']==1)).sum(),
    ((df['rule_c_gold']==1) & (df['rule_c_text']==1)).sum(),
    ((df['any_gold']==1) & (df['any_text']==1)).sum()
]

comparison["FN"] = [
    ((df['rule_a_gold']==1) & (df['rule_a_text']==0)).sum(),
    ((df['rule_b_gold']==1) & (df['rule_b_text']==0)).sum(),
    ((df['rule_c_gold']==1) & (df['rule_c_text']==0)).sum(),
    ((df['any_gold']==1) & (df['any_text']==0)).sum()
]

comparison["FP"] = [
    ((df['rule_a_gold']==0) & (df['rule_a_text']==1)).sum(),
    ((df['rule_b_gold']==0) & (df['rule_b_text']==1)).sum(),
    ((df['rule_c_gold']==0) & (df['rule_c_text']==1)).sum(),
    ((df['any_gold']==0) & (df['any_text']==1)).sum()
]

comparison["TN"] = [
    ((df['rule_a_gold']==0) & (df['rule_a_text']==0)).sum(),
    ((df['rule_b_gold']==0) & (df['rule_b_text']==0)).sum(),
    ((df['rule_c_gold']==0) & (df['rule_c_text']==0)).sum(),
    ((df['any_gold']==0) & (df['any_text']==0)).sum()
]

comparison["Precision"] = comparison["TP"] / (comparison["TP"] + comparison["FP"]).replace({0:0})
comparison["Recall"]    = comparison["TP"] / (comparison["TP"] + comparison["FN"]).replace({0:0})
comparison["F1"]        = 2 * (comparison["Precision"] * comparison["Recall"]) / (comparison["Precision"] + comparison["Recall"]).replace({0:0})
comparison["Accuracy"]  = (comparison["TP"] + comparison["TN"]) / len(df)

print(comparison.round(3).to_string(index=False))


# ============================================================================
# STEP 11: Temporality Analysis
# ============================================================================

print("\n" + "="*70)
print("TEMPORALITY IMPACT ANALYSIS")
print("="*70)

print(ev["temporality"].value_counts())

print("\nPrimary Med by temporality:")
print(ev[ev["asserts_primary_med"]].groupby("temporality").size())

print("\nSecondary Med by temporality:")
print(ev[ev["asserts_secondary_med"]].groupby("temporality").size())




EVALUATION RESULTS - SENTENCE-BASED APPROACH

=== Rule A (Symptoms) ===
Confusion Matrix:
[[60 18]
 [ 2  0]]
Precision: 0.000
Recall:    0.000
F1 Score:  0.000

=== Rule B (Primary Meds) ===
Confusion Matrix:
[[46  7]
 [11 16]]
Precision: 0.696
Recall:    0.593
F1 Score:  0.640

=== Rule C (Secondary Meds) ===
Confusion Matrix:
[[57 11]
 [ 5  7]]
Precision: 0.389
Recall:    0.583
F1 Score:  0.467

=== Any Rule (Insomnia) ===
Confusion Matrix:
[[32  8]
 [14 26]]
Precision: 0.765
Recall:    0.650
F1 Score:  0.703

COMPREHENSIVE CLASSIFICATION ANALYSIS

RULE A

OVERALL DISTRIBUTION:
  Total patients: 80
  Gold positive:  2
  Gold negative:  78
  LLM predicted positive: 18
  LLM predicted negative: 62

CLASSIFICATION OUTCOMES:
  True Positives:   0
  True Negatives:  60
  False Positives: 18
  False Negatives:  2

--- TRUE NEGATIVES (60) ---
Patient IDs: [17406776, 10200543, 15889689, 18922904, 11464800, 14644835, 11411009, 19232236, 19060383, 14065960, 18602508, 14021871, 12659223, 19703

KeyError: "None of [Index([0, 0, 0, 0, 0, 0, 1, 1, 0, 0,\n       ...\n       0, 0, 1, 1, 1, 0, 0, 0, 0, 0],\n      dtype='int8', length=472)] are in the [columns]"

In [16]:
# ============================================================================
# STEP 12: Save Results (fix all types before parquet)
# ============================================================================

print("\n" + "="*70)
print("CLEANING DATA TYPES BEFORE SAVING")
print("="*70)

# ---------------------------------------------------------------------------
# 1) Clean EV dataframe
# ---------------------------------------------------------------------------
ev = ev.copy()

bool_like_cols_ev = [
    "is_sleep", "is_impair", "is_primary", "is_secondary",
    "asserts_sleep_difficulty", "asserts_daytime_impairment",
    "asserts_primary_med", "asserts_secondary_med",
    "negated"
]

for col in bool_like_cols_ev:
    if col in ev.columns:
        ev[col] = ev[col].map(lambda x: 1 if x in [True, "True", 1, "1"] else 0).astype("int8")

# Convert any remaining object columns to string
for col in ev.columns:
    if ev[col].dtype == "object":
        ev[col] = ev[col].astype(str)

# ---------------------------------------------------------------------------
# 2) Clean AGG dataframe
# ---------------------------------------------------------------------------
agg = agg.copy()

for col in ["rule_a_text", "rule_b_text", "rule_c_text", "any_text"]:
    if col in agg.columns:
        agg[col] = agg[col].astype(int)

# Convert any remaining object columns to string
for col in agg.columns:
    if agg[col].dtype == "object":
        agg[col] = agg[col].astype(str)

# ---------------------------------------------------------------------------
# 3) Clean DF dataframe (the one that fails!)
# ---------------------------------------------------------------------------
df = df.copy()

# Convert rule columns explicitly
rule_cols = [
    "rule_a_text", "rule_b_text", "rule_c_text", "any_text",
    "rule_a_gold", "rule_b_gold", "rule_c_gold", "any_gold",
    "icd_insomnia"
]

for col in rule_cols:
    if col in df.columns:
        df[col] = df[col].astype(int)

# Fix ALL remaining object columns safely
for col in df.columns:
    if df[col].dtype == "object":
        # Try to coerce booleans first
        df[col] = df[col].map(
            lambda x: 1 if x in [True, "True", 1, "1"] else
                      0 if x in [False, "False", 0, "0"] else
                      str(x)
        )
        # After mapping, ensure no object types remain
        if df[col].dtype == "object":
            df[col] = df[col].astype(str)

# ---------------------------------------------------------------------------
# SAVE FILES
# ---------------------------------------------------------------------------

notes.to_parquet("notes_sample_balanced.parquet", engine="fastparquet", index=False)
ev.to_parquet("ev_sentence_level_balanced.parquet", engine="fastparquet", index=False)
agg.to_parquet("agg_patient_level_balanced.parquet", engine="fastparquet", index=False)
df.to_parquet("df_evaluation_balanced.parquet", engine="fastparquet", index=False)

false_neg = df[(df["any_gold"] == 1) & (df["any_text"] == 0)]
false_pos = df[(df["icd_insomnia"] == 0) & (df["any_text"] == 1)]

false_neg.to_csv("false_negatives_balanced.csv", index=False)
false_pos.to_csv("false_positives_balanced.csv", index=False)

print("\nAll results saved successfully!")
print("Files created:")
print("  - notes_sample_balanced.parquet")
print("  - ev_sentence_level_balanced.parquet")
print("  - agg_patient_level_balanced.parquet")
print("  - df_evaluation_balanced.parquet")
print("  - false_negatives_balanced.csv")
print("  - false_positives_balanced.csv")



CLEANING DATA TYPES BEFORE SAVING

All results saved successfully!
Files created:
  - notes_sample_balanced.parquet
  - ev_sentence_level_balanced.parquet
  - agg_patient_level_balanced.parquet
  - df_evaluation_balanced.parquet
  - false_negatives_balanced.csv
  - false_positives_balanced.csv
