In [None]:
import os
print(os.getcwd())
os.chdir("/Users/M1HR/Desktop/MIGRAINE")

In [None]:
import json
import pandas as pd
import numpy as np
from pathlib import Path
from collections import defaultdict


# =============================================================================
# 1. NORMALIZATION FUNCTIONS
# =============================================================================

def normalize_location(location_text):

    if not location_text or pd.isna(location_text):
        return 0
    
    loc_str = str(location_text).lower()
    
    # Check for bilateral
    bilateral_terms = ['bilateral', 'both', 'everywhere', 'all over', 'entire head']
    if any(term in loc_str for term in bilateral_terms):
        return 2
    
    # Check for unilateral
    unilateral_terms = ['left', 'right', 'side', 'temple', 'unilateral']
    if any(term in loc_str for term in unilateral_terms):
        return 1
    
    return 0


def normalize_character(character_text):
   
    if not character_text or pd.isna(character_text):
        return 0
    
    char_str = str(character_text).lower()
    
    # Check for throbbing/pulsating
    throbbing_terms = ['throb', 'puls', 'pound', 'beat']
    if any(term in char_str for term in throbbing_terms):
        return 1
    
    # Check for constant/pressing
    constant_terms = ['constant', 'steady', 'press', 'tight', 'squeeze']
    if any(term in char_str for term in constant_terms):
        return 2
    
    return 0


def normalize_intensity(intensity_text, intensity_numeric):

    # Try text field first (preferred)
    if intensity_text and not pd.isna(intensity_text):
        int_str = str(intensity_text).lower()
        
        if 'severe' in int_str:
            return 3
        elif 'moderate' in int_str or 'medium' in int_str:
            return 2
        elif 'mild' in int_str:
            return 1
    
    # Try numeric field (fallback)
    if intensity_numeric and not pd.isna(intensity_numeric):
        try:
            int_val = float(intensity_numeric)
            if int_val >= 2.5:
                return 3  # Severe
            elif int_val >= 1.5:
                return 2  # Moderate
            elif int_val >= 0.5:
                return 1  # Mild
        except:
            pass
    
    return 0


# =============================================================================
# 2. LOAD AND PREPARE DATA
# =============================================================================

def load_ner_results(json_path):
    
    print(f"\n{'='*80}")
    print("LOADING NER-EXTRACTED ENTITIES")
    print(f"{'='*80}")
    
    with open(json_path, 'r') as f:
        data = json.load(f)
    
    print(f"✓ Loaded {len(data)} patients")
    
    # Convert to DataFrame
    df = pd.DataFrame(data)
    
    # Normalize column names (capitalize first letter)
    df = df.rename(columns={
        'patient_id': 'Patient_ID',
        'duration': 'Duration',
        'intensity': 'Intensity',
        'intensity_text': 'Intensity_Text',
        'location': 'Location',
        'character': 'Character',
        'frequency': 'Frequency',
        'nausea': 'Nausea',
        'vomit': 'Vomit',
        'photophobia': 'Photophobia',
        'phonophobia': 'Phonophobia',
        'visual': 'Visual',
        'sensory': 'Sensory',
        'dysphasia': 'Dysphasia',
        'dysarthria': 'Dysarthria',
        'vertigo': 'Vertigo',
        'tinnitus': 'Tinnitus',
        'hypoacusis': 'Hypoacusis',
        'diplopia': 'Diplopia',
        'ataxia': 'Ataxia',
        'conscience': 'Conscience',
        'visual_defect': 'Visual_defect',
        'paresthesia': 'Paresthesia',
        'dpf': 'DPF'
    })
    
    # Normalize text fields to numeric
    print("\nNormalizing text fields to numeric values...")
    
    df['Location_Normalized'] = df.apply(
        lambda row: normalize_location(row.get('Location')), axis=1
    )
    
    df['Character_Normalized'] = df.apply(
        lambda row: normalize_character(row.get('Character')), axis=1
    )
    
    df['Intensity_Normalized'] = df.apply(
        lambda row: normalize_intensity(
            row.get('Intensity_Text'), 
            row.get('Intensity')
        ), axis=1
    )
    
    print(f" Normalized location: {df['Location_Normalized'].value_counts().to_dict()}")
    print(f" Normalized character: {df['Character_Normalized'].value_counts().to_dict()}")
    print(f" Normalized intensity: {df['Intensity_Normalized'].value_counts().to_dict()}")
    
    return df


def load_kaggle_data(csv_path):
    
    print(f"\n{'='*80}")
    print("LOADING ORIGINAL KAGGLE DATA")
    print(f"{'='*80}")
    
    df = pd.read_csv(csv_path)
    
    # Add Patient_ID if not present
    if 'Patient_ID' not in df.columns:
        df['Patient_ID'] = df.index + 1
    
    print(f" Loaded {len(df)} patients")
    
    return df


# =============================================================================
# 3. COMPARISON FUNCTIONS
# =============================================================================

def compare_numeric_field(kaggle_val, ner_val, tolerance=0.2):
   
    # Handle NaN
    kaggle_present = not pd.isna(kaggle_val) and kaggle_val != 0
    ner_present = not pd.isna(ner_val) and ner_val != 0
    
    if not kaggle_present:
        return 'na'
    
    if not ner_present:
        return 'missing'
    
    try:
        k_num = float(kaggle_val)
        n_num = float(ner_val)
        
        # Check relative difference
        if k_num == 0:
            return 'match' if n_num == 0 else 'mismatch'
        
        rel_diff = abs(k_num - n_num) / abs(k_num)
        return 'match' if rel_diff <= tolerance else 'mismatch'
    
    except:
        return 'mismatch'


def compare_binary_field(kaggle_val, ner_val):

    kaggle_present = not pd.isna(kaggle_val) and float(kaggle_val) > 0
    ner_present = not pd.isna(ner_val) and float(ner_val) > 0
    
    if kaggle_present and ner_present:
        return 'match'
    elif kaggle_present and not ner_present:
        return 'missing'
    elif not kaggle_present and ner_present:
        return 'hallucination'
    else:
        return 'na'


def compare_categorical_field(kaggle_val, ner_val):

    kaggle_present = not pd.isna(kaggle_val) and float(kaggle_val) > 0
    ner_present = not pd.isna(ner_val) and float(ner_val) > 0
    
    if not kaggle_present:
        return 'na'
    
    if not ner_present:
        return 'missing'
    
    return 'match' if float(kaggle_val) == float(ner_val) else 'mismatch'


# =============================================================================
# 4. COMPREHENSIVE COMPARISON
# =============================================================================

def compare_datasets(kaggle_df, ner_df):

    print(f"\n{'='*80}")
    print("COMPARING DATASETS")
    print(f"{'='*80}")
    
    # Merge datasets
    merged = kaggle_df.merge(
        ner_df,
        on='Patient_ID',
        how='inner',
        suffixes=('_kaggle', '_ner')
    )
    
    print(f"✓ Merged {len(merged)} patients")
    
    # Define fields to compare
    fields_to_compare = {
        # Categorical (normalized)
        'Location': {
            'kaggle': 'Location_kaggle',
            'ner': 'Location_Normalized',
            'type': 'categorical'
        },
        'Character': {
            'kaggle': 'Character_kaggle',
            'ner': 'Character_Normalized',
            'type': 'categorical'
        },
        'Intensity': {
            'kaggle': 'Intensity_kaggle',
            'ner': 'Intensity_Normalized',
            'type': 'categorical'
        },
        
        # Numeric
        'Duration': {
            'kaggle': 'Duration_kaggle',
            'ner': 'Duration_ner',
            'type': 'numeric'
        },
        'Frequency': {
            'kaggle': 'Frequency_kaggle',
            'ner': 'Frequency_ner',
            'type': 'numeric'
        },
        
        # Binary
        'Nausea': {'kaggle': 'Nausea_kaggle', 'ner': 'Nausea_ner', 'type': 'binary'},
        'Vomit': {'kaggle': 'Vomit_kaggle', 'ner': 'Vomit_ner', 'type': 'binary'},
        'Photophobia': {'kaggle': 'Photophobia_kaggle', 'ner': 'Photophobia_ner', 'type': 'binary'},
        'Phonophobia': {'kaggle': 'Phonophobia_kaggle', 'ner': 'Phonophobia_ner', 'type': 'binary'},
        'Visual': {'kaggle': 'Visual_kaggle', 'ner': 'Visual_ner', 'type': 'binary'},
        'Sensory': {'kaggle': 'Sensory_kaggle', 'ner': 'Sensory_ner', 'type': 'binary'},
        'Dysphasia': {'kaggle': 'Dysphasia_kaggle', 'ner': 'Dysphasia_ner', 'type': 'binary'},
        'Dysarthria': {'kaggle': 'Dysarthria_kaggle', 'ner': 'Dysarthria_ner', 'type': 'binary'},
        'Vertigo': {'kaggle': 'Vertigo_kaggle', 'ner': 'Vertigo_ner', 'type': 'binary'},
        'Tinnitus': {'kaggle': 'Tinnitus_kaggle', 'ner': 'Tinnitus_ner', 'type': 'binary'},
        'DPF': {'kaggle': 'DPF_kaggle', 'ner': 'DPF_ner', 'type': 'binary'},
    }
    
    # Compare each field
    results = defaultdict(lambda: defaultdict(int))
    per_patient_results = []
    
    for idx, row in merged.iterrows():
        patient_id = row['Patient_ID']
        patient_matches = 0
        patient_total = 0
        
        for field_name, config in fields_to_compare.items():
            kaggle_col = config['kaggle']
            ner_col = config['ner']
            field_type = config['type']
            
            kaggle_val = row.get(kaggle_col)
            ner_val = row.get(ner_col)
            
            # Compare based on type
            if field_type == 'numeric':
                result = compare_numeric_field(kaggle_val, ner_val)
            elif field_type == 'binary':
                result = compare_binary_field(kaggle_val, ner_val)
            elif field_type == 'categorical':
                result = compare_categorical_field(kaggle_val, ner_val)
            else:
                result = 'unknown'
            
            # Update counters
            results[field_name][result] += 1
            
            if result not in ['na']:
                patient_total += 1
                if result == 'match':
                    patient_matches += 1
        
        per_patient_results.append({
            'Patient_ID': patient_id,
            'matches': patient_matches,
            'total': patient_total,
            'accuracy': patient_matches / patient_total if patient_total > 0 else 0
        })
    
    # Print results
    print(f"\n{'='*80}")
    print("PER-FIELD COMPARISON RESULTS")
    print(f"{'='*80}\n")
    
    summary_records = []
    
    for field_name, counts in sorted(results.items()):
        total_applicable = sum(v for k, v in counts.items() if k != 'na')
        
        if total_applicable > 0:
            match_rate = counts['match'] / total_applicable
            missing_rate = counts['missing'] / total_applicable
            mismatch_rate = counts.get('mismatch', 0) / total_applicable
            
            print(f"{field_name:15s}: "
                  f"Match {match_rate*100:5.1f}%  "
                  f"Missing {missing_rate*100:5.1f}%  "
                  f"Mismatch {mismatch_rate*100:5.1f}%  "
                  f"(n={total_applicable})")
            
            summary_records.append({
                'field': field_name,
                'match_rate': match_rate,
                'missing_rate': missing_rate,
                'mismatch_rate': mismatch_rate,
                'total_applicable': total_applicable
            })
    
    # Overall statistics
    total_matches = sum(p['matches'] for p in per_patient_results)
    total_comparisons = sum(p['total'] for p in per_patient_results)
    overall_accuracy = total_matches / total_comparisons if total_comparisons > 0 else 0
    
    print(f"\n{'='*80}")
    print("OVERALL STATISTICS")
    print(f"{'='*80}\n")
    print(f"  Total Comparisons: {total_comparisons}")
    print(f"  Matches:           {total_matches} ({overall_accuracy*100:.1f}%)")
    print(f"  Mismatches:        {total_comparisons - total_matches} ({(1-overall_accuracy)*100:.1f}%)")
    print(f"  Overall Accuracy:  {overall_accuracy:.4f}")
    
    return {
        'per_field': results,
        'per_patient': per_patient_results,
        'summary': summary_records,
        'overall_accuracy': overall_accuracy
    }


# =============================================================================
# 5. SAVE RESULTS
# =============================================================================

def save_comparison_results(comparison_results, output_dir):

    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    print(f"\n{'='*80}")
    print("SAVING RESULTS")
    print(f"{'='*80}")
    
    # Per-field summary
    summary_df = pd.DataFrame(comparison_results['summary'])
    summary_df = summary_df.sort_values('match_rate', ascending=False)
    summary_path = output_dir / 'field_comparison_summary.csv'
    summary_df.to_csv(summary_path, index=False)
    print(f"✓ Field summary: {summary_path}")
    
    # Per-patient results
    patient_df = pd.DataFrame(comparison_results['per_patient'])
    patient_path = output_dir / 'patient_comparison.csv'
    patient_df.to_csv(patient_path, index=False)
    print(f"✓ Patient results: {patient_path}")
    
    # Overall metrics
    overall_path = output_dir / 'overall_metrics.json'
    with open(overall_path, 'w') as f:
        json.dump({
            'overall_accuracy': comparison_results['overall_accuracy']
        }, f, indent=2)
    print(f"✓ Overall metrics: {overall_path}")
    
    print(f"\n✓ All results saved to: {output_dir}")


# =============================================================================
# 6. MAIN
# =============================================================================

def compare_ner_with_kaggle(
    ner_json_path='data/ner_results/patient_summaries_fixed.json',
    kaggle_csv_path='data/migraine_with_id.csv',
    output_dir='evaluation_results/ner_comparison'
):

    
    print("\n" + "="*80)
    print("NER EXTRACTION vs ORIGINAL DATA COMPARISON")
    print("="*80)
   
    
    # Load data
    ner_df = load_ner_results(ner_json_path)
    kaggle_df = load_kaggle_data(kaggle_csv_path)
    
    # Compare
    results = compare_datasets(kaggle_df, ner_df)
    
    # Save
    save_comparison_results(results, output_dir)
    
    # Final summary
    print("\n" + "="*80)
    print(" COMPARISON COMPLETE!")
    print("="*80)
    
    print(f"\nKEY FINDINGS:")
    print(f"  Overall Accuracy: {results['overall_accuracy']:.4f} ({results['overall_accuracy']*100:.1f}%)")
    
    # Show best and worst fields
    summary_df = pd.DataFrame(results['summary'])
    summary_df = summary_df.sort_values('match_rate', ascending=False)
    
    print(f"\n  Best Preserved (Top 3):")
    for _, row in summary_df.head(3).iterrows():
        print(f"    {row['field']:15s}: {row['match_rate']*100:5.1f}%")
    
    print(f"\n  Worst Preserved (Bottom 3):")
    for _, row in summary_df.tail(3).iterrows():
        print(f"    {row['field']:15s}: {row['match_rate']*100:5.1f}%")
    
    return results


if __name__ == "__main__":
    results = compare_ner_with_kaggle()

In [None]:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 8)
plt.rcParams['font.size'] = 11


# =============================================================================
# 1. FIELD COMPARISON VISUALIZATIONS
# =============================================================================

def plot_field_match_rates(df, output_path):
    
    # Sort by match rate
    df = df.sort_values('match_rate', ascending=True)
    
    okabe_ito = {
        'orange': '#E69F00',      # Excellent
        'sky_blue': '#56B4E9',    # Good
        'bluish_green': '#009E73', # Fair
        'yellow': '#F0E442'       # Poor
    }
    
    colors = []
    for rate in df['match_rate']:
        if rate >= 0.9:
            colors.append(okabe_ito['orange'])      # Excellent
        elif rate >= 0.8:
            colors.append(okabe_ito['sky_blue'])    # Good
        elif rate >= 0.7:
            colors.append(okabe_ito['bluish_green']) # Fair
        else:
            colors.append(okabe_ito['yellow'])      # Poor
    
    fig, ax = plt.subplots(figsize=(12, max(8, len(df) * 0.4)))
    
    bars = ax.barh(df['field'], df['match_rate'], color=colors, alpha=0.85, edgecolor='black', linewidth=1)
    
    ax.set_xlabel('Match Rate', fontsize=13, fontweight='bold')
    ax.set_ylabel('Clinical Attribute', fontsize=13, fontweight='bold')
    ax.set_title('NER Extraction: Information Preservation by Attribute', 
                 fontsize=15, fontweight='bold', pad=20)
    ax.set_xlim([0, 1.05])
    
    # Add grid
    ax.grid(axis='x', alpha=0.3, linestyle='--')
    ax.set_axisbelow(True)
    
    # Add value labels
    for i, (bar, val, total) in enumerate(zip(bars, df['match_rate'], df['total_applicable'])):
        width = bar.get_width()
        label = f'{val*100:.1f}%'
        
        # Position text inside or outside bar depending on width
        if width > 0.15:
            ax.text(width - 0.03, i, label, 
                   ha='right', va='center', fontsize=10, fontweight='bold', color='white')
        else:
            ax.text(width + 0.02, i, label,
                   ha='left', va='center', fontsize=10, fontweight='bold')
        
        # Add sample size annotation
        ax.text(1.01, i, f'n={int(total)}',
               ha='left', va='center', fontsize=9, style='italic', color='gray')
    
    # Add performance level legend - Okabe-Ito palette
    from matplotlib.patches import Patch
    legend_elements = [
        Patch(facecolor='#E69F00', alpha=0.85, label='Excellent (≥90%)'),
        Patch(facecolor='#56B4E9', alpha=0.85, label='Good (80-89%)'),
        Patch(facecolor='#009E73', alpha=0.85, label='Fair (70-79%)'),
        Patch(facecolor='#F0E442', alpha=0.85, label='Poor (<70%)')
    ]
    ax.legend(handles=legend_elements, loc='lower right', framealpha=0.95)
    
    # Add reference lines - Okabe-Ito palette
    ax.axvline(x=0.9, color='#E69F00', linestyle='--', alpha=0.4, linewidth=1.5)
    ax.axvline(x=0.8, color='#56B4E9', linestyle='--', alpha=0.4, linewidth=1.5)
    ax.axvline(x=0.7, color='#009E73', linestyle='--', alpha=0.4, linewidth=1.5)
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f" Saved: {output_path}")
    plt.close()


def plot_field_breakdown_stacked(df, output_path):

    
    # Sort by match rate
    df = df.sort_values('match_rate', ascending=True)
    
    fig, ax = plt.subplots(figsize=(12, max(8, len(df) * 0.4)))
    
    # Create stacked bars
    fields = df['field']
    match = df['match_rate'] * 100
    missing = df['missing_rate'] * 100
    mismatch = df['mismatch_rate'] * 100
    
    # Plot - Okabe-Ito colorblind-safe palette
    ax.barh(fields, match, label='Match', color='#E69F00', alpha=0.85, edgecolor='black', linewidth=0.5)
    ax.barh(fields, missing, left=match, label='Missing', color='#56B4E9', alpha=0.85, edgecolor='black', linewidth=0.5)
    ax.barh(fields, mismatch, left=match + missing, label='Mismatch', color='#F0E442', alpha=0.85, edgecolor='black', linewidth=0.5)
    
    ax.set_xlabel('Percentage (%)', fontsize=13, fontweight='bold')
    ax.set_ylabel('Clinical Attribute', fontsize=13, fontweight='bold')
    ax.set_title('NER Extraction: Detailed Breakdown by Attribute', 
                 fontsize=15, fontweight='bold', pad=20)
    ax.set_xlim([0, 100])
    
    # Add grid
    ax.grid(axis='x', alpha=0.3, linestyle='--')
    ax.set_axisbelow(True)
    
    # Legend
    ax.legend(loc='lower right', framealpha=0.95, fontsize=11)
    
    # Add percentage labels for each segment
    for i, (m, miss, misma) in enumerate(zip(match, missing, mismatch)):
        # Match
        if m > 8:
            ax.text(m/2, i, f'{m:.1f}%', ha='center', va='center', 
                   fontsize=9, fontweight='bold', color='white')
        
        # Missing
        if miss > 8:
            ax.text(m + miss/2, i, f'{miss:.1f}%', ha='center', va='center',
                   fontsize=9, fontweight='bold', color='white')
        
        # Mismatch
        if misma > 8:
            ax.text(m + miss + misma/2, i, f'{misma:.1f}%', ha='center', va='center',
                   fontsize=9, fontweight='bold', color='white')
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f" Saved: {output_path}")
    plt.close()


def plot_field_comparison_heatmap(df, output_path):

    # Sort by match rate
    df = df.sort_values('match_rate', ascending=False)
    
    # Prepare data
    data = df[['match_rate', 'missing_rate', 'mismatch_rate']].values * 100
    fields = df['field'].tolist()
    
    fig, ax = plt.subplots(figsize=(10, max(8, len(df) * 0.4)))
    
    im = ax.imshow(data, cmap='inferno', aspect='auto', vmin=0, vmax=100)
    
    # Set ticks
    ax.set_xticks([0, 1, 2])
    ax.set_xticklabels(['Match', 'Missing', 'Mismatch'], fontsize=12, fontweight='bold')
    ax.set_yticks(range(len(fields)))
    ax.set_yticklabels(fields, fontsize=11)
    
    # Add text annotations
    for i in range(len(fields)):
        for j in range(3):
            text = ax.text(j, i, f'{data[i, j]:.1f}%',
                          ha='center', va='center', 
                          color='white' if data[i, j] < 50 else 'black',
                          fontsize=10, fontweight='bold')
    
    # Colorbar
    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label('Percentage (%)', rotation=270, labelpad=20, fontsize=12, fontweight='bold')
    
    ax.set_title('NER Extraction: Performance Heatmap', 
                 fontsize=15, fontweight='bold', pad=20)
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"✓ Saved: {output_path}")
    plt.close()


# =============================================================================
# 2. PER-PATIENT VISUALIZATIONS
# =============================================================================

def plot_patient_accuracy_distribution(df, output_path):

    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    n, bins, patches = ax1.hist(df['accuracy'] * 100, bins=20, 
                                 color='#56B4E9', alpha=0.7, edgecolor='black', linewidth=1.5)
    
    for i, patch in enumerate(patches):
        bin_center = (bins[i] + bins[i+1]) / 2
        if bin_center >= 90:
            patch.set_facecolor('#E69F00')   # Orange - Excellent
        elif bin_center >= 80:
            patch.set_facecolor('#56B4E9')   # Sky blue - Good
        elif bin_center >= 70:
            patch.set_facecolor('#009E73')   # Bluish green - Fair
        else:
            patch.set_facecolor('#F0E442')   # Yellow - Poor
    
    ax1.set_xlabel('Accuracy (%)', fontsize=12, fontweight='bold')
    ax1.set_ylabel('Number of Patients', fontsize=12, fontweight='bold')
    ax1.set_title('Distribution of Per-Patient Accuracy', fontsize=14, fontweight='bold')
    ax1.grid(axis='y', alpha=0.3)
    
    # Add statistics
    mean_acc = df['accuracy'].mean() * 100
    median_acc = df['accuracy'].median() * 100
    std_acc = df['accuracy'].std() * 100
    
    stats_text = f'Mean: {mean_acc:.1f}%\nMedian: {median_acc:.1f}%\nStd: {std_acc:.1f}%'
    ax1.text(0.02, 0.98, stats_text, transform=ax1.transAxes,
            fontsize=11, verticalalignment='top',
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # Add reference lines 
    ax1.axvline(x=mean_acc, color='#E69F00', linestyle='--', linewidth=2, label=f'Mean: {mean_acc:.1f}%')
    ax1.axvline(x=median_acc, color='#56B4E9', linestyle='--', linewidth=2, label=f'Median: {median_acc:.1f}%')
    ax1.legend(loc='upper left')
    
    # Box plot -
    bp = ax2.boxplot([df['accuracy'] * 100], vert=True, widths=0.5, patch_artist=True,
                      showmeans=True, meanline=True,
                      boxprops=dict(facecolor='#56B4E9', alpha=0.7, edgecolor='black', linewidth=2),
                      whiskerprops=dict(color='black', linewidth=2),
                      capprops=dict(color='black', linewidth=2),
                      medianprops=dict(color='#E69F00', linewidth=3),
                      meanprops=dict(color='#009E73', linewidth=3, linestyle='--'))
    
    ax2.set_ylabel('Accuracy (%)', fontsize=14, fontweight='bold')
    ax2.set_title('Per-Patient Accuracy Summary', fontsize=14, fontweight='bold')
    ax2.set_xticks([1])
    ax2.set_xticklabels(['All Patients'], fontsize=14)
    ax2.grid(axis='y', alpha=0.3)
    
    # Add percentile annotations
    q25 = df['accuracy'].quantile(0.25) * 100
    q75 = df['accuracy'].quantile(0.75) * 100
    
    ax2.text(1.15, median_acc, f'Median: {median_acc:.1f}%', 
            fontsize=16, va='center', fontweight='bold')
    ax2.text(1.15, q75, f'Q3: {q75:.1f}%', fontsize=16, va='center', style='italic')
    ax2.text(1.15, q25, f'Q1: {q25:.1f}%', fontsize=16, va='center', style='italic')
    
    plt.suptitle('NER Extraction: Per-Patient Accuracy Analysis', 
                 fontsize=16, fontweight='bold', y=1.02)
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f" Saved: {output_path}")
    plt.close()


def plot_patient_accuracy_cumulative(df, output_path):

    
    # Sort accuracies
    sorted_acc = np.sort(df['accuracy'].values) * 100
    cumulative = np.arange(1, len(sorted_acc) + 1) / len(sorted_acc) * 100
    
    fig, ax = plt.subplots(figsize=(12, 7))
    
    ax.plot(sorted_acc, cumulative, linewidth=3, color='#56B4E9', label='Cumulative Distribution')
    ax.fill_between(sorted_acc, 0, cumulative, alpha=0.3, color='#56B4E9')
    
    ax.set_xlabel('Accuracy (%)', fontsize=13, fontweight='bold')
    ax.set_ylabel('Cumulative Percentage of Patients (%)', fontsize=13, fontweight='bold')
    ax.set_title('Cumulative Distribution of Per-Patient Accuracy', 
                 fontsize=15, fontweight='bold', pad=20)
    ax.grid(True, alpha=0.3)
    ax.set_xlim([0, 100])
    ax.set_ylim([0, 100])
    
    # Add reference lines for key percentiles - Okabe-Ito palette
    percentiles = [50, 80, 90, 95]
    for p in percentiles:
        acc_at_p = np.percentile(sorted_acc, p)
        ax.axhline(y=p, color='#009E73', linestyle='--', alpha=0.5, linewidth=1)
        ax.axvline(x=acc_at_p, color='#009E73', linestyle='--', alpha=0.5, linewidth=1)
        ax.plot(acc_at_p, p, 'o', color='maroon', markersize=8)
        ax.text(acc_at_p + 2, p - 3, f'{p}th: {acc_at_p:.1f}%', 
               fontsize=10, fontweight='bold',
               bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # Add legend
    ax.legend(loc='lower right', fontsize=11, framealpha=0.95)
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"✓ Saved: {output_path}")
    plt.close()


def plot_patient_performance_categories(df, output_path):

    # Categorize patients
    df['category'] = pd.cut(df['accuracy'] * 100, 
                            bins=[0, 70, 80, 90, 100],
                            labels=['Poor (<70%)', 'Fair (70-80%)', 'Good (80-90%)', 'Excellent (≥90%)'])
    
    category_counts = df['category'].value_counts().sort_index()
    
    fig, ax = plt.subplots(figsize=(10, 7))
    
    colors = ['#F0E442', '#009E73', '#56B4E9', '#E69F00']
    bars = ax.bar(range(len(category_counts)), category_counts.values, 
                  color=colors, alpha=0.85, edgecolor='black', linewidth=2)
    
    ax.set_xticks(range(len(category_counts)))
    ax.set_xticklabels(category_counts.index, fontsize=16, fontweight='bold')
    ax.set_ylabel('Number of Patients', fontsize=16, fontweight='bold')
    ax.set_title('Patient Distribution by Performance Category', 
                 fontsize=15, fontweight='bold', pad=20)
    ax.grid(axis='y', alpha=0.3)
    
    # Add value labels and percentages
    total = len(df)
    for i, (bar, count) in enumerate(zip(bars, category_counts.values)):
        height = bar.get_height()
        percentage = (count / total) * 100
        ax.text(bar.get_x() + bar.get_width()/2., height,
               f'{int(count)}\n({percentage:.1f}%)',
               ha='center', va='bottom', fontsize=16, fontweight='bold')
    
    # Add total annotation
    ax.text(0.02, 0.98, f'Total Patients: {total}', 
           transform=ax.transAxes, fontsize=16, fontweight='bold',
           verticalalignment='top',
           bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f" Saved: {output_path}")
    plt.close()


# =============================================================================
# 3. MAIN VISUALIZATION FUNCTION
# =============================================================================

def visualize_ner_comparison(
    field_summary_path='evaluation_results/ner_comparison/field_comparison_summary.csv',
    patient_results_path='evaluation_results/ner_comparison/patient_comparison.csv',
    output_dir='evaluation_results/ner_comparison/plots'
):
 
    
    print("\n" + "="*80)
    print("VISUALIZING NER COMPARISON RESULTS")
    print("="*80)
    
    # Create output directory
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Load data
    print("\nLoading data...")
    field_df = pd.read_csv(field_summary_path)
    patient_df = pd.read_csv(patient_results_path)
    
    print(f" Loaded {len(field_df)} fields")
    print(f" Loaded {len(patient_df)} patients")
    
    # Generate visualizations
    print(f"\n{'='*80}")
    print("GENERATING FIELD VISUALIZATIONS")
    print(f"{'='*80}\n")
    
    plot_field_match_rates(field_df, output_dir / '1_field_match_rates.png')
    plot_field_breakdown_stacked(field_df, output_dir / '2_field_breakdown_stacked.png')
    plot_field_comparison_heatmap(field_df, output_dir / '3_field_heatmap.png')
    
    print(f"\n{'='*80}")
    print("GENERATING PATIENT VISUALIZATIONS")
    print(f"{'='*80}\n")
    
    plot_patient_accuracy_distribution(patient_df, output_dir / '4_patient_accuracy_distribution.png')
    plot_patient_accuracy_cumulative(patient_df, output_dir / '5_patient_accuracy_cumulative.png')
    plot_patient_performance_categories(patient_df, output_dir / '6_patient_performance_categories.png')
    
    # Summary
    print(f"\n{'='*80}")
    print(" VISUALIZATION COMPLETE!")
    print(f"{'='*80}")
    
    print(f"\n All plots saved to: {output_dir}")
    
    print(f"\nField Performance:")
    print(f"  Best:  {field_df.loc[field_df['match_rate'].idxmax(), 'field']} "
          f"({field_df['match_rate'].max()*100:.1f}%)")
    print(f"  Worst: {field_df.loc[field_df['match_rate'].idxmin(), 'field']} "
          f"({field_df['match_rate'].min()*100:.1f}%)")
    print(f"  Mean:  {field_df['match_rate'].mean()*100:.1f}%")
    
    print(f"\nPatient Performance:")
    print(f"  Mean:   {patient_df['accuracy'].mean()*100:.1f}%")
    print(f"  Median: {patient_df['accuracy'].median()*100:.1f}%")
    print(f"  Std:    {patient_df['accuracy'].std()*100:.1f}%")
    print(f"  Min:    {patient_df['accuracy'].min()*100:.1f}%")
    print(f"  Max:    {patient_df['accuracy'].max()*100:.1f}%")
    
    return output_dir


if __name__ == "__main__":
    output_dir = visualize_ner_comparison()