In [None]:
# CXR Report Evaluation - Modular Architecture
# Updated to use the new modular metrics system

import pandas as pd
import numpy as np
import os
import json
import time
import pickle
import hashlib
import warnings
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, clear_output

# Import the new modular metrics
from CXRMetric.metrics.rouge import ROUGEEvaluator
from CXRMetric.metrics.bleu import BLEUEvaluator
from CXRMetric.metrics.bertscore import BERTScoreEvaluator
from CXRMetric.metrics.perplexity import PerplexityEvaluator
from CXRMetric.metrics.composite import CompositeEvaluator
from CXRMetric.metrics.semantic_embedding import SemanticEmbeddingEvaluator

# Import configuration
import sys
sys.path.append('..')
from config import *

# Set up plotting style
sns.set(style='whitegrid')
warnings.filterwarnings('ignore')

print("🚀 CXR Report Evaluation - Modular Architecture")
print("=" * 60)
print("Available metrics:")
print("  ✅ ROUGE-L (lightweight)")
print("  ✅ BLEU-4 (lightweight)")
print("  ✅ BERTScore (moderate GPU usage)")
print("  ✅ Perplexity (GPU-accelerated)")
print("  ✅ Composite RadCliQ (v0, v1)")
print("  ✅ Semantic Embedding (CheXbert)")
print("=" * 60)

# Initialize metric evaluators (will be done later with user options)
evaluators = {}

def initialize_evaluators(config_options):
    """Initialize metric evaluators based on user configuration."""
    global evaluators
    evaluators = {}
    
    # Always available lightweight metrics
    evaluators['rouge'] = ROUGEEvaluator(beta=1.2)
    evaluators['bleu'] = BLEUEvaluator()
    
    # GPU-dependent metrics
    if config_options.get('enable_bertscore', True):
        evaluators['bertscore'] = BERTScoreEvaluator(use_idf=config_options.get('use_idf', False))
    
    if config_options.get('enable_perplexity', False):
        model_name = config_options.get('perplexity_model', 'distilgpt2')
        evaluators['perplexity'] = PerplexityEvaluator(model_name=model_name)
    
    if config_options.get('enable_composite', True):
        evaluators['composite'] = CompositeEvaluator()
    
    if config_options.get('enable_semantic', False):
        evaluators['semantic'] = SemanticEmbeddingEvaluator()
    
    print(f"📊 Initialized {len(evaluators)} metric evaluators")
    return evaluators

def get_cache_key(input_paths, options):
    """Generate a unique cache key based on file paths and evaluation options."""
    file_info = []
    for path in input_paths:
        if os.path.exists(path):
            mtime = os.path.getmtime(path)
            file_info.append(f"{path}:{mtime}")
    
    content = "|".join(file_info) + "|" + str(sorted(options.items()))
    return hashlib.md5(content.encode()).hexdigest()

def load_cached_result(cache_key, cache_dir="cache/results"):
    """Load cached evaluation result if it exists."""
    cache_path = os.path.join(cache_dir, f"{cache_key}.pkl")
    if os.path.exists(cache_path):
        try:
            with open(cache_path, 'rb') as f:
                cached_data = pickle.load(f)
            print(f"✓ Loaded cached result from {cache_path}")
            return cached_data['pred_df'], cached_data['summary']
        except Exception as e:
            print(f"Failed to load cache: {e}")
    return None

def save_cached_result(cache_key, pred_df, summary, cache_dir="cache/results"):
    """Save evaluation result to cache for future runs."""
    os.makedirs(cache_dir, exist_ok=True)
    cache_path = os.path.join(cache_key, f"{cache_key}.pkl")
    try:
        cached_data = {
            'pred_df': pred_df,
            'summary': summary,
            'timestamp': time.time()
        }
        with open(cache_path, 'wb') as f:
            pickle.dump(cached_data, f)
        print(f"✓ Saved result to cache: {cache_path}")
    except Exception as e:
        print(f"Failed to save cache: {e}")

def run_modular_evaluation(gt_csv, pred_csv, out_csv, config_options):
    """Run evaluation using the new modular metrics architecture."""
    print("🚀 Starting modular evaluation pipeline...")
    start_time = time.time()
    
    # Create cache key
    cache_key = get_cache_key([gt_csv, pred_csv], config_options)
    
    # Try cache first
    cached_result = load_cached_result(cache_key)
    if cached_result is not None:
        pred_df, summary = cached_result
        print(f"⚡ Evaluation completed in {time.time() - start_time:.1f}s (from cache)")
        return pred_df, summary
    
    # Initialize evaluators based on configuration
    evaluators = initialize_evaluators(config_options)
    
    # Load and align datasets
    print("📊 Loading and aligning datasets...")
    gt = pd.read_csv(gt_csv).sort_values(by=[STUDY_ID_COL_NAME]).reset_index(drop=True)
    pred = pd.read_csv(pred_csv).sort_values(by=[STUDY_ID_COL_NAME]).reset_index(drop=True)
    
    # Find shared study IDs
    shared_ids = set(gt[STUDY_ID_COL_NAME]).intersection(set(pred[STUDY_ID_COL_NAME]))
    print(f"Found {len(shared_ids)} shared study IDs")
    
    gt = gt[gt[STUDY_ID_COL_NAME].isin(shared_ids)].reset_index(drop=True)
    pred = pred[pred[STUDY_ID_COL_NAME].isin(shared_ids)].reset_index(drop=True)
    
    # Initialize results dataframe
    results_df = pred.copy()
    summary = {'mean_metrics': {}}
    
    # Run each metric evaluation
    for metric_name, evaluator in evaluators.items():
        try:
            print(f"🔄 Computing {metric_name} metrics...")
            metric_start = time.time()
            
            # Compute metric using modular interface
            metric_results = evaluator.compute_metric(gt, pred)
            
            # Add results to main dataframe
            for col_name, values in metric_results.items():
                results_df[col_name] = values
                summary['mean_metrics'][col_name] = float(np.nanmean(values))
            
            metric_time = time.time() - metric_start
            print(f"  ✅ {metric_name} completed in {metric_time:.2f}s")
            
        except Exception as e:
            print(f"  ❌ {metric_name} failed: {str(e)}")
            # Add placeholder columns for failed metrics
            if metric_name == 'rouge':
                results_df['rouge_l'] = [0.0] * len(results_df)
            elif metric_name == 'bleu':
                results_df['bleu4_score'] = [0.0] * len(results_df)
            elif metric_name == 'bertscore':
                results_df['bertscore_f1'] = [0.0] * len(results_df)
            elif metric_name == 'perplexity':
                results_df['perplexity_generated'] = [100.0] * len(results_df)
                results_df['perplexity_reference'] = [100.0] * len(results_df)
            elif metric_name == 'composite':
                results_df['radcliq_v0'] = [0.0] * len(results_df)
                results_df['radcliq_v1'] = [0.0] * len(results_df)
            elif metric_name == 'semantic':
                results_df['semantic_similarity'] = [0.0] * len(results_df)
    
    # Save results
    print("💾 Saving results...")
    results_df.to_csv(out_csv, index=False)
    
    summary_path = out_csv + '.summary.json'
    with open(summary_path, 'w') as f:
        json.dump(summary, f, indent=2)
    
    # Cache results
    save_cached_result(cache_key, results_df, summary)
    
    elapsed_time = time.time() - start_time
    print(f"✅ Modular evaluation completed in {elapsed_time:.1f}s")
    
    return results_df, summary

# CXR Report Evaluation - Modular Metrics Notebook

This notebook provides an interactive interface for evaluating CXR report generation using the **new modular metrics architecture**.

## ✨ New Modular Architecture Features

### 🎯 Available Metrics
- **ROUGE-L**: Content coverage and paraphrasing (pure Python, fastest)
- **BLEU-4**: Precision-focused n-gram matching (fast)  
- **BERTScore**: Semantic similarity with transformers (moderate speed)
- **Perplexity**: Text fluency with GPT models (GPU-accelerated)
- **Composite RadCliQ**: Clinical quality assessment (v0, v1)
- **Semantic Embedding**: Medical-specific similarity with CheXbert

### 🏗️ Modular Benefits
- **Flexible**: Enable/disable individual metrics as needed
- **Cacheable**: Intelligent caching for faster repeated evaluations
- **GPU-Optimized**: Automatic GPU detection and acceleration
- **Consistent Interface**: All metrics follow the same API pattern
- **Error Resilient**: Failed metrics don't break the entire evaluation

## 📋 Prerequisites

### 1. Install Dependencies
```bash
pip install -r requirements.txt
```

### 2. Configure Paths (Optional)
Update `config.py` if using:
- CheXbert model path (for semantic embedding)
- Custom model checkpoints
- GPU vs CPU preferences

### 3. GPU Support (Recommended)
- For Perplexity metrics: PyTorch with CUDA
- For BERTScore: GPU acceleration available
- For Semantic Embedding: CheXbert model benefits from GPU

## 🚀 Quick Start
1. **Configure file paths** in the interface below
2. **Select desired metrics** (lightweight metrics are always enabled)
3. **Enable GPU metrics** if you have GPU support
4. **Run evaluation** and explore results with built-in visualizations

## 📊 Output Files
- **Results CSV**: Per-report metric scores
- **Summary JSON**: Aggregated statistics and metadata
- **Cache files**: For faster repeated runs

---
*Updated for modular architecture - September 2025*

In [None]:
# Interactive Modular Evaluation Interface
import ipywidgets as widgets
from IPython.display import display, clear_output

# File path configuration
gt_csv_widget = widgets.Text(
    value='path/to/ground_truth.csv',
    placeholder='Enter path to ground truth CSV',
    description='GT CSV:',
    style={'description_width': '120px'}
)

pred_csv_widget = widgets.Text(
    value='path/to/predictions.csv', 
    placeholder='Enter path to predictions CSV',
    description='Pred CSV:',
    style={'description_width': '120px'}
)

out_csv_widget = widgets.Text(
    value='results/modular_evaluation_results.csv',
    placeholder='Enter output path for results',
    description='Output CSV:',
    style={'description_width': '120px'}
)

# Modular metric configuration
enable_rouge_widget = widgets.Checkbox(
    value=True,
    description='✅ ROUGE-L (Always enabled - lightweight)',
    disabled=True,
    style={'description_width': 'initial'}
)

enable_bleu_widget = widgets.Checkbox(
    value=True,
    description='✅ BLEU-4 (Always enabled - lightweight)', 
    disabled=True,
    style={'description_width': 'initial'}
)

enable_bertscore_widget = widgets.Checkbox(
    value=True,
    description='🤖 BERTScore (Transformer-based semantic similarity)',
    style={'description_width': 'initial'}
)

use_idf_widget = widgets.Checkbox(
    value=False,
    description='    ↳ Use IDF weighting (slightly slower but better quality)',
    style={'description_width': 'initial'}
)

enable_perplexity_widget = widgets.Checkbox(
    value=False,
    description='🔥 Perplexity (GPU-accelerated text fluency)',
    style={'description_width': 'initial'}
)

perplexity_model_widget = widgets.Dropdown(
    options=['distilgpt2', 'gpt2'],
    value='distilgpt2',
    description='    ↳ Model:',
    style={'description_width': '80px'}
)

enable_composite_widget = widgets.Checkbox(
    value=True,
    description='🏥 Composite RadCliQ (Clinical quality assessment)',
    style={'description_width': 'initial'}
)

enable_semantic_widget = widgets.Checkbox(
    value=False,
    description='🧠 Semantic Embedding (CheXbert-based medical similarity)',
    style={'description_width': 'initial'}
)

# Visualization options
available_viz_metrics = ['rouge_l', 'bleu4_score', 'bertscore_f1', 'perplexity_generated', 
                        'perplexity_reference', 'radcliq_v0', 'radcliq_v1', 'semantic_similarity']

viz_metrics_widget = widgets.SelectMultiple(
    options=available_viz_metrics,
    value=['rouge_l', 'bleu4_score', 'bertscore_f1'],
    description='Visualization metrics:',
    style={'description_width': 'initial'}
)

sample_size_widget = widgets.IntSlider(
    value=1000,
    min=100,
    max=5000,
    step=100,
    description='Sample size for plots:',
    style={'description_width': '150px'}
)

# Output area
output = widgets.Output()

# Global variables for results
evaluation_results = None
evaluation_summary = None

def run_modular_evaluation_click(b):
    """Handler for the modular evaluation button."""
    with output:
        clear_output()
        print("🔄 Running modular evaluation...")
        
        try:
            # Prepare configuration options
            config_options = {
                'enable_bertscore': enable_bertscore_widget.value,
                'use_idf': use_idf_widget.value,
                'enable_perplexity': enable_perplexity_widget.value,
                'perplexity_model': perplexity_model_widget.value,
                'enable_composite': enable_composite_widget.value,
                'enable_semantic': enable_semantic_widget.value
            }
            
            global evaluation_results, evaluation_summary
            evaluation_results, evaluation_summary = run_modular_evaluation(
                gt_csv=gt_csv_widget.value,
                pred_csv=pred_csv_widget.value,
                out_csv=out_csv_widget.value,
                config_options=config_options
            )
            
            # Display summary metrics
            print("\n📊 Evaluation Results Summary:")
            print("-" * 40)
            if 'mean_metrics' in evaluation_summary:
                for metric, value in evaluation_summary['mean_metrics'].items():
                    if metric == 'perplexity_generated' or metric == 'perplexity_reference':
                        print(f"  {metric}: {value:.2f}")
                    else:
                        print(f"  {metric}: {value:.4f}")
            
            print(f"\n📋 Total reports evaluated: {len(evaluation_results)}")
            print(f"📄 Results saved to: {out_csv_widget.value}")
            print("\n✅ Modular evaluation complete! Use visualization buttons below.")
            
        except Exception as e:
            print(f"❌ Error during evaluation: {str(e)}")
            import traceback
            print(traceback.format_exc())

# Visualization functions (updated for modular metrics)
def plot_distributions_click(b):
    """Plot metric distributions."""
    with output:
        if evaluation_results is None:
            print("❌ Please run evaluation first!")
            return
        
        selected_metrics = [m for m in viz_metrics_widget.value if m in evaluation_results.columns]
        if not selected_metrics:
            print("❌ No valid metrics selected for visualization!")
            return
        
        print("📈 Plotting metric distributions...")
        plot_metric_distributions(evaluation_results, selected_metrics)

def plot_boxplots_click(b):
    """Create boxplots for metrics."""
    with output:
        if evaluation_results is None:
            print("❌ Please run evaluation first!")
            return
            
        selected_metrics = [m for m in viz_metrics_widget.value if m in evaluation_results.columns]
        if not selected_metrics:
            print("❌ No valid metrics selected for visualization!")
            return
        
        print("📦 Creating boxplots...")
        plot_metric_boxplots(evaluation_results, selected_metrics)

def plot_correlations_click(b):
    """Plot correlation heatmap."""
    with output:
        if evaluation_results is None:
            print("❌ Please run evaluation first!")
            return
            
        selected_metrics = [m for m in viz_metrics_widget.value if m in evaluation_results.columns]
        if len(selected_metrics) < 2:
            print("❌ Need at least 2 metrics for correlation analysis!")
            return
        
        print("🔗 Plotting correlation heatmap...")
        plot_correlation_heatmap(evaluation_results, selected_metrics)

def show_top_bottom_click(b):
    """Show top and bottom performing reports."""
    with output:
        if evaluation_results is None:
            print("❌ Please run evaluation first!")
            return
            
        selected_metrics = [m for m in viz_metrics_widget.value if m in evaluation_results.columns]
        if not selected_metrics:
            print("❌ No valid metrics selected!")
            return
        
        for metric in selected_metrics[:3]:  # Limit to first 3 metrics
            print(f"\n📋 Analysis for {metric}:")
            show_top_bottom_reports(evaluation_results, metric, k=3)

# Visualization helper functions (simplified versions)
def plot_metric_distributions(df, metrics, bins=30, figsize=(12, 8)):
    """Plot distributions for selected metrics."""
    n_metrics = len(metrics)
    n_cols = min(3, n_metrics)
    n_rows = (n_metrics + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
    if n_metrics == 1:
        axes = [axes]
    elif n_rows == 1:
        axes = axes.reshape(1, -1)
    
    for i, metric in enumerate(metrics):
        row, col = i // n_cols, i % n_cols
        ax = axes[row, col] if n_rows > 1 else axes[col]
        
        data = df[metric].dropna()
        if len(data) > 0:
            sns.histplot(data, kde=True, bins=bins, ax=ax, alpha=0.7)
            ax.set_title(f'{metric}')
            mean_val, std_val = data.mean(), data.std()
            ax.text(0.02, 0.98, f'μ={mean_val:.3f}\nσ={std_val:.3f}', 
                   transform=ax.transAxes, verticalalignment='top',
                   bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # Hide empty subplots
    for i in range(n_metrics, n_rows * n_cols):
        row, col = i // n_cols, i % n_cols
        ax = axes[row, col] if n_rows > 1 else axes[col]
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

def plot_metric_boxplots(df, metrics, figsize=(12, 6)):
    """Create boxplots for metrics."""
    data = df[metrics].melt(var_name='metric', value_name='value')
    
    plt.figure(figsize=figsize)
    ax = sns.boxplot(x='metric', y='value', data=data)
    
    means = data.groupby('metric')['value'].mean()
    for i, metric in enumerate(metrics):
        if metric in means.index:
            ax.scatter(i, means[metric], marker='D', s=50, color='red', zorder=10)
    
    plt.xticks(rotation=45, ha='right')
    plt.title('Metric Distribution Comparison')
    plt.tight_layout()
    plt.show()

def plot_correlation_heatmap(df, metrics, figsize=(10, 8)):
    """Plot correlation heatmap."""
    corr_matrix = df[metrics].corr()
    
    plt.figure(figsize=figsize)
    mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
    sns.heatmap(corr_matrix, mask=mask, annot=True, cmap='coolwarm', 
                center=0, square=True, fmt='.3f')
    plt.title('Metric Correlation Heatmap')
    plt.tight_layout()
    plt.show()

def show_top_bottom_reports(df, metric, k=5):
    """Show top and bottom reports by metric."""
    if metric not in df.columns:
        print(f'Metric {metric} not found!')
        return
    
    print(f"📊 {metric}: {df[metric].min():.3f} to {df[metric].max():.3f} (μ={df[metric].mean():.3f})")
    
    # Top performers
    top_df = df.nlargest(k, metric)
    print(f"🏆 Top {k} reports:")
    display(top_df[[STUDY_ID_COL_NAME, REPORT_COL_NAME, metric]].round(3))
    
    # Bottom performers
    bottom_df = df.nsmallest(k, metric)
    print(f"⚠️ Bottom {k} reports:")
    display(bottom_df[[STUDY_ID_COL_NAME, REPORT_COL_NAME, metric]].round(3))

# Create interface buttons
run_button = widgets.Button(
    description='🚀 Run Modular Evaluation',
    button_style='success',
    layout=widgets.Layout(width='220px', height='40px')
)
run_button.on_click(run_modular_evaluation_click)

# Visualization buttons
viz_buttons = [
    ('📈 Distributions', plot_distributions_click),
    ('📦 Boxplots', plot_boxplots_click),
    ('🔗 Correlations', plot_correlations_click),
    ('📋 Top/Bottom', show_top_bottom_click)
]

viz_button_widgets = []
for desc, handler in viz_buttons:
    btn = widgets.Button(
        description=desc,
        button_style='info',
        layout=widgets.Layout(width='140px', height='35px')
    )
    btn.on_click(handler)
    viz_button_widgets.append(btn)

clear_button = widgets.Button(
    description='🗑️ Clear',
    button_style='warning',
    layout=widgets.Layout(width='80px', height='30px')
)
clear_button.on_click(lambda b: output.clear_output())

# Display the interface
print("🎛️ Modular CXR Report Evaluation Interface")
print("=" * 55)

print("\n📁 File Configuration:")
display(widgets.VBox([gt_csv_widget, pred_csv_widget, out_csv_widget]))

print("\n🎯 Metric Configuration:")
display(widgets.VBox([
    widgets.HTML("<b>Lightweight metrics (always enabled):</b>"),
    enable_rouge_widget,
    enable_bleu_widget,
    widgets.HTML("<br><b>Advanced metrics (configurable):</b>"),
    enable_bertscore_widget,
    use_idf_widget,
    enable_perplexity_widget,
    perplexity_model_widget,
    enable_composite_widget,
    enable_semantic_widget
]))

print("\n🚀 Evaluation:")
display(run_button)

print("\n📊 Visualization:")
display(widgets.VBox([
    widgets.HTML("<b>Select metrics for visualization:</b>"),
    viz_metrics_widget,
    sample_size_widget,
    widgets.HTML("<br><b>Visualization Actions:</b>"),
    widgets.HBox(viz_button_widgets + [clear_button])
]))

print("\n📋 Output:")
display(output)

print("\n" + "=" * 55)
print("💡 Usage Tips:")
print("• GPU metrics (Perplexity, BERTScore) are much faster with CUDA")
print("• Results are automatically cached for faster repeated runs")  
print("• Enable only needed metrics to optimize evaluation time")
print("• Check utility_scripts/ folder for additional evaluation tools")