## Introduction

This notebook is a follow up to DS 1 and DS 2. In this notebook, the process of building models and generating predictions is combined into a single process, and is based off live data in the stock_minute_agg table (recommend having at least a few weeks of data) instead of downloading history. In addition, this notebook will attempt to pull parameters for each stock from MLflow -- this allows a model with ideal parameters to be created and stored (such as in DS1).

In [None]:
!pip install prophet
import pyspark.sql.functions as F
import pandas as pd

from pyspark.sql.functions import concat, col, lit, when, substring 
from pyspark.sql.types import *

import datetime
import time
from datetime import datetime
from datetime import timedelta

In [None]:
# specify date cutoff -- this is typically the current date unless testing

generated_date = datetime.utcnow()
cutoff_date = datetime.utcnow()

# manually specify a cutoff date
# cutoff_date = '2023-12-28 01:23:45'
# cutoff_date = datetime.strptime(cutoff_date, '%Y-%m-%d %H:%M:%S')

# normalize times to nearest minute
cutoff_date = cutoff_date.replace(second=0, microsecond=0)
print(f'Cutoff date: {cutoff_date}')

In [None]:
def readStockHistoryLive():
    
    df = spark.sql("SELECT * FROM stocks_minute_agg")

    # create a timestamp column, derived from the datestamp + hour + minute columns
    df = df.withColumn('timestamp', F.expr("to_timestamp(datestamp) + make_interval(0, 0, 0, 0, hour, minute, 0)"))

    # drop the datestamp, hour, minute, minprice, maxprice columns
    df = df.drop('Datestamp', 'Hour', 'Minute', 'MinPrice', 'MaxPrice')

    # Rename column 'Symbol' to 'symbol'
    df = df.withColumnRenamed('Symbol', 'symbol')

    # Rename column 'LastPrice' to 'lastprice'
    df = df.withColumnRenamed('LastPrice', 'lastprice')

    df = df.sort("timestamp")
    return df

# read all data into a dataframe
df_stocks = readStockHistoryLive()

In [None]:
def filterStocksByDate(df, date):

    df_stocks_history = df.select("*").where(
        f'timestamp < "{str(date)}"')

    return df_stocks_history

# keep only data up until current date (no future looking data)
# not needed for live data unless testing 
df_stocks = filterStocksByDate(df_stocks, cutoff_date)

In [None]:
def create_prediction_table():
    spark.sql(f"""
        CREATE TABLE IF NOT EXISTS stocks_prediction (
            Predict_time TIMESTAMP
            ,Symbol VARCHAR(5)
            ,yhat DOUBLE
            ,yhat_lower DOUBLE
            ,yhat_upper DOUBLE
            ,Generated TIMESTAMP)
        USING DELTA
        """)
    
# create the stocks prediction table if needed
create_prediction_table()

In [None]:
# gets all symbols to process.
# symbols can be specified explicitly or by filtering from dataframe

def get_symbols(df):

    # create symbols manually if dim_symbol (from lakehouse module) does not exist
    symbol_df = spark.createDataFrame( \
        [['BCUZ'], ['IDGD'], ['IDK'], ['TDY'], ['TMRW'], ['WHAT'], ['WHO'], ['WHY']],['Symbol'])

    # # can specify a single symbol like so:
    # symbol_df = spark.createDataFrame( \
    #     [['WHO']],['Symbol'])

    # by default, get all symbols from the current dataframe
    if not df.rdd.isEmpty():
        symbol_df = df.select('symbol').distinct().sort('symbol')
        symbol_df = symbol_df.withColumnRenamed('symbol','Symbol')

    symbols = symbol_df.collect()
    return symbols

symbols_list = get_symbols(df_stocks)
print(symbols_list)

## Functions for filtering, building, and merging data

In [None]:
# remove all but specified stock symbol
# individual models can be built for each stock

def filterStocksBySymbol(df_stocks, symbol):
    
    df_stocks_filtered = df_stocks.select("*").where(
        f'symbol == "{symbol}"'
    )
    df_stocks_filtered = df_stocks_filtered.sort("timestamp")

    df_stocks_filtered.tail(4)
    return df_stocks_filtered

In [None]:
# establish begin/end dates for prediction
# returns an empty dataframe

def make_prediction_dataframe(fromdate = datetime.utcnow()):

    enddate = fromdate + timedelta(days=7)

    print(f'Forecast {fromdate} to {enddate}')

    future = pd.DataFrame({'ds': pd.date_range(start=fromdate, end=enddate, freq='T')})
    return future

In [None]:
import mlflow
import pandas as pd
from mlflow import MlflowClient
from mlflow.entities import ViewType

# queries mlflow for matching models for each stock, using the most recent model
def get_model_params(symbol):

    use_default_params = True

    # define default parameters for prophet
    param_grid = {  
        'changepoint_prior_scale': 0.025,
        'changepoint_range': 0.95,
        'seasonality_prior_scale': 10,
        'weekly_seasonality': 6
    }

    try:
        runs_df = mlflow.search_runs(experiment_names=[f"{symbol}-stock-prediction"],
            run_view_type=ViewType.ACTIVE_ONLY,
            filter_string="attributes.status = 'Finished'",
            order_by=["attributes.start_time DESC"])
            
        if not runs_df.empty:
            run = runs_df.iloc[0]
            run_id = run.run_id
            model_uri = f"runs:/{run_id}/{symbol}-stock-prediction-model"

            use_default_params = False
            param_grid["changepoint_prior_scale"] = float(run["params.changepoint_prior_scale"])
            param_grid["changepoint_range"] = float(run["params.changepoint_range"])
            param_grid["seasonality_prior_scale"] = float(run["params.seasonality_prior_scale"])
            param_grid["weekly_seasonality"] = float(run["params.weekly_seasonality"])

            print(f'Found params in MLflow for {symbol} {model_uri}')

            # can also get the entire params section by making add't call to get_run
            # run = mlflow.get_run(run_id)
            # param_grid = run.data.params

    except Exception as e:
        print(f'Exception in getting parameters from MLflow. Using defaults where needed. Details: {e}')        

    print(f'{"Using default params" if use_default_params == True else "Using MLflow params"}: {param_grid}')
    return param_grid


In [None]:
from prophet import Prophet
from prophet.plot import add_changepoints_to_plot

def build_and_predict(symbol, dfStocks, prediction_begin_date):

    # set the seconds/microseconds to zero to normalize the time across runs
    prediction_begin_date = prediction_begin_date.replace(second=0, microsecond=0)

    # predict_df = loaded_model.make_future_dataframe(periods=60*24*7, freq='min', include_history = False)
    predict_df = make_prediction_dataframe(prediction_begin_date)

    dfstocks_pd = dfStocks.toPandas()

    # rename the columns as expected by Prophet (ds and y)
    dfstocks_pd = dfstocks_pd.rename(columns={'timestamp': 'ds'})
    dfstocks_pd = dfstocks_pd.rename(columns={'lastprice': 'y'})
    print('Min data date: ', dfstocks_pd['ds'].min())
    print('Max data date: ', dfstocks_pd['ds'].max())

    # get default params or load from mlflow
    params = get_model_params(symbol)

    m = Prophet(**params)
    m.fit(dfstocks_pd)

    forecast = m.predict(predict_df)
    return forecast


In [None]:
# merge the predictions with the table in the lakehouse

from delta.tables import *

def write_predictions(predictions_pd, symbol, generated):

    predictions_df = spark.createDataFrame(predictions_pd) 
    predictions_df = predictions_df.withColumn("symbol", lit(symbol))
    predictions_df = predictions_df.withColumn("generated", lit(generated))
   
    stock_predictions_table = DeltaTable.forName(spark, "stocks_prediction")

    stock_predictions_table.alias('table') \
    .merge(
        predictions_df.alias('predictions'),
        f'table.Predict_time = predictions.ds and table.Symbol = "{symbol}"'
    ) \
    .whenMatchedUpdate(set =
        {
            "yhat": "predictions.yhat"
            ,"yhat_lower": "predictions.yhat_lower"
            ,"yhat_upper": "predictions.yhat_upper"
            ,"Generated": f"'{str(generated)}'"
        }
    ) \
    .whenNotMatchedInsert(values =
        {
            "Predict_time": "predictions.ds"
            ,"Symbol": f"'{symbol}'"
            ,"yhat": "predictions.yhat"
            ,"yhat_lower": "predictions.yhat_lower"
            ,"yhat_upper": "predictions.yhat_upper"
            ,"Generated": f"'{str(generated)}'"
        }
    ) \
    .execute()

## Main loop to build predicitions for each symbol

In [None]:
# loop through all the symbols,
# filter the data by symbol, generate predictions
# write predictions to table

for row in symbols_list:
    start_time = datetime.utcnow()
    print(f'Starting: {row.Symbol} at {start_time}')

    df_stocks_filtered_symbol = filterStocksBySymbol(df_stocks, row.Symbol)
    forecast = build_and_predict(row.Symbol, df_stocks_filtered_symbol, cutoff_date)
    forecast_finish_time = datetime.utcnow()

    write_predictions(forecast, row.Symbol, generated_date)
    write_finish_time = datetime.utcnow()
    forecast_elap = forecast_finish_time - start_time
    write_elap = write_finish_time - forecast_finish_time

    print(f'Completed: {row.Symbol} at {datetime.utcnow()}. ' \
        f'Model: {forecast_elap.total_seconds()}s ' \
        f'Write: {write_elap.total_seconds()}s')
    print('-' * 50)

In [None]:
# spark.sql("DELETE FROM stocks_prediction")

In [None]:
df = spark.sql(f'SELECT * FROM stocks_prediction WHERE Generated == "{generated_date}" ORDER BY Predict_time ASC')
display(df)

In [None]:
# visualize the predictions

import plotly.express as px
import plotly.graph_objects as go
import pandas as pd

df_pd = df.toPandas()
symbols_pd = sorted(df_pd['Symbol'].unique())

for symbol in symbols_pd:
    fig = go.Figure()

    dftemp = df_pd.loc[df_pd['Symbol'] == symbol][["Predict_time","yhat","yhat_lower","yhat_upper"]]
    dftemp = dftemp.set_index(pd.DatetimeIndex(dftemp["Predict_time"])).drop("Predict_time", axis=1)

    # use resample when graphing to limit data points on graph
    dftemp = dftemp.resample("60min").mean()
 
    dftemp.reset_index(inplace = True)

    fig.add_trace(go.Scatter(x=dftemp['Predict_time'], y=dftemp['yhat_upper'], name=f'{symbol} Upper', line=dict(color='red',width=1)))
    fig.add_trace(go.Scatter(x=dftemp['Predict_time'], y=dftemp['yhat'], name=f'{symbol} Predicted', line=dict(color='black',width=2)))
    fig.add_trace(go.Scatter(x=dftemp['Predict_time'], y=dftemp['yhat_lower'], name=f'{symbol} Lower', line=dict(color='green',width=1)))

    fig.update_layout(title=f'{symbol} forecacst', showlegend=True)
    fig.show()