In [None]:
from databricks.feature_engineering import FeatureEngineeringClient
import mlflow.pyfunc
from pyspark.sql.functions import *
from pyspark.sql.types import IntegerType, FloatType, DoubleType

class HealthcareBatchInference:
    """
    Batch inference class aligned with feature engineering and training.
    Uses dim_patients table and customer_id key mapping.
    """
    
    def __init__(self, model_name="juan_dev.healthcare_data.insurance_model", model_alias="champion"):
        self.model_name = model_name
        self.model_alias = model_alias
        self.fe = FeatureEngineeringClient()
        
        # Spark optimization for batch processing
        spark.conf.set("spark.sql.adaptive.enabled", "true")
        spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")
        spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "10000")
    
    def run_batch_inference(self, input_table="juan_dev.healthcare_data.dim_patients", output_table=None):
        """Execute batch inference using feature engineering integration"""
        
        try:
            # Load model from Unity Catalog
            model_uri = f"models:/{self.model_name}@{self.model_alias}"
            
            # Get model version info
            client = mlflow.MlflowClient()
            model_version_info = client.get_model_version_by_alias(self.model_name, self.model_alias)
            model_version = model_version_info.version
            
            print(f"Loading model version {model_version} from {model_uri}")
            
            # Load input data from dim_patients - filter for current records (same as training)
            input_df = spark.table(input_table).filter(col("is_current_record") == True)
            print(f"Input data shape: {input_df.count()} rows, {len(input_df.columns)} columns")
            
            # Prepare data by creating customer_id key
            # The feature engineering client will automatically join features from the feature table
            input_df_prepared = (
                input_df
                .withColumn("customer_id", col("patient_natural_key"))
            )
            
            print("Data preparation completed successfully")
            print(f"Customer ID sample: {[row.customer_id for row in input_df_prepared.select('customer_id').take(3)]}")
            
            # Batch scoring with feature engineering integration
            # The fe.score_batch will automatically:
            # 1. Join with feature table (ml_insurance_features) using customer_id
            # 2. Get all required features (age_risk_score, smoking_impact, etc.)
            # 3. Apply the model's preprocessing pipeline
            # 4. Generate predictions
            print("Starting batch scoring with feature engineering integration...")
            
            predictions_df = self.fe.score_batch(
                df=input_df_prepared,
                model_uri=model_uri
            )
            
            print("Batch scoring completed successfully!")
            print(f"Predictions shape: {predictions_df.count()} rows")
            print(f"Prediction columns: {predictions_df.columns}")
            
            # Add business logic and metadata
            final_predictions = (
                predictions_df
                .withColumn("prediction_timestamp", current_timestamp())
                .withColumn("model_version", lit(model_version))
                .withColumn("model_name", lit(self.model_name))
                
                # Business rule: minimum risk score threshold
                .withColumn("adjusted_prediction", 
                           expr("GREATEST(prediction, 10)"))  # Minimum risk score of 10
                
                # Risk categorization for business use (adjusted for health risk scores 0-100)
                .withColumn("risk_category",
                           expr("CASE WHEN adjusted_prediction < 30 THEN 'low' " +
                                "WHEN adjusted_prediction < 60 THEN 'medium' " +
                                "WHEN adjusted_prediction < 85 THEN 'high' " +
                                "ELSE 'critical' END"))
                
                # Confidence intervals (approximate business rules)
                .withColumn("prediction_lower_bound", 
                           expr("adjusted_prediction * 0.90"))
                .withColumn("prediction_upper_bound", 
                           expr("adjusted_prediction * 1.10"))
                
                # Add risk flags for business decision making
                .withColumn("high_risk_patient",
                           expr("adjusted_prediction > 75 OR risk_category = 'critical'"))
                .withColumn("requires_review", 
                           expr("adjusted_prediction > 90"))
            )
            
            # Display results for inspection
            print("Sample predictions:")
            final_predictions.select(
                "customer_id", "prediction", "adjusted_prediction", "risk_category",
                "high_risk_patient", "requires_review"
            ).show(10)
            
            # Save results if output table specified
            if output_table:
                print(f"Saving results to {output_table}...")
                (final_predictions
                 .write
                 .mode("overwrite")
                 .option("overwriteSchema", "true")
                 .saveAsTable(output_table))
                print("Results saved successfully!")
            
            # Log batch inference metrics for monitoring
            with mlflow.start_run(run_name="batch_inference_health_risk"):
                inference_count = final_predictions.count()
                
                # Calculate business metrics (adjusted for health risk prediction)
                avg_prediction = final_predictions.agg(avg("adjusted_prediction")).collect()[0][0]
                high_risk_count = final_predictions.filter(col("high_risk_patient") == True).count()
                requires_review_count = final_predictions.filter(col("requires_review") == True).count()
                
                # Risk category distribution
                risk_distribution = final_predictions.groupBy("risk_category").count().collect()
                risk_dist_dict = {row.risk_category: row['count'] for row in risk_distribution}
                
                mlflow.log_metrics({
                    "batch_inference_count": inference_count,
                    "average_predicted_risk_score": avg_prediction,
                    "high_risk_patient_count": high_risk_count,
                    "requires_review_count": requires_review_count,
                    "high_risk_percentage": (high_risk_count / inference_count) * 100,
                    "model_version": float(model_version),
                    **{f"risk_{k}_count": v for k, v in risk_dist_dict.items()}
                })
                
                print(f"Logged metrics - Average predicted risk score: {avg_prediction:.2f}")
                print(f"High risk patients: {high_risk_count} ({high_risk_count/inference_count*100:.1f}%)")
                print(f"Require review: {requires_review_count} ({requires_review_count/inference_count*100:.1f}%)")
            
            return final_predictions
            
        except Exception as e:
            print(f"Error during batch inference: {str(e)}")
            print("Troubleshooting steps:")
            print("1. Check that the model exists and has the 'champion' alias")
            print("2. Verify input data contains patient_natural_key column")
            print("3. Ensure feature table ml_insurance_features is accessible")
            print("4. Check Unity Catalog permissions")
            print("5. Verify customer_id mapping from patient_natural_key")
            import traceback
            traceback.print_exc()
            raise e

# Example usage
print("Initializing batch inference pipeline...")

batch_inference = HealthcareBatchInference()

print("Running batch inference on healthcare data using dim_patients table...")
try:
    results = batch_inference.run_batch_inference(
        input_table="juan_dev.healthcare_data.dim_patients",
        output_table="juan_dev.healthcare_data.ml_patient_predictions"
    )
    
    print("\n=== Batch Inference Summary ===")
    print(f"Successfully processed {results.count()} records")
    print("Results saved to predictions table")
    
    # Show sample results
    print("\n=== Sample Predictions ===")
    results.select(
        "customer_id", 
        "adjusted_prediction",
        "risk_category",
        "high_risk_patient",
        "requires_review"
    ).show(5)
    
    # Show risk distribution
    print("\n=== Risk Category Distribution ===")
    results.groupBy("risk_category").count().orderBy("risk_category").show()
    
except Exception as e:
    print(f"Batch inference failed: {e}")
    print("Please check the error message above and follow troubleshooting steps")