# Healthcare Payer Member Churn - Model Training & Comparison

**Part 2 of 3**: This notebook covers model training, evaluation, and selection.

## Prerequisites
- Complete **Part 1: Feature Engineering** notebook first
- Ensure the following tables exist:
  - `demo.hls.label_features_versioned`
  - `demo.hls.member_features_versioned`

## What's Covered
- Champion model training (Random Forest)
- Model management and versioning
- MLflow tracing
- Challenger model training (Gradient Boosting)
- Model comparison and selection
- Hyperparameter tuning
- Champion promotion

---
# 4. Model Training with MLflow

Now we'll train our member churn prediction model using MLflow for experiment tracking and model management.

## 4.1 Create Training Dataset with Feature Lookups (Optional)

Feature lookups allow you to automatically join features from the feature store at training and inference time. This example shows how to use them, though we'll use a simpler approach for this demo.

In [0]:
# Example: How to define feature lookups (optional pattern)
# This pattern is useful when you want to separate feature computation from model training
from databricks.feature_store import FeatureFunction, FeatureLookup

# Define feature specifications for runtime lookups
features = [
    FeatureLookup(
        table_name="demo.hls.member_features_versioned",
        lookup_key=["member_id"],
        timestamp_lookup_key="version_ts"  # Ensures point-in-time correctness
    ),
    FeatureFunction(
        udf_name="demo.hls.current_age",
        input_bindings={"dob": "dob"},
        output_name="age"
    )
]

print("✓ Feature specifications defined (for reference)")

In [0]:
# Load the labels and features for training
labels_df = spark.read.table("demo.hls.label_features_versioned")
print(f"Loaded {labels_df.count():,} records for training")


In [0]:
# Create training set specifications using Feature Engineering Client
# This links the features to the model for lineage tracking
from databricks.feature_engineering import FeatureEngineeringClient

fe = FeatureEngineeringClient()

# Create training set with feature store linkage
training_set_specs = fe.create_training_set(
    df=labels_df,  # DataFrame with labels and features
    label="disenrolled",  # Target variable
    feature_lookups=[],  # Empty list since features are already in the DataFrame
    exclude_columns=["member_id", "version_ts"]  # Exclude ID and timestamp from training
)

print("✓ Training set specifications created")

In [0]:
# Load the training set as a Pandas DataFrame for scikit-learn training
# training_set_specs.load_df() returns a PySpark DataFrame, which we convert to Pandas
df_loaded = training_set_specs.load_df().toPandas()

print(f"✓ Training data loaded: {df_loaded.shape[0]:,} rows × {df_loaded.shape[1]} columns")

In [0]:
# Preview the training data
df_loaded.head(10)

---
## 4.2 Train Champion Model (Random Forest)

We'll start by training a **Random Forest classifier** as our baseline (Champion) model. Random Forest is a good starting point because it:
- Handles non-linear relationships well
- Provides feature importance insights
- Is robust to outliers and missing values
- Works well out-of-the-box with minimal tuning

**Data Splitting Strategy**:
- **Training set (20%)**: For model training
- **Validation set (40%)**: For hyperparameter tuning and model selection
- **Test set (40%)**: For final evaluation (held out until the end)

In [0]:
# Prepare features and target variable
# Exclude ID columns, timestamps, and the target variable from features
selected_features = [
    c for c in df_loaded.columns 
    if c not in ["member_id", "version_ts", "last_claim_date", "first_claim_date", "disenrolled", "days_since_last_claim"]
]

print(f"Selected {len(selected_features)} features for modeling:")
print(selected_features)

X = df_loaded[selected_features]  # Feature matrix
y = df_loaded["disenrolled"]  # Target variable

# Split data: 20% train, 40% validation, 40% test
X_train, X_rem, y_train, y_rem = train_test_split(
    X.fillna(0), y, 
    train_size=0.20, 
    random_state=42,
    stratify=y  # Maintain class balance
)

X_val, X_test, y_val, y_test = train_test_split(
    X_rem, y_rem, 
    test_size=0.5, 
    random_state=42,
    stratify=y_rem  # Maintain class balance
)

print(f"\n✓ Data split complete:")
print(f"  Training set: {len(X_train):,} samples ({len(X_train)/len(X):.1%})")
print(f"  Validation set: {len(X_val):,} samples ({len(X_val)/len(X):.1%})")
print(f"  Test set: {len(X_test):,} samples ({len(X_test)/len(X):.1%})")
print(f"  Disenrollment rate: {y.mean():.2%}")


In [0]:
# Train Random Forest model with MLflow tracking
import mlflow
import mlflow.sklearn

# End any existing run
mlflow.end_run()

# Start MLflow run with a descriptive name
with mlflow.start_run(run_name="random_forest_champion") as run:
    # Set model hyperparameters
    n_estimators = 100
    max_depth = 6
    random_state = 42
    
    # Log hyperparameters
    mlflow.log_param("n_estimators", n_estimators)
    mlflow.log_param("max_depth", max_depth)
    mlflow.log_param("random_state", random_state)
    mlflow.log_param("algorithm", "RandomForest")
    
    # Create and train model
    rf = RandomForestClassifier(
        n_estimators=n_estimators, 
        max_depth=max_depth,
        random_state=random_state,
        n_jobs=-1  # Use all available cores
    )
    rf.fit(X_train, y_train)
    
    # Make predictions on validation set
    y_val_pred = rf.predict(X_val)
    y_val_proba = rf.predict_proba(X_val)[:, 1]
    
    # Calculate and log metrics
    val_roc_auc = roc_auc_score(y_val, y_val_proba)
    val_accuracy = accuracy_score(y_val, y_val_pred)
    val_f1 = f1_score(y_val, y_val_pred)
    val_precision = precision_score(y_val, y_val_pred)
    val_recall = recall_score(y_val, y_val_pred)
    
    mlflow.log_metric("val_roc_auc", val_roc_auc)
    mlflow.log_metric("val_accuracy", val_accuracy)
    mlflow.log_metric("val_f1", val_f1)
    mlflow.log_metric("val_precision", val_precision)
    mlflow.log_metric("val_recall", val_recall)
    
    print("=" * 60)
    print("RANDOM FOREST - VALIDATION METRICS")
    print("=" * 60)
    print(f"ROC AUC:   {val_roc_auc:.4f}")
    print(f"Accuracy:  {val_accuracy:.4f}")
    print(f"F1 Score:  {val_f1:.4f}")
    print(f"Precision: {val_precision:.4f}")
    print(f"Recall:    {val_recall:.4f}")
    print("=" * 60)
    
    run_id = run.info.run_id
    print(f"\n✓ Model trained successfully (Run ID: {run_id})")

---
## 4.3 Register Model with Feature Store Lineage

To maintain the connection between features and models, we use the **Feature Engineering client** to log the model. This enables:
- **Automatic feature lookup** during inference
- **Feature lineage tracking** (which features were used)
- **Consistent feature computation** across training and serving

In [0]:
# Create model signature and input example for documentation
from mlflow.models.signature import infer_signature

input_example = X_train.iloc[[0]]
signature = infer_signature(X_train, rf.predict(X_train))

# Log model with Feature Store lineage using fe.log_model()
# This maintains the connection between features and the model
fe.log_model(
    model=rf,
    artifact_path="mx_churn",
    flavor=mlflow.sklearn,
    training_set=training_set_specs,  # Links to feature store
    input_example=input_example,
    signature=signature,
    registered_model_name="demo.hls.mx_churn"
)

print("✓ Model registered in Unity Catalog with feature store lineage")


---
## 4.4 Viewing Experiments and Models in Databricks

### Viewing Experiment Runs:
1. Click the **Experiments** icon in the right sidebar
2. View parameters, metrics, and artifacts for each run
3. Compare multiple runs side-by-side
4. Click a run name to see detailed information

### Viewing Registered Models:
1. Navigate to **Catalog Explorer** (left sidebar → Catalog)
2. Browse to `demo.hls.mx_churn`
3. View all model versions, lineage, and metadata
4. Track which features were used in each version

---
## 4.5 Feature Importance and Exploratory Analysis

Understanding which features drive predictions helps:
- **Build trust** in the model
- **Identify intervention opportunities**
- **Guide future feature engineering**

In [0]:
# Analyze feature importance from the Random Forest model
importances = rf.feature_importances_
importance_df = pd.DataFrame({
    "feature": selected_features, 
    "importance": importances
}).sort_values(by="importance", ascending=False)

# Plot top 15 most important features
plt.figure(figsize=(12, 8))
top_features = importance_df.head(15)
sns.barplot(data=top_features, x="importance", y="feature", hue="feature", palette="viridis", legend=False)
plt.title("Top 15 Most Important Features for Churn Prediction", fontsize=14, fontweight='bold')
plt.xlabel("Feature Importance", fontsize=12)
plt.ylabel("Feature", fontsize=12)
plt.tight_layout()
plt.show()

print("\nTop 10 Most Important Features:")
print(importance_df.head(10).to_string(index=False))


In [0]:
# Visualize claim type distribution by disenrollment status
claim_cols = [c for c in df_loaded.columns if c in ["Outpatient", "Inpatient", "Pharmacy", "Professional"]]
if claim_cols:
    claim_summary = df_loaded.groupby("disenrolled")[claim_cols].mean().T
    claim_summary.columns = ["Retained", "Disenrolled"]
    
    plt.figure(figsize=(10, 6))
    claim_summary.plot(kind="bar", figsize=(10, 6), color=["#2ecc71", "#e74c3c"])
    plt.title("Average Claim Counts by Type and Disenrollment Status", fontsize=14, fontweight='bold')
    plt.ylabel("Average Claims per Member", fontsize=12)
    plt.xlabel("Claim Type", fontsize=12)
    plt.xticks(rotation=45)
    plt.legend(title="Member Status")
    plt.tight_layout()
    plt.show()
    
    print("\nClaim Type Analysis:")
    print(claim_summary)



# 5. Model Management and Versioning

Unity Catalog provides enterprise-grade model governance with:
- **Version control**: Track all model versions
- **Aliasing**: Use aliases like "Champion" and "Challenger" for A/B testing
- **Lineage**: Trace features, data, and code used in each version
- **Access control**: Manage who can view, use, or deploy models

## 5.1 Load and Manage Model Versions

In [0]:
# Initialize MLflow client for model management
from mlflow import MlflowClient

client = MlflowClient()
mlflow.set_registry_uri("databricks-uc")

model_name = "demo.hls.mx_churn"

# Helper function to get the latest READY model version
def get_latest_model_version(model_name):
    """Returns the latest READY version number for a model"""
    model_version_infos = client.search_model_versions(f"name = '{model_name}'")
    ready_versions = [
        int(vars(v)["_version"]) 
        for v in model_version_infos 
        if vars(v)["_status"] == "READY"
    ]
    if not ready_versions:
        raise ValueError(f"No READY versions found for model {model_name}")
    return max(ready_versions)

latest_version = get_latest_model_version(model_name)
print(f"✓ Latest model version: {latest_version}")

In [0]:
# Update model metadata and set aliases
# This provides context for model users and enables lifecycle management

# Update overall model description
client.update_registered_model(
    name=model_name,
    description="Healthcare member churn prediction model using claims data. Predicts disenrollment risk."
)

# Update specific version description
client.update_model_version(
    name=model_name,
    version=latest_version,
    description="Random Forest classifier (n_estimators=100, max_depth=6). Trained on member claims features."
)

# Set tags for tracking
client.set_model_version_tag(
    name=model_name,
    version=str(latest_version),
    key="stage",
    value="production"
)

client.set_model_version_tag(
    name=model_name,
    version=str(latest_version),
    key="algorithm",
    value="RandomForest"
)

# Set "Champion" alias for this version
# Aliases enable easy reference without hardcoding version numbers
client.set_registered_model_alias(model_name, "Champion", latest_version)

print(f"✓ Model version {latest_version} set as 'Champion'")
print(f"  Model URI: models:/{model_name}@Champion")


## 5.2 Viewing Models in Unity Catalog

You can view and manage registered models in **Unity Catalog** using Catalog Explorer:

1. In the left sidebar, click **Catalog**
2. Navigate to `demo` → `hls` → `mx_churn`
3. View:
   - All model versions and their metadata
   - Model lineage (features, datasets, code)
   - Performance metrics
   - Aliases and tags
   - Access permissions


## 5.3 Loading and Using Models

Models can be loaded from Unity Catalog in multiple ways:
- **By version number**: `models:/model_name/version`
- **By alias**: `models:/model_name@alias`

Using aliases (like "Champion") is recommended as it decouples code from specific versions.

In [0]:
import mlflow.pyfunc

# Load model by version number
latest_version = get_latest_model_version(model_name)
model_version_uri = f"models:/{model_name}/{latest_version}"
print(f"Loading model from URI: {model_version_uri}")
model_by_version = mlflow.pyfunc.load_model(model_version_uri)

# Load model by alias (recommended approach)
model_champion_uri = f"models:/{model_name}@Champion"
print(f"Loading model from URI: {model_champion_uri}")
champion_model = mlflow.pyfunc.load_model(model_champion_uri)

print(f"\n✓ Model loaded successfully: {champion_model}")

In [0]:
# Helper functions for model predictions

def load_and_predict(model_name, model_alias, new_data):
    """
    Load model and make predictions (class labels)
    Uses pyfunc loader which works for any MLflow model flavor
    """
    model_uri = f"models:/{model_name}@{model_alias}"
    model = mlflow.pyfunc.load_model(model_uri)
    predictions = pd.DataFrame(model.predict(new_data))
    return predictions

def load_and_predict_proba(model_name, model_alias, new_data):
    """
    Load sklearn model and get prediction probabilities
    Uses sklearn loader to access predict_proba method
    """
    model_uri = f"models:/{model_name}@{model_alias}"
    model = mlflow.sklearn.load_model(model_uri)
    probs = model.predict_proba(new_data)[:, 1]  # Probability of churn (class 1)
    return probs

print("✓ Prediction helper functions defined")

In [0]:
# Generate predictions on validation set
rf_pred = load_and_predict(model_name, "Champion", X_val)
rf_pred.columns = ["disenrollment_prediction"]

rf_prob = pd.DataFrame(load_and_predict_proba(model_name, "Champion", X_val))
rf_prob.columns = ["disenrollment_probability"]

# Combine features with predictions
result_df = pd.concat([
    X_val.reset_index(drop=True), 
    rf_pred.reset_index(drop=True), 
    rf_prob.reset_index(drop=True)
], axis=1)

print(f"✓ Generated predictions for {len(result_df):,} members")
print(f"  Average churn probability: {result_df['disenrollment_probability'].mean():.2%}")
print(f"  Predicted churners: {result_df['disenrollment_prediction'].sum():,} ({result_df['disenrollment_prediction'].mean():.1%})")

display(result_df.head(20))

In [0]:
# Save predictions to Delta table for downstream use
# Clean column names (replace spaces with underscores)
result_df.columns = [col.replace(" ", "_") for col in result_df.columns]

spark.createDataFrame(result_df) \
    .write \
    .format("delta") \
    .mode("overwrite") \
    .option("mergeSchema", "true") \
    .saveAsTable("demo.hls.payer_1_disenroll_prediction")

print("✓ Predictions saved to demo.hls.payer_1_disenroll_prediction")


## 5.4 MLflow Tracing for Model Observability

**MLflow Tracing** provides detailed visibility into model predictions:
- Track inputs, outputs, and intermediate steps
- Debug production issues
- Monitor model behavior over time
- Audit predictions for compliance

Traces appear in the **Traces** tab of your experiment.

In [0]:
# Load model and make predictions with tracing
model_uri = f"models:/{model_name}@Champion"
model = mlflow.sklearn.load_model(model_uri)

# Use span to trace the prediction process
with mlflow.start_span(name="predict_proba_with_trace") as span:
    # Log input metadata
    span.set_inputs({
        "rows": X_test.shape[0],
        "features": list(X_test.columns),
        "model_alias": "Champion"
    })
    
    # Make predictions
    proba = model.predict_proba(X_test)
    pred = (proba[:, 1] >= 0.5).astype(int)
    
    # Log output samples (keep payloads small)
    span.set_outputs({
        "proba_sample": proba[:5].round(4).tolist(),
        "pred_sample": pred[:10].tolist(),
        "churn_rate": float(pred.mean())
    })

# Log metrics
test_auc = roc_auc_score(y_test, proba[:, 1])
test_accuracy = accuracy_score(y_test, pred)

mlflow.log_metric("test_auc", test_auc)
mlflow.log_metric("test_accuracy", test_accuracy)

print("=" * 60)
print("CHAMPION MODEL - TEST SET METRICS")
print("=" * 60)
print(f"ROC AUC:   {test_auc:.4f}")
print(f"Accuracy:  {test_accuracy:.4f}")
print("=" * 60)
print("\n✓ Prediction traced successfully (view in Experiments → Traces tab)")

---
# 6. Train Challenger Model (Gradient Boosting)

To improve model performance, we'll train a **Gradient Boosting** classifier as a challenger model. 

**Champion/Challenger Pattern**:
- **Champion**: Current production model (Random Forest)
- **Challenger**: New model being evaluated (Gradient Boosting)
- Compare performance before promoting the challenger

Gradient Boosting often achieves better performance than Random Forest by:
- Building trees sequentially to correct previous errors
- Better handling of complex interactions
- More accurate probability estimates

In [0]:
# Train Gradient Boosting classifier
import mlflow.sklearn
from sklearn.ensemble import GradientBoostingClassifier

mlflow.end_run()  # End any existing run

with mlflow.start_run(run_name="gradient_boosting_challenger") as run:
    # Train with default parameters first
    gb = GradientBoostingClassifier(random_state=42)
    gb.fit(X_train, y_train)
    
    # Make predictions on validation set
    y_val_pred = gb.predict(X_val)
    y_val_proba = gb.predict_proba(X_val)[:, 1]
    
    # Calculate metrics
    val_roc_auc = roc_auc_score(y_val, y_val_proba)
    val_f1 = f1_score(y_val, y_val_pred)
    
    # Log metrics
    mlflow.log_metric("val_roc_auc", val_roc_auc)
    mlflow.log_metric("val_f1", val_f1)
    mlflow.log_param("algorithm", "GradientBoosting")
    
    print("=" * 60)
    print("GRADIENT BOOSTING (Default) - VALIDATION METRICS")
    print("=" * 60)
    print(f"ROC AUC:  {val_roc_auc:.4f}")
    print(f"F1 Score: {val_f1:.4f}")
    print("=" * 60)
    
    run_id = run.info.run_id
    print(f"\n✓ Gradient Boosting model trained (Run ID: {run_id})")

In [0]:
# Register Gradient Boosting model to Unity Catalog
example_input = X_val.iloc[[0]]

mlflow.sklearn.log_model(
    sk_model=gb,
    artifact_path="mx_churn",
    input_example=example_input,
    registered_model_name="demo.hls.mx_churn"
)

print("✓ Gradient Boosting model registered in Unity Catalog")

In [0]:
# Set this version as "Challenger" for A/B testing
latest_version = get_latest_model_version(model_name)

client.update_model_version(
    name=model_name,
    version=latest_version,
    description="Gradient Boosting Classifier (default parameters). Challenger model for comparison."
)

client.set_model_version_tag(
    name=model_name,
    version=str(latest_version),
    key="algorithm",
    value="GradientBoosting"
)

client.set_registered_model_alias(model_name, "Challenger", latest_version)

print(f"✓ Model version {latest_version} set as 'Challenger'")
print(f"  Model URI: models:/{model_name}@Challenger")

---
# 7. Model Comparison and Selection

Before promoting a challenger to champion, we need rigorous comparison:
- **Multiple metrics**: ROC AUC, F1, Precision, Recall, Brier Score
- **Business context**: Which metric matters most for member retention?
- **Validation set**: Used for model selection
- **Test set**: Final unbiased evaluation (held out until the end)

## 7.1 Define Evaluation Function

In [0]:
# Comprehensive evaluation function
def evaluate_model(y_true, y_probs, threshold=0.5):
    """
    Evaluate classification model with multiple metrics
    
    Args:
        y_true: True labels
        y_probs: Predicted probabilities
        threshold: Classification threshold (default 0.5)
    
    Returns:
        Dictionary of evaluation metrics
    """
    y_pred = (y_probs >= threshold).astype(int)
    
    return {
        "ROC AUC": roc_auc_score(y_true, y_probs),
        "F1": f1_score(y_true, y_pred),
        "Precision": precision_score(y_true, y_pred),
        "Recall": recall_score(y_true, y_pred),
        "Accuracy": accuracy_score(y_true, y_pred),
        "Brier Score": brier_score_loss(y_true, y_probs)
    }

print("✓ Evaluation function defined")

In [0]:
# Compare Champion vs Challenger on validation set
print("Loading models for comparison...")
rf_probs = load_and_predict_proba(model_name, "Champion", X_test)
gb_probs = load_and_predict_proba(model_name, "Challenger", X_test)

# Evaluate both models
print("\nEvaluating models...")
rf_results = evaluate_model(y_test, rf_probs)
gb_results = evaluate_model(y_test, gb_probs)

# Create comparison DataFrame
comparison = pd.DataFrame(
    [rf_results, gb_results], 
    index=["Random Forest (Champion)", "Gradient Boosting (Challenger)"]
)

print("\n" + "=" * 80)
print("MODEL COMPARISON - TEST SET RESULTS")
print("=" * 80)
print(comparison.to_string())
print("=" * 80)

# Highlight improvements
print("\n📊 ANALYSIS:")
for metric in comparison.columns:
    champion_val = rf_results[metric]
    challenger_val = gb_results[metric]
    
    # For Brier Score, lower is better
    if metric == "Brier Score":
        improvement = ((champion_val - challenger_val) / champion_val) * 100
        symbol = "↓" if challenger_val < champion_val else "↑"
    else:
        improvement = ((challenger_val - champion_val) / champion_val) * 100
        symbol = "↑" if challenger_val > champion_val else "↓"
    
    print(f"  {metric:12s}: {symbol} {improvement:+.2f}% change")


---
## 7.2 Hyperparameter Tuning with Randomized Search

To further improve the Gradient Boosting model, we'll use **Randomized Search CV** to find optimal hyperparameters:
- More efficient than grid search for large parameter spaces
- Tests random combinations of parameters
- Uses cross-validation to avoid overfitting

In [0]:
# Hyperparameter tuning for Gradient Boosting
from sklearn.model_selection import RandomizedSearchCV
from scipy.stats import uniform, randint

# Define hyperparameter search space
param_distributions = {
    'n_estimators': randint(100, 1000),
    'learning_rate': uniform(0.01, 0.2),
    'max_depth': randint(3, 10),
    'subsample': uniform(0.7, 0.3),
    'min_samples_split': randint(2, 10),
    'min_samples_leaf': randint(1, 5),
}

print("Starting hyperparameter search (this may take several minutes)...")
print(f"Search space: {len(param_distributions)} parameters")
print(f"Number of iterations: 50")
print(f"Cross-validation folds: 5\n")

gb = GradientBoostingClassifier(random_state=42)

gb_search = RandomizedSearchCV(
    estimator=gb,
    param_distributions=param_distributions,
    n_iter=50,  # Number of parameter settings sampled
    scoring='roc_auc',  # Optimize for ROC AUC
    cv=5,  # 5-fold cross-validation
    verbose=1,
    random_state=42,
    n_jobs=-1  # Use all available cores
)

gb_search.fit(X_train, y_train)

print("\n" + "=" * 60)
print("HYPERPARAMETER TUNING COMPLETE")
print("=" * 60)
print(f"Best CV Score (ROC AUC): {gb_search.best_score_:.4f}")
print(f"\nBest Parameters:")
for param, value in gb_search.best_params_.items():
    print(f"  {param:20s}: {value}")
print("=" * 60)

In [0]:
# Register the tuned Gradient Boosting model
example_input = X_val.iloc[[0]]

mlflow.sklearn.log_model(
    sk_model=gb_search.best_estimator_,
    artifact_path="mx_churn",
    input_example=example_input,
    registered_model_name="demo.hls.mx_churn"
)

print("✓ Tuned Gradient Boosting model registered in Unity Catalog")

In [0]:
# Update the Challenger alias to point to the tuned model
latest_version = get_latest_model_version(model_name)

client.update_model_version(
    name=model_name,
    version=latest_version,
    description="Gradient Boosting with Randomized Search CV tuning. Optimized for ROC AUC."
)

client.set_model_version_tag(
    name=model_name,
    version=str(latest_version),
    key="algorithm",
    value="GradientBoosting_Tuned"
)

client.set_registered_model_alias(model_name, "Challenger", latest_version)

print(f"✓ Tuned model (version {latest_version}) set as 'Challenger'")

In [0]:
# Final comparison: Champion vs Tuned Challenger
print("Evaluating tuned challenger model...\n")

gb_tuned_probs = load_and_predict_proba(model_name, "Challenger", X_test)
rf_probs = load_and_predict_proba(model_name, "Champion", X_test)

# Evaluate both models
gb_tuned_results = evaluate_model(y_test, gb_tuned_probs)
rf_results = evaluate_model(y_test, rf_probs)

# Create comparison DataFrame
final_comparison = pd.DataFrame(
    [rf_results, gb_tuned_results], 
    index=["Random Forest (Champion)", "Gradient Boosting - Tuned (Challenger)"]
)

print("=" * 80)
print("FINAL MODEL COMPARISON - TEST SET RESULTS")
print("=" * 80)
print(final_comparison.to_string())
print("=" * 80)

# Decision logic
print("\n🎯 RECOMMENDATION:")
if gb_tuned_results["ROC AUC"] > rf_results["ROC AUC"]:
    improvement = ((gb_tuned_results["ROC AUC"] - rf_results["ROC AUC"]) / rf_results["ROC AUC"]) * 100
    print(f"✅ PROMOTE Challenger to Champion")
    print(f"   ROC AUC improved by {improvement:.2f}%")
else:
    print(f"⏸️  KEEP current Champion")
    print(f"   Challenger did not outperform Champion")

---
## 7.3 Promote Challenger to Champion

Once we've validated that the challenger outperforms the champion, we can promote it. This involves:
1. Updating the "Champion" alias to point to the new model
2. Optionally archiving or removing the "Challenger" alias
3. Updating model tags and documentation

In [0]:
# Promote the Challenger to Champion
new_champion_version = get_latest_model_version(model_name)

# Update Champion alias to point to the new best model
client.set_registered_model_alias(
    name=model_name,
    alias="Champion",
    version=new_champion_version
)

# Update tags
client.set_model_version_tag(
    name=model_name,
    version=str(new_champion_version),
    key="stage",
    value="production"
)

# Remove Challenger alias (optional)
try:
    client.delete_registered_model_alias(name=model_name, alias="Challenger")
    print("✓ Removed 'Challenger' alias")
except:
    pass

print(f"\n🎉 Model version {new_champion_version} promoted to Champion!")
print(f"   URI: models:/{model_name}@Champion")
print(f"\n📝 Next steps:")
print(f"   1. Deploy to production endpoint")
print(f"   2. Monitor performance metrics")
print(f"   3. Set up alerts for model drift")

# Note: Uncomment the following lines to delete old model versions if needed
# client.delete_model_version(name=model_name, version=old_version)