# 04 â€” Build silver current-state + gold analytics

Builds silver current-state tables and gold analytics tables (provider 360, staffing gaps, recommendations, credential risk, daily KPIs).


In [None]:
%pip install faker==25.2.0


In [None]:
# Configuration (Databricks widgets)
# These widgets make the demo portable across workspaces/accounts.
# If you're running this outside a Databricks notebook, it will fall back to defaults.

DEFAULT_CATALOG = "staffing_catalog"
DEFAULT_SCHEMA_REF = "credentialing_ref"
DEFAULT_SCHEMA_BRONZE = "credentialing_bronze"
DEFAULT_SCHEMA_SILVER = "credentialing_silver"
DEFAULT_SCHEMA_GOLD = "credentialing_gold"

DEFAULT_N_PROVIDERS = 200
DEFAULT_DAYS_SCHEDULE = 14
DEFAULT_SEED = 42

try:
    dbutils.widgets.text("catalog", DEFAULT_CATALOG, "Catalog")
    dbutils.widgets.text("schema_ref", DEFAULT_SCHEMA_REF, "Schema (ref)")
    dbutils.widgets.text("schema_bronze", DEFAULT_SCHEMA_BRONZE, "Schema (bronze)")
    dbutils.widgets.text("schema_silver", DEFAULT_SCHEMA_SILVER, "Schema (silver)")
    dbutils.widgets.text("schema_gold", DEFAULT_SCHEMA_GOLD, "Schema (gold)")

    dbutils.widgets.text("n_providers", str(DEFAULT_N_PROVIDERS), "N providers")
    dbutils.widgets.text("days_schedule", str(DEFAULT_DAYS_SCHEDULE), "Days schedule")
    dbutils.widgets.text("seed", str(DEFAULT_SEED), "Random seed")

    catalog = dbutils.widgets.get("catalog") or DEFAULT_CATALOG
    schema_ref = dbutils.widgets.get("schema_ref") or DEFAULT_SCHEMA_REF
    schema_bronze = dbutils.widgets.get("schema_bronze") or DEFAULT_SCHEMA_BRONZE
    schema_silver = dbutils.widgets.get("schema_silver") or DEFAULT_SCHEMA_SILVER
    schema_gold = dbutils.widgets.get("schema_gold") or DEFAULT_SCHEMA_GOLD

    N_PROVIDERS = int(dbutils.widgets.get("n_providers") or DEFAULT_N_PROVIDERS)
    DAYS_SCHEDULE = int(dbutils.widgets.get("days_schedule") or DEFAULT_DAYS_SCHEDULE)
    SEED = int(dbutils.widgets.get("seed") or DEFAULT_SEED)
except Exception:
    catalog = DEFAULT_CATALOG
    schema_ref = DEFAULT_SCHEMA_REF
    schema_bronze = DEFAULT_SCHEMA_BRONZE
    schema_silver = DEFAULT_SCHEMA_SILVER
    schema_gold = DEFAULT_SCHEMA_GOLD

    N_PROVIDERS = DEFAULT_N_PROVIDERS
    DAYS_SCHEDULE = DEFAULT_DAYS_SCHEDULE
    SEED = DEFAULT_SEED

# Derived helpers
fq = lambda sch, tbl: f"{catalog}.{sch}.{tbl}"


In [None]:
# Unity Catalog bootstrap (you may need permissions to create catalogs/schemas)
spark.sql(f"CREATE CATALOG IF NOT EXISTS {catalog}")
spark.sql(f"USE CATALOG {catalog}")
for sch in [schema_ref, schema_bronze, schema_silver, schema_gold]:
    spark.sql(f"CREATE SCHEMA IF NOT EXISTS {catalog}.{sch}")


## Silver current-state
`silver.current_credential` keeps the latest record per provider+cred_type (window ordered by verified_at desc, ingested_at desc).


In [None]:
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from datetime import datetime

provider = spark.read.table(fq(schema_bronze, "provider_raw"))
cred_evt = spark.read.table(fq(schema_bronze, "credential_event_raw"))
priv_raw = spark.read.table(fq(schema_bronze, "privilege_raw"))
enr_raw = spark.read.table(fq(schema_bronze, "payer_enrollment_raw"))
shift = spark.read.table(fq(schema_bronze, "shift_raw"))
assign = spark.read.table(fq(schema_bronze, "assignment_raw"))

ref_facility = spark.read.table(fq(schema_ref, "facility"))
ref_procedure = spark.read.table(fq(schema_ref, "procedure"))

w = Window.partitionBy("provider_id", "cred_type").orderBy(
    F.col("verified_at").desc_nulls_last(),
    F.col("ingested_at").desc()
)

current_credential = (
    cred_evt
      .withColumn("rn", F.row_number().over(w))
      .filter(F.col("rn") == 1)
      .drop("rn")
      .withColumn("days_until_expiration", F.datediff(F.to_date("expires_at"), F.current_date()))
)

current_credential.write.format("delta").mode("overwrite").option("overwriteSchema", "true").saveAsTable(
    fq(schema_silver, "current_credential")
)

# For simplicity in this demo: treat these bronze tables as current state (overwrite)
priv_raw.write.format("delta").mode("overwrite").option("overwriteSchema", "true").saveAsTable(
    fq(schema_silver, "current_privilege")
)
enr_raw.write.format("delta").mode("overwrite").option("overwriteSchema", "true").saveAsTable(
    fq(schema_silver, "current_payer_enrollment")
)

# Nurse assignment current state
# Use the latest available assignment date (since demo data is generated with fixed dates)
nurse_assign_raw = spark.read.table(fq(schema_bronze, "nurse_assignment_raw"))
latest_date = nurse_assign_raw.filter(F.col("assignment_status") == "ASSIGNED").agg(F.max("assignment_date")).collect()[0][0]
nurse_assign_current = (
    nurse_assign_raw
      .filter(F.col("assignment_status") == "ASSIGNED")
      .filter(F.col("assignment_date") == latest_date)
)
nurse_assign_current.write.format("delta").mode("overwrite").option("overwriteSchema", "true").saveAsTable(
    fq(schema_silver, "nurse_assignment_current")
)
print(f"Using assignment date: {latest_date}, rows: {nurse_assign_current.count()}")


## Gold: provider_360_flat
Provider row enriched with facility name and summarized credential/privilege/enrollment fields.


In [None]:
current_cred = spark.read.table(fq(schema_silver, "current_credential"))
current_priv = spark.read.table(fq(schema_silver, "current_privilege"))
current_enr = spark.read.table(fq(schema_silver, "current_payer_enrollment"))

lic = (
    current_cred
      .filter(F.col("cred_type") == "STATE_MED_LICENSE")
      .select(
          "provider_id",
          F.col("cred_status").alias("state_license_status"),
          F.col("days_until_expiration").alias("state_license_days_left")
      )
)

acls = (
    current_cred
      .filter(F.col("cred_type") == "ACLS")
      .select(
          "provider_id",
          F.col("cred_status").alias("acls_status"),
          F.col("days_until_expiration").alias("acls_days_left")
      )
)

active_priv = current_priv.filter(F.col("privilege_status") == "ACTIVE")
priv_rollup = active_priv.groupBy("provider_id").agg(
    F.count("*").alias("active_privilege_count"),
    F.countDistinct("facility_id").alias("active_privilege_facility_count")
)

active_enr = current_enr.filter(F.col("enrollment_status") == "ACTIVE")
payer_rollup = active_enr.groupBy("provider_id").agg(
    F.countDistinct("payer_id").alias("active_payer_count")
)

# Load unit reference for primary_unit_name
ref_unit = spark.read.table(fq(schema_ref, "unit"))

provider_360 = (
    provider
      .join(
          ref_facility.select("facility_id", F.col("facility_name").alias("home_facility_name")),
          provider.home_facility_id == F.col("facility_id"),
          "left"
      )
      .drop("facility_id")
      .join(
          ref_unit.select("unit_id", F.col("unit_name").alias("primary_unit_name")),
          provider.primary_unit_id == F.col("unit_id"),
          "left"
      )
      .drop("unit_id")
      .join(lic, "provider_id", "left")
      .join(acls, "provider_id", "left")
      .join(priv_rollup, "provider_id", "left")
      .join(payer_rollup, "provider_id", "left")
      .fillna({
          "active_privilege_count": 0,
          "active_privilege_facility_count": 0,
          "active_payer_count": 0
      })
      .withColumn("last_built_at", F.current_timestamp())
)

provider_360.write.format("delta").mode("overwrite").option("overwriteSchema", "true").saveAsTable(
    fq(schema_gold, "provider_360_flat")
)


## Gold: staffing_gaps + shift_recommendations
Eligibility rules: provider ACTIVE, valid state license, >=1 ACTIVE payer enrollment, plus privilege/ACLS as required by the procedure.


In [None]:
# Eligibility building blocks
p_active = provider.filter(F.col("provider_status") == "ACTIVE").select("provider_id").distinct()
lic_ok = (
    current_cred
      .filter(F.col("cred_type") == "STATE_MED_LICENSE")
      .filter(F.col("days_until_expiration") >= 0)
      .select("provider_id").distinct()
)
payer_ok = current_enr.filter(F.col("enrollment_status") == "ACTIVE").select("provider_id").distinct()
acls_ok = (
    current_cred
      .filter(F.col("cred_type") == "ACLS")
      .filter(F.col("days_until_expiration") >= 0)
      .select("provider_id").distinct()
)

base_ok = p_active.join(lic_ok, "provider_id", "inner").join(payer_ok, "provider_id", "inner")

# Shift enrichment
shift_req = (
    shift
      .join(
          ref_procedure.select("procedure_code", "procedure_name", "requires_privilege", "requires_acls"),
          shift.required_procedure_code == F.col("procedure_code"),
          "left"
      )
      .drop("procedure_code")
      .join(ref_facility.select("facility_id", "facility_name"), "facility_id", "left")
)

# Candidate provider x shift (small demo sizes make crossJoin acceptable)
cand = shift_req.select(
    "shift_id", "facility_id", "required_procedure_code", "requires_privilege", "requires_acls"
).crossJoin(base_ok)

# Privilege requirement
priv_ok = (
    current_priv
      .filter(F.col("privilege_status") == "ACTIVE")
      .select(
          "provider_id",
          "facility_id",
          F.col("procedure_code").alias("required_procedure_code")
      )
      .withColumn("has_priv", F.lit(1))
)
cand = cand.join(priv_ok, ["provider_id", "facility_id", "required_procedure_code"], "left")

# ACLS requirement
cand = cand.join(acls_ok.withColumn("has_acls", F.lit(1)), "provider_id", "left")

eligible = (
    cand
      .withColumn(
          "eligible",
          F.when((F.col("requires_privilege") == True) & (F.col("has_priv").isNull()), F.lit(False))
           .when((F.col("requires_acls") == True) & (F.col("has_acls").isNull()), F.lit(False))
           .otherwise(F.lit(True))
      )
      .filter(F.col("eligible") == True)
      .select("shift_id", "provider_id")
)

assigned = (
    assign
      .filter(F.col("assignment_status") == "ASSIGNED")
      .groupBy("shift_id")
      .agg(F.countDistinct("provider_id").alias("assigned_count"))
)
eligible_cnt = eligible.groupBy("shift_id").agg(F.countDistinct("provider_id").alias("eligible_provider_count"))

staffing_gaps = (
    shift_req
      .join(assigned, "shift_id", "left")
      .join(eligible_cnt, "shift_id", "left")
      .fillna({"assigned_count": 0, "eligible_provider_count": 0})
      .withColumn("gap_count", F.greatest(F.col("required_count") - F.col("assigned_count"), F.lit(0)))
      .withColumn(
          "risk_reason",
          F.when(F.col("gap_count") <= 0, F.lit("OK"))
           .when(F.col("eligible_provider_count") == 0, F.lit("No eligible providers"))
           .when(F.col("assigned_count") == 0, F.lit("Unfilled shift"))
           .otherwise(F.lit("Partial coverage"))
      )
      .withColumn(
          "risk_level",
          F.when((F.col("gap_count") > 0) & (F.col("eligible_provider_count") == 0), F.lit("HIGH"))
           .when(F.col("gap_count") >= 2, F.lit("HIGH"))
           .when(F.col("gap_count") == 1, F.lit("MEDIUM"))
           .otherwise(F.lit("LOW"))
      )
      .withColumn("last_built_at", F.current_timestamp())
      .select(
          "shift_id", "facility_id", "facility_name", "start_ts", "end_ts",
          "required_procedure_code", "procedure_name",
          "required_count", "assigned_count", "eligible_provider_count",
          "gap_count", "risk_reason", "risk_level", "last_built_at"
      )
)

staffing_gaps.write.format("delta").mode("overwrite").option("overwriteSchema", "true").saveAsTable(
    fq(schema_gold, "staffing_gaps")
)

# Recommendations: up to 5 eligible providers per shift
rank_w = Window.partitionBy("shift_id").orderBy(F.rand(SEED))
shift_recommendations = (
    eligible
      .withColumn("rn", F.row_number().over(rank_w))
      .filter(F.col("rn") <= 5)
      .groupBy("shift_id")
      .agg(F.collect_list("provider_id").alias("recommended_provider_ids"))
)

shift_recommendations.write.format("delta").mode("overwrite").option("overwriteSchema", "true").saveAsTable(
    fq(schema_gold, "shift_recommendations")
)


## Gold: credential_risk + kpi_summary_daily
Buckets based on days left and a simple daily KPI snapshot.


In [None]:
credential_risk = (
    current_cred
      .withColumn(
          "risk_bucket",
          F.when(F.col("days_until_expiration") < 0, F.lit("EXPIRED"))
           .when(F.col("days_until_expiration") <= 14, F.lit("0-14"))
           .when(F.col("days_until_expiration") <= 30, F.lit("15-30"))
           .when(F.col("days_until_expiration") <= 90, F.lit("31-90"))
           .otherwise(F.lit(">90"))
      )
      .withColumn("last_built_at", F.current_timestamp())
)

credential_risk.write.format("delta").mode("overwrite").option("overwriteSchema", "true").saveAsTable(
    fq(schema_gold, "credential_risk")
)

# Daily KPIs
providers_total = provider.count()
providers_pending = provider.join(base_ok, "provider_id", "left_anti").count()
providers_expiring_30d = (
    current_cred
      .filter(F.col("cred_type") == "STATE_MED_LICENSE")
      .filter((F.col("days_until_expiration") >= 0) & (F.col("days_until_expiration") <= 30))
      .select("provider_id").distinct().count()
)

daily_revenue_at_risk_est = float(providers_expiring_30d) * 7500.0  # demo estimate

kpi_df = spark.createDataFrame(
    [(
        datetime.utcnow().date(),
        int(providers_total),
        int(providers_pending),
        int(providers_expiring_30d),
        float(daily_revenue_at_risk_est),
        datetime.utcnow()
    )],
    ["kpi_date", "providers_total", "providers_pending", "providers_expiring_30d", "daily_revenue_at_risk_est", "last_built_at"]
)

kpi_df.write.format("delta").mode("overwrite").option("overwriteSchema", "true").saveAsTable(
    fq(schema_gold, "kpi_summary_daily")
)

display(spark.read.table(fq(schema_gold, "staffing_gaps")).orderBy(F.desc("gap_count")).limit(25))


In [None]:
## Gold: risk_actions (closed-loop mitigation workflow)
# This table is intentionally simple: it captures operational actions tied to either a SHIFT or a PROVIDER.

import uuid
from datetime import datetime, timedelta
from pyspark.sql import functions as F

base_ts = datetime(2026, 1, 1, 8, 0, 0)

# Stable UUIDs make this table deterministic/reproducible across reruns and across accounts.
def _action_id(key: str) -> str:
    return str(uuid.uuid5(uuid.NAMESPACE_URL, f"risk-action-{SEED}-{key}"))

_action_id_udf = F.udf(_action_id)

owners = ["staffing_coordinator", "med_staff_office", "ops_manager"]

# SHIFT actions (from staffing gaps)
gaps_for_actions = (
    spark.read.table(fq(schema_gold, "staffing_gaps"))
      .filter(F.col("gap_count") > 0)
      .filter(F.col("risk_level").isin(["HIGH", "MEDIUM"]))
      .orderBy(F.desc("gap_count"), F.asc("start_ts"))
      .limit(75)
)

shift_actions = (
    gaps_for_actions
      .withColumn("action_id", _action_id_udf(F.concat(F.lit("SHIFT:"), F.col("shift_id"))))
      .withColumn("entity_type", F.lit("SHIFT"))
      .withColumn("entity_id", F.col("shift_id"))
      .withColumn("action_type", F.lit("OUTREACH"))
      .withColumn("status", F.lit("OPEN"))
      .withColumn(
          "priority",
          F.when(F.col("risk_level") == "HIGH", F.lit("HIGH")).otherwise(F.lit("MEDIUM"))
      )
      .withColumn("owner", F.element_at(F.array([F.lit(o) for o in owners]), (F.pmod(F.hash(F.col("shift_id")), F.lit(len(owners))) + 1)))
      .withColumn("notes", F.concat(F.lit("Outreach for uncovered shift (gap="), F.col("gap_count").cast("string"), F.lit(")")))
      .withColumn("created_at", F.lit(base_ts) + F.expr("INTERVAL 1 HOURS") * F.pmod(F.hash(F.col("shift_id")), F.lit(120)))
      .withColumn("updated_at", F.col("created_at"))
      .withColumn("resolved_at", F.lit(None).cast("timestamp"))
      .withColumn("last_built_at", F.current_timestamp())
      .select(
          "action_id",
          "entity_type",
          "entity_id",
          "facility_id",
          "action_type",
          "status",
          "priority",
          "owner",
          "created_at",
          "updated_at",
          "resolved_at",
          "notes",
          "last_built_at",
      )
)

# PROVIDER actions (from expiring credentials)
provider_home = spark.read.table(fq(schema_gold, "provider_360_flat")).select("provider_id", "home_facility_id")

cred_for_actions = (
    spark.read.table(fq(schema_gold, "credential_risk"))
      .filter(F.col("cred_type").isin(["STATE_MED_LICENSE", "ACLS"]))
      .filter(F.col("risk_bucket").isin(["EXPIRED", "0-14", "15-30"]))
      .orderBy(F.asc("days_until_expiration"))
      .limit(75)
      .join(provider_home, "provider_id", "left")
)

provider_actions = (
    cred_for_actions
      .withColumn("action_id", _action_id_udf(F.concat(F.lit("PROVIDER:"), F.col("provider_id"), F.lit(":"), F.col("cred_type"))))
      .withColumn("entity_type", F.lit("PROVIDER"))
      .withColumn("entity_id", F.col("provider_id"))
      .withColumn("facility_id", F.col("home_facility_id"))
      .withColumn("action_type", F.lit("CREDENTIAL_EXPEDITE"))
      .withColumn("status", F.lit("OPEN"))
      .withColumn(
          "priority",
          F.when(F.col("risk_bucket") == "EXPIRED", F.lit("HIGH"))
           .when(F.col("risk_bucket") == "0-14", F.lit("HIGH"))
           .otherwise(F.lit("MEDIUM"))
      )
      .withColumn("owner", F.lit("med_staff_office"))
      .withColumn(
          "notes",
          F.concat(
              F.lit("Credential renewal outreach: "),
              F.col("cred_type"),
              F.lit(" (days_left="),
              F.col("days_until_expiration").cast("string"),
              F.lit(")")
          )
      )
      .withColumn("created_at", F.lit(base_ts) + F.expr("INTERVAL 1 HOURS") * F.pmod(F.hash(F.col("provider_id")), F.lit(120)))
      .withColumn("updated_at", F.col("created_at"))
      .withColumn("resolved_at", F.lit(None).cast("timestamp"))
      .withColumn("last_built_at", F.current_timestamp())
      .select(
          "action_id",
          "entity_type",
          "entity_id",
          "facility_id",
          "action_type",
          "status",
          "priority",
          "owner",
          "created_at",
          "updated_at",
          "resolved_at",
          "notes",
          "last_built_at",
      )
)

risk_actions = shift_actions.unionByName(provider_actions)

risk_actions.write.format("delta").mode("overwrite").option("overwriteSchema", "true").saveAsTable(
    fq(schema_gold, "risk_actions")
)

display(risk_actions.orderBy(F.desc("priority"), F.desc("created_at")).limit(50))

## Gold: nurse_staffing_summary
Daily staffing summary per unit with nurse-to-patient ratios, staffing status, and labor costs.

In [None]:
# Load reference tables for nurse staffing
ref_unit = spark.read.table(fq(schema_ref, "unit"))
nurse_assign_current = spark.read.table(fq(schema_silver, "nurse_assignment_current"))

# Join nurse assignments with provider info to get employment type and hourly rate
nurse_with_info = (
    nurse_assign_current
      .join(
          provider.select("provider_id", "employment_type", "hourly_rate"),
          "provider_id",
          "left"
      )
)

# Aggregate by unit
staffing_by_unit = (
    nurse_with_info
      .groupBy("unit_id")
      .agg(
          F.count("*").alias("nurses_assigned"),
          F.sum(F.when(F.col("employment_type") == "INTERNAL", 1).otherwise(0)).alias("nurses_internal"),
          F.sum(F.when(F.col("employment_type") == "CONTRACT", 1).otherwise(0)).alias("nurses_contract"),
          F.sum(F.when(F.col("employment_type") == "AGENCY", 1).otherwise(0)).alias("nurses_agency"),
          F.sum(F.col("hourly_rate") * 12).alias("labor_cost_daily"),  # 12-hour shifts
      )
)

# Simulate census data (in production this would come from ADT/census table)
# For demo: census = bed_count * random factor (0.6-0.9)
import random
random.seed(SEED)

census_data = []
for row in ref_unit.collect():
    census = int(row["bed_count"] * random.uniform(0.6, 0.9))
    census_data.append((row["unit_id"], census))

census_df = spark.createDataFrame(census_data, ["unit_id", "current_census"])

# Build nurse staffing summary
nurse_staffing_summary = (
    ref_unit
      .join(ref_facility.select("facility_id", "facility_name"), "facility_id", "left")
      .join(census_df, "unit_id", "left")
      .join(staffing_by_unit, "unit_id", "left")
      .fillna({
          "nurses_assigned": 0,
          "nurses_internal": 0,
          "nurses_contract": 0,
          "nurses_agency": 0,
          "labor_cost_daily": 0.0,
          "current_census": 0,
      })
      .withColumn("summary_date", F.current_date())
      .withColumn("nurses_required", F.ceil(F.col("current_census") / F.col("target_ratio")).cast("int"))
      .withColumn("staffing_delta", F.col("nurses_assigned") - F.col("nurses_required"))
      .withColumn(
          "staffing_status",
          F.when(F.col("nurses_assigned") < F.col("nurses_required"), F.lit("UNDERSTAFFED"))
           .when(F.col("nurses_assigned") > F.col("nurses_required") + 1, F.lit("OVERSTAFFED"))
           .otherwise(F.lit("OPTIMAL"))
      )
      .withColumn("last_built_at", F.current_timestamp())
      .select(
          "summary_date",
          "unit_id",
          "facility_id",
          "facility_name",
          "unit_name",
          "unit_type",
          "bed_count",
          "current_census",
          "target_ratio",
          "nurses_required",
          "nurses_assigned",
          "nurses_internal",
          "nurses_contract",
          "nurses_agency",
          "staffing_delta",
          "staffing_status",
          "labor_cost_daily",
          "last_built_at",
      )
)

nurse_staffing_summary.write.format("delta").mode("overwrite").option("overwriteSchema", "true").saveAsTable(
    fq(schema_gold, "nurse_staffing_summary")
)

display(nurse_staffing_summary.orderBy(F.desc("staffing_status"), "unit_name"))

## Gold: credential_gaps (nurse staffing)
Units where required certifications are missing among assigned nurses.

In [None]:
# Load unit certification requirements
unit_cert = spark.read.table(fq(schema_ref, "unit_certification")).filter(F.col("is_required") == True)

# Get nurses assigned to each unit today
nurses_by_unit = (
    nurse_assign_current
      .groupBy("unit_id")
      .agg(
          F.count("*").alias("nurses_assigned"),
          F.collect_list("provider_id").alias("nurse_ids"),
      )
)

# For each unit + required cert, count how many nurses have that cert
# Join unit -> unit_type -> unit_cert requirements
unit_with_type = ref_unit.select("unit_id", "unit_type", "unit_name", "facility_id")

unit_cert_requirements = (
    unit_with_type
      .join(unit_cert, "unit_type", "inner")
      .join(ref_facility.select("facility_id", "facility_name"), "facility_id", "left")
)

# For each nurse, check which certs they have (using credential_risk table)
nurse_certs = (
    spark.read.table(fq(schema_gold, "credential_risk"))
      .filter(F.col("cred_status") == "ACTIVE")
      .filter(F.col("days_until_expiration") >= 0)
      .select("provider_id", "cred_type")
      .distinct()
)

# Cross join nurses with their assigned units and check cert coverage
nurse_unit_cert_check = (
    nurse_assign_current
      .select("unit_id", "provider_id")
      .join(unit_with_type, "unit_id", "left")
      .join(unit_cert, "unit_type", "inner")
      .join(nurse_certs, ["provider_id", "cred_type"], "left_outer")
      .withColumn("has_cert", F.when(nurse_certs["cred_type"].isNotNull(), F.lit(1)).otherwise(F.lit(0)))
)

# Aggregate: for each unit + cert type, count nurses with/without cert
credential_gaps_raw = (
    nurse_unit_cert_check
      .groupBy("unit_id", "unit_type", "cred_type")
      .agg(
          F.count("*").alias("nurses_assigned"),
          F.sum("has_cert").alias("nurses_with_cert"),
          F.collect_list(
              F.when(F.col("has_cert") == 0, F.col("provider_id"))
          ).alias("affected_nurse_ids_raw"),
      )
      .withColumn("nurses_missing_cert", F.col("nurses_assigned") - F.col("nurses_with_cert"))
      .filter(F.col("nurses_missing_cert") > 0)  # Only show gaps
)

# Calculate severity
credential_gaps = (
    credential_gaps_raw
      .join(unit_with_type.select("unit_id", "unit_name", "facility_id"), "unit_id", "left")
      .join(ref_facility.select("facility_id", "facility_name"), "facility_id", "left")
      .withColumn(
          "gap_severity",
          F.when(F.col("nurses_missing_cert") >= F.col("nurses_assigned") * 0.5, F.lit("CRITICAL"))
           .when(F.col("nurses_missing_cert") >= F.col("nurses_assigned") * 0.25, F.lit("HIGH"))
           .when(F.col("nurses_missing_cert") > 0, F.lit("MEDIUM"))
           .otherwise(F.lit("LOW"))
      )
      .withColumn(
          "affected_nurse_ids",
          F.expr("filter(affected_nurse_ids_raw, x -> x IS NOT NULL)")
      )
      .withColumn("required_cred_type", F.col("cred_type"))
      .select(
          "unit_id",
          "facility_id",
          "facility_name",
          "unit_name",
          "unit_type",
          "required_cred_type",
          "nurses_assigned",
          "nurses_with_cert",
          "nurses_missing_cert",
          "gap_severity",
          "affected_nurse_ids",
      )
)

credential_gaps.write.format("delta").mode("overwrite").option("overwriteSchema", "true").saveAsTable(
    fq(schema_gold, "credential_gaps")
)

display(credential_gaps.orderBy(F.desc("gap_severity"), "unit_name"))