# Business Validation for Healthcare Insurance Risk Model

This notebook validates the business value and clinical relevance of the healthcare insurance risk prediction model.

## Validation Areas:
1. **Business KPIs**: Key performance indicators aligned with business objectives
2. **Clinical Relevance**: Medical validity of risk predictions
3. **Regional Equity**: Fair predictions across geographic regions
4. **Prediction Stability**: Consistency over time
5. **ROI Metrics**: Business impact and cost-benefit analysis

In [None]:
# Business Validation Notebook for Healthcare Insurance MLOps
import sys
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.window import Window
import mlflow
from mlflow.tracking import MlflowClient
from datetime import datetime, timedelta
import json

# Get parameters
dbutils.widgets.text("catalog", "juan_dev", "Unity Catalog name")
dbutils.widgets.text("ml_schema", "healthcare_data", "ML Schema name")
dbutils.widgets.text("model_name", "insurance_model", "Model name")
dbutils.widgets.text("lookback_days", "30", "Days to look back for trend analysis")

catalog = dbutils.widgets.get("catalog")
ml_schema = dbutils.widgets.get("ml_schema")
model_name = dbutils.widgets.get("model_name")
lookback_days = int(dbutils.widgets.get("lookback_days"))

# Configure MLflow
mlflow.set_registry_uri("databricks-uc")
client = MlflowClient()

print(f"="*80)
print(f"BUSINESS VALIDATION REPORT")
print(f"="*80)
print(f"Model: {catalog}.{ml_schema}.{model_name}")
print(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Lookback period: {lookback_days} days")
print(f"="*80)

## 1. Business KPI Validation

Validate that the model meets key business objectives and performance indicators.

In [None]:
# Load predictions data
predictions_table = f"{catalog}.{ml_schema}.ml_patient_predictions"

try:
    predictions_df = spark.table(predictions_table)
    
    # Filter for recent predictions
    cutoff_date = (datetime.now() - timedelta(days=lookback_days)).strftime('%Y-%m-%d')
    recent_predictions = predictions_df.filter(
        col("prediction_timestamp") >= cutoff_date
    )
    
    total_predictions = recent_predictions.count()
    print(f"\nüìä KPI Summary (Last {lookback_days} days)")
    print(f"-" * 80)
    print(f"Total predictions: {total_predictions:,}")
    
    if total_predictions == 0:
        print(f"\n‚ö†Ô∏è  No predictions found in the last {lookback_days} days")
        dbutils.notebook.exit(json.dumps({"status": "NO_DATA", "message": "No recent predictions"}))
    
    # Calculate key business metrics (using actual column names from schema)
    business_metrics = recent_predictions.agg(
        avg("adjusted_prediction").alias("avg_risk_score"),
        stddev("adjusted_prediction").alias("risk_score_std"),
        min("adjusted_prediction").alias("min_risk"),
        max("adjusted_prediction").alias("max_risk"),
        (sum(when(col("high_risk_patient") == True, 1).otherwise(0)) / count("*") * 100).alias("high_risk_pct"),
        avg("prediction_lower_bound").alias("avg_ci_lower"),
        avg("prediction_upper_bound").alias("avg_ci_upper")
    ).collect()[0]
    
    print(f"\nRisk Score Statistics:")
    print(f"  Average: {business_metrics.avg_risk_score:.2f}")
    print(f"  Std Dev: {business_metrics.risk_score_std:.2f}")
    print(f"  Range: {business_metrics.min_risk:.2f} - {business_metrics.max_risk:.2f}")
    print(f"  High-risk patients: {business_metrics.high_risk_pct:.1f}%")
    
    # Business KPI Targets
    kpi_targets = {
        "high_risk_pct_target": (5.0, 25.0),  # Target range: 5-25%
        "avg_risk_target": (20.0, 60.0),       # Average risk should be moderate
        "prediction_volume_min": 100           # Minimum predictions for statistical validity
    }
    
    print(f"\n‚úÖ KPI Target Validation:")
    kpi_results = {}
    
    # High-risk percentage check
    if kpi_targets["high_risk_pct_target"][0] <= business_metrics.high_risk_pct <= kpi_targets["high_risk_pct_target"][1]:
        print(f"  ‚úÖ High-risk %: {business_metrics.high_risk_pct:.1f}% (Target: {kpi_targets['high_risk_pct_target'][0]}-{kpi_targets['high_risk_pct_target'][1]}%)")
        kpi_results["high_risk_pct"] = "PASS"
    else:
        print(f"  ‚ùå High-risk %: {business_metrics.high_risk_pct:.1f}% (Target: {kpi_targets['high_risk_pct_target'][0]}-{kpi_targets['high_risk_pct_target'][1]}%)")
        kpi_results["high_risk_pct"] = "FAIL"
    
    # Average risk check
    if kpi_targets["avg_risk_target"][0] <= business_metrics.avg_risk_score <= kpi_targets["avg_risk_target"][1]:
        print(f"  ‚úÖ Avg risk: {business_metrics.avg_risk_score:.2f} (Target: {kpi_targets['avg_risk_target'][0]}-{kpi_targets['avg_risk_target'][1]})")
        kpi_results["avg_risk"] = "PASS"
    else:
        print(f"  ‚ùå Avg risk: {business_metrics.avg_risk_score:.2f} (Target: {kpi_targets['avg_risk_target'][0]}-{kpi_targets['avg_risk_target'][1]})")
        kpi_results["avg_risk"] = "FAIL"
    
    # Prediction volume check
    if total_predictions >= kpi_targets["prediction_volume_min"]:
        print(f"  ‚úÖ Prediction volume: {total_predictions:,} (Minimum: {kpi_targets['prediction_volume_min']:,})")
        kpi_results["volume"] = "PASS"
    else:
        print(f"  ‚ùå Prediction volume: {total_predictions:,} (Minimum: {kpi_targets['prediction_volume_min']:,})")
        kpi_results["volume"] = "FAIL"
    
except Exception as e:
    print(f"\n‚ùå KPI validation failed: {str(e)}")
    import traceback
    traceback.print_exc()
    dbutils.notebook.exit(json.dumps({"status": "ERROR", "message": str(e)}))

## 2. Clinical Relevance Validation

Validate that predictions align with medical knowledge and healthcare best practices.

In [None]:
print(f"\nüè• Clinical Relevance Validation")
print(f"-" * 80)

clinical_results = {}

try:
    # 1. Smoking Impact Validation
    print(f"\n1. Smoking Impact Analysis:")
    smoking_analysis = recent_predictions.groupBy("patient_smoking_status").agg(
        avg("adjusted_prediction").alias("avg_risk"),
        (sum(when(col("high_risk_patient") == True, 1).otherwise(0)) / count("*") * 100).alias("high_risk_pct"),
        count("*").alias("patient_count")
    ).orderBy("patient_smoking_status").collect()
    
    for row in smoking_analysis:
        print(f"   {row.patient_smoking_status}: Avg Risk={row.avg_risk:.2f}, High-Risk%={row.high_risk_pct:.1f}%, N={row.patient_count:,}")
    
    # Validate that smokers have higher average risk
    smoker_risks = {row.patient_smoking_status: row.avg_risk for row in smoking_analysis}
    if "yes" in smoker_risks and "no" in smoker_risks:
        if smoker_risks["yes"] > smoker_risks["no"]:
            print(f"   ‚úÖ Smokers have higher average risk (clinically valid)")
            clinical_results["smoking_correlation"] = "PASS"
        else:
            print(f"   ‚ùå Smokers do NOT have higher risk (clinically invalid)")
            clinical_results["smoking_correlation"] = "FAIL"
    
    # 2. Age Correlation Validation
    print(f"\n2. Age Correlation Analysis:")
    age_analysis = recent_predictions.groupBy("patient_age_category").agg(
        avg("adjusted_prediction").alias("avg_risk"),
        count("*").alias("patient_count")
    ).orderBy("patient_age_category").collect()
    
    prev_risk = 0
    age_monotonic = True
    for row in age_analysis:
        print(f"   {row.patient_age_category}: Avg Risk={row.avg_risk:.2f}, N={row.patient_count:,}")
        if row.avg_risk < prev_risk:
            age_monotonic = False
        prev_risk = row.avg_risk
    
    if age_monotonic:
        print(f"   ‚úÖ Age positively correlated with risk (clinically valid)")
        clinical_results["age_correlation"] = "PASS"
    else:
        print(f"   ‚ö†Ô∏è  Age NOT consistently correlated with risk")
        clinical_results["age_correlation"] = "WARN"
    
    # 3. BMI Impact Validation
    print(f"\n3. BMI Category Impact Analysis:")
    # Derive BMI category from numeric bmi field
    bmi_predictions = recent_predictions.withColumn(
        "bmi_category",
        when(col("bmi") < 18.5, "underweight")
        .when((col("bmi") >= 18.5) & (col("bmi") < 25), "normal")
        .when((col("bmi") >= 25) & (col("bmi") < 30), "overweight")
        .when(col("bmi") >= 30, "obese")
        .otherwise("unknown")
    )
    
    bmi_analysis = bmi_predictions.groupBy("bmi_category").agg(
        avg("adjusted_prediction").alias("avg_risk"),
        count("*").alias("patient_count")
    ).orderBy("bmi_category").collect()
    
    for row in bmi_analysis:
        print(f"   {row.bmi_category}: Avg Risk={row.avg_risk:.2f}, N={row.patient_count:,}")
    
    # Validate that obese patients have higher risk than normal
    bmi_risks = {row.bmi_category: row.avg_risk for row in bmi_analysis}
    if "obese" in bmi_risks and "normal" in bmi_risks:
        if bmi_risks["obese"] > bmi_risks["normal"]:
            print(f"   ‚úÖ Obese patients have higher risk than normal BMI (clinically valid)")
            clinical_results["bmi_correlation"] = "PASS"
        else:
            print(f"   ‚ùå Obese patients do NOT have higher risk (clinically questionable)")
            clinical_results["bmi_correlation"] = "FAIL"
    
    # 4. Risk Category Distribution
    print(f"\n4. Risk Category Distribution:")
    risk_dist = recent_predictions.groupBy("risk_category").agg(
        count("*").alias("patient_count")
    ).withColumn("percentage", col("patient_count") / total_predictions * 100).orderBy("risk_category").collect()
    
    for row in risk_dist:
        print(f"   {row.risk_category}: {row.patient_count:,} ({row.percentage:.1f}%)")
    
    # Healthcare industry benchmark: Expect majority in low/medium risk
    risk_pcts = {row.risk_category: row.percentage for row in risk_dist}
    low_medium_pct = risk_pcts.get("low", 0) + risk_pcts.get("medium", 0)
    
    if low_medium_pct >= 60:
        print(f"   ‚úÖ Low/Medium risk patients: {low_medium_pct:.1f}% (Expected >60%)")
        clinical_results["risk_distribution"] = "PASS"
    else:
        print(f"   ‚ö†Ô∏è  Low/Medium risk patients: {low_medium_pct:.1f}% (Expected >60%)")
        clinical_results["risk_distribution"] = "WARN"
    
except Exception as e:
    print(f"\n‚ùå Clinical validation failed: {str(e)}")
    import traceback
    traceback.print_exc()

## 3. Regional Equity Validation

Ensure fair and equitable predictions across all geographic regions.

In [None]:
print(f"\nüåé Regional Equity Validation")
print(f"-" * 80)

equity_results = {}

try:
    # Regional distribution analysis
    regional_analysis = recent_predictions.groupBy("region").agg(
        avg("adjusted_prediction").alias("avg_risk"),
        stddev("adjusted_prediction").alias("risk_std"),
        (sum(when(col("high_risk_patient") == True, 1).otherwise(0)) / count("*") * 100).alias("high_risk_pct"),
        count("*").alias("patient_count")
    ).orderBy("region").collect()
    
    print(f"\nRisk Metrics by Region:")
    print(f"{'Region':<15} {'Avg Risk':<12} {'Std Dev':<12} {'High-Risk %':<12} {'Count':<10}")
    print(f"-" * 80)
    
    regional_risks = []
    for row in regional_analysis:
        regional_risks.append(row.avg_risk)
        print(f"{row.region:<15} {row.avg_risk:<12.2f} {row.risk_std:<12.2f} {row.high_risk_pct:<12.1f} {row.patient_count:<10,}")
    
    # Calculate regional disparity using Python's built-in max/min
    if len(regional_risks) > 1:
        import builtins
        max_regional_risk = builtins.max(regional_risks)
        min_regional_risk = builtins.min(regional_risks)
        regional_disparity = ((max_regional_risk - min_regional_risk) / min_regional_risk) * 100
        
        print(f"\nRegional Disparity Analysis:")
        print(f"  Max regional risk: {max_regional_risk:.2f}")
        print(f"  Min regional risk: {min_regional_risk:.2f}")
        print(f"  Disparity: {regional_disparity:.1f}%")
        
        # Acceptable disparity threshold: 20%
        if regional_disparity <= 20:
            print(f"  ‚úÖ Regional disparity within acceptable range (<20%)")
            equity_results["regional_equity"] = "PASS"
        else:
            print(f"  ‚ö†Ô∏è  High regional disparity (>{regional_disparity:.1f}%) - investigate regional bias")
            equity_results["regional_equity"] = "WARN"
    
    # Gender equity analysis
    print(f"\nGender Equity Analysis:")
    gender_analysis = recent_predictions.groupBy("sex").agg(
        avg("adjusted_prediction").alias("avg_risk"),
        (sum(when(col("high_risk_patient") == True, 1).otherwise(0)) / count("*") * 100).alias("high_risk_pct"),
        count("*").alias("patient_count")
    ).collect()
    
    for row in gender_analysis:
        print(f"  {row.sex}: Avg Risk={row.avg_risk:.2f}, High-Risk%={row.high_risk_pct:.1f}%, N={row.patient_count:,}")
    
    # Check for extreme gender bias
    gender_risks = {row.sex: row.avg_risk for row in gender_analysis}
    if len(gender_risks) == 2:
        gender_disparity = builtins.abs(list(gender_risks.values())[0] - list(gender_risks.values())[1])
        if gender_disparity <= 10:
            print(f"  ‚úÖ Gender disparity acceptable: {gender_disparity:.2f} points")
            equity_results["gender_equity"] = "PASS"
        else:
            print(f"  ‚ö†Ô∏è  High gender disparity: {gender_disparity:.2f} points")
            equity_results["gender_equity"] = "WARN"
    
except Exception as e:
    print(f"\n‚ùå Equity validation failed: {str(e)}")
    import traceback
    traceback.print_exc()

## 4. Prediction Stability Over Time

Validate that predictions are consistent and stable over the lookback period.

In [None]:
print(f"\nüìà Prediction Stability Analysis")
print(f"-" * 80)

stability_results = {}

try:
    # Daily trend analysis
    daily_trends = recent_predictions.withColumn(
        "prediction_date", date_format(col("prediction_timestamp"), "yyyy-MM-dd")
    ).groupBy("prediction_date").agg(
        avg("adjusted_prediction").alias("avg_risk"),
        count("*").alias("prediction_count"),
        (sum(when(col("high_risk_patient") == True, 1).otherwise(0)) / count("*") * 100).alias("high_risk_pct")
    ).orderBy("prediction_date").collect()
    
    if len(daily_trends) > 1:
        print(f"\nDaily Prediction Trends (Last {min(7, len(daily_trends))} days):")
        print(f"{'Date':<12} {'Avg Risk':<12} {'High-Risk %':<12} {'Count':<10}")
        print(f"-" * 80)
        
        # Show last 7 days
        for row in daily_trends[-7:]:
            print(f"{row.prediction_date:<12} {row.avg_risk:<12.2f} {row.high_risk_pct:<12.1f} {row.prediction_count:<10,}")
        
        # Calculate coefficient of variation (CV) for stability
        daily_risks = [row.avg_risk for row in daily_trends]
        avg_daily_risk = sum(daily_risks) / len(daily_risks)
        variance = sum([(x - avg_daily_risk)**2 for x in daily_risks]) / len(daily_risks)
        std_dev = variance ** 0.5
        cv = (std_dev / avg_daily_risk) * 100
        
        print(f"\nStability Metrics:")
        print(f"  Coefficient of Variation: {cv:.2f}%")
        
        # CV < 10% indicates good stability
        if cv < 10:
            print(f"  ‚úÖ High prediction stability (CV < 10%)")
            stability_results["temporal_stability"] = "PASS"
        elif cv < 20:
            print(f"  ‚ö†Ô∏è  Moderate prediction stability (CV < 20%)")
            stability_results["temporal_stability"] = "WARN"
        else:
            print(f"  ‚ùå Low prediction stability (CV > 20%)")
            stability_results["temporal_stability"] = "FAIL"
    else:
        print(f"  ‚ÑπÔ∏è  Insufficient data for temporal stability analysis")
        stability_results["temporal_stability"] = "INSUFFICIENT_DATA"
    
except Exception as e:
    print(f"\n‚ùå Stability validation failed: {str(e)}")
    import traceback
    traceback.print_exc()

## 5. Business Impact and ROI Metrics

Calculate the business value and return on investment of the risk prediction model.

In [None]:
print(f"\nüí∞ Business Impact & ROI Analysis")
print(f"-" * 80)

roi_results = {}

try:
    # Business assumptions (adjust based on actual business metrics)
    assumptions = {
        "avg_intervention_cost": 500,        # Cost per high-risk patient intervention
        "avg_prevented_claim": 5000,         # Average claim prevented by early intervention
        "intervention_success_rate": 0.30,   # 30% of interventions prevent a claim
        "model_operating_cost_monthly": 2000 # Monthly cost to run the model
    }
    
    # Calculate high-risk patient counts
    high_risk_patients = recent_predictions.filter(col("high_risk_patient") == True).count()
    
    # ROI Calculation
    print(f"\nBusiness Impact Estimates:")
    print(f"  High-risk patients identified: {high_risk_patients:,}")
    
    # Calculate costs
    intervention_costs = high_risk_patients * assumptions["avg_intervention_cost"]
    model_costs = assumptions["model_operating_cost_monthly"] * (lookback_days / 30)
    total_costs = intervention_costs + model_costs
    
    print(f"\nCosts:")
    print(f"  Intervention costs: ${intervention_costs:,.2f}")
    print(f"  Model operating costs: ${model_costs:,.2f}")
    print(f"  Total costs: ${total_costs:,.2f}")
    
    # Calculate benefits
    successful_interventions = high_risk_patients * assumptions["intervention_success_rate"]
    prevented_claims_value = successful_interventions * assumptions["avg_prevented_claim"]
    
    print(f"\nBenefits:")
    print(f"  Estimated successful interventions: {successful_interventions:.0f}")
    print(f"  Value of prevented claims: ${prevented_claims_value:,.2f}")
    
    # Calculate ROI
    net_benefit = prevented_claims_value - total_costs
    roi_percentage = (net_benefit / total_costs) * 100 if total_costs > 0 else 0
    
    print(f"\nROI Summary:")
    print(f"  Net benefit: ${net_benefit:,.2f}")
    print(f"  ROI: {roi_percentage:.1f}%")
    
    if roi_percentage > 100:
        print(f"  ‚úÖ Strong positive ROI (>{roi_percentage:.0f}%)")
        roi_results["roi_status"] = "EXCELLENT"
    elif roi_percentage > 50:
        print(f"  ‚úÖ Positive ROI ({roi_percentage:.0f}%)")
        roi_results["roi_status"] = "GOOD"
    elif roi_percentage > 0:
        print(f"  ‚ö†Ô∏è  Marginal ROI ({roi_percentage:.0f}%)")
        roi_results["roi_status"] = "MARGINAL"
    else:
        print(f"  ‚ùå Negative ROI ({roi_percentage:.0f}%)")
        roi_results["roi_status"] = "NEGATIVE"
    
    print(f"\nüìù Note: ROI calculations based on business assumptions. Adjust assumptions for actual metrics.")
    
except Exception as e:
    print(f"\n‚ùå ROI validation failed: {str(e)}")
    import traceback
    traceback.print_exc()

## 6. Model Governance Status

Check current model governance and champion model status.

In [None]:
print(f"\nüèõÔ∏è  Model Governance Status")
print(f"-" * 80)

governance_results = {}

try:
    full_model_name = f"{catalog}.{ml_schema}.{model_name}"
    
    # Check champion model
    try:
        champion_info = client.get_model_version_by_alias(full_model_name, "champion")
        print(f"\nChampion Model:")
        print(f"  Version: {champion_info.version}")
        print(f"  Status: {champion_info.status}")
        print(f"  Run ID: {champion_info.run_id}")
        
        # Get model metrics
        run_data = client.get_run(champion_info.run_id)
        metrics = run_data.data.metrics
        
        print(f"\n  Performance Metrics:")
        for metric_name in ["r2_score", "mean_absolute_error", "high_risk_accuracy"]:
            if metric_name in metrics:
                print(f"    {metric_name}: {metrics[metric_name]:.4f}")
        
        # Check governance tags
        if champion_info.tags:
            print(f"\n  Governance Tags:")
            for key, value in champion_info.tags.items():
                if key in ["healthcare_compliance", "validation_r2", "hipaa_compliant"]:
                    print(f"    {key}: {value}")
        
        governance_results["champion_exists"] = True
        governance_results["champion_version"] = champion_info.version
        
    except Exception as e:
        print(f"  ‚ö†Ô∏è  No champion model found: {e}")
        governance_results["champion_exists"] = False
    
except Exception as e:
    print(f"\n‚ùå Governance check failed: {str(e)}")
    import traceback
    traceback.print_exc()

## 7. Final Business Validation Summary

Comprehensive summary of all business validation checks.

In [None]:
print(f"\n{'='*80}")
print(f"BUSINESS VALIDATION SUMMARY")
print(f"{'='*80}")

# Collect all validation results
all_results = {
    "Business KPIs": kpi_results,
    "Clinical Relevance": clinical_results,
    "Regional Equity": equity_results,
    "Temporal Stability": stability_results,
    "ROI": roi_results,
    "Governance": governance_results
}

total_checks = 0
passed_checks = 0
warnings = 0
failed_checks = 0

for category, results in all_results.items():
    print(f"\n{category}:")
    for check_name, result in results.items():
        total_checks += 1
        if result == "PASS" or result == True or result in ["EXCELLENT", "GOOD"]:
            print(f"  ‚úÖ {check_name}: {result}")
            passed_checks += 1
        elif result == "WARN" or result == "MARGINAL":
            print(f"  ‚ö†Ô∏è  {check_name}: {result}")
            warnings += 1
        elif result == False or result == "FAIL" or result == "NEGATIVE":
            print(f"  ‚ùå {check_name}: {result}")
            failed_checks += 1
        else:
            print(f"  ‚ÑπÔ∏è  {check_name}: {result}")

print(f"\n{'='*80}")
print(f"Validation Score: {passed_checks}/{total_checks} checks passed")
print(f"Warnings: {warnings}")
print(f"Failures: {failed_checks}")

if failed_checks == 0 and warnings == 0:
    overall_status = "EXCELLENT"
    print(f"\n‚úÖ‚úÖ‚úÖ OVERALL STATUS: EXCELLENT - Model meets all business objectives ‚úÖ‚úÖ‚úÖ")
elif failed_checks == 0:
    overall_status = "GOOD"
    print(f"\n‚úÖ OVERALL STATUS: GOOD - Model meets business objectives with minor warnings ‚úÖ")
elif failed_checks <= 2:
    overall_status = "ACCEPTABLE"
    print(f"\n‚ö†Ô∏è  OVERALL STATUS: ACCEPTABLE - Model has some issues to address ‚ö†Ô∏è")
else:
    overall_status = "NEEDS_IMPROVEMENT"
    print(f"\n‚ùå OVERALL STATUS: NEEDS IMPROVEMENT - Model requires attention ‚ùå")

print(f"{'='*80}")

# Create summary for return
summary = {
    "status": overall_status,
    "total_checks": total_checks,
    "passed": passed_checks,
    "warnings": warnings,
    "failures": failed_checks,
    "timestamp": datetime.now().isoformat(),
    "lookback_days": lookback_days,
    "total_predictions": total_predictions
}

dbutils.notebook.exit(json.dumps(summary))