d #US Covid-19 Forecasting

The objective of this notebook is to illustrate how we might generate a large number of fine-grained forecasts at the store-item level in an efficient manner leveraging the distributed computational power of Databricks.  For this exercise, we will make use of an increasingly popular library for demand forecasting, [FBProphet](https://facebook.github.io/prophet/), which we will load into the notebook session associated with a cluster running Databricks 6.0 or higher:

In [2]:
# load fbprophet library
dbutils.library.installPyPI('FBProphet', version='0.5') # find latest version of fbprophet here: https://pypi.org/project/fbprophet/
dbutils.library.installPyPI('holidays','0.9.12') # this line is in response to this issue with fbprophet 0.5: https://github.com/facebook/prophet/issues/1293

dbutils.library.restartPython()

## Examine the Data

For our training dataset, we will make use of 5-years of store-item unit sales data for 50 items across 10 different stores.  This data set is publicly available as part of a past Kaggle competition and can be downloaded [here](https://www.kaggle.com/c/demand-forecasting-kernels-only/data). 

Once downloaded, we can uzip the *train.csv.zip* file and upload the decompressed CSV to */FileStore/tables/demand_forecast/train/* using the file import steps documented [here](https://docs.databricks.com/data/tables.html#create-table-ui). Please note when performing the file import, you don't need to select the *Create Table with UI* or the *Create Table in Notebook* options to complete the import process.

With the dataset accessible within Databricks, we can now explore it in preparation for modeling:

In [4]:
from pyspark.sql.types import *

# structure of the training data set
train_schema = StructType([
  StructField('county', StringType()),
  StructField('state', StringType()),
  StructField('date', DateType()),
  StructField('cases', IntegerType()),
  StructField('deaths', IntegerType()),
  StructField('new_cases', IntegerType()),
  StructField('new_deaths', IntegerType())
  ])

# read the training file into a dataframe
train = spark.read.csv(
  '/FileStore/tables/covid_ca_county.csv', 
  header=True, 
  schema=train_schema
  )

# make the dataframe queriable as a temporary view
train.createOrReplaceTempView('train')

When performing demand forecasting, we are often interested in general trends and seasonality.  Let's start our exploration by examing the annual trend in unit sales:

In [6]:
%sql

SELECT date,
       sum(cases) as cases,
       sum(deaths) as deaths,
       sum(new_cases) as new_cases,
       sum(new_deaths) as new_deaths
 FROM train
 GROUP BY date;

date,cases,deaths,new_cases,new_deaths
2020-01-21,0,0,0,0
2020-04-30,50470,2057,1566,96
2020-03-07,100,1,19,0
2020-03-13,320,5,68,1
2020-02-04,6,0,0,0
2020-02-15,7,0,0,0
2020-05-23,92815,3768,2014,78
2020-02-12,7,0,0,0
2020-05-08,64618,2650,2135,89
2020-05-24,94743,3790,1928,22


In [7]:
%sql

SELECT 
       sum(cases) as cases,
       sum(deaths) as deaths

  from train
 where date = (select max(date) from train)

cases,deaths
400195,7764


In [8]:
%sql

SELECT 
       sum(new_cases) as new_cases,
       sum(new_deaths) as new_deaths

  from train

new_cases,new_deaths
400195,7764


It's very clear from the data that there is a generally upward trend in total unit sales across the stores. If we had better knowledge of the markets served by these stores, we might wish to identify whether there is a maximum growth capacity we'd expect to approach over the life of our forecast.  But without that knowledge and by just quickly eyeballing this dataset, it feels safe to assume that if our goal is to make a forecast a few days, months or even a year out, we might expect continued linear growth over that time span.

Now let's examine seasonality.  If we aggregate the data around the individual months in each year, a distinct yearly seasonal pattern is observed which seems to grow in scale with overall growth in sales:

In [10]:
%sql

SELECT county,
       sum(cases) as cases,
       sum(deaths) as deaths
  from train
 where date = (select max(date) from train)
 group by county
 order by cases desc
 limit 10;

county,cases,deaths
Los Angeles,159045,4104
Riverside,30340,588
Orange,29986,493
San Diego,24198,478
San Bernardino,24099,329
Fresno,10639,100
Kern,10094,105
Alameda,9277,162
Imperial,8606,162
San Joaquin,8321,94


In [11]:
%sql

SELECT county,
       sum(cases) as cases
  from train
 where date = (select max(date) from train)
 and county <> "Unknown"
 group by county

county,cases
Plumas,17
Kings,3398
Marin,4376
Inyo,38
Sonoma,2212
Napa,633
Madera,1451
Siskiyou,51
Ventura,5748
Orange,29986


In [12]:
%sql

SELECT county,
       sum(deaths) as deaths
  from train
 where date = (select max(date) from train)
 and county <> "Unknown"
 group by county

county,deaths
Plumas,0
Kings,42
Marin,40
Inyo,1
Sonoma,20
Napa,5
Madera,13
Siskiyou,0
Ventura,58
Orange,493


Aggregating the data at a weekday level, a pronounced weekly seasonal pattern is observed with a peak on Sunday (weekday 0), a hard drop on Monday (weekday 1) and then a steady pickup over the week heading back to the Sunday high.  This pattern seems to be pretty stable across the five years of observations:

Now that we are oriented to the basic patterns within our data, let's explore how we might build a forecast.

###Build a Forecast

Before attempting to generate forecasts for individual combinations of stores and items, it might be helpful to build a single forecast for no other reason than to orient ourselves to the use of FBProphet.

Our first step is to assemble the historical dataset on which we will train the model:

In [16]:
# query to aggregate data to date (ds) level
sql_statement = '''
  SELECT
    CAST(date as date) as ds,
    sum(cases) as y
  FROM train
  group by ds
  ORDER BY ds
  '''

# assemble dataset in Pandas dataframe
history_pd = spark.sql(sql_statement).toPandas()

# drop any missing records
history_pd = history_pd.dropna()

Now, we will import the fbprophet library, but because it can be a bit verbose when in use, we will need to fine-tune the logging settings in our environment:

In [18]:
from fbprophet import Prophet
import logging

# disable informational messages from fbprophet
logging.getLogger('py4j').setLevel(logging.ERROR)

Based on our review of the data, it looks like we should set our overall growth pattern to linear and enable the evaluation of weekly and yearly seasonal patterns. We might also wish to set our seasonality mode to multiplicative as the seasonal pattern seems to grow with overall growth in sales:

In [20]:
# set model parameters
'''
model = Prophet(
  interval_width=0.95,
  daily_seasonality=True,
  weekly_seasonality=True,
  yearly_seasonality=False,
  seasonality_mode='multiplicative'
  )
'''

model = Prophet()

# fit the model to historical data
model.fit(history_pd)

Now that we have a trained model, let's use it to build a 90-day forecast:

In [22]:
# define a dataset including both historical dates & 90-days beyond the last available date
future_pd = model.make_future_dataframe(
  periods=90, 
  freq='d', 
  include_history=True
  )

# predict over the dataset
forecast_pd = model.predict(future_pd)

display(forecast_pd)

ds,trend,yhat_lower,yhat_upper,trend_lower,trend_upper,additive_terms,additive_terms_lower,additive_terms_upper,weekly,weekly_lower,weekly_upper,multiplicative_terms,multiplicative_terms_lower,multiplicative_terms_upper,yhat
2020-01-21T00:00:00.000+0000,-482.9148929540729,-8439.054088387342,6592.216919609824,-482.9148929540729,-482.9148929540729,-382.819328158745,-382.819328158745,-382.819328158745,-382.819328158745,-382.819328158745,-382.819328158745,0.0,0.0,0.0,-865.734221112818
2020-01-22T00:00:00.000+0000,-463.93719217520567,-7323.213815845426,6798.248200739783,-463.93719217520567,-463.93719217520567,-235.3989685744908,-235.3989685744908,-235.3989685744908,-235.3989685744908,-235.3989685744908,-235.3989685744908,0.0,0.0,0.0,-699.3361607496964
2020-01-23T00:00:00.000+0000,-444.9594913963384,-7733.677707623275,6586.530415446468,-444.9594913963384,-444.9594913963384,-88.18568816978618,-88.18568816978618,-88.18568816978618,-88.18568816978618,-88.18568816978618,-88.18568816978618,0.0,0.0,0.0,-533.1451795661245
2020-01-24T00:00:00.000+0000,-425.9817906174712,-7705.571126217286,6963.239762248082,-425.9817906174712,-425.9817906174712,209.0198612726683,209.0198612726683,209.0198612726683,209.0198612726683,209.0198612726683,209.0198612726683,0.0,0.0,0.0,-216.9619293448029
2020-01-25T00:00:00.000+0000,-407.00408983860393,-7281.161095905872,7178.33961856273,-407.00408983860393,-407.00408983860393,333.5570770002374,333.5570770002374,333.5570770002374,333.5570770002374,333.5570770002374,333.5570770002374,0.0,0.0,0.0,-73.4470128383665
2020-01-26T00:00:00.000+0000,-388.02638905973663,-8104.845714580344,7251.469401703149,-388.02638905973663,-388.02638905973663,79.9640077801566,79.9640077801566,79.9640077801566,79.9640077801566,79.9640077801566,79.9640077801566,0.0,0.0,0.0,-308.06238127958
2020-01-27T00:00:00.000+0000,-369.04868828086944,-7387.71033446336,6695.5145218551415,-369.04868828086944,-369.04868828086944,83.86303885150465,83.86303885150465,83.86303885150465,83.86303885150465,83.86303885150465,83.86303885150465,0.0,0.0,0.0,-285.1856494293648
2020-01-28T00:00:00.000+0000,-350.0709823012129,-8003.679430464537,6161.261157688011,-350.0709823012129,-350.0709823012129,-382.8193281588256,-382.8193281588256,-382.8193281588256,-382.8193281588256,-382.8193281588256,-382.8193281588256,0.0,0.0,0.0,-732.8903104600386
2020-01-29T00:00:00.000+0000,-331.0932763215565,-7197.646893162919,6559.191156870424,-331.0932763215565,-331.0932763215565,-235.39896857509785,-235.39896857509785,-235.39896857509785,-235.39896857509785,-235.39896857509785,-235.39896857509785,0.0,0.0,0.0,-566.4922448966544
2020-01-30T00:00:00.000+0000,-312.1155703419,-7816.692408340182,6802.874652396866,-312.1155703419,-312.1155703419,-88.18568817013751,-88.18568817013751,-88.18568817013751,-88.18568817013751,-88.18568817013751,-88.18568817013751,0.0,0.0,0.0,-400.3012585120375


How did our model perform? Here we can see the general and seasonal trends in our model presented as graphs:

In [24]:
trends_fig = model.plot_components(forecast_pd)
display(trends_fig)

And here, we can see how our actual and predicted data line up as well as a forecast for the future, though we will limit our graph to the last year of historical data just to keep it readable:

In [26]:
predict_fig = model.plot( forecast_pd, xlabel='date', ylabel='cases')

# adjust figure to display dates from last year + the 90 day forecast
xlim = predict_fig.axes[0].get_xlim()
#new_xlim = ( xlim[1]-(180.0+365.0), xlim[1]-90.0)
new_xlim = ( xlim[1] - 240.0, xlim[1]-90.0)
predict_fig.axes[0].set_xlim(new_xlim)

display(predict_fig)

In [27]:
predict_fig = model.plot( forecast_pd, xlabel='date', ylabel='cases')
display(predict_fig)

**NOTE** This visualization is a bit busy. Bartosz Mikulski provides [an excellent breakdown](https://www.mikulskibartosz.name/prophet-plot-explained/) of it that is well worth checking out.  In a nutshell, the black dots represent our actuals with the darker blue line representing our predictions and the lighter blue band representing our (95%) uncertainty interval.

Visual inspection is useful, but a better way to evaulate the forecast is to calculate Mean Absolute Error, Mean Squared Error and Root Mean Squared Error values for the predicted relative to the actual values in our set:

In [30]:
from sklearn.metrics import mean_squared_error, mean_absolute_error
from math import sqrt
from datetime import date

# get historical actuals & predictions for comparison
actuals_pd = history_pd[ history_pd['ds'] < date(2020, 7, 20) ]['y']
predicted_pd = forecast_pd[ forecast_pd['ds'] < date(2020, 7, 20) ]['yhat']

# calculate evaluation metrics
mae = mean_absolute_error(actuals_pd, predicted_pd)
mse = mean_squared_error(actuals_pd, predicted_pd)
rmse = sqrt(mse)

# print metrics to the screen
print( '\n'.join(['MAE: {0}', 'MSE: {1}', 'RMSE: {2}']).format(mae, mse, rmse) )

FBProphet provides [additional means](https://facebook.github.io/prophet/docs/diagnostics.html) for evaluating how your forecasts hold up over time. You're strongly encouraged to consider using these and those additional techniques when building your forecast models but we'll skip this here to focus on the scaling challenge.

####Build a Forecast for NEW_CASES

In [33]:
# query to aggregate data to date (ds) level
sql_statement = '''
  SELECT
    CAST(date as date) as ds,
    sum(new_cases) as y
  FROM train
  group by ds
  ORDER BY ds
  '''

# assemble dataset in Pandas dataframe
history_pd = spark.sql(sql_statement).toPandas()

# drop any missing records
history_pd = history_pd.dropna()

In [34]:
# set model parameters
'''
model = Prophet(
  interval_width=0.95,
  daily_seasonality=True,
  weekly_seasonality=True,
  yearly_seasonality=False,
  seasonality_mode='multiplicative'
  )
'''

model = Prophet()

# fit the model to historical data
model.fit(history_pd)

In [35]:
# define a dataset including both historical dates & 90-days beyond the last available date
future_pd = model.make_future_dataframe(
  periods=90, 
  freq='d', 
  include_history=True
  )

# predict over the dataset
forecast_pd = model.predict(future_pd)

display(forecast_pd)

ds,trend,yhat_lower,yhat_upper,trend_lower,trend_upper,additive_terms,additive_terms_lower,additive_terms_upper,weekly,weekly_lower,weekly_upper,multiplicative_terms,multiplicative_terms_lower,multiplicative_terms_upper,yhat
2020-01-21T00:00:00.000+0000,-273.59253127504843,-636.76983811458,691.4320375965895,-273.59253127504843,-273.59253127504843,288.36742975340474,288.36742975340474,288.36742975340474,288.36742975340474,288.36742975340474,288.36742975340474,0.0,0.0,0.0,14.774898478356306
2020-01-22T00:00:00.000+0000,-262.0482962749925,-841.9666394730186,469.652829220206,-262.0482962749925,-262.0482962749925,79.2471297577001,79.2471297577001,79.2471297577001,79.2471297577001,79.2471297577001,79.2471297577001,0.0,0.0,0.0,-182.8011665172924
2020-01-23T00:00:00.000+0000,-250.50406127493653,-889.1927703038157,438.6397027840146,-250.50406127493653,-250.50406127493653,26.12651324799514,26.12651324799514,26.12651324799514,26.12651324799514,26.12651324799514,26.12651324799514,0.0,0.0,0.0,-224.37754802694135
2020-01-24T00:00:00.000+0000,-238.9598262748806,-788.1037072318626,558.4229452519129,-238.95982627488064,-238.95982627488064,143.43935109610524,143.43935109610524,143.43935109610524,143.43935109610524,143.43935109610524,143.43935109610524,0.0,0.0,0.0,-95.52047517877536
2020-01-25T00:00:00.000+0000,-227.41559127482464,-1015.5695367018418,344.13603528218766,-227.41559127482464,-227.41559127482464,-79.98789472074749,-79.98789472074749,-79.98789472074749,-79.98789472074749,-79.98789472074749,-79.98789472074749,0.0,0.0,0.0,-307.40348599557217
2020-01-26T00:00:00.000+0000,-215.87135627476877,-1215.8375789819138,142.63212662149826,-215.87135627476877,-215.87135627476877,-330.61326377575995,-330.61326377575995,-330.61326377575995,-330.61326377575995,-330.61326377575995,-330.61326377575995,0.0,0.0,0.0,-546.4846200505287
2020-01-27T00:00:00.000+0000,-204.32712127471285,-974.9488880408736,321.8800098384018,-204.32712127471285,-204.32712127471285,-126.57926535763993,-126.57926535763993,-126.57926535763993,-126.57926535763993,-126.57926535763993,-126.57926535763993,0.0,0.0,0.0,-330.9063866323528
2020-01-28T00:00:00.000+0000,-192.7828860669532,-556.4050493142919,777.7488988851894,-192.7828860669532,-192.7828860669532,288.3674297534571,288.3674297534571,288.3674297534571,288.3674297534571,288.3674297534571,288.3674297534571,0.0,0.0,0.0,95.5845436865039
2020-01-29T00:00:00.000+0000,-181.23865085919348,-796.2247649591151,581.7123241086699,-181.2386508591935,-181.2386508591935,79.2471297577225,79.2471297577225,79.2471297577225,79.2471297577225,79.2471297577225,79.2471297577225,0.0,0.0,0.0,-101.99152110147098
2020-01-30T00:00:00.000+0000,-169.6944156514338,-855.871561142043,519.3313741321556,-169.6944156514338,-169.6944156514338,26.126513248252024,26.126513248252024,26.126513248252024,26.126513248252024,26.126513248252024,26.126513248252024,0.0,0.0,0.0,-143.5679024031818


In [36]:
trends_fig = model.plot_components(forecast_pd)
display(trends_fig)

In [37]:
predict_fig = model.plot( forecast_pd, xlabel='date', ylabel='cases')

xlim = predict_fig.axes[0].get_xlim()
print(xlim)
# new_xlim = ( xlim[1]-(180.0+365.0), xlim[1]-90.0)
# adjust figure to display dates from last 4 month  + the 30 day forecast
new_xlim = ( xlim[1] - 210.0, xlim[1] - 60)
predict_fig.axes[0].set_xlim(new_xlim)

display(predict_fig)

In [38]:
predict_fig = model.plot( forecast_pd, xlabel='date', ylabel='cases')
display(predict_fig)

In [39]:
from sklearn.metrics import mean_squared_error, mean_absolute_error
from math import sqrt
from datetime import date

# get historical actuals & predictions for comparison
today = date.today()
actuals_pd = history_pd[ history_pd['ds'] < today ]['y']
predicted_pd = forecast_pd[ forecast_pd['ds'] < today ]['yhat']

# calculate evaluation metrics
mae = mean_absolute_error(actuals_pd, predicted_pd)
mse = mean_squared_error(actuals_pd, predicted_pd)
rmse = sqrt(mse)

# print metrics to the screen
print( '\n'.join(['MAE: {0}', 'MSE: {1}', 'RMSE: {2}']).format(mae, mse, rmse) )