In [None]:
%md
### Kardiaflow - Silver Providers (SCD2)

**Source:** `kardia_bronze.bronze_providers` (snapshot input with `_ingest_ts`)

**Target:** `kardia_silver.silver_providers` (SCD2 with `is_current` flag)

**Pattern:** Detect changes; close current row, insert new current row (SCD2)

**Trigger:** Incremental batch

**Description:** The Silver layer is where raw data becomes trustworthy and usable. Here we enforce constraints,
standardize types, rename fields into consistent names, mask PHI, and apply deduplication, SCD1/SCD2 handling, and
timezone normalization. In dbt, staging handles renaming/typing while refined handles business rules, making lineage and documentation transparent in its SQL-first world. In Kardiaflow, we combine both in Silver, following Delta Lake convention where Silver covers adaptation and core business logic.

Notes: Tracks provider specialty/location changes over time.

In [0]:
from delta.tables import DeltaTable
from pyspark.sql import functions as F
from pyspark.sql.window import Window

from kflow.config import bronze_table, silver_paths
from kflow.notebook_utils import init, show_history

init()

# Load table paths and names for the Providers dataset
S         = silver_paths("providers")
SRC_TABLE = bronze_table("providers")
TGT_TABLE = S.table

In [0]:
# 1. Ensure Silver DB and Providers table exist
spark.sql(f"CREATE DATABASE IF NOT EXISTS {S.db}")

spark.sql(
    f"""
    CREATE TABLE IF NOT EXISTS {TGT_TABLE} (
        provider_id         STRING  NOT NULL,
        provider_specialty  STRING,
        provider_location   STRING,
        eff_start_ts        TIMESTAMP,
        eff_end_ts          TIMESTAMP,
        is_current          BOOLEAN,
        _ingest_ts          TIMESTAMP,
        _batch_id           STRING,
        _source_file        STRING
    ) USING DELTA
    LOCATION '{S.path}'
    """
)

In [0]:
# 2. Build the latest snapshot per ProviderID from Bronze
bronze = (
    spark.table(SRC_TABLE)
         .filter(F.col("ProviderID").isNotNull())
)

w_latest = (
    Window.partitionBy("ProviderID")
          .orderBy(F.col("_ingest_ts").desc_nulls_last())
)

# Identify the most recent record per provider_id using ingest timestamp
deduped_df = (
    bronze
      .withColumn("_rn", F.row_number().over(w_latest))
      .filter(F.col("_rn") == 1)
      .select(
          F.col("ProviderID").alias("provider_id"),
          F.col("ProviderSpecialty").alias("provider_specialty"),
          F.col("ProviderLocation").alias("provider_location"),
          F.col("_ingest_ts"),
          F.col("_batch_id"),
          F.col("_source_file"),
      )
)

# Final DataFrame used in MERGE (add SCD-2 columns)
latest_df = (
    deduped_df
      .withColumn("eff_start_ts", F.col("_ingest_ts"))
      .withColumn("eff_end_ts",   F.lit(None).cast("timestamp"))
      .withColumn("is_current",   F.lit(True))
)

In [None]:
# 3. Define what counts as a change (SCD-2 tracked fields)
PROVIDER_CHANGE_CONDITION = (
    "NOT (t.provider_specialty <=> s.provider_specialty) OR "
    "NOT (t.provider_location  <=> s.provider_location)"
)

In [0]:
# 4. Perform SCD Type 2 MERGE
#    Track history for changes in specialty or location

# Refresh so the newly created table is visible to the engine
spark.sql(f"REFRESH TABLE {TGT_TABLE}")

# Perform the SCD-2 upsert
(
    DeltaTable.forPath(spark, S.path)
              .alias("t")
              .merge(
                  latest_df.alias("s"),
                  "t.provider_id = s.provider_id AND t.is_current = TRUE"
              )

    # CASE 1: Close rows where tracked fields changed by setting eff_end_ts and is_current = FALSE.
              .whenMatchedUpdate(
                 condition=PROVIDER_CHANGE_CONDITION,
                 set={
                     "eff_end_ts": F.col("s.eff_start_ts"),
                     "is_current": F.lit(False),
                 },
              )
              .execute()
)
# CASE 2: Insert a new current row when:
#    - the provider is brand new (no match exists), or
#    - the provider just changed and their previous row was closed in Step 1
(
    DeltaTable.forPath(spark, S.path)
              .alias("t")
              .merge(
                  latest_df.alias("s"),
                  "t.provider_id = s.provider_id AND t.is_current = TRUE"
              )
              .whenNotMatchedInsertAll()
              .execute()
)

In [0]:
# 5. Verify Silver Providers SCD2 output
df = spark.table(TGT_TABLE)
print(f"Silver Providers row count: {df.count():,}")
display(df.orderBy(F.col("_ingest_ts").desc_nulls_last()).limit(5))
show_history(S.path)