In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from pyspark.sql.functions import *
import mlflow

In [None]:
# Start MLflow experiment for EDA tracking
mlflow.set_experiment("/Users/juan.lamadrid@databricks.com/experiments/insurance_cost_prediction_eda")

with mlflow.start_run(run_name="healthcare_insurance_eda"):
    # Load cleaned data
    df = spark.table("juan_dev.ml.insurance_silver").toPandas()
    
    # Healthcare-specific data profiling
    eda_results = {
        "total_patients": len(df),
        "avg_age": df['age'].mean(),
        "smoker_percentage": (df['smoker'].sum() / len(df)) * 100,
        "high_cost_threshold": df['charges'].quantile(0.95),
        "missing_data_percentage": (df.isnull().sum() / len(df)) * 100
    }
    
    # Log healthcare compliance metrics
    # mlflow.log_metrics(eda_results)
    # Log healthcare compliance metrics
    # mlflow.log_metrics({k: float(v) for k, v in eda_results.items()})
    
    # Risk factor analysis
    risk_analysis = df.groupby(['smoker', 'age_group']).agg({
        'charges': ['mean', 'median', 'std'],
        'bmi': 'mean'
    }).round(2)
    
    # Log visualizations
    plt.figure(figsize=(12, 8))
    sns.boxplot(data=df, x='region', y='charges', hue='smoker')
    plt.title('Healthcare Costs by Region and Smoking Status')
    plt.xticks(rotation=45)
    mlflow.log_figure(plt.gcf(), "cost_distribution_by_region_smoking.png")
    
    # Feature correlation analysis
    correlation_matrix = df[['age', 'bmi', 'children', 'charges']].corr()
    plt.figure(figsize=(10, 8))
    sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm')
    plt.title('Healthcare Feature Correlations')
    mlflow.log_figure(plt.gcf(), "feature_correlations.png")
    
    # Healthcare risk insights
    high_risk_patients = df[
        (df['smoker'] == True) & 
        (df['bmi'] > 30) & 
        (df['age'] > 50)
    ]
    
    mlflow.log_metrics({
        "high_risk_patients_count": len(high_risk_patients),
        "high_risk_avg_cost": high_risk_patients['charges'].mean()
    })
