In [0]:
from pyspark.sql import functions as F

CATALOG = "energy_usage_data_platform"
GOLD_SCHEMA = f"{CATALOG}.gold"

In [0]:
daily = spark.table(f"{GOLD_SCHEMA}.daily_region_usage")


In [0]:
# Compute per-region/year percentiles for load and temp
stats = (
    daily
    .groupBy("year", "region_id")
    .agg(
        F.expr("percentile_approx(daily_kwh_total, 0.95)").alias("p95_kwh"),
        F.expr("percentile_approx(daily_kw_peak, 0.95)").alias("p95_kw_peak"),
        F.expr("percentile_approx(max_daily_temp_c, 0.95)").alias("p95_temp"),
        F.expr("percentile_approx(daily_precip_mm, 0.95)").alias("p95_precip"),
    )
)

d = daily.join(stats, on=["year", "region_id"], how="left")

In [0]:
# Basic event flags
d = (
    d
    .withColumn(
        "is_high_load_day",
        F.col("daily_kwh_total") >= F.col("p95_kwh")
    )
    .withColumn(
        "is_peak_demand_day",
        F.col("daily_kw_peak") >= F.col("p95_kw_peak")
    )
    .withColumn(
        "is_heatwave_day",
        F.col("max_daily_temp_c") >= F.col("p95_temp")
    )
    .withColumn(
        "is_heavy_precip_day",
        (F.col("daily_precip_mm").isNotNull()) &
        (F.col("daily_precip_mm") >= F.col("p95_precip"))
    )
)

In [0]:

# Derive a primary event label (simple priority scheme)
event_label = (
    F.when(F.col("is_heatwave_day") & F.col("is_high_load_day"), "heatwave_high_load")
     .when(F.col("is_heatwave_day"), "heatwave")
     .when(F.col("is_peak_demand_day"), "peak_demand")
     .when(F.col("is_heavy_precip_day"), "heavy_precip")
     .when(F.col("is_high_load_day"), "high_load")
     .otherwise("normal")
)

d = d.withColumn("event_label", event_label)

In [0]:
# Select final columns (keep original daily metrics + flags)
final_cols = [
    "obs_date",
    "year",
    "region_id",
    "city",
    "substation_id",
    "daily_kwh_total",
    "daily_kw_peak",
    "avg_daily_temp_c",
    "max_daily_temp_c",
    "daily_precip_mm",
    "is_high_load_day",
    "is_peak_demand_day",
    "is_heatwave_day",
    "is_heavy_precip_day",
    "event_label",
]

existing_cols = [c for c in final_cols if c in d.columns]
extreme_days = d.select(*existing_cols)

In [0]:
(
    extreme_days
    .write
    .mode("overwrite")
    .option("overwriteSchema", "true")
    .saveAsTable(f"{GOLD_SCHEMA}.extreme_event_days")
)
display(spark.table(f"{GOLD_SCHEMA}.extreme_event_days").limit(10))