In [None]:
# Inference Validation Notebook for Healthcare Insurance MLOps
# This notebook validates batch inference results

import sys
import json
from datetime import date
from pyspark.sql import SparkSession
from pyspark.sql.functions import *

# Get parameters
dbutils.widgets.text("catalog", "juan_dev", "Unity Catalog name")
dbutils.widgets.text("ml_schema", "ml", "ML Schema name")
dbutils.widgets.text("predictions_table", "juan_dev.healthcare_data.ml_patient_predictions", "Predictions table")
dbutils.widgets.text("batch_date", "", "Batch date (YYYY-MM-DD)")

catalog = dbutils.widgets.get("catalog")
ml_schema = dbutils.widgets.get("ml_schema")
predictions_table = dbutils.widgets.get("predictions_table")
batch_date = dbutils.widgets.get("batch_date")

# Validate required parameters
if not predictions_table or predictions_table.strip() == "":
    print("‚ùå ERROR: predictions_table parameter is required")
    dbutils.notebook.exit(json.dumps({"status": "FAILED", "message": "predictions_table parameter is required"}))

print(f"Validating inference results for {predictions_table} on {batch_date if batch_date else 'current date'}")

In [None]:
# Load and validate predictions
try:
    predictions_df = spark.table(predictions_table)
    
    # Filter for today's predictions if batch_date provided
    if batch_date:
        daily_predictions = predictions_df.filter(
            date_format(col("prediction_timestamp"), "yyyy-MM-dd") == batch_date
        )
    else:
        daily_predictions = predictions_df.filter(
            col("prediction_timestamp") >= current_date()
        )
    
    prediction_count = daily_predictions.count()
    print(f"‚úÖ Found {prediction_count:,} predictions for validation")
    
    if prediction_count == 0:
        print("‚ùå No predictions found for validation")
        dbutils.notebook.exit(json.dumps({"status": "FAILED", "message": "No predictions found"}))
        
except Exception as e:
    print(f"‚ùå Could not load predictions: {str(e)}")
    dbutils.notebook.exit(json.dumps({"status": "FAILED", "message": str(e)}))

In [None]:
# Validate prediction quality
try:
    # Check for null predictions
    null_predictions = daily_predictions.filter(col("adjusted_prediction").isNull()).count()
    
    if null_predictions > 0:
        print(f"‚ùå Found {null_predictions} null predictions")
        dbutils.notebook.exit(json.dumps({"status": "FAILED", "message": f"{null_predictions} null predictions"}))
    else:
        print(f"‚úÖ No null predictions found")
    
    # Check prediction range (risk scores should be 0-100)
    prediction_stats = daily_predictions.agg(
        min("adjusted_prediction").alias("min_pred"),
        max("adjusted_prediction").alias("max_pred"),
        avg("adjusted_prediction").alias("avg_pred"),
        stddev("adjusted_prediction").alias("std_pred")
    ).collect()[0]
    
    min_pred = prediction_stats.min_pred
    max_pred = prediction_stats.max_pred
    avg_pred = prediction_stats.avg_pred
    std_pred = prediction_stats.std_pred
    
    print(f"Prediction range: {min_pred:.2f} - {max_pred:.2f}")
    print(f"Average prediction: {avg_pred:.2f} ¬± {std_pred:.2f}")
    
    # Validate reasonable ranges
    if min_pred < 0 or max_pred > 200:  # Allow some flexibility for different target types
        print(f"‚ö†Ô∏è  Predictions outside expected range: {min_pred:.2f} - {max_pred:.2f}")
    else:
        print(f"‚úÖ Predictions within reasonable range")
        
except Exception as e:
    print(f"‚ùå Prediction validation failed: {str(e)}")
    dbutils.notebook.exit(json.dumps({"status": "FAILED", "message": str(e)}))

In [None]:
# Check business logic validation
try:
    # Risk category distribution
    risk_distribution = daily_predictions.groupBy("risk_category").count().collect()
    
    print("\nüìä Risk category distribution:")
    for row in risk_distribution:
        category = row.risk_category
        count = row['count']
        percentage = (count / prediction_count) * 100
        print(f"  {category}: {count:,} ({percentage:.1f}%)")
    
    # Check for high-risk flags
    high_risk_count = daily_predictions.filter(col("high_risk_patient") == True).count()
    high_risk_pct = (high_risk_count / prediction_count) * 100
    
    print(f"\nüö® High-risk patients: {high_risk_count:,} ({high_risk_pct:.1f}%)")
    
    # Business KPI Thresholds (healthcare industry standards)
    business_kpis = {
        "high_risk_pct_min": 5.0,   # At least 5% should be flagged as high-risk
        "high_risk_pct_max": 30.0,  # No more than 30% should be high-risk
        "critical_risk_pct_max": 10.0,  # Critical cases should be < 10%
        "low_risk_pct_min": 30.0    # At least 30% should be low-risk (healthy population)
    }
    
    # Calculate category percentages
    category_pcts = {row.risk_category: (row['count'] / prediction_count) * 100 for row in risk_distribution}
    
    print(f"\n‚úÖ Business KPI Validation:")
    
    # Validate high-risk percentage
    if high_risk_pct < business_kpis["high_risk_pct_min"]:
        print(f"  ‚ö†Ô∏è  High-risk percentage too low: {high_risk_pct:.1f}% < {business_kpis['high_risk_pct_min']}%")
        print(f"      ‚Üí Model may be under-identifying at-risk patients")
    elif high_risk_pct > business_kpis["high_risk_pct_max"]:
        print(f"  ‚ö†Ô∏è  High-risk percentage too high: {high_risk_pct:.1f}% > {business_kpis['high_risk_pct_max']}%")
        print(f"      ‚Üí Model may be over-flagging patients, causing alert fatigue")
    else:
        print(f"  ‚úÖ High-risk patient percentage within acceptable range: {high_risk_pct:.1f}%")
    
    # Validate critical risk percentage
    critical_pct = category_pcts.get("critical", 0)
    if critical_pct > business_kpis["critical_risk_pct_max"]:
        print(f"  ‚ö†Ô∏è  Critical risk percentage too high: {critical_pct:.1f}% > {business_kpis['critical_risk_pct_max']}%")
    else:
        print(f"  ‚úÖ Critical risk percentage acceptable: {critical_pct:.1f}%")
    
    # Validate low-risk percentage
    low_pct = category_pcts.get("low", 0)
    if low_pct < business_kpis["low_risk_pct_min"]:
        print(f"  ‚ö†Ô∏è  Low-risk percentage too low: {low_pct:.1f}% < {business_kpis['low_risk_pct_min']}%")
        print(f"      ‚Üí Population may be unusually unhealthy, verify data quality")
    else:
        print(f"  ‚úÖ Low-risk percentage acceptable: {low_pct:.1f}%")
    
    # Clinical Relevance Checks
    print(f"\nüè• Clinical Relevance Validation:")
    
    # Check smoking vs high-risk correlation
    smoking_high_risk = daily_predictions.filter(
        (col("patient_smoking_status") == "yes") & (col("high_risk_patient") == True)
    ).count()
    smoking_total = daily_predictions.filter(col("patient_smoking_status") == "yes").count()
    
    if smoking_total > 0:
        smoking_high_risk_pct = (smoking_high_risk / smoking_total) * 100
        print(f"  Smokers flagged as high-risk: {smoking_high_risk_pct:.1f}%")
        if smoking_high_risk_pct < 30:
            print(f"    ‚ö†Ô∏è  Low high-risk rate among smokers - check model calibration")
        else:
            print(f"    ‚úÖ Smoking correctly correlated with higher risk")
    
    # Check age vs risk correlation (older patients should have higher average risk)
    age_risk_correlation = daily_predictions.groupBy("patient_age_category").agg(
        avg("adjusted_prediction").alias("avg_risk")
    ).orderBy("patient_age_category").collect()
    
    print(f"\n  Average risk by age category:")
    prev_risk = 0
    age_correlation_valid = True
    for row in age_risk_correlation:
        avg_risk = row.avg_risk
        print(f"    {row.patient_age_category}: {avg_risk:.2f}")
        if avg_risk < prev_risk:
            age_correlation_valid = False
        prev_risk = avg_risk
    
    if age_correlation_valid:
        print(f"  ‚úÖ Age positively correlated with risk (clinically valid)")
    else:
        print(f"  ‚ö†Ô∏è  Age not consistently correlated with risk - verify model")
        
except Exception as e:
    print(f"‚ùå Business logic validation failed: {str(e)}")
    dbutils.notebook.exit(json.dumps({"status": "FAILED", "message": str(e)}))

In [None]:
# Final validation summary
validation_summary = {
    "status": "SUCCESS",
    "prediction_count": prediction_count,
    "avg_prediction": float(avg_pred),
    "high_risk_percentage": float(high_risk_pct),
    "validation_date": batch_date or str(date.today())
}

print(f"\n=== Validation Summary ===")
print(f"‚úÖ Inference validation completed successfully")
print(f"Predictions validated: {prediction_count:,}")
print(f"Average prediction: {avg_pred:.2f}")
print(f"High-risk patients: {high_risk_pct:.1f}%")

dbutils.notebook.exit(json.dumps(validation_summary))