In [0]:
"""
GOLD LAYER
- aggregate for daily reporting and upserts
"""
from delta.tables import DeltaTable
from pyspark.sql.functions import col, sum, avg, to_date

def create_gold():
    print("Building Gold Layer")


    bronze_df = (
        spark.read.table("cscie103_catalog_final.bronze.train")
        .withColumn("date", to_date("datetime"))
    )

    # Consumption aggregated by date + county + is_business + product_type
    consumption_df = (
        bronze_df
        .filter(col("is_consumption") == 1)
        .groupBy("date", "county", "is_business", "product_type")
        .agg(
            sum("target").alias("total_energy")
        )
    )

    # 2Production aggregated by date + county + is_business + product_type
    production_df = (
        bronze_df
        .filter(col("is_consumption") == 0)
        .groupBy("date", "county", "is_business", "product_type")
        .agg(
            sum("target").alias("total_energy_production")
        )
    )

    # Combine consumption + production
    energy_df = (
        consumption_df
        .join(
            production_df,
            on=["date", "county", "is_business", "product_type"],
            how="full_outer"
        )
    )

    # Weather metrics per date + county (from county_weather_4hours_vw)
    weather_df = (
        spark.read.table("cscie103_catalog_final.gold.county_weather_4hours_vw")
        .groupBy("wh_observ_date", "county_id")
        .agg(
            avg("wf_temperature").alias("avg_temp"),
            avg("wf_direct_solar_radiation").alias("avg_radiation")
        )
        .withColumnRenamed("wh_observ_date", "date")
        .withColumnRenamed("county_id", "county")
    )

    # Join energy + weather
    report_df = energy_df.join(weather_df, on=["date", "county"], how="left")

    # 5County name mapping
    county_mapping_df = (
        spark.read.table("cscie103_catalog_final.silver.county_mapping")
        .select(
            col("county_id").alias("county"),
            col("county_name")
        )
    )

    # County geo (1:1 county_id → lat/long)
    county_geo_df = (
        spark.read.table("cscie103_catalog_final.silver.county_geo")
        .select(
            col("county_id").alias("county"),
            col("county_latitude").alias("latitude"),
            col("county_longitude").alias("longitude")
        )
    )

    # Product mapping: product_type code → human-readable name
    product_mapping_df = (
        spark.read.table("cscie103_catalog_final.silver.product_mapping")
        .select(
            col("product_id").alias("product_type"),        # join key
            col("product_type").alias("product_type_name")  # label
        )
    )

    # Join in county attributes + product info
    enriched_df = (
        report_df
        .join(county_mapping_df, on="county", how="left")
        .join(county_geo_df, on="county", how="left")
        .join(product_mapping_df, on="product_type", how="left")
    )

    # Convert is_business: 1 → True, otherwise False
    enriched_df = enriched_df.withColumn(
        "is_business",
        (col("is_business") == 1)
    )

    # Final selected columns / ordering
    final_df = enriched_df.select(
        "date",
        "county",
        "county_name",
        "latitude",
        "longitude",
        "is_business",              # boolean
        "product_type",             # 0 / 1 / 2 / 3
        "product_type_name",        # from product_mapping
        "total_energy",
        "total_energy_production",
        "avg_temp",
        "avg_radiation"
    )

    # Target table
    target_table_name = "cscie103_catalog_final.gold.daily_energy_report"

    # Create table if it does not exist
    if not spark.catalog.tableExists(target_table_name):
        print("Creating table...")
        final_df.write.format("delta").saveAsTable(target_table_name)
        print("Table created.")
        return

    # Upsert / merge into existing Delta table
    deltaTable = DeltaTable.forName(spark, target_table_name)

    (
        deltaTable.alias("t")
        .merge(
            final_df.alias("s"),
            """
            t.date = s.date
            AND t.county = s.county
            AND t.is_business = s.is_business
            AND t.product_type = s.product_type
            """
        )
        .whenMatchedUpdateAll()
        .whenNotMatchedInsertAll()
        .execute()
    )

    print(f"Merge/Upsert complete for {target_table_name}")

create_gold()


Building Gold Layer
Creating table...
Table created.
