In [0]:
from pyspark.sql import functions as F

TOL = F.lit(1.00).cast("decimal(18,2)")

# LTI derived month series (already gap-filled)
lti = (
  spark.table("tp_finance.silver.facility_monthly_balance")
    .select(
      "customer_id","contract_id","facility_id","month",
      F.col("drawn_this_month").cast("decimal(18,2)").alias("lti_drawn_this_month"),
      F.col("repaid_this_month").cast("decimal(18,2)").alias("lti_repaid_this_month"),
      F.col("net_movement").cast("decimal(18,2)").alias("lti_net_movement"),
      F.col("opening_balance").cast("decimal(18,2)").alias("lti_opening_balance"),
      F.col("closing_balance").cast("decimal(18,2)").alias("lti_closing_balance"),
      F.col("source_system").alias("lti_source_system")
    )
)

# Historic snapshot
hist = (
  spark.table("tp_finance.bronze.sp_facility_monthly_snapshot")
    .select(
      "customer_id","contract_id","facility_id","month",
      F.col("currency").alias("currency"),
      F.col("opening_balance").cast("decimal(18,2)").alias("hist_opening_balance"),
      F.col("closing_balance").cast("decimal(18,2)").alias("hist_closing_balance"),
      F.col("drawn_this_month").cast("decimal(18,2)").alias("hist_drawn_this_month"),
      F.col("repaid_this_month").cast("decimal(18,2)").alias("hist_repaid_this_month"),
      F.col("source_system").alias("hist_source_system")
    )
)

j = lti.join(hist, on=["customer_id","contract_id","facility_id","month"], how="full")

canon = (
  j.withColumn(
      "balance_source",
      F.when(F.col("lti_closing_balance").isNotNull(), F.lit("LTI_TX"))
       .when(F.col("hist_closing_balance").isNotNull(), F.lit("HIST_SNAPSHOT"))
       .otherwise(F.lit("UNKNOWN"))
  )
  .withColumn("drawn_this_month",
      F.coalesce(F.col("lti_drawn_this_month"), F.col("hist_drawn_this_month"), F.lit(0)).cast("decimal(18,2)")
  )
  .withColumn("repaid_this_month",
      F.coalesce(F.col("lti_repaid_this_month"), F.col("hist_repaid_this_month"), F.lit(0)).cast("decimal(18,2)")
  )
  .withColumn("net_movement",
      (F.col("drawn_this_month") - F.col("repaid_this_month")).cast("decimal(18,2)")
  )
  .withColumn("opening_balance",
      F.coalesce(F.col("lti_opening_balance"), F.col("hist_opening_balance")).cast("decimal(18,2)")
  )
  .withColumn("closing_balance",
      F.coalesce(F.col("lti_closing_balance"), F.col("hist_closing_balance")).cast("decimal(18,2)")
  )
  .withColumn(
      "balance_diff",
      F.when(
          F.col("lti_closing_balance").isNotNull() & F.col("hist_closing_balance").isNotNull(),
          F.abs(F.col("lti_closing_balance") - F.col("hist_closing_balance")).cast("decimal(18,2)")
      ).otherwise(F.lit(None).cast("decimal(18,2)"))
  )
  .withColumn(
      "is_mismatch",
      F.when(F.col("balance_diff").isNotNull() & (F.col("balance_diff") > TOL), F.lit(True)).otherwise(F.lit(False))
  )
  .withColumn(
      "is_backfilled",
      F.when(F.col("lti_closing_balance").isNull() & F.col("hist_closing_balance").isNotNull(), F.lit(True))
       .otherwise(F.lit(False))
  )
  .withColumn("load_ts", F.current_timestamp())
  .select(
      "customer_id","contract_id","facility_id","month","currency",
      "drawn_this_month","repaid_this_month","net_movement",
      "opening_balance","closing_balance",
      "lti_opening_balance","lti_closing_balance",
      "hist_opening_balance","hist_closing_balance",
      "balance_source","balance_diff","is_mismatch","is_backfilled",
      "load_ts"
  )
)

canon.createOrReplaceTempView("stg_canon")

spark.sql("""
MERGE INTO tp_finance.silver.contract_facility_monthly_canonical t
USING stg_canon s
ON  t.customer_id = s.customer_id
AND t.contract_id = s.contract_id
AND t.facility_id = s.facility_id
AND t.month       = s.month
WHEN MATCHED THEN UPDATE SET *
WHEN NOT MATCHED THEN INSERT *
""")