### Data Quality Monitor for timeseries snapshot features
**Relevant docs**:
 - [Create a data profile using the API](https://docs.databricks.com/aws/en/data-quality-monitoring/data-profiling/create-monitor-api)
 - [Data Quality Monitor Python SDK](https://databricks-sdk-py.readthedocs.io/en/latest/workspace/dataquality/data_quality.html)
 - [Notebook example: time series profile](https://docs.databricks.com/aws/en/data-quality-monitoring/data-profiling/create-monitor-api#notebook-example-time-series-profile)


In [0]:
%pip install scikit-learn==1.7.0 databricks-sdk>=0.68.0
%restart_python

In [0]:
import uuid
from sklearn.datasets import make_blobs
import pandas as pd
import numpy as np
import mlflow
from pyspark.sql.functions import col
import pyspark.sql.functions as func

from databricks.sdk import WorkspaceClient
from databricks.sdk.service.dataquality import Monitor, DataProfilingConfig, TimeSeriesConfig, AggregationGranularity, DataProfilingStatus, RefreshState, Refresh

In [0]:
dbutils.widgets.text('catalog_name','','Enter catalog name')
dbutils.widgets.text('schema_name','','Enter schema name')
dbutils.widgets.text('suffix','','Enter suffix')

In [0]:
#dbutils.widgets.removeAll()

In [0]:
catalog_name = dbutils.widgets.get('catalog_name')
schema_name = dbutils.widgets.get('schema_name')
suffix = dbutils.widgets.get('suffix')
uc_location = f"{catalog_name}.{schema_name}"

print(f"UC location: {uc_location}")

In [0]:
N_SAMPLES = 10000
N_DRIFT_SAMPLES = 2000
CENTERS = 4
N_FEATURES = 20
TRAIN_TIMESTAMP = "2026-02-01 00:00:00"
# Introduce drift to data over these timestamp snapshots
FEATURE_DRIFT_TIMESTAMPS = ["2026-02-02 00:00:00", "2026-02-03 00:00:00", "2026-02-04 00:00:00", "2026-02-05 00:00:00", "2026-02-06 00:00:00"]

BASELINE_UC_TABLE = f"{uc_location}.baseline_features_{suffix}"
DRIFT_UC_TABLE = f"{uc_location}.drifted_features_{suffix}"

print(f"Baseline UC table: {BASELINE_UC_TABLE}\nDrift UC table: {DRIFT_UC_TABLE}")

In [0]:
#spark.sql(f"DROP TABLE IF EXISTS {DRIFT_UC_TABLE}")

Generate a baseline training dataset

In [0]:
def generate_dataset(n_samples, centers, n_features, timestamp, random_state=42, cluster_std=1):
  """
  Generate synthetic clusters. Then, add a timestamp columns and distinct id.
  """

  X_train, y_train = make_blobs(n_samples = n_samples, 
                                centers = centers,
                                cluster_std = cluster_std, 
                                random_state = random_state,
                                n_features = n_features)
  
  FEATURES_COLS = [f'feature_{i}' for i in range(n_features)]
  LABEL_COL = "cluster"

  y_train = pd.DataFrame(y_train, columns=[LABEL_COL])
  X_train = pd.DataFrame(X_train, columns=FEATURES_COLS)

  features_train = pd.concat([X_train, y_train], axis=1)
  features_train['timestamp'] = pd.Timestamp(timestamp)
  features_train['customer_id'] = [str(uuid.uuid4()) for _ in range(len(features_train))]

  features_train = features_train[["customer_id", "timestamp"] + FEATURES_COLS + [LABEL_COL]]

  features_train = spark.createDataFrame(features_train)
  return features_train

In [0]:
baseline_features = generate_dataset(n_samples = N_SAMPLES, 
                                     centers = CENTERS,
                                     n_features = N_FEATURES,
                                     timestamp = TRAIN_TIMESTAMP)

baseline_features.write.mode('overwrite').saveAsTable(BASELINE_UC_TABLE)
baseline_features = spark.table(BASELINE_UC_TABLE)
display(baseline_features)

In [0]:
baseline_features.count()

Generate drifted features over the timestamp snapshots

In [0]:
def generate_drift(
    baseline_df,
    timestamp,
    drift_columns=None,
    loc=2.0,
    scale=2.0,
    cluster_std=2.0,
    random_state=42,
):
    """
    Preserve every customer_id while drifting selected feature columns.
    - Draws a cluster-level shift so cluster centers wander over time.
    - Adds row-level jitter (cluster_std) like the original make_blobs call.
    - Works with the looping pattern where you feed the previous drift back in.
    """
    feature_cols = [c for c in baseline_df.columns if c.startswith("feature_")]
    drift_columns = drift_columns or feature_cols

    approx_centers_pdf = (
        baseline_df
        .groupBy("cluster")
        .agg(*[func.avg(f).alias(f) for f in feature_cols])
        .orderBy("cluster")
        .toPandas()
    )

    rng = np.random.default_rng(random_state)
    cluster_shifts = {
        row["cluster"]: rng.normal(loc=loc, scale=scale, size=len(drift_columns))
        for _, row in approx_centers_pdf.iterrows()
    }

    baseline_pdf = baseline_df.orderBy("customer_id").toPandas()

    for cluster_value, shift in cluster_shifts.items():
        mask = baseline_pdf["cluster"] == cluster_value
        if mask.any():
            noise = rng.normal(
                loc=0.0,
                scale=cluster_std,
                size=(mask.sum(), len(drift_columns)),
            )
            baseline_pdf.loc[mask, drift_columns] = (
                baseline_pdf.loc[mask, drift_columns].to_numpy()
                + shift
                + noise
            )

    baseline_pdf["timestamp"] = pd.Timestamp(timestamp)
    return spark.createDataFrame(baseline_pdf)

In [0]:
baseline_features = spark.table(BASELINE_UC_TABLE)

seed = 42
current_df = baseline_features
for ts in FEATURE_DRIFT_TIMESTAMPS:
    drifted_features = generate_drift(
        baseline_df=current_df,
        timestamp=ts,
        loc=2.0,
        scale=2.0,
        cluster_std=2.0,
        drift_columns= ['feature_0', 'feature_1', 'feature_2'],
        random_state=seed + FEATURE_DRIFT_TIMESTAMPS.index(ts)
    )
    drifted_features.write.mode("append").saveAsTable(DRIFT_UC_TABLE)
    current_df = drifted_features

For TimeSeries and Inference profiles, it's a best practice to enable change data feed (CDF) on your table. When CDF is enabled, only newly appended data is processed, rather than re-processing the entire table every refresh. This makes execution more efficient and reduces costs as you scale across many tables.

In [0]:
spark.sql(f"ALTER TABLE {DRIFT_UC_TABLE} SET TBLPROPERTIES ('delta.enableChangeDataFeed' = 'true')")

In [0]:
drifted_features = spark.table(DRIFT_UC_TABLE)
display(drifted_features.groupBy("timestamp").count().orderBy("timestamp"))

Create a timeseries profile monitor

In [0]:
%sql

DROP TABLE IF EXISTS shared.mlc_schema.drifted_features_mlc_profile_metrics;

In [0]:
w = WorkspaceClient()
schema = w.schemas.get(full_name=f"{catalog_name}.{schema_name}")
table = w.tables.get(full_name=DRIFT_UC_TABLE)

config = DataProfilingConfig(
 output_schema_id=schema.schema_id,
 assets_dir=f"/Workspace/Users/marshall.carter@databricks.com/mlops_workshop/04_lakehouse_monitor/{DRIFT_UC_TABLE}",
 time_series=TimeSeriesConfig(
    timestamp_column="timestamp",
    granularities=[AggregationGranularity.AGGREGATION_GRANULARITY_1_DAY]),
baseline_table_name=BASELINE_UC_TABLE)

info = w.data_quality.create_monitor(
   monitor=Monitor(
     object_type="table",     # object_type is always "table" for data profiling
     object_id=table.table_id,
     data_profiling_config=config,
   ),
)

In [0]:
w.data_quality.get_monitor(object_type="table", object_id=table.table_id)

In [0]:
refreshes = list(w.data_quality.list_refresh(object_type="table", object_id=table.table_id))
refreshes

In [0]:
w.data_quality.get_refresh(object_type="table", object_id=table.table_id, refresh_id = refreshes[0].refresh_id)

In [0]:
import time

# A metric refresh will automatically be triggered on creation
it = w.data_quality.list_refresh(object_type="table", object_id=table.table_id)

run_info = next(it, None) 
while run_info.state in (RefreshState.MONITOR_REFRESH_STATE_PENDING, RefreshState.MONITOR_REFRESH_STATE_RUNNING):
  run_info = w.data_quality.get_refresh(object_type="table", object_id=table.table_id, refresh_id=run_info.refresh_id)
  time.sleep(10)

assert run_info.state == RefreshState.MONITOR_REFRESH_STATE_SUCCESS, "Monitor refresh failed"

In [0]:
w.data_quality.get_refresh(object_type="table", object_id=table.table_id, refresh_id = refreshes[0].refresh_id)

Trigger a refresh

In [0]:
#run_info =  w.data_quality.create_refresh(object_type="table", object_id=table.table_id, refresh=Refresh(
#   object_type="table",
#   object_id=table.table_id,
# )
#)

#### View Lakehouse Monitor tables

#### Profile metrics

In [0]:
PROFILE_METRICS_TABLE = (w.quality_monitors.get(table_name=DRIFT_UC_TABLE)
                                           .profile_metrics_table_name)

print(PROFILE_METRICS_TABLE)

profile_metrics_df = spark.table(PROFILE_METRICS_TABLE)

features_and_prediction = [col for col in spark.table(DRIFT_UC_TABLE).columns if col not in ["customer_id", "timestamp"]]
        
display(profile_metrics_df.filter((col("log_type") == "INPUT") & (col("column_name").isin(features_and_prediction)))
                          .select(["window", "log_type", "granularity", "column_name", "count",
                                   "data_type", "num_nulls", "avg", "median", "quantiles", "min", "max",
                                   "num_zeros", "num_nan", "percent_nan", "percent_null", "percent_distinct", "avg"]))

In [0]:
display(profile_metrics_df.groupBy("log_type").agg(func.count("*")))

In [0]:
display(profile_metrics_df.filter(col("log_type") == "BASELINE"))

In [0]:
display(profile_metrics_df.groupBy("window").agg(func.count("*")))

#### Drift metrics overview
**KS Test**: A non-parametric test that measures the maximum distance between the cumulative distribution functions of two distributions. It returns a static and a p-value; statistics close to 0 means the distributions are close to 0, while a statistic colde to 1 mean they are very different. This is assuming a small p-value (<0.05). This statistic expect continuous variables.

**Wasserstein Distance**: Measures the minimal amount of work needed to transform one distribution into another (how much mass must be moved and how far?). Identical distributions have a value of 0.0 while higher values indicate larger differences. It's value is >= 0. his statistic expect continuous variables.

**Population Stability Index (PSI)**: Compares the binned distributions of a feature across two datasets. Roughly speaking, values of < 0.1 mean no significant drift, >= 0.1 and <= 0.25 mean moderate drift, and > 0.25 mean significant drift. PSI works well for deteving changes in frequency of distributions for categorical variable.


Below results:
 - The high population stability index score (5.02) on the cluster's column indicates **significant drift** in the predicted cluster distribution.  
 - For feature columns:
   - The **KS Test** indicates moderate difference in data distributions compared to baseline.
   - **Wasserstein Distance** is frequently above 1; a strong signal for drift.
   - The **PSI** is typically > 0.25, indicating significant drift.

In [0]:
DRIFT_METRICS_TABLE = (w.quality_monitors.get(table_name=DRIFT_UC_TABLE)
                                         .drift_metrics_table_name)

print(DRIFT_METRICS_TABLE)

profile_metrics_df = spark.table(DRIFT_METRICS_TABLE)

features_and_prediction = [col for col in spark.table(DRIFT_UC_TABLE).columns if col not in ["customer_id", "timestamp"]]
display(profile_metrics_df.filter(col("column_name").isin(features_and_prediction))
                          .select(["window", "granularity", "column_name", "data_type", "drift_type",
                                   "ks_test", "wasserstein_distance", "population_stability_index"])
                          .orderBy("column_name"))

In [0]:
display(profile_metrics_df.groupBy("drift_type").agg(func.count("*")))

In [0]:
display(profile_metrics_df.groupBy("window").agg(func.count("*")))

#### Create a monitor query
Return rows that indicate drift

In [0]:
%sql

SELECT window, 
       column_name, 
       data_type, 
       drift_type, 
       count_delta,
       ks_test, 
       wasserstein_distance, 
       population_stability_index
FROM shared.mlc_schema.drifted_features_mlc_drift_metrics
WHERE drift_type = "BASELINE" AND
      window.end = (SELECT MAX(window.end) from shared.mlc_schema.drifted_features_mlc_drift_metrics) AND
      ((ks_test.pvalue < 0.05 AND ks_test.statistic > 0.2) OR wasserstein_distance >= 0.1);

#### After model retraining, recalculate Lakehouse Monitor's metrics against the a new baseline table.
 - Either overwrite the monitor's baseline table or update it to a different table
 - Continue to write predictions to the same inference table
 - See update monitor [documentation](https://databricks-sdk-py.readthedocs.io/en/latest/workspace/dataquality/data_quality.html#databricks.sdk.service.dataquality.DataQualityAPI.update_monitor)

Delete a monitor

In [0]:
#w = WorkspaceClient()
#table = w.tables.get(full_name=DRIFT_UC_TABLE)
#w.data_quality.delete_monitor(object_type="table", object_id=table.table_id)