In [None]:
# Inference Validation Notebook for Healthcare Insurance MLOps
# This notebook validates batch inference results

import sys
from pyspark.sql import SparkSession
from pyspark.sql.functions import *

# Get parameters
dbutils.widgets.text("catalog", "juan_dev", "Unity Catalog name")
dbutils.widgets.text("ml_schema", "ml", "ML Schema name")
dbutils.widgets.text("predictions_table", "", "Predictions table")
dbutils.widgets.text("batch_date", "", "Batch date (YYYY-MM-DD)")

catalog = dbutils.widgets.get("catalog")
ml_schema = dbutils.widgets.get("ml_schema")
predictions_table = dbutils.widgets.get("predictions_table")
batch_date = dbutils.widgets.get("batch_date")

print(f"Validating inference results for {predictions_table} on {batch_date}")

In [None]:
# Load and validate predictions
try:
    predictions_df = spark.table(predictions_table)
    
    # Filter for today's predictions if batch_date provided
    if batch_date:
        daily_predictions = predictions_df.filter(
            date_format(col("prediction_timestamp"), "yyyy-MM-dd") == batch_date
        )
    else:
        daily_predictions = predictions_df.filter(
            col("prediction_timestamp") >= current_date()
        )
    
    prediction_count = daily_predictions.count()
    print(f"✅ Found {prediction_count:,} predictions for validation")
    
    if prediction_count == 0:
        print("❌ No predictions found for validation")
        dbutils.notebook.exit(json.dumps({"status": "FAILED", "message": "No predictions found"}))
        
except Exception as e:
    print(f"❌ Could not load predictions: {str(e)}")
    dbutils.notebook.exit(json.dumps({"status": "FAILED", "message": str(e)}))

In [None]:
# Validate prediction quality
try:
    # Check for null predictions
    null_predictions = daily_predictions.filter(col("adjusted_prediction").isNull()).count()
    
    if null_predictions > 0:
        print(f"❌ Found {null_predictions} null predictions")
        dbutils.notebook.exit(json.dumps({"status": "FAILED", "message": f"{null_predictions} null predictions"}))
    else:
        print(f"✅ No null predictions found")
    
    # Check prediction range (risk scores should be 0-100)
    prediction_stats = daily_predictions.agg(
        min("adjusted_prediction").alias("min_pred"),
        max("adjusted_prediction").alias("max_pred"),
        avg("adjusted_prediction").alias("avg_pred"),
        stddev("adjusted_prediction").alias("std_pred")
    ).collect()[0]
    
    min_pred = prediction_stats.min_pred
    max_pred = prediction_stats.max_pred
    avg_pred = prediction_stats.avg_pred
    std_pred = prediction_stats.std_pred
    
    print(f"Prediction range: {min_pred:.2f} - {max_pred:.2f}")
    print(f"Average prediction: {avg_pred:.2f} ± {std_pred:.2f}")
    
    # Validate reasonable ranges
    if min_pred < 0 or max_pred > 200:  # Allow some flexibility for different target types
        print(f"⚠️  Predictions outside expected range: {min_pred:.2f} - {max_pred:.2f}")
    else:
        print(f"✅ Predictions within reasonable range")
        
except Exception as e:
    print(f"❌ Prediction validation failed: {str(e)}")
    dbutils.notebook.exit(json.dumps({"status": "FAILED", "message": str(e)}))

In [None]:
# Check business logic validation
try:
    # Risk category distribution
    risk_distribution = daily_predictions.groupBy("risk_category").count().collect()
    
    print("Risk category distribution:")
    for row in risk_distribution:
        category = row.risk_category
        count = row['count']
        percentage = (count / prediction_count) * 100
        print(f"  {category}: {count:,} ({percentage:.1f}%)")
    
    # Check for high-risk flags
    high_risk_count = daily_predictions.filter(col("high_risk_patient") == True).count()
    high_risk_pct = (high_risk_count / prediction_count) * 100
    
    print(f"High-risk patients: {high_risk_count:,} ({high_risk_pct:.1f}%)")
    
    # Validate business rules
    if high_risk_pct > 50:  # More than 50% high-risk seems unusual
        print(f"⚠️  High percentage of high-risk patients: {high_risk_pct:.1f}%")
    elif high_risk_pct < 1:  # Less than 1% high-risk also seems unusual
        print(f"⚠️  Very low percentage of high-risk patients: {high_risk_pct:.1f}%")
    else:
        print(f"✅ High-risk patient percentage within reasonable range")
        
except Exception as e:
    print(f"❌ Business logic validation failed: {str(e)}")
    dbutils.notebook.exit(json.dumps({"status": "FAILED", "message": str(e)}))

In [None]:
# Final validation summary
validation_summary = {
    "status": "SUCCESS",
    "prediction_count": prediction_count,
    "avg_prediction": float(avg_pred),
    "high_risk_percentage": float(high_risk_pct),
    "validation_date": batch_date or str(date.today())
}

print(f"\n=== Validation Summary ===")
print(f"✅ Inference validation completed successfully")
print(f"Predictions validated: {prediction_count:,}")
print(f"Average prediction: {avg_pred:.2f}")
print(f"High-risk patients: {high_risk_pct:.1f}%")

dbutils.notebook.exit(json.dumps(validation_summary))