# Medical Risk Assessment: Population Deployment with Rank Preservation

**Problem**: A cardiovascular risk model trained on clinical trial data needs to be deployed for population screening. Clinical trials over-represent severe cases, so the model's risk probabilities need adjustment to match the general population's disease distribution - but critically, **patient risk rankings must be preserved** for proper triage.

## Unique Value Proposition

This example demonstrates why **rank-preserving calibration** is essential in medical applications:

- üè• **Clinical triage depends on relative risk rankings** between patients
- üìä **Population estimates need accurate marginal distributions**  
- ‚ö†Ô∏è **Standard calibration methods can scramble patient orderings**
- ‚úÖ **Our method preserves rankings while adjusting population rates**

We'll use the **UCI Heart Disease dataset** - real clinical data with documented population vs. clinical differences.

In [None]:
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.stats import spearmanr
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    accuracy_score,
    brier_score_loss,
    f1_score,
    log_loss,
    roc_auc_score,
)
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.calibration import CalibratedClassifierCV

# Import our calibration package
from rank_preserving_calibration import calibrate_dykstra

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

# Set style
plt.style.use('seaborn-v0_8')
sns.set_palette(["#e74c3c", "#f39c12", "#3498db", "#2ecc71", "#9b59b6"])
np.random.seed(42)

print("üè• MEDICAL RISK CALIBRATION WITH REAL DATA")
print("Focus: Population deployment with rank preservation")

## Load UCI Heart Disease Dataset

We'll use the famous UCI Heart Disease dataset, which contains real clinical measurements from patients.

In [None]:
def load_heart_disease_data():
    """Load and preprocess UCI Heart Disease dataset."""
    # Heart disease data (we'll fetch from UCI or use sklearn's make_classification to simulate real patterns)
    from sklearn.datasets import fetch_openml
    
    try:
        # Try to load real heart disease data from OpenML
        heart_data = fetch_openml(name='heart-disease', version=1, as_frame=True, parser='auto')
        X = heart_data.data
        y = heart_data.target
        
        # Convert target to numeric if needed
        if y.dtype == 'object':
            from sklearn.preprocessing import LabelEncoder
            le = LabelEncoder()
            y = le.fit_transform(y)
            
    except:
        # Fallback: Create realistic heart disease simulation
        print("Creating realistic heart disease simulation...")
        from sklearn.datasets import make_classification
        
        # Create a realistic 5-class heart disease severity dataset
        X, y = make_classification(
            n_samples=1000,
            n_features=13,  # Similar to actual heart disease features
            n_informative=10,
            n_redundant=3,
            n_classes=5,  # 0: No disease, 1-4: Increasing severity
            n_clusters_per_class=1,
            class_sep=0.8,
            random_state=42
        )
        
        # Create realistic feature names
        feature_names = [
            'age', 'sex', 'cp', 'trestbps', 'chol', 'fbs', 
            'restecg', 'thalach', 'exang', 'oldpeak', 'slope', 'ca', 'thal'
        ]
        
        X = pd.DataFrame(X, columns=feature_names)
        
    return X, y

# Load the data
print("üìä LOADING UCI HEART DISEASE DATA")
print("="*50)

X, y = load_heart_disease_data()

# Ensure we have 5 severity classes (0=none, 1-4=increasing severity)
if len(np.unique(y)) != 5:
    # Bin into 5 severity classes if needed
    y = pd.cut(y, bins=5, labels=[0, 1, 2, 3, 4]).astype(int)

print(f"Dataset shape: {X.shape}")
print(f"Features: {list(X.columns)[:5]}...")
print(f"Target classes: {sorted(np.unique(y))}")

# Show class distribution
class_counts = np.bincount(y)
severity_labels = ['No Disease', 'Mild', 'Moderate', 'Severe', 'Critical']

print("\nCLINICAL TRIAL DISTRIBUTION (biased toward severe cases):")
for i, (label, count) in enumerate(zip(severity_labels, class_counts)):
    pct = count / len(y) * 100
    print(f"  {label}: {count} patients ({pct:.1f}%)")

## Model Training & Clinical Trial Bias

We'll train a cardiovascular risk model and simulate the bias present in clinical trials.

In [None]:
# Preprocess data
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# Split data
X_train, X_test, y_train, y_test = train_test_split(
    X_scaled, y, test_size=0.3, random_state=42, stratify=y
)

print("ü§ñ TRAINING CARDIOVASCULAR RISK MODEL")
print("="*45)

# Train Random Forest classifier
model = RandomForestClassifier(
    n_estimators=100,
    max_depth=10,
    random_state=42,
    class_weight='balanced'
)

model.fit(X_train, y_train)

# Get predictions
y_pred = model.predict(X_test)
y_proba = model.predict_proba(X_test)

print(f"Model accuracy: {accuracy_score(y_test, y_pred):.3f}")
print(f"Test samples: {len(y_test)}")

# Current clinical trial marginals
clinical_marginals = np.mean(y_proba, axis=0)
print("\nCLINICAL TRIAL PREDICTIONS (biased):")
for i, (label, prob) in enumerate(zip(severity_labels, clinical_marginals)):
    print(f"  {label}: {prob:.3f} ({prob*100:.1f}%)")

# Multi-class AUC
auc_scores = []
for i in range(len(severity_labels)):
    if len(np.unique(y_test == i)) > 1:  # Only if both classes exist
        y_binary = (y_test == i).astype(int)
        auc = roc_auc_score(y_binary, y_proba[:, i])
        auc_scores.append(auc)
        print(f"AUC {severity_labels[i]}: {auc:.3f}")

print(f"Mean AUC: {np.mean(auc_scores):.3f}")

## Population Health Target Distribution

For population deployment, we need to match real-world cardiovascular disease prevalence.

In [None]:
print("üåç POPULATION HEALTH TARGET DISTRIBUTION")
print("="*45)

# Real-world population distribution (based on cardiovascular epidemiology)
population_distribution = np.array([
    0.75,   # No Disease: Most of population is healthy
    0.12,   # Mild: Some risk factors
    0.08,   # Moderate: Moderate risk
    0.04,   # Severe: High risk
    0.01    # Critical: Very high risk
])

print("POPULATION SCREENING TARGET DISTRIBUTION:")
for i, (label, target_pct) in enumerate(zip(severity_labels, population_distribution)):
    clinical_pct = clinical_marginals[i]
    change = target_pct - clinical_pct
    direction = "‚Üë" if change > 0 else "‚Üì" if change < 0 else "‚Üí"
    print(f"  {label}: {target_pct:.1%} (clinical: {clinical_pct:.1%}, change: {change:+.1%} {direction})")

# Calculate target marginals for calibration
n_test_samples = len(y_test)
target_marginals = population_distribution * n_test_samples

print(f"\nüéØ CALIBRATION TARGETS:")
print(f"   Test samples: {n_test_samples}")
print(f"   Target marginals: {target_marginals.astype(int)}")
print(f"   Sum check: {np.sum(target_marginals):.1f} (should equal {n_test_samples})")

print("\n‚ö†Ô∏è WHY RANK PRESERVATION IS CRITICAL:")
critical_reasons = [
    "Patient triage: Who gets priority for specialist referral?",
    "Treatment decisions: Medication intensity based on relative risk",
    "Resource allocation: ICU beds, cardiac procedures, preventive care",
    "Clinical trials: Patient stratification for drug studies",
    "Insurance: Risk-based premium calculations"
]

for reason in critical_reasons:
    print(f"   ‚Ä¢ {reason}")

## Calibration Methods Comparison

We'll compare rank-preserving calibration against standard methods.

In [None]:
from sklearn.isotonic import IsotonicRegression
from sklearn.linear_model import LogisticRegression as LogReg

def temperature_scaling(y_proba, y_true):
    """Temperature scaling calibration."""
    from scipy.optimize import minimize
    
    def temperature_loss(temp, probs, labels):
        scaled_probs = np.exp(np.log(probs + 1e-12) / temp)
        scaled_probs = scaled_probs / np.sum(scaled_probs, axis=1, keepdims=True)
        return log_loss(labels, scaled_probs)
    
    # Find optimal temperature
    temp_result = minimize(temperature_loss, 1.0, args=(y_proba, y_true), method='BFGS')
    optimal_temp = temp_result.x[0]
    
    # Apply temperature scaling
    scaled_probs = np.exp(np.log(y_proba + 1e-12) / optimal_temp)
    return scaled_probs / np.sum(scaled_probs, axis=1, keepdims=True)

def platt_scaling(y_proba, y_true):
    """Platt scaling using CalibratedClassifierCV."""
    # For multiclass, we'll use sigmoid calibration per class
    calibrated_proba = np.zeros_like(y_proba)
    
    for class_idx in range(y_proba.shape[1]):
        # Convert to binary problem
        y_binary = (y_true == class_idx).astype(int)
        
        if len(np.unique(y_binary)) > 1:  # Only calibrate if both classes exist
            # Use isotonic regression as fallback to Platt scaling
            iso_reg = IsotonicRegression(out_of_bounds='clip')
            calibrated_proba[:, class_idx] = iso_reg.fit_transform(y_proba[:, class_idx], y_binary)
        else:
            calibrated_proba[:, class_idx] = y_proba[:, class_idx]
    
    # Renormalize to valid probabilities
    return calibrated_proba / np.sum(calibrated_proba, axis=1, keepdims=True)

def histogram_binning(y_proba, y_true, n_bins=10):
    """Histogram binning calibration."""
    calibrated_proba = np.zeros_like(y_proba)
    
    for class_idx in range(y_proba.shape[1]):
        y_binary = (y_true == class_idx).astype(int)
        probs = y_proba[:, class_idx]
        
        # Create bins
        bin_boundaries = np.linspace(0, 1, n_bins + 1)
        bin_lowers = bin_boundaries[:-1]
        bin_uppers = bin_boundaries[1:]
        
        calibrated = np.zeros_like(probs)
        
        for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
            in_bin = (probs > bin_lower) & (probs <= bin_upper)
            if np.sum(in_bin) > 0:
                bin_accuracy = np.mean(y_binary[in_bin]) if np.sum(in_bin) > 0 else 0
                calibrated[in_bin] = bin_accuracy
            else:
                calibrated[in_bin] = (bin_lower + bin_upper) / 2
        
        calibrated_proba[:, class_idx] = calibrated
    
    # Renormalize
    return calibrated_proba / np.sum(calibrated_proba, axis=1, keepdims=True)

print("‚öñÔ∏è CALIBRATION METHODS COMPARISON")
print("="*40)

# Apply different calibration methods
print("\n1Ô∏è‚É£ Temperature Scaling:")
y_proba_temp = temperature_scaling(y_proba, y_test)
print(f"   Mean probability shift: {np.mean(np.abs(y_proba_temp - y_proba)):.3f}")

print("\n2Ô∏è‚É£ Platt/Isotonic Scaling:")
y_proba_platt = platt_scaling(y_proba, y_test)
print(f"   Mean probability shift: {np.mean(np.abs(y_proba_platt - y_proba)):.3f}")

print("\n3Ô∏è‚É£ Histogram Binning:")
y_proba_hist = histogram_binning(y_proba, y_test)
print(f"   Mean probability shift: {np.mean(np.abs(y_proba_hist - y_proba)):.3f}")

print("\n4Ô∏è‚É£ Rank-Preserving (Ours):")
result_ours = calibrate_dykstra(
    P=y_proba,
    M=target_marginals,
    max_iters=200,
    tol=1e-5,
    verbose=False
)
y_proba_ours = result_ours.Q
print(f"   Converged: {result_ours.converged}")
print(f"   Iterations: {result_ours.iterations}")
print(f"   Max marginal error: {result_ours.max_col_error:.2e}")
print(f"   Mean probability shift: {np.mean(np.abs(y_proba_ours - y_proba)):.3f}")

## Rank Preservation Analysis

This is the key analysis: how well does each method preserve patient risk rankings?

In [None]:
def calculate_rank_preservation(y_orig, y_cal, method_name):
    """Calculate how well rankings are preserved."""
    rank_correlations = []
    
    for i in range(len(y_orig)):
        corr, _ = spearmanr(y_orig[i], y_cal[i])
        if not np.isnan(corr):
            rank_correlations.append(corr)
    
    rank_correlations = np.array(rank_correlations)
    perfect_preservation = np.sum(np.isclose(rank_correlations, 1.0, atol=1e-10))
    scrambled = np.sum(rank_correlations < 0.95)  # Significantly scrambled
    
    return {
        'method': method_name,
        'mean_corr': np.mean(rank_correlations),
        'min_corr': np.min(rank_correlations),
        'perfect_count': perfect_preservation,
        'scrambled_count': scrambled,
        'total_patients': len(rank_correlations)
    }

def expected_calibration_error(y_true, y_proba, n_bins=10):
    """Calculate Expected Calibration Error."""
    y_pred = np.argmax(y_proba, axis=1)
    confidences = np.max(y_proba, axis=1)
    accuracies = (y_pred == y_true).astype(float)
    
    bin_boundaries = np.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]
    
    ece = 0
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = (confidences > bin_lower) & (confidences <= bin_upper)
        prop_in_bin = in_bin.mean()
        
        if prop_in_bin > 0:
            accuracy_in_bin = accuracies[in_bin].mean()
            avg_confidence_in_bin = confidences[in_bin].mean()
            ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
    
    return ece

def calculate_comprehensive_metrics(y_true, y_proba_orig, y_proba_cal, method_name):
    """Calculate all metrics for comparison."""
    y_pred = np.argmax(y_proba_cal, axis=1)
    
    # Basic metrics
    accuracy = accuracy_score(y_true, y_pred)
    log_loss_val = log_loss(y_true, y_proba_cal)
    f1_macro = f1_score(y_true, y_pred, average='macro')
    
    # AUC (average across classes)
    auc_scores = []
    for i in range(y_proba_cal.shape[1]):
        if len(np.unique(y_true == i)) > 1:
            y_binary = (y_true == i).astype(int)
            auc = roc_auc_score(y_binary, y_proba_cal[:, i])
            auc_scores.append(auc)
    auc_macro = np.mean(auc_scores)
    
    # Calibration metrics
    ece = expected_calibration_error(y_true, y_proba_cal)
    
    # Rank preservation
    rank_stats = calculate_rank_preservation(y_proba_orig, y_proba_cal, method_name)
    
    # Marginal accuracy (how close to target distribution)
    achieved_marginals = np.mean(y_proba_cal, axis=0)
    target_dist = target_marginals / np.sum(target_marginals)
    marginal_error = np.max(np.abs(achieved_marginals - target_dist))
    
    return {
        'method': method_name,
        'accuracy': accuracy,
        'log_loss': log_loss_val,
        'f1_macro': f1_macro,
        'auc_macro': auc_macro,
        'ece': ece,
        'rank_corr': rank_stats['mean_corr'],
        'scrambled_patients': rank_stats['scrambled_count'],
        'marginal_error': marginal_error
    }

print("üìä COMPREHENSIVE METHODS COMPARISON")
print("="*50)

# Calculate metrics for all methods
results = [
    calculate_comprehensive_metrics(y_test, y_proba, y_proba, "Original"),
    calculate_comprehensive_metrics(y_test, y_proba, y_proba_temp, "Temperature Scale"),
    calculate_comprehensive_metrics(y_test, y_proba, y_proba_platt, "Platt/Isotonic"),
    calculate_comprehensive_metrics(y_test, y_proba, y_proba_hist, "Histogram Bin"),
    calculate_comprehensive_metrics(y_test, y_proba, y_proba_ours, "Rank-Preserving")
]

# Create comparison DataFrame
df_results = pd.DataFrame(results)

print(f"{'Method':<16} {'Accuracy':<8} {'AUC':<6} {'ECE':<6} {'RankCorr':<8} {'Scrambled':<9} {'MargErr':<8}")
print("-" * 70)

for _, row in df_results.iterrows():
    print(f"{row['method']:<16} {row['accuracy']:<8.3f} {row['auc_macro']:<6.3f} {row['ece']:<6.3f} "
          f"{row['rank_corr']:<8.4f} {row['scrambled_patients']:<9} {row['marginal_error']:<8.3f}")

print("\nüéØ KEY INSIGHTS:")
print(f"‚Ä¢ Rank-Preserving has {df_results.loc[4, 'scrambled_patients']} scrambled patients vs {df_results.loc[1, 'scrambled_patients']} for Temperature Scaling")
print(f"‚Ä¢ Rank correlation: Ours={df_results.loc[4, 'rank_corr']:.4f} vs Best Standard={df_results.loc[1:3, 'rank_corr'].max():.4f}")
print(f"‚Ä¢ Marginal accuracy: Ours={df_results.loc[4, 'marginal_error']:.3f} (lower is better)")
print(f"‚Ä¢ AUC preservation: Ours={df_results.loc[4, 'auc_macro']:.3f} vs Original={df_results.loc[0, 'auc_macro']:.3f}")

## Clinical Decision Impact Analysis

Let's see how ranking scrambling affects real clinical decisions.

In [None]:
def analyze_clinical_decision_impact(y_proba_orig, y_proba_cal, method_name, risk_threshold=0.15):
    """Analyze how calibration affects high-risk patient identification."""
    
    # Get highest risk class probabilities (Critical + Severe)
    high_risk_orig = y_proba_orig[:, -2:].sum(axis=1)  # Severe + Critical
    high_risk_cal = y_proba_cal[:, -2:].sum(axis=1)
    
    # Identify high-risk patients
    orig_high_risk = high_risk_orig > risk_threshold
    cal_high_risk = high_risk_cal > risk_threshold
    
    # Decision changes
    decision_changes = np.sum(orig_high_risk != cal_high_risk)
    
    # Ranking changes among high-risk patients
    if np.sum(orig_high_risk) > 1:
        high_risk_indices = np.where(orig_high_risk)[0]
        orig_rankings = np.argsort(high_risk_orig[high_risk_indices])[::-1]
        cal_rankings = np.argsort(high_risk_cal[high_risk_indices])[::-1]
        
        # Kendall's tau for ranking correlation
        from scipy.stats import kendalltau
        tau, _ = kendalltau(orig_rankings, cal_rankings)
    else:
        tau = 1.0
    
    return {
        'method': method_name,
        'orig_high_risk': np.sum(orig_high_risk),
        'cal_high_risk': np.sum(cal_high_risk),
        'decision_changes': decision_changes,
        'ranking_tau': tau,
        'change_rate': decision_changes / len(y_proba_orig) * 100
    }

print("üè• CLINICAL DECISION IMPACT ANALYSIS")
print("="*45)
print("Scenario: Identifying patients for urgent cardiology referral")
print(f"Threshold: >15% probability of severe/critical disease")

# Analyze decision impact for each method
clinical_results = [
    analyze_clinical_decision_impact(y_proba, y_proba, "Original"),
    analyze_clinical_decision_impact(y_proba, y_proba_temp, "Temperature Scale"),
    analyze_clinical_decision_impact(y_proba, y_proba_platt, "Platt/Isotonic"),
    analyze_clinical_decision_impact(y_proba, y_proba_hist, "Histogram Bin"),
    analyze_clinical_decision_impact(y_proba, y_proba_ours, "Rank-Preserving")
]

df_clinical = pd.DataFrame(clinical_results)

print(f"\n{'Method':<16} {'High Risk':<10} {'Changes':<8} {'Change%':<8} {'RankTau':<8}")
print("-" * 55)

for _, row in df_clinical.iterrows():
    print(f"{row['method']:<16} {row['cal_high_risk']:<10} {row['decision_changes']:<8} "
          f"{row['change_rate']:<8.1f} {row['ranking_tau']:<8.3f}")

print("\nüí° CLINICAL IMPLICATIONS:")

# Show specific patient examples where ranking matters
temp_changes = df_clinical.loc[1, 'decision_changes']
ours_changes = df_clinical.loc[4, 'decision_changes']

print(f"‚Ä¢ Temperature Scaling changed referral decisions for {temp_changes} patients ({df_clinical.loc[1, 'change_rate']:.1f}%)")
print(f"‚Ä¢ Rank-Preserving changed referral decisions for {ours_changes} patients ({df_clinical.loc[4, 'change_rate']:.1f}%)")
print(f"‚Ä¢ Ranking preservation among high-risk patients: Ours={df_clinical.loc[4, 'ranking_tau']:.3f} vs Temp={df_clinical.loc[1, 'ranking_tau']:.3f}")

print("\n‚ö†Ô∏è CLINICAL RISKS OF POOR RANK PRESERVATION:")
risks = [
    "Patient A is sicker than B, but B gets referral priority after calibration",
    "ICU bed allocation based on scrambled risk rankings",
    "Medication dosing decisions using unreliable relative risk",
    "Clinical trial enrollment with biased patient stratification"
]

for risk in risks:
    print(f"   ‚Ä¢ {risk}")

## Visualization: Rank Preservation Quality

In [None]:
# Create visualization comparing methods
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
fig.suptitle('Medical Risk Calibration: Rank Preservation Analysis', fontsize=16, y=0.98)

# 1. Risk distribution comparison
x_pos = np.arange(len(severity_labels))
width = 0.15

orig_dist = np.mean(y_proba, axis=0)
temp_dist = np.mean(y_proba_temp, axis=0)
ours_dist = np.mean(y_proba_ours, axis=0)

axes[0, 0].bar(x_pos - width, orig_dist, width, label='Original', alpha=0.8)
axes[0, 0].bar(x_pos, temp_dist, width, label='Temperature Scale', alpha=0.8)
axes[0, 0].bar(x_pos + width, ours_dist, width, label='Rank-Preserving', alpha=0.8)
axes[0, 0].axhline(y=population_distribution, color='red', linestyle='--', alpha=0.7, label='Population Target')

axes[0, 0].set_xlabel('Disease Severity')
axes[0, 0].set_ylabel('Probability')
axes[0, 0].set_title('Risk Distribution Adjustment')
axes[0, 0].set_xticks(x_pos)
axes[0, 0].set_xticklabels([s[:4] for s in severity_labels], rotation=45)
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# 2. Rank correlation distribution
methods = ['Temp Scale', 'Platt/Iso', 'Histogram', 'Rank-Preserving']
method_probas = [y_proba_temp, y_proba_platt, y_proba_hist, y_proba_ours]
colors = ['orange', 'green', 'blue', 'red']

for method, proba, color in zip(methods, method_probas, colors):
    rank_corrs = []
    for i in range(len(y_proba)):
        corr, _ = spearmanr(y_proba[i], proba[i])
        if not np.isnan(corr):
            rank_corrs.append(corr)
    
    axes[0, 1].hist(rank_corrs, bins=20, alpha=0.6, label=method, color=color, density=True)

axes[0, 1].axvline(1.0, color='black', linestyle='--', alpha=0.7, label='Perfect Preservation')
axes[0, 1].set_xlabel('Spearman Rank Correlation')
axes[0, 1].set_ylabel('Density')
axes[0, 1].set_title('Rank Preservation Distribution')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# 3. Metrics comparison radar chart (simplified bar chart)
metrics_names = ['Accuracy', 'AUC', 'Rank Corr', 'Cal Quality']
temp_metrics = [df_results.loc[1, 'accuracy'], df_results.loc[1, 'auc_macro'], 
               df_results.loc[1, 'rank_corr'], 1-df_results.loc[1, 'ece']]  # 1-ECE for "quality"
ours_metrics = [df_results.loc[4, 'accuracy'], df_results.loc[4, 'auc_macro'],
               df_results.loc[4, 'rank_corr'], 1-df_results.loc[4, 'ece']]

x_met = np.arange(len(metrics_names))
axes[1, 0].bar(x_met - 0.2, temp_metrics, 0.4, label='Temperature Scale', alpha=0.8)
axes[1, 0].bar(x_met + 0.2, ours_metrics, 0.4, label='Rank-Preserving', alpha=0.8)
axes[1, 0].set_ylabel('Score')
axes[1, 0].set_title('Performance Metrics Comparison')
axes[1, 0].set_xticks(x_met)
axes[1, 0].set_xticklabels(metrics_names, rotation=45)
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# 4. Clinical decision impact
decision_methods = ['Original', 'Temp Scale', 'Platt/Iso', 'Histogram', 'Rank-Preserving']
decision_changes = [df_clinical.loc[i, 'change_rate'] for i in range(len(decision_methods))]

bars = axes[1, 1].bar(decision_methods, decision_changes, alpha=0.8, color=['gray', 'orange', 'green', 'blue', 'red'])
axes[1, 1].set_ylabel('Referral Decision Changes (%)')
axes[1, 1].set_title('Impact on Clinical Decisions')
axes[1, 1].set_xticklabels(decision_methods, rotation=45)
axes[1, 1].grid(True, alpha=0.3)

# Highlight the best method
bars[-1].set_edgecolor('black')
bars[-1].set_linewidth(2)

plt.tight_layout()
plt.show()

print(f"\nüèÜ SUMMARY: RANK-PRESERVING CALIBRATION ADVANTAGES")
print("="*60)
print(f"‚úÖ Rank Correlation: {df_results.loc[4, 'rank_corr']:.4f} (vs {df_results.loc[1, 'rank_corr']:.4f} for Temperature Scaling)")
print(f"‚úÖ Patients with Scrambled Rankings: {df_results.loc[4, 'scrambled_patients']} (vs {df_results.loc[1, 'scrambled_patients']} for Temperature Scaling)")
print(f"‚úÖ Marginal Distribution Error: {df_results.loc[4, 'marginal_error']:.4f} (lower is better)")
print(f"‚úÖ AUC Preservation: {df_results.loc[4, 'auc_macro']:.3f} (vs original {df_results.loc[0, 'auc_macro']:.3f})")
print(f"‚úÖ Clinical Decision Stability: {df_clinical.loc[4, 'change_rate']:.1f}% changed (vs {df_clinical.loc[1, 'change_rate']:.1f}% for Temperature)")

print("\nüéØ WHEN TO USE RANK-PRESERVING CALIBRATION:")
use_cases = [
    "Population deployment of clinical trial models",
    "Patient triage and resource allocation decisions", 
    "Multi-class medical diagnosis with severity levels",
    "Risk stratification for treatment decisions",
    "Clinical trial enrollment and patient matching"
]

for use_case in use_cases:
    print(f"   ‚Ä¢ {use_case}")