## Lakehouse 4: Load Star Schema
This notebook will further process data from the hourly aggregation table into a dimension model.
Configure the variable **sourceTableName**, if needed, to match the hourly aggregation table.

In [None]:
from delta.tables import *
from pyspark.sql.functions import *
import datetime
from datetime import datetime

sourceTableName = 'stocks_hour_agg'

if not spark.catalog.tableExists(sourceTableName):
    msg = f'Error! Source table not found: {sourceTableName}'
    print(msg)
    raise SystemExit(msg)

In [None]:
# this function adds symbols to dim_symbol that may not exist in table
# this allows for new symbols to be added to feed over time

def dim_symbol_incremental_load(df_stocks, df_existing_symbols):

    # determine max id of current symbols table
    if df_existing_symbols.rdd.isEmpty():
        maxId = 0
    else:
        maxId = df_existing_symbols.select("Symbol_SK").rdd.max()[0]

    # for the new rows to be ingested, get a list of unique symbols
    df_symbols = df_stocks.select("Symbol").distinct().orderBy("Symbol")

    # get the symbols in the new dataset that do not exist in current symbols dimension
    df_symbols = df_symbols.join(df_existing_symbols, df_symbols.Symbol == df_existing_symbols.Symbol, "left_outer")\
                        .where(df_existing_symbols.Symbol.isNull()) \
                        .select(df_symbols.Symbol) \
                        .orderBy("Symbol")

    df_symbols = df_symbols.withColumn("Symbol_SK", monotonically_increasing_id() + maxId + 1)
    df_symbols = df_symbols.withColumn("Name", when(df_symbols.Symbol == "BCUZ","Company Because")
        .when(df_symbols.Symbol == "IDGD","Company IDontGiveADarn")
        .when(df_symbols.Symbol == "IDK","Company IDontKnow")
        .when(df_symbols.Symbol == "TDY","Company Today")
        .when(df_symbols.Symbol == "TMRW","Company Tomorrow")
        .when(df_symbols.Symbol == "WHAT","Company What")
        .when(df_symbols.Symbol == "WHY","Company Why")
        .when(df_symbols.Symbol == "WHO","Company Who")
        .otherwise("Company Unknown"))
    df_symbols = df_symbols.withColumn("Market", when(substring(df_symbols.Symbol,1,1) == "B","NASDAQ")
        .when(substring(df_symbols.Symbol,1,1) == "W","NASDAQ")
        .when(substring(df_symbols.Symbol,1,1) == "I","NYSE")
        .when(substring(df_symbols.Symbol,1,1) == "T","NYSE")
        .otherwise("No Market"))
    df_symbols = df_symbols.select(df_symbols.Symbol_SK, df_symbols.Symbol, df_symbols.Name, df_symbols.Market)

    # if the dataframe is empty, there are no missing symbols
    if df_symbols.rdd.isEmpty():
        print("No new symbols.") 
        return df_existing_symbols

    print("New Symbols:")
    df_symbols.show()

    dim_symbol_table = DeltaTable.forName(spark, "dim_symbol")

    dim_symbol_table.alias('dim_symbol') \
    .merge(
        df_symbols.alias('updates'),
        'dim_symbol.Symbol = updates.Symbol'
    ) \
    .whenNotMatchedInsert(values =
        {
            "Symbol_SK": "updates.Symbol_SK"
            ,"Symbol": "updates.Symbol"
            ,"Name": "updates.Name"
            ,"Market": "updates.Market"
        }
    ) \
    .execute()

    return spark.sql("SELECT * FROM dim_symbol ORDER BY Symbol ASC")

In [None]:
# find latest date 

df_watermark = spark.sql(f"SELECT PriceDateKey \
    FROM fact_stocks_daily_prices \
    ORDER BY PriceDateKey DESC LIMIT 1")

if not df_watermark.rdd.isEmpty():
    df_watermark.show()
    maxDate = df_watermark.first()["PriceDateKey"]
    cutoff_datetime = datetime(maxDate.year, maxDate.month, maxDate.day, 0, 0, 0)
else:
    print('Table is empty, using default date.')
    cutoff_datetime = '2000-01-01 00:00:00'

# manually specify a cutoff date
#cutoff_datetime = '2023-11-27 23:59:50'

print(f"Cutoff date: {cutoff_datetime}")

In [None]:
# get new stock data to ingest, starting at watermark
# limit is arbitrary; limited primarily for demo purposes

df_stocks = spark.sql(f"SELECT Symbol, MinPrice, MaxPrice, LastPrice, Datestamp, Hour FROM {sourceTableName} \
    WHERE Datestamp >= '{cutoff_datetime}' \
    ORDER BY Datestamp ASC, Hour ASC LIMIT 5000000")
df_stocks.show()

In [None]:
# load the date dimension for later joins

df_date = spark.sql("SELECT * FROM dim_date")
df_date.show()

In [None]:
# load the symbols dimension 

# creating the symbols incremental load in this way allows new symbols
# to be added over time dynamically. if new symbols are found in the 
# new stock data, they will be imported into the symbol dimension
# before continuing

df_symbol = spark.sql("SELECT * FROM dim_symbol ORDER BY Symbol ASC")
print("Current Symbols:")
df_symbol.show()

# load any new symbols into dimension
df_symbol = dim_symbol_incremental_load(df_stocks, df_symbol)

print("Symbols After Merge:")
df_symbol.show()

In [None]:
# Code generated by Data Wrangler for PySpark DataFrame

from pyspark.sql import functions as F

def clean_data(df_stocks):
    # df_stocks = df_stocks.withColumn('datestamp', to_date(df_stocks['timestamp']))
    df_stocks = df_stocks.groupBy('Symbol', 'Datestamp').agg(F.min('MinPrice').alias('newMinPrice'), 
        F.max('MAxPrice').alias('newMaxPrice'), F.last('LastPrice').alias('newClosePrice'))
    df_stocks = df_stocks.dropna()
    df_stocks = df_stocks.sort(df_stocks['Symbol'].asc(), df_stocks['Datestamp'].asc())
    return df_stocks

df_stocks_agg = clean_data(df_stocks)
display(df_stocks_agg)

In [None]:
# join the aggregated data to the date dimension

df_join = df_stocks_agg.join(df_date, df_stocks_agg.Datestamp == df_date.DateKey)
display(df_join)

In [None]:
# join the data from above with the symbols dimension

df_join = df_join.join(df_symbol, df_join.Symbol == df_symbol.Symbol)
display(df_join)

In [None]:
# create a final view with cleaned names for processing ease

df_final_view = df_join.select(col("DateKey").alias("newPriceDateKey"), col("dim_symbol.Symbol").alias("newSymbol"),
    col("dim_symbol.Symbol_SK").alias("newSymbol_SK"),"newMinPrice","newMaxPrice","newClosePrice")

df_final_view.show()

In [None]:
# to insert the new data, we'll merge the dataframe with the fact table.
# for existing records, update the high/low/close price of the stock
# for new records, insert a new row with the current high/low/close

from delta.tables import *

fact_stock_prices_table = DeltaTable.forName(spark, "fact_stocks_daily_prices")

fact_stock_prices_table.alias('fact') \
  .merge(
    df_final_view.alias('updates'),
    'fact.PriceDateKey = updates.newPriceDateKey and fact.Symbol_SK = updates.newSymbol_SK'
  ) \
  .whenMatchedUpdate(set =
    {
        "MinPrice": "CASE WHEN fact.MinPrice < updates.newMinPrice THEN fact.MinPrice ELSE updates.newMinPrice END"
        ,"MaxPrice": "CASE WHEN fact.MaxPrice > updates.newMaxPrice THEN fact.MaxPrice ELSE updates.newMaxPrice END"
        ,"ClosePrice": "updates.newClosePrice"
    }
  ) \
  .whenNotMatchedInsert(values =
    {
        "Symbol_SK": "updates.newSymbol_SK"
        ,"PriceDateKey": "updates.newPriceDateKey"
        ,"MinPrice": "updates.newMinPrice"
        ,"MaxPrice": "updates.newMaxPrice"
        ,"ClosePrice": "updates.newClosePrice"
    }
  ) \
  .execute()


The code below is for observing the output and comparing results.

In [None]:
# function that gets the latest fact data

def get_latest_fact():
    return spark.sql("SELECT dim.Symbol, fact.Symbol_SK, PriceDateKey, MinPrice, MaxPrice, ClosePrice \
        FROM fact_stocks_daily_prices fact \
        INNER JOIN dim_symbol dim on fact.Symbol_SK = dim.Symbol_SK \
        WHERE PriceDateKey >= date_add(current_date(),-30) \
        ORDER BY PriceDateKey ASC, fact.Symbol_SK ASC")

In [None]:
# run results:
df_run = get_latest_fact()
display(df_run)

In [None]:
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd

df_run_pd = df_run.toPandas()
symbols_pd = sorted(df_run_pd['Symbol'].unique())

fig = go.Figure()

for symbol in symbols_pd:
    dftemp = df_run_pd.loc[df_run_pd['Symbol'] == symbol][["PriceDateKey","ClosePrice"]]
    fig.add_trace(go.Scatter(x=dftemp['PriceDateKey'], y=dftemp['ClosePrice'], name=symbol, line=dict(width=1)))

fig.update_layout(title="Close Price - Last 30 Days", showlegend=True)
fig.show()