# Healthcare Payer Member Churn Prediction
## End-to-End MLflow Demo with Databricks Feature Engineering

### Overview
This notebook demonstrates an end-to-end machine learning workflow for predicting member churn in the healthcare payer industry. Member churn occurs when patients disenroll from a health insurance plan, which can be costly for payers and disruptive to care continuity.

**Business Problem**: Healthcare payers need to identify members at risk of disenrollment to implement targeted retention strategies, improve member satisfaction, and reduce acquisition costs.

**Solution**: We'll build a predictive model using claims data to identify members likely to disenroll, leveraging:
- **Databricks Feature Engineering** for scalable feature management
- **MLflow** for experiment tracking and model lifecycle management
- **Unity Catalog** for secure model governance
- **Champion/Challenger** model comparison strategies

### Key Technologies
- **Databricks Runtime**: 16.3.x-cpu-ml-scala2.12
- **MLflow**: Experiment tracking, model registry, and deployment
- **Feature Engineering Client**: Feature store management
- **Unity Catalog**: Centralized governance

### Prerequisites
⚠️ **Note**: Ensure that classic compute assigned to single user or group is used with Databricks runtime **16.3.x-cpu-ml-scala2.12** or higher.

---
# 1. Data Preparation and Loading

In this section, we'll load healthcare claims data that will be used to predict member disenrollment. The data contains detailed claims information including:
- **Member demographics**: age, date of birth
- **Claims activity**: service dates, claim types, amounts
- **Provider information**: provider IDs, service locations
- **Claim status**: approved, denied, paid amounts

In [None]:
%sql
-- Load raw claims data from Databricks Volumes
-- Data is ingested by dropping CSV files into a volume, which automatically populates the claims table
-- This pattern enables easy data ingestion from various payers

CREATE TABLE IF NOT EXISTS demo.hls.PAYER_DETAILED_CLAIMS
FROM
  READ_FILES('/Volumes/demo/hls/payer_detailed_claims/*.csv', format => 'csv')

In [None]:
# Install required libraries
%pip install databricks-feature-engineering

# Restart Python to load the newly installed packages
dbutils.library.restartPython()

In [None]:
# Import required libraries
from pyspark.sql.functions import *
from pyspark.sql.window import Window
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    roc_auc_score,
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    brier_score_loss
)
import mlflow

# Load claims data from Delta table
df = spark.table("demo.hls.PAYER_DETAILED_CLAIMS")

print(f"Total records loaded: {df.count():,}")
print(f"Total unique members: {df.select('member_id').distinct().count():,}")

---
# 2. Feature Engineering

Feature engineering is critical for predicting member churn. We'll create features that capture:

**Temporal Patterns**:
- Days between consecutive claims (engagement frequency)
- Days since last claim (recency indicator)
- Number of months with claims (tenure activity)

**Utilization Metrics**:
- Total number of claims
- Claim type distribution (Inpatient, Outpatient, Pharmacy, Professional)
- Number of distinct providers used

**Financial Indicators**:
- Total and average charge amounts
- Total and average paid amounts
- Denial rate (percentage of denied claims)
- Allowed-to-paid ratio

**Member Demographics**:
- Age calculated from date of birth

These features help identify patterns that distinguish members who are likely to disenroll from those who remain active.

In [None]:
## 2.1 Date Preprocessing
# Convert string dates to proper date types
cols_to_cast = ["service_date", "dob"]
for col_name in cols_to_cast:
    df = df.withColumn(col_name, to_date(col_name))

## 2.2 Temporal Features - Calculate days between consecutive claims
# Use window functions to calculate claim frequency patterns
w = Window.partitionBy("member_id").orderBy("service_date")
df_with_lag = df.withColumn("prev_date", lag("service_date").over(w)) \
                .withColumn("gap_days", datediff(col("service_date"), col("prev_date")))

## 2.3 Claim Type Distribution - Pivot claim types for each member
claim_type_pivot = df.groupBy("member_id").pivot("claim_type").count().na.fill(0)

## 2.4 Aggregate Features at Member Level
features_df = df.groupBy("member_id").agg(
    # Utilization metrics
    count("*").alias("num_claims_total"),
    countDistinct("provider_id").alias("num_distinct_providers"),
    countDistinct(col("service_date").substr(0,7)).alias("num_months_with_claims"),
    
    # Quality/Status metrics
    (sum(when(col("status") == "Denied", 1).otherwise(0)) / count("*")).alias("pct_denied_claims"),
    
    # Financial metrics
    sum("charge_amount").alias("total_charge_amount"),
    sum("paid_amount").alias("total_paid_amount"),
    avg("allowed_amount").alias("avg_allowed_amount"),
    avg("paid_amount").alias("avg_paid_amount"),
    
    # Temporal markers
    max("service_date").alias("last_claim_date"),
    min("service_date").alias("first_claim_date"),
    
    # Demographics
    max(months_between(current_date(), col("dob")) / 12).alias("age")
)

## 2.5 Add temporal gap features
avg_gap = df_with_lag.groupBy("member_id").agg(avg("gap_days").alias("avg_days_between_claims"))

## 2.6 Combine all features
features_df = features_df.join(avg_gap, on="member_id", how="left")
features_df = features_df.join(claim_type_pivot, on="member_id", how="left")

# Calculate derived ratio
features_df = features_df.withColumn("avg_allowed_paid_ratio", col("avg_allowed_amount") / col("avg_paid_amount"))

# Add recency metric
features_df = features_df.withColumn("days_since_last_claim", datediff(current_date(), col("last_claim_date")))

# Add versioning timestamp for feature store
features_df = features_df.withColumn("version_ts", current_timestamp())

## 2.7 Create Target Variable (Label)
# Define disenrollment: members with no claims in the last 180 days are considered disenrolled
label_features_df = features_df.withColumn(
    "disenrolled", 
    when(col("days_since_last_claim") > 180, 1).otherwise(0)
)

# Remove date columns that won't be used in modeling
label_features_df = label_features_df.drop("last_claim_date", "first_claim_date", "days_since_last_claim")

print(f"Features created for {label_features_df.count():,} members")
print(f"Disenrollment rate: {label_features_df.filter(col('disenrolled') == 1).count() / label_features_df.count():.2%}")


---
# 3. Feature Store Management with Databricks

The **Databricks Feature Engineering** client provides a centralized feature store that:
- **Tracks feature lineage**: Automatically links features to models
- **Enables feature reuse**: Share features across multiple models and teams
- **Supports point-in-time lookups**: Ensures training/serving consistency
- **Manages versioning**: Track feature evolution over time

We'll create two tables:
1. **Label table**: Contains labels (disenrolled) along with features for model training
2. **Feature table**: Contains only features for inference and feature sharing

In [None]:
%sql
-- Clean up existing table if needed (for demo purposes)
DROP TABLE IF EXISTS demo.hls.label_features_versioned

In [None]:
# Save label + features table as a Delta table
# This table contains both features and the target variable for training
label_features_df.write.format("delta") \
    .mode("overwrite") \
    .option("mergeSchema", "true") \
    .saveAsTable("demo.hls.label_features_versioned")

print("✓ Label features table saved successfully")

In [None]:
%sql
-- Define primary key constraints for the label features table
-- This ensures data quality and enables proper feature store functionality

ALTER TABLE demo.hls.label_features_versioned DROP CONSTRAINT IF EXISTS member_labels_versioned_pk;
ALTER TABLE demo.hls.label_features_versioned ALTER COLUMN member_id SET NOT NULL;
ALTER TABLE demo.hls.label_features_versioned ALTER COLUMN version_ts SET NOT NULL;
ALTER TABLE demo.hls.label_features_versioned ADD CONSTRAINT member_labels_versioned_pk PRIMARY KEY(member_id, version_ts);

### Viewing Feature Tables

Feature store tables are saved as Delta tables in Unity Catalog. You can:
- Browse them in **Catalog Explorer** (left sidebar → Catalog)
- Query them using SQL
- Track their lineage and usage
- Set permissions and governance policies

In [None]:
# Preview the label features table
display(label_features_df.limit(100))

---
## 3.1 Register Features with Feature Engineering Client

Now we'll register the feature table with the Feature Engineering client. This provides:
- **Automatic lineage tracking**: Link features to models automatically
- **Point-in-time correctness**: Use `timeseries_columns` to ensure temporal consistency
- **Feature discovery**: Teams can browse and reuse features
- **Serving integration**: Features can be automatically looked up during inference

In [None]:
# Initialize Feature Engineering Client
from databricks.feature_engineering import FeatureEngineeringClient
fe = FeatureEngineeringClient()

# Create feature table in Unity Catalog
# This registers the table with the feature store for lineage tracking
disenrollment_feature_table = fe.create_table(
    name="demo.hls.member_features_versioned",
    primary_keys=["member_id", "version_ts"],
    schema=features_df.schema,
    timeseries_columns="version_ts",  # Enables point-in-time lookups
    description="Member churn prediction features derived from claims data for PAYER1"
)

print("✓ Feature table registered in Unity Catalog")

In [None]:
# Write features to the feature store
# 'merge' mode supports schema evolution and incremental updates
fe.write_table(
    name="demo.hls.member_features_versioned",
    df=features_df,  # Can also be a streaming DataFrame for real-time updates
    mode='merge'  # Supports schema evolution and upserts
)

print("✓ Features written to feature store")

In [None]:
%sql
-- Verify features were written correctly (showing most recent versions)
SELECT * FROM demo.hls.member_features_versioned 
ORDER BY version_ts DESC 
LIMIT 100

---
## 3.2 On-Demand Feature Functions (Optional)

**Feature Functions** allow you to define features that are calculated on-demand at inference time, rather than pre-computed and stored. This is useful for:
- **Time-dependent features**: Like age, which changes over time
- **Complex transformations**: That depend on multiple inputs
- **Dynamic calculations**: That need to be computed at serving time

Example: Member age needs to be calculated based on the current date at inference time, not pre-computed during training.

In [None]:
%sql
-- Create a user-defined function to calculate age dynamically at inference time
-- This ensures age is always current, not stale from training time

CREATE OR REPLACE FUNCTION demo.hls.current_age(dob DATE)
RETURNS INT
LANGUAGE PYTHON
COMMENT "[Feature Function] Calculate current age based on date of birth"
AS $$
from datetime import datetime
if dob is None:
    return None
    
current_date = datetime.now()
age = current_date.year - dob.year - ((current_date.month, current_date.day) < (dob.month, dob.day))
return age
$$

In [None]:
%sql
-- Verify the function was created successfully
DESCRIBE FUNCTION demo.hls.current_age;

-- Test the function (uncomment to test):
-- SELECT demo.hls.current_age('1989-09-14') as calculated_age;