### Prepare environment

In [0]:
%run ../environment/prepare_environment

# XGBoost - Telco Churn Classification

This notebook will cover:
- Training and tuning an XGBoost model using scikit-learn API
- Tracking experiments and results with MLflow
- Visualizing model performance and interpretability

**Why XGBoost?**
- State-of-the-art performance for tabular data
- Handles nonlinearity, missing values, and feature interactions
- Widely used in industry and competitions

In [0]:
import os
import mlflow
import logging
import numpy as np
import pandas as pd
import mlflow.xgboost
import xgboost as xgb
import matplotlib.pyplot as plt
from mlflow.models import infer_signature
from sklearn.model_selection import train_test_split
from databricks.feature_engineering import FeatureLookup, FeatureEngineeringClient
from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_auc_score, confusion_matrix, roc_curve, auc

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("telco-churn-pipeline")

## 1. Load Feature Table

We use features engineered in the previous workshop.

If the feature table is missing, re-run the following notebooks in order:
- [2.1_telco_raw_to_bronze]($../2_data_preparations/2.1_telco_raw_to_bronze)
- [2.3_telco_bronze_to_silver]($../2_data_preparations/2.3_telco_bronze_to_silver)
- [3.4_telco_embedding_models]($../3_feature_engineering/3.4_telco_embedding_models)
- [3.1_telco_feature_table]($../3_feature_engineering/3.1_telco_feature_table)
- [3.2_telco_feature_store]($../3_feature_engineering/3.2_telco_feature_store)

In [0]:
from pyspark.sql.utils import AnalysisException

def table_exists(full_name: str) -> bool:
    catalog, schema, table = full_name.split(".")
    return (
        spark.sql(f"SHOW TABLES IN {catalog}.{schema}")
        .filter(f"tableName = '{table}'")
        .count() == 1
    )

def load_feature_table():
    feature_table_name = "ai_ml_in_practice.telco_customer_churn_silver.telco_silver"

    if not table_exists(feature_table_name):
        logger.warning("Feature table not found. Please run the required notebooks.")
        return None

    df = spark.table(feature_table_name)
    logger.info(f"Loaded feature table: {feature_table_name}")
    return df

feature_df = load_feature_table().select("customer_id", "churn")

## 2. Prepare Data for XGBoost

Convert Spark DataFrame to Pandas and split for scikit-learn/XGBoost. This is a common pattern for advanced ML on Spark clusters.

In [0]:
# Prepare feature lookup (None in feature names means take all)
fe = FeatureEngineeringClient()

feature_lookups = [
  FeatureLookup(
    table_name="ai_ml_in_practice.telco_customer_churn_silver.telco_customer_features",
    feature_names=None,
    lookup_key="customer_id"
  )
]

# Create training set with Feature Engineering client
training_set = fe.create_training_set(
  df=feature_df,
  feature_lookups=feature_lookups,
  exclude_columns=["customer_id"],
  label="churn",
)

training_df = training_set.load_df()
display(training_df)

pdf = training_df.toPandas()
for col in pdf.select_dtypes(include=['object']).columns:
        pdf[col] = pdf[col].astype('category')
X = pdf.drop(columns=["churn"])
y = pdf["churn"].astype(int)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
logger.info(f"Train: {len(X_train)}, Test: {len(X_test)}")

## 3. Model Training, Evaluation, and MLflow Logging

Train an XGBoost classifier, log parameters, metrics, and artifacts to MLflow. Visualize confusion matrix and ROC curve for best practices.

To add preprocessing (or post-processing) code to the model and generate processed predictions with batch inference, we will build a custom PyFunc MLflow model.

In [0]:
# Custom preprocessing model
class churnClassifier(mlflow.pyfunc.PythonModel):
    def __init__(self, trained_model):
        self.model = trained_model

    def preprocess_result(self, model_input):
        for col in model_input.select_dtypes(include=['object']).columns:
                model_input[col] = model_input[col].astype('category')
        return model_input

    def predict(self, context, model_input):
        processed_df = self.preprocess_result(model_input.copy())
        results = self.model.predict(processed_df)
        return results

In [0]:
def train_and_log(X_train, y_train, X_test, y_test):

    with mlflow.start_run(run_name="xgboost_classifier") as run:
        # Train a model
        model = xgb.XGBClassifier(
            n_estimators=100,
            max_depth=4,
            learning_rate=0.1,
            subsample=0.8,
            colsample_bytree=0.8,
            random_state=42,
            enable_categorical=True,
            eval_metric='logloss'
        )
        model.fit(X_train, y_train)

        # Evaluate the model using the test set
        y_pred = model.predict(X_test)
        y_score = model.predict_proba(X_test)[:, 1]
        acc = accuracy_score(y_test, y_pred)
        prec = precision_score(y_test, y_pred)
        rec = recall_score(y_test, y_pred)
        roc_auc = roc_auc_score(y_test, y_score)
        mlflow.log_metrics({
            'accuracy': acc,
            'precision': prec,
            'recall': rec,
            'roc_auc': roc_auc
        })
        logger.info(f"Accuracy: {acc:.4f}, Precision: {prec:.4f}, Recall: {rec:.4f}, ROC_AUC: {roc_auc:.4f}")

        # Confusion matrix plot
        cm = confusion_matrix(y_test, y_pred)
        fig, ax = plt.subplots(figsize=(4,4))
        ax.matshow(cm, cmap=plt.cm.Blues)
        for (i, j), z in np.ndenumerate(cm):
            ax.text(j, i, str(z), ha='center', va='center')
        ax.set_xlabel('Predicted')
        ax.set_ylabel('Actual')
        ax.set_title('Confusion Matrix')
        plt.close(fig)
        mlflow.log_figure(fig, "confusion_matrix.png")

        # ROC curve plot
        fpr, tpr, _ = roc_curve(y_test, y_score)
        roc_auc = auc(fpr, tpr)
        fig, ax = plt.subplots(figsize=(6, 6))
        ax.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
        ax.plot([0, 1], [0, 1], color='gray', lw=1, linestyle='--')
        ax.set_xlabel('False Positive Rate')
        ax.set_ylabel('True Positive Rate')
        ax.set_title('Receiver Operating Characteristic (ROC)')
        ax.legend(loc='lower right')
        plt.tight_layout()
        plt.close(fig)
        mlflow.log_figure(fig, "roc_curve.png")

        # Feature Importance plot
        importances = model.feature_importances_
        fig, ax = plt.subplots(figsize=(8, 6))
        ax.bar(range(len(importances)), importances, color='teal')
        ax.set_title('Feature Importances (XGBoost)')
        ax.set_xlabel('Feature Index')
        ax.set_ylabel('Importance')
        plt.tight_layout()
        plt.close(fig)
        mlflow.log_figure(fig, "feature_importances.png")

        # Save model to MLflow Model Registry
        signature = infer_signature(X_train, y_pred)

        fe.log_model(
            model=model,
            artifact_path='telco_churn_xgboost',
            signature=signature,
            flavor=mlflow.xgboost,
            training_set=training_set,
            registered_model_name='ai_ml_in_practice.telco_customer_churn_silver.telco_churn_xgboost_model'
        )

        # Save packaged model to MLFlow Model Registry
        pyfunc_model = churnClassifier(model)
        fe.log_model(
            model=pyfunc_model,
            artifact_path="telco_churn_xgboost_packaged",
            flavor=mlflow.pyfunc,
            training_set=training_set,
            registered_model_name="ai_ml_in_practice.telco_customer_churn_silver.telco_churn_xgboost_packaged_model",
        )

        logger.info('MLflow run completed. Run ID: %s', run.info.run_id)
        return model

model = train_and_log(X_train, y_train, X_test, y_test)

## 4. Batch Inference and Model Loading

We will load and score a model using Feature Engineering client. We have to provide only PK for feature table, the lookup will be automatic.

In [0]:
predictions = fe.score_batch(
    model_uri="models:/ai_ml_in_practice.telco_customer_churn_silver.telco_churn_xgboost_packaged_model/3",
    df=feature_df.select("customer_id")
)

display(predictions)

## 5. Tree Visualization (if supported)

**Note:** XGBoost tree plotting is not supported on Databricks serverless compute. If running locally or on a cluster with Graphviz, you can visualize the first tree as follows:

In [0]:
import shutil
if shutil.which('dot') is None:
    print("[INFO] Graphviz is not available on Databricks serverless compute. Skipping XGBoost tree plotting and artifact logging.")
else:
    from xgboost import plot_tree
    fig, ax = plt.subplots(figsize=(16, 8))
    plot_tree(model, num_trees=0, ax=ax)
    plt.title('XGBoost Tree 0 Structure')
    mlflow.log_figure(fig, 'xgb_tree_0.png')