In [None]:
# Data Validation Notebook for Healthcare Insurance MLOps
# This notebook validates prerequisite data before running batch inference

import sys
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
import json

# Get parameters
dbutils.widgets.text("catalog", "juan_dev", "Unity Catalog name")
dbutils.widgets.text("schema", "healthcare_data", "Schema name")
dbutils.widgets.text("validation_date", "", "Validation date (YYYY-MM-DD)")

catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")
validation_date = dbutils.widgets.get("validation_date")

print(f"Validating data for {catalog}.{schema} on {validation_date}")

In [None]:
# Validate prerequisite tables exist
required_tables = [
    f"{catalog}.{schema}.silver_patients",
    f"{catalog}.{schema}.dim_patients"
]

validation_results = []

for table_name in required_tables:
    try:
        df = spark.table(table_name)
        row_count = df.count()
        validation_results.append({
            "table": table_name,
            "status": "SUCCESS",
            "row_count": row_count,
            "message": f"Table exists with {row_count:,} rows"
        })
        print(f"‚úÖ {table_name}: {row_count:,} rows")
    except Exception as e:
        validation_results.append({
            "table": table_name,
            "status": "FAILED",
            "row_count": 0,
            "message": str(e)
        })
        print(f"‚ùå {table_name}: {str(e)}")

In [None]:
# Check data freshness for dim_patients
try:
    dim_patients = spark.table(f"{catalog}.{schema}.dim_patients")
    
    # Check for current records
    current_records = dim_patients.filter(col("is_current_record") == True).count()
    total_records = dim_patients.count()
    
    print(f"‚úÖ Current records: {current_records:,} out of {total_records:,} total")
    
    # Check required columns (using actual schema from dim_patients)
    required_columns = [
        "patient_natural_key", "patient_gender", "patient_region",
        "patient_age_category", "bmi", "patient_smoking_status",
        "is_current_record", "number_of_dependents"
    ]
    
    missing_columns = [col for col in required_columns if col not in dim_patients.columns]
    
    if missing_columns:
        print(f"‚ùå Missing required columns: {missing_columns}")
        dbutils.notebook.exit(json.dumps({"status": "FAILED", "message": f"Missing columns: {missing_columns}"}))
    else:
        print(f"‚úÖ All required columns present")
    
    # Validate data quality
    print(f"\nüìä Data Quality Checks:")
    
    # Check for nulls in critical columns
    null_checks = dim_patients.filter(col("is_current_record") == True).select(
        [count(when(col(c).isNull(), c)).alias(c) for c in ["patient_natural_key", "patient_gender", "bmi", "patient_smoking_status"]]
    ).collect()[0]
    
    for col_name in ["patient_natural_key", "patient_gender", "bmi", "patient_smoking_status"]:
        null_count = null_checks[col_name]
        if null_count > 0:
            print(f"  ‚ö†Ô∏è  {col_name}: {null_count:,} null values")
        else:
            print(f"  ‚úÖ {col_name}: No null values")
    
    # Check BMI ranges (should be reasonable for adults)
    bmi_stats = dim_patients.filter(col("is_current_record") == True).select(
        min("bmi").alias("min_bmi"),
        max("bmi").alias("max_bmi"),
        avg("bmi").alias("avg_bmi")
    ).collect()[0]
    
    print(f"\nüìà BMI Distribution:")
    print(f"  Min: {bmi_stats.min_bmi:.1f}, Max: {bmi_stats.max_bmi:.1f}, Avg: {bmi_stats.avg_bmi:.1f}")
    
    if bmi_stats.min_bmi < 10 or bmi_stats.max_bmi > 80:
        print(f"  ‚ö†Ô∏è  BMI values outside reasonable range (10-80)")
    else:
        print(f"  ‚úÖ BMI values within reasonable range")
        
except Exception as e:
    print(f"‚ùå Data validation failed: {str(e)}")
    dbutils.notebook.exit(json.dumps({"status": "FAILED", "message": str(e)}))

In [None]:
# Final validation summary
failed_validations = [r for r in validation_results if r["status"] == "FAILED"]

if failed_validations:
    print(f"‚ùå Validation failed: {len(failed_validations)} issues found")
    for failure in failed_validations:
        print(f"  - {failure['table']}: {failure['message']}")
    dbutils.notebook.exit(json.dumps({"status": "FAILED", "failed_tables": len(failed_validations)}))
else:
    print(f"‚úÖ All validations passed - ready for batch inference")
    dbutils.notebook.exit(json.dumps({"status": "SUCCESS", "validated_tables": len(validation_results)}))