# Imports and Setup

In [0]:
from pyspark.sql.functions import col, upper, lower, initcap, trim, from_json, to_timestamp
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType
from delta.tables import DeltaTable

In [0]:
catalog = "fuel"
# bronze tables
bronze_stations_table = f"{catalog}.bronze.nsw_fuel_stations_bronze"
bronze_prices_table = f"{catalog}.bronze.nsw_fuel_prices_bronze"
# silver tables
silver_state_table = f"{catalog}.silver.state"
silver_brand_table = f"{catalog}.silver.brand"
silver_station_table = f"{catalog}.silver.station"
silver_fuel_table = f"{catalog}.silver.fuel"
silver_price_table = f"{catalog}.silver.price"

In [0]:
def clean_name_string_col(col_exp):
    return initcap(lower(trim(col_exp)))

def clean_acronym_string_col(col_exp):
    return upper(trim(col_exp))

In [0]:
station_stream_df = (
    spark.readStream
        .format("delta")
        .table(bronze_stations_table)
)

# Silver Table Batch Functions

## State Table

In [0]:
def ingest_into_silver_state_table(batch_df, batch_id: int):

    batch_df.createOrReplaceTempView("temp_batch_stations")

    # clean up and make states unique
    cleaned_unique_states_df = spark.sql(f"""
        select distinct upper(trim(state)) as state_code
            from temp_batch_stations
            where state is not null                    
    """)
    cleaned_unique_states_df.createOrReplaceTempView("temp_cleaned_states")

    # create silver state table if it doesn't exist
    spark.sql(f"""
        create table if not exists {silver_state_table} (
            state_id bigint generated by default as identity,
            state_code string not null
        ) using delta
    """)

    # find new states not in silver state table
    states_to_add_df = spark.sql(f"""
        select tcs.state_code 
            from temp_cleaned_states tcs left join {silver_state_table} sss on tcs.state_code = sss.state_code
            where sss.state_code is null
    """)
    states_to_add_df.createOrReplaceTempView("temp_states_to_add")

    # insert new states into silver state table
    spark.sql(f"""
        insert into {silver_state_table} (state_code)
            select state_code
                from temp_states_to_add
    """)

## Brand Table

In [0]:
def ingest_into_silver_brand_table(batch_df, batch_id: int):

    batch_df.createOrReplaceTempView("temp_batch_stations")

    # clean up and make brands unique
    cleaned_unique_brands_df = spark.sql(f"""
        select distinct initcap(lower(trim(brand))) as brand_name
            from temp_batch_stations
            where brand is not null                    
    """)
    cleaned_unique_brands_df.createOrReplaceTempView("temp_cleaned_brands")

    # create silver brand table if it doesn't exist
    spark.sql(f"""
        create table if not exists {silver_brand_table} (
            brand_id bigint generated by default as identity,
            brand_name string not null
        ) using delta
    """)

    # find new brands not in silver brand table
    brands_to_add_df = spark.sql(f"""
        select tcb.brand_name 
            from temp_cleaned_brands tcb left join {silver_brand_table} sbs on tcb.brand_name = sbs.brand_name
            where sbs.brand_name is null
    """)
    brands_to_add_df.createOrReplaceTempView("temp_brands_to_add")

    # insert new brands into silver brand table
    spark.sql(f"""
        insert into {silver_brand_table} (brand_name)
            select brand_name
                from temp_brands_to_add
    """)

## Station Table

In [0]:
def ingest_into_silver_station_table(batch_df, batch_id: int):

    station_df = batch_df.dropDuplicates()

    # split location into latitude and longitiude
    station_df = station_df.select(
        "*", 
        col("location.latitude").alias("latitude"), 
        col("location.longitude").alias("longitude")
    )
    station_df = station_df.drop("location")

    # clean brand and state for joining
    station_df = station_df.select(
        "*",
        clean_name_string_col(col("brand")).alias("brand_name"),
        clean_acronym_string_col(col("state")).alias("state_code")
    )

    # replace state with state_id
    state_df = spark.read.table(silver_state_table)
    station_df = station_df.join(state_df, on=station_df.state_code == state_df.state_code, how="left")
    station_df = (
        station_df.drop("state")
        .drop("state_code")
    )

    # replace brand with brand_id
    brand_df = spark.read.table(silver_brand_table)
    station_df = station_df.join(brand_df, on=station_df.brand_name == brand_df.brand_name, how="left")
    station_df = (
        station_df.drop("brand")
        .drop("brand_name")
    )

    # drop columns
    station_df = (
        station_df.drop("_ingest_ts")
        .drop("_ingest_file")
        .drop("brandid")
        .drop("stationid")
    )

    # rename station_code
    station_df = station_df.withColumnRenamed("code", "station_code")

    # deduplicate station_code
    station_df = station_df.dropDuplicates(["station_code"])

    # create station table if it does not exist
    spark.sql(f"""
        create table if not exists {silver_station_table} (
            station_id bigint generated by default as identity,
            state_id bigint not null,
            brand_id bigint not null,
            station_code int not null,
            address string,
            name string,
            latitude double,
            longitude double
        ) using delta
    """)

    # keep only new stations
    current_stations_df = spark.sql(f"""
        select station_code from {silver_station_table}
    """)
    station_df = station_df.join(current_stations_df, on=station_df.station_code == current_stations_df.station_code, how="leftanti")

    # insert new stations into silver station table
    station_df.createOrReplaceTempView("temp_new_stations")
    spark.sql(f"""
        insert into {silver_station_table} (state_id, brand_id, station_code, address, name, latitude, longitude)
            select state_id, brand_id, station_code, address, name, latitude, longitude
            from temp_new_stations
    """)

# Write Stream

In [0]:
def bronze_to_silver(batch_df, batch_id: int):

    # process station based changes
    ingest_into_silver_state_table(batch_df, batch_id)
    ingest_into_silver_brand_table(batch_df, batch_id)
    ingest_into_silver_station_table(batch_df, batch_id)

query = (
    station_stream_df.writeStream
        .option("checkpointLocation", "/Workspace/nsw-fuel-project/silver_station_checkpoint")
        .trigger(availableNow=True)  # batch processinsg of all available changes
        .foreachBatch(bronze_to_silver)
        .start()
)

query.awaitTermination()