In [0]:
# Databricks notebook source
# =========================================
# Notebook: 04_gold_marts
# Purpose : Gold layer marts + KPIs + exec dashboard table
# =========================================
%run ./01_config
from pyspark.sql import functions as F, Window as W

def save_gold(df, table, path, partition_cols=None, mode="overwrite"):
    writer = df.write.format("delta").mode(mode)
    if partition_cols:
        writer = writer.partitionBy(*partition_cols)
    writer.save(path)
    spark.sql(f"CREATE TABLE IF NOT EXISTS {DB_GOLD}.{table} USING DELTA LOCATION '{path}'")

# -----------------------------------------
# Load Silver tables
# -----------------------------------------
sm  = spark.table(f"{DB_SILVER}.smart_meter_readings")        # event_ts, kwh_consumed, region/state, is_kwh_anomaly
tel = spark.table(f"{DB_SILVER}.substation_telemetry")        # telemetry_ts, transformer_id, risk_score, risk_level
mnt = spark.table(f"{DB_SILVER}.maintenance_logs")
cust_cur = spark.table(f"{DB_SILVER}.dim_customer_current")   # current customers
trf_cur  = spark.table(f"{DB_SILVER}.dim_transformer_current")# current transformers
bill_cur = spark.table(f"{DB_SILVER}.billing_transactions_current")
ren = spark.table(f"{DB_SILVER}.renewable_production")

# -----------------------------------------
# Helper time columns
# -----------------------------------------
sm = sm.withColumn("hour_ts", F.date_trunc("hour", F.col("event_ts")))
tel = tel.withColumn("hour_ts", F.date_trunc("hour", F.col("telemetry_ts")))
ren = ren.withColumn("hour_ts", F.date_trunc("hour", F.col("event_ts")))

# -----------------------------------------
# 1) Gold: Customer Hourly Usage
# - Aggregates smart meter readings hourly
# - Joins customer dim using customer_id in customer master:
#   customer_master has meter_id and customer_id (per guide) :contentReference[oaicite:1]{index=1}
# -----------------------------------------
# If customer dim contains meter_id, join; otherwise we keep meter_id-level output.
joinable = ("meter_id" in cust_cur.columns)

sm_hourly = (sm
    .groupBy("hour_ts", "event_date", "meter_id", "region", "state")
    .agg(
        F.sum("kwh_consumed").alias("kwh_total"),
        F.avg("voltage_v").alias("voltage_avg"),
        F.avg("current_a").alias("current_avg"),
        F.avg("power_factor").alias("pf_avg"),
        F.sum(F.col("is_kwh_anomaly").cast("int")).alias("anomaly_events"),
        F.count("*").alias("reading_events")
    )
    .withColumn("anomaly_rate", F.col("anomaly_events") / F.col("reading_events"))
)

if joinable:
    sm_hourly = (sm_hourly
        .join(cust_cur.select("customer_id","meter_id","region","state"), on=["meter_id"], how="left")
    )

gold_customer_hourly_path = f"{PATH_GOLD}/customer_hourly_usage"
save_gold(sm_hourly, "customer_hourly_usage", gold_customer_hourly_path, partition_cols=["event_date"])
audit_table_metrics(f"{DB_GOLD}.customer_hourly_usage", "GOLD", "rowcount", str(sm_hourly.count()))

# -----------------------------------------
# 2) Gold: Transformer Health Hourly
# - Aggregates telemetry hourly and builds health KPIs
# - Joins transformer dimension for capacity/installation context
# -----------------------------------------
tel_hourly = (tel
    .groupBy("hour_ts", "event_date", "transformer_id", "region", "state")
    .agg(
        F.max("risk_score").alias("risk_score_max"),
        F.avg("risk_score").alias("risk_score_avg"),
        F.expr("percentile_approx(oil_temp_delta_c, 0.95)").alias("oil_temp_delta_p95"),
        F.max("dissolved_gas_ppm").alias("dissolved_gas_max"),
        F.max("load_pct").alias("load_pct_max"),
        F.sum(F.col("risk_alarm")).alias("alarm_events"),
        F.count("*").alias("telemetry_events")
    )
    .withColumn("risk_level",
        F.when(F.col("risk_score_max") >= 6, F.lit("HIGH"))
         .when(F.col("risk_score_max") >= 3, F.lit("MEDIUM"))
         .otherwise(F.lit("LOW"))
    )
)

# Join transformer current dim (if columns exist)
trf_join_cols = ["transformer_id"]
tel_hourly_enriched = tel_hourly.join(trf_cur, on=trf_join_cols, how="left")

gold_transformer_health_path = f"{PATH_GOLD}/transformer_health_hourly"
save_gold(tel_hourly_enriched, "transformer_health_hourly", gold_transformer_health_path, partition_cols=["event_date"])
audit_table_metrics(f"{DB_GOLD}.transformer_health_hourly", "GOLD", "rowcount", str(tel_hourly_enriched.count()))

# -----------------------------------------
# 3) Gold: Billing KPIs Daily
# - Paid vs pending vs disputed, fraud rate, outstanding amount
# -----------------------------------------
# Normalize date grain: use due_date if exists else cdc_ts date
bill = bill_cur
if "due_date" in bill.columns:
    bill = bill.withColumn("bill_date", F.col("due_date"))
elif "cdc_ts" in bill.columns:
    bill = bill.withColumn("bill_date", F.to_date("cdc_ts"))
else:
    bill = bill.withColumn("bill_date", F.current_date())

# Some schemas may have amount columns named differently; handle common ones
amount_col = None
for c in ["amount_due","amount","total_amount","bill_amount"]:
    if c in bill.columns:
        amount_col = c
        break

if amount_col is None:
    # fallback: create a zero amount
    bill = bill.withColumn("amount_due", F.lit(0.0))
    amount_col = "amount_due"

# customer_id expected present (as per CDC guide examples)
status_col = "payment_status" if "payment_status" in bill.columns else None
fraud_col = "fraud_flag" if "fraud_flag" in bill.columns else None

bill_daily = (bill
    .groupBy("bill_date")
    .agg(
        F.count("*").alias("tx_count"),
        F.sum(F.col(amount_col)).alias("amount_total"),
        F.sum(F.when(F.col(status_col) == "PAID", F.col(amount_col)).otherwise(F.lit(0.0))).alias("amount_paid") if status_col else F.lit(None).cast("double").alias("amount_paid"),
        F.sum(F.when(F.col(status_col) != "PAID", F.col(amount_col)).otherwise(F.lit(0.0))).alias("amount_outstanding") if status_col else F.lit(None).cast("double").alias("amount_outstanding"),
        F.sum(F.when(F.col(status_col) == "PAID", 1).otherwise(0)).alias("paid_count") if status_col else F.lit(None).cast("bigint").alias("paid_count"),
        F.sum(F.when(F.col(status_col) == "PENDING", 1).otherwise(0)).alias("pending_count") if status_col else F.lit(None).cast("bigint").alias("pending_count"),
        F.sum(F.when(F.col(status_col) == "DISPUTED", 1).otherwise(0)).alias("disputed_count") if status_col else F.lit(None).cast("bigint").alias("disputed_count"),
        F.sum(F.when(F.col(fraud_col) == True, 1).otherwise(0)).alias("fraud_count") if fraud_col else F.lit(None).cast("bigint").alias("fraud_count")
    )
    .withColumn("fraud_rate", F.col("fraud_count") / F.col("tx_count") if fraud_col else F.lit(None).cast("double"))
)

gold_billing_daily_path = f"{PATH_GOLD}/billing_kpis_daily"
save_gold(bill_daily, "billing_kpis_daily", gold_billing_daily_path, partition_cols=["bill_date"])
audit_table_metrics(f"{DB_GOLD}.billing_kpis_daily", "GOLD", "rowcount", str(bill_daily.count()))

# -----------------------------------------
# 4) Gold: Renewable KPIs Hourly
# - production vs grid_injection, curtailment, storage (if exists)
# - schema evolves: curtailment_mw, battery_storage_mwh may be null :contentReference[oaicite:2]{index=2}
# -----------------------------------------
# Identify column names in renewable feed
prod_col = None
inj_col = None
for c in ["production_mw","generation_mw","total_generation_mw"]:
    if c in ren.columns: prod_col = c; break
for c in ["grid_injection_mw","injection_mw","grid_export_mw"]:
    if c in ren.columns: inj_col = c; break

if prod_col is None:
    ren = ren.withColumn("production_mw", F.lit(0.0))
    prod_col = "production_mw"
if inj_col is None:
    ren = ren.withColumn("grid_injection_mw", F.lit(None).cast("double"))
    inj_col = "grid_injection_mw"

for c in ["curtailment_mw","battery_storage_mwh"]:
    if c not in ren.columns:
        ren = ren.withColumn(c, F.lit(None).cast("double"))

ren_hourly = (ren
    .groupBy("hour_ts", F.to_date("hour_ts").alias("event_date"), "plant_id" if "plant_id" in ren.columns else F.lit("ALL").alias("plant_id"))
    .agg(
        F.sum(F.col(prod_col)).alias("production_mw_sum"),
        F.sum(F.col(inj_col)).alias("grid_injection_mw_sum"),
        F.sum(F.coalesce(F.col("curtailment_mw"), F.lit(0.0))).alias("curtailment_mw_sum"),
        F.max(F.col("battery_storage_mwh")).alias("battery_storage_mwh_max")
    )
    .withColumn("curtailment_rate",
        F.when(F.col("production_mw_sum") > 0, F.col("curtailment_mw_sum") / F.col("production_mw_sum")).otherwise(F.lit(0.0))
    )
)

gold_renewable_hourly_path = f"{PATH_GOLD}/renewable_kpis_hourly"
save_gold(ren_hourly, "renewable_kpis_hourly", gold_renewable_hourly_path, partition_cols=["event_date"])
audit_table_metrics(f"{DB_GOLD}.renewable_kpis_hourly", "GOLD", "rowcount", str(ren_hourly.count()))

# -----------------------------------------
# 5) Gold: Executive Ops Daily Dashboard
# One row per day per region (consumption, anomalies, transformer risk mix, billing outstanding, renewable curtailment)
# -----------------------------------------
# Daily energy
energy_daily = (sm_hourly
    .groupBy("event_date", "region", "state")
    .agg(
        F.sum("kwh_total").alias("kwh_total"),
        F.sum("anomaly_events").alias("meter_anomaly_events"),
        F.sum("reading_events").alias("meter_reading_events")
    )
    .withColumn("meter_anomaly_rate",
        F.when(F.col("meter_reading_events") > 0, F.col("meter_anomaly_events") / F.col("meter_reading_events")).otherwise(F.lit(0.0))
    )
)

# Daily transformer risk mix
risk_daily = (tel_hourly
    .groupBy("event_date", "region", "state")
    .agg(
        F.sum(F.when(F.col("risk_level") == "HIGH", 1).otherwise(0)).alias("transformers_high_risk_hours"),
        F.sum(F.when(F.col("risk_level") == "MEDIUM", 1).otherwise(0)).alias("transformers_medium_risk_hours"),
        F.sum(F.when(F.col("risk_level") == "LOW", 1).otherwise(0)).alias("transformers_low_risk_hours"),
        F.count("*").alias("transformer_hours")
    )
    .withColumn("high_risk_rate",
        F.when(F.col("transformer_hours") > 0, F.col("transformers_high_risk_hours")/F.col("transformer_hours")).otherwise(F.lit(0.0))
    )
)

# Billing outstanding daily (global, join later as same date for dashboard)
bill_daily_for_join = bill_daily.select(
    F.col("bill_date").alias("event_date"),
    "tx_count","amount_total","amount_paid","amount_outstanding","fraud_rate"
)

# Renewable daily (global)
ren_daily = (ren_hourly
    .groupBy("event_date")
    .agg(
        F.sum("production_mw_sum").alias("renewable_production_mw"),
        F.sum("grid_injection_mw_sum").alias("renewable_injection_mw"),
        F.sum("curtailment_mw_sum").alias("renewable_curtailment_mw")
    )
    .withColumn("renewable_curtailment_rate",
        F.when(F.col("renewable_production_mw") > 0, F.col("renewable_curtailment_mw")/F.col("renewable_production_mw")).otherwise(F.lit(0.0))
    )
)

ops_daily = (energy_daily
    .join(risk_daily, on=["event_date","region","state"], how="left")
    .join(bill_daily_for_join, on=["event_date"], how="left")
    .join(ren_daily, on=["event_date"], how="left")
)

gold_ops_daily_path = f"{PATH_GOLD}/ops_daily_dashboard"
save_gold(ops_daily, "ops_daily_dashboard", gold_ops_daily_path, partition_cols=["event_date"])
audit_table_metrics(f"{DB_GOLD}.ops_daily_dashboard", "GOLD", "rowcount", str(ops_daily.count()))

print("âœ… Gold marts created:")
spark.sql(f"SHOW TABLES IN {DB_GOLD}").show(truncate=False)