In [None]:
import os
os.chdir("/Users/M1HR/Desktop/MIGRAINE")

In [None]:
pip install frozendict

In [None]:
pip install experta pandas

In [None]:
pip install --upgrade jsonpickle pyyaml nltk

In [None]:


import json
import pandas as pd
from pathlib import Path
from collections import Counter, defaultdict
import warnings
warnings.filterwarnings('ignore')


# =============================================================================
# HELPER FUNCTIONS
# =============================================================================

def check_present(value):
  
    if not value:
        return False
    
    # Numeric: 0 = absent, >0 = present
    if isinstance(value, (int, float)):
        return value > 0
    
    # String: check for common "absent" values
    value_str = str(value).lower()
    return value_str not in ['not found', 'not specified', 'none', '', '0']


def check_base_migraine_criteria(patient, verbose=False):

    pain_count = 0
    pain_details = []
  
    # ------------------------------------------------------
    # 1. Unilateral location
    # ------------------------------------------------------
    location = patient.get('location')
    unilateral = False
    
    if location:
        loc_str = str(location).lower()
        # terms that imply "one side"
        unilateral_terms = ['left', 'right', 'side', 'temple', 'unilateral']
        if any(term in loc_str for term in unilateral_terms):
            unilateral = True
            pain_count += 1
            pain_details.append('unilateral')
    
    # ------------------------------------------------------
    # 2. Pulsating quality
    # ------------------------------------------------------
    character = patient.get('character')
    pulsating = False
    
    if character:
        char_str = str(character).lower()
        puls_terms = ['throb', 'puls', 'pound', 'beat']
        if any(term in char_str for term in puls_terms):
            pulsating = True
            pain_count += 1
            pain_details.append('pulsating')
    
    # ------------------------------------------------------
    # 3. Intensity: moderate or severe
    # ------------------------------------------------------
    moderate_severe = False
    
    # Text intensity (preferred)
    int_text = patient.get('intensity_text')
    if int_text:
        it = str(int_text).lower()
        if 'moderate' in it or 'severe' in it:
            moderate_severe = True
            pain_details.append(f'moderate_severe_text({int_text})')
    
    # Numeric intensity (fallback)
    if not moderate_severe:
        int_val = patient.get('intensity')
        try:
            if int_val is not None and float(int_val) >= 2:  # 2=moderate, 3=severe
                moderate_severe = True
                pain_details.append(f'moderate_severe_numeric({int_val})')
        except:
            pass
    
    if moderate_severe:
        pain_count += 1
    
    # ------------------------------------------------------
    # Final evaluation
    # Need >= 2 of 3 characteristics
    # ------------------------------------------------------
    meets_criterion = pain_count >= 2
    
    if verbose:
        if meets_criterion:
            print(f"  ‚úì Pain characteristics: {pain_count}/3 {pain_details}")
        else:
            print(f"  ‚úó Pain characteristics: {pain_count}/3 {pain_details}")
    

    # =========================================================================
    # CRITERION 2: Frequency >= 5 attacks (CHANGED from >=5)
    # =========================================================================
    frequency = patient.get('frequency')
    
    if frequency:
        try:
            freq_val = int(frequency)
            if freq_val < 5:
                reason = f"Frequency: {freq_val} (need >=4)"
                if verbose:
                    print(f"  ‚úó {reason}")
                return False, reason
            
            if verbose:
                print(f"  ‚úì Frequency: {freq_val}")
        except:
            if verbose:
                print(f"    Frequency parse failed, assuming valid")
    else:
        if verbose:
            print(f"    Frequency not specified, assuming valid")
    
    # =========================================================================
    # CRITERION 3: Accompanying symptoms (need >= 1)
    # =========================================================================
    nausea = patient.get('nausea')
    vomit = patient.get('vomit')
    photophobia = patient.get('photophobia')
    phonophobia = patient.get('phonophobia')
    
    has_symptoms = (
        check_present(nausea) or
        check_present(vomit) or
        check_present(photophobia) or
        check_present(phonophobia)
    )
    
    if not has_symptoms:
        reason = "No accompanying symptoms"
        if verbose:
            print(f"  ‚úó {reason}")
        return False, reason
    
    symptom_list = []
    if check_present(nausea): symptom_list.append('nausea')
    if check_present(vomit): symptom_list.append('vomit')
    if check_present(photophobia): symptom_list.append('photophobia')
    if check_present(phonophobia): symptom_list.append('phonophobia')
    
    if verbose:
        print(f"  ‚úì Accompanying symptoms: {symptom_list}")
    
    return True, "Base criteria met"

def count_ha_symptoms(patient):
    
    count = 0
    types = []
    
    ha = [
        'location', 'intensity'
    ]
    
    for field in ha:
        if check_present(patient.get(field)):
            count += 1
            types.append(field)
    
    return count, types
    
def count_aura_symptoms(patient):
    
    count = 0
    types = []
    
    aura_fields = [
        'visual', 'sensory', 'dysphasia', 'dysarthria',
        'vertigo', 'tinnitus', 'hypoacusis', 'diplopia',
        'ataxia', 'conscience', 'visual_defect', 'paresthesia'
    ]
    
    for field in aura_fields:
        if check_present(patient.get(field)):
            count += 1
            types.append(field)
    
    return count, types


def count_brainstem_symptoms(patient):
   
    count = 0
    types = []
    
    brainstem_fields = [
        'dysarthria', 'vertigo', 'tinnitus', 'hypoacusis',
        'diplopia', 'ataxia', 'conscience'
    ]
    
    for field in brainstem_fields:
        if check_present(patient.get(field)):
            count += 1
            types.append(field)
    
    return count, types


# =============================================================================
# MAIN DIAGNOSTIC FUNCTION
# =============================================================================

def diagnose_patient(patient, verbose=False):
   
    
    patient_id = patient.get('patient_id')
    reasoning = []
    
    if verbose:
        print(f"\n{'='*80}")
        print(f"DIAGNOSING PATIENT {patient_id}")
        print(f"{'='*80}")
    
    # =========================================================================
    # CHECK BASE CRITERIA
    # =========================================================================
    meets_base, base_reason = check_base_migraine_criteria(patient, verbose=verbose)
    
    if verbose:
        if meets_base:
            reasoning.append("‚úì Base migraine criteria met")
        else:
            reasoning.append(f"‚úó Base migraine criteria failed: {base_reason}")
    
    # =========================================================================
    # COUNT SYMPTOMS
    # =========================================================================
    aura_count, aura_types = count_aura_symptoms(patient)
    brainstem_count, brainstem_types = count_brainstem_symptoms(patient)
    
    if verbose:
        print(f"  Aura symptoms: {aura_count} ({aura_types[:5]}...)" if aura_count > 5 else f"  Aura symptoms: {aura_count} {aura_types}")
        print(f"  Brainstem symptoms: {brainstem_count} {brainstem_types}")
    
    reasoning.append(f"Aura: {aura_count}, Brainstem: {brainstem_count}")
    
    # Get other relevant fields
    frequency = patient.get('frequency', 0)
    dpf = patient.get('dpf', 0)

    #count HA
    ha_count, ha_types = count_ha_symptoms(patient)
    # =========================================================================
    # DIAGNOSTIC RULES (NON-OVERLAPPING - in priority order)
    # =========================================================================
    
    # Get dysphasia once for all rules
    dysphasia = patient.get('dysphasia')
    has_dysphasia = check_present(dysphasia)
    
    # Rule 1: Familial hemiplegic migraine
    # - Base criteria met
    # - dysphasia >= 1 (ONLY dysphasia)
    # - DPF = 1 (family history present)
    if has_dysphasia and (check_present(dpf) or int(dpf) == 1) and brainstem_count <2:
        diagnosis = "Familial hemiplegic migraine"
        code = "ICHD-3 1.2.3.1"
        confidence = "high"
        reasoning.append(f"‚úì {diagnosis}: dysphasia + DPF({int(dpf)})=1 + brainstem({brainstem_count})<2")
        
        if verbose:
            print(f"\n  ‚Üí DIAGNOSIS: {diagnosis}")
        
        return {
            'patient_id': patient_id,
            'diagnosis': diagnosis,
            'code': code,
            'confidence': confidence,
            'reasoning': reasoning
        }
    
    # Rule 2: Sporadic hemiplegic migraine
    # - Base criteria met
    # - dysphasia >0 (ONLY dysphasia)
    # - DPF = 0 (no family history)
    dpf_zero = not check_present(dpf) or int(dpf) == 0
    
    if has_dysphasia and dpf_zero and brainstem_count <2:
        diagnosis = "Sporadic hemiplegic migraine"
        code = "ICHD-3 1.2.3.2"
        confidence = "high"
        reasoning.append(f"‚úì {diagnosis}: dysphasia + DPF({int(dpf)})=0 +brainstem({brainstem_count})<2")
        
        if verbose:
            print(f"\n  ‚Üí DIAGNOSIS: {diagnosis}")
        
        return {
            'patient_id': patient_id,
            'diagnosis': diagnosis,
            'code': code,
            'confidence': confidence,
            'reasoning': reasoning
        }
    
    # Rule 3: Basilar-type aura
    # - Base criteria met
    # - dysphasia == 0 (NO dysphasia)
    if not has_dysphasia and brainstem_count > 1:
        diagnosis = "Basilar-type aura"
        code = "ICHD-3 1.2.2"
        confidence = "high"
        reasoning.append(f"‚úì {diagnosis}: brainstem({brainstem_count})>1")
        
        if verbose:
            print(f"\n  ‚Üí DIAGNOSIS: {diagnosis}")
        
        return {
            'patient_id': patient_id,
            'diagnosis': diagnosis,
            'code': code,
            'confidence': confidence,
            'reasoning': reasoning
        }
    
    # Rule 4: Typical aura with headache
    # - Base criteria met
    # - dysphasia == 0 (NO dysphasia)
    # - brainstem ==0 (not basilar)
    if ha_count >0 and aura_count >= 1 and not has_dysphasia and brainstem_count <=1:
        diagnosis = "Typical aura with migraine"
        code = "ICHD-3 1.2.1.1"
        confidence = "high"
        reasoning.append(f"‚úì {diagnosis}: aura({aura_count})>0 + headache({ha_count})>0 + no dysphasia + brainstem({brainstem_count})<=1")
        
        if verbose:
            print(f"\n  ‚Üí DIAGNOSIS: {diagnosis}")
        
        return {
            'patient_id': patient_id,
            'diagnosis': diagnosis,
            'code': code,
            'confidence': confidence,
            'reasoning': reasoning
        }
    
    # Rule 5: Typical aura without migraine
    # - Does NOT meet base criteria
    # - aura >= 1
    # - frequency = 0 (no headache attacks)
    # - dysphasia == 0 (NO dysphasia)
    # - brainstem ==0 (not basilar)
    if ha_count ==0 and aura_count >= 1 and not has_dysphasia and brainstem_count <=1:
        try:
            freq_val = int(frequency)
            if freq_val <= 3:
                diagnosis = "Typical aura without migraine"
                code = "ICHD-3 1.2.1.2"
                confidence = "high"
                reasoning.append(f"‚úì {diagnosis}: aura({aura_count})>0 + headache({ha_count})=0 + no dysphasia + brainstem({brainstem_count})<=1")
                
                if verbose:
                    print(f"\n  ‚Üí DIAGNOSIS: {diagnosis}")
                
                return {
                    'patient_id': patient_id,
                    'diagnosis': diagnosis,
                    'code': code,
                    'confidence': confidence,
                    'reasoning': reasoning
                }
        except:
            pass
    
    # Rule 6: Migraine without aura
    # - Base criteria met
    # - aura ==0
    # - brainstem ==0 (not basilar)
    # - dysphasia ==0
    
    if meets_base and aura_count == 0:
        diagnosis = "Migraine without aura"
        code = "ICHD-3 1.1"
        confidence = "high"
        reasoning.append(f"‚úì {diagnosis}: Base + aura({aura_count})=0 ")
        
        if verbose:
            print(f"\n  ‚Üí DIAGNOSIS: {diagnosis}")
        
        return {
            'patient_id': patient_id,
            'diagnosis': diagnosis,
            'code': code,
            'confidence': confidence,
            'reasoning': reasoning
        }
    
    # Rule 6: Other
    # - All remainder cases
    diagnosis = "Other"
    code = "N/A"
    confidence = "low"
    
    if not meets_base:
        reasoning.append(f"‚úó {diagnosis}: Base criteria not met")
    else:
        reasoning.append(f"‚úó {diagnosis}: No specific criteria matched")
    
    if verbose:
        print(f"\n  ‚Üí DIAGNOSIS: {diagnosis}")
    
    return {
        'patient_id': patient_id,
        'diagnosis': diagnosis,
        'code': code,
        'confidence': confidence,
        'reasoning': reasoning
    }


# =============================================================================
# BATCH PROCESSING
# =============================================================================

def batch_diagnose(
    summaries_file='data/ner_results/patient_summaries_fixed.json',
    output_file='data/diagnoses/ichd3_diagnoses_final.json',
    ground_truth_file=None
):
    
    
    print("\n" + "="*80)
    print("ICHD-3 SYMBOLIC REASONING ENGINE - BATCH DIAGNOSIS")
    print("="*80)
    
    # Load patient summaries
    print(f"\n1. Loading patient summaries from {summaries_file}...")
    with open(summaries_file, 'r', encoding='utf-8') as f:
        patients = json.load(f)
    
    print(f"   ‚úì Loaded {len(patients)} patients")
    
    # Diagnose all patients
    print(f"\n2. Running diagnostic engine...")
    
    all_results = []
    diagnosis_counts = Counter()
    
    for i, patient in enumerate(patients):
        if (i + 1) % 50 == 0:
            print(f"   Progress: {i+1}/{len(patients)}...")
        
        result = diagnose_patient(patient, verbose=False)
        all_results.append(result)
        diagnosis_counts[result['diagnosis']] += 1
    
    print(f"\n   ‚úì Diagnosed {len(all_results)} patients")
    
    # Save results
    print(f"\n3. Saving results to {output_file}...")
    
    output_path = Path(output_file)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    # Save JSON (full details)
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(all_results, f, indent=2, ensure_ascii=False)
    
    # Save CSV (summary)
    csv_data = [{
        'patient_id': r['patient_id'],
        'diagnosis': r['diagnosis'],
        'code': r['code'],
        'confidence': r['confidence']
    } for r in all_results]
    
    df = pd.DataFrame(csv_data)
    csv_path = output_path.with_suffix('.csv')
    df.to_csv(csv_path, index=False, encoding='utf-8')
    
    print(f"   ‚úì JSON: {output_path}")
    print(f"   ‚úì CSV: {csv_path}")
    
    # Display statistics
    print("\n" + "="*80)
    print("DIAGNOSIS DISTRIBUTION")
    print("="*80)
    
    print(f"\nResults:")
    for diagnosis, count in diagnosis_counts.most_common():
        percentage = count / len(all_results) * 100
        print(f"   {diagnosis:40s}: {count:4d} ({percentage:5.1f}%)")
    
    # Validate against ground truth if provided
    if ground_truth_file:
        print("\n" + "="*80)
        print("VALIDATION AGAINST GROUND TRUTH")
        print("="*80)
        
        try:
            gt_df = pd.read_csv(ground_truth_file)
            print(f"\n‚úì Loaded ground truth: {len(gt_df)} patients")
            
            # Create mapping
            gt_map = {}
            for idx, row in gt_df.iterrows():
                patient_id = idx + 1
                gt_map[patient_id] = row.get('Type', 'Unknown')
            
            # Compare
            matches = 0
            mismatches = 0
            mismatch_details = defaultdict(int)
            
            for result in all_results:
                pid = result['patient_id']
                predicted = result['diagnosis']
                actual = gt_map.get(pid, 'Unknown')
                
                # Normalize for comparison
                predicted_norm = predicted.lower()
                actual_norm = actual.lower()
                
                if predicted_norm == actual_norm:
                    matches += 1
                elif ('without aura' in predicted_norm and 'without aura' in actual_norm) or \
                     ('with aura' in predicted_norm and 'with aura' in actual_norm) or \
                     ('hemiplegic' in predicted_norm and 'hemiplegic' in actual_norm) or \
                     ('basilar' in predicted_norm and 'basilar' in actual_norm):
                    matches += 1
                else:
                    mismatches += 1
                    mismatch_details[f"{actual} ‚Üí {predicted}"] += 1
            
            total = matches + mismatches
            accuracy = matches / total * 100 if total > 0 else 0
            
            print(f"\n Accuracy:")
            print(f"   Total:      {total}")
            print(f"   Matches:    {matches}")
            print(f"   Mismatches: {mismatches}")
            print(f"   Accuracy:   {accuracy:.2f}%")
            
            if mismatches > 0:
                print(f"\n Top mismatches:")
                for error, count in sorted(mismatch_details.items(), key=lambda x: x[1], reverse=True)[:10]:
                    print(f"   {error}: {count}")
        
        except Exception as e:
            print(f"\n‚ö†Ô∏è  Validation error: {e}")
    
    # Show sample diagnoses
    print("\n" + "="*80)
    print("SAMPLE DIAGNOSES")
    print("="*80)
    
    for i in range(min(3, len(patients))):
        diagnose_patient(patients[i], verbose=True)
    
    print("\n" + "="*80)
    print(" BATCH DIAGNOSIS COMPLETE!")
    print("="*80)
    
    return all_results


# =============================================================================
# MAIN ENTRY POINT
# =============================================================================

if __name__ == "__main__":
    import sys

    
    results = batch_diagnose(
        summaries_file='data/ner_results/patient_summaries_fixed.json',
        output_file='data/diagnoses/ichd3_diagnoses_final.json',
        ground_truth_file='data/migraine_with_id.csv'
    )
        
    print(f"\n Diagnosed {len(results)} patients successfully!")
    
 

In [None]:
import pandas as pd

def run_engine_on_csv(csv_path, output_path="engine_output.csv", verbose=False):
    """
    Run your existing diagnose_patient() on every row in a CSV.
    Includes ground truth comparison using the 'Type' column.
    """

    df = pd.read_csv(csv_path)
    print(f"\nLoaded {len(df)} patients from {csv_path}")

    # Check if Type column exists
    if "Type" not in df.columns:
        raise ValueError("CSV missing required ground truth column: 'Type'")

    results = []

    for i, row in df.iterrows():
        patient_dict = row.to_dict()

        # Normalize keys to lower-case since diagnose_patient uses lowercase
        patient = {k.lower(): v for k, v in patient_dict.items()}

        diag = diagnose_patient(patient, verbose=verbose)

        gt = row["Type"]  # Ground truth

        # Compare engine diagnosis vs ground truth
        match = (diag['diagnosis'] == gt)

        results.append({
            "patient_id": diag['patient_id'],
            "ground_truth": gt,
            "diagnosis": diag['diagnosis'],
            "match": match,
            "code": diag['code'],
            "confidence": diag['confidence'],
            "reasoning": " | ".join(diag['reasoning']),
        })

        if (i+1) % 20 == 0:
            print(f"Processed {i+1}/{len(df)} patients...")

    out_df = pd.DataFrame(results)
    out_df.to_csv(output_path, index=False)

    # Quick stats
    acc = out_df['match'].mean() * 100
    print(f"\nüî• DONE! Saved output to {output_path}")
    print(f"üéØ Accuracy vs ground truth: {acc:.2f}%\n")

    return out_df


run_engine_on_csv("data/migraine_with_id.csv", verbose =True)

In [None]:
"""
Run ICHD-3 Diagnostic Engine on JSON Patient Summaries
========================================================

Runs diagnose_patient() on NER-extracted patient summaries (JSON format).
Compares against ground truth from original CSV.
"""

import pandas as pd
import json
from pathlib import Path


def run_engine_on_json(json_path, 
                       ground_truth_csv="data/migraine_with_id.csv",
                       output_path="engine_output_ner.csv", 
                       verbose=False):
    """
    Run diagnose_patient() on JSON patient summaries (NER-extracted data).
    
    Args:
        json_path: Path to JSON file with patient summaries
        ground_truth_csv: Path to CSV with ground truth diagnoses
        output_path: Where to save results CSV
        verbose: Whether to print detailed reasoning
    
    Returns:
        DataFrame with results
    """
    
    print("\n" + "="*80)
    print("RUNNING ENGINE ON NER-EXTRACTED JSON DATA")
    print("="*80)
    
    # Load JSON patient summaries
    with open(json_path, 'r') as f:
        patients_json = json.load(f)
    
    print(f"\n‚úì Loaded {len(patients_json)} patients from {json_path}")
    
    # Load ground truth
    gt_df = pd.read_csv(ground_truth_csv)
    print(f"‚úì Loaded ground truth from {ground_truth_csv}")
    
    # Create ground truth mapping
    gt_map = {}
    for idx, row in gt_df.iterrows():
        patient_id = idx + 1
        gt_map[patient_id] = row.get('Type', 'Unknown')
    
    # Process each patient
    results = []
    
    for i, patient_data in enumerate(patients_json):
        
        # Get patient ID
        patient_id = patient_data.get('patient_id', i + 1)
        
        # Prepare patient dict for diagnose_patient()
        # Convert JSON format to expected format (lowercase keys)
        patient = {
            'patient_id': patient_id,
            'duration': patient_data.get('duration', 0),
            'intensity': patient_data.get('intensity', 0),
            'intensity_text': patient_data.get('intensity_text', ''),
            'location': patient_data.get('location', ''),
            'character': patient_data.get('character', ''),
            'frequency': patient_data.get('frequency', 0),
            'nausea': patient_data.get('nausea', 0),
            'vomit': patient_data.get('vomit', 0),
            'photophobia': patient_data.get('photophobia', 0),
            'phonophobia': patient_data.get('phonophobia', 0),
            'visual': patient_data.get('visual', 0),
            'sensory': patient_data.get('sensory', 0),
            'dysphasia': patient_data.get('dysphasia', 0),
            'dysarthria': patient_data.get('dysarthria', 0),
            'vertigo': patient_data.get('vertigo', 0),
            'tinnitus': patient_data.get('tinnitus', 0),
            'hypoacusis': patient_data.get('hypoacusis', 0),
            'diplopia': patient_data.get('diplopia', 0),
            'ataxia': patient_data.get('ataxia', 0),
            'conscience': patient_data.get('conscience', 0),
            'visual_defect': patient_data.get('visual_defect', 0),
            'paresthesia': patient_data.get('paresthesia', 0),
            'dpf': patient_data.get('dpf', 0),
        }
        
        # Run diagnosis
        diag = diagnose_patient(patient, verbose=verbose)
        
        # Get ground truth
        gt = gt_map.get(patient_id, 'Unknown')
        
        # Compare
        match = (diag['diagnosis'] == gt)
        
        results.append({
            "patient_id": patient_id,
            "ground_truth": gt,
            "diagnosis": diag['diagnosis'],
            "match": match,
            "code": diag['code'],
            "confidence": diag['confidence'],
            "reasoning": " | ".join(diag['reasoning']),
        })
        
        if (i+1) % 20 == 0:
            print(f"Processed {i+1}/{len(patients_json)} patients...")
    
    # Create DataFrame
    out_df = pd.DataFrame(results)
    
    # Save to CSV
    output_dir = Path(output_path).parent
    output_dir.mkdir(parents=True, exist_ok=True)
    
    out_df.to_csv(output_path, index=False)
    
    # Calculate accuracy
    acc = out_df['match'].mean() * 100
    
    print(f"\n{'='*80}")
    print("RESULTS")
    print(f"{'='*80}")
    print(f"\n‚úì Saved output to: {output_path}")
    print(f"üéØ Accuracy vs ground truth: {acc:.2f}%")
    print(f"üìä Correct: {out_df['match'].sum()}/{len(out_df)}")
    
    # Show diagnosis distribution
    print(f"\n{'‚îÄ'*80}")
    print("Diagnosis Distribution:")
    print(f"{'‚îÄ'*80}")
    print("\nGround Truth:")
    print(out_df['ground_truth'].value_counts())
    print("\nPredicted:")
    print(out_df['diagnosis'].value_counts())
    
    return out_df


def run_engine_on_csv(csv_path, output_path="engine_output.csv", verbose=False):
    """
    Run diagnose_patient() on CSV data (original structured format).
    
    Args:
        csv_path: Path to CSV file
        output_path: Where to save results
        verbose: Whether to print detailed reasoning
    
    Returns:
        DataFrame with results
    """
    
    print("\n" + "="*80)
    print("RUNNING ENGINE ON ORIGINAL CSV DATA")
    print("="*80)
    
    df = pd.read_csv(csv_path)
    print(f"\n‚úì Loaded {len(df)} patients from {csv_path}")
    
    # Check if Type column exists
    if "Type" not in df.columns:
        raise ValueError("CSV missing required ground truth column: 'Type'")
    
    results = []
    
    for i, row in df.iterrows():
        patient_dict = row.to_dict()
        
        # Normalize keys to lower-case since diagnose_patient uses lowercase
        patient = {k.lower(): v for k, v in patient_dict.items()}
        
        diag = diagnose_patient(patient, verbose=verbose)
        gt = row["Type"]  # Ground truth
        
        # Compare engine diagnosis vs ground truth
        match = (diag['diagnosis'] == gt)
        
        results.append({
            "patient_id": diag['patient_id'],
            "ground_truth": gt,
            "diagnosis": diag['diagnosis'],
            "match": match,
            "code": diag['code'],
            "confidence": diag['confidence'],
            "reasoning": " | ".join(diag['reasoning']),
        })
        
        if (i+1) % 20 == 0:
            print(f"Processed {i+1}/{len(df)} patients...")
    
    out_df = pd.DataFrame(results)
    
    # Save
    output_dir = Path(output_path).parent
    output_dir.mkdir(parents=True, exist_ok=True)
    
    out_df.to_csv(output_path, index=False)
    
    # Quick stats
    acc = out_df['match'].mean() * 100
    
    print(f"\n{'='*80}")
    print("RESULTS")
    print(f"{'='*80}")
    print(f"\n‚úì Saved output to: {output_path}")
    print(f"üéØ Accuracy vs ground truth: {acc:.2f}%")
    print(f"üìä Correct: {out_df['match'].sum()}/{len(out_df)}")
    
    return out_df


# Example usage
if __name__ == "__main__":
    
    print("\n" + "="*80)
    print("EXAMPLE USAGE")
    print("="*80)
    

    results_json = run_engine_on_json(
        json_path="data/ner_results/patient_summaries_fixed.json",
        ground_truth_csv="data/migraine_with_id.csv",
        output_path="engine_output_ner.csv",
        verbose=False
    )



In [None]:
"""
Systematic Debug - Trace Rule Failures
=======================================

Find EXACTLY why "Typical aura with migraine" patients fail all rules.
"""

import json
import pandas as pd

# Load everything
with open('data/diagnoses/ichd3_diagnoses_final.json', 'r') as f:
    results = json.load(f)

with open('data/ner_results/patient_summaries_fixed.json', 'r') as f:
    patients = json.load(f)

gt_df = pd.read_csv('data/migraine_with_id.csv')

# Create mappings
gt_map = {}
for idx, row in gt_df.iterrows():
    patient_id = idx + 1
    gt_map[patient_id] = row.get('Type', 'Unknown')

patient_map = {p['patient_id']: p for p in patients}

print("="*80)
print("SYSTEMATIC DEBUG - FINDING THE BUG")
print("="*80)

# Find "Typical aura with migraine" ‚Üí "Other" cases
print("\nFinding 'Typical aura with migraine' ‚Üí 'Other' patients...")

failing_patients = []
for result in results:
    pid = result['patient_id']
    predicted = result['diagnosis']
    actual = gt_map.get(pid, 'Unknown')
    
    if 'typical aura with' in actual.lower() and predicted == "Other":
        failing_patients.append(pid)

print(f"Found {len(failing_patients)} failing patients")
print(f"\nAnalyzing first 3 patients in detail...\n")

# Detailed analysis
for i, pid in enumerate(failing_patients[:3]):
    patient = patient_map[pid]
    
    print("="*80)
    print(f"PATIENT {pid} - TRACE THROUGH ALL RULES")
    print("="*80)
    
    print(f"\nGround truth: {gt_map[pid]}")
    print(f"Our diagnosis: Other")
    
    # Show all fields
    print("\nüìã PATIENT DATA:")
    print("-"*80)
    print(f"  location:       {patient.get('location')}")
    print(f"  character:      {patient.get('character')}")
    print(f"  intensity:      {patient.get('intensity')}")
    print(f"  intensity_text: {patient.get('intensity_text')}")
    print(f"  frequency:      {patient.get('frequency')}")
    print(f"  nausea:         {patient.get('nausea')}")
    print(f"  vomit:          {patient.get('vomit')}")
    print(f"  photophobia:    {patient.get('photophobia')}")
    print(f"  phonophobia:    {patient.get('phonophobia')}")
    print(f"  visual:         {patient.get('visual')}")
    print(f"  sensory:        {patient.get('sensory')}")
    print(f"  dysphasia:      {patient.get('dysphasia')}")
    print(f"  dysarthria:     {patient.get('dysarthria')}")
    print(f"  vertigo:        {patient.get('vertigo')}")
    print(f"  tinnitus:       {patient.get('tinnitus')}")
    print(f"  dpf:            {patient.get('dpf')}")
    
    # Check base criteria manually
    print("\nüîç BASE CRITERIA CHECK:")
    print("-"*80)
    
    # Pain characteristics
    pain_count = 0
    pain_details = []
    
    location = patient.get('location')
    if location and any(term in str(location).lower() for term in ['left', 'right', 'side', 'temple']):
        pain_count += 1
        pain_details.append(f'unilateral({location})')
    
    character = patient.get('character')
    if character and any(term in str(character).lower() for term in ['throb', 'puls', 'pound', 'beat']):
        pain_count += 1
        pain_details.append(f'pulsating({character})')
    
    intensity_text = patient.get('intensity_text')
    intensity = patient.get('intensity')
    if intensity_text and any(term in str(intensity_text).lower() for term in ['moderate', 'severe']):
        pain_count += 1
        pain_details.append(f'mod/sev_text({intensity_text})')
    elif intensity:
        try:
            if float(intensity) >= 2:
                pain_count += 1
                pain_details.append(f'mod/sev_num({intensity})')
        except:
            pass
    
    print(f"  Pain characteristics: {pain_count}/3 {pain_details}")
    if pain_count < 2:
        print(f"  ‚ùå FAIL: Need >=2, got {pain_count}")
    else:
        print(f"  ‚úÖ PASS")
    
    # Frequency
    frequency = patient.get('frequency')
    freq_pass = True
    if frequency:
        try:
            if int(frequency) < 5:
                print(f"  Frequency: {frequency}")
                print(f"  ‚ùå FAIL: Need >=5, got {frequency}")
                freq_pass = False
            else:
                print(f"  Frequency: {frequency}")
                print(f"  ‚úÖ PASS")
        except:
            print(f"  Frequency: {frequency} (parse error, assuming pass)")
    else:
        print(f"  Frequency: Not specified (assuming pass)")
    
    # Symptoms
    nausea = patient.get('nausea')
    vomit = patient.get('vomit')
    photophobia = patient.get('photophobia')
    phonophobia = patient.get('phonophobia')
    
    def is_present(val):
        if not val:
            return False
        if isinstance(val, (int, float)):
            return val > 0
        return str(val).lower() not in ['not found', 'none', '', '0']
    
    has_symptoms = (is_present(nausea) or is_present(vomit) or 
                   is_present(photophobia) or is_present(phonophobia))
    
    sx_list = []
    if is_present(nausea): sx_list.append('nausea')
    if is_present(vomit): sx_list.append('vomit')
    if is_present(photophobia): sx_list.append('photophobia')
    if is_present(phonophobia): sx_list.append('phonophobia')
    
    print(f"  Symptoms: {sx_list}")
    if not has_symptoms:
        print(f"  ‚ùå FAIL: No symptoms")
    else:
        print(f"  ‚úÖ PASS")
    
    base_pass = (pain_count >= 2 and freq_pass and has_symptoms)
    
    print(f"\n  üéØ BASE CRITERIA: {'‚úÖ PASS' if base_pass else '‚ùå FAIL'}")
    
    # Count aura
    print("\nüîç AURA COUNT:")
    print("-"*80)
    
    aura_fields = ['visual', 'sensory', 'dysphasia', 'dysarthria', 
                   'vertigo', 'tinnitus', 'hypoacusis', 'diplopia',
                   'ataxia', 'conscience', 'visual_defect', 'paresthesia']
    
    aura_count = 0
    aura_list = []
    for field in aura_fields:
        if is_present(patient.get(field)):
            aura_count += 1
            aura_list.append(field)
    
    print(f"  Aura symptoms: {aura_count} {aura_list}")
    
    # Count brainstem
    brainstem_fields = ['dysarthria', 'vertigo', 'tinnitus', 'hypoacusis',
                       'diplopia', 'ataxia', 'conscience']
    
    brainstem_count = 0
    brainstem_list = []
    for field in brainstem_fields:
        if is_present(patient.get(field)):
            brainstem_count += 1
            brainstem_list.append(field)
    
    print(f"  Brainstem symptoms: {brainstem_count} {brainstem_list}")
    
    # DPF
    dpf = patient.get('dpf')
    print(f"  DPF: {dpf}")
    
    # Check each rule
    print("\nüîç RULE-BY-RULE CHECK:")
    print("-"*80)
    
    # Rule 1: Familial hemiplegic
    has_hemi_aura = (is_present(patient.get('dysphasia')) or 
                     is_present(patient.get('visual')) or 
                     is_present(patient.get('sensory')))
    
    print("\n  Rule 1: Familial hemiplegic")
    print(f"    Base: {base_pass}")
    print(f"    Hemiplegic aura (dysphasia/visual/sensory): {has_hemi_aura}")
    print(f"    DPF=1: {is_present(dpf) and int(dpf) == 1}")
    
    if base_pass and has_hemi_aura and is_present(dpf) and int(dpf) == 1:
        print("    ‚úÖ SHOULD MATCH")
    else:
        print("    ‚ùå NO MATCH")
    
    # Rule 2: Sporadic hemiplegic
    dpf_zero = not is_present(dpf) or int(dpf) == 0
    
    print("\n  Rule 2: Sporadic hemiplegic")
    print(f"    Base: {base_pass}")
    print(f"    Hemiplegic aura: {has_hemi_aura}")
    print(f"    DPF=0: {dpf_zero}")
    
    if base_pass and has_hemi_aura and dpf_zero:
        print("    ‚úÖ SHOULD MATCH")
    else:
        print("    ‚ùå NO MATCH")
    
    # Rule 3: Typical aura with headache
    print("\n  Rule 3: Typical aura with headache")
    print(f"    Base: {base_pass}")
    print(f"    Aura>=1: {aura_count >= 1}")
    print(f"    DPF=0: {dpf_zero}")
    print(f"    Brainstem<2: {brainstem_count < 2}")
    
    if base_pass and aura_count >= 1 and dpf_zero and brainstem_count < 2:
        print("    ‚úÖ SHOULD MATCH")
    else:
        print("    ‚ùå NO MATCH")
    
    # Rule 4: Typical aura without migraine
    print("\n  Rule 4: Typical aura without migraine")
    print(f"    NOT base: {not base_pass}")
    print(f"    Aura>=1: {aura_count >= 1}")
    print(f"    Freq=0: {frequency == 0 if frequency else 'N/A'}")
    print(f"    DPF=0: {dpf_zero}")
    print(f"    Brainstem<2: {brainstem_count < 2}")
    
    # Rule 5: Basilar
    print("\n  Rule 5: Basilar-type aura")
    print(f"    Base: {base_pass}")
    print(f"    Brainstem>=2: {brainstem_count >= 2}")
    
    if base_pass and brainstem_count >= 2:
        print("    ‚úÖ SHOULD MATCH")
    else:
        print("    ‚ùå NO MATCH")
    
    print("\n" + "="*80)
    print("üéØ WHY IT FAILED:")
    print("="*80)
    
    if not base_pass:
        print("‚ùå BASE CRITERIA FAILED")
        if pain_count < 2:
            print(f"   - Pain characteristics: {pain_count}/3 (need >=2)")
        if not freq_pass:
            print(f"   - Frequency: {frequency} (need >=5)")
        if not has_symptoms:
            print(f"   - No accompanying symptoms")
    elif aura_count == 0:
        print("‚ùå NO AURA SYMPTOMS")
    elif brainstem_count >= 2:
        print("‚ùå HAS >=2 BRAINSTEM (should be basilar but base fails?)")
    else:
        print("‚ùå UNKNOWN - should have matched a rule!")
    
    print("\n")

# Summary
print("="*80)
print("SUMMARY OF COMMON FAILURES")
print("="*80)

fail_reasons = {
    'base_fail': 0,
    'no_aura': 0,
    'brainstem_high': 0,
    'dpf_issue': 0,
    'other': 0
}

for pid in failing_patients[:50]:  # Check first 50
    patient = patient_map[pid]
    
    # Check base
    pain_count = 0
    if patient.get('location') and any(term in str(patient.get('location')).lower() for term in ['left', 'right', 'side', 'temple']):
        pain_count += 1
    if patient.get('character') and any(term in str(patient.get('character')).lower() for term in ['throb', 'puls', 'pound', 'beat']):
        pain_count += 1
    if patient.get('intensity_text') and any(term in str(patient.get('intensity_text')).lower() for term in ['moderate', 'severe']):
        pain_count += 1
    elif patient.get('intensity'):
        try:
            if float(patient.get('intensity')) >= 2:
                pain_count += 1
        except:
            pass
    
    freq_ok = True
    if patient.get('frequency'):
        try:
            if int(patient.get('frequency')) < 5:
                freq_ok = False
        except:
            pass
    
    def is_present(val):
        if not val:
            return False
        if isinstance(val, (int, float)):
            return val > 0
        return str(val).lower() not in ['not found', 'none', '', '0']
    
    has_sx = (is_present(patient.get('nausea')) or is_present(patient.get('vomit')) or
              is_present(patient.get('photophobia')) or is_present(patient.get('phonophobia')))
    
    base_ok = pain_count >= 2 and freq_ok and has_sx
    
    if not base_ok:
        fail_reasons['base_fail'] += 1
    else:
        # Count aura
        aura_count = sum(1 for f in ['visual', 'sensory', 'dysphasia', 'dysarthria', 
                                     'vertigo', 'tinnitus', 'hypoacusis', 'diplopia',
                                     'ataxia', 'conscience', 'visual_defect', 'paresthesia']
                        if is_present(patient.get(f)))
        
        if aura_count == 0:
            fail_reasons['no_aura'] += 1
        else:
            fail_reasons['other'] += 1

print("\nOut of first 50 failing patients:")
print(f"  Base criteria failed: {fail_reasons['base_fail']}")
print(f"  No aura symptoms:     {fail_reasons['no_aura']}")
print(f"  Other reasons:        {fail_reasons['other']}")

print("\nüéØ MAIN ISSUE IDENTIFIED!")

In [None]:
"""
Plot Confusion Matrix from Engine Output CSV
==============================================

Load engine_output.csv and plot confusion matrix with inferno colormap.
"""

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
from pathlib import Path

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 10)


def plot_confusion_matrix_from_csv(csv_path='engine_output.csv', 
                                   output_path='evaluation_results/symbolic_original/confusion_matrix_original_data.png'):
    """
    Load engine output CSV and plot confusion matrix
    
    Args:
        csv_path: Path to engine_output.csv
        output_path: Where to save the confusion matrix plot
    """
    
    print("\n" + "="*80)
    print("PLOTTING CONFUSION MATRIX FROM ENGINE OUTPUT")
    print("="*80)
    
    # Load results
    df = pd.read_csv(csv_path)
    print(f"\n‚úì Loaded {len(df)} patients from {csv_path}")
    
    # Extract true and predicted labels
    y_true = df['ground_truth']
    y_pred = df['diagnosis']
    
    # Calculate metrics
    accuracy = accuracy_score(y_true, y_pred)
    
    print(f"\n{'='*80}")
    print("PERFORMANCE METRICS")
    print(f"{'='*80}")
    print(f"\nOverall Accuracy: {accuracy:.4f} ({accuracy*100:.1f}%)")
    
    print(f"\n{'‚îÄ'*80}")
    print("Classification Report:")
    print(f"{'‚îÄ'*80}")
    print(classification_report(y_true, y_pred, zero_division=0))
    
    # Compute confusion matrix
    print(f"\n{'='*80}")
    print("GENERATING CONFUSION MATRIX")
    print(f"{'='*80}")
    
    cm = confusion_matrix(y_true, y_pred)
    labels = sorted(y_true.unique())
    
    print(f"\nUnique diagnoses found: {len(labels)}")
    print(f"Labels: {labels}")
    
    # Normalize
    cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    
    # Create figure
    fig, ax = plt.subplots(figsize=(12, 10))
    
    # Plot with inferno colormap
    sns.heatmap(cm_norm, annot=True, fmt='.2f', cmap='inferno',
                xticklabels=labels, yticklabels=labels,
                ax=ax, cbar_kws={'label': 'Proportion'})
    
    ax.set_title('Symbolic Reasoning on Original Data - Confusion Matrix', 
                 fontsize=14, fontweight='bold', pad=15)
    ax.set_xlabel('Predicted Diagnosis', fontsize=12, fontweight='bold')
    ax.set_ylabel('True Diagnosis', fontsize=12, fontweight='bold')
    plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
    plt.setp(ax.get_yticklabels(), rotation=0)
    
    plt.tight_layout()
    
    # Create output directory
    output_dir = Path(output_path).parent
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Save
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"\n‚úì Saved confusion matrix: {output_path}")
    plt.close()
    
    # Print diagnosis distribution
    print(f"\n{'='*80}")
    print("DIAGNOSIS DISTRIBUTION")
    print(f"{'='*80}")
    
    print("\nGround Truth Distribution:")
    print(y_true.value_counts().sort_index())
    
    print("\nPredicted Distribution:")
    print(y_pred.value_counts().sort_index())
    
    # Calculate per-class accuracy
    print(f"\n{'='*80}")
    print("PER-CLASS ACCURACY")
    print(f"{'='*80}")
    
    for label in labels:
        mask = y_true == label
        if mask.sum() > 0:
            class_acc = (y_true[mask] == y_pred[mask]).mean()
            correct = (y_true[mask] == y_pred[mask]).sum()
            total = mask.sum()
            print(f"\n{label}:")
            print(f"  Accuracy: {class_acc:.2%} ({correct}/{total})")
    
    print("\n" + "="*80)
    print("‚úÖ COMPLETE!")
    print("="*80)
    print(f"\nGenerated file: {output_path}")
    
    return df


def main():
    """Main execution"""
    
    print("\n" + "="*80)
    print("üîç CHECKING FOR ENGINE OUTPUT...")
    print("="*80)
    
    # Check multiple possible locations
    possible_paths = [
        'engine_output.csv',
        'data/engine_output.csv',
        '/home/claude/engine_output.csv',
    ]
    
    csv_path = None
    for path in possible_paths:
        if Path(path).exists():
            csv_path = path
            print(f"‚úì Found: {path}")
            break
    
    if csv_path is None:
        print("\n‚ùå ERROR: engine_output.csv not found in any location!")
        print("\n" + "="*80)
        print("üìã INSTRUCTIONS TO GENERATE engine_output.csv")
        print("="*80)
        print("""
1. Make sure you have your diagnose_patient() function defined

2. Run this code in Python:

   import pandas as pd
   
   def run_engine_on_csv(csv_path, output_path="engine_output.csv", verbose=False):
       df = pd.read_csv(csv_path)
       print(f"\\nLoaded {len(df)} patients from {csv_path}")
       
       if "Type" not in df.columns:
           raise ValueError("CSV missing required column: 'Type'")
       
       results = []
       for i, row in df.iterrows():
           patient_dict = row.to_dict()
           patient = {k.lower(): v for k, v in patient_dict.items()}
           diag = diagnose_patient(patient, verbose=verbose)
           gt = row["Type"]
           
           results.append({
               "patient_id": diag['patient_id'],
               "ground_truth": gt,
               "diagnosis": diag['diagnosis'],
               "match": (diag['diagnosis'] == gt),
               "code": diag['code'],
               "confidence": diag['confidence'],
               "reasoning": " | ".join(diag['reasoning']),
           })
           
           if (i+1) % 20 == 0:
               print(f"Processed {i+1}/{len(df)} patients...")
       
       out_df = pd.DataFrame(results)
       out_df.to_csv(output_path, index=False)
       
       acc = out_df['match'].mean() * 100
       print(f"\\nüî• DONE! Saved to {output_path}")
       print(f"üéØ Accuracy: {acc:.2f}%\\n")
       return out_df
   
   # Run it
   results = run_engine_on_csv("data/migraine_with_id.csv", verbose=False)

3. Then run this script again:
   python /mnt/user-data/outputs/plot_engine_confusion_matrix.py
        """)
        return
    
    # Plot confusion matrix
    output_path = 'evaluation_results/symbolic_original/confusion_matrix_original_data.png'
    plot_confusion_matrix_from_csv(csv_path, output_path)


if __name__ == "__main__":
    main()

In [None]:
"""
Detailed Error Analysis for ICHD-3 Diagnostic Engine (original data)
======================================================

Analyzes which reasoning criteria fail for misclassified patients.
Visualizes error patterns and rule failures.
"""


import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from collections import defaultdict, Counter
import re

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 10)

# Okabe-Ito colorblind-safe palette
COLORS = {
    'orange': '#E69F00',
    'sky_blue': '#56B4E9', 
    'bluish_green': '#009E73',
    'yellow': '#F0E442',
    'vermillion': '#D55E00',
    'reddish_purple': '#CC79A7'
}


def parse_reasoning(reasoning_str):
    """
    Extract specific criteria from reasoning string
    
    Returns dict with parsed components:
    - met_criteria: list of criteria that passed
    - failed_criteria: list of criteria that failed
    - diagnosis_path: which rule was triggered
    """
    if pd.isna(reasoning_str):
        return {
            'met_criteria': [],
            'failed_criteria': [],
            'diagnosis_path': 'Unknown'
        }
    
    met = []
    failed = []
    diagnosis = 'Unknown'
    
    # Split by separator
    parts = str(reasoning_str).split('|')
    
    for part in parts:
        part = part.strip()
        
        # Check for diagnosis assignment
        if 'diagnosed as' in part.lower() or 'diagnosis:' in part.lower():
            diagnosis = part
        
        # Check for positive criteria
        if any(word in part.lower() for word in ['meets', 'has', 'present', 'found', 'detected', '‚úì', 'pass']):
            met.append(part)
        
        # Check for negative criteria
        if any(word in part.lower() for word in ['fails', 'missing', 'absent', 'not found', 'insufficient', '‚úó', 'fail']):
            failed.append(part)
    
    return {
        'met_criteria': met,
        'failed_criteria': failed,
        'diagnosis_path': diagnosis
    }


def extract_criterion_type(criterion_text):
    """
    Categorize a criterion into types
    """
    text = criterion_text.lower()
    
    # Define criterion categories
    if any(word in text for word in ['pain', 'character', 'location', 'intensity', 'unilateral', 'pulsating', 'throbbing']):
        return 'Pain Characteristics'
    elif any(word in text for word in ['duration', 'hours', '4-72']):
        return 'Duration'
    elif any(word in text for word in ['frequency', 'attacks', 'episodes', '>=5']):
        return 'Frequency'
    elif any(word in text for word in ['nausea', 'vomit', 'photophobia', 'phonophobia', 'accompanying']):
        return 'Associated Symptoms'
    elif any(word in text for word in ['visual', 'sensory', 'dysphasia', 'aura', 'scotoma', 'paresthesia']):
        return 'Aura Symptoms'
    elif any(word in text for word in ['dysarthria', 'vertigo', 'tinnitus', 'diplopia', 'ataxia', 'brainstem']):
        return 'Brainstem Symptoms'
    elif any(word in text for word in ['hemiplegic', 'motor', 'weakness', 'paralysis']):
        return 'Motor Symptoms'
    elif any(word in text for word in ['dpf', 'family', 'familial', 'hereditary']):
        return 'Family History'
    else:
        return 'Other'


def analyze_errors(csv_path='engine_output_ner.csv'):
    """
    Comprehensive error analysis
    """
    
    print("\n" + "="*80)
    print("DETAILED ERROR ANALYSIS")
    print("="*80)
    
    # Load data
    df = pd.read_csv(csv_path)
    print(f"\n‚úì Loaded {len(df)} patients from {csv_path}")
    
    # Split into correct and incorrect
    correct_df = df[df['match'] == True]
    incorrect_df = df[df['match'] == False]
    
    print(f"\nüìä Overall Statistics:")
    print(f"   Correct: {len(correct_df)} ({len(correct_df)/len(df)*100:.1f}%)")
    print(f"   Incorrect: {len(incorrect_df)} ({len(incorrect_df)/len(df)*100:.1f}%)")
    
    # Analyze error patterns
    error_patterns = analyze_error_patterns(incorrect_df)
    
    # Analyze reasoning failures
    reasoning_failures = analyze_reasoning_failures(incorrect_df)
    
    # Analyze confusion patterns
    confusion_patterns = analyze_confusion_patterns(incorrect_df)
    
    return {
        'overall': df,
        'correct': correct_df,
        'incorrect': incorrect_df,
        'error_patterns': error_patterns,
        'reasoning_failures': reasoning_failures,
        'confusion_patterns': confusion_patterns
    }


def analyze_error_patterns(incorrect_df):
    """
    Analyze patterns in misclassifications
    """
    
    print(f"\n{'='*80}")
    print("ERROR PATTERNS BY TRUE DIAGNOSIS")
    print(f"{'='*80}")
    
    error_patterns = {}
    
    for true_diag in incorrect_df['ground_truth'].unique():
        mask = incorrect_df['ground_truth'] == true_diag
        subset = incorrect_df[mask]
        
        error_patterns[true_diag] = {
            'count': len(subset),
            'predicted_as': subset['diagnosis'].value_counts().to_dict(),
            'sample_cases': []
        }
        
        print(f"\n{true_diag}:")
        print(f"   Total errors: {len(subset)}")
        print(f"   Predicted as:")
        for pred, count in subset['diagnosis'].value_counts().items():
            pct = count / len(subset) * 100
            print(f"      - {pred}: {count} ({pct:.1f}%)")
        
        # Sample cases
        for idx, row in subset.head(3).iterrows():
            error_patterns[true_diag]['sample_cases'].append({
                'patient_id': row['patient_id'],
                'predicted': row['diagnosis'],
                'confidence': row.get('confidence', 'N/A'),
                'reasoning': row.get('reasoning', 'N/A')
            })
    
    return error_patterns


def analyze_reasoning_failures(incorrect_df):
    """
    Analyze which criteria fail most often
    """
    
    print(f"\n{'='*80}")
    print("REASONING FAILURE ANALYSIS")
    print(f"{'='*80}")
    
    # Parse all reasoning strings
    failed_criteria_by_type = defaultdict(int)
    failed_criteria_by_diagnosis = defaultdict(lambda: defaultdict(int))
    
    for idx, row in incorrect_df.iterrows():
        parsed = parse_reasoning(row.get('reasoning', ''))
        true_diag = row['ground_truth']
        
        for criterion in parsed['failed_criteria']:
            criterion_type = extract_criterion_type(criterion)
            failed_criteria_by_type[criterion_type] += 1
            failed_criteria_by_diagnosis[true_diag][criterion_type] += 1
    
    # Print overall failure counts
    print("\nüìä Most Common Criterion Failures (Overall):")
    sorted_failures = sorted(failed_criteria_by_type.items(), key=lambda x: x[1], reverse=True)
    for criterion_type, count in sorted_failures:
        pct = count / len(incorrect_df) * 100
        print(f"   {criterion_type}: {count} ({pct:.1f}%)")
    
    # Print by diagnosis
    print(f"\n{'‚îÄ'*80}")
    print("Criterion Failures by True Diagnosis:")
    print(f"{'‚îÄ'*80}")
    
    for true_diag in sorted(failed_criteria_by_diagnosis.keys()):
        print(f"\n{true_diag}:")
        failures = failed_criteria_by_diagnosis[true_diag]
        sorted_diag_failures = sorted(failures.items(), key=lambda x: x[1], reverse=True)
        for criterion_type, count in sorted_diag_failures[:5]:  # Top 5
            print(f"   - {criterion_type}: {count}")
    
    return {
        'by_type': dict(failed_criteria_by_type),
        'by_diagnosis': dict(failed_criteria_by_diagnosis)
    }


def analyze_confusion_patterns(incorrect_df):
    """
    Analyze which diagnoses are confused with each other
    """
    
    print(f"\n{'='*80}")
    print("CONFUSION PATTERNS")
    print(f"{'='*80}")
    
    confusion_pairs = []
    
    for idx, row in incorrect_df.iterrows():
        confusion_pairs.append((row['ground_truth'], row['diagnosis']))
    
    confusion_counts = Counter(confusion_pairs)
    
    print("\nMost Common Misclassifications:")
    for (true_diag, pred_diag), count in confusion_counts.most_common(10):
        print(f"   {true_diag} ‚Üí {pred_diag}: {count}")
    
    return dict(confusion_counts)


def visualize_error_analysis(analysis, output_dir='evaluation_results/symbolic_original'):
    """
    Create visualizations for error analysis
    """
    
    print(f"\n{'='*80}")
    print("GENERATING VISUALIZATIONS")
    print(f"{'='*80}")
    
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # 1. Error rate by diagnosis
    plot_error_rates(analysis, output_dir)
    
    # 2. Criterion failure heatmap
    plot_criterion_failures(analysis, output_dir)
    
    # 3. Confusion flow diagram
    plot_confusion_flow(analysis, output_dir)
    
    # 4. Detailed error breakdown
    plot_error_breakdown(analysis, output_dir)


def plot_error_rates(analysis, output_dir):
    """
    Plot error rates for each diagnosis type
    """
    
    df = analysis['overall']
    
    # Calculate error rate per diagnosis
    error_rates = {}
    for diag in df['ground_truth'].unique():
        mask = df['ground_truth'] == diag
        total = mask.sum()
        errors = ((df['ground_truth'] == diag) & (df['match'] == False)).sum()
        error_rates[diag] = {
            'error_rate': errors / total if total > 0 else 0,
            'total': total,
            'errors': errors,
            'correct': total - errors
        }
    
    # Sort by error rate
    sorted_diags = sorted(error_rates.keys(), key=lambda x: error_rates[x]['error_rate'], reverse=True)
    
    # Create figure
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # Left: Error rates
    error_pcts = [error_rates[d]['error_rate'] * 100 for d in sorted_diags]
    colors = [COLORS['vermillion'] if pct > 50 else COLORS['orange'] if pct > 30 else COLORS['bluish_green'] for pct in error_pcts]
    
    bars1 = ax1.barh(range(len(sorted_diags)), error_pcts, color=colors, alpha=0.8)
    ax1.set_yticks(range(len(sorted_diags)))
    ax1.set_yticklabels(sorted_diags, fontsize=16)
    ax1.set_xlabel('Error Rate (%)', fontsize=14, fontweight='bold')
    ax1.set_title('Error Rate by Diagnosis Type', fontsize=14, fontweight='bold', pad=15)
    ax1.axvline(50, color='red', linestyle='--', alpha=0.3, linewidth=1)
    ax1.axvline(30, color='orange', linestyle='--', alpha=0.3, linewidth=1)
    ax1.grid(axis='x', alpha=0.3)
    
    # Add value labels
    for i, (bar, pct) in enumerate(zip(bars1, error_pcts)):
        ax1.text(pct + 1, i, f'{pct:.1f}%', va='center', fontsize=9)
    
    # Right: Correct vs Incorrect counts
    diagnoses = sorted_diags
    correct_counts = [error_rates[d]['correct'] for d in diagnoses]
    error_counts = [error_rates[d]['errors'] for d in diagnoses]
    
    y_pos = np.arange(len(diagnoses))
    
    ax2.barh(y_pos, correct_counts, color=COLORS['bluish_green'], alpha=0.8, label='Correct')
    ax2.barh(y_pos, error_counts, left=correct_counts, color=COLORS['vermillion'], alpha=0.8, label='Errors')
    
    ax2.set_yticks(y_pos)
    ax2.set_yticklabels(diagnoses, fontsize=16)
    ax2.set_xlabel('Number of Patients', fontsize=14, fontweight='bold')
    ax2.set_title('Correct vs Incorrect Predictions', fontsize=14, fontweight='bold', pad=15)
    ax2.legend(loc='lower right', fontsize=14)
    ax2.grid(axis='x', alpha=0.3)
    
    plt.tight_layout()
    
    output_path = output_dir / 'error_rates_by_diagnosis.png'
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"‚úì Saved: {output_path}")
    plt.close()


def plot_criterion_failures(analysis, output_dir):
    """
    Plot heatmap of criterion failures by diagnosis
    """
    
    reasoning_failures = analysis['reasoning_failures']['by_diagnosis']
    
    if not reasoning_failures:
        print("‚ö† No reasoning failure data available, skipping criterion failure plot")
        return
    
    # Create matrix
    all_criterion_types = set()
    for diag_failures in reasoning_failures.values():
        all_criterion_types.update(diag_failures.keys())
    
    criterion_types = sorted(all_criterion_types)
    diagnoses = sorted(reasoning_failures.keys())
    
    # Build matrix
    matrix = np.zeros((len(diagnoses), len(criterion_types)))
    
    for i, diag in enumerate(diagnoses):
        for j, criterion in enumerate(criterion_types):
            matrix[i, j] = reasoning_failures[diag].get(criterion, 0)
    
    # Create figure
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # Plot heatmap with inferno colormap
    sns.heatmap(matrix, annot=True, fmt='g', cmap='inferno',
                xticklabels=criterion_types, yticklabels=diagnoses,
                ax=ax, cbar_kws={'label': 'Failure Count'})
    
    ax.set_title('Criterion Failures by Diagnosis Type (Inferno)', 
                 fontsize=14, fontweight='bold', pad=15)
    ax.set_xlabel('Criterion Type', fontsize=14, fontweight='bold')
    ax.set_ylabel('True Diagnosis', fontsize=14, fontweight='bold')
    plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
    plt.setp(ax.get_yticklabels(), rotation=0)
    
    plt.tight_layout()
    
    output_path = output_dir / 'criterion_failures_heatmap.png'
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"‚úì Saved: {output_path}")
    plt.close()


def plot_confusion_flow(analysis, output_dir):
    """
    Plot sankey-style confusion flow
    """
    
    confusion_patterns = analysis['confusion_patterns']
    
    if not confusion_patterns:
        print("‚ö† No confusion pattern data, skipping flow plot")
        return
    
    # Get top 10 confusions
    top_confusions = sorted(confusion_patterns.items(), key=lambda x: x[1], reverse=True)[:10]
    
    # Create bar chart
    fig, ax = plt.subplots(figsize=(12, 8))
    
    labels = [f"{true} ‚Üí {pred}" for (true, pred), count in top_confusions]
    counts = [count for (true, pred), count in top_confusions]
    
    bars = ax.barh(range(len(labels)), counts, color=COLORS['orange'], alpha=0.8)
    ax.set_yticks(range(len(labels)))
    ax.set_yticklabels(labels, fontsize=16)
    ax.set_xlabel('Number of Misclassifications', fontsize=16, fontweight='bold')
    ax.set_title('Top 10 Confusion Patterns', fontsize=14, fontweight='bold', pad=15)
    ax.grid(axis='x', alpha=0.3)
    
    # Add value labels
    for i, (bar, count) in enumerate(zip(bars, counts)):
        ax.text(count + 0.5, i, str(count), va='center', fontsize=14)
    
    plt.tight_layout()
    
    output_path = output_dir / 'confusion_flow.png'
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"‚úì Saved: {output_path}")
    plt.close()


def plot_error_breakdown(analysis, output_dir):
    """
    Plot detailed breakdown of failure types
    """
    
    reasoning_failures = analysis['reasoning_failures']['by_type']
    
    if not reasoning_failures:
        print("‚ö† No failure type data, skipping breakdown plot")
        return
    
    # Sort by count
    sorted_failures = sorted(reasoning_failures.items(), key=lambda x: x[1], reverse=True)
    
    criterion_types = [item[0] for item in sorted_failures]
    counts = [item[1] for item in sorted_failures]
    
    # Create figure
    fig, ax = plt.subplots(figsize=(12, 7))
    
    # Color by severity
    max_count = max(counts)
    colors = []
    for count in counts:
        if count > max_count * 0.6:
            colors.append(COLORS['vermillion'])
        elif count > max_count * 0.3:
            colors.append(COLORS['orange'])
        else:
            colors.append(COLORS['bluish_green'])
    
    bars = ax.barh(range(len(criterion_types)), counts, color=colors, alpha=0.8)
    ax.set_yticks(range(len(criterion_types)))
    ax.set_yticklabels(criterion_types, fontsize=14)
    ax.set_xlabel('Number of Failures', fontsize=14, fontweight='bold')
    ax.set_title('Criterion Failure Breakdown (All Misclassifications)', 
                 fontsize=14, fontweight='bold', pad=15)
    ax.grid(axis='x', alpha=0.3)
    
    # Add value labels
    for i, (bar, count) in enumerate(zip(bars, counts)):
        ax.text(count + max_count*0.01, i, str(count), va='center', fontsize=10)
    
    plt.tight_layout()
    
    output_path = output_dir / 'failure_breakdown.png'
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"‚úì Saved: {output_path}")
    plt.close()


def print_detailed_examples(analysis, num_examples=5):
    """
    Print detailed examples of errors
    """
    
    print(f"\n{'='*80}")
    print(f"DETAILED ERROR EXAMPLES (First {num_examples} per diagnosis)")
    print(f"{'='*80}")
    
    incorrect_df = analysis['incorrect']
    
    for true_diag in sorted(incorrect_df['ground_truth'].unique()):
        subset = incorrect_df[incorrect_df['ground_truth'] == true_diag].head(num_examples)
        
        print(f"\n{'‚îÄ'*80}")
        print(f"TRUE DIAGNOSIS: {true_diag}")
        print(f"{'‚îÄ'*80}")
        
        for idx, row in subset.iterrows():
            print(f"\n  Patient {row['patient_id']}:")
            print(f"    Predicted: {row['diagnosis']}")
            print(f"    Confidence: {row.get('confidence', 'N/A')}")
            print(f"    Reasoning:")
            
            reasoning = str(row.get('reasoning', 'N/A'))
            reasoning_parts = reasoning.split('|')
            for part in reasoning_parts[:5]:  # Show first 5 parts
                print(f"      ‚Ä¢ {part.strip()}")
            if len(reasoning_parts) > 5:
                print(f"      ... ({len(reasoning_parts) - 5} more criteria)")


def main():
    """Main execution"""
    
    print("\n" + "="*80)
    print("üîç CHECKING FOR ENGINE OUTPUT...")
    print("="*80)
    
    # Check for file
    possible_paths = [
        'engine_output.csv',
       
    ]
    
    csv_path = None
    for path in possible_paths:
        if Path(path).exists():
            csv_path = path
            print(f"‚úì Found: {path}")
            break
    
    if csv_path is None:
        print("\n‚ùå ERROR: engine_output.csv not found!")
        print("\nPlease run your engine first to generate engine_output.csv")
        print("See: /mnt/user-data/outputs/QUICK_START.txt")
        return
    
    # Run analysis
    analysis = analyze_errors(csv_path)
    
    # Create visualizations
    visualize_error_analysis(analysis)
    
    # Print detailed examples
    print_detailed_examples(analysis, num_examples=3)
    
    print("\n" + "="*80)
    print("‚úÖ ERROR ANALYSIS COMPLETE!")
    print("="*80)
    print("\nGenerated files in: evaluation_results/symbolic_original/")
    print("  1. error_rates_by_diagnosis.png")
    print("  2. criterion_failures_heatmap.png (inferno colormap)")
    print("  3. confusion_flow.png")
    print("  4. failure_breakdown.png")


if __name__ == "__main__":
    main()

In [None]:
"""
Detailed Error Analysis for ICHD-3 Diagnostic Engine (NER)
======================================================

Analyzes which reasoning criteria fail for misclassified patients.
Visualizes error patterns and rule failures.
"""

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from collections import defaultdict, Counter
import re

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 10)

# Okabe-Ito colorblind-safe palette
COLORS = {
    'orange': '#E69F00',
    'sky_blue': '#56B4E9', 
    'bluish_green': '#009E73',
    'yellow': '#F0E442',
    'vermillion': '#D55E00',
    'reddish_purple': '#CC79A7'
}


def parse_reasoning(reasoning_str):
    """
    Extract specific criteria from reasoning string
    
    Returns dict with parsed components:
    - met_criteria: list of criteria that passed
    - failed_criteria: list of criteria that failed
    - diagnosis_path: which rule was triggered
    """
    if pd.isna(reasoning_str):
        return {
            'met_criteria': [],
            'failed_criteria': [],
            'diagnosis_path': 'Unknown'
        }
    
    met = []
    failed = []
    diagnosis = 'Unknown'
    
    # Split by separator
    parts = str(reasoning_str).split('|')
    
    for part in parts:
        part = part.strip()
        
        # Check for diagnosis assignment
        if 'diagnosed as' in part.lower() or 'diagnosis:' in part.lower():
            diagnosis = part
        
        # Check for positive criteria
        if any(word in part.lower() for word in ['meets', 'has', 'present', 'found', 'detected', '‚úì', 'pass']):
            met.append(part)
        
        # Check for negative criteria
        if any(word in part.lower() for word in ['fails', 'missing', 'absent', 'not found', 'insufficient', '‚úó', 'fail']):
            failed.append(part)
    
    return {
        'met_criteria': met,
        'failed_criteria': failed,
        'diagnosis_path': diagnosis
    }


def extract_criterion_type(criterion_text):
    """
    Categorize a criterion into types
    """
    text = criterion_text.lower()
    
    # Define criterion categories
    if any(word in text for word in ['pain', 'character', 'location', 'intensity', 'unilateral', 'pulsating', 'throbbing']):
        return 'Pain Characteristics'
    elif any(word in text for word in ['duration', 'hours', '4-72']):
        return 'Duration'
    elif any(word in text for word in ['frequency', 'attacks', 'episodes', '>=5']):
        return 'Frequency'
    elif any(word in text for word in ['nausea', 'vomit', 'photophobia', 'phonophobia', 'accompanying']):
        return 'Associated Symptoms'
    elif any(word in text for word in ['visual', 'sensory', 'dysphasia', 'aura', 'scotoma', 'paresthesia']):
        return 'Aura Symptoms'
    elif any(word in text for word in ['dysarthria', 'vertigo', 'tinnitus', 'diplopia', 'ataxia', 'brainstem']):
        return 'Brainstem Symptoms'
    elif any(word in text for word in ['hemiplegic', 'motor', 'weakness', 'paralysis']):
        return 'Motor Symptoms'
    elif any(word in text for word in ['dpf', 'family', 'familial', 'hereditary']):
        return 'Family History'
    else:
        return 'Other'


def analyze_errors(csv_path='engine_output_ner.csv'):
    """
    Comprehensive error analysis
    """
    
    print("\n" + "="*80)
    print("DETAILED ERROR ANALYSIS")
    print("="*80)
    
    # Load data
    df = pd.read_csv(csv_path)
    print(f"\n‚úì Loaded {len(df)} patients from {csv_path}")
    
    # Split into correct and incorrect
    correct_df = df[df['match'] == True]
    incorrect_df = df[df['match'] == False]
    
    print(f"\nüìä Overall Statistics:")
    print(f"   Correct: {len(correct_df)} ({len(correct_df)/len(df)*100:.1f}%)")
    print(f"   Incorrect: {len(incorrect_df)} ({len(incorrect_df)/len(df)*100:.1f}%)")
    
    # Analyze error patterns
    error_patterns = analyze_error_patterns(incorrect_df)
    
    # Analyze reasoning failures
    reasoning_failures = analyze_reasoning_failures(incorrect_df)
    
    # Analyze confusion patterns
    confusion_patterns = analyze_confusion_patterns(incorrect_df)
    
    return {
        'overall': df,
        'correct': correct_df,
        'incorrect': incorrect_df,
        'error_patterns': error_patterns,
        'reasoning_failures': reasoning_failures,
        'confusion_patterns': confusion_patterns
    }


def analyze_error_patterns(incorrect_df):
    """
    Analyze patterns in misclassifications
    """
    
    print(f"\n{'='*80}")
    print("ERROR PATTERNS BY TRUE DIAGNOSIS")
    print(f"{'='*80}")
    
    error_patterns = {}
    
    for true_diag in incorrect_df['ground_truth'].unique():
        mask = incorrect_df['ground_truth'] == true_diag
        subset = incorrect_df[mask]
        
        error_patterns[true_diag] = {
            'count': len(subset),
            'predicted_as': subset['diagnosis'].value_counts().to_dict(),
            'sample_cases': []
        }
        
        print(f"\n{true_diag}:")
        print(f"   Total errors: {len(subset)}")
        print(f"   Predicted as:")
        for pred, count in subset['diagnosis'].value_counts().items():
            pct = count / len(subset) * 100
            print(f"      - {pred}: {count} ({pct:.1f}%)")
        
        # Sample cases
        for idx, row in subset.head(3).iterrows():
            error_patterns[true_diag]['sample_cases'].append({
                'patient_id': row['patient_id'],
                'predicted': row['diagnosis'],
                'confidence': row.get('confidence', 'N/A'),
                'reasoning': row.get('reasoning', 'N/A')
            })
    
    return error_patterns


def analyze_reasoning_failures(incorrect_df):
    """
    Analyze which criteria fail most often
    """
    
    print(f"\n{'='*80}")
    print("REASONING FAILURE ANALYSIS")
    print(f"{'='*80}")
    
    # Parse all reasoning strings
    failed_criteria_by_type = defaultdict(int)
    failed_criteria_by_diagnosis = defaultdict(lambda: defaultdict(int))
    
    for idx, row in incorrect_df.iterrows():
        parsed = parse_reasoning(row.get('reasoning', ''))
        true_diag = row['ground_truth']
        
        for criterion in parsed['failed_criteria']:
            criterion_type = extract_criterion_type(criterion)
            failed_criteria_by_type[criterion_type] += 1
            failed_criteria_by_diagnosis[true_diag][criterion_type] += 1
    
    # Print overall failure counts
    print("\nüìä Most Common Criterion Failures (Overall):")
    sorted_failures = sorted(failed_criteria_by_type.items(), key=lambda x: x[1], reverse=True)
    for criterion_type, count in sorted_failures:
        pct = count / len(incorrect_df) * 100
        print(f"   {criterion_type}: {count} ({pct:.1f}%)")
    
    # Print by diagnosis
    print(f"\n{'‚îÄ'*80}")
    print("Criterion Failures by True Diagnosis:")
    print(f"{'‚îÄ'*80}")
    
    for true_diag in sorted(failed_criteria_by_diagnosis.keys()):
        print(f"\n{true_diag}:")
        failures = failed_criteria_by_diagnosis[true_diag]
        sorted_diag_failures = sorted(failures.items(), key=lambda x: x[1], reverse=True)
        for criterion_type, count in sorted_diag_failures[:5]:  # Top 5
            print(f"   - {criterion_type}: {count}")
    
    return {
        'by_type': dict(failed_criteria_by_type),
        'by_diagnosis': dict(failed_criteria_by_diagnosis)
    }


def analyze_confusion_patterns(incorrect_df):
    """
    Analyze which diagnoses are confused with each other
    """
    
    print(f"\n{'='*80}")
    print("CONFUSION PATTERNS")
    print(f"{'='*80}")
    
    confusion_pairs = []
    
    for idx, row in incorrect_df.iterrows():
        confusion_pairs.append((row['ground_truth'], row['diagnosis']))
    
    confusion_counts = Counter(confusion_pairs)
    
    print("\nMost Common Misclassifications:")
    for (true_diag, pred_diag), count in confusion_counts.most_common(10):
        print(f"   {true_diag} ‚Üí {pred_diag}: {count}")
    
    return dict(confusion_counts)


def visualize_error_analysis(analysis, output_dir='evaluation_results/symbolic'):
    """
    Create visualizations for error analysis
    """
    
    print(f"\n{'='*80}")
    print("GENERATING VISUALIZATIONS")
    print(f"{'='*80}")
    
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # 1. Error rate by diagnosis
    plot_error_rates(analysis, output_dir)
    
    # 2. Criterion failure heatmap
    plot_criterion_failures(analysis, output_dir)
    
    # 3. Confusion flow diagram
    plot_confusion_flow(analysis, output_dir)
    
    # 4. Detailed error breakdown
    plot_error_breakdown(analysis, output_dir)


def plot_error_rates(analysis, output_dir):
    """
    Plot error rates for each diagnosis type
    """
    
    df = analysis['overall']
    
    # Calculate error rate per diagnosis
    error_rates = {}
    for diag in df['ground_truth'].unique():
        mask = df['ground_truth'] == diag
        total = mask.sum()
        errors = ((df['ground_truth'] == diag) & (df['match'] == False)).sum()
        error_rates[diag] = {
            'error_rate': errors / total if total > 0 else 0,
            'total': total,
            'errors': errors,
            'correct': total - errors
        }
    
    # Sort by error rate
    sorted_diags = sorted(error_rates.keys(), key=lambda x: error_rates[x]['error_rate'], reverse=True)
    
    # Create figure
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # Left: Error rates
    error_pcts = [error_rates[d]['error_rate'] * 100 for d in sorted_diags]
    colors = [COLORS['vermillion'] if pct > 50 else COLORS['orange'] if pct > 30 else COLORS['bluish_green'] for pct in error_pcts]
    
    bars1 = ax1.barh(range(len(sorted_diags)), error_pcts, color=colors, alpha=0.8)
    ax1.set_yticks(range(len(sorted_diags)))
    ax1.set_yticklabels(sorted_diags, fontsize=16)
    ax1.set_xlabel('Error Rate (%)', fontsize=14, fontweight='bold')
    ax1.set_title('Error Rate by Diagnosis Type', fontsize=14, fontweight='bold', pad=15)
    ax1.axvline(50, color='red', linestyle='--', alpha=0.3, linewidth=1)
    ax1.axvline(30, color='orange', linestyle='--', alpha=0.3, linewidth=1)
    ax1.grid(axis='x', alpha=0.3)
    
    # Add value labels
    for i, (bar, pct) in enumerate(zip(bars1, error_pcts)):
        ax1.text(pct + 1, i, f'{pct:.1f}%', va='center', fontsize=9)
    
    # Right: Correct vs Incorrect counts
    diagnoses = sorted_diags
    correct_counts = [error_rates[d]['correct'] for d in diagnoses]
    error_counts = [error_rates[d]['errors'] for d in diagnoses]
    
    y_pos = np.arange(len(diagnoses))
    
    ax2.barh(y_pos, correct_counts, color=COLORS['bluish_green'], alpha=0.8, label='Correct')
    ax2.barh(y_pos, error_counts, left=correct_counts, color=COLORS['vermillion'], alpha=0.8, label='Errors')
    
    ax2.set_yticks(y_pos)
    ax2.set_yticklabels(diagnoses, fontsize=16)
    ax2.set_xlabel('Number of Patients', fontsize=14, fontweight='bold')
    ax2.set_title('Correct vs Incorrect Predictions', fontsize=14, fontweight='bold', pad=15)
    ax2.legend(loc='lower right', fontsize=14)
    ax2.grid(axis='x', alpha=0.3)
    
    plt.tight_layout()
    
    output_path = output_dir / 'error_rates_by_diagnosis.png'
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"‚úì Saved: {output_path}")
    plt.close()


def plot_criterion_failures(analysis, output_dir):
    """
    Plot heatmap of criterion failures by diagnosis
    """
    
    reasoning_failures = analysis['reasoning_failures']['by_diagnosis']
    
    if not reasoning_failures:
        print("‚ö† No reasoning failure data available, skipping criterion failure plot")
        return
    
    # Create matrix
    all_criterion_types = set()
    for diag_failures in reasoning_failures.values():
        all_criterion_types.update(diag_failures.keys())
    
    criterion_types = sorted(all_criterion_types)
    diagnoses = sorted(reasoning_failures.keys())
    
    # Build matrix
    matrix = np.zeros((len(diagnoses), len(criterion_types)))
    
    for i, diag in enumerate(diagnoses):
        for j, criterion in enumerate(criterion_types):
            matrix[i, j] = reasoning_failures[diag].get(criterion, 0)
    
    # Create figure
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # Plot heatmap with inferno colormap
    sns.heatmap(matrix, annot=True, fmt='g', cmap='inferno',
                xticklabels=criterion_types, yticklabels=diagnoses,
                ax=ax, cbar_kws={'label': 'Failure Count'})
    
    ax.set_title('Criterion Failures by Diagnosis Type (Inferno)', 
                 fontsize=14, fontweight='bold', pad=15)
    ax.set_xlabel('Criterion Type', fontsize=14, fontweight='bold')
    ax.set_ylabel('True Diagnosis', fontsize=14, fontweight='bold')
    plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
    plt.setp(ax.get_yticklabels(), rotation=0)
    
    plt.tight_layout()
    
    output_path = output_dir / 'criterion_failures_heatmap.png'
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"‚úì Saved: {output_path}")
    plt.close()


def plot_confusion_flow(analysis, output_dir):
    """
    Plot sankey-style confusion flow
    """
    
    confusion_patterns = analysis['confusion_patterns']
    
    if not confusion_patterns:
        print("‚ö† No confusion pattern data, skipping flow plot")
        return
    
    # Get top 10 confusions
    top_confusions = sorted(confusion_patterns.items(), key=lambda x: x[1], reverse=True)[:10]
    
    # Create bar chart
    fig, ax = plt.subplots(figsize=(12, 8))
    
    labels = [f"{true} ‚Üí {pred}" for (true, pred), count in top_confusions]
    counts = [count for (true, pred), count in top_confusions]
    
    bars = ax.barh(range(len(labels)), counts, color=COLORS['yellow'], alpha=0.8)
    ax.set_yticks(range(len(labels)))
    ax.set_yticklabels(labels, fontsize=16)
    ax.set_xlabel('Number of Misclassifications', fontsize=14, fontweight='bold')
    ax.set_title('Top 10 Confusion Patterns', fontsize=14, fontweight='bold', pad=15)
    ax.grid(axis='x', alpha=0.3)
    
    # Add value labels
    for i, (bar, count) in enumerate(zip(bars, counts)):
        ax.text(count + 0.5, i, str(count), va='center', fontsize=10)
    
    plt.tight_layout()
    
    output_path = output_dir / 'confusion_flow.png'
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"‚úì Saved: {output_path}")
    plt.close()


def plot_error_breakdown(analysis, output_dir):
    """
    Plot detailed breakdown of failure types
    """
    
    reasoning_failures = analysis['reasoning_failures']['by_type']
    
    if not reasoning_failures:
        print("‚ö† No failure type data, skipping breakdown plot")
        return
    
    # Sort by count
    sorted_failures = sorted(reasoning_failures.items(), key=lambda x: x[1], reverse=True)
    
    criterion_types = [item[0] for item in sorted_failures]
    counts = [item[1] for item in sorted_failures]
    
    # Create figure
    fig, ax = plt.subplots(figsize=(12, 7))
    
    # Color by severity
    max_count = max(counts)
    colors = []
    for count in counts:
        if count > max_count * 0.6:
            colors.append(COLORS['yellow'])
        elif count > max_count * 0.3:
            colors.append(COLORS['orange'])
        else:
            colors.append(COLORS['bluish_green'])
    
    bars = ax.barh(range(len(criterion_types)), counts, color=colors, alpha=0.8)
    ax.set_yticks(range(len(criterion_types)))
    ax.set_yticklabels(criterion_types, fontsize=16)
    ax.set_xlabel('Number of Failures', fontsize=14, fontweight='bold')
    ax.set_title('Criterion Failure Breakdown (All Misclassifications)', 
                 fontsize=14, fontweight='bold', pad=15)
    ax.grid(axis='x', alpha=0.3)
    
    # Add value labels
    for i, (bar, count) in enumerate(zip(bars, counts)):
        ax.text(count + max_count*0.01, i, str(count), va='center', fontsize=10)
    
    plt.tight_layout()
    
    output_path = output_dir / 'failure_breakdown.png'
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"‚úì Saved: {output_path}")
    plt.close()


def print_detailed_examples(analysis, num_examples=5):
    """
    Print detailed examples of errors
    """
    
    print(f"\n{'='*80}")
    print(f"DETAILED ERROR EXAMPLES (First {num_examples} per diagnosis)")
    print(f"{'='*80}")
    
    incorrect_df = analysis['incorrect']
    
    for true_diag in sorted(incorrect_df['ground_truth'].unique()):
        subset = incorrect_df[incorrect_df['ground_truth'] == true_diag].head(num_examples)
        
        print(f"\n{'‚îÄ'*80}")
        print(f"TRUE DIAGNOSIS: {true_diag}")
        print(f"{'‚îÄ'*80}")
        
        for idx, row in subset.iterrows():
            print(f"\n  Patient {row['patient_id']}:")
            print(f"    Predicted: {row['diagnosis']}")
            print(f"    Confidence: {row.get('confidence', 'N/A')}")
            print(f"    Reasoning:")
            
            reasoning = str(row.get('reasoning', 'N/A'))
            reasoning_parts = reasoning.split('|')
            for part in reasoning_parts[:5]:  # Show first 5 parts
                print(f"      ‚Ä¢ {part.strip()}")
            if len(reasoning_parts) > 5:
                print(f"      ... ({len(reasoning_parts) - 5} more criteria)")

def main():
    """Main execution"""
    
    print("\n" + "="*80)
    print("üîç CHECKING FOR ENGINE OUTPUT...")
    print("="*80)
    
    # Check for file
    possible_paths = [
        'engine_output_ner.csv'
    ]
    
    csv_path = None
    for path in possible_paths:
        if Path(path).exists():
            csv_path = path
            print(f"‚úì Found: {path}")
            break
    
    if csv_path is None:
        print("\n‚ùå ERROR: engine_output.csv not found!")
        print("\nPlease run your engine first to generate engine_output.csv")
        print("See: /mnt/user-data/outputs/QUICK_START.txt")
        return
    
    # Run analysis
    analysis = analyze_errors(csv_path)
    
    # Create visualizations
    visualize_error_analysis(analysis)
    
    # Print detailed examples
    print_detailed_examples(analysis, num_examples=3)
    
    print("\n" + "="*80)
    print("‚úÖ ERROR ANALYSIS COMPLETE!")
    print("="*80)
    print("\nGenerated files in: evaluation_results/symbolic_original/")
    print("  1. error_rates_by_diagnosis.png")
    print("  2. criterion_failures_heatmap.png (inferno colormap)")
    print("  3. confusion_flow.png")
    print("  4. failure_breakdown.png")


if __name__ == "__main__":
    main()