# Databricks notebook source
# =========================================
# Notebook: 03_silver_transform
# Purpose : Curate Silver layer:
#          - type casting + standardization
#          - dedup
#          - billing CDC apply via MERGE
#          - transformer SCD2 via MERGE
# =========================================

In [0]:
%run ./01_config

In [0]:


from pyspark.sql import functions as F, Window as W
from delta.tables import DeltaTable


In [0]:
def save_as_table(df, db, table, path, mode="overwrite"):
    (df.write.format("delta").mode(mode).save(path))
    spark.sql(f"CREATE TABLE IF NOT EXISTS {db}.{table} USING DELTA LOCATION '{path}'")

def table_exists(db, table):
    try:
        spark.table(f"{db}.{table}")
        return True
    except:
        return False

def delta_exists(path):
    try:
        return DeltaTable.isDeltaTable(spark, path)
    except:
        return False

In [0]:
# -----------------------------------------
# 1) Smart Meter Silver: clean + dedup + anomaly scoring
# -----------------------------------------
sm_bronze = spark.table(f"{DB_BRONZE}.smart_meter_events")

sm = (sm_bronze
    .withColumn("event_ts", F.to_timestamp("timestamp"))
    .withColumn("event_date", F.to_date("event_ts"))
    .drop("timestamp")
)

In [0]:
# Dedup: keep latest record for each event_id (or fallback to meter_id+event_ts)
# We use ingestion timestamp as tie-breaker (deterministic).
w_sm = W.partitionBy("event_id").orderBy(F.col("_ingestion_ts").desc(), F.col("_source_file").desc())
sm_dedup = (sm
    .withColumn("_rn", F.row_number().over(w_sm))
    .filter(F.col("_rn") == 1)
    .drop("_rn")
)

In [0]:
# Robust anomaly thresholds per region (1st & 99th percentile using approx)
# (Edge: dynamic thresholds instead of hard-coded)
thr = (sm_dedup
    .groupBy("region")
    .agg(
        F.expr("percentile_approx(kwh_consumed, 0.01)").alias("p01_kwh"),
        F.expr("percentile_approx(kwh_consumed, 0.99)").alias("p99_kwh"),
    )
)

sm_scored = (sm_dedup
    .join(thr, on="region", how="left")
    .withColumn("is_kwh_anomaly",
        (F.col("kwh_consumed") < F.col("p01_kwh")) | (F.col("kwh_consumed") > F.col("p99_kwh"))
    )
    .withColumn("anomaly_type",
        F.when(F.col("kwh_consumed") < F.col("p01_kwh"), F.lit("LOW_CONSUMPTION"))
         .when(F.col("kwh_consumed") > F.col("p99_kwh"), F.lit("HIGH_CONSUMPTION"))
         .otherwise(F.lit(None))
    )
    .drop("p01_kwh","p99_kwh")
)

sm_silver_path = f"{PATH_SILVER}/smart_meter_readings"
save_as_table(sm_scored, DB_SILVER, "smart_meter_readings", sm_silver_path, mode="overwrite")

audit_table_metrics(f"{DB_SILVER}.smart_meter_readings", "SILVER", "rowcount", str(sm_scored.count()))


In [0]:
# -----------------------------------------
# 2) Substation Telemetry Silver: clean + dedup + transformer risk score
# -----------------------------------------
tel_bronze = spark.table(f"{DB_BRONZE}.substation_telemetry_events")

tel = (tel_bronze
    .withColumn("telemetry_ts", F.to_timestamp("timestamp"))
    .withColumn("event_date", F.to_date("telemetry_ts"))
    .drop("timestamp")
)


In [0]:
# Dedup by telemetry_id
w_tel = W.partitionBy("telemetry_id").orderBy(F.col("_ingestion_ts").desc(), F.col("_source_file").desc())
tel_dedup = (tel
    .withColumn("_rn", F.row_number().over(w_tel))
    .filter(F.col("_rn") == 1)
    .drop("_rn")
)

In [0]:
# Risk scoring (simple + explainable, but strong story)
tel_scored = (tel_dedup
    .withColumn("oil_temp_delta_c", F.col("oil_temperature_c") - F.col("ambient_temperature_c"))
    .withColumn("risk_alarm", F.when(F.col("alarm_code").isNotNull() & (F.col("alarm_code") != ""), 1).otherwise(0))
    .withColumn("risk_oil_temp", F.when(F.col("oil_temp_delta_c") > 25, 1).otherwise(0))
    .withColumn("risk_gas", F.when(F.col("dissolved_gas_ppm") > 800, 1).otherwise(0))
    .withColumn("risk_vibration", F.when(F.col("vibration_mm_s") > 7, 1).otherwise(0))
    .withColumn("risk_overload", F.when(F.col("load_pct") > 90, 1).otherwise(0))
    .withColumn("risk_score",
        3*F.col("risk_alarm") +
        2*F.col("risk_oil_temp") +
        2*F.col("risk_gas") +
        1*F.col("risk_vibration") +
        1*F.col("risk_overload")
    )
    .withColumn("risk_level",
        F.when(F.col("risk_score") >= 6, F.lit("HIGH"))
         .when(F.col("risk_score") >= 3, F.lit("MEDIUM"))
         .otherwise(F.lit("LOW"))
    )
)

tel_silver_path = f"{PATH_SILVER}/substation_telemetry"
save_as_table(tel_scored, DB_SILVER, "substation_telemetry", tel_silver_path, mode="overwrite")
audit_table_metrics(f"{DB_SILVER}.substation_telemetry", "SILVER", "rowcount", str(tel_scored.count()))


In [0]:
# -----------------------------------------
# 3) Maintenance Logs Silver: standardize
# -----------------------------------------
mnt_bronze = spark.table(f"{DB_BRONZE}.maintenance_logs")

# Standardize likely columns (we keep flexible if schema changes)
# If your CSV has different names, we’ll keep columns and only standardize dates if present.
mnt = mnt_bronze
for c in ["maintenance_date", "performed_at", "date"]:
    if c in mnt.columns:
        mnt = mnt.withColumn("maintenance_ts", F.to_timestamp(F.col(c)))
        break

mnt_silver_path = f"{PATH_SILVER}/maintenance_logs"
save_as_table(mnt, DB_SILVER, "maintenance_logs", mnt_silver_path, mode="overwrite")
audit_table_metrics(f"{DB_SILVER}.maintenance_logs", "SILVER", "rowcount", str(mnt.count()))


In [0]:
# -----------------------------------------
# 4) Customer Dimension Silver: SCD2-ready snapshot
# -----------------------------------------
cust_bronze = spark.table(f"{DB_BRONZE}.customer_master_snapshot")

# If dataset already contains SCD2 fields, keep them. Else add baseline.
cust = cust_bronze
if "is_current" not in cust.columns:
    cust = (cust
        .withColumn("record_effective_date", F.to_date(F.lit("1900-01-01")))
        .withColumn("record_end_date", F.to_date(F.lit("9999-12-31")))
        .withColumn("is_current", F.lit(True))
        .withColumn("scd_version", F.lit(1))
    )

cust_silver_path = f"{PATH_SILVER}/dim_customer"
save_as_table(cust, DB_SILVER, "dim_customer", cust_silver_path, mode="overwrite")
spark.sql(f"CREATE OR REPLACE VIEW {DB_SILVER}.dim_customer_current AS SELECT * FROM {DB_SILVER}.dim_customer WHERE is_current = true")
audit_table_metrics(f"{DB_SILVER}.dim_customer", "SILVER", "rowcount", str(cust.count()))


In [0]:
# -----------------------------------------
# 4) Customer Dimension Silver: SCD2-ready snapshot
# -----------------------------------------
cust_bronze = spark.table(f"{DB_BRONZE}.customer_master_snapshot")

# If dataset already contains SCD2 fields, keep them. Else add baseline.
cust = cust_bronze
if "is_current" not in cust.columns:
    cust = (cust
        .withColumn("record_effective_date", F.to_date(F.lit("1900-01-01")))
        .withColumn("record_end_date", F.to_date(F.lit("9999-12-31")))
        .withColumn("is_current", F.lit(True))
        .withColumn("scd_version", F.lit(1))
    )

cust_silver_path = f"{PATH_SILVER}/dim_customer"
save_as_table(cust, DB_SILVER, "dim_customer", cust_silver_path, mode="overwrite")
spark.sql(f"CREATE OR REPLACE VIEW {DB_SILVER}.dim_customer_current AS SELECT * FROM {DB_SILVER}.dim_customer WHERE is_current = true")
audit_table_metrics(f"{DB_SILVER}.dim_customer", "SILVER", "rowcount", str(cust.count()))


In [0]:
# Initialize dim if not exists
if not delta_exists(trf_silver_path):
    trf_snap.write.format("delta").mode("overwrite").save(trf_silver_path)
    spark.sql(f"CREATE TABLE IF NOT EXISTS {DB_SILVER}.dim_transformer USING DELTA LOCATION '{trf_silver_path}'")
else:
    spark.sql(f"CREATE TABLE IF NOT EXISTS {DB_SILVER}.dim_transformer USING DELTA LOCATION '{trf_silver_path}'")

# Determine business key
# (Assume transformer_id exists; if not, fail fast)
assert "transformer_id" in trf_delta.columns, "transformer_id not found in transformer delta"



In [0]:
# Columns to compare for change detection (exclude SCD + metadata)
scd_cols = {"record_effective_date","record_end_date","is_current","scd_version"}
meta_cols = {c for c in trf_delta.columns if c.startswith("_")}
compare_cols = [c for c in trf_delta.columns if c not in scd_cols and c not in meta_cols]


In [0]:
# Join to compute next version if needed
dim_current = spark.read.format("delta").load(trf_silver_path).filter("is_current = true").select("transformer_id","scd_version")
incoming = (trf_delta.alias("d")
    .join(dim_current.alias("c"), on="transformer_id", how="left")
    .withColumn("next_scd_version", F.coalesce(F.col("c.scd_version") + F.lit(1), F.lit(1)))
    .drop("scd_version")  # we will use next_scd_version instead
    .withColumnRenamed("next_scd_version", "scd_version")
)

In [0]:
# Join to compute next version if needed
dim_current = spark.read.format("delta").load(trf_silver_path).filter("is_current = true").select("transformer_id","scd_version")
incoming = (trf_delta.alias("d")
    .join(dim_current.alias("c"), on="transformer_id", how="left")
    .withColumn("next_scd_version", F.coalesce(F.col("c.scd_version") + F.lit(1), F.lit(1)))
    .drop("scd_version")  # we will use next_scd_version instead
    .withColumnRenamed("next_scd_version", "scd_version")
)

In [0]:
# MERGE: expire current row if changed; insert new row as current
dt = DeltaTable.forPath(spark, trf_silver_path)

(dt.alias("t")
 .merge(incoming.alias("s"), "t.transformer_id = s.transformer_id AND t.is_current = true")
 .whenMatchedUpdate(
     condition=change_pred,
     set={
         "record_end_date": F.date_sub(F.to_date(F.col("s.record_effective_date")), 1),
         "is_current": F.lit(False)
     }
 )
 .whenNotMatchedInsert(values={c: F.col(f"s.{c}") for c in incoming.columns})
 .execute()
)

In [0]:

# After expiring, we must insert new versions for changed keys.
# Keys that were matched+changed won't insert in above merge (because they matched).
# So we insert them explicitly.
changed_keys = (spark.read.format("delta").load(trf_silver_path)
    .filter("is_current = false")
    .select("transformer_id")
    .join(incoming.select("transformer_id").distinct(), on="transformer_id", how="inner")
    .select("transformer_id").distinct()
)

to_insert = incoming.join(changed_keys, on="transformer_id", how="inner")


In [0]:
# Insert new current versions
(to_insert
 .write.format("delta")
 .mode("append")
 .save(trf_silver_path)
)

spark.sql(f"CREATE OR REPLACE VIEW {DB_SILVER}.dim_transformer_current AS SELECT * FROM {DB_SILVER}.dim_transformer WHERE is_current = true")
audit_table_metrics(f"{DB_SILVER}.dim_transformer", "SILVER", "rowcount", str(spark.table(f"{DB_SILVER}.dim_transformer").count()))


In [0]:
# -----------------------------------------
# 6) Renewable Production Silver: standardize schema evolution output
# -----------------------------------------
renew_bronze = spark.table(f"{DB_BRONZE}.renewable_production")

# Make timestamp column consistent
# Many JSON feeds include 'timestamp' or can derive from filename. We'll handle both.
renew = renew_bronze
if "timestamp" in renew.columns:
    renew = renew.withColumn("event_ts", F.to_timestamp("timestamp")).drop("timestamp")
elif "datetime" in renew.columns:
    renew = renew.withColumn("event_ts", F.to_timestamp("datetime")).drop("datetime")
else:
    # Fallback: attempt parse from filename like renewable_production_YYYYMMDD_HHMM.json
    renew = (renew
        .withColumn("_fn", F.regexp_extract(F.col("_source_file"), r"renewable_production_(\d{8})_(\d{4})", 0))
        .withColumn("event_ts",
            F.to_timestamp(
                F.concat_ws(" ",
                    F.regexp_extract(F.col("_source_file"), r"renewable_production_(\d{8})_", 1),
                    F.regexp_extract(F.col("_source_file"), r"_(\d{4})\.json", 1)
                ),
                "yyyyMMdd HHmm"
            )
        )
        .drop("_fn")
    )

renew = renew.withColumn("event_date", F.to_date("event_ts"))

# Ensure evolving columns exist (curtailment_mw, battery_storage_mwh) even if null
for c in ["curtailment_mw", "battery_storage_mwh"]:
    if c not in renew.columns:
        renew = renew.withColumn(c, F.lit(None).cast("double"))

renew_silver_path = f"{PATH_SILVER}/renewable_production"
save_as_table(renew, DB_SILVER, "renewable_production", renew_silver_path, mode="overwrite")
audit_table_metrics(f"{DB_SILVER}.renewable_production", "SILVER", "rowcount", str(renew.count()))


In [0]:
# -----------------------------------------
# 7) Billing CDC Apply -> Silver current-state table (MERGE)
# -----------------------------------------
billing_bronze = spark.table(f"{DB_BRONZE}.billing_cdc_events")

# Standardize types
bill = (billing_bronze
    .withColumn("cdc_ts", F.to_timestamp("cdc_timestamp"))
    .withColumn("billing_period_start", F.to_date("billing_period_start"))
    .withColumn("billing_period_end", F.to_date("billing_period_end"))
    .withColumn("due_date", F.to_date("due_date"))
)


In [0]:
# CDC semantics:
# 1 = DELETE, 2 = INSERT, 3 = BEFORE-IMAGE (ignore), 4 = AFTER-IMAGE/UPDATE (upsert)
# We'll use transaction_id as natural key (exists in the file).
assert "transaction_id" in bill.columns, "Expected transaction_id in billing CDC feed"

# Keep latest event per transaction_id (tie-break seqval if present)
order_cols = [F.col("cdc_ts").desc()]
if "__$seqval" in bill.columns:
    order_cols.append(F.col("__$seqval").desc())
order_cols.append(F.col("_ingestion_ts").desc())

w_cdc = W.partitionBy("transaction_id").orderBy(*order_cols)
bill_latest = (bill
    .withColumn("_rn", F.row_number().over(w_cdc))
    .filter(F.col("_rn") == 1)
    .drop("_rn")
)

In [0]:
# Split actions
bill_upserts = bill_latest.filter(F.col("__$operation").isin([2,4]))
bill_deletes = bill_latest.filter(F.col("__$operation") == 1).select("transaction_id").distinct()

billing_silver_path = f"{PATH_SILVER}/billing_transactions_current"


In [0]:
# Initialize target table if needed
if not delta_exists(billing_silver_path):
    # Create empty Delta with the upsert schema
    (bill_upserts.limit(0)
        .write.format("delta")
        .mode("overwrite")
        .save(billing_silver_path)
    )
    spark.sql(f"CREATE TABLE IF NOT EXISTS {DB_SILVER}.billing_transactions_current USING DELTA LOCATION '{billing_silver_path}'")
else:
    spark.sql(f"CREATE TABLE IF NOT EXISTS {DB_SILVER}.billing_transactions_current USING DELTA LOCATION '{billing_silver_path}'")

target = DeltaTable.forPath(spark, billing_silver_path)


In [0]:
# Apply deletes first (clean)
if bill_deletes.count() > 0:
    (target.alias("t")
        .merge(bill_deletes.alias("s"), "t.transaction_id = s.transaction_id")
        .whenMatchedDelete()
        .execute()
    )

# Apply upserts
# (insert new rows, update existing)
if bill_upserts.count() > 0:
    set_map = {c: F.col(f"s.{c}") for c in bill_upserts.columns}
    (target.alias("t")
        .merge(bill_upserts.alias("s"), "t.transaction_id = s.transaction_id")
        .whenMatchedUpdate(set=set_map)
        .whenNotMatchedInsert(values=set_map)
        .execute()
    )

audit_table_metrics(f"{DB_SILVER}.billing_transactions_current", "SILVER", "rowcount", str(spark.table(f"{DB_SILVER}.billing_transactions_current").count()))

print("✅ Silver transforms completed.")