In [0]:
%sql
use catalog juan_dev;
use schema ml;

In [0]:
from databricks.sdk import WorkspaceClient
import mlflow
from pyspark.sql.functions import *
from pyspark.sql.types import *

class HealthcareModelMonitor:
    """
    Comprehensive model monitoring system specifically designed for your healthcare insurance model.
    Updated to use the unified juan_dev.ml schema for all monitoring assets.
    
    This design provides a cleaner, more maintainable monitoring architecture where
    all ML assets live in the same logical space.
    """
    
    def __init__(self, 
                 model_name="juan_dev.ml.healthcare_insurance_model",
                 baseline_table="juan_dev.ml.insurance_silver",
                 monitoring_table="juan_dev.ml.insurance_predictions"):
        
        self.model_name = model_name
        self.baseline_table = baseline_table  # Training data (our "healthy" reference)
        self.monitoring_table = monitoring_table  # Prediction results (what we monitor)
        self.schema_name = "juan_dev.ml"  # Unified schema for all ML assets
        self.workspace = WorkspaceClient()
        
        # Healthcare-specific monitoring thresholds
        self.thresholds = {
            "max_mae_threshold": 2500,  # Maximum acceptable Mean Absolute Error
            "min_daily_predictions": 50,  # Minimum predictions per day
            "max_prediction_drift": 0.20,  # 20% drift in average predictions
            "min_r2_score": 0.75,  # Minimum acceptable R² when ground truth available
            "max_demographic_drift": 0.10  # Maximum acceptable demographic shift
        }
        
        print(f"✅ Healthcare Model Monitor initialized")
        print(f"   Schema: {self.schema_name}")
        print(f"   Model: {self.model_name}")
        print(f"   Baseline: {self.baseline_table}")
        print(f"   Monitoring: {self.monitoring_table}")
    
    def setup_lakehouse_monitoring(self):
        """
        Setup Databricks Lakehouse Monitoring for your healthcare model.
        This creates automated monitoring dashboards and drift detection.
        """
        
        print("Setting up Lakehouse Monitoring for healthcare insurance model...")
        
        # First, let's ensure our monitoring table has the right structure
        self._prepare_monitoring_table()
        
        # Create monitoring configuration specifically for healthcare use case
        monitor_config = {
            "assets_dir": "/juan_dev/ml/monitoring/healthcare_insurance/",
            "output_schema_name": self.schema_name,  # Use unified schema
            "baseline_table_name": self.baseline_table,
            
            # Slice data by important healthcare demographics for bias monitoring
            "slicing_exprs": [
                "region",           # Geographic bias
                "age_group",        # Age-based bias
                "CASE WHEN smoker THEN 'smoker' ELSE 'non_smoker' END as smoking_status",  # Smoking bias
                "CASE WHEN bmi > 30 THEN 'obese' WHEN bmi > 25 THEN 'overweight' ELSE 'normal' END as bmi_category"  # Weight bias
            ],
            
            "problem_type": "PROBLEM_TYPE_REGRESSION",  # We're predicting continuous charges
            
            # Configure how to monitor predictions over time
            "inference_log": {
                "granularities": ["1 hour", "1 day", "1 week"],  # Monitor at multiple time scales
                "model_id_col": "model_name",  # Identifies which model made the prediction
                "prediction_col": "adjusted_prediction",  # Our main prediction column
                "timestamp_col": "prediction_timestamp",  # When prediction was made
                "label_col": "actual_charges"  # Ground truth (when available from claims data)
            },
            
            # Enable data classification for healthcare compliance
            "data_classification_config": {
                "enabled": True  # Important for HIPAA compliance
            }
        }
        
        try:
            # Create the monitor
            monitor_info = self.workspace.quality_monitors.create(
                table_name=self.monitoring_table,
                assets_dir=monitor_config["assets_dir"],
                output_schema_name=monitor_config["output_schema_name"],
                baseline_table_name=monitor_config["baseline_table_name"],
                slicing_exprs=monitor_config["slicing_exprs"],
                inference_log=monitor_config["inference_log"]
            )
            
            print(f"✅ Lakehouse Monitor created successfully!")
            print(f"Monitor ID: {monitor_info.monitor_name}")
            print(f"Dashboard available at: {monitor_config['assets_dir']}")
            print(f"Monitoring assets stored in: {self.schema_name}")
            
            return monitor_info
            
        except Exception as e:
            print(f"❌ Error creating Lakehouse Monitor: {str(e)}")
            print("This might be because:")
            print("1. The monitoring table doesn't exist yet (run batch inference first)")
            print("2. You don't have permissions to create monitors")
            print("3. The table schema doesn't match expected format")
            print("Note: Drift detection and alerts will still work without Lakehouse Monitoring")
            return None
    
    def _prepare_monitoring_table(self):
        """
        Ensure the monitoring table has the right structure for effective monitoring.
        This adds columns needed for comprehensive healthcare model monitoring.
        """
        
        print("Preparing monitoring table structure...")
        
        # Check if monitoring table exists
        try:
            existing_table = spark.table(self.monitoring_table)
            print(f"✅ Monitoring table {self.monitoring_table} exists")
            
            # Add age_group column if it doesn't exist (needed for slicing)
            if "age_group" not in existing_table.columns:
                print("Adding age_group column for demographic monitoring...")
                
                enhanced_table = (
                    existing_table
                    .withColumn("age_group", 
                               expr("CASE WHEN age < 30 THEN 'young' " +
                                    "WHEN age < 50 THEN 'middle' " +
                                    "ELSE 'senior' END"))
                )
                
                # Update the table with new column
                enhanced_table.write.mode("overwrite").saveAsTable(self.monitoring_table)
                print("✅ Enhanced monitoring table with age_group column")
            
        except Exception as e:
            print(f"⚠️  Monitoring table {self.monitoring_table} not found: {e}")
            print("Run batch inference first to create the table")
    
    def setup_drift_detection(self):
        """
        Setup comprehensive drift detection for healthcare-specific features.
        This monitors both statistical drift and business-relevant drift.
        All monitoring views are created in the unified juan_dev.ml schema.
        """
        
        print("Setting up healthcare-specific drift detection...")
        
        # Create a comprehensive drift detection query
        drift_detection_query = f"""
        WITH daily_stats AS (
            SELECT 
                DATE(prediction_timestamp) as prediction_date,
                model_version,
                
                -- Prediction statistics
                COUNT(*) as daily_prediction_count,
                AVG(adjusted_prediction) as avg_daily_prediction,
                STDDEV(adjusted_prediction) as std_daily_prediction,
                PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY adjusted_prediction) as median_prediction,
                
                -- Healthcare demographic monitoring (critical for bias detection)
                AVG(CASE WHEN smoker THEN 1.0 ELSE 0.0 END) as smoker_rate,
                AVG(age) as avg_age,
                AVG(bmi) as avg_bmi,
                AVG(children) as avg_children,
                
                -- Regional distribution (important for healthcare equity)
                SUM(CASE WHEN UPPER(region) = 'NORTHEAST' THEN 1 ELSE 0 END) / COUNT(*) as northeast_pct,
                SUM(CASE WHEN UPPER(region) = 'SOUTHEAST' THEN 1 ELSE 0 END) / COUNT(*) as southeast_pct,
                SUM(CASE WHEN UPPER(region) = 'NORTHWEST' THEN 1 ELSE 0 END) / COUNT(*) as northwest_pct,
                SUM(CASE WHEN UPPER(region) = 'SOUTHWEST' THEN 1 ELSE 0 END) / COUNT(*) as southwest_pct,
                
                -- Risk category distribution
                SUM(CASE WHEN cost_risk_category = 'high' THEN 1 ELSE 0 END) / COUNT(*) as high_risk_pct,
                SUM(CASE WHEN cost_risk_category = 'very_high' THEN 1 ELSE 0 END) / COUNT(*) as very_high_risk_pct,
                
                -- Performance metrics (when ground truth is available)
                CASE WHEN COUNT(actual_charges) > 0 THEN
                    AVG(ABS(adjusted_prediction - actual_charges))
                ELSE NULL END as daily_mae,
                
                CASE WHEN COUNT(actual_charges) > 0 THEN
                    SQRT(AVG(POWER(adjusted_prediction - actual_charges, 2)))
                ELSE NULL END as daily_rmse,
                
                CASE WHEN COUNT(actual_charges) > 0 THEN
                    1 - (SUM(POWER(adjusted_prediction - actual_charges, 2)) / 
                         SUM(POWER(actual_charges - AVG(actual_charges), 2)))
                ELSE NULL END as daily_r2
                
            FROM {self.monitoring_table}
            WHERE prediction_timestamp >= CURRENT_DATE() - INTERVAL 30 DAYS
            GROUP BY DATE(prediction_timestamp), model_version
        ),
        
        baseline_stats AS (
            SELECT 
                AVG(charges) as baseline_avg_charges,
                STDDEV(charges) as baseline_std_charges,
                AVG(CASE WHEN smoker THEN 1.0 ELSE 0.0 END) as baseline_smoker_rate,
                AVG(age) as baseline_avg_age,
                AVG(bmi) as baseline_avg_bmi,
                
                -- Baseline regional distribution
                SUM(CASE WHEN UPPER(region) = 'NORTHEAST' THEN 1 ELSE 0 END) / COUNT(*) as baseline_northeast_pct,
                SUM(CASE WHEN UPPER(region) = 'SOUTHEAST' THEN 1 ELSE 0 END) / COUNT(*) as baseline_southeast_pct,
                SUM(CASE WHEN UPPER(region) = 'NORTHWEST' THEN 1 ELSE 0 END) / COUNT(*) as baseline_northwest_pct,
                SUM(CASE WHEN UPPER(region) = 'SOUTHWEST' THEN 1 ELSE 0 END) / COUNT(*) as baseline_southwest_pct
                
            FROM {self.baseline_table}
        )
        
        SELECT 
            ds.*,
            bs.baseline_avg_charges,
            bs.baseline_std_charges,
            bs.baseline_smoker_rate,
            bs.baseline_avg_age,
            bs.baseline_avg_bmi,
            bs.baseline_northeast_pct,
            bs.baseline_southeast_pct,
            bs.baseline_northwest_pct,
            bs.baseline_southwest_pct,
            
            -- Statistical drift indicators
            ABS(ds.avg_daily_prediction - bs.baseline_avg_charges) / GREATEST(bs.baseline_std_charges, 1) as prediction_drift_zscore,
            ABS(ds.avg_daily_prediction - bs.baseline_avg_charges) / GREATEST(bs.baseline_avg_charges, 1) as prediction_drift_percentage,
            
            -- Demographic drift indicators (critical for healthcare fairness)
            ABS(ds.smoker_rate - bs.baseline_smoker_rate) as smoker_rate_drift,
            ABS(ds.avg_age - bs.baseline_avg_age) as age_drift,
            ABS(ds.avg_bmi - bs.baseline_avg_bmi) as bmi_drift,
            
            -- Regional distribution drift (important for access equity)
            ABS(ds.northeast_pct - bs.baseline_northeast_pct) as northeast_drift,
            ABS(ds.southeast_pct - bs.baseline_southeast_pct) as southeast_drift,
            ABS(ds.northwest_pct - bs.baseline_northwest_pct) as northwest_drift,
            ABS(ds.southwest_pct - bs.baseline_southwest_pct) as southwest_drift,
            
            -- Composite drift score (overall health indicator)
            (ABS(ds.avg_daily_prediction - bs.baseline_avg_charges) / GREATEST(bs.baseline_avg_charges, 1) +
             ABS(ds.smoker_rate - bs.baseline_smoker_rate) +
             ABS(ds.avg_age - bs.baseline_avg_age) / GREATEST(bs.baseline_avg_age, 1)) / 3 as composite_drift_score,
            
            -- Alert status flags
            CASE WHEN ds.daily_prediction_count < {self.thresholds['min_daily_predictions']} 
                 THEN 'VOLUME_LOW' ELSE 'VOLUME_OK' END as volume_status,
                 
            CASE WHEN ds.daily_mae > {self.thresholds['max_mae_threshold']} 
                 THEN 'ACCURACY_DEGRADED' ELSE 'ACCURACY_OK' END as accuracy_status,
                 
            CASE WHEN ABS(ds.avg_daily_prediction - bs.baseline_avg_charges) / GREATEST(bs.baseline_avg_charges, 1) > {self.thresholds['max_prediction_drift']}
                 THEN 'PREDICTION_DRIFT' ELSE 'PREDICTION_STABLE' END as prediction_drift_status,
                 
            CASE WHEN ABS(ds.smoker_rate - bs.baseline_smoker_rate) > {self.thresholds['max_demographic_drift']}
                 THEN 'DEMOGRAPHIC_DRIFT' ELSE 'DEMOGRAPHIC_STABLE' END as demographic_drift_status,
            
            CURRENT_TIMESTAMP() as analysis_timestamp
                 
        FROM daily_stats ds
        CROSS JOIN baseline_stats bs
        ORDER BY ds.prediction_date DESC
        """
        
        # Create drift detection view in the unified schema
        drift_view_name = f"{self.schema_name}.healthcare_drift_detection"
        
        try:
            # Create the monitoring view
            spark.sql(f"CREATE OR REPLACE VIEW {drift_view_name} AS {drift_detection_query}")
            print(f"✅ Created drift detection view: {drift_view_name}")
            
            # Test the view and show sample results
            drift_results = spark.sql(f"SELECT * FROM {drift_view_name} LIMIT 5")
            drift_count = drift_results.count()
            
            if drift_count > 0:
                print(f"✅ Drift detection operational - analyzing {drift_count} recent monitoring periods")
                print("Sample drift analysis:")
                drift_results.select(
                    "prediction_date", "daily_prediction_count", 
                    "prediction_drift_percentage", "composite_drift_score",
                    "volume_status", "prediction_drift_status", "demographic_drift_status"
                ).show(truncate=False)
            else:
                print("⚠️  No recent prediction data found for drift analysis")
                print("Run batch inference first to generate monitoring data")
                
        except Exception as e:
            print(f"❌ Error creating drift detection: {e}")
            print("Make sure the monitoring table exists and has recent data")
        
        return drift_view_name
    
    def setup_performance_alerts(self):
        """
        Setup automated alerts for healthcare model performance issues.
        These alerts help catch problems before they impact business decisions.
        All alert views are created in the unified juan_dev.ml schema.
        """
        
        print("Setting up healthcare-specific performance alerts...")
        
        # Define healthcare-specific alert conditions
        alert_conditions = [
            {
                "name": "high_mae_alert",
                "description": "Model prediction error exceeds healthcare accuracy standards",
                "condition": f"daily_mae > {self.thresholds['max_mae_threshold']}",
                "severity": "HIGH",
                "message": f"Model MAE exceeded ${self.thresholds['max_mae_threshold']} threshold - may impact underwriting decisions"
            },
            {
                "name": "prediction_volume_drop",
                "description": "Daily prediction volume below expected threshold",
                "condition": f"daily_prediction_count < {self.thresholds['min_daily_predictions']}",
                "severity": "MEDIUM", 
                "message": f"Daily prediction volume below {self.thresholds['min_daily_predictions']} - check data pipeline health"
            },
            {
                "name": "prediction_drift_alert",
                "description": "Significant drift in average predictions detected",
                "condition": f"prediction_drift_zscore > 2.0",
                "severity": "HIGH",
                "message": "Significant prediction drift detected - model may need retraining"
            },
            {
                "name": "demographic_bias_alert", 
                "description": "Potential bias detected in demographic predictions",
                "condition": f"smoker_rate_drift > {self.thresholds['max_demographic_drift']}",
                "severity": "HIGH",
                "message": "Smoking rate distribution has shifted significantly - check for selection bias"
            },
            {
                "name": "regional_distribution_alert",
                "description": "Regional distribution of predictions has changed",
                "condition": "ABS(northeast_pct + southeast_pct + northwest_pct + southwest_pct - 1.0) > 0.05",
                "severity": "MEDIUM",
                "message": "Regional distribution has shifted - verify data source coverage"
            }
        ]
        
        # Create alert monitoring query
        alert_query = f"""
        SELECT 
            prediction_date,
            model_version,
            daily_prediction_count,
            avg_daily_prediction,
            daily_mae,
            composite_drift_score,
            
            -- Specific alert conditions with business context
            CASE 
                WHEN volume_status = 'VOLUME_LOW' THEN 'LOW_PREDICTION_VOLUME'
                WHEN accuracy_status = 'ACCURACY_DEGRADED' THEN 'HIGH_PREDICTION_ERROR'
                WHEN prediction_drift_status = 'PREDICTION_DRIFT' THEN 'SIGNIFICANT_PREDICTION_DRIFT'
                WHEN demographic_drift_status = 'DEMOGRAPHIC_DRIFT' THEN 'DEMOGRAPHIC_BIAS_RISK'
                WHEN composite_drift_score > 0.3 THEN 'MULTIPLE_DRIFT_INDICATORS'
                ELSE 'HEALTHY'
            END as primary_alert_type,
            
            -- Alert severity based on business impact
            CASE 
                WHEN accuracy_status = 'ACCURACY_DEGRADED' AND daily_mae > {self.thresholds['max_mae_threshold'] * 1.5} THEN 'CRITICAL'
                WHEN demographic_drift_status = 'DEMOGRAPHIC_DRIFT' THEN 'HIGH'
                WHEN prediction_drift_status = 'PREDICTION_DRIFT' THEN 'HIGH'
                WHEN volume_status = 'VOLUME_LOW' THEN 'MEDIUM'
                WHEN composite_drift_score > 0.3 THEN 'MEDIUM'
                ELSE 'LOW'
            END as alert_severity,
            
            -- Business impact description
            CASE 
                WHEN accuracy_status = 'ACCURACY_DEGRADED' THEN 
                    CONCAT('Prediction errors exceed $', CAST({self.thresholds['max_mae_threshold']} AS STRING), ' - may impact underwriting decisions')
                WHEN demographic_drift_status = 'DEMOGRAPHIC_DRIFT' THEN 
                    'Demographic distribution has shifted significantly - potential bias concern'
                WHEN prediction_drift_status = 'PREDICTION_DRIFT' THEN 
                    'Average predictions have drifted from baseline - model may need retraining'
                WHEN volume_status = 'VOLUME_LOW' THEN 
                    'Daily prediction volume below threshold - check data pipeline health'
                ELSE 'Model operating within normal parameters'
            END as business_impact_description,
            
            -- Recommended actions
            CASE 
                WHEN accuracy_status = 'ACCURACY_DEGRADED' THEN 'Investigate model performance and consider immediate retraining'
                WHEN demographic_drift_status = 'DEMOGRAPHIC_DRIFT' THEN 'Review data sources and conduct bias analysis'
                WHEN prediction_drift_status = 'PREDICTION_DRIFT' THEN 'Schedule model retraining and validate feature engineering'
                WHEN volume_status = 'VOLUME_LOW' THEN 'Check upstream data pipelines and batch job health'
                ELSE 'Continue routine monitoring'
            END as recommended_action,
            
            CURRENT_TIMESTAMP() as alert_timestamp
            
        FROM {self.schema_name}.healthcare_drift_detection
        WHERE prediction_date >= CURRENT_DATE() - INTERVAL 7 DAYS
          AND (
              volume_status != 'VOLUME_OK' OR
              accuracy_status != 'ACCURACY_OK' OR
              prediction_drift_status != 'PREDICTION_STABLE' OR
              demographic_drift_status != 'DEMOGRAPHIC_STABLE' OR
              composite_drift_score > 0.2
          )
        ORDER BY 
            CASE alert_severity 
                WHEN 'CRITICAL' THEN 1 
                WHEN 'HIGH' THEN 2 
                WHEN 'MEDIUM' THEN 3 
                ELSE 4 
            END,
            prediction_date DESC
        """
        
        # Create alerts view in the unified schema
        alerts_view_name = f"{self.schema_name}.healthcare_model_alerts"
        
        try:
            spark.sql(f"CREATE OR REPLACE VIEW {alerts_view_name} AS {alert_query}")
            print(f"✅ Created alerts monitoring view: {alerts_view_name}")
            
            # Check for current alerts
            current_alerts = spark.sql(f"SELECT * FROM {alerts_view_name}")
            alert_count = current_alerts.count()
            
            if alert_count > 0:
                print(f"⚠️  Found {alert_count} active alerts:")
                current_alerts.select(
                    "prediction_date", "primary_alert_type", "alert_severity", 
                    "business_impact_description", "recommended_action"
                ).show(truncate=False)
            else:
                print("✅ No active alerts - model is performing within expected parameters")
                
        except Exception as e:
            print(f"❌ Error setting up alerts: {e}")
        
        return alert_conditions, alerts_view_name
    
    def create_performance_dashboard(self):
        """
        Create a comprehensive performance dashboard that provides executive-level
        insights into your healthcare model's business impact and operational health.
        """
        
        print("Creating executive performance dashboard...")
        
        dashboard_query = f"""
        WITH performance_summary AS (
            SELECT 
                -- Time period analysis
                'Last 7 Days' as reporting_period,
                MIN(prediction_date) as period_start,
                MAX(prediction_date) as period_end,
                COUNT(DISTINCT prediction_date) as active_days,
                
                -- Volume and operational metrics
                SUM(daily_prediction_count) as total_predictions,
                AVG(daily_prediction_count) as avg_daily_predictions,
                MIN(daily_prediction_count) as min_daily_predictions,
                MAX(daily_prediction_count) as max_daily_predictions,
                
                -- Financial impact metrics
                AVG(avg_daily_prediction) as avg_predicted_cost,
                STDDEV(avg_daily_prediction) as prediction_volatility,
                SUM(daily_prediction_count * avg_daily_prediction) as total_predicted_exposure,
                
                -- Quality and performance metrics
                AVG(daily_mae) as avg_prediction_error,
                AVG(daily_rmse) as avg_rmse,
                AVG(daily_r2) as avg_r2_score,
                
                -- Drift and stability metrics
                AVG(composite_drift_score) as avg_composite_drift,
                MAX(composite_drift_score) as max_composite_drift,
                AVG(prediction_drift_percentage) as avg_prediction_drift,
                
                -- Operational health indicators
                COUNT(CASE WHEN volume_status != 'VOLUME_OK' THEN 1 END) as days_with_volume_issues,
                COUNT(CASE WHEN accuracy_status != 'ACCURACY_OK' THEN 1 END) as days_with_accuracy_issues,
                COUNT(CASE WHEN prediction_drift_status != 'PREDICTION_STABLE' THEN 1 END) as days_with_drift_issues,
                COUNT(CASE WHEN demographic_drift_status != 'DEMOGRAPHIC_STABLE' THEN 1 END) as days_with_bias_concerns,
                
                -- Demographic fairness and equity metrics
                AVG(smoker_rate) as avg_smoker_rate,
                STDDEV(smoker_rate) as smoker_rate_stability,
                MAX(smoker_rate_drift) as max_smoker_rate_drift,
                AVG(age_drift) as avg_age_drift,
                AVG(bmi_drift) as avg_bmi_drift,
                
                -- Regional coverage and equity
                AVG(northeast_pct) as avg_northeast_coverage,
                AVG(southeast_pct) as avg_southeast_coverage,
                AVG(northwest_pct) as avg_northwest_coverage,
                AVG(southwest_pct) as avg_southwest_coverage,
                
                -- Risk distribution metrics
                AVG(high_risk_pct) as avg_high_risk_percentage,
                AVG(very_high_risk_pct) as avg_very_high_risk_percentage
                
            FROM {self.schema_name}.healthcare_drift_detection
            WHERE prediction_date >= CURRENT_DATE() - INTERVAL 7 DAYS
        )
        
        SELECT 
            *,
            -- Executive summary ratings
            CASE 
                WHEN avg_prediction_error IS NULL THEN 'MONITORING_ONLY'
                WHEN avg_prediction_error <= {self.thresholds['max_mae_threshold']} * 0.7 THEN 'EXCELLENT'
                WHEN avg_prediction_error <= {self.thresholds['max_mae_threshold']} THEN 'GOOD'
                WHEN avg_prediction_error <= {self.thresholds['max_mae_threshold']} * 1.2 THEN 'ACCEPTABLE'
                ELSE 'NEEDS_ATTENTION'
            END as performance_rating,
            
            CASE 
                WHEN days_with_volume_issues = 0 AND days_with_accuracy_issues = 0 AND days_with_drift_issues = 0 THEN 'STABLE'
                WHEN days_with_accuracy_issues > 0 OR days_with_drift_issues > 2 THEN 'UNSTABLE'
                ELSE 'MINOR_ISSUES'
            END as operational_stability,
            
            CASE 
                WHEN max_smoker_rate_drift <= {self.thresholds['max_demographic_drift']} * 0.5 THEN 'EXCELLENT'
                WHEN max_smoker_rate_drift <= {self.thresholds['max_demographic_drift']} THEN 'GOOD'
                ELSE 'REQUIRES_REVIEW'
            END as fairness_rating,
            
            -- Business impact and recommendations
            CASE 
                WHEN avg_prediction_error > {self.thresholds['max_mae_threshold']} THEN 'IMMEDIATE_RETRAINING_RECOMMENDED'
                WHEN avg_composite_drift > 0.3 THEN 'SCHEDULE_RETRAINING_SOON'
                WHEN days_with_drift_issues > 3 THEN 'MONITOR_CLOSELY'
                ELSE 'CONTINUE_ROUTINE_OPERATIONS'
            END as business_recommendation,
            
            -- Financial impact assessment
            CASE 
                WHEN avg_prediction_error IS NOT NULL AND avg_prediction_error > {self.thresholds['max_mae_threshold']} THEN 'HIGH_FINANCIAL_RISK'
                WHEN prediction_volatility > avg_predicted_cost * 0.3 THEN 'MODERATE_FINANCIAL_RISK'
                ELSE 'LOW_FINANCIAL_RISK'
            END as financial_risk_assessment,
            
            CURRENT_TIMESTAMP() as dashboard_generated_timestamp
            
        FROM performance_summary
        """
        
        # Create dashboard view in the unified schema
        dashboard_view_name = f"{self.schema_name}.model_performance_dashboard"
        
        try:
            spark.sql(f"CREATE OR REPLACE VIEW {dashboard_view_name} AS {dashboard_query}")
            print(f"✅ Created performance dashboard: {dashboard_view_name}")
            
            # Display current dashboard summary
            dashboard_data = spark.sql(f"SELECT * FROM {dashboard_view_name}")
            if dashboard_data.count() > 0:
                print("📊 Executive Performance Summary:")
                dashboard_data.select(
                    "reporting_period", "total_predictions", "avg_predicted_cost",
                    "performance_rating", "operational_stability", "fairness_rating",
                    "business_recommendation", "financial_risk_assessment"
                ).show(truncate=False)
                
                # Show detailed metrics
                print("📈 Detailed Metrics:")
                dashboard_data.select(
                    "avg_prediction_error", "avg_composite_drift", 
                    "days_with_accuracy_issues", "days_with_drift_issues", "days_with_bias_concerns"
                ).show(truncate=False)
            
            return {"success": True, "view_name": dashboard_view_name}
            
        except Exception as e:
            print(f"❌ Failed to create performance dashboard: {str(e)}")
            return {"success": False, "error": str(e)}
    
    def run_comprehensive_health_check(self):
        """
        Run a complete health check of your healthcare model monitoring system.
        This verifies everything is working correctly within the unified schema.
        """
        
        print("🏥 Running Comprehensive Healthcare Model Health Check...")
        print("=" * 70)
        
        health_status = {
            "monitoring_table_exists": False,
            "baseline_table_exists": False, 
            "recent_predictions": False,
            "drift_detection_working": False,
            "alerts_configured": False,
            "dashboard_available": False,
            "overall_health": "UNKNOWN"
        }
        
        # Check 1: Monitoring table exists and has data
        try:
            monitoring_df = spark.table(self.monitoring_table)
            row_count = monitoring_df.count()
            health_status["monitoring_table_exists"] = True
            print(f"✅ Monitoring table exists with {row_count:,} records")
            
            # Check for recent data
            recent_data = monitoring_df.filter(
                col("prediction_timestamp") >= expr("CURRENT_DATE() - INTERVAL 7 DAYS")
            ).count()
            
            if recent_data > 0:
                health_status["recent_predictions"] = True
                print(f"✅ Found {recent_data:,} recent predictions in last 7 days")
            else:
                print("⚠️  No recent predictions found - run batch inference")
                
        except Exception as e:
            print(f"❌ Monitoring table check failed: {e}")
        
        # Check 2: Baseline table exists
        try:
            baseline_df = spark.table(self.baseline_table)
            baseline_count = baseline_df.count()
            health_status["baseline_table_exists"] = True
            print(f"✅ Baseline table exists with {baseline_count:,} records")
        except Exception as e:
            print(f"❌ Baseline table check failed: {e}")
        
        # Check 3: Drift detection (only if view exists)
        try:
            # First check if the view exists
            existing_views = spark.sql(f"SHOW VIEWS IN {self.schema_name}")
            view_names = [row.viewName for row in existing_views.collect()]
            
            if "healthcare_drift_detection" in view_names:
                drift_df = spark.sql(f"SELECT * FROM {self.schema_name}.healthcare_drift_detection LIMIT 1")
                if drift_df.count() > 0:
                    health_status["drift_detection_working"] = True
                    print("✅ Drift detection is working")
                else:
                    print("⚠️  Drift detection view exists but no data")
            else:
                print("⚠️  Drift detection not yet configured")
        except Exception as e:
            print(f"⚠️  Drift detection check failed: {e}")
        
        # Check 4: Alerts (only if view exists)
        try:
            existing_views = spark.sql(f"SHOW VIEWS IN {self.schema_name}")
            view_names = [row.viewName for row in existing_views.collect()]
            
            if "healthcare_model_alerts" in view_names:
                alerts_df = spark.sql(f"SELECT * FROM {self.schema_name}.healthcare_model_alerts LIMIT 1")
                health_status["alerts_configured"] = True
                print("✅ Alerts system is configured")
            else:
                print("⚠️  Alerts system not yet configured")
        except Exception as e:
            print(f"⚠️  Alerts check failed: {e}")
        
        # Check 5: Dashboard (only if view exists)
        try:
            existing_views = spark.sql(f"SHOW VIEWS IN {self.schema_name}")
            view_names = [row.viewName for row in existing_views.collect()]
            
            if "model_performance_dashboard" in view_names:
                dashboard_df = spark.sql(f"SELECT * FROM {self.schema_name}.model_performance_dashboard LIMIT 1")
                health_status["dashboard_available"] = True
                print("✅ Performance dashboard is available")
            else:
                print("⚠️  Performance dashboard not yet configured")
        except Exception as e:
            print(f"⚠️  Dashboard check failed: {e}")
        
        # Overall health assessment
        checks_passed = sum([
            health_status["monitoring_table_exists"],
            health_status["baseline_table_exists"],
            health_status["recent_predictions"],
            health_status["drift_detection_working"],
            health_status["alerts_configured"],
            health_status["dashboard_available"]
        ])
        
        if checks_passed >= 5:
            health_status["overall_health"] = "HEALTHY"
            print("\n🎉 Overall Status: HEALTHY - Monitoring system is fully operational!")
        elif checks_passed >= 3:
            health_status["overall_health"] = "PARTIALLY_HEALTHY"
            print("\n⚠️  Overall Status: PARTIALLY HEALTHY - Some components need setup")
        else:
            health_status["overall_health"] = "UNHEALTHY"
            print("\n❌ Overall Status: UNHEALTHY - Monitoring system needs configuration")
        
        print("=" * 70)
        return health_status
    
    def setup_complete_monitoring_system(self):
        """
        One-stop method to set up the complete monitoring infrastructure.
        This creates all monitoring components in the correct order.
        """
        
        print("🚀 Setting up Complete Healthcare Model Monitoring System")
        print("All monitoring assets will be created in the unified juan_dev.ml schema")
        print("=" * 75)
        
        setup_results = {}
        
        # Step 1: Health check foundation
        print("\nStep 1: Validating monitoring foundation...")
        health_status = self.run_comprehensive_health_check()
        setup_results["foundation"] = health_status
        
        if not health_status["monitoring_table_exists"]:
            print("❌ Cannot proceed without monitoring table. Run batch inference first.")
            return setup_results
        
        # Step 2: Set up drift detection
        print("\nStep 2: Setting up drift detection system...")
        try:
            drift_view = self.setup_drift_detection()
            setup_results["drift_detection"] = {"success": True, "view": drift_view}
        except Exception as e:
            setup_results["drift_detection"] = {"success": False, "error": str(e)}
        
        # Step 3: Set up alerts
        print("\nStep 3: Setting up intelligent alerts...")
        try:
            alert_conditions, alerts_view = self.setup_performance_alerts()
            setup_results["alerts"] = {
                "success": True, 
                "view": alerts_view, 
                "conditions": len(alert_conditions)
            }
        except Exception as e:
            setup_results["alerts"] = {"success": False, "error": str(e)}
        
        # Step 4: Set up dashboard
        print("\nStep 4: Creating performance dashboard...")
        try:
            dashboard_result = self.create_performance_dashboard()
            setup_results["dashboard"] = dashboard_result
        except Exception as e:
            setup_results["dashboard"] = {"success": False, "error": str(e)}
        
        # Step 5: Try Lakehouse Monitoring (optional)
        print("\nStep 5: Attempting Lakehouse Monitoring setup...")
        try:
            monitor_info = self.setup_lakehouse_monitoring()
            if monitor_info:
                setup_results["lakehouse_monitoring"] = {"success": True, "monitor": monitor_info}
            else:
                setup_results["lakehouse_monitoring"] = {"success": False, "note": "Requires additional permissions"}
        except Exception as e:
            setup_results["lakehouse_monitoring"] = {"success": False, "error": str(e)}
        
        # Final validation
        print("\nStep 6: Final system validation...")
        final_health = self.run_comprehensive_health_check()
        setup_results["final_health"] = final_health
        
        # Print setup summary
        self._print_setup_summary(setup_results)
        
        return setup_results
    
    def _print_setup_summary(self, setup_results):
        """Print a comprehensive summary of the monitoring setup."""
        
        print("\n" + "=" * 75)
    
    def diagnose_and_fix_setup_issues(self):
        """
        Diagnose common setup issues and provide specific fixes.
        Run this if you encounter errors during monitoring setup.
        """
        
        print("🔧 HEALTHCARE MONITORING SYSTEM DIAGNOSTICS")
        print("=" * 60)
        
        issues_found = []
        fixes_applied = []
        
        # Check 1: Schema exists and is accessible
        print("\n1. Checking schema accessibility...")
        try:
            schemas = spark.sql("SHOW TABLES IN juan_dev.ml")
            if schemas.count() > 0:
                print("✅ juan_dev catalog is accessible")
                
                # Check if ml schema exists
                ml_tables = spark.sql("SHOW TABLES IN juan_dev.ml")
                print(f"✅ juan_dev.ml schema exists with {ml_tables.count()} objects")
            else:
                issues_found.append("Cannot access juan_dev catalog")
                print("❌ Cannot access juan_dev catalog - check permissions")
                
        except Exception as e:
            issues_found.append(f"Schema access error: {str(e)}")
            print(f"❌ Schema access error: {e}")
        
        # Check 2: Required tables exist
        print("\n2. Checking required tables...")
        required_tables = [
            ("insurance_silver", "Training baseline data"),
            ("insurance_predictions", "Live prediction results")
        ]
        
        for table_name, description in required_tables:
            try:
                full_table_name = f"juan_dev.ml.{table_name}"
                df = spark.table(full_table_name)
                count = df.count()
                print(f"✅ {table_name}: {count:,} records ({description})")
                
                # Check if predictions table has recent data
                if table_name == "insurance_predictions":
                    recent_count = df.filter(
                        col("prediction_timestamp") >= expr("CURRENT_DATE() - INTERVAL 1 DAYS")
                    ).count()
                    
                    if recent_count > 0:
                        print(f"   ✅ Has {recent_count:,} recent predictions (last 24 hours)")
                    else:
                        issues_found.append("No recent prediction data for monitoring")
                        print("   ⚠️  No recent predictions - run batch inference first")
                        
            except Exception as e:
                issues_found.append(f"Missing table {table_name}: {str(e)}")
                print(f"❌ {table_name}: Not found - {e}")
        
        # Check 3: Column requirements
        print("\n3. Checking column requirements...")
        try:
            pred_df = spark.table("juan_dev.ml.insurance_predictions")
            required_columns = [
                "prediction_timestamp", "adjusted_prediction", "model_name", 
                "sex", "region", "smoker", "age", "bmi", "cost_risk_category"
            ]
            
            missing_columns = [col for col in required_columns if col not in pred_df.columns]
            
            if missing_columns:
                issues_found.append(f"Monitoring table missing columns: {missing_columns}")
                print(f"❌ Missing required columns: {missing_columns}")
                print("   💡 Run updated batch inference to add missing columns")
            else:
                print("✅ All required columns present for monitoring")
                
        except Exception as e:
            print(f"⚠️  Could not check column requirements: {e}")
        
        # Check 4: Existing monitoring views
        print("\n4. Checking existing monitoring views...")
        try:
            views = spark.sql(f"SHOW VIEWS IN juan_dev.ml")
            view_names = [row.viewName for row in views.collect()]
            
            monitoring_views = [name for name in view_names if 'healthcare' in name or 'monitoring' in name or 'dashboard' in name]
            
            if monitoring_views:
                print("✅ Found existing monitoring views:")
                for view in monitoring_views:
                    print(f"   - {view}")
            else:
                print("ℹ️  No monitoring views found yet (this is normal for first setup)")
                
        except Exception as e:
            print(f"⚠️  Could not check views: {e}")
        
        # Provide specific fixes
        print(f"\n{'='*60}")
        print("🔧 DIAGNOSTIC SUMMARY AND FIXES")
        print(f"{'='*60}")
        
        if not issues_found:
            print("🎉 No issues found! Your system is ready for monitoring setup.")
            print("\nNext step: Run monitor.setup_complete_monitoring_system()")
        else:
            print(f"Found {len(issues_found)} issues that need attention:")
            for i, issue in enumerate(issues_found, 1):
                print(f"\n{i}. ❌ {issue}")
                
                # Provide specific fixes
                if "No recent prediction data" in issue:
                    print("   💡 FIX: Run batch inference first:")
                    print("      batch_inference.run_batch_inference(")
                    print("          input_table='juan_dev.ml.insurance_silver',")
                    print("          output_table='juan_dev.ml.insurance_predictions')")
                
                elif "Missing table" in issue and "insurance_predictions" in issue:
                    print("   💡 FIX: Create predictions table by running batch inference")
                
                elif "Missing required columns" in issue:
                    print("   💡 FIX: Update batch inference with latest code that includes all columns")
                
                elif "Cannot access" in issue:
                    print("   💡 FIX: Check Unity Catalog permissions for juan_dev.ml schema")
        
        return {
            "issues_found": issues_found,
            "fixes_applied": fixes_applied,
            "ready_for_setup": len(issues_found) == 0
        }
    
    def quick_monitoring_setup(self):
        """
        Simplified setup method that handles common issues automatically.
        Use this if the full setup encounters problems.
        """
        
        print("🚀 Quick Healthcare Monitoring Setup")
        print("This simplified setup works around common configuration issues")
        print("=" * 65)
        
        setup_success = True
        
        # Step 1: Basic validation
        print("\nStep 1: Running diagnostics...")
        diagnostic_results = self.diagnose_and_fix_setup_issues()
        
        if not diagnostic_results["ready_for_setup"]:
            print("\n❌ Cannot proceed with setup - please address the issues above first")
            return {"success": False, "diagnostics": diagnostic_results}
        
        # Step 2: Create drift detection view with error handling
        print("\nStep 2: Setting up drift detection...")
        try:
            drift_view = self.setup_drift_detection()
            print(f"✅ Drift detection created: {drift_view}")
        except Exception as e:
            print(f"❌ Drift detection failed: {e}")
            setup_success = False
        
        # Step 3: Create alerts view
        print("\nStep 3: Setting up alerts...")
        try:
            alert_conditions, alerts_view = self.setup_performance_alerts()
            print(f"✅ Alert system created: {alerts_view}")
        except Exception as e:
            print(f"❌ Alert setup failed: {e}")
            setup_success = False
        
        # Step 4: Create dashboard
        print("\nStep 4: Setting up dashboard...")
        try:
            dashboard_result = self.create_performance_dashboard()
            if dashboard_result.get("success"):
                print(f"✅ Dashboard created: {dashboard_result['view_name']}")
            else:
                print(f"❌ Dashboard failed: {dashboard_result.get('error')}")
                setup_success = False
        except Exception as e:
            print(f"❌ Dashboard setup failed: {e}")
            setup_success = False
        
        # Final status
        if setup_success:
            print(f"\n🎉 Quick setup completed successfully!")
            print(f"All monitoring views created in {self.schema_name}")
            
            # Show what was created
            try:
                views = spark.sql(f"SHOW VIEWS IN {self.schema_name}")
                monitoring_views = [row.viewName for row in views.collect() 
                                  if any(keyword in row.viewName.lower() 
                                       for keyword in ['healthcare', 'monitoring', 'dashboard', 'alert'])]
                
                if monitoring_views:
                    print(f"\n📋 Created monitoring views:")
                    for view in monitoring_views:
                        print(f"   • {self.schema_name}.{view}")
            except:
                pass
                
        else:
            print(f"\n⚠️  Setup completed with some errors")
            print("Check the error messages above for specific issues")
        
        return {
            "success": setup_success,
            "diagnostics": diagnostic_results,
            "monitoring_ready": setup_success
        }
        print("🏥 HEALTHCARE MODEL MONITORING SETUP COMPLETE")
        print("=" * 75)
        
        # Component status
        components = [
            ("Foundation", setup_results.get("foundation", {}).get("overall_health", "Unknown")),
            ("Drift Detection", "✅ Operational" if setup_results.get("drift_detection", {}).get("success") else "❌ Failed"),
            ("Alert System", "✅ Operational" if setup_results.get("alerts", {}).get("success") else "❌ Failed"),
            ("Performance Dashboard", "✅ Operational" if setup_results.get("dashboard", {}).get("success") else "❌ Failed"),
            ("Lakehouse Monitoring", "✅ Enabled" if setup_results.get("lakehouse_monitoring", {}).get("success") else "⚠️  Optional")
        ]
        
        print("\n📊 COMPONENT STATUS:")
        for component, status in components:
            print(f"   {component:.<25} {status}")
        
        # Monitoring assets created
        print(f"\n🗂️  MONITORING ASSETS (in {self.schema_name}):")
        if setup_results.get("drift_detection", {}).get("success"):
            print("   ✅ healthcare_drift_detection (view)")
        if setup_results.get("alerts", {}).get("success"):
            print("   ✅ healthcare_model_alerts (view)")
        if setup_results.get("dashboard", {}).get("success"):
            print("   ✅ model_performance_dashboard (view)")
        
        # Quick access queries
        print(f"\n🔍 QUICK ACCESS QUERIES:")
        print(f"   • Check recent drift: SELECT * FROM {self.schema_name}.healthcare_drift_detection LIMIT 5")
        print(f"   • View active alerts: SELECT * FROM {self.schema_name}.healthcare_model_alerts")
        print(f"   • Executive summary: SELECT * FROM {self.schema_name}.model_performance_dashboard")
        
        # Next steps
        print(f"\n🎯 RECOMMENDED NEXT STEPS:")
        if setup_results.get("foundation", {}).get("recent_predictions", False):
            print("   1. ✅ You have recent prediction data - monitoring is active")
        else:
            print("   1. 🔄 Run batch inference to generate fresh monitoring data")
        
        print("   2. 📈 Review the performance dashboard for baseline insights")
        print("   3. 🚨 Set up notification channels for critical alerts")
        print("   4. 📅 Schedule automated monitoring reports")
        print("   5. 🔄 Connect monitoring to your retraining pipeline")
        
        print("\n" + "=" * 75)


# Example usage with the unified schema approach
print("🚀 Initializing Healthcare Model Monitoring System")
print("   Using unified juan_dev.ml schema for all ML assets")

# Create the monitor with unified schema
monitor = HealthcareModelMonitor(
    model_name="juan_dev.ml.healthcare_insurance_model",
    baseline_table="juan_dev.ml.insurance_silver",      # Your training data
    monitoring_table="juan_dev.ml.insurance_predictions" # Your prediction results
)

# Option 1: Run diagnostics first (recommended)
print("Running diagnostics to check system readiness...")
diagnostic_results = monitor.diagnose_and_fix_setup_issues()

# Option 2: If diagnostics show you're ready, run the quick setup
if diagnostic_results["ready_for_setup"]:
    print("System ready! Running monitoring setup...")
    setup_results = monitor.setup_complete_monitoring_system()
else:
    print("Please address the diagnostic issues first, then run setup")

# Quick verification queries you can run
print("\n🔧 VERIFICATION COMMANDS:")
print("Run these to verify your monitoring system:")
print()
print("# Check drift detection")
print("display(spark.sql('SELECT * FROM juan_dev.ml.healthcare_drift_detection ORDER BY prediction_date DESC LIMIT 5'))")
print()
print("# Check for alerts") 
print("display(spark.sql('SELECT * FROM juan_dev.ml.healthcare_model_alerts'))")
print()
print("# View executive dashboard")
print("display(spark.sql('SELECT * FROM juan_dev.ml.model_performance_dashboard'))")
print()
print("# List all monitoring views")
print("display(spark.sql('SHOW TABLES IN juan_dev.ml LIKE \"*healthcare*\" OR LIKE \"*monitoring*\" OR LIKE \"*dashboard*\"'))")