### Prepare environment

In [0]:
%run ../environment/prepare_environment


### Data drift detection

This notebook detects feature-level data drift between training and inference datasets using PySpark and statistical population metrics.

The purpose of this notebook is to ensure that the data used during model inference remains statistically consistent with the data used during training, before any model quality degradation becomes visible.

In detail, this notebook:
- Accepts two input parameters: the training table and the inference table (Delta tables).
- Loads both datasets using Spark and identifies numerical feature columns automatically.
- Uses the training dataset as a reference distribution to avoid data leakage.
- Computes Population Stability Index (PSI) for each numerical feature using distributed PySpark operations.
- Classifies drift severity per feature (no drift, moderate drift, severe drift) based on industry-standard thresholds.
- Logs drift metrics and metadata to MLflow for auditability, historical tracking, and correlation with deployments.

In [0]:
# Load the data. Use dbutils.widgets for latter parametrization in Databricks Jobs
from pyspark.sql.types import NumericType

dbutils.widgets.text("train_table", "ai_ml_in_practice.telco_customer_churn_silver.telco_silver")
dbutils.widgets.text("inference_table", "ai_ml_in_practice.telco_customer_churn_silver.telco_silver")
dbutils.widgets.text("num_bins", "10")

TRAIN_TABLE = dbutils.widgets.get("train_table")
INFERENCE_TABLE = dbutils.widgets.get("inference_table")
NUM_BINS = int(dbutils.widgets.get("num_bins"))

train_df = spark.read.table(TRAIN_TABLE)
inf_df = spark.read.table(INFERENCE_TABLE)

In [0]:
from pyspark.sql import functions as F

# Prepare function to calculate Population Stability Index (PSI)
def compute_psi(train_df, inf_df, col_name, num_bins=10, eps=1e-6):
    """
    Computes the Population Stability Index (PSI) for a single numerical feature
    between a training dataset and an inference dataset using PySpark.

    The function uses the training dataset as a reference distribution to define
    bucket boundaries (quantile-based binning) in order to avoid data leakage.
    It then compares the relative frequency of observations in each bucket
    between training and inference data to quantify distributional shift.

    PSI is commonly used to detect feature-level data drift in production
    machine learning systems.

    Parameters:
        train_df (pyspark.sql.DataFrame): Spark DataFrame containing the training data.
        inf_df (pyspark.sql.DataFrame): Spark DataFrame containing the inference data.
        col_name (str): Name of the numerical feature column to evaluate.
        num_bins (int): Number of quantile-based buckets to use for PSI calculation.
        eps (float): Small constant added to bucket proportions to avoid division
                     by zero and logarithm of zero.

    Returns:
        float: The computed PSI value for the given feature. Lower values indicate
               similar distributions, while higher values indicate stronger drift.
    """

    quantiles = train_df.approxQuantile(col_name, 
                                         [i / num_bins for i in range(1, num_bins)],
                                         0.01)

    splits = [-float("inf")] + quantiles + [float("inf")]

    bucketizer = F.when(F.col(col_name) < splits[1], 0)
    for i in range(1, len(splits) - 1):
        bucketizer = bucketizer.when(
            (F.col(col_name) >= splits[i]) & (F.col(col_name) < splits[i+1]),
            i
        )
    bucketizer = bucketizer.otherwise(len(splits) - 2)

    train_hist = (
        train_df
        .withColumn("bucket", bucketizer)
        .groupBy("bucket")
        .count()
        .withColumn("train_pct", F.col("count") / train_df.count())
        .select("bucket", "train_pct")
    )

    inf_hist = (
        inf_df
        .withColumn("bucket", bucketizer)
        .groupBy("bucket")
        .count()
        .withColumn("inf_pct", F.col("count") / inf_df.count())
        .select("bucket", "inf_pct")
    )

    joined = (
        train_hist
        .join(inf_hist, on="bucket", how="outer")
        .fillna(eps)
    )

    psi_df = joined.withColumn(
        "psi",
        (F.col("train_pct") - F.col("inf_pct")) *
        F.log(F.col("train_pct") / F.col("inf_pct"))
    )

    psi_value = psi_df.agg(F.sum("psi")).first()[0]

    return float(psi_value)


In [0]:
import mlflow

# Select numeric columns only
numeric_cols = [
    field.name
    for field in train_df.schema.fields
    if isinstance(field.dataType, NumericType)
]

# Start mlflow run context to run the test and log metrics
with mlflow.start_run(run_name="pyspark_feature_drift"):
    mlflow.log_param("train_table", TRAIN_TABLE)
    mlflow.log_param("inference_table", INFERENCE_TABLE)
    mlflow.log_param("num_bins", NUM_BINS)

    results = []

    for col_name in numeric_cols:
        psi = compute_psi(train_df, inf_df, col_name, NUM_BINS)

        results.append({
            "feature": col_name,
            "psi": psi,
            "drift_level": (
                "NO_DRIFT" if psi < 0.1 else
                "MODERATE" if psi < 0.2 else
                "SEVERE"
            )
        })

    psi_df = spark.createDataFrame(results)
    display(psi_df.orderBy(F.desc("psi")))

    metrics = {
        f"psi_{row['feature']}": row["psi"]
        for row in psi_df.collect()
    }

    mlflow.log_metrics(metrics)
