# 05 — Train + deploy ML model (gap risk) with MLflow + Unity Catalog

Trains a shift-level model to predict whether a shift will have a staffing gap, registers it to Unity Catalog Model Registry, and writes batch predictions back to gold tables.


In [None]:
%pip install mlflow==2.12.2 databricks-sdk==0.28.0


In [None]:
# Configuration (Databricks widgets)
# These widgets make the demo portable across workspaces/accounts.
# If you're running this outside a Databricks notebook, it will fall back to defaults.

DEFAULT_CATALOG = "rtpa_catalog"
DEFAULT_SCHEMA_REF = "credentialing_ref"
DEFAULT_SCHEMA_BRONZE = "credentialing_bronze"
DEFAULT_SCHEMA_SILVER = "credentialing_silver"
DEFAULT_SCHEMA_GOLD = "credentialing_gold"

DEFAULT_N_PROVIDERS = 200
DEFAULT_DAYS_SCHEDULE = 14
DEFAULT_SEED = 42

try:
    dbutils.widgets.text("catalog", DEFAULT_CATALOG, "Catalog")
    dbutils.widgets.text("schema_ref", DEFAULT_SCHEMA_REF, "Schema (ref)")
    dbutils.widgets.text("schema_bronze", DEFAULT_SCHEMA_BRONZE, "Schema (bronze)")
    dbutils.widgets.text("schema_silver", DEFAULT_SCHEMA_SILVER, "Schema (silver)")
    dbutils.widgets.text("schema_gold", DEFAULT_SCHEMA_GOLD, "Schema (gold)")

    dbutils.widgets.text("n_providers", str(DEFAULT_N_PROVIDERS), "N providers")
    dbutils.widgets.text("days_schedule", str(DEFAULT_DAYS_SCHEDULE), "Days schedule")
    dbutils.widgets.text("seed", str(DEFAULT_SEED), "Random seed")

    catalog = dbutils.widgets.get("catalog") or DEFAULT_CATALOG
    schema_ref = dbutils.widgets.get("schema_ref") or DEFAULT_SCHEMA_REF
    schema_bronze = dbutils.widgets.get("schema_bronze") or DEFAULT_SCHEMA_BRONZE
    schema_silver = dbutils.widgets.get("schema_silver") or DEFAULT_SCHEMA_SILVER
    schema_gold = dbutils.widgets.get("schema_gold") or DEFAULT_SCHEMA_GOLD

    N_PROVIDERS = int(dbutils.widgets.get("n_providers") or DEFAULT_N_PROVIDERS)
    DAYS_SCHEDULE = int(dbutils.widgets.get("days_schedule") or DEFAULT_DAYS_SCHEDULE)
    SEED = int(dbutils.widgets.get("seed") or DEFAULT_SEED)
except Exception:
    catalog = DEFAULT_CATALOG
    schema_ref = DEFAULT_SCHEMA_REF
    schema_bronze = DEFAULT_SCHEMA_BRONZE
    schema_silver = DEFAULT_SCHEMA_SILVER
    schema_gold = DEFAULT_SCHEMA_GOLD

    N_PROVIDERS = DEFAULT_N_PROVIDERS
    DAYS_SCHEDULE = DEFAULT_DAYS_SCHEDULE
    SEED = DEFAULT_SEED

# Derived helpers
fq = lambda sch, tbl: f"{catalog}.{sch}.{tbl}"

# Model registry name (Unity Catalog 3-level name: <catalog>.<schema>.<model>)
MODEL_NAME = f"{catalog}.{schema_gold}.shift_gap_risk_model"

# Where we publish predictions
PRED_TABLE = fq(schema_gold, "shift_gap_predictions")


In [None]:
# Unity Catalog bootstrap (you may need permissions to create catalogs/schemas)
spark.sql(f"CREATE CATALOG IF NOT EXISTS {catalog}")
spark.sql(f"USE CATALOG {catalog}")
for sch in [schema_ref, schema_bronze, schema_silver, schema_gold]:
    spark.sql(f"CREATE SCHEMA IF NOT EXISTS {catalog}.{sch}")


## Build training set

We train on `gold.staffing_gaps` (label = gap_count > 0) plus joined reference attributes, with time-derived features.


In [None]:
from pyspark.sql import functions as F

gaps = spark.read.table(fq(schema_gold, "staffing_gaps"))
ref_proc = spark.read.table(fq(schema_ref, "procedure")).select(
    F.col("procedure_code").alias("required_procedure_code"),
    "requires_privilege",
    "requires_acls"
)

base = (
    gaps
      .join(ref_proc, "required_procedure_code", "left")
      .withColumn("shift_date", F.to_date("start_ts"))
      .withColumn("dow", F.dayofweek("start_ts"))
      .withColumn("hour", F.hour("start_ts"))
      .withColumn("is_weekend", F.when(F.col("dow").isin([1, 7]), F.lit(1)).otherwise(F.lit(0)))
      .withColumn("days_to_shift", F.datediff(F.to_date("start_ts"), F.current_date()))
      .withColumn("label", F.when(F.col("gap_count") > 0, F.lit(1)).otherwise(F.lit(0)))
)

# Basic sanity
display(base.select("shift_id", "facility_id", "required_procedure_code", "required_count", "assigned_count", "eligible_provider_count", "gap_count", "label").limit(10))


## Train a Spark ML pipeline

Uses categorical encodings + a classifier, logs to MLflow, and registers the model to Unity Catalog.


In [None]:
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator
import mlflow
import mlflow.spark

# Ensure we use Unity Catalog model registry
mlflow.set_registry_uri("databricks-uc")

# Keep only rows with the minimal required fields
data = (
    base
      .select(
          "label",
          "facility_id",
          "required_procedure_code",
          "required_count",
          "assigned_count",
          "eligible_provider_count",
          "requires_privilege",
          "requires_acls",
          "dow",
          "hour",
          "is_weekend",
          "days_to_shift",
          "shift_id",
          "shift_date"
      )
      .fillna({"requires_privilege": False, "requires_acls": False})
      .withColumn("requires_privilege_i", F.col("requires_privilege").cast("int"))
      .withColumn("requires_acls_i", F.col("requires_acls").cast("int"))
)

train_df, test_df = data.randomSplit([0.8, 0.2], seed=SEED)

# Categorical
facility_indexer = StringIndexer(inputCol="facility_id", outputCol="facility_id_idx", handleInvalid="keep")
proc_indexer = StringIndexer(inputCol="required_procedure_code", outputCol="proc_code_idx", handleInvalid="keep")
encoder = OneHotEncoder(
    inputCols=["facility_id_idx", "proc_code_idx"],
    outputCols=["facility_ohe", "proc_ohe"]
)

# Numeric
numeric_cols = [
    "required_count",
    "assigned_count",
    "eligible_provider_count",
    "requires_privilege_i",
    "requires_acls_i",
    "dow",
    "hour",
    "is_weekend",
    "days_to_shift"
]

assembler = VectorAssembler(
    inputCols=["facility_ohe", "proc_ohe"] + numeric_cols,
    outputCol="features"
)

# Classifier: Logistic regression (fast, interpretable)
clf = LogisticRegression(
    featuresCol="features",
    labelCol="label",
    maxIter=50,
    regParam=0.1,
    elasticNetParam=0.0
)

pipeline = Pipeline(stages=[facility_indexer, proc_indexer, encoder, assembler, clf])

mlflow.spark.autolog()

with mlflow.start_run(run_name="shift_gap_risk_train"):
    model = pipeline.fit(train_df)
    pred = model.transform(test_df)

    evaluator = BinaryClassificationEvaluator(labelCol="label", rawPredictionCol="rawPrediction", metricName="areaUnderROC")
    auc = evaluator.evaluate(pred)

    # Basic accuracy
    scored = pred.select("label", F.col("prediction").cast("int").alias("pred"))
    acc = scored.filter(F.col("label") == F.col("pred")).count() / max(1, scored.count())

    mlflow.log_metric("test_auc", float(auc))
    mlflow.log_metric("test_accuracy", float(acc))

    # Log + register to Unity Catalog (3-level name)
    mlflow.spark.log_model(
        spark_model=model,
        artifact_path="model",
        registered_model_name=MODEL_NAME
    )

print(f"AUC={auc:.4f}, accuracy={acc:.4f}")
display(pred.select("label", "prediction", "probability").limit(20))


## Set a ‘Champion’ alias (optional)

This makes it easier to deploy consistently (you can reference the model by alias in downstream jobs).


In [None]:
from mlflow.tracking import MlflowClient

client = MlflowClient()

# Get the newest version (by version number)
versions = client.search_model_versions(f"name='{MODEL_NAME}'")
latest = max([int(v.version) for v in versions]) if versions else None
print(f"Latest registered version for {MODEL_NAME}: {latest}")

# Set alias if supported / permitted
if latest is not None:
    try:
        client.set_registered_model_alias(MODEL_NAME, "Champion", str(latest))
        print("Alias set: Champion")
    except Exception as e:
        print("Could not set alias (permissions or feature availability). Continuing.")
        print(str(e)[:500])


## Batch deployment: write predictions table

Loads the registered model and scores all shifts in `gold.staffing_gaps`, writing results to `gold.shift_gap_predictions` for dashboards/apps.


In [None]:
# Choose a stable model URI: prefer alias if available, else use latest version
model_uri = f"models:/{MODEL_NAME}@Champion"

try:
    loaded = mlflow.spark.load_model(model_uri)
    print(f"Loaded model by alias: {model_uri}")
except Exception:
    # Fall back to latest numeric version
    versions = client.search_model_versions(f"name='{MODEL_NAME}'")
    latest = max([int(v.version) for v in versions])
    model_uri = f"models:/{MODEL_NAME}/{latest}"
    loaded = mlflow.spark.load_model(model_uri)
    print(f"Loaded model by version: {model_uri}")

# Rebuild the same feature columns used for training
to_score = (
    spark.read.table(fq(schema_gold, "staffing_gaps"))
      .join(
          spark.read.table(fq(schema_ref, "procedure")).select(
              F.col("procedure_code").alias("required_procedure_code"),
              "requires_privilege",
              "requires_acls"
          ),
          "required_procedure_code",
          "left"
      )
      .withColumn("shift_date", F.to_date("start_ts"))
      .withColumn("dow", F.dayofweek("start_ts"))
      .withColumn("hour", F.hour("start_ts"))
      .withColumn("is_weekend", F.when(F.col("dow").isin([1, 7]), F.lit(1)).otherwise(F.lit(0)))
      .withColumn("days_to_shift", F.datediff(F.to_date("start_ts"), F.current_date()))
      .fillna({"requires_privilege": False, "requires_acls": False})
      .withColumn("requires_privilege_i", F.col("requires_privilege").cast("int"))
      .withColumn("requires_acls_i", F.col("requires_acls").cast("int"))
)

# The logged Spark ML pipeline expects the same raw columns (it includes indexers/encoders/assembler)
scored = loaded.transform(
    to_score.select(
        "facility_id",
        "required_procedure_code",
        "required_count",
        "assigned_count",
        "eligible_provider_count",
        "requires_privilege_i",
        "requires_acls_i",
        "dow",
        "hour",
        "is_weekend",
        "days_to_shift",
        "shift_id",
        "shift_date",
        "gap_count",
        "risk_level",
        "risk_reason"
    ).withColumnRenamed("requires_privilege_i", "requires_privilege")
     .withColumnRenamed("requires_acls_i", "requires_acls")
)

# Extract probability of class 1 (gap)
pred_out = (
    scored
      .withColumn("predicted_is_gap", F.col("prediction").cast("int"))
      .withColumn("predicted_gap_prob", F.col("probability").getItem(1))
      .withColumn("scored_at", F.current_timestamp())
      .select(
          "shift_id",
          "shift_date",
          "facility_id",
          "required_procedure_code",
          "required_count",
          "assigned_count",
          "eligible_provider_count",
          "gap_count",
          "risk_level",
          "risk_reason",
          "predicted_is_gap",
          "predicted_gap_prob",
          "scored_at"
      )
)

pred_out.write.format("delta").mode("overwrite").option("overwriteSchema", "true").saveAsTable(PRED_TABLE)
print(f"Wrote predictions: {PRED_TABLE}")
display(pred_out.orderBy(F.desc("predicted_gap_prob")).limit(25))


## Optional: create a Model Serving endpoint

If you have permissions, you can deploy the registered model as a real-time endpoint using the Databricks deployments client.


In [None]:
DO_CREATE_SERVING_ENDPOINT = False
SERVING_ENDPOINT_NAME = "shift-gap-risk-endpoint"

if DO_CREATE_SERVING_ENDPOINT:
    import mlflow
    from mlflow.deployments import get_deploy_client

    mlflow.set_registry_uri("databricks-uc")
    client = get_deploy_client("databricks")

    # Note: config details vary by workspace; this is a minimal example.
    # Requires permissions for model serving.
    endpoint = client.create_endpoint(
        name=SERVING_ENDPOINT_NAME,
        config={
            "served_entities": [
                {
                    "entity_name": MODEL_NAME,
                    "entity_version": str(latest) if latest is not None else "1",
                    "workload_size": "Small",
                    "scale_to_zero_enabled": True
                }
            ]
        }
    )
    print(endpoint)
else:
    print("Skipping endpoint creation. Set DO_CREATE_SERVING_ENDPOINT=True to attempt deployment.")
