In [0]:
# Databricks notebook source
# =========================================
# Notebook: 05_dq_audit_report
# Purpose : Data Quality + Audit Report (judge-friendly)
# =========================================
%run ./01_config
from pyspark.sql import functions as F

def write_dq(dataset, layer, rule_name, severity, df_violations, key_cols):
    """
    Stores aggregated DQ violation info into audit.dq_violations with sample keys.
    """
    cnt = df_violations.count()
    if cnt == 0:
        return

    # sample up to 20 keys as a compact string
    sample = (df_violations
        .select(*[F.col(c).cast("string") for c in key_cols if c in df_violations.columns])
        .limit(20)
        .toPandas()
    )
    sample_keys = sample.to_csv(index=False).strip().replace("\n", " | ") if len(sample) > 0 else ""

    out = spark.createDataFrame([(
        BATCH_ID, layer, dataset, rule_name, severity, int(cnt), sample_keys
    )], schema="batch_id string, layer string, dataset string, rule_name string, severity string, violation_count long, sample_keys string")

    (out
     .withColumn("created_ts", F.current_timestamp())
     .write.format("delta").mode("append").saveAsTable(f"{DB_AUDIT}.dq_violations")
    )

def metric(table_fqn, layer, name, value):
    audit_table_metrics(table_fqn, layer, name, str(value))

# -----------------------------------------
# Load tables
# -----------------------------------------
sm  = spark.table(f"{DB_SILVER}.smart_meter_readings")
tel = spark.table(f"{DB_SILVER}.substation_telemetry")
bill = spark.table(f"{DB_SILVER}.billing_transactions_current")
cust = spark.table(f"{DB_SILVER}.dim_customer_current")
trf  = spark.table(f"{DB_SILVER}.dim_transformer")  # full history
trf_cur = spark.table(f"{DB_SILVER}.dim_transformer_current")
gold_ops = spark.table(f"{DB_GOLD}.ops_daily_dashboard")

# -----------------------------------------
# 1) Core DQ Rules — Smart Meter
# -----------------------------------------
write_dq(
    dataset="smart_meter_readings",
    layer="SILVER",
    rule_name="PK_NOT_NULL_event_id",
    severity="CRITICAL",
    df_violations=sm.filter(F.col("event_id").isNull() | (F.trim("event_id") == "")),
    key_cols=["meter_id","event_ts","region"]
)

write_dq(
    dataset="smart_meter_readings",
    layer="SILVER",
    rule_name="PK_NOT_NULL_meter_id",
    severity="CRITICAL",
    df_violations=sm.filter(F.col("meter_id").isNull() | (F.trim("meter_id") == "")),
    key_cols=["event_id","event_ts","region"]
)

write_dq(
    dataset="smart_meter_readings",
    layer="SILVER",
    rule_name="NEGATIVE_KWH",
    severity="WARNING",
    df_violations=sm.filter(F.col("kwh_consumed") < 0),
    key_cols=["event_id","meter_id","kwh_consumed"]
)

write_dq(
    dataset="smart_meter_readings",
    layer="SILVER",
    rule_name="INVALID_TIMESTAMP",
    severity="CRITICAL",
    df_violations=sm.filter(F.col("event_ts").isNull()),
    key_cols=["event_id","meter_id"]
)

# Voltage plausibility (broad bounds)
write_dq(
    dataset="smart_meter_readings",
    layer="SILVER",
    rule_name="VOLTAGE_OUT_OF_RANGE",
    severity="WARNING",
    df_violations=sm.filter((F.col("voltage_v") < 150) | (F.col("voltage_v") > 300)),
    key_cols=["event_id","meter_id","voltage_v","region"]
)

# -----------------------------------------
# 2) Core DQ Rules — Telemetry
# -----------------------------------------
write_dq(
    dataset="substation_telemetry",
    layer="SILVER",
    rule_name="PK_NOT_NULL_telemetry_id",
    severity="CRITICAL",
    df_violations=tel.filter(F.col("telemetry_id").isNull() | (F.trim("telemetry_id") == "")),
    key_cols=["transformer_id","telemetry_ts","region"]
)

write_dq(
    dataset="substation_telemetry",
    layer="SILVER",
    rule_name="PK_NOT_NULL_transformer_id",
    severity="CRITICAL",
    df_violations=tel.filter(F.col("transformer_id").isNull() | (F.trim("transformer_id") == "")),
    key_cols=["telemetry_id","telemetry_ts","region"]
)

write_dq(
    dataset="substation_telemetry",
    layer="SILVER",
    rule_name="INVALID_TIMESTAMP",
    severity="CRITICAL",
    df_violations=tel.filter(F.col("telemetry_ts").isNull()),
    key_cols=["telemetry_id","transformer_id","region"]
)

# Oil temp delta plausibility (loose)
write_dq(
    dataset="substation_telemetry",
    layer="SILVER",
    rule_name="OIL_TEMP_DELTA_IMPLAUSIBLE",
    severity="WARNING",
    df_violations=tel.filter(F.col("oil_temp_delta_c") > 80),
    key_cols=["telemetry_id","transformer_id","oil_temp_delta_c","region"]
)

# -----------------------------------------
# 3) Core DQ Rules — Billing (current-state)
# -----------------------------------------
write_dq(
    dataset="billing_transactions_current",
    layer="SILVER",
    rule_name="PK_NOT_NULL_transaction_id",
    severity="CRITICAL",
    df_violations=bill.filter(F.col("transaction_id").isNull() | (F.trim("transaction_id") == "")),
    key_cols=["customer_id","cdc_ts"]
)

# Amount due plausibility (if present)
amt_col = None
for c in ["amount_due","amount","total_amount","bill_amount"]:
    if c in bill.columns:
        amt_col = c
        break

if amt_col:
    write_dq(
        dataset="billing_transactions_current",
        layer="SILVER",
        rule_name="NEGATIVE_AMOUNT",
        severity="WARNING",
        df_violations=bill.filter(F.col(amt_col) < 0),
        key_cols=["transaction_id","customer_id",amt_col]
    )

# -----------------------------------------
# 4) SCD2 sanity checks — Transformer dimension
# -----------------------------------------
# Rule: Only one current record per transformer_id
multi_current = (trf
    .filter("is_current = true")
    .groupBy("transformer_id")
    .count()
    .filter("count > 1")
)

write_dq(
    dataset="dim_transformer",
    layer="SILVER",
    rule_name="SCD2_MULTIPLE_CURRENT_RECORDS",
    severity="CRITICAL",
    df_violations=multi_current,
    key_cols=["transformer_id","count"]
)

# Rule: current record must have end_date = 9999-12-31 (if column exists)
if "record_end_date" in trf.columns:
    wrong_end = trf.filter("is_current = true AND record_end_date <> to_date('9999-12-31')")
    write_dq(
        dataset="dim_transformer",
        layer="SILVER",
        rule_name="SCD2_CURRENT_END_DATE_NOT_MAX",
        severity="WARNING",
        df_violations=wrong_end,
        key_cols=["transformer_id","record_end_date","scd_version"]
    )

# -----------------------------------------
# 5) Join Coverage / Referential sanity checks (Gold readiness)
# customer ↔ meter join coverage if meter_id exists in customer dim :contentReference[oaicite:0]{index=0}
# -----------------------------------------
if "meter_id" in cust.columns:
    joined = (sm.select("meter_id").distinct()
              .join(cust.select("meter_id").distinct(), on="meter_id", how="left")
              .withColumn("has_customer", F.col("meter_id").isNotNull() & F.col("meter_id").isNotNull())
             )
    # measure coverage as: meters in readings that exist in customer dim
    coverage = (sm.select("meter_id").distinct()
                .join(cust.select("meter_id").distinct().withColumn("hit", F.lit(1)), "meter_id", "left")
                .agg(
                    F.count("*").alias("meters_in_readings"),
                    F.sum(F.coalesce("hit", F.lit(0))).alias("meters_with_customer")
                )
                .collect()[0]
               )
    meters_in_readings = coverage["meters_in_readings"]
    meters_with_customer = coverage["meters_with_customer"]
    cov_rate = meters_with_customer / meters_in_readings if meters_in_readings else 0.0

    metric(f"{DB_SILVER}.smart_meter_readings", "SILVER", "meter_to_customer_join_coverage", f"{cov_rate:.4f}")

# telemetry ↔ transformer join coverage
tel_cov = (tel.select("transformer_id").distinct()
    .join(trf_cur.select("transformer_id").distinct().withColumn("hit", F.lit(1)), "transformer_id", "left")
    .agg(
        F.count("*").alias("transformers_in_telemetry"),
        F.sum(F.coalesce("hit", F.lit(0))).alias("transformers_found_in_dim")
    )
    .collect()[0]
)
tc_total = tel_cov["transformers_in_telemetry"]
tc_hit = tel_cov["transformers_found_in_dim"]
tc_rate = tc_hit / tc_total if tc_total else 0.0
metric(f"{DB_SILVER}.substation_telemetry", "SILVER", "telemetry_to_transformer_join_coverage", f"{tc_rate:.4f}")

# -----------------------------------------
# 6) Pipeline health summary metrics (min/max timestamps)
# -----------------------------------------
sm_ts = sm.agg(F.min("event_ts").alias("min_ts"), F.max("event_ts").alias("max_ts")).collect()[0]
tel_ts = tel.agg(F.min("telemetry_ts").alias("min_ts"), F.max("telemetry_ts").alias("max_ts")).collect()[0]

metric(f"{DB_SILVER}.smart_meter_readings", "SILVER", "min_event_ts", sm_ts["min_ts"])
metric(f"{DB_SILVER}.smart_meter_readings", "SILVER", "max_event_ts", sm_ts["max_ts"])
metric(f"{DB_SILVER}.substation_telemetry", "SILVER", "min_telemetry_ts", tel_ts["min_ts"])
metric(f"{DB_SILVER}.substation_telemetry", "SILVER", "max_telemetry_ts", tel_ts["max_ts"])

# -----------------------------------------
# 7) Gold sanity checks (ops dashboard not empty)
# -----------------------------------------
ops_cnt = gold_ops.count()
metric(f"{DB_GOLD}.ops_daily_dashboard", "GOLD", "rowcount", ops_cnt)

if ops_cnt == 0:
    write_dq(
        dataset="ops_daily_dashboard",
        layer="GOLD",
        rule_name="GOLD_EMPTY_TABLE",
        severity="CRITICAL",
        df_violations=gold_ops,  # empty
        key_cols=["event_date","region"]
    )

# -----------------------------------------
# 8) Print judge-friendly report
# -----------------------------------------
print("\n======================")
print("✅ DATA QUALITY REPORT")
print("======================\n")

print("Latest run batch_id:", BATCH_ID)

print("\n-- DQ Violations (this run) --")
spark.sql(f"""
SELECT severity, dataset, rule_name, violation_count, sample_keys
FROM {DB_AUDIT}.dq_violations
WHERE batch_id = '{BATCH_ID}'
ORDER BY
  CASE severity WHEN 'CRITICAL' THEN 1 WHEN 'WARNING' THEN 2 ELSE 3 END,
  dataset, rule_name
""").show(200, truncate=False)

print("\n-- Pipeline Metrics (this run) --")
spark.sql(f"""
SELECT layer, table_fqn, metric_name, metric_value
FROM {DB_AUDIT}.table_metrics
WHERE batch_id = '{BATCH_ID}'
ORDER BY layer, table_fqn, metric_name
""").show(200, truncate=False)

print("\n-- Gold Executive Snapshot (sample) --")
spark.table(f"{DB_GOLD}.ops_daily_dashboard").orderBy(F.col("event_date").desc()).show(30, truncate=False)

print("\n✅ DQ + audit report completed.")