# Understanding Washington State Cannabis Sale Forecasts

## Introduction

In November 2012, Washington State voters approved to legalize marijuana through [Washington Initiative 502 (I-502)](https://sos.wa.gov/_assets/elections/initiatives/i502.pdf) by a margin of approximately [56 to 44](https://results.vote.wa.gov/results/20121106/Initiative-Measure-No-502-Concerns-marijuana_ByCounty.html). As part of this initiative, Washington State created a new agency, the [Washington State Liquor and Cannabis Board (WSLCB)](https://lcb.wa.gov), for licensing and regulating liquor and marijuana. As part of its mandate the WSLCB provides public access to [data](https://data.lcb.wa.gov) on how Washington States’s marijuana market is performing.

In this exercise, we will analyze daily sales data from 2015-11-01 to 2017-03-04 in an attempt to answer the following questions:

1. How many previous weeks influence cannabis sales?
1. What days exhibit unusual sales?
1. Do stores in counties that voted to legalize marijuana behave differently that those that didn't?

## Setup

This analysis will use H2O's Sparkling Water to analyze cannabis sales data, where Spark will be used for data management and manipulation and H2O will be use for data analysis and python will be used as the client. This Jupyter notebook assumes that it was launched with a PySparkling context.

### Create H2O Context inside Spark Cluster

When using PySparkling, the first step is to create a Sparkling Water context within the Spark context, `spark`, so data can be passed back and forth between Spark and H2O.

In [None]:
from pysparkling import *
hc = H2OContext.getOrCreate(spark)

## Data Preparation using Spark

The next step is to use Spark to prepare the data for analysis by H2O. The particular data manipulations are:

1. Read daily sales transactions
1. Create daily sales aggregates by store
1. Find unusual days
1. Join county voting results for I-502 with aggregated data
1. Collapse infrequent counties and cities for analysis, i.e. manage high cardinality categorical columsn
1. Create lagged predictors
1. Create train / test splits for modeling

### Read Daily Sales Transactions for Different Stores

Reading the data involve a straight-forward use of the `spark.read.csv` function with the schema metadata. The data file contains six columns:

|     | Column Name | Description |
| --- | ----------- | ----------- |
|  1  | SalesDate | Date of sale|
|  2  | Organization | Organization that owns the store |
|  3  | County | County of store location|
|  4  | City | City of store location |
|  5  | Sales Price | Price of line item |
|  6  | Freq | Number of occurrences |

In [None]:
from pyspark.sql.types import *
from pyspark.sql.functions import *

In [None]:
schema = StructType([StructField('SalesDate', DateType(), metadata = {'desc': 'Date of sale'}),
                     StructField('Organization', StringType(), metadata = {'desc': 'Organization that owns the store'}),
                     StructField('County', StringType(), metadata = {'desc': 'County of store location'}),
                     StructField('City', StringType(), metadata = {'desc': 'City of store location'}),
                     StructField('SalesPrice', DoubleType(), metadata = {'desc': 'Price of line item'}),
                     StructField('Freq', IntegerType(), metadata = {'desc': 'Number of occurrences'})
                    ])

# https://s3-us-west-2.amazonaws.com/h2o-tutorials/data/topics/time_series/wa_cannabis/WA_Cannabis_Sales_Daily.csv
raw_sales = spark.read.csv('../../data/time_series/wa_cannabis/WA_Cannabis_Sales_Daily.csv',
                           header = True, schema = schema)

In [None]:
# Describe numeric and string columns
raw_sales.describe().show()

In [None]:
# Additional summaries
raw_sales.select([count('*').alias('nrows'), min('SalesDate'), max('SalesDate'), countDistinct('Organization')]).show()

### Create Daily Sales Aggregates by Store

In order to analyze aggreate sales demand, the transactional sales data are aggregated three ways, each of which uses the `log(x + 1)` function to manage their inherent skewness:

1. `Log1pDemandInThou = log1p(sum(store sales)/1000)`
1. `Log1pOtherDemandInThou = log1p(sum(citywise sales)/1000) - log1p(sum(store sales)/1000)`, demand from the rest of the organizations
1. `Log1pNumSales = log1p(sum(I(store sales > 0)))`

In [None]:
demand = raw_sales.groupBy('SalesDate', 'Organization', 'County', 'City') \
    .agg(log1p(sum(col('Freq') * col('SalesPrice')) / 1000).alias('Log1pDemandInThou'),
         log1p(sum(when(col('SalesPrice') > 0, col('Freq')).otherwise(0))).alias('Log1pNumSales')) \
    .alias('demand')
print("Number of Organization-Days: ", demand.count())

In [None]:
demand.describe().show()

In [None]:
daily_demand = raw_sales.groupBy('SalesDate') \
    .agg(log1p(sum(col('Freq') * col('SalesPrice')) / 1000).alias('Log1pCitywideDemandInThou')) \
    .alias('daily_demand')
print("Number of Days: ", daily_demand.count())

In [None]:
demand = demand.join(daily_demand, demand.SalesDate == daily_demand.SalesDate, how = "left_outer") \
    .select('demand.*', 'daily_demand.Log1pCitywideDemandInThou')
print("Number of Organization-Days: ", demand.count())

In [None]:
demand = demand.select('SalesDate', 'Organization', 'County', 'City', 'Log1pDemandInThou',
                       (col('Log1pCitywideDemandInThou') - col('Log1pDemandInThou')).alias('Log1pOtherDemandInThou'),
                       'Log1pNumSales') \
         .alias('demand')

In [None]:
demand.describe(['Log1pDemandInThou', 'Log1pOtherDemandInThou', 'Log1pNumSales']).show()

### Find Unusual Days

Fifteen unusual days are discovered by examining the week-over-week ratios in `Log1pCitywideDemandInThou`. Not surprisingly, these days are at or around [420](https://en.wikipedia.org/wiki/420_%28cannabis_culture%29) and the holidays Fourth of July, Thanksgiving, Christmas, and New Year's.

In [None]:
from pyspark.sql.window import Window

w = Window().orderBy(col('SalesDate'))
plot_data = \
  daily_demand.select('SalesDate', 'Log1pCitywideDemandInThou',
                      round(col('Log1pCitywideDemandInThou') / lag('Log1pCitywideDemandInThou', count = 7).over(w), 4).alias('WoW'),
                      round(abs(col('Log1pCitywideDemandInThou') / lag('Log1pCitywideDemandInThou', count = 7).over(w) - 1), 4).alias('AbsWoWDiff')) \
                      .orderBy('AbsWoWDiff', ascending = False).toPandas()

In [None]:
%matplotlib inline
plot_data['WoW'].plot.hist(bins = 50)

In [None]:
plot_data.head(20)

In [None]:
import datetime

%matplotlib inline
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize = (20, 10))
for dt in ['2015-11-25', '2015-12-02', '2015-12-23', '2015-12-24', '2015-12-25', '2015-12-26', '2016-01-01',
           '2016-01-07', '2016-04-20', '2016-04-27', '2016-07-01',
           '2016-11-23', '2016-11-30',                             '2016-12-25',               '2017-01-01']:
    plt.axvline(x = datetime.datetime.strptime(dt, '%Y-%m-%d'), color = 'orange', linestyle='--')
plot_data.plot(x = 'SalesDate', y = 'Log1pCitywideDemandInThou', ax = ax)
plt.show()

In [None]:
demand = demand.select('*',
                       when(col('SalesDate') == '2016-04-20', 'FourTwenty')
            .otherwise(when(col('SalesDate') == '2016-04-27', 'FourTwentySeven')
            .otherwise(when(col('SalesDate') == '2016-07-01', 'PreJuly4th')
            .otherwise(when(col('SalesDate') == '2015-11-25', 'ThanksgivingMinusOne')
            .otherwise(when(col('SalesDate') == '2016-11-23', 'ThanksgivingMinusOne')
            .otherwise(when(col('SalesDate') == '2015-12-02', 'ThanksgivingPlusSix')
            .otherwise(when(col('SalesDate') == '2016-11-30', 'ThanksgivingPlusSix')
            .otherwise(when(col('SalesDate') == '2015-12-23', 'ChristmasMinusTwo')
            .otherwise(when(col('SalesDate') == '2015-12-24', 'ChristmasMinusOne')
            .otherwise(when(col('SalesDate') == '2015-12-25', 'Christmas')
            .otherwise(when(col('SalesDate') == '2016-12-25', 'Christmas')
            .otherwise(when(col('SalesDate') == '2015-12-26', 'ChristmasPlusOne')
            .otherwise(when(col('SalesDate') == '2016-01-01', 'NewYearsDay')
            .otherwise(when(col('SalesDate') == '2017-01-01', 'NewYearsDay')
            .otherwise(when(col('SalesDate') == '2016-01-07', 'NewYearsDayPlusSix')
            .otherwise('N/A'))))))))))))))).alias('DayOfInterest')).alias('demand')

In [None]:
demand.groupBy('DayOfInterest').count().sort('count', ascending = False).show()

### Join Sales Data with Washington Initiative 502 Vote to Legalize Cannabis

Generally speaking counties in Western Washington voted to legalize cannabis sales, while those in Eastern Washington did not.

In [None]:
schema = StructType([StructField('County', StringType(), metadata = {'desc': 'County'}),
                     StructField('LegalizationVote', DoubleType(), metadata = {'desc': 'Fraction voting to legalize'})
                    ])

# https://s3-us-west-2.amazonaws.com/h2o-tutorials/data/topics/time_series/wa_cannabis/Initiative-Measure-No-502-Concerns-marijuana_ByCounty.csv
legalization = spark.read.csv('../../data/topics/time_series/wa_cannabis/Initiative-Measure-No-502-Concerns-marijuana_ByCounty.csv',
                              header = True, schema = schema).alias('legalization')

In [None]:
legalization.describe().show()

In [None]:
legalization.sort('LegalizationVote', ascending = False).show(5)

In [None]:
legalization.sort('LegalizationVote').show(5)

In [None]:
demand = demand.join(legalization, demand.County == legalization.County, how = "left_outer") \
    .select('demand.*', 'legalization.LegalizationVote')

In [None]:
demand.printSchema()

### Collapse Infrequent Counties and Cities

In order to correct for high cardinality county and city features, the infrequent locations are collapsed into an `OTHER` category.

In [None]:
demand.select('Organization', 'County').distinct().groupBy('County') \
    .agg(count('*').alias('Freq')).orderBy('Freq', ascending = False).show()

In [None]:
demand = demand.withColumn('County',
                           when(col('County') == 'KING', 'KING')
                .otherwise(when(col('County') == 'SPOKANE', 'SPOKANE')
                .otherwise(when(col('County') == 'SNOHOMISH', 'SNOHOMISH')
                .otherwise(when(col('County') == 'PIERCE', 'PIERCE')
                .otherwise(when(col('County') == 'KITSAP', 'KITSAP')
                .otherwise(when(col('County') == 'THURSTON', 'THURSTON')
                .otherwise(when(col('County') == 'WHATCOM', 'WHATCOM')
                .otherwise(when(col('County') == 'CLARK', 'CLARK')
                .otherwise('OTHER')))))))))

In [None]:
demand.select('Organization', 'County').distinct().groupBy('County') \
    .agg(count('*').alias('Freq')).orderBy('Freq', ascending = False).show()

In [None]:
demand.select('Organization', 'City').distinct().groupBy('City') \
    .agg(count('*').alias('Freq')).orderBy('Freq', ascending = False).show()

In [None]:
demand = demand.withColumn('City',
                           when(col('City') == 'SEATTLE', 'SEATTLE')
                .otherwise(when(col('City') == 'SPOKANE', 'SPOKANE')
                .otherwise(when(col('City') == 'TACOMA', 'TACOMA')
                .otherwise('OTHER'))))

In [None]:
demand.select('Organization', 'City').distinct().groupBy('City') \
    .agg(count('*').alias('Freq')).orderBy('Freq', ascending = False).show()

### Create Lagged Predictors

Up to seven weeks of sales will be considered as features in the models.

In [None]:
from pyspark.sql.window import Window

w = Window().partitionBy([col(x) for x in ['Organization']]).orderBy(col('SalesDate'))
demand = demand.select('SalesDate', 'DayOfInterest', 'Organization', 'County', 'City', 'LegalizationVote',
                       'Log1pDemandInThou',
                       lag('Log1pDemandInThou', count = 7).over(w).alias('Log1pDemandInThou_L7'),
                       lag('Log1pDemandInThou', count = 14).over(w).alias('Log1pDemandInThou_L14'),
                       lag('Log1pDemandInThou', count = 21).over(w).alias('Log1pDemandInThou_L21'),
                       lag('Log1pDemandInThou', count = 28).over(w).alias('Log1pDemandInThou_L28'),
                       lag('Log1pDemandInThou', count = 35).over(w).alias('Log1pDemandInThou_L35'),
                       lag('Log1pOtherDemandInThou', count = 7).over(w).alias('Log1pOtherDemandInThou_L7'),
                       lag('Log1pNumSales', count = 7).over(w).alias('Log1pNumSales_L7'))

In [None]:
demand.printSchema()

### Create Train / Test Splits for Modeling

Given the time series nature of this exercise, the train and test splits are based on time, where everything up to 2017-02-25 is in the training set and everything from 2017-02-26 onwards in the test set.

In [None]:
train = demand.filter(demand.SalesDate <= '2017-02-25')
test = demand.filter(demand.SalesDate >= '2017-02-26')

In [None]:
train.select(countDistinct('SalesDate').alias('Number of Dates in Training Data Set')).show()
test.select(countDistinct('SalesDate').alias('Number of Dates in Testing Data Set')).show()

## Analyze Data in H2O

The steps for analyzing the data in H2O are as follows:
1. Copy data from Spark to H2O
1. Segment organizations into folds for cross-validation
1. Run automatic machine learning to experiment with generalized linear models, random forests, extreme random trees, and gradient boosting machines.
1. Answer questions using leading model

### Copy Data from Spark to H2O

First copy the training and test data inside the H2O context.

In [None]:
import h2o
train_hf = hc.as_h2o_frame(train, "train")
test_hf = hc.as_h2o_frame(train, "test")

In [None]:
for j in ['Organization', 'County', 'City', 'DayOfInterest']:
    train_hf[j] = train_hf[j].asfactor()
    test_hf[j] = test_hf[j].asfactor()

### Segment Organizations into Folds for Cross-Validation

Then segment the training data into folds for cross-validation using the organization.

In [None]:
organizations = train_hf['Organization'].unique().sort(0).as_data_frame()
organizations = organizations.rename(columns = {'C1': 'Organization'})

In [None]:
import pandas as pd
import numpy as np
np.random.seed(2307)
organizations = organizations.assign(Fold = np.random.randint(1,6, size = organizations.count()))

In [None]:
print(organizations.groupby(['Fold']).count())

In [None]:
organizations_hf = h2o.H2OFrame(organizations, 'Organization')
organizations_hf['Organization'] = organizations_hf['Organization'].asfactor()

In [None]:
train_hf = train_hf.merge(organizations_hf, all_x = True, all_y = False)

In [None]:
train_hf.describe()

### Run Automatic Machine Learning

With the training and test data, we use H2O's Automatic Machine Learning to explore models of daily sales data based on generalized linear models, random forest, extreme random trees, and gradient boosting machines.

In [None]:
# Set Predictors
predictors = ['DayOfInterest', 'County', 'City', 'LegalizationVote',
              'Log1pDemandInThou_L7', 'Log1pDemandInThou_L14', 'Log1pDemandInThou_L21',
              'Log1pDemandInThou_L28', 'Log1pDemandInThou_L35',
              'Log1pOtherDemandInThou_L7', 'Log1pNumSales_L7']
response = 'Log1pDemandInThou'

In [None]:
from h2o.automl import H2OAutoML
aml = H2OAutoML(max_models = 6, exclude_algos = ['DeepLearning'])
aml.train(x = predictors, y = response,
          training_frame = train_hf,
          leaderboard_frame = test_hf,
          fold_column = 'Fold')

In [None]:
print(aml.leaderboard)

### Answer Questions using Leading Model

We can now return to the questions that motivated this analysis:

1. How many previous weeks influence cannabis sales?
1. What days exhibit unusual sales?
1. Do stores in counties that voted to legalize marijuana behave differently that those that didn't?

In [None]:
best_model = h2o.get_model(aml.leaderboard[0,'model_id'])

In [None]:
print("R^2: train = {:.4f}, valid = {:.4f}, xval = {:.4f}" \
      .format(best_model.r2(train = True), best_model.r2(valid = True), best_model.r2(xval = True)))

In [None]:
best_model.varimp_plot()

#### Examining demand lags

The variable importance plot of the leading model as well as the partial dependency plots show that five week's worth of sales should be sufficient for forecasting cannabis sales.

In [None]:
pdp_demand = best_model.partial_plot(data = train_hf,
                                     cols = ['Log1pDemandInThou_L7', 'Log1pDemandInThou_L14',
                                             'Log1pDemandInThou_L21', 'Log1pDemandInThou_L28',
                                             'Log1pDemandInThou_L35'])

In [None]:
pdp_other_lagged = best_model.partial_plot(data = train_hf,
                                           cols = ['Log1pOtherDemandInThou_L7', 'Log1pNumSales_L7'])

In [None]:
glm = h2o.get_model(aml.leaderboard[7, 'model_id'])

In [None]:
print("R^2: train = {:.4f}, valid = {:.4f}, xval = {:.4f}" \
      .format(glm.r2(train = True), glm.r2(valid = True), glm.r2(xval = True)))

In [None]:
for j in ['Log1pDemandInThou_L7', 'Log1pDemandInThou_L14', 'Log1pDemandInThou_L21', 'Log1pDemandInThou_L28',
          'Log1pDemandInThou_L35']:
    print(j + ": {:.4f}".format(glm.coef()[j]))

#### Effects of unusual days

As the time series plot suggested, the effects of unusual days, such as 420 and Christmas, are factor into the trained model.

In [None]:
pdp_doi = best_model.partial_plot(data = train_hf, cols = ['DayOfInterest'], plot = False)[0].as_data_frame()

In [None]:
pdp_doi.sort_values('mean_response', ascending = False)

#### Effect of legalization vote

Not too surprisingly, stores in different counties tend to behave similarly despite the differences in their voters' desire to legalize cannabis.

In [None]:
pdp_legalization = best_model.partial_plot(data = train_hf, cols = ['LegalizationVote'])

In [None]:
pdp_cats = best_model.partial_plot(data = train_hf, cols = ['County', 'City'], plot = False)
pdp_county = pdp_cats[0].as_data_frame()
pdp_city = pdp_cats[1].as_data_frame()

In [None]:
pdp_county.sort_values('mean_response', ascending = False)

In [None]:
pdp_city.sort_values('mean_response', ascending = False)

## Shutdown Sparkling Water Services

The last step in this script is to be a good cloud citizen and shut down the H2O cluster.

In [None]:
h2o.cluster().shutdown()