In [1]:
!pip install seaborn

import pandas as pd
import numpy as np
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns

Collecting seaborn
  Using cached seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)
Using cached seaborn-0.13.2-py3-none-any.whl (294 kB)
Installing collected packages: seaborn
Successfully installed seaborn-0.13.2


In [3]:
import pandas as pd
import numpy as np
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns

# ============================================================================
# LOAD DATA
# ============================================================================

print("="*80)
print("LOADING EVALUATION DATA")
print("="*80)

# Load all required datasets
df = pd.read_parquet("df_evaluation_balanced.parquet")
ev = pd.read_parquet("ev_sentence_level_balanced.parquet")
notes = pd.read_parquet("notes_sample_balanced.parquet")

print(f"\nDataset sizes:")
print(f"  - Patients: {len(df)}")
print(f"  - Sentences: {len(ev)}")
print(f"  - Notes: {len(notes)}")

# ============================================================================
# 1. OVERALL PERFORMANCE METRICS
# ============================================================================

print("\n" + "="*80)
print("1. OVERALL PERFORMANCE METRICS")
print("="*80)

def calculate_metrics(y_true, y_pred, label):
    """Calculate and display comprehensive metrics"""
    
    cm = confusion_matrix(y_true, y_pred)
    tn, fp, fn, tp = cm.ravel()
    
    precision = precision_score(y_true, y_pred, zero_division=0)
    recall = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)
    accuracy = accuracy_score(y_true, y_pred)
    
    # Additional metrics
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    npv = tn / (tn + fn) if (tn + fn) > 0 else 0
    
    print(f"\n{label}")
    print("-" * len(label))
    print(f"Confusion Matrix:")
    print(f"  TN: {tn:3d}  |  FP: {fp:3d}")
    print(f"  FN: {fn:3d}  |  TP: {tp:3d}")
    print(f"\nPerformance Metrics:")
    print(f"  Accuracy:    {accuracy:.3f}")
    print(f"  Precision:   {precision:.3f}")
    print(f"  Recall:      {recall:.3f}")
    print(f"  F1-Score:    {f1:.3f}")
    print(f"  Specificity: {specificity:.3f}")
    print(f"  NPV:         {npv:.3f}")
    
    return {
        'label': label,
        'TP': tp, 'FP': fp, 'TN': tn, 'FN': fn,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'specificity': specificity,
        'npv': npv
    }

# Calculate metrics for each rule
results = []
results.append(calculate_metrics(df['rule_a_gold'], df['rule_a_text'], 
                                 "Rule A: Sleep Difficulty + Daytime Impairment"))
results.append(calculate_metrics(df['rule_b_gold'], df['rule_b_text'], 
                                 "Rule B: Primary Insomnia Medications"))
results.append(calculate_metrics(df['rule_c_gold'], df['rule_c_text'], 
                                 "Rule C: Secondary Medications + Symptoms"))
results.append(calculate_metrics(df['any_gold'], df['any_text'], 
                                 "ANY RULE: Overall Insomnia Detection"))

# Create results summary table
results_df = pd.DataFrame(results)
print("\n" + "="*80)
print("PERFORMANCE SUMMARY TABLE")
print("="*80)
print(results_df.to_string(index=False))

# ============================================================================
# 2. INFORMATION EXPRESSIVENESS ANALYSIS
# ============================================================================

print("\n\n" + "="*80)
print("2. INFORMATION EXPRESSIVENESS ANALYSIS")
print("="*80)
print("\nFollowing Li et al. framework: analyzing LLM's ability to express")
print("different clinical concepts through extracted features")

# Analyze feature extraction at sentence level
feature_stats = ev.groupby('subject_id').agg({
    'asserts_sleep_difficulty': 'sum',
    'asserts_daytime_impairment': 'sum',
    'asserts_primary_med': 'sum',
    'asserts_secondary_med': 'sum',
    'negated': 'sum'
}).reset_index()

feature_stats.columns = ['subject_id', 'sleep_mentions', 'impairment_mentions', 
                         'primary_med_mentions', 'secondary_med_mentions', 
                         'negated_mentions']

print("\n2.1 Feature Extraction Frequencies")
print("-" * 40)
print(f"Patients with sleep difficulty mentions: {(feature_stats['sleep_mentions'] > 0).sum()} "
      f"({(feature_stats['sleep_mentions'] > 0).sum() / len(feature_stats) * 100:.1f}%)")
print(f"Patients with daytime impairment mentions: {(feature_stats['impairment_mentions'] > 0).sum()} "
      f"({(feature_stats['impairment_mentions'] > 0).sum() / len(feature_stats) * 100:.1f}%)")
print(f"Patients with primary med mentions: {(feature_stats['primary_med_mentions'] > 0).sum()} "
      f"({(feature_stats['primary_med_mentions'] > 0).sum() / len(feature_stats) * 100:.1f}%)")
print(f"Patients with secondary med mentions: {(feature_stats['secondary_med_mentions'] > 0).sum()} "
      f"({(feature_stats['secondary_med_mentions'] > 0).sum() / len(feature_stats) * 100:.1f}%)")

print("\n2.2 Feature Co-occurrence Patterns")
print("-" * 40)

# Merge with gold labels for analysis
feature_analysis = feature_stats.merge(df[['subject_id', 'rule_a_gold', 'rule_b_gold', 
                                            'rule_c_gold', 'any_gold']], on='subject_id')

# Analyze co-occurrence in ICD+ patients
icd_pos = feature_analysis[feature_analysis['any_gold'] == 1]

print("\nAmong ICD+ patients:")
both_sleep_impair = ((icd_pos['sleep_mentions'] > 0) & 
                      (icd_pos['impairment_mentions'] > 0)).sum()
print(f"  - Both sleep + impairment features: {both_sleep_impair}/{len(icd_pos)} "
      f"({both_sleep_impair/len(icd_pos)*100:.1f}%)")

sleep_only = ((icd_pos['sleep_mentions'] > 0) & 
              (icd_pos['impairment_mentions'] == 0)).sum()
print(f"  - Sleep only: {sleep_only}/{len(icd_pos)} "
      f"({sleep_only/len(icd_pos)*100:.1f}%)")

impair_only = ((icd_pos['sleep_mentions'] == 0) & 
               (icd_pos['impairment_mentions'] > 0)).sum()
print(f"  - Impairment only: {impair_only}/{len(icd_pos)} "
      f"({impair_only/len(icd_pos)*100:.1f}%)")

# ============================================================================
# 3. COMPLETENESS ANALYSIS
# ============================================================================

print("\n\n" + "="*80)
print("3. COMPLETENESS ANALYSIS")
print("="*80)
print("\nFollowing Li et al.: measuring coverage of insomnia-relevant information")

# Calculate sentence coverage
total_sentences = notes['note_text'].apply(lambda x: len(x.split('.'))).sum()
candidate_sentences = len(ev)

print(f"\n3.1 Sentence-Level Coverage")
print("-" * 40)
print(f"Total sentences in corpus: ~{total_sentences:,}")
print(f"Candidate sentences extracted: {candidate_sentences:,}")
print(f"Coverage rate: {candidate_sentences/total_sentences*100:.2f}%")

# Analyze patient coverage
patients_with_candidates = ev['subject_id'].nunique()
total_patients = df['subject_id'].nunique()

print(f"\n3.2 Patient-Level Coverage")
print("-" * 40)
print(f"Patients with extracted features: {patients_with_candidates}/{total_patients} "
      f"({patients_with_candidates/total_patients*100:.1f}%)")

# Analyze coverage by gold standard status
icd_pos_with_features = df[(df['any_gold'] == 1) & 
                           (df['subject_id'].isin(ev['subject_id']))].shape[0]
icd_pos_total = (df['any_gold'] == 1).sum()

print(f"\nICD+ patients with extracted features: {icd_pos_with_features}/{icd_pos_total} "
      f"({icd_pos_with_features/icd_pos_total*100:.1f}%)")

# Missed patients (false negatives) - no features extracted
fn_patients = df[(df['any_gold'] == 1) & (df['any_text'] == 0)]
print(f"\nFalse Negative patients: {len(fn_patients)}")
print(f"  - Likely due to incomplete feature extraction")

# ============================================================================
# 4. GRANULARITY ANALYSIS
# ============================================================================

print("\n\n" + "="*80)
print("4. GRANULARITY ANALYSIS")
print("="*80)
print("\nFollowing Li et al.: analyzing detail level and rule-specific patterns")

print("\n4.1 Rule-Specific Performance Comparison")
print("-" * 40)

rule_comparison = pd.DataFrame({
    'Rule': ['A (Symptoms)', 'B (Primary Meds)', 'C (Secondary Meds)', 'Any Rule'],
    'Precision': [results[0]['precision'], results[1]['precision'], 
                  results[2]['precision'], results[3]['precision']],
    'Recall': [results[0]['recall'], results[1]['recall'], 
               results[2]['recall'], results[3]['recall']],
    'F1-Score': [results[0]['f1'], results[1]['f1'], 
                 results[2]['f1'], results[3]['f1']],
    'Gold +': [df['rule_a_gold'].sum(), df['rule_b_gold'].sum(),
               df['rule_c_gold'].sum(), df['any_gold'].sum()],
    'LLM +': [df['rule_a_text'].sum(), df['rule_b_text'].sum(),
              df['rule_c_text'].sum(), df['any_text'].sum()]
})

print(rule_comparison.to_string(index=False))

print("\n4.2 Phenotype Complexity Analysis")
print("-" * 40)

# Analyze patients by number of rules satisfied (gold standard)
df['gold_rule_count'] = (df['rule_a_gold'] + df['rule_b_gold'] + df['rule_c_gold'])
df['llm_rule_count'] = (df['rule_a_text'] + df['rule_b_text'] + df['rule_c_text'])

print("\nGold Standard Rule Distribution:")
print(df['gold_rule_count'].value_counts().sort_index().to_string())

print("\nLLM Predicted Rule Distribution:")
print(df['llm_rule_count'].value_counts().sort_index().to_string())

# Analyze agreement by complexity
print("\n4.3 Performance by Phenotype Complexity")
print("-" * 40)

for n_rules in sorted(df['gold_rule_count'].unique()):
    subset = df[df['gold_rule_count'] == n_rules]
    if len(subset) > 0:
        agreement = (subset['any_gold'] == subset['any_text']).sum()
        print(f"Patients with {n_rules} rules: {len(subset)} patients, "
              f"{agreement} correct ({agreement/len(subset)*100:.1f}%)")

# ============================================================================
# 5. ERROR ANALYSIS
# ============================================================================

print("\n\n" + "="*80)
print("5. DETAILED ERROR ANALYSIS")
print("="*80)

# 5.1 False Negatives
print("\n5.1 FALSE NEGATIVES (Gold=1, LLM=0)")
print("-" * 40)

fn_patients = df[(df['any_gold'] == 1) & (df['any_text'] == 0)]
print(f"\nTotal False Negatives: {len(fn_patients)}")

if len(fn_patients) > 0:
    print("\nFalse Negative Patient IDs:")
    print(fn_patients['subject_id'].tolist())
    
    print("\nFalse Negative Rule Breakdown:")
    print(f"  - Rule A only: {(fn_patients['rule_a_gold'] == 1).sum()}")
    print(f"  - Rule B only: {(fn_patients['rule_b_gold'] == 1).sum()}")
    print(f"  - Rule C only: {(fn_patients['rule_c_gold'] == 1).sum()}")
    
    # Analyze why they were missed
    fn_with_notes = fn_patients.merge(
        ev.groupby('subject_id').size().reset_index(name='n_sentences'),
        on='subject_id', how='left'
    )
    fn_with_notes['n_sentences'] = fn_with_notes['n_sentences'].fillna(0)
    
    print("\n5.1.1 False Negative Error Categories:")
    print("-" * 40)
    
    no_sentences = (fn_with_notes['n_sentences'] == 0).sum()
    print(f"  Category 1: No candidate sentences extracted")
    print(f"    → Count: {no_sentences}")
    print(f"    → Likely cause: Vocabulary mismatch or phrasing not captured")
    
    with_sentences = (fn_with_notes['n_sentences'] > 0).sum()
    print(f"\n  Category 2: Sentences extracted but features not detected")
    print(f"    → Count: {with_sentences}")
    print(f"    → Likely cause: LLM classification errors")
    
    # Sample false negative sentences
    if with_sentences > 0:
        fn_ids = fn_with_notes[fn_with_notes['n_sentences'] > 0]['subject_id'].tolist()
        fn_sentences = ev[ev['subject_id'].isin(fn_ids)]
        
        print(f"\n5.1.2 Sample False Negative Sentences:")
        print("-" * 40)
        for idx, row in fn_sentences.head(10).iterrows():
            print(f"\nPatient: {row['subject_id']}")
            print(f"Sentence: {row['text_span'][:200]}...")
            print(f"LLM Output: sleep={row['asserts_sleep_difficulty']}, "
                  f"impair={row['asserts_daytime_impairment']}, "
                  f"negated={row['negated']}")

# 5.2 False Positives
print("\n\n5.2 FALSE POSITIVES (Gold=0, LLM=1)")
print("-" * 40)

fp_patients = df[(df['any_gold'] == 0) & (df['any_text'] == 1)]
print(f"\nTotal False Positives: {len(fp_patients)}")

if len(fp_patients) > 0:
    print("\nFalse Positive Patient IDs:")
    print(fp_patients['subject_id'].tolist())
    
    print("\nFalse Positive Rule Breakdown:")
    print(f"  - Rule A predicted: {(fp_patients['rule_a_text'] == 1).sum()}")
    print(f"  - Rule B predicted: {(fp_patients['rule_b_text'] == 1).sum()}")
    print(f"  - Rule C predicted: {(fp_patients['rule_c_text'] == 1).sum()}")
    
    # Analyze what was detected
    fp_ids = fp_patients['subject_id'].tolist()
    fp_sentences = ev[ev['subject_id'].isin(fp_ids)]
    
    print("\n5.2.1 False Positive Error Categories:")
    print("-" * 40)
    
    # Count detection patterns
    negation_errors = fp_sentences[fp_sentences['negated'] == True]
    print(f"  Category 1: Negation errors (features marked as present despite negation)")
    print(f"    → Sentences marked as negated: {len(negation_errors)}")
    print(f"    → NOTE: These shouldn't contribute to rules if negation worked correctly")
    
    feature_cols = [
    'asserts_sleep_difficulty',
    'asserts_daytime_impairment',
    'asserts_primary_med',
    'asserts_secondary_med'
    ]

    ambiguous = fp_sentences[
    fp_sentences[feature_cols].sum(axis=1) == 0
    ]
    
    print(f"\n  Category 2: False candidate extraction")
    print(f"    → Sentences with no features detected: {len(ambiguous)}")
    print(f"    → Likely caught by keyword filter but not truly relevant")
    
    print(f"\n  Category 3: Over-inference or context misunderstanding")
    print(f"    → Features detected inappropriately: {len(fp_sentences) - len(negation_errors) - len(ambiguous)}")
    print(f"    → May include: medication list mentions, family history, screening questions")
    
    # Sample false positive sentences
    print(f"\n5.2.2 Sample False Positive Sentences:")
    print("-" * 40)
    for idx, row in fp_sentences.head(10).iterrows():
        print(f"\nPatient: {row['subject_id']}")
        print(f"Sentence: {row['text_span'][:200]}...")
        print(f"LLM Output: sleep={row['asserts_sleep_difficulty']}, "
              f"impair={row['asserts_daytime_impairment']}, "
              f"primary_med={row['asserts_primary_med']}, "
              f"secondary_med={row['asserts_secondary_med']}, "
              f"negated={row['negated']}, temporality={row['temporality']}")

# 5.3 Feature-Level Error Analysis
print("\n\n5.3 FEATURE-LEVEL ERROR ANALYSIS")
print("-" * 40)

# Compare feature detection vs gold rules
print("\nRule A Errors (Sleep + Impairment):")
rule_a_errors = df[df['rule_a_gold'] != df['rule_a_text']]
print(f"  Total errors: {len(rule_a_errors)}/{len(df)} ({len(rule_a_errors)/len(df)*100:.1f}%)")

# Analyze specific failure patterns
rule_a_fn = df[(df['rule_a_gold'] == 1) & (df['rule_a_text'] == 0)]
if len(rule_a_fn) > 0:
    print(f"\n  False Negatives: {len(rule_a_fn)}")
    # Check if features were detected but not combined
    rule_a_fn_features = rule_a_fn.merge(
        ev.groupby('subject_id').agg({
            'asserts_sleep_difficulty': 'max',
            'asserts_daytime_impairment': 'max'
        }).reset_index(),
        on='subject_id',
        how='left'
    )
    
    sleep_only = (rule_a_fn_features['asserts_sleep_difficulty'] == True) & \
                 (rule_a_fn_features['asserts_daytime_impairment'] == False)
    print(f"    → Sleep detected but no impairment: {sleep_only.sum()}")
    
    impair_only = (rule_a_fn_features['asserts_sleep_difficulty'] == False) & \
                  (rule_a_fn_features['asserts_daytime_impairment'] == True)
    print(f"    → Impairment detected but no sleep: {impair_only.sum()}")
    
    neither = (rule_a_fn_features['asserts_sleep_difficulty'] == False) & \
              (rule_a_fn_features['asserts_daytime_impairment'] == False)
    print(f"    → Neither feature detected: {neither.sum()}")

print("\nRule B Errors (Primary Medications):")
rule_b_errors = df[df['rule_b_gold'] != df['rule_b_text']]
print(f"  Total errors: {len(rule_b_errors)}/{len(df)} ({len(rule_b_errors)/len(df)*100:.1f}%)")

print("\nRule C Errors (Secondary Medications + Symptoms):")
rule_c_errors = df[df['rule_c_gold'] != df['rule_c_text']]
print(f"  Total errors: {len(rule_c_errors)}/{len(df)} ({len(rule_c_errors)/len(df)*100:.1f}%)")

# ============================================================================
# 6. SUMMARY AND RECOMMENDATIONS
# ============================================================================

print("\n\n" + "="*80)
print("6. SUMMARY AND RECOMMENDATIONS")
print("="*80)

print("\n6.1 Key Findings")
print("-" * 40)
print(f"✓ Overall F1-Score: {results[3]['f1']:.3f}")
print(f"✓ Precision: {results[3]['precision']:.3f} (low false positive rate)")
print(f"✓ Recall: {results[3]['recall']:.3f} (moderate false negative rate)")

print("\n6.2 Information Expressiveness")
print("-" * 40)
print("• LLM successfully extracts multiple insomnia-related features")
print(f"• Feature extraction coverage: {candidate_sentences/total_sentences*100:.2f}% of sentences")
print("• Co-occurrence patterns align with clinical phenotypes")

print("\n6.3 Completeness")
print("-" * 40)
print(f"• Patient coverage: {patients_with_candidates}/{total_patients} patients")
print(f"• Main limitation: {len(fn_patients)} false negatives due to:")
print("  - Vocabulary/phrasing gaps")
print("  - Classification errors")

print("\n6.4 Granularity")
print("-" * 40)
print("• Rule A (Symptoms): Most challenging - requires both features")
print("• Rule B (Primary Meds): Best performance - clear vocabulary")
print("• Rule C (Secondary Meds): Moderate - medication detection strong")

print("\n6.5 Recommendations for Improvement")
print("-" * 40)
print("1. Expand keyword vocabulary for sentence extraction")
print("2. Improve negation and temporality detection")
print("3. Add contextual understanding for ambiguous cases")
print("4. Consider ensemble approaches for Rule A")
print("5. Manual review of borderline cases to refine prompts")

print("\n" + "="*80)
print("EVALUATION COMPLETE")
print("="*80)

# Save detailed error analysis
error_summary = pd.DataFrame({
    'Error_Type': ['False Negatives', 'False Positives', 'Rule A Errors', 
                   'Rule B Errors', 'Rule C Errors'],
    'Count': [len(fn_patients), len(fp_patients), len(rule_a_errors),
              len(rule_b_errors), len(rule_c_errors)],
    'Rate': [len(fn_patients)/len(df), len(fp_patients)/len(df),
             len(rule_a_errors)/len(df), len(rule_b_errors)/len(df),
             len(rule_c_errors)/len(df)]
})

print("\nSaving detailed outputs...")
error_summary.to_csv("error_analysis_summary.csv", index=False)
results_df.to_csv("performance_metrics_summary.csv", index=False)
rule_comparison.to_csv("rule_comparison.csv", index=False)

# Save false negative and false positive details with sentences
if len(fn_patients) > 0:
    fn_ids = fn_patients['subject_id'].tolist()
    fn_detailed = ev[ev['subject_id'].isin(fn_ids)]
    fn_detailed.to_csv("false_negatives_detailed.csv", index=False)

if len(fp_patients) > 0:
    fp_ids = fp_patients['subject_id'].tolist()
    fp_detailed = ev[ev['subject_id'].isin(fp_ids)]
    fp_detailed.to_csv("false_positives_detailed.csv", index=False)

print("\nFiles saved:")
print("  - error_analysis_summary.csv")
print("  - performance_metrics_summary.csv")
print("  - rule_comparison.csv")
print("  - false_negatives_detailed.csv")
print("  - false_positives_detailed.csv")

LOADING EVALUATION DATA

Dataset sizes:
  - Patients: 60
  - Sentences: 505
  - Notes: 212

1. OVERALL PERFORMANCE METRICS

Rule A: Sleep Difficulty + Daytime Impairment
---------------------------------------------
Confusion Matrix:
  TN:  38  |  FP:   9
  FN:   4  |  TP:   9

Performance Metrics:
  Accuracy:    0.783
  Precision:   0.500
  Recall:      0.692
  F1-Score:    0.581
  Specificity: 0.809
  NPV:         0.905

Rule B: Primary Insomnia Medications
------------------------------------
Confusion Matrix:
  TN:  35  |  FP:  10
  FN:   6  |  TP:   9

Performance Metrics:
  Accuracy:    0.733
  Precision:   0.474
  Recall:      0.600
  F1-Score:    0.529
  Specificity: 0.778
  NPV:         0.854

Rule C: Secondary Medications + Symptoms
----------------------------------------
Confusion Matrix:
  TN:  36  |  FP:   6
  FN:   5  |  TP:  13

Performance Metrics:
  Accuracy:    0.817
  Precision:   0.684
  Recall:      0.722
  F1-Score:    0.703
  Specificity: 0.857
  NPV:         0.