# Part 2: Streaming application using Spark Structured Streaming  
In this task, you will implement Spark Structured Streaming to consume the data from task 1 and perform a prediction.    
Important:   
-	This task uses PySpark Structured Streaming with PySpark Dataframe APIs and PySpark ML.  
-	You also need your pipeline model from A2A to make predictions and persist the results.  

1.	Write code to create a SparkSession, which 1) uses four cores with a proper application name; 2) use the Melbourne timezone; 3) ensure a checkpoint location has been set.


In [1]:
from pyspark.sql import SparkSession





# ---------- CONFIGURATION ----------
APP_NAME = "FIT5202_A2B_WeatherStream"
CHECKPOINT_DIR = "/home/student/checkpoints/a2b_task1"  # make sure this exists or will be created
TIMEZONE = "Australia/Melbourne"

# ---------- CREATE SPARK SESSION ----------
from pyspark.sql import SparkSession

SPARK_VER = "3.5.0"

spark = (
    SparkSession.builder
    .appName(APP_NAME)
    .master("local[4]")
    .config("spark.sql.session.timeZone", "Australia/Melbourne")
    .config("spark.driver.memory", "7g")
    .config("spark.executor.memory", "7g")
    .config("spark.sql.shuffle.partitions", "8")
    .config("spark.memory.fraction", "0.6")
    .config(
        "spark.jars.packages",
        f"org.apache.spark:spark-sql-kafka-0-10_2.12:{SPARK_VER},"
        f"org.apache.kafka:kafka-clients:{SPARK_VER}"
    )
    .getOrCreate()
)
spark.sparkContext.setLogLevel("WARN")

spark.conf.set("spark.sql.files.maxPartitionBytes", "64m")
spark.conf.set("spark.sql.shuffle.partitions", "4")


print("✅ SparkSession created successfully")
print(f"Application Name: {spark.sparkContext.appName}")
print(f"Master: {spark.sparkContext.master}")
print(f"Timezone: {spark.conf.get('spark.sql.session.timeZone')}")
print(f"Checkpoint Directory: {spark.conf.get('spark.sql.streaming.checkpointLocation')}")


✅ SparkSession created successfully
Application Name: FIT5202_A2B_WeatherStream
Master: local[4]
Timezone: Australia/Melbourne
Checkpoint Directory: None


2.	Write code to define the data schema for the data files, following the data types suggested in the metadata file. Load the static datasets (e.g. building information) into data frames. (You can reuse your code from 2A.)


In [2]:
# === A2B — Define schemas per metadata & load static datasets (reusing 2A style) ===
from pyspark.sql.types import (
    StructType, StructField,
    IntegerType, StringType, TimestampType, DecimalType, DoubleType
)
from pyspark.sql import functions as F

# ---------------------------
# 1) meters.csv schema (time-series) — per metadata
# ---------------------------
meters_schema = StructType([
    StructField("building_id", IntegerType(),  False),
    StructField("meter_type",  StringType(),   False),   # Char(1) -> StringType in Spark
    StructField("ts",          TimestampType(),False),
    StructField("value",       DecimalType(20, 6), False),
    StructField("row_id",      IntegerType(),  False),
])




# ---------------------------
# Paths (adjust to your files)
# ---------------------------
METERS_CSV    = "new_meters.csv"              

# Load meters/weather as static tables (only if needed now)
# ---------------------------
meters_df = (spark.read
     .option("header", True)
      .schema(meters_schema)
     .csv(METERS_CSV))




print(" Meters:")
meters_df.printSchema()
print("Sample Meters:")
meters_df.show(5, truncate=False)


 Meters:
root
 |-- building_id: integer (nullable = true)
 |-- meter_type: string (nullable = true)
 |-- ts: timestamp (nullable = true)
 |-- value: decimal(20,6) (nullable = true)
 |-- row_id: integer (nullable = true)

Sample Meters:
+-----------+----------+-------------------+---------+------+
|building_id|meter_type|ts                 |value    |row_id|
+-----------+----------+-------------------+---------+------+
|163        |c         |2022-01-01 00:00:00|4.571900 |3     |
|170        |c         |2022-01-01 00:00:00|11.289100|8     |
|171        |c         |2022-01-01 00:00:00|0.000000 |9     |
|172        |c         |2022-01-01 00:00:00|0.000000 |10    |
|174        |c         |2022-01-01 00:00:00|52.858300|12    |
+-----------+----------+-------------------+---------+------+
only showing top 5 rows



3.	Using the Kafka topic from the producer in Task 1, ingest the streaming data into Spark Streaming, assuming all data comes in the String format. Except for the 'weather_ts' column, you shall receive it as an Int type. Load the new building information CSV file into a dataframe. Then, the data frames should be transformed into the proper formats following the metadata file schema, similar to assignment 2A.


In [3]:
# pyspark_app_weather_stream.py  (no building join)

import time, uuid
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    col, from_json, array, coalesce, explode_outer,
    to_timestamp, from_unixtime
)
from pyspark.sql.types import (
    StructType, StructField, StringType, IntegerType, DoubleType, ArrayType
)

# -------------------- CONFIG --------------------
TIMEZONE                 = "Australia/Melbourne"
CHECKPOINT_DIR  = "/home/student/checkpoints/a2b_task1/weather"
OUT_WEATHER     = "/home/student/out/weather_6h_ts"
KAFKA_BOOTSTRAP          = "kafka:9092"
TOPIC_SUBSCRIBE_PATTERN  = r"weather-site-\d+"
STARTING_OFFSETS         = "latest"   # use "earliest" if you want to replay retained data

# ------------------ KAFKA SOURCE -----------------
kafka_df = (
    spark.readStream
    .format("kafka")
    .option("kafka.bootstrap.servers", KAFKA_BOOTSTRAP)
    .option("subscribePattern", TOPIC_SUBSCRIBE_PATTERN)
    .option("startingOffsets", STARTING_OFFSETS)
    .option("failOnDataLoss", "false")  # dev-friendly: don’t crash if retention trimmed
    .load()
)

# value is binary; cast to string
value_str_df = kafka_df.selectExpr(
    "CAST(value AS STRING) AS value", "topic", "partition", "offset", "timestamp as kafka_ingest_ts"
)

# ------------------ WIRE SCHEMA ------------------
wire_obj_schema = StructType([
    StructField("site_id",            StringType(),  True),
    StructField("timestamp",          StringType(),  True),
    StructField("air_temperature",    StringType(),  True),
    StructField("cloud_coverage",     StringType(),  True),
    StructField("dew_temperature",    StringType(),  True),
    StructField("sea_level_pressure", StringType(),  True),
    StructField("wind_direction",     StringType(),  True),
    StructField("wind_speed",         StringType(),  True),
    StructField("weather_ts",         IntegerType(), True),
    StructField("day_index",          StringType(),  True),
])

wire_array_schema = ArrayType(wire_obj_schema)

parsed_arr = value_str_df.select(
    col("value"),
    col("topic"), col("partition"), col("offset"), col("kafka_ingest_ts"),
    from_json(col("value"), wire_array_schema).alias("rows_arr"),
    from_json(col("value"), wire_obj_schema).alias("row_obj")
)

rows_df = parsed_arr.select(
    col("topic"), col("partition"), col("offset"), col("kafka_ingest_ts"),
    coalesce(col("rows_arr"), array(col("row_obj"))).alias("rows")
)

exploded = rows_df.select(
    col("topic"), col("partition"), col("offset"), col("kafka_ingest_ts"),
    explode_outer(col("rows")).alias("r")
).select(
    "topic", "partition", "offset", "kafka_ingest_ts",
    col("r.site_id").alias("site_id_str"),
    col("r.timestamp").alias("ts_str"),
    col("r.air_temperature").alias("air_temperature_str"),
    col("r.cloud_coverage").alias("cloud_coverage_str"),
    col("r.dew_temperature").alias("dew_temperature_str"),
    col("r.sea_level_pressure").alias("sea_level_pressure_str"),
    col("r.wind_direction").alias("wind_direction_str"),
    col("r.wind_speed").alias("wind_speed_str"),
    col("r.weather_ts").alias("weather_ts_int"),
    col("r.day_index").alias("day_index_str")
)


In [4]:
# ---------------------------
# 2) buildings.csv schema (static) — per metadata
# ---------------------------
buildings_schema = StructType([
    StructField("site_id",       IntegerType(),      False),
    StructField("building_id",   IntegerType(),      False),
    StructField("primary_use",   StringType(),       True),
    StructField("square_feet",   IntegerType(),      True),
    StructField("floor_count",   IntegerType(),      True),
    StructField("row_id",        IntegerType(),      False),
    StructField("year_built",    IntegerType(),      True),
    StructField("latent_y",      DecimalType(20, 6), True),
    StructField("latent_s",      DecimalType(20, 6), True),
    StructField("latent_r",      DecimalType(20, 6), True),
])
BUILDINGS_CSV = "new_building_information.csv"   

# ---------------------------
# Load STATIC dataset(s): buildings.csv
# ---------------------------
buildings_df = (
    spark.read
         .option("header", True)
         .schema(buildings_schema)
         .csv(BUILDINGS_CSV)
)


# ---------------------------
# Quick sanity prints
# ---------------------------
print("Buildings schema:")
buildings_df.printSchema()
print("Sample buildings:")
buildings_df.show(5, truncate=False)


Buildings schema:
root
 |-- site_id: integer (nullable = true)
 |-- building_id: integer (nullable = true)
 |-- primary_use: string (nullable = true)
 |-- square_feet: integer (nullable = true)
 |-- floor_count: integer (nullable = true)
 |-- row_id: integer (nullable = true)
 |-- year_built: integer (nullable = true)
 |-- latent_y: decimal(20,6) (nullable = true)
 |-- latent_s: decimal(20,6) (nullable = true)
 |-- latent_r: decimal(20,6) (nullable = true)

Sample buildings:
+-------+-----------+-----------+-----------+-----------+------+----------+---------+--------+--------+
|site_id|building_id|primary_use|square_feet|floor_count|row_id|year_built|latent_y |latent_s|latent_r|
+-------+-----------+-----------+-----------+-----------+------+----------+---------+--------+--------+
|10     |1017       |Technology |109263     |6          |1018  |1971      |29.000000|4.260310|4.000000|
|4      |587        |Technology |53234      |5          |588   |1949      |51.000000|4.027186|3.000000|


4.	Use a watermark on weather_ts, if data points are received 5 seconds late, discard the data.

In [14]:
from pyspark.sql import functions as F

# Make sure session TZ matches producer/CSV expectations
# spark.conf.set("spark.sql.session.timeZone", "Australia/Melbourne")

# 1) Parse columns
weather_base = (
    exploded
    .withColumn("measure_ts", F.to_timestamp("ts_str"))                # 2022 event time
    .withColumn("event_time", F.to_timestamp(F.from_unixtime("weather_ts_int")))  # producer clock
    .withColumn("site_id", F.col("site_id_str").cast("int"))
    .withColumn("air_temperature",    F.col("air_temperature_str").cast("double"))
    .withColumn("cloud_coverage",     F.col("cloud_coverage_str").cast("double"))
    .withColumn("dew_temperature",    F.col("dew_temperature_str").cast("double"))
    .withColumn("sea_level_pressure", F.col("sea_level_pressure_str").cast("double"))
    .withColumn("wind_direction",     F.col("wind_direction_str").cast("int"))
    .withColumn("wind_speed",         F.col("wind_speed_str").cast("double"))
    .drop(
        "site_id_str","ts_str","air_temperature_str","cloud_coverage_str",
        "dew_temperature_str","sea_level_pressure_str","wind_direction_str","wind_speed_str"
    )
)

# 2) Watermark on event_time (assignment) + dedup at hour grain
dedupbed = (
    weather_base
    .withWatermark("event_time", "5 seconds")
    .withColumn("measure_hour", F.date_trunc("hour", F.col("measure_ts")))
    .dropDuplicates(["site_id", "measure_hour"])   # bounded state ~140k keys in your scale
)

# 3) Derive join keys & friendly label
with_keys = (
    dedupbed
    .withColumn("date", F.to_date("measure_hour"))
    .withColumn("slot", F.floor(F.hour("measure_hour")/6).cast("int"))
    .withColumn(
        "slot_label",
        F.when(F.col("slot")==0, F.lit("00:00-05:59"))
         .when(F.col("slot")==1, F.lit("06:00-11:59"))
         .when(F.col("slot")==2, F.lit("12:00-17:59"))
         .otherwise(F.lit("18:00-23:59"))
    )
)

# 4) 6h aggregates
weather_6h = (
    with_keys
    .groupBy("site_id","date","slot","slot_label")
    .agg(
        F.avg("air_temperature").alias("s_air"),
        F.avg("cloud_coverage").alias("s_cloud"),
        F.avg("dew_temperature").alias("s_dew"),
        F.avg("sea_level_pressure").alias("s_slp"),
        F.avg("wind_speed").alias("s_wspd"),
        F.count(F.lit(1)).alias("n_obs"),
        F.atan2(
            F.avg(F.sin(F.radians("wind_direction"))),
            F.avg(F.cos(F.radians("wind_direction")))
        ).alias("s_wdir_rad")
    )
    .withColumn("s_wdir", ((F.degrees("s_wdir_rad") + F.lit(360.0)) % F.lit(360.0)))
    .drop("s_wdir_rad")
    .withColumn(
        "peak_flag",
        F.when(F.col("slot").isin(1,2), F.lit("peak")).otherwise(F.lit("off-peak"))
    )
)


5.	Perform the necessary transformation you used in A2A. (note: every student may have used different features, feel free to reuse the code you have written in A2A. If you built an end-to-end pipeline, you can ignore this task.) 

In [15]:
from pyspark.sql import functions as F

# Add hour-of-day and 6-hour interval slot
meters_with_slot = (
    meters_df
      .withColumn("hour", F.hour("ts"))
      .withColumn("slot", (F.col("hour") / 6).cast("int"))  # 0,1,2,3
)


In [16]:
# Add date for grouping
meters_with_slot = meters_with_slot.withColumn("date", F.to_date("ts"))

# Aggregate
agg_df = (
    meters_with_slot
      .groupBy("building_id", "date", "slot")
      .agg(F.sum("value").alias("energy_6h"))
)


In [17]:
agg_df = agg_df.withColumn(
    "slot_label",
    F.when(F.col("slot") == 0, "00:00-05:59")
     .when(F.col("slot") == 1, "06:00-11:59")
     .when(F.col("slot") == 2, "12:00-17:59")
     .when(F.col("slot") == 3, "18:00-23:59")
)


In [18]:
print("6-hour aggregated energy consumption:")
agg_df.show(12, truncate=False)


6-hour aggregated energy consumption:
+-----------+----------+----+------------+-----------+
|building_id|date      |slot|energy_6h   |slot_label |
+-----------+----------+----+------------+-----------+
|194        |2022-01-01|0   |1746.996100 |00:00-05:59|
|260        |2022-01-01|0   |2472.028500 |00:00-05:59|
|889        |2022-01-01|0   |1960.266000 |00:00-05:59|
|895        |2022-01-01|0   |998.004900  |00:00-05:59|
|926        |2022-01-01|0   |247.469100  |00:00-05:59|
|931        |2022-01-01|0   |2898.023600 |00:00-05:59|
|933        |2022-01-01|0   |205.366700  |00:00-05:59|
|952        |2022-01-01|0   |6129.160000 |00:00-05:59|
|973        |2022-01-01|0   |1973.255500 |00:00-05:59|
|1092       |2022-01-01|0   |10067.018600|00:00-05:59|
|1140       |2022-01-01|0   |43435.750400|00:00-05:59|
|1160       |2022-01-01|0   |132.299000  |00:00-05:59|
+-----------+----------+----+------------+-----------+
only showing top 12 rows



6.	Load your pipeline model and perform the following aggregations:  
a)	Print the prediction from your model as a stream comes in.  
b)	Every 7 seconds, print the total energy consumption for each 6-hour interval, aggregated by building, and print 20 records. (Note: This is simulating energy data each day in a week)  
c)	Every 14 seconds, for each site, print the daily total energy consumption.  

In [19]:
# ===== CELL 0: Spark tuning & safety =====
from pyspark.sql import functions as F
import time

# Kill any old queries to avoid resource contention
for q in list(spark.streams.active):
    print("stopping", q.name); q.stop()
time.sleep(1.0)

# Practical defaults
spark.sparkContext.setLogLevel("WARN")
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.shuffle.partitions", "8")   # tune 8–16 based on cores


In [20]:
# ===== CELL 1: Config, model, helpers =====
from pyspark.sql import functions as F, Window as W
from pyspark.ml import PipelineModel
import uuid, datetime as dt, os, math, time

# ---------- CONFIG ----------
MODEL_PATH = "models/gbt_best_model"           # change if needed
BASE = "/tmp/stream_jobs"                      # use a host mount if you want persistence
CKPT_ONE   = f"{BASE}/ckpt_one_{dt.datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}"

# Optional batch Parquet outputs (inside foreachBatch; set to False to disable)
WRITE_PARQUET_B = False   # 6h-by-building sink
WRITE_PARQUET_C = False   # daily-by-site sink
OUT_B = f"{BASE}/out/energy_6h_by_building"
OUT_C = f"{BASE}/out/daily_energy_by_site"

# Predictions clamp (avoid negative & extreme outliers)
CLAMP_CAP = 1e6
PRINT_SAMPLE_ROWS = 10

# ---------- ASSUMED INPUTS (already exist in your session) ----------
# - weather_6h: streaming DF with [site_id,date,slot,slot_label,s_air,s_dew,s_wspd,s_cloud,s_slp,...]
# - buildings_df: static DF with [building_id,site_id,square_feet,floor_count,primary_use, (optional) latent_*]
# - agg_df: static 6h energy DF with [building_id,date,slot,energy_6h,slot_label]

# ---------- UTIL: tiny inspectors ----------
def nonempty(df): 
    try: 
        return not df.rdd.isEmpty()
    except Exception:
        return False

def _print_stats(tag, batch_id, df):
    try:
        n = df.count()
    except Exception:
        n = 0
    print(f"\n[{tag}] batch {batch_id} | rows={n}", flush=True)

# ---------- Load model & figure out assembler inputs ----------
model = PipelineModel.load(MODEL_PATH)

assembler_input_cols = []
model_output_cols = {"prediction","rawPrediction","probability"}
for stg in model.stages:
    cname = stg.__class__.__name__
    if cname == "VectorAssembler":
        assembler_input_cols = list(stg.getInputCols())
        model_output_cols.add(stg.getOutputCol())
    elif "StringIndexerModel" in cname:
        try:
            model_output_cols.add(stg.getOutputCol())
        except Exception:
            pass

def ensure_model_inputs(df):
    for c in assembler_input_cols:
        if c not in df.columns:
            df = df.withColumn(c, F.lit(0.0).cast("double"))
    return df

def drop_model_outputs_if_exist(df):
    return df.drop(*[c for c in model_output_cols if c in df.columns])

# ---------- Precompute lags/rolling from full historic energy (STATIC, once) ----------
w = W.partitionBy("building_id").orderBy("date","slot")
energy_lagged = (
    agg_df
    .withColumn("lag_1",       F.lag("energy_6h", 1).over(w))
    .withColumn("lag_4",       F.lag("energy_6h", 4).over(w))
    .withColumn("lag_28",      F.lag("energy_6h", 28).over(w))
    .withColumn("roll_mean_4", F.avg("energy_6h").over(w.rowsBetween(-4, -1)))
    .withColumn("roll_std_4",  F.stddev("energy_6h").over(w.rowsBetween(-4, -1)))
    .fillna({"lag_1":0.0,"lag_4":0.0,"lag_28":0.0,"roll_mean_4":0.0,"roll_std_4":0.0})
)

energy_enriched = (
    energy_lagged
    .join(buildings_df.select("building_id","site_id"), on="building_id", how="left")
    .select("building_id","site_id","date","slot","energy_6h",
            "lag_1","lag_4","lag_28","roll_mean_4","roll_std_4")
)

# ---------- Join weather × buildings (per micro-batch) ----------
buildings_slim = buildings_df.select(
    "building_id","site_id","square_feet","floor_count","primary_use",
    *[c for c in ["latent_y","latent_s","latent_r"] if c in buildings_df.columns]
)

def join_weather_buildings(batch_weather):
    return (
        batch_weather
        .select(
            "site_id","date","slot","slot_label",
            F.col("s_air").alias("avg_air_temp"),
            F.col("s_dew").alias("avg_dew_temp"),
            F.col("s_wspd").alias("avg_wind_speed"),
            F.col("s_cloud").alias("avg_cloud_coverage"),
            F.col("s_slp").alias("avg_sea_level_pressure"),
        )
        .withColumn("site_id", F.col("site_id").cast("int"))
        .withColumn("slot",    F.col("slot").cast("int"))
        .join(F.broadcast(buildings_slim.withColumn("site_id", F.col("site_id").cast("int"))),
              on="site_id", how="left")
        .withColumn("primary_use", F.coalesce(F.col("primary_use"), F.lit("unknown")))
    )

# ---------- Feature shaping ----------
def shape_for_model_from_joined(joined):
    base = (
        joined
        .withColumn("peak_flag",
            F.when(F.col("slot_label").isin("06:00-11:59","12:00-17:59"), "peak").otherwise("off-peak"))
        .withColumn("slot_sin", F.sin(2*math.pi*F.col("slot")/F.lit(4)))
        .withColumn("slot_cos", F.cos(2*math.pi*F.col("slot")/F.lit(4)))
        .withColumn("dow", F.dayofweek("date"))
        .withColumn("dow_sin", F.sin(2*math.pi*(F.col("dow")-1)/F.lit(7)))
        .withColumn("dow_cos", F.cos(2*math.pi*(F.col("dow")-1)/F.lit(7)))
        .withColumn("cdh", F.greatest(F.col("avg_air_temp")-F.lit(18.0), F.lit(0.0)))
        .withColumn("hdh", F.greatest(F.lit(18.0)-F.col("avg_air_temp"), F.lit(0.0)))
        .withColumn("sum_weather_core",
            F.coalesce(F.col("avg_air_temp"),F.lit(0.0)) +
            F.coalesce(F.col("avg_dew_temp"),F.lit(0.0)) +
            F.coalesce(F.col("avg_wind_speed"),F.lit(0.0)))
    )

    # ensure lag/rolling columns exist
    for c in ["lag_1","lag_4","lag_28","roll_mean_4","roll_std_4"]:
        if c not in base.columns:
            base = base.withColumn(c, F.lit(0.0))

    # weather variance features
    nW = F.lit(5.0)
    sq_sum = (
        F.pow(F.coalesce(F.col("avg_air_temp"),F.lit(0.0)),2) +
        F.pow(F.coalesce(F.col("avg_dew_temp"),F.lit(0.0)),2) +
        F.pow(F.coalesce(F.col("avg_wind_speed"),F.lit(0.0)),2) +
        F.pow(F.coalesce(F.col("avg_cloud_coverage"),F.lit(0.0)),2) +
        F.pow(F.coalesce(F.col("avg_sea_level_pressure"),F.lit(0.0)),2)
    )
    mean = (
        F.coalesce(F.col("avg_air_temp"),F.lit(0.0)) +
        F.coalesce(F.col("avg_dew_temp"),F.lit(0.0)) +
        F.coalesce(F.col("avg_wind_speed"),F.lit(0.0)) +
        F.coalesce(F.col("avg_cloud_coverage"),F.lit(0.0)) +
        F.coalesce(F.col("avg_sea_level_pressure"),F.lit(0.0))
    )/nW
    var = (sq_sum/nW) - (mean*mean)

    base = (base
        .withColumn("std_weather", F.sqrt(F.when(var < 0, 0.0).otherwise(var)))
        .withColumn("root_ratio_weather_var_over_size",
                    F.sqrt((F.col("std_weather")**2) /
                           (F.coalesce(F.col("square_feet"),F.lit(0.0)) + F.lit(1e-6))))
        .withColumn("root_ratio_rollstd_over_size",
                    F.sqrt(F.coalesce(F.col("roll_std_4"),F.lit(0.0)) /
                           (F.coalesce(F.col("square_feet"),F.lit(0.0)) + F.lit(1e-6))))
    )

    for lc in ("latent_y","latent_s","latent_r"):
        if lc not in base.columns:
            base = base.withColumn(lc, F.lit(0.0).cast("double"))

    base = base.fillna({
        "square_feet": 0.0, "floor_count": 0.0,
        "avg_air_temp": 0.0, "avg_dew_temp": 0.0, "avg_wind_speed": 0.0,
        "avg_cloud_coverage": 0.0, "avg_sea_level_pressure": 0.0,
        "lag_1": 0.0, "lag_4": 0.0, "lag_28": 0.0, "roll_mean_4": 0.0, "roll_std_4": 0.0
    })

    base = ensure_model_inputs(base)
    base = drop_model_outputs_if_exist(base)
    return base.withColumn("label", F.lit(None).cast("double"))


In [21]:
# ===== CELL 2: One stream, three tasks (7s + 14s) =====
import os, time

def process_all(batch_df, epochId: int):
    # Quick empty check
    if batch_df.head(1) == []:
        print(f"\n[batch {epochId}] no new rows"); 
        return

    # --- common prep (once per micro-batch) ---
    wb = join_weather_buildings(batch_df)
    wb_lag = (
        wb.join(
            F.broadcast(
                energy_enriched.select(
                    "building_id","date","slot",
                    "lag_1","lag_4","lag_28","roll_mean_4","roll_std_4"
                )
            ),
            on=["building_id","date","slot"], how="left"
        )
        .fillna({"lag_1":0.0,"lag_4":0.0,"lag_28":0.0,"roll_mean_4":0.0,"roll_std_4":0.0})
    )
    shaped = drop_model_outputs_if_exist(shape_for_model_from_joined(wb_lag))

    # --- (a) Predictions (every batch) ---
    preds = (
        model.transform(shaped)
        .select("building_id","site_id","date","slot",
                F.col("prediction").alias("energy_raw"))
        .withColumn(
            "energy",
            F.greatest(F.lit(0.0), F.least(F.col("energy_raw"), F.lit(CLAMP_CAP)))
        )
        .drop("energy_raw")
        .cache()
    )

    print(f"\n--- (a) predictions — batch {epochId} ---")
    preds.orderBy("date","slot","building_id").show(PRINT_SAMPLE_ROWS, truncate=False)

    # --- (b) 6h energy by building (every batch ~7s) ---
    six = (preds.groupBy("building_id","date","slot")
                 .agg(F.sum("energy").alias("energy_6h_total")))
    print(f"\n--- (b) 6h energy by building — batch {epochId} ---")
    six.orderBy("date","slot","building_id").show(20, truncate=False)

    if WRITE_PARQUET_B:
        (six
          .write
          .mode("append")
          .partitionBy("building_id","date","slot")
          .parquet(OUT_B))

    # --- (c) daily energy by site (every second batch ~14s) ---
    if epochId % 2 == 0:
        daily = (preds.groupBy("site_id","date")
                       .agg(F.sum("energy").alias("daily_energy")))
        print(f"\n--- (c) daily energy by site — batch {epochId} ---")
        daily.orderBy("date","site_id").show(50, truncate=False)

        if WRITE_PARQUET_C:
            (daily
              .write
              .mode("append")
              .partitionBy("site_id","date")
              .parquet(OUT_C))

    preds.unpersist()

# Start ONE streaming query with a 7s trigger; (c) runs every 2nd batch => ~14s
one_q = (
    weather_6h
    .writeStream
    .queryName("one_stream_three_tasks")
    .outputMode("update")
    .option("checkpointLocation", os.path.join(CKPT_ONE, "one_stream_three_tasks_v11"))
    .trigger(processingTime="7 seconds")
    .foreachBatch(process_all)
    .start()
)

# Optional: small peek after a couple of batches
time.sleep(10)
print(one_q.lastProgress)



[batch 0] no new rows

--- (a) predictions — batch 1 ---
{'id': '5f589a5d-6ce6-4c29-8a96-d0cc054d9d6a', 'runId': '006c2f43-69ce-403d-992a-59b1ec7a7c8c', 'name': 'one_stream_three_tasks', 'timestamp': '2025-10-25T22:52:08.404Z', 'batchId': 0, 'numInputRows': 0, 'inputRowsPerSecond': 0.0, 'processedRowsPerSecond': 0.0, 'durationMs': {'addBatch': 429, 'commitOffsets': 16, 'getBatch': 0, 'latestOffset': 60, 'queryPlanning': 18, 'triggerExecution': 565, 'walCommit': 38}, 'eventTime': {'watermark': '1970-01-01T00:00:00.000Z'}, 'stateOperators': [{'operatorName': 'stateStoreSave', 'numRowsTotal': 0, 'numRowsUpdated': 0, 'allUpdatesTimeMs': 52, 'numRowsRemoved': 0, 'allRemovalsTimeMs': 0, 'commitTimeMs': 207, 'memoryUsedBytes': 1792, 'numRowsDroppedByWatermark': 0, 'numShufflePartitions': 8, 'numStateStoreInstances': 8, 'customMetrics': {'loadedMapCacheHitCount': 0, 'loadedMapCacheMissCount': 0, 'stateOnCurrentVersionSizeBytes': 640}}, {'operatorName': 'dedupe', 'numRowsTotal': 0, 'numRowsUpd


--- (a) predictions — batch 4 ---
+-----------+-------+----------+----+------------------+
|building_id|site_id|date      |slot|energy            |
+-----------+-------+----------+----+------------------+
|4          |0      |2022-09-08|0   |11687.217094702759|
|7          |0      |2022-09-08|0   |4574.489311491051 |
|11         |0      |2022-09-08|0   |2951.662716071554 |
|14         |0      |2022-09-08|0   |6070.670218895307 |
|15         |0      |2022-09-08|0   |12374.51618811274 |
|19         |0      |2022-09-08|0   |1331.8485989187102|
|25         |0      |2022-09-08|0   |2951.662716071554 |
|35         |0      |2022-09-08|0   |1331.8485989187102|
|42         |0      |2022-09-08|0   |8157.238547299251 |
|44         |0      |2022-09-08|0   |0.0               |
+-----------+-------+----------+----+------------------+
only showing top 10 rows


--- (b) 6h energy by building — batch 4 ---
+-----------+----------+----+------------------+
|building_id|date      |slot|energy_6h_total   


--- (a) predictions — batch 8 ---
+-----------+-------+----------+----+------------------+
|building_id|site_id|date      |slot|energy            |
+-----------+-------+----------+----+------------------+
|4          |0      |2022-10-08|0   |8157.238547299251 |
|7          |0      |2022-10-08|0   |0.0               |
|11         |0      |2022-10-08|0   |2951.662716071554 |
|14         |0      |2022-10-08|0   |84253.46804401917 |
|15         |0      |2022-10-08|0   |90557.3140132366  |
|19         |0      |2022-10-08|0   |1331.8485989187102|
|25         |0      |2022-10-08|0   |1331.8485989187102|
|35         |0      |2022-10-08|0   |1331.8485989187102|
|42         |0      |2022-10-08|0   |5383.371125485325 |
|44         |0      |2022-10-08|0   |90.69999381569704 |
+-----------+-------+----------+----+------------------+
only showing top 10 rows


--- (b) 6h energy by building — batch 8 ---
+-----------+----------+----+------------------+
|building_id|date      |slot|energy_6h_total   

In [23]:
one_q.stop()

7.	Save the data from 6 to Parquet files as streams. (Hint: Parquet files support streaming writing/reading. The file keeps updating while new batches arrive.)

In [24]:



# ===== 6a) PREDICTIONS → PARQUET (STREAM) =====
from pyspark.sql import functions as F
from pyspark.sql.types import *
import os

spark.conf.set("spark.sql.shuffle.partitions", "8")       # down from 200
spark.conf.set("spark.sql.files.maxPartitionBytes", "64m")
spark.conf.set("spark.sql.streaming.fileSource.logSegmentBytes", "1048576")  # 1MB segments


OUT_A  = "/tmp/stream_jobs/out/predictions_a_v02"   # use a new path if you changed partitioning
CKPT_A = os.path.join(CHECKPOINT_DIR, "predictions_a_ckpt_v02")

CLAMP_CAP = 5000.0
PRINT_SAMPLE_ROWS = 10

def write_predictions(batch_df, epochId: int):
    if batch_df.head(1) == []:
        print(f"[6a/batch {epochId}] no new rows")
        return

    # 1) site→building expansion for current batch
    wb = join_weather_buildings(batch_df)

    # 2) add lag/rolling features from historical energy (left join)
    wb_lag = (
        wb.join(
            F.broadcast(
                energy_enriched.select(
                    "building_id","date","slot",
                    "lag_1","lag_4","lag_28","roll_mean_4","roll_std_4"
                )
            ),
            on=["building_id","date","slot"], how="left"
        )
        .fillna({"lag_1":0.0,"lag_4":0.0,"lag_28":0.0,"roll_mean_4":0.0,"roll_std_4":0.0})
    )

    # 3) shape features for the model
    shaped = drop_model_outputs_if_exist(shape_for_model_from_joined(wb_lag))

    # 4) predict
    preds = (
        model.transform(shaped)
        .select("building_id","site_id","date","slot",
                F.col("prediction").alias("energy_raw"))
        .withColumn("energy", F.greatest(F.lit(0.0),
                              F.least(F.col("energy_raw"), F.lit(CLAMP_CAP))))
        .drop("energy_raw")
        .cache()
    )


    # 5) write to parquet (INSIDE this function)
    (preds
        .coalesce(2)                # fewer, larger files
        .write.mode("append")
        .partitionBy("date")        # simpler partitioning to reduce file listing
        .parquet(OUT_A)
    )

    preds.unpersist()

# Start the streaming query (from your weather_6h upstream)
predictions_q = (
    weather_6h
      .writeStream
      .outputMode("update")
      .queryName("predictions_to_parquet_v2")
      .foreachBatch(write_predictions)
      .trigger(processingTime="7 seconds")
      .option("checkpointLocation", CKPT_A)
      .start()
)

print("✅ 6a running — writing predictions to", OUT_A)


✅ 6a running — writing predictions to /tmp/stream_jobs/out/predictions_a_v02
[6a/batch 0] no new rows


In [25]:
from pyspark.sql import functions as F
from pyspark.sql.types import *

spark.conf.set("spark.sql.shuffle.partitions", "8")
spark.conf.set("spark.sql.files.maxPartitionBytes", "64m")

schema = StructType([
    StructField("building_id", IntegerType(), True),
    StructField("site_id",     IntegerType(), True),
    StructField("date",        StringType(),  True),
    StructField("slot",        IntegerType(), True),
    StructField("energy",      DoubleType(),  True),
])

IN_A  = "/tmp/stream_jobs/out/predictions_a_v02"
OUT_B = "/tmp/stream_jobs/out/energy_6h_by_building_ver02"
CKPT_B = os.path.join(CHECKPOINT_DIR, "energy_6h_by_building_ckpt_v02")  # NEW

energy_source = (
    spark.readStream
         .schema(schema)
         .option("maxFilesPerTrigger", 20)
         .parquet(IN_A)
)

building_6h_stream = (
    energy_source
    .groupBy("building_id", "date", "slot")
    .agg(F.sum("energy").alias("energy_6h_total"))
)

def write_b(df, epochId: int):
    rows = df.count()
    if rows == 0:
        return
    (df.coalesce(2)
       .write.mode("append")
       .partitionBy("date","slot","building_id")
       .parquet(OUT_B))

b_parquet_q = (
    building_6h_stream.writeStream
    .outputMode("update")
    .queryName("energy_6h_from_v4_b")
    .foreachBatch(write_b)
    .option("checkpointLocation", CKPT_B)
    .trigger(processingTime="7 seconds")
    .start()
)

print("✅ 6b running →", OUT_B)


✅ 6b running → /tmp/stream_jobs/out/energy_6h_by_building_ver02


In [26]:
# ===== 6c) STREAM DAILY ENERGY TOTALS BY SITE TO PARQUET (with logging) =====
from pyspark.sql import functions as F
import os

# Reuse the SAME energy_source used by 6b (do NOT recreate with a different path).
# energy_source has schema: building_id, site_id, date, slot, energy

OUT_C  = "/tmp/stream_jobs/out/energy_daily_by_site_ver02"
CKPT_C = os.path.join(CHECKPOINT_DIR, "energy_daily_by_site_ckpt_ver02")  # <- NEW checkpoint

site_daily_stream = (
    energy_source
    .groupBy("site_id", "date")
    .agg(F.sum("energy").alias("daily_energy"))
)

def write_c(df, epochId: int):
    # cheap emptiness check for streaming batches
    if not df.take(1):
        print(f"[6c/batch {epochId}] no new rows")
        return

    # write out
    (df.coalesce(2)
       .write
       .mode("append")
       .partitionBy("date")             # daily partition
       .parquet(OUT_C))

# stop any old 6c with same name
for q in spark.streams.active:
    if q.name == "energy_daily_by_site_v3":
        print("stopping old", q.name)
        q.stop()

c_parquet_q = (
    site_daily_stream
    .writeStream
    .outputMode("update")               # fine with foreachBatch
    .queryName("energy_daily_by_site_v3")
    .foreachBatch(write_c)
    .option("checkpointLocation", CKPT_C)
    .trigger(processingTime="14 seconds")
    .start()
)

print("✅ 6c running — writing daily-by-site to", OUT_C)


✅ 6c running — writing daily-by-site to /tmp/stream_jobs/out/energy_daily_by_site_ver02


8.	Read the parquet files from task 7 as data streams and send them to Kafka topics with appropriate names.
(Note: You shall read the parquet files as a streaming data frame and send messages to the Kafka topic when new data appears in the parquet file.)

In [27]:
from pyspark.sql import functions as F
from pyspark.sql.types import *
import os

# ---------- Paths (adjust if you moved outputs off /tmp) ----------
IN_A = "/tmp/stream_jobs/out/predictions_a_v02"
IN_B = "/tmp/stream_jobs/out/energy_6h_by_building_ver02"
IN_C = "/tmp/stream_jobs/out/energy_daily_by_site_ver02"

CHECKPOINT_BASE = "/home/student/work/ass2b/ckpts_task02"
os.makedirs(CHECKPOINT_BASE, exist_ok=True)

KAFKA_BOOTSTRAP = "kafka:9092"   # your docker-compose host:port

# ---------- Schemas ----------
# A) predictions from 6a
pred_schema = StructType([
    StructField("building_id", IntegerType(), True),
    StructField("site_id",     IntegerType(), True),
    StructField("date",        StringType(),  True),   # was written as string in your code
    StructField("slot",        IntegerType(), True),
    StructField("energy",      DoubleType(),  True),
])

# B) 6-hour totals by building (from 6b)
b_schema = StructType([
    StructField("building_id",    IntegerType(), True),
    StructField("date",           StringType(),  True),
    StructField("slot",           IntegerType(), True),
    StructField("energy_6h_total",DoubleType(),  True),
])

# C) daily totals by site (from 6c)
c_schema = StructType([
    StructField("site_id",      IntegerType(), True),
    StructField("date",         StringType(),  True),
    StructField("daily_energy", DoubleType(),  True),
])

# ---------- Utility: build Kafka (key, value) columns ----------
def to_kafka_cols(df, key_cols, value_cols=None):
    """Return df with 'key' and 'value' as STRING, ready for Kafka sink."""
    if value_cols is None:
        value_cols = df.columns
    key_col = F.concat_ws(":", *[F.col(c).cast("string") for c in key_cols]).alias("key")
    val_col = F.to_json(F.struct(*[F.col(c) for c in value_cols])).alias("value")
    return df.select(key_col, val_col)


In [28]:
predictions_stream = (
    spark.readStream.schema(pred_schema)
         .option("maxFilesPerTrigger", 50)   # tune 25–200
         .option("latestFirst", "true")
         .option("failOnDataLoss", "false")
         .parquet(IN_A)
)

predictions_kafka = to_kafka_cols(
    predictions_stream,
    key_cols=["building_id", "site_id", "date", "slot"],     # consistent partitioning key
    value_cols=["building_id","site_id","date","slot","energy"]
)

q_pred = (
    predictions_kafka
      .select(F.col("key").cast("binary").alias("key"),
              F.col("value").cast("binary").alias("value"))
      .writeStream
      .format("kafka")
      .option("kafka.bootstrap.servers", KAFKA_BOOTSTRAP)
      .option("topic", "predictions_a")
      .option("checkpointLocation", os.path.join(CHECKPOINT_BASE, "predictions_a"))
      .outputMode("append")
      .start()
)

print("▶️ task7-A: streaming Parquet → Kafka topic 'predictions_a'")


▶️ task7-A: streaming Parquet → Kafka topic 'predictions_a'


In [29]:
sixh_stream = (
    spark.readStream.schema(b_schema)
         .option("maxFilesPerTrigger", 50)
         .option("latestFirst", "true")
         .option("failOnDataLoss", "false")
         .parquet(IN_B)
)

sixh_kafka = to_kafka_cols(
    sixh_stream,
    key_cols=["building_id", "date", "slot"],
    value_cols=["building_id","date","slot","energy_6h_total"]
)

q_b = (
    sixh_kafka
      .select(F.col("key").cast("binary").alias("key"),
              F.col("value").cast("binary").alias("value"))
      .writeStream
      .format("kafka")
      .option("kafka.bootstrap.servers", KAFKA_BOOTSTRAP)
      .option("topic", "energy_6h_by_building")
      .option("checkpointLocation", os.path.join(CHECKPOINT_BASE, "energy_6h_by_building"))
      .outputMode("append")
      .start()
)

print("▶️ task7-B: streaming Parquet → Kafka topic 'energy_6h_by_building'")


▶️ task7-B: streaming Parquet → Kafka topic 'energy_6h_by_building'


In [30]:
daily_stream = (
    spark.readStream.schema(c_schema)
         .option("maxFilesPerTrigger", 50)
         .option("latestFirst", "true")
         .option("failOnDataLoss", "false")
         .parquet(IN_C)
)

daily_kafka = to_kafka_cols(
    daily_stream,
    key_cols=["site_id", "date"],
    value_cols=["site_id","date","daily_energy"]
)

q_c = (
    daily_kafka
      .select(F.col("key").cast("binary").alias("key"),
              F.col("value").cast("binary").alias("value"))
      .writeStream
      .format("kafka")
      .option("kafka.bootstrap.servers", KAFKA_BOOTSTRAP)
      .option("topic", "energy_daily_by_site")
      .option("checkpointLocation", os.path.join(CHECKPOINT_BASE, "energy_daily_by_site"))
      .outputMode("append")
      .start()
)

print("▶️ task7-C: streaming Parquet → Kafka topic 'energy_daily_by_site'")


▶️ task7-C: streaming Parquet → Kafka topic 'energy_daily_by_site'
