In [0]:
from databricks.feature_engineering import FeatureEngineeringClient
import mlflow.pyfunc
from pyspark.sql.functions import *
from pyspark.sql.types import IntegerType, FloatType, DoubleType

class HealthcareBatchInferenceV2:
    """
    Updated batch inference class that works seamlessly with the improved model pipeline.
    No more preprocessing mismatches!
    """
    
    def __init__(self, model_name="juan_dev.ml.healthcare_insurance_model", model_alias="champion"):
        self.model_name = model_name
        self.model_alias = model_alias
        self.fe = FeatureEngineeringClient()
        
        # Spark optimization for batch processing
        spark.conf.set("spark.sql.adaptive.enabled", "true")
        spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")
        spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "10000")
    
    def run_batch_inference(self, input_table, output_table=None):
        """Execute batch inference with automatic feature lookup and proper preprocessing"""
        
        try:
            # Load model from Unity Catalog
            model_uri = f"models:/{self.model_name}@{self.model_alias}"
            
            # Get model version info
            client = mlflow.MlflowClient()
            model_version_info = client.get_model_version_by_alias(self.model_name, self.model_alias)
            model_version = model_version_info.version
            
            print(f"Loading model version {model_version} from {model_uri}")
            
            # Load input data
            input_df = spark.table(input_table)
            print(f"Input data shape: {input_df.count()} rows, {len(input_df.columns)} columns")
            print(f"Input columns: {input_df.columns}")
            
            # Validate required columns are present
            required_base_columns = ['customer_id', 'sex', 'region', 'smoker', 'age', 'bmi', 'children']
            missing_columns = [col for col in required_base_columns if col not in input_df.columns]
            if missing_columns:
                raise ValueError(f"Input data is missing required columns: {missing_columns}")
            
            # Ensure proper data types for the base columns
            # The model pipeline will handle the encoding internally
            input_df_prepared = (
                input_df
                .withColumn("sex", upper(col("sex")).cast("string"))  # Ensure consistent case
                .withColumn("region", upper(col("region")).cast("string"))
                .withColumn("smoker", 
                           when(upper(col("smoker")).isin("YES", "TRUE", "1"), True)
                           .otherwise(False).cast("boolean"))
                .withColumn("age", col("age").cast("integer"))
                .withColumn("bmi", col("bmi").cast("double"))
                .withColumn("children", col("children").cast("integer"))
                .withColumn("customer_id", col("customer_id").cast("long"))
            )
            
            print("Data preparation completed successfully")
            print(f"Prepared data columns: {input_df_prepared.columns}")
            
            # Batch scoring with feature engineering integration
            # The model pipeline will automatically handle:
            # 1. Feature lookup from juan_dev.ml.healthcare_features
            # 2. Categorical encoding (sex, region)
            # 3. Numerical scaling
            # 4. Model prediction
            print("Starting batch scoring with feature engineering integration...")
            
            predictions_df = self.fe.score_batch(
                df=input_df_prepared,
                model_uri=model_uri
            )
            
            print("Batch scoring completed successfully!")
            print(f"Predictions shape: {predictions_df.count()} rows")
            
            # Add business logic and metadata
            final_predictions = (
                predictions_df
                .withColumn("prediction_timestamp", current_timestamp())
                .withColumn("model_version", lit(model_version))
                .withColumn("model_name", lit(self.model_name))
                
                # Business rule: minimum charge threshold
                .withColumn("adjusted_prediction", 
                           expr("GREATEST(prediction, 500)"))
                
                # Risk categorization for business use
                .withColumn("cost_risk_category",
                           expr("CASE WHEN adjusted_prediction < 2000 THEN 'low' " +
                                "WHEN adjusted_prediction < 8000 THEN 'medium' " +
                                "WHEN adjusted_prediction < 20000 THEN 'high' " +
                                "ELSE 'very_high' END"))
                
                # Confidence intervals (approximate business rules)
                .withColumn("prediction_lower_bound", 
                           expr("adjusted_prediction * 0.85"))
                .withColumn("prediction_upper_bound", 
                           expr("adjusted_prediction * 1.15"))
                
                # Add risk flags for business decision making
                .withColumn("high_risk_patient",
                           expr("adjusted_prediction > 15000 OR cost_risk_category = 'very_high'"))
                .withColumn("requires_review", 
                           expr("adjusted_prediction > 25000 OR (smoker AND adjusted_prediction > 10000)"))
            )
            
            # Display results for inspection
            print("Sample predictions:")
            final_predictions.select(
                "customer_id", "sex", "region", "smoker", 
                "prediction", "adjusted_prediction", "cost_risk_category"
            ).show(10)
            
            # Save results if output table specified
            if output_table:
                print(f"Saving results to {output_table}...")
                (final_predictions
                 .write
                 .mode("overwrite")
                 .option("overwriteSchema", "true")
                 .saveAsTable(output_table))
                print("Results saved successfully!")
            
            # Log batch inference metrics for monitoring
            with mlflow.start_run(run_name="batch_inference"):
                inference_count = final_predictions.count()
                
                # Calculate business metrics
                avg_prediction = final_predictions.agg(avg("adjusted_prediction")).collect()[0][0]
                high_risk_count = final_predictions.filter(col("high_risk_patient") == True).count()
                requires_review_count = final_predictions.filter(col("requires_review") == True).count()
                
                # Risk category distribution
                risk_distribution = final_predictions.groupBy("cost_risk_category").count().collect()
                risk_dist_dict = {row.cost_risk_category: row['count'] for row in risk_distribution}
                
                mlflow.log_metrics({
                    "batch_inference_count": inference_count,
                    "average_predicted_cost": avg_prediction,
                    "high_risk_patient_count": high_risk_count,
                    "requires_review_count": requires_review_count,
                    "high_risk_percentage": (high_risk_count / inference_count) * 100,
                    "model_version": float(model_version),
                    **{f"risk_{k}_count": v for k, v in risk_dist_dict.items()}
                })
                
                print(f"Logged metrics - Average predicted cost: ${avg_prediction:,.2f}")
                print(f"High risk patients: {high_risk_count} ({high_risk_count/inference_count*100:.1f}%)")
                print(f"Require review: {requires_review_count} ({requires_review_count/inference_count*100:.1f}%)")
            
            return final_predictions
            
        except Exception as e:
            print(f"Error during batch inference: {str(e)}")
            print("Troubleshooting steps:")
            print("1. Check that the model exists and has the 'champion' alias")
            print("2. Verify input data contains all required columns")
            print("3. Ensure feature table juan_dev.ml.healthcare_features is accessible")
            print("4. Check Unity Catalog permissions")
            raise e

# Example usage with improved error handling and logging
print("Initializing improved batch inference pipeline...")

batch_inference = HealthcareBatchInferenceV2()

print("Running batch inference on insurance data...")
try:
    results = batch_inference.run_batch_inference(
        input_table="juan_dev.ml.insurance_silver",
        output_table="juan_dev.ml.cost_predictions"
    )
    
    print("\n=== Batch Inference Summary ===")
    print(f"Successfully processed {results.count()} records")
    print("Results saved to juan_dev.ml.cost_predictions")
    
    # Show sample results
    print("\n=== Sample Predictions ===")
    results.select(
        "customer_id", 
        "adjusted_prediction",
        "cost_risk_category",
        "high_risk_patient",
        "requires_review"
    ).show(5)
    
except Exception as e:
    print(f"Batch inference failed: {e}")
    print("Please check the error message above and follow troubleshooting steps")