In [0]:
# Databricks notebook source
import dlt
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.window import Window

spark.sql("USE CATALOG midterm")
spark.sql("USE SCHEMA source1_layer")

# ============================================
# DIM_RESTAURANT (SCD TYPE 2) - FIXED
# ============================================

@dlt.table(name="bronze_restaurant_cdf")
def bronze_restaurant_cdf():
    """Extract restaurant attributes with CLEANED license_no"""
    df = spark.read.table("midterm.source1_layer.silver_table")
    
    df = df.select(
        col("License_No"),
        col("DBA_Name"),
        col("AKA_Name"),
        col("Facility_Type"),
        col("Risk_Category"),
        col("City"),
        col("FileName")
    )
    
    # ✅ Remove city suffix from License_No
    df = df.withColumn("License_No",
        regexp_replace(
            regexp_replace(
                regexp_replace(
                    regexp_replace(col("License_No"), "_DALLAS$", ""),
                    "_CHICAGO$", ""),
                "_Dallas$", ""),
            "_Chicago$", ""))
    
    # Drop duplicates
    df = df.dropDuplicates(["License_No", "DBA_Name", "AKA_Name", 
                            "Facility_Type", "Risk_Category", "City"])
    
    # Add CDC metadata
    df = df.withColumn("_commit_timestamp", current_timestamp()) \
           .withColumn("_change_type", lit("insert"))
    
    return df

@dlt.view
def silver_restaurant_cdf():
    df = spark.readStream.table("LIVE.bronze_restaurant_cdf")
    
    return (df
        .withColumn("License_No", trim(col("License_No")))
        .withColumn("DBA_Name", trim(col("DBA_Name")))
        .withColumn("AKA_Name", 
            when(col("AKA_Name").isNull() | (trim(col("AKA_Name")) == ""), 
                 trim(col("DBA_Name")))
            .otherwise(trim(col("AKA_Name"))))
        .withColumn("Facility_Type", trim(col("Facility_Type")))
        .withColumn("Risk_Category",
            when(col("Risk_Category").isNull() | (trim(col("Risk_Category")) == ""),
                 lit("UNKNOWN"))
            .otherwise(trim(col("Risk_Category"))))
        .withColumn("City", trim(upper(col("City"))))
    )

dlt.create_streaming_table(name="restaurant_cdf_type2_stage")

dlt.apply_changes(
    target="restaurant_cdf_type2_stage",
    source="silver_restaurant_cdf",
    keys=["License_No", "City"],
    sequence_by=col("_commit_timestamp"),
    ignore_null_updates=True,
    apply_as_deletes=expr("_change_type = 'delete'"),
    stored_as_scd_type=2
)

@dlt.view
def silver_restaurant_dim_final():
    """Transform stage to final format - FIXED for new DLT version"""
    df = spark.read.table("LIVE.restaurant_cdf_type2_stage")
    
    return (df
        # ✅ FIX: __START_AT and __END_AT are now TIMESTAMP directly, not structs
        .withColumn("effective_date", to_date(col("__START_AT")))  # Changed!
        .withColumn("end_date", 
            when(col("__END_AT").isNull(), to_date(lit("9999-12-31")))  # Changed!
            .otherwise(to_date(col("__END_AT"))))  # Changed!
        .withColumn("is_current", 
            when(col("__END_AT").isNull(), lit(True))  # Changed!
            .otherwise(lit(False)))
        .withColumn("job_load_id", lit("dim_restaurant_scd2"))
        .withColumn("job_load_date", current_timestamp())
        .withColumnRenamed("License_No", "license_no")
        .withColumnRenamed("DBA_Name", "dba_name")
        .withColumnRenamed("AKA_Name", "aka_name")
        .withColumnRenamed("Facility_Type", "facility_type")
        .withColumnRenamed("Risk_Category", "risk_category")
        .withColumnRenamed("City", "city")
        .withColumnRenamed("FileName", "file_name")
        .drop("__START_AT", "__END_AT", "_commit_timestamp", "_change_type")
    )

@dlt.table(
    name="dim_restaurant",
    partition_cols=["is_current"]
)
def dim_restaurant():
    df = spark.read.table("LIVE.silver_restaurant_dim_final")
    
    window_spec = Window.partitionBy(lit(1)).orderBy("license_no", "city", "effective_date")
    df = df.withColumn("restaurant_key", row_number().over(window_spec))
    
    return df.select(
        "restaurant_key", "license_no", "dba_name", "aka_name", 
        "facility_type", "risk_category", "city", "effective_date", 
        "end_date", "is_current", "file_name", "job_load_id", "job_load_date"
    )

# ============================================
# OTHER DIMENSIONS (Keep as before)
# ============================================
@dlt.table(name="dim_date")
def dim_date():
    start_date = "2018-01-01"
    end_date = "2028-12-31"
    
    date_df = spark.sql(f"""
        SELECT explode(sequence(to_date('{start_date}'), to_date('{end_date}'), interval 1 day)) as full_date
    """)
    
    return (
        date_df
        .withColumn("date_key", concat(lpad(year(col("full_date")), 4, "0"),
                                       lpad(month(col("full_date")), 2, "0"),
                                       lpad(dayofmonth(col("full_date")), 2, "0")).cast("int"))
        .withColumn("year", year(col("full_date")))
        .withColumn("quarter", quarter(col("full_date")))
        .withColumn("month", month(col("full_date")))
        .withColumn("month_name", date_format(col("full_date"), "MMMM"))
        .withColumn("day", dayofmonth(col("full_date")))
        .withColumn("day_of_week", dayofweek(col("full_date")))
        .withColumn("day_name", date_format(col("full_date"), "EEEE"))
        .withColumn("week_of_year", weekofyear(col("full_date")))
        .withColumn("is_weekend", when(dayofweek(col("full_date")).isin([1, 7]), True).otherwise(False))
        .withColumn("job_load_id", lit("dim_date_initial_load"))
        .withColumn("job_load_date", current_timestamp())
        .select("date_key", "full_date", "year", "quarter", "month", "month_name",
                "day", "day_of_week", "day_name", "week_of_year", "is_weekend",
                "job_load_id", "job_load_date")
    )

@dlt.table(name="dim_location")
def dim_location():
    df = spark.read.table("midterm.source1_layer.silver_table")
    
    df_location = (df
        .select(
            trim(col("Address")).alias("address"),
            trim(upper(col("City"))).alias("city"),
            col("Zip_Code").cast("int").alias("zip_code"),
            col("Latitude").alias("latitude"),
            col("Longitude").alias("longitude")
        )
        .filter(col("address").isNotNull())
        .filter(col("zip_code").isNotNull())
        .withColumn("location_business_key", 
            md5(concat_ws("|", lower(col("address")), lower(col("city")), col("zip_code"))))
        .dropDuplicates(["location_business_key"])
    )
    
    window_spec = Window.orderBy("address", "city", "zip_code")
    return (df_location
        .withColumn("location_key", row_number().over(window_spec))
        .withColumn("job_load_id", lit("dim_location_initial_load"))
        .withColumn("job_load_date", current_timestamp())
        .select("location_key", "location_business_key", "address", "city",
                "zip_code", "latitude", "longitude", "job_load_id", "job_load_date")
    )

@dlt.table(name="dim_inspection_type")
def dim_inspection_type():
    df = spark.read.table("midterm.source1_layer.silver_table")
    
    df_type = (df
        .select(trim(col("Inspection_Type")).alias("inspection_type"),
                trim(upper(col("City"))).alias("city"))
        .filter(col("inspection_type").isNotNull())
        .withColumn("inspection_type_business_key",
            md5(concat_ws("|", col("inspection_type"), col("city"))))
        .dropDuplicates(["inspection_type_business_key"])
    )
    
    window_spec = Window.orderBy("inspection_type", "city")
    return (df_type
        .withColumn("inspection_type_key", row_number().over(window_spec))
        .withColumn("inspection_category",
            when(lower(col("inspection_type")).contains("routine"), "Routine")
            .when(lower(col("inspection_type")).contains("follow"), "Follow-up")
            .when(lower(col("inspection_type")).contains("complaint"), "Complaint")
            .otherwise("Other"))
        .withColumn("job_load_id", lit("dim_inspection_type_initial_load"))
        .withColumn("job_load_date", current_timestamp())
        .select("inspection_type_key", "inspection_type_business_key", 
                "inspection_type", "inspection_category", "city", 
                "job_load_id", "job_load_date")
    )

@dlt.table(name="dim_inspection_result")
def dim_inspection_result():
    df = spark.read.table("midterm.source1_layer.silver_table")
    
    df_result = (df
        .select(trim(col("Inspection_Results")).alias("result_code"))
        .filter(col("result_code").isNotNull())
        .withColumn("result_business_key", md5(col("result_code")))
        .dropDuplicates(["result_business_key"])
    )
    
    window_spec = Window.orderBy("result_code")
    return (df_result
        .withColumn("inspection_result_key", row_number().over(window_spec))
        .withColumn("result_category",
            when(upper(col("result_code")).contains("PASS"), "Pass")
            .when(upper(col("result_code")).contains("FAIL"), "Fail")
            .otherwise("Other"))
        .withColumn("job_load_id", lit("dim_inspection_result_initial_load"))
        .withColumn("job_load_date", current_timestamp())
        .select("inspection_result_key", "result_business_key", "result_code", 
                "result_category", "job_load_id", "job_load_date")
    )

@dlt.table(name="dim_risk_category")
def dim_risk_category():
    df = spark.read.table("midterm.source1_layer.silver_table")
    
    df_risk = (df
        .select(col("Risk_Category").alias("risk_category"))
        .filter(col("risk_category").isNotNull())
        .withColumn("risk_level",
            when(lower(col("risk_category")).contains("high") | 
                 lower(col("risk_category")).contains("1"), "High")
            .when(lower(col("risk_category")).contains("medium") | 
                  lower(col("risk_category")).contains("2"), "Medium")
            .when(lower(col("risk_category")).contains("low") | 
                  lower(col("risk_category")).contains("3"), "Low")
            .otherwise("Unknown"))
        .withColumn("priority_level",
            when(col("risk_level") == "High", 1)
            .when(col("risk_level") == "Medium", 2)
            .when(col("risk_level") == "Low", 3)
            .otherwise(99))
        .select("risk_level", "priority_level")
        .distinct()
        .withColumn("risk_business_key",
            md5(concat_ws("|", col("risk_level"), col("priority_level").cast("string"))))
        .dropDuplicates(["risk_business_key"])
    )
    
    window_spec = Window.orderBy("priority_level")
    return (df_risk
        .withColumn("risk_category_key", row_number().over(window_spec))
        .withColumn("job_load_id", lit("dim_risk_category_initial_load"))
        .withColumn("job_load_date", current_timestamp())
        .select("risk_category_key", "risk_business_key", "risk_level", 
                "priority_level", "job_load_id", "job_load_date")
    )

@dlt.table(name="dim_violation")
def dim_violation():
    df = spark.read.table("midterm.source1_layer.silver_table")
    
    df_violation = (df
        .select(
            col("Violation_Code").cast("int").alias("violation_code"),
            trim(col("Violation_Desc")).alias("violation_desc"),
            coalesce(col("is_violation_critical"), lit(False)).alias("is_violation_critical"),
            coalesce(col("is_violation_urgent"), lit(False)).alias("is_violation_urgent"),
            col("City").alias("city_source")
        )
        .filter(col("violation_code").isNotNull())
        .filter(col("violation_desc").isNotNull())
        .groupBy("violation_code", "city_source")
        .agg(
            first("violation_desc").alias("violation_desc"),
            max("is_violation_critical").alias("is_violation_critical"),
            max("is_violation_urgent").alias("is_violation_urgent")
        )
        .withColumn("violation_business_key",
            md5(concat_ws("|", col("violation_code"), col("city_source"))))
        .dropDuplicates(["violation_business_key"])
    )
    
    window_spec = Window.orderBy("violation_code", "city_source")
    return (df_violation
        .withColumn("violation_key", row_number().over(window_spec))
        .withColumn("violation_category",
            when(col("is_violation_critical") == True, "Critical")
            .when(col("is_violation_urgent") == True, "Urgent")
            .otherwise("Standard"))
        .withColumn("job_load_id", lit("dim_violation_initial_load"))
        .withColumn("job_load_date", current_timestamp())
        .select("violation_key", "violation_business_key", "violation_code", 
                "violation_desc", "is_violation_critical", "is_violation_urgent", 
                "violation_category", "city_source", "job_load_id", "job_load_date")
    )

# ============================================
# FACT_INSPECTION (COMPLETE)
# ============================================
@dlt.table(
    name="fact_inspection",
    partition_cols=["date_key"]
)
def fact_inspection():
    
    df_silver = spark.read.table("midterm.source1_layer.silver_table")
    
    # Prepare fact data
    df_fact = df_silver.select(
        col("Inspection_ID").alias("inspection_id"),
        col("Inspection_Date").alias("inspection_date"),
        col("License_No"),
        col("Address"),
        col("City"),
        col("Zip_Code").cast("int").alias("zip_code"),
        trim(col("Inspection_Type")).alias("inspection_type"),
        trim(col("Inspection_Results")).alias("inspection_results"),
        col("Risk_Category").alias("risk_category"),
        col("Inspection_Score").cast("int").alias("inspection_score"),
        col("Violation_Count").cast("int").alias("violation_count"),
        col("FileName").alias("file_name")
    ).dropDuplicates(["Inspection_ID"])
    
    # ✅ Remove city suffix from License_No (same as dim_restaurant)
    df_fact = df_fact.withColumn("clean_license",
        trim(regexp_replace(
            regexp_replace(
                regexp_replace(
                    regexp_replace(col("License_No"), "_DALLAS$", ""),
                    "_CHICAGO$", ""),
                "_Dallas$", ""),
            "_Chicago$", "")))
    
    # Clean other columns
    df_fact = df_fact \
        .withColumn("clean_city", trim(upper(col("City")))) \
        .withColumn("clean_address", trim(col("Address")))
    
    # Create date_key
    df_fact = df_fact.withColumn("date_key",
        concat(
            lpad(year(col("inspection_date")), 4, "0"),
            lpad(month(col("inspection_date")), 2, "0"),
            lpad(dayofmonth(col("inspection_date")), 2, "0")
        ).cast("int"))
    
    # Derive risk_level
    df_fact = df_fact.withColumn("risk_level",
        when(lower(col("risk_category")).contains("high") | 
             lower(col("risk_category")).contains("1"), "High")
        .when(lower(col("risk_category")).contains("medium") | 
              lower(col("risk_category")).contains("2"), "Medium")
        .when(lower(col("risk_category")).contains("low") | 
              lower(col("risk_category")).contains("3"), "Low")
        .otherwise("Unknown"))
    
    # Create business keys
    df_fact = df_fact \
        .withColumn("loc_bk", md5(concat_ws("|", 
            lower(col("clean_address")), 
            lower(col("clean_city")), 
            col("zip_code")))) \
        .withColumn("type_bk", md5(concat_ws("|", col("inspection_type"), col("clean_city")))) \
        .withColumn("result_bk", md5(col("inspection_results"))) \
        .withColumn("risk_bk", md5(concat_ws("|", col("risk_level"),
            when(col("risk_level") == "High", 1)
            .when(col("risk_level") == "Medium", 2)
            .when(col("risk_level") == "Low", 3)
            .otherwise(99).cast("string"))))
    
    # Read dimensions
    dim_restaurant = spark.read.table("LIVE.dim_restaurant").filter(col("is_current") == True)
    dim_location = spark.read.table("LIVE.dim_location")
    dim_type = spark.read.table("LIVE.dim_inspection_type")
    dim_result = spark.read.table("LIVE.dim_inspection_result")
    dim_risk = spark.read.table("LIVE.dim_risk_category")
    
    # JOIN Restaurant
    df_fact = df_fact.join(
        broadcast(dim_restaurant.select(
            col("license_no").alias("r_lic"),
            col("city").alias("r_city"),
            col("restaurant_key")
        )),
        (col("clean_license") == col("r_lic")) & 
        (col("clean_city") == col("r_city")),
        "left"
    ).drop("r_lic", "r_city")
    
    # JOIN Location
    df_fact = df_fact.join(
        broadcast(dim_location.select(col("location_business_key"), col("location_key"))),
        df_fact["loc_bk"] == dim_location["location_business_key"],
        "left"
    ).drop("loc_bk")
    
    # JOIN Type
    df_fact = df_fact.join(
        broadcast(dim_type.select(col("inspection_type_business_key"), col("inspection_type_key"))),
        df_fact["type_bk"] == dim_type["inspection_type_business_key"],
        "left"
    ).drop("type_bk")
    
    # JOIN Result
    df_fact = df_fact.join(
        broadcast(dim_result.select(col("result_business_key"), col("inspection_result_key"))),
        df_fact["result_bk"] == dim_result["result_business_key"],
        "left"
    ).drop("result_bk")
    
    # JOIN Risk
    df_fact = df_fact.join(
        broadcast(dim_risk.select(col("risk_business_key"), col("risk_category_key"))),
        df_fact["risk_bk"] == dim_risk["risk_business_key"],
        "left"
    ).drop("risk_bk")
    
    # Add surrogate key
    window_spec = Window.orderBy("inspection_id")
    df_fact = df_fact.withColumn("inspection_fact_key", row_number().over(window_spec))
    
    # Add audit
    df_fact = df_fact \
        .withColumn("job_load_id", lit("fact_inspection_load")) \
        .withColumn("job_load_date", current_timestamp()) \
        .withColumn("created_date", current_timestamp())
    
    return df_fact.select(
        "inspection_fact_key", "inspection_id", "date_key",
        "restaurant_key", "location_key", "inspection_type_key",
        "inspection_result_key", "risk_category_key",
        "inspection_score", "violation_count", "file_name",
        "inspection_date", "created_date", "job_load_id", "job_load_date"
    )

# ============================================
# FACT_VIOLATION
# ============================================
@dlt.table(name="bridge_fact_violation")
def fact_violation():
    df = spark.read.table("midterm.source1_layer.silver_table")
    
    df_violation = (df
        .select(
            col("Inspection_ID").alias("inspection_id"),
            col("Violation_Code").cast("int").alias("violation_code"),
            col("City").alias("city_source"),
            trim(col("Violation_Comments")).alias("violation_comments")
        )
        .filter(col("violation_code").isNotNull())
        .withColumn("violation_business_key",
            md5(concat_ws("|", col("violation_code"), col("city_source"))))
    )
    
    fact_inspection = spark.read.table("LIVE.fact_inspection")
    dim_violation = spark.read.table("LIVE.dim_violation")
    
    df_violation = (df_violation
        .join(fact_inspection.select("inspection_id", col("inspection_fact_key")), 
              "inspection_id", "inner")
        .join(dim_violation.select("violation_business_key", col("violation_key")), 
              "violation_business_key", "left")
    )
    
    window_spec = Window.orderBy("inspection_id", "violation_code")
    return (df_violation
        .withColumn("violation_fact_key", row_number().over(window_spec))
        .withColumn("job_load_id", lit("fact_violation_initial_load"))
        .withColumn("job_load_date", current_timestamp())
        .withColumn("created_date", current_timestamp())
        .select(
            "violation_fact_key", "inspection_fact_key", "violation_key",
            "violation_comments", "created_date", "job_load_id", "job_load_date")
        
    )