# Medical Survey Research: Population Health Calibration

**Core Problem**: Survey statisticians and machine learning practitioners often need to adjust predicted class probabilities from a classifier so they match known population totals (column marginals). Simple post-hoc methods that apply separate logit shifts or raking to each class can scramble the ranking of individuals within a class when there are three or more classes.

This example demonstrates this problem using the Wisconsin Breast Cancer dataset, simulating a health survey scenario where:
1. A diagnostic model is trained on hospital data (biased sample)
2. We need to calibrate predictions to match true population health statistics
3. **Critical requirement**: Preserve individual patient risk rankings while adjusting marginals

## Medical Survey Context

**Scenario**: Large-scale health screening survey where:
- **Sampling bias**: Hospital training data over-represents high-risk patients  
- **Population matching**: Need to match Census health demographics
- **Multiple risk categories**: Low, Medium, High risk (3+ classes)
- **Ranking preservation**: Individual likelihood orderings must be maintained

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.calibration import calibration_curve
from sklearn.metrics import brier_score_loss, roc_auc_score
from scipy.stats import spearmanr
import pandas as pd

# Import our calibration package - proper imports (no sys.path hacks!)
from rank_preserving_calibration import (
    calibrate_dykstra, calibrate_admm,  # Two algorithms
    feasibility_metrics, isotonic_metrics, distance_metrics,  # Rich metrics
    sharpness_metrics, classwise_ece  # Advanced metrics
)

# Set style for publication-quality plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")
np.random.seed(42)

print("üè• MEDICAL SURVEY CALIBRATION TOOLKIT LOADED")
print("Focus: Rank-preserving multinomial calibration for health surveys")
print("Package features: Dykstra, ADMM, nearly-isotonic, comprehensive metrics")

## Health Survey Data with Multinomial Risk Categories

We'll create a realistic health survey scenario with 3 risk categories (Low, Medium, High) to demonstrate the core multinomial calibration problem.

In [None]:
print("üè• CREATING BIASED HEALTH SURVEY DATA")
print("="*50)

# Load breast cancer data and create multinomial risk categories  
data = load_breast_cancer()
X, y_binary = data.data, data.target

# Convert to multinomial: Create 3 risk categories based on features
# This simulates a health survey where we classify patients into risk levels
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# Create synthetic 3-class problem: Low, Medium, High risk
# Use feature combinations to create realistic risk stratification
risk_score = (
    X_scaled[:, 0] * 0.3 +  # mean radius
    X_scaled[:, 7] * 0.3 +  # mean concavity  
    X_scaled[:, 20] * 0.2 + # worst radius
    X_scaled[:, 27] * 0.2   # worst concavity
)

# Convert to 3-class labels based on percentiles (hospital bias)
# Hospital data biased toward high-risk patients
percentiles = [40, 75]  # Biased split: more high-risk
y_multinomial = np.digitize(risk_score, np.percentile(risk_score, percentiles))

# Create labels
risk_labels = ['Low Risk', 'Medium Risk', 'High Risk']
print(f"Dataset shape: {X.shape}")
print(f"Risk categories: {risk_labels}")

# Split data - hospital training sample  
X_train, X_test, y_train, y_test = train_test_split(
    X_scaled, y_multinomial, test_size=0.3, random_state=42, stratify=y_multinomial
)

# Show hospital sample bias
hospital_distribution = np.bincount(y_train) / len(y_train)
print(f"\nHOSPITAL TRAINING SAMPLE (biased):")
for i, (label, pct) in enumerate(zip(risk_labels, hospital_distribution)):
    print(f"  {label}: {np.sum(y_train == i):,} ({pct:.1%})")

print(f"\nTest sample size: {len(y_test)}")
print(f"Features: {X.shape[1]} clinical measurements")

## Model Training and Survey Bias

In [None]:
# Train multinomial classifier on biased hospital data
print("ü§ñ TRAINING HEALTH RISK CLASSIFIER")
print("="*40)

# Train Random Forest for 3-class classification
model = RandomForestClassifier(
    n_estimators=100, 
    max_depth=10, 
    random_state=42,
    class_weight='balanced'
)
model.fit(X_train, y_train)

# Get predictions and probabilities on test set
y_pred = model.predict(X_test)
y_proba = model.predict_proba(X_test)  # Shape: (n_test, 3)

# Model performance
from sklearn.metrics import accuracy_score, classification_report
accuracy = accuracy_score(y_test, y_pred)
print(f"Classifier accuracy: {accuracy:.3f}")
print(f"Test samples: {len(y_test)}")

# Current model marginals (reflects hospital bias)
current_marginals = np.mean(y_proba, axis=0)
print(f"\nMODEL PREDICTED DISTRIBUTION (hospital bias):")
for i, (label, marginal) in enumerate(zip(risk_labels, current_marginals)):
    print(f"  {label}: {marginal:.3f} ({marginal*100:.1f}%)")

# Show feature importance
feature_names = ['radius', 'texture', 'perimeter', 'area', 'smoothness'][:5]
importance = model.feature_importances_[:5]
print(f"\nTop predictive features:")
for feat, imp in zip(feature_names, importance):
    print(f"  {feat}: {imp:.3f}")

## Population Health Statistics and Calibration Challenge

**Key Problem**: Our hospital-trained model must be calibrated to match true population health demographics from Census/health survey data.

In [None]:
print("üìä POPULATION HEALTH TARGETS")
print("="*40)

# True population health distribution (from Census/health surveys)
# General population has different risk distribution than hospital sample
population_health_distribution = np.array([
    0.60,   # Low Risk: Higher in general population
    0.30,   # Medium Risk: Moderate 
    0.10    # High Risk: Much lower than hospital sample
])

print("TARGET POPULATION HEALTH DISTRIBUTION (Census data):")
for i, (label, target_pct) in enumerate(zip(risk_labels, population_health_distribution)):
    current_pct = current_marginals[i]
    change = target_pct - current_pct
    direction = "‚Üë" if change > 0 else "‚Üì" if change < 0 else "‚Üí"
    print(f"  {label}: {target_pct:.1%} (change: {change:+.1%} {direction})")

print(f"\nüéØ MULTINOMIAL CALIBRATION CHALLENGE:")
print(f"   ‚Ä¢ Hospital model biased toward high-risk patients")
print(f"   ‚Ä¢ Need to match Census population health marginals")  
print(f"   ‚Ä¢ Must preserve individual patient risk rankings")
print(f"   ‚Ä¢ Simple logit shifts will scramble patient orderings")

# Calculate target marginals for calibration
n_test_samples = len(y_test)
target_marginals = population_health_distribution * n_test_samples

print(f"\nCALIBRATION PARAMETERS:")
print(f"  Test samples: {n_test_samples}")
print(f"  Target marginals: {target_marginals}")
print(f"  Sum check: {np.sum(target_marginals):.1f} (should equal {n_test_samples})")

print(f"\nüö® WHY RANKING PRESERVATION IS CRITICAL:")
ranking_importance = [
    "Individual patient triage depends on relative risk rankings",
    "Resource allocation requires preserved within-category orderings", 
    "Treatment priority decisions based on individual likelihood rankings",
    "Health economics models depend on maintained patient risk orderings"
]

for importance in ranking_importance:
    print(f"   ‚Ä¢ {importance}")

## Demonstrating Simple Methods Fail (Ranking Scrambling)

In [None]:
print("‚ö†Ô∏è DEMONSTRATING RANKING SCRAMBLING WITH SIMPLE METHODS")
print("="*65)

# Simple post-hoc method: separate logit shifts per class
def simple_logit_calibration(probs, targets):
    """Apply separate logit shifts - CAN SCRAMBLE RANKINGS with 3+ classes"""
    current_marginals = np.mean(probs, axis=0)
    
    # Calculate logit shifts for each risk category
    logit_shifts = np.log(targets / np.sum(targets)) - np.log(current_marginals)
    
    # Apply shifts
    log_probs = np.log(probs + 1e-12)
    shifted_log_probs = log_probs + logit_shifts[np.newaxis, :]
    
    # Renormalize
    shifted_probs = np.exp(shifted_log_probs)
    calibrated_probs = shifted_probs / np.sum(shifted_probs, axis=1, keepdims=True)
    
    return calibrated_probs

# Apply simple method
y_proba_simple = simple_logit_calibration(y_proba, population_health_distribution)

# Check ranking preservation with simple method
print("RANKING PRESERVATION ANALYSIS - SIMPLE LOGIT METHOD:")
simple_rank_correlations = []
for i in range(len(y_test)):
    corr, _ = spearmanr(y_proba[i], y_proba_simple[i])
    if not np.isnan(corr):
        simple_rank_correlations.append(corr)

simple_rank_correlations = np.array(simple_rank_correlations)
perfect_simple = np.sum(np.isclose(simple_rank_correlations, 1.0, atol=1e-10))
scrambled_simple = np.sum(simple_rank_correlations < 0.95)

print(f"  Perfect rank preservation: {perfect_simple}/{len(simple_rank_correlations)}")
print(f"  Significantly scrambled (corr < 0.95): {scrambled_simple}")
print(f"  Mean Spearman correlation: {np.mean(simple_rank_correlations):.3f}")

# Check marginal accuracy
simple_achieved = np.mean(y_proba_simple, axis=0)
simple_marginal_error = np.max(np.abs(simple_achieved - population_health_distribution))
print(f"  Maximum marginal error: {simple_marginal_error:.4f}")

print(f"\n‚ùå PROBLEMS WITH SIMPLE LOGIT METHOD:")
print(f"   ‚Ä¢ Scrambles rankings in {scrambled_simple} patient cases")
print(f"   ‚Ä¢ Patient A more likely High Risk than B before calibration")
print(f"   ‚Ä¢ But less likely High Risk than B after calibration")
print(f"   ‚Ä¢ Violates clinical triage and priority principles")
print(f"   ‚Ä¢ Makes individual patient risk assessment unreliable")

# Show examples of ranking violations
if scrambled_simple > 0:
    worst_cases = np.argsort(simple_rank_correlations)[:3]
    print(f"\nüîç EXAMPLES OF RANKING SCRAMBLING:")
    for i, case_idx in enumerate(worst_cases):
        orig_order = np.argsort(-y_proba[case_idx])
        simple_order = np.argsort(-y_proba_simple[case_idx])
        corr = simple_rank_correlations[case_idx]
        
        orig_labels = [risk_labels[j] for j in orig_order]
        simple_labels = [risk_labels[j] for j in simple_order]
        
        print(f"  Patient {case_idx}: Correlation = {corr:.3f}")
        print(f"    Original ranking: {' > '.join(orig_labels)}")
        print(f"    Simple method:   {' > '.join(simple_labels)}")

## Algorithm Comparison: Dykstra vs ADMM

In [None]:
print("üî¨ ALGORITHM COMPARISON: DYKSTRA vs ADMM")
print("="*50)

# Method 1: Dykstra's alternating projections (recommended default)
print("1Ô∏è‚É£ DYKSTRA'S ALTERNATING PROJECTIONS:")
result_dykstra = calibrate_dykstra(
    P=y_proba,
    M=target_marginals,
    max_iters=2000,
    tol=1e-7,
    verbose=True
)

y_proba_dykstra = result_dykstra.Q
print(f"\n   Converged: {result_dykstra.converged}")
print(f"   Iterations: {result_dykstra.iterations}")
print(f"   Final objective: {result_dykstra.objective:.2e}")
print(f"   Max column error: {result_dykstra.max_col_error:.2e}")
print(f"   Max rank violation: {result_dykstra.max_rank_violation:.2e}")

# Method 2: ADMM optimization (alternative with convergence history)
print(f"\n2Ô∏è‚É£ ADMM OPTIMIZATION:")
result_admm = calibrate_admm(
    P=y_proba,
    M=target_marginals,
    max_iters=1000,
    tol=1e-6,
    verbose=True
)

y_proba_admm = result_admm.Q
print(f"\n   Converged: {result_admm.converged}")
print(f"   Iterations: {result_admm.iterations}")
print(f"   Final objective: {result_admm.objective:.2e}")
print(f"   Max column error: {result_admm.max_col_error:.2e}")
print(f"   Max rank violation: {result_admm.max_rank_violation:.2e}")

# Comparison of results
print(f"\nüìä ALGORITHM COMPARISON:")
print(f"{'Metric':<20} {'Dykstra':<15} {'ADMM':<15}")
print("-" * 50)
print(f"{'Converged':<20} {result_dykstra.converged:<15} {result_admm.converged:<15}")
print(f"{'Iterations':<20} {result_dykstra.iterations:<15} {result_admm.iterations:<15}")
print(f"{'Final objective':<20} {result_dykstra.objective:<15.2e} {result_admm.objective:<15.2e}")
print(f"{'Max col error':<20} {result_dykstra.max_col_error:<15.2e} {result_admm.max_col_error:<15.2e}")

# Check if both methods give same result
prob_difference = np.max(np.abs(y_proba_dykstra - y_proba_admm))
print(f"\nMaximum probability difference: {prob_difference:.2e}")
print(f"Methods agree: {'Yes' if prob_difference < 1e-6 else 'No'}")

print(f"\nüéØ WHEN TO USE EACH ALGORITHM:")
print(f"   ‚Ä¢ Dykstra: Default choice, exact projections, reliable convergence")
print(f"   ‚Ä¢ ADMM: When you need convergence diagnostics, experimental features")

## Nearly Isotonic Constraints (Relaxed Rank Preservation)

In [None]:
print("üîÑ NEARLY ISOTONIC CALIBRATION (RELAXED CONSTRAINTS)")  
print("="*60)

print("Sometimes strict rank preservation is too restrictive...")
print("Nearly-isotonic allows small ranking violations for better fit")

# Method 1: Epsilon-slack approach (Dykstra)
print("\n1Ô∏è‚É£ EPSILON-SLACK APPROACH:")
print("   Allows z[i+1] >= z[i] - eps instead of strict z[i+1] >= z[i]")

nearly_epsilon = {"mode": "epsilon", "eps": 0.05}
result_nearly_eps = calibrate_dykstra(
    P=y_proba,
    M=target_marginals,
    nearly=nearly_epsilon,
    max_iters=2000,
    tol=1e-7,
    verbose=True
)

y_proba_nearly_eps = result_nearly_eps.Q
print(f"\n   Converged: {result_nearly_eps.converged}")
print(f"   Iterations: {result_nearly_eps.iterations}")
print(f"   Max rank violation: {result_nearly_eps.max_rank_violation:.2e}")

# Method 2: Lambda-penalty approach (ADMM)
print("\n2Ô∏è‚É£ LAMBDA-PENALTY APPROACH (Experimental):")
print("   Penalizes ranking violations with Œª * sum(violations)")

nearly_lambda = {"mode": "lambda", "lam": 1.0}
result_nearly_lam = calibrate_admm(
    P=y_proba,
    M=target_marginals,
    nearly=nearly_lambda,
    max_iters=1000,
    tol=1e-6,
    verbose=True
)

y_proba_nearly_lam = result_nearly_lam.Q
print(f"\n   Converged: {result_nearly_lam.converged}")
print(f"   Iterations: {result_nearly_lam.iterations}")
print(f"   Max rank violation: {result_nearly_lam.max_rank_violation:.2e}")

# Compare ranking preservation
def check_ranking_preservation(P_orig, P_cal, method_name):
    """Check how well rankings are preserved"""
    rank_correlations = []
    for i in range(len(P_orig)):
        corr, _ = spearmanr(P_orig[i], P_cal[i])
        if not np.isnan(corr):
            rank_correlations.append(corr)
    
    rank_correlations = np.array(rank_correlations)
    perfect = np.sum(np.isclose(rank_correlations, 1.0, atol=1e-10))
    mean_corr = np.mean(rank_correlations)
    min_corr = np.min(rank_correlations)
    
    print(f"\n{method_name} ranking preservation:")
    print(f"   Perfect preservation: {perfect}/{len(rank_correlations)}")
    print(f"   Mean correlation: {mean_corr:.6f}")
    print(f"   Min correlation: {min_corr:.6f}")
    
    return mean_corr

# Check all methods
strict_corr = check_ranking_preservation(y_proba, y_proba_dykstra, "Strict isotonic")
eps_corr = check_ranking_preservation(y_proba, y_proba_nearly_eps, "Epsilon-slack")
lam_corr = check_ranking_preservation(y_proba, y_proba_nearly_lam, "Lambda-penalty")

print(f"\nüéØ WHEN TO USE NEARLY-ISOTONIC:")
print(f"   ‚Ä¢ Strict isotonic too restrictive for your data")
print(f"   ‚Ä¢ Small ranking violations acceptable for better marginal fit")  
print(f"   ‚Ä¢ Epsilon-slack: Maintains convexity, theoretical guarantees")
print(f"   ‚Ä¢ Lambda-penalty: Experimental, may need parameter tuning")

## Performance Metrics Comparison

Let's quantify the improvement in calibration while showing that discrimination is preserved.

In [None]:
# Calculate various performance metrics
from sklearn.metrics import brier_score_loss, log_loss

# Discrimination metrics (should be unchanged)
auc_original = roc_auc_score(y_test, malignant_probs_original)
auc_calibrated = roc_auc_score(y_test, malignant_probs_calibrated)

# Calibration metrics
brier_original = brier_score_loss(y_test, malignant_probs_original)
brier_calibrated = brier_score_loss(y_test, malignant_probs_calibrated)

logloss_original = log_loss(y_test, malignant_probs_original)
logloss_calibrated = log_loss(y_test, malignant_probs_calibrated)

# Calibration error (Expected Calibration Error)
def expected_calibration_error(y_true, y_prob, n_bins=10):
    bin_boundaries = np.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]
    
    ece = 0
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = (y_prob > bin_lower) & (y_prob <= bin_upper)
        prop_in_bin = in_bin.mean()
        
        if prop_in_bin > 0:
            accuracy_in_bin = y_true[in_bin].mean()
            avg_confidence_in_bin = y_prob[in_bin].mean()
            ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
    
    return ece

ece_original = expected_calibration_error(y_test, malignant_probs_original)
ece_calibrated = expected_calibration_error(y_test, malignant_probs_calibrated)

# Create results summary
results_df = pd.DataFrame({
    'Metric': ['AUC-ROC', 'Brier Score', 'Log Loss', 'ECE', 'Mean Prediction'],
    'Original': [auc_original, brier_original, logloss_original, ece_original, malignant_probs_original.mean()],
    'Calibrated': [auc_calibrated, brier_calibrated, logloss_calibrated, ece_calibrated, malignant_probs_calibrated.mean()],
    'Target': ['-', '-', '-', 0.0, target_prevalence]
})

# Calculate improvements
results_df['Change'] = results_df['Calibrated'] - results_df['Original']

print("Performance Metrics Comparison:")
print("=" * 80)
for _, row in results_df.iterrows():
    if row['Change'] != row['Change']:  # NaN check
        print(f"{row['Metric']:<15} {row['Original']:<12.4f} {row['Calibrated']:<12.4f} {row['Target']:<12}")
    else:
        print(f"{row['Metric']:<15} {row['Original']:<12.4f} {row['Calibrated']:<12.4f} {row['Target']:<12} ({row['Change']:+.4f})")

print("\nKey Observations:")
print(f"‚Ä¢ AUC-ROC maintained: {abs(auc_calibrated - auc_original) < 0.001} (Œî={auc_calibrated-auc_original:.6f})")
print(f"‚Ä¢ Calibration error reduced: {ece_original:.4f} ‚Üí {ece_calibrated:.4f}")
print(f"‚Ä¢ Mean prediction corrected: {malignant_probs_original.mean():.3f} ‚Üí {malignant_probs_calibrated.mean():.3f} (target: {target_prevalence:.3f})")
print(f"‚Ä¢ Brier score {'improved' if brier_calibrated < brier_original else 'changed'}: {brier_original:.4f} ‚Üí {brier_calibrated:.4f}")

## Clinical Decision Analysis

Let's analyze how calibration affects clinical decision making at different risk thresholds.

In [None]:
# Clinical decision analysis
def analyze_clinical_decisions(y_true, y_prob_orig, y_prob_cal, thresholds):
    """Analyze clinical decisions at different thresholds."""
    results = []
    
    for thresh in thresholds:
        # Original model decisions
        decisions_orig = y_prob_orig >= thresh
        tp_orig = np.sum((decisions_orig == 1) & (y_true == 1))
        fp_orig = np.sum((decisions_orig == 1) & (y_true == 0))
        tn_orig = np.sum((decisions_orig == 0) & (y_true == 0))
        fn_orig = np.sum((decisions_orig == 0) & (y_true == 1))
        
        # Calibrated model decisions
        decisions_cal = y_prob_cal >= thresh
        tp_cal = np.sum((decisions_cal == 1) & (y_true == 1))
        fp_cal = np.sum((decisions_cal == 1) & (y_true == 0))
        tn_cal = np.sum((decisions_cal == 0) & (y_true == 0))
        fn_cal = np.sum((decisions_cal == 0) & (y_true == 1))
        
        # Calculate metrics
        sens_orig = tp_orig / (tp_orig + fn_orig) if (tp_orig + fn_orig) > 0 else 0
        spec_orig = tn_orig / (tn_orig + fp_orig) if (tn_orig + fp_orig) > 0 else 0
        ppv_orig = tp_orig / (tp_orig + fp_orig) if (tp_orig + fp_orig) > 0 else 0
        
        sens_cal = tp_cal / (tp_cal + fn_cal) if (tp_cal + fn_cal) > 0 else 0
        spec_cal = tn_cal / (tn_cal + fp_cal) if (tn_cal + fp_cal) > 0 else 0
        ppv_cal = tp_cal / (tp_cal + fp_cal) if (tp_cal + fp_cal) > 0 else 0
        
        # Decision changes
        decision_changes = np.sum(decisions_orig != decisions_cal)
        
        results.append({
            'Threshold': thresh,
            'Sensitivity_Orig': sens_orig,
            'Sensitivity_Cal': sens_cal,
            'Specificity_Orig': spec_orig,
            'Specificity_Cal': spec_cal,
            'PPV_Orig': ppv_orig,
            'PPV_Cal': ppv_cal,
            'Decision_Changes': decision_changes,
            'Patients_Flagged_Orig': np.sum(decisions_orig),
            'Patients_Flagged_Cal': np.sum(decisions_cal)
        })
    
    return pd.DataFrame(results)

# Analyze at key clinical thresholds
clinical_thresholds = [0.1, 0.2, 0.3, 0.4, 0.5]
decision_analysis = analyze_clinical_decisions(
    y_test, malignant_probs_original, malignant_probs_calibrated, clinical_thresholds
)

print("Clinical Decision Analysis:")
print("=" * 100)
print(f"{'Threshold':<10} {'Sensitivity':<20} {'Specificity':<20} {'PPV':<20} {'Changes':<10}")
print(f"{'':10} {'Orig':>8} {'Cal':>8} {'Œî':>6} {'Orig':>8} {'Cal':>8} {'Œî':>6} {'Orig':>8} {'Cal':>8} {'Œî':>6} {'N':>6}")
print("-" * 100)

for _, row in decision_analysis.iterrows():
    thresh = row['Threshold']
    sens_delta = row['Sensitivity_Cal'] - row['Sensitivity_Orig']
    spec_delta = row['Specificity_Cal'] - row['Specificity_Orig']
    ppv_delta = row['PPV_Cal'] - row['PPV_Orig']
    
    print(f"{thresh:<10.1f} {row['Sensitivity_Orig']:>8.3f} {row['Sensitivity_Cal']:>8.3f} {sens_delta:>+6.3f} "
          f"{row['Specificity_Orig']:>8.3f} {row['Specificity_Cal']:>8.3f} {spec_delta:>+6.3f} "
          f"{row['PPV_Orig']:>8.3f} {row['PPV_Cal']:>8.3f} {ppv_delta:>+6.3f} {row['Decision_Changes']:>6.0f}")

print(f"\nTotal patients: {len(y_test)}")
print(f"Actual malignant cases: {np.sum(y_test)}")
print(f"Target prevalence: {target_prevalence:.1%}")

## Summary and Clinical Implications

This example demonstrates the value of rank-preserving calibration in medical applications:

In [None]:
print("CLINICAL SUMMARY: Rank-Preserving Calibration for Medical Diagnosis")
print("=" * 80)

print("\nüéØ SCENARIO:")
print(f"   Model trained on population with {y_train.mean():.1%} disease prevalence")
print(f"   Deployed in high-risk population with {target_prevalence:.1%} prevalence")

print("\nüìä KEY RESULTS:")
print(f"   ‚úì Maintained perfect patient ranking (AUC: {auc_original:.3f} ‚Üí {auc_calibrated:.3f})")
print(f"   ‚úì Corrected prevalence estimate ({malignant_probs_original.mean():.3f} ‚Üí {malignant_probs_calibrated.mean():.3f})")
print(f"   ‚úì Improved calibration (ECE: {ece_original:.4f} ‚Üí {ece_calibrated:.4f})")
print(f"   ‚úì Better probability estimates for clinical decision making")

print("\nüè• CLINICAL BENEFITS:")
print("   ‚Ä¢ Maintains relative risk ranking of patients")
print("   ‚Ä¢ Provides accurate absolute risk estimates")
print("   ‚Ä¢ Enables proper resource allocation in new populations")
print("   ‚Ä¢ Supports evidence-based clinical decision thresholds")

print("\n‚ö†Ô∏è  IMPORTANT CONSIDERATIONS:")
print("   ‚Ä¢ Requires reliable estimates of target population prevalence")
print("   ‚Ä¢ Should be validated on representative test data")
print("   ‚Ä¢ Consider confidence intervals for prevalence estimates")
print("   ‚Ä¢ Monitor performance in production deployment")

print("\nüìà WHEN TO USE RANK-PRESERVING CALIBRATION:")
print("   ‚Ä¢ Deploying models across different populations")
print("   ‚Ä¢ When both ranking and absolute probabilities matter")
print("   ‚Ä¢ Resource allocation based on risk scores")
print("   ‚Ä¢ Clinical decision support systems")

# Show specific example of a clinical decision
print("\nüí° EXAMPLE CLINICAL DECISION (30% threshold):")
thresh_example = 0.3
orig_flagged = np.sum(malignant_probs_original >= thresh_example)
cal_flagged = np.sum(malignant_probs_calibrated >= thresh_example)
changes = np.sum((malignant_probs_original >= thresh_example) != (malignant_probs_calibrated >= thresh_example))

print(f"   Original model: {orig_flagged} patients flagged for biopsy")
print(f"   Calibrated model: {cal_flagged} patients flagged for biopsy")
print(f"   Decision changes: {changes} patients ({100*changes/len(y_test):.1f}%)")
print(f"   ‚Üí Better alignment with {target_prevalence:.0%} prevalence population")

## Next Steps

This example showed rank-preserving calibration for binary classification in a medical context. The same principles apply to:

- **Multiclass medical diagnosis** (e.g., different types of skin lesions)
- **Risk stratification** with multiple risk categories
- **Treatment response prediction** across patient populations
- **Biomarker discovery** with population-specific prevalences

For more examples, see the other notebooks in this series:
- Text classification with sentiment analysis
- Image classification with vision models
- Financial risk assessment
- Survey reweighting applications