In [0]:
from pyspark.sql.functions import * 
from pyspark.sql.types import * 
from delta.tables import DeltaTable 

######Parameters

In [0]:
dbutils.widgets.dropdown("mode","incremental",["initial","incremental"])
dbutils.widgets.text("run_date","2024-01-01")
mode = dbutils.widgets.get("mode").lower().strip()
run_date = dbutils.widgets.get("run_date").strip() 

######Table names

In [0]:
SILVER_TRIPS = "databricks_catalog.silver.NYC_Taxi_Trips"          
DIM_LOC_SCD2 = "databricks_catalog.gold.dim_location_scd2"         
GOLD_FACT    = "databricks_catalog.gold.fact_trip"     

####Data Reading

In [0]:
df_silver=spark.table(SILVER_TRIPS)

In [0]:
if "trip_key" not in df_silver.columns:
    raise Exception("trip_key is missing in Silver. Create it in Silver first (1 row per trip).")

#filter for incremental mode: only process pickup_date >= run_date

if mode == "incremental":
    df_silver = df_silver.filter(col("trip_date") >= to_date(lit(run_date)))

####Read Dimension

In [0]:
df_loc = spark.table(DIM_LOC_SCD2)

In [0]:
df_loc=(
    df_loc
    .withColumn("effective_to_filled", coalesce(col("effective_to"), lit("2999-12-31").cast("timestamp")))
    .select(
        col("LocationID").alias("location_id"),
        col("location_sk"),
        col("effective_from"),
        col("effective_to_filled")
    )
)

#####Join trips to dim_location_scd2 twice (PU and DO)

In [0]:
fact_joined = (
    df_silver.alias("t")
    .join(
        df_loc.alias("pu"),
        (col("t.PULocationID") == col("pu.location_id")) &
        (col(f"t.tpep_pickup_datetime") >=col("pu.effective_from")) &
        (col(f"t.tpep_pickup_datetime") < col("pu.effective_to_filled")),
        "left"
    )
    .join(
        df_loc.alias("do"),
        (col("t.DOLocationID") == col("do.location_id")) &
        (col(f"t.tpep_dropoff_datetime") >=col("do.effective_from")) &
        (col(f"t.tpep_dropoff_datetime") < col("do.effective_to_filled")),
        "left"
    )
)


####Select final fact columns

In [0]:
fact = fact_joined.select(
    col("t.trip_key"),
    col(f"t.tpep_pickup_datetime").alias("pickup_ts"),
    col(f"t.tpep_dropoff_datetime").alias("dropoff_ts"),
    col("t.trip_date"),

    # Measures
    col("t.trip_distance"),
    col("t.passenger_count"),
    col("t.fare_amount"),
    col("t.extra"),
    col("t.mta_tax"),
    col("t.tip_amount"),
    col("t.tolls_amount"),
    col("t.improvement_surcharge"),
    col("t.total_amount"),

    # Low-cardinality attributes (optional: keep them here or move to dims)
    col("t.payment_type"),
    col("t.store_and_fwd_flag"),

    # Foreign keys (SCD2-resolved)
    col("pu.location_sk").alias("pickup_location_sk"),
    col("do.location_sk").alias("dropoff_location_sk")
)

######Data quality gate (fail fast)

In [0]:
#No duplicates on trip_key
dup_cnt = (
    fact.groupBy("trip_key")
        .count()
        .filter(col("count") > 1)
        .count()
)
if dup_cnt > 0:
    raise Exception(f"Data Quality FAILED: fact_trip has {dup_cnt} duplicate trip_key values.")

# No null foreign keys (at least pickup/dropoff location)
null_fk_cnt = fact.filter(
   col("pickup_location_sk").isNull() | col("dropoff_location_sk").isNull()
).count()
if null_fk_cnt > 0:
    raise Exception(f"Data Quality FAILED: {null_fk_cnt} rows have null pickup/dropoff location_sk.")

####Write Gold fact (INSERT ONLY)

In [0]:
if not spark.catalog.tableExists(GOLD_FACT) or mode == "initial":
    # Initial load (or forced initial)
    (fact.write
        .format("delta")
        .mode("overwrite")
        .option("overwriteSchema", "true")
        .partitionBy("trip_date")
        .saveAsTable(GOLD_FACT)
    )
else:
    # Incremental upsert
    dt = DeltaTable.forName(spark, GOLD_FACT)

    (dt.alias("trg")
      .merge(
          fact.alias("src"),
          "trg.trip_key = src.trip_key"
      )
      .whenMatchedUpdateAll()
      .whenNotMatchedInsertAll()
      .execute()
    )

In [0]:
print("DONE")
print("mode =", mode, "| run_date =", run_date)
print("source rows processed =", fact.count())
print("target total rows     =", spark.table(GOLD_FACT).count())