In [None]:
import os
print(os.getcwd())
os.chdir("/Users/M1HR/Desktop/MIGRAINE")

In [None]:

import pandas as pd
import numpy as np
import json
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    classification_report, confusion_matrix
)
import warnings
warnings.filterwarnings('ignore')

sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 8)


# =============================================================================
# LOAD DATA
# =============================================================================

def load_ground_truth(csv_path):
    
    print(f"\n{'='*80}")
    print("LOADING GROUND TRUTH")
    print(f"{'='*80}")
    
    df = pd.read_csv(csv_path)
    
    # Add Patient_ID if not present
    if 'Patient_ID' not in df.columns:
        df['Patient_ID'] = df.index + 1
    
    print(f" Loaded {len(df)} patients")
    
    if 'Type' not in df.columns:
        raise ValueError("CSV must have 'Type' column")
    
    print(f"\nGround Truth Distribution:")
    dist = df['Type'].value_counts().sort_index()
    for diag, count in dist.items():
        print(f"  {diag:40s}: {count:4d} ({count/len(df)*100:5.1f}%)")
    
    return df


def load_symbolic_predictions(json_path):
    
    print(f"\n{'='*80}")
    print("LOADING SYMBOLIC PREDICTIONS")
    print(f"{'='*80}")
    
    with open(json_path, 'r') as f:
        results = json.load(f)
    
    # Convert to DataFrame
    rows = []
    for result in results:
        rows.append({
            'patient_id': result['patient_id'],
            'diagnosis': result['diagnosis'],
            'code': result.get('code', 'N/A'),
            'confidence': result.get('confidence', 'unknown')
        })
    
    df = pd.DataFrame(rows)
    
    print(f"✓ Loaded {len(df)} predictions")
    
    print(f"\nPredicted Distribution:")
    dist = df['diagnosis'].value_counts().sort_index()
    for diag, count in dist.items():
        print(f"  {diag:40s}: {count:4d} ({count/len(df)*100:5.1f}%)")
    
    return df


def normalize_labels(label):
    
    if pd.isna(label):
        return 'Unknown'
    
    label = str(label).strip().lower()
    
    mappings = {
        # Ground truth variations
        'migraine without aura': 'Migraine without aura',
        'migraine with aura': 'Typical aura with headache',
        'migraine with typical aura': 'Typical aura with headache',
        'typical aura with migraine': 'Typical aura with headache',
        'typical aura with headache': 'Typical aura with headache',
        'typical aura without headache': 'Typical aura without headache',
        'typical aura without migraine': 'Typical aura without headache',
        'familial hemiplegic migraine': 'Familial hemiplegic migraine',
        'sporadic hemiplegic migraine': 'Sporadic hemiplegic migraine',
        'basilar-type migraine': 'Basilar-type aura',
        'basilar-type aura': 'Basilar-type aura',
        'basilar migraine': 'Basilar-type aura',
        'other': 'Other',
        'no diagnosis': 'Other'
    }
    
    return mappings.get(label, label.title())


# =============================================================================
# MERGE AND EVALUATE
# =============================================================================

def merge_and_evaluate(ground_truth_df, predictions_df):
    
    print(f"\n{'='*80}")
    print("MERGING DATA")
    print(f"{'='*80}")
    
    merged = ground_truth_df.merge(
        predictions_df,
        left_on='Patient_ID',
        right_on='patient_id',
        how='inner'
    )
    
    print(f" Merged {len(merged)} patients")
    
    # Normalize labels
    merged['true_diagnosis'] = merged['Type'].apply(normalize_labels)
    merged['predicted_diagnosis'] = merged['diagnosis'].apply(normalize_labels)
    
    # Calculate metrics
    y_true = merged['true_diagnosis']
    y_pred = merged['predicted_diagnosis']
    
    metrics = {
        'accuracy': accuracy_score(y_true, y_pred),
        'precision_macro': precision_score(y_true, y_pred, average='macro', zero_division=0),
        'precision_weighted': precision_score(y_true, y_pred, average='weighted', zero_division=0),
        'recall_macro': recall_score(y_true, y_pred, average='macro', zero_division=0),
        'recall_weighted': recall_score(y_true, y_pred, average='weighted', zero_division=0),
        'f1_macro': f1_score(y_true, y_pred, average='macro', zero_division=0),
        'f1_weighted': f1_score(y_true, y_pred, average='weighted', zero_division=0)
    }
    
    print(f"\n{'='*80}")
    print("METRICS")
    print(f"{'='*80}")
    print(f"\n  Accuracy:             {metrics['accuracy']:.4f} ({metrics['accuracy']*100:.1f}%)")
    print(f"  Precision (macro):    {metrics['precision_macro']:.4f}")
    print(f"  Precision (weighted): {metrics['precision_weighted']:.4f}")
    print(f"  Recall (macro):       {metrics['recall_macro']:.4f}")
    print(f"  Recall (weighted):    {metrics['recall_weighted']:.4f}")
    print(f"  F1 (macro):           {metrics['f1_macro']:.4f}")
    print(f"  F1 (weighted):        {metrics['f1_weighted']:.4f}")
    
    # Classification report
    print(f"\n{'='*80}")
    print("CLASSIFICATION REPORT")
    print(f"{'='*80}\n")
    print(classification_report(y_true, y_pred, zero_division=0))
    
    return merged, metrics


# =============================================================================
# ERROR ANALYSIS
# =============================================================================

def analyze_errors(merged_df):
    
    print(f"\n{'='*80}")
    print("ERROR ANALYSIS")
    print(f"{'='*80}")
    
    errors = merged_df[merged_df['true_diagnosis'] != merged_df['predicted_diagnosis']].copy()
    
    total = len(merged_df)
    error_count = len(errors)
    
    print(f"\nTotal: {total} patients")
    print(f"   Correct: {total - error_count} ({(total-error_count)/total*100:.1f}%)")
    print(f"   Errors:  {error_count} ({error_count/total*100:.1f}%)")
    
    if error_count == 0:
        print("\n Perfect accuracy!")
        return
    
    print(f"\n Top 10 Error Patterns:")
    print(f"{'─'*80}")
    
    error_patterns = errors.groupby(['true_diagnosis', 'predicted_diagnosis']).size()
    error_patterns = error_patterns.sort_values(ascending=False).head(10)
    
    for (true, pred), count in error_patterns.items():
        pct = (count / total) * 100
        print(f"  {true:35s} → {pred:35s}: {count:3d} ({pct:4.1f}%)")


# =============================================================================
# VISUALIZATIONS
# =============================================================================

def plot_confusion_matrix(merged_df, output_path):
    
    y_true = merged_df['true_diagnosis']
    y_pred = merged_df['predicted_diagnosis']
    
    cm = confusion_matrix(y_true, y_pred)
    labels = sorted(y_true.unique())
    
    cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    
    fig, ax = plt.subplots(figsize=(12, 10))
    
    sns.heatmap(cm_norm, annot=True, fmt='.2f', cmap='inferno',
                xticklabels=labels, yticklabels=labels,
                ax=ax, cbar_kws={'label': 'Proportion'})
    
    ax.set_title('Symbolic Reasoning - Confusion Matrix', fontsize=14, fontweight='bold')
    ax.set_xlabel('Predicted Diagnosis', fontsize=12)
    ax.set_ylabel('True Diagnosis', fontsize=12)
    plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f" Saved: {output_path}")
    plt.close()


def plot_metrics(metrics, output_path):
    
    metric_names = ['Accuracy', 'Precision\n(weighted)', 'Recall\n(weighted)', 'F1\n(weighted)']
    values = [
        metrics['accuracy'],
        metrics['precision_weighted'],
        metrics['recall_weighted'],
        metrics['f1_weighted']
    ]
    
    okabe_ito = {
        'orange': '#E69F00',
        'sky_blue': '#56B4E9',
        'bluish_green': '#009E73'
    }
    
    colors = []
    for v in values:
        if v >= 0.8:
            colors.append(okabe_ito['orange'])       # Excellent
        elif v >= 0.6:
            colors.append(okabe_ito['sky_blue'])     # Good
        else:
            colors.append(okabe_ito['bluish_green']) # Fair
    
    fig, ax = plt.subplots(figsize=(10, 6))
    
    bars = ax.bar(metric_names, values, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)
    
    ax.set_ylabel('Score', fontsize=12)
    ax.set_title('Symbolic Reasoning Performance', fontsize=14, fontweight='bold')
    ax.set_ylim([0, 1.1])
    ax.grid(axis='y', alpha=0.3)
    
    for bar, val in zip(bars, values):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
               f'{val:.3f}',
               ha='center', va='bottom', fontsize=11, fontweight='bold')
    
    ax.axhline(y=0.8, color='#E69F00', linestyle='--', alpha=0.4, label='Excellent (≥0.8)')
    ax.axhline(y=0.6, color='#56B4E9', linestyle='--', alpha=0.4, label='Good (≥0.6)')
    ax.legend(loc='lower right')
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"Saved: {output_path}")
    plt.close()


# =============================================================================
# SAVE RESULTS
# =============================================================================

def save_results(merged_df, metrics, output_dir):
    
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    print(f"\n{'='*80}")
    print("SAVING RESULTS")
    print(f"{'='*80}")
    
    # Evaluation results
    results_path = output_dir / 'evaluation_results.csv'
    merged_df.to_csv(results_path, index=False)
    print(f"✓ Results: {results_path}")
    
    # Errors only
    errors = merged_df[merged_df['true_diagnosis'] != merged_df['predicted_diagnosis']]
    if len(errors) > 0:
        errors_path = output_dir / 'error_cases.csv'
        errors.to_csv(errors_path, index=False)
        print(f"✓ Errors: {errors_path}")
    
    # Metrics
    metrics_path = output_dir / 'metrics.json'
    with open(metrics_path, 'w') as f:
        json.dump(metrics, f, indent=2)
    print(f"✓ Metrics: {metrics_path}")
    
    print(f"\n✓ All results saved to: {output_dir}")


# =============================================================================
# MAIN
# =============================================================================

def evaluate_symbolic(
    ground_truth_path='data/migraine_with_id.csv',
    predictions_path='data/diagnoses/ichd3_diagnoses_final.json',
    output_dir='evaluation_results/symbolic'
):
    
    print("\n" + "="*80)
    print("SYMBOLIC REASONING EVALUATION")
    print("="*80)
    
    # Load data
    ground_truth = load_ground_truth(ground_truth_path)
    predictions = load_symbolic_predictions(predictions_path)
    
    # Merge and evaluate
    merged, metrics = merge_and_evaluate(ground_truth, predictions)
    
    # Error analysis
    analyze_errors(merged)
    
    # Visualizations
    print(f"\n{'='*80}")
    print("GENERATING VISUALIZATIONS")
    print(f"{'='*80}\n")
    
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    plot_confusion_matrix(merged, output_dir / 'confusion_matrix.png')
    plot_metrics(metrics, output_dir / 'metrics.png')
    
    # Save results
    save_results(merged, metrics, output_dir)
    
    # Final summary
    print("\n" + "="*80)
    print(" EVALUATION COMPLETE!")
    print("="*80)

    print(f"\n FINAL SUMMARY:")
    print(f"  Total Patients:  {len(merged)}")
    print(f"  Accuracy:        {metrics['accuracy']:.4f} ({metrics['accuracy']*100:.1f}%)")
    print(f"  F1 (weighted):   {metrics['f1_weighted']:.4f}")
    print(f"  Precision:       {metrics['precision_weighted']:.4f}")
    print(f"  Recall:          {metrics['recall_weighted']:.4f}")
    
    print(f"\n Results: {output_dir}")
    
    return metrics, merged


if __name__ == "__main__":
    metrics, results = evaluate_symbolic()