In [0]:
# Databricks notebook: 05_gold/01_build_gold_tables
from pyspark.sql import functions as F

# -----------------------------------------------------------------------------
# 0) Guardrails / quick visibility
# -----------------------------------------------------------------------------
print("Building Gold tables from Silver...")

# Source of truth for Gold facility-month grain
silver_canon = "tp_finance.silver.contract_facility_monthly_canonical"

# Gold targets
gold_facility_fact = "tp_finance.gold.fact_facility_monthly"
gold_contract_summary = "tp_finance.gold.fact_contract_monthly_summary"
gold_dim_month = "tp_finance.gold.dim_month"

# -----------------------------------------------------------------------------
# 1) Build/Upsert Gold facility-month fact
# -----------------------------------------------------------------------------
src_fac = spark.table(silver_canon).select(
    "customer_id","contract_id","facility_id","month","currency",
    F.col("drawn_this_month").cast("decimal(18,2)").alias("drawn_this_month"),
    F.col("repaid_this_month").cast("decimal(18,2)").alias("repaid_this_month"),
    F.col("net_movement").cast("decimal(18,2)").alias("net_movement"),
    F.col("opening_balance").cast("decimal(18,2)").alias("opening_balance"),
    F.col("closing_balance").cast("decimal(18,2)").alias("closing_balance"),
    "balance_source",
    F.col("balance_diff").cast("decimal(18,2)").alias("balance_diff"),
    "is_mismatch",
    "is_backfilled",
    "load_ts",
)

src_fac.createOrReplaceTempView("stg_gold_facility_monthly")

spark.sql(f"""
MERGE INTO {gold_facility_fact} t
USING stg_gold_facility_monthly s
ON  t.customer_id = s.customer_id
AND t.contract_id = s.contract_id
AND t.facility_id = s.facility_id
AND t.month       = s.month
WHEN MATCHED THEN UPDATE SET *
WHEN NOT MATCHED THEN INSERT *;
""")

fac_count = spark.table(gold_facility_fact).count()
print(f"Gold facility-month fact rows: {fac_count}")

# -----------------------------------------------------------------------------
# 2) Rebuild Gold contract-month summary deterministically
#    (TRUNCATE+INSERT avoids merge anomalies; it's standard for small/medium)
# -----------------------------------------------------------------------------
spark.sql(f"TRUNCATE TABLE {gold_contract_summary}")

spark.sql(f"""
INSERT INTO {gold_contract_summary}
SELECT
  customer_id,
  contract_id,
  month,
  MAX(currency) AS currency,

  CAST(SUM(drawn_this_month) AS DECIMAL(18,2))  AS total_drawn,
  CAST(SUM(repaid_this_month) AS DECIMAL(18,2)) AS total_repaid,
  CAST(SUM(net_movement) AS DECIMAL(18,2))      AS net_movement,
  CAST(SUM(closing_balance) AS DECIMAL(18,2))   AS closing_balance,

  COUNT(DISTINCT facility_id)                   AS facilities,
  SUM(CASE WHEN is_mismatch THEN 1 ELSE 0 END)  AS mismatched_rows,

  MAX(load_ts)                                  AS load_ts
FROM {gold_facility_fact}
GROUP BY customer_id, contract_id, month
""")

summary_count = spark.table(gold_contract_summary).count()
print(f"Gold contract-month summary rows: {summary_count}")

# -----------------------------------------------------------------------------
# 3) Rebuild dim_month from gold facility fact bounds (month grain)
# -----------------------------------------------------------------------------
bounds = spark.table(gold_facility_fact).agg(
    F.min("month").alias("min_m"),
    F.max("month").alias("max_m")
).collect()[0]

min_m = bounds["min_m"]
max_m = bounds["max_m"]

if min_m is None or max_m is None:
    print("No months found in gold fact. Skipping dim_month rebuild.")
else:
    months = (spark.range(1)
        .select(F.explode(F.sequence(F.lit(min_m), F.lit(max_m), F.expr("interval 1 month"))).alias("month"))
        .select(
            F.col("month").cast("date").alias("month"),
            F.date_format(F.col("month"), "yyyyMM").cast("int").alias("month_key"),
            F.year(F.col("month")).alias("year"),
            F.month(F.col("month")).alias("month_num"),
            F.date_format(F.col("month"), "MMMM").alias("month_name"),
            F.date_format(F.col("month"), "yyyy-MM").alias("year_month"),
        )
    )

    # overwrite dim table each run (standard for dims derived from facts)
    (months.write
        .mode("overwrite")
        .option("overwriteSchema", "true")
        .saveAsTable(gold_dim_month)
    )

    dim_count = spark.table(gold_dim_month).count()
    print(f"Gold dim_month rows: {dim_count}")

# -----------------------------------------------------------------------------
# 4) Optional: quick reconciliation assertion (should be 0 mismatches)
# -----------------------------------------------------------------------------
recon = spark.sql(f"""
SELECT count(*) AS bad_rows
FROM {gold_contract_summary} s
JOIN (
  SELECT customer_id, contract_id, month,
         CAST(SUM(drawn_this_month) AS DECIMAL(18,2))  AS total_drawn_check,
         CAST(SUM(repaid_this_month) AS DECIMAL(18,2)) AS total_repaid_check,
         CAST(SUM(closing_balance) AS DECIMAL(18,2))   AS closing_balance_check
  FROM {gold_facility_fact}
  GROUP BY customer_id, contract_id, month
) f
ON s.customer_id=f.customer_id AND s.contract_id=f.contract_id AND s.month=f.month
WHERE s.total_drawn <> f.total_drawn_check
   OR s.total_repaid <> f.total_repaid_check
   OR s.closing_balance <> f.closing_balance_check
""").collect()[0]["bad_rows"]

print("Reconciliation bad_rows:", recon)
if recon != 0:
    raise Exception(f"Gold reconciliation failed: {recon} mismatched contract-month rows")

print("Gold build complete.")