In [66]:
import pandas as pd
from pyspark import SparkContext
from pyspark.sql import SQLContext
from pyspark.sql.functions import pandas_udf, PandasUDFType
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, IntegerType, DateType
from math import log as pylog
from fbprophet import Prophet

In [14]:
playoffs = pd.DataFrame({
  'holiday': 'playoff',
  'ds': pd.to_datetime(['2008-01-13', '2009-01-03', '2010-01-16',
                        '2010-01-24', '2010-02-07', '2011-01-08',
                        '2013-01-12', '2014-01-12', '2014-01-19',
                        '2014-02-02', '2015-01-11', '2016-01-17',
                        '2016-01-24', '2016-02-07']),
  'lower_window': 0,
  'upper_window': 1,
})
superbowls = pd.DataFrame({
  'holiday': 'superbowl',
  'ds': pd.to_datetime(['2010-02-07', '2014-02-02', '2016-02-07']),
  'lower_window': 0,
  'upper_window': 1,
})
holidays = pd.concat((playoffs, superbowls))

In [38]:
wallmart_data = pd.read_csv('wallmart_sales.csv')
wallmart_data.head()

Unnamed: 0,Store,Dept,Date,Weekly_Sales,IsHoliday
0,1,1,2010-02-05,24924.5,False
1,1,1,2010-02-12,46039.49,True
2,1,1,2010-02-19,41595.55,False
3,1,1,2010-02-26,19403.54,False
4,1,1,2010-03-05,21827.9,False


In [8]:
wallmart_1 = wallmart_data[wallmart_data.Store == 1]
wallmart_1.describe()

Unnamed: 0,Store,Dept,Weekly_Sales
count,10244.0,10244.0,10244.0
mean,1.0,44.391742,21710.543621
std,0.0,29.867247,27748.945511
min,1.0,1.0,-863.0
25%,1.0,20.0,3465.6225
50%,1.0,38.0,10289.375
75%,1.0,72.0,31452.9575
max,1.0,99.0,203670.47


In [39]:
wallmart_1['Date'] = pd.to_datetime(wallmart_1['Date'], format='%Y-%m-%d')#.map(lambda x: x.year)


In [43]:
wallmart_1['yearmonth'] = wallmart_1['Date'].map(lambda x: 100*x.year+x.month)

In [45]:
wallmart_1.yearmonth.describe()

count     10244.000000
mean     201102.860894
std          79.126794
min      201002.000000
25%      201010.000000
50%      201106.000000
75%      201202.000000
max      201210.000000
Name: yearmonth, dtype: float64

In [46]:
train = wallmart_1[wallmart_1.yearmonth <= 201202]

In [47]:
train.Date.min()

Timestamp('2010-02-05 00:00:00')

In [48]:
train.Date.max()

Timestamp('2012-02-24 00:00:00')

In [49]:
test = wallmart_1[wallmart_1.yearmonth > 201202]

In [51]:
print('Min date: {}, Max date: {}'.format(test.Date.min(), test.Date.max()))

Min date: 2012-03-02 00:00:00, Max date: 2012-10-26 00:00:00


In [52]:
test.describe()

Unnamed: 0,Store,Dept,Weekly_Sales,year,yearmonth
count,2499.0,2499.0,2499.0,2499.0,2499.0
mean,1.0,44.662265,22281.029448,2012.0,201206.438575
std,0.0,30.165536,28614.397402,0.0,2.288388
min,1.0,1.0,-223.0,2012.0,201203.0
25%,1.0,20.0,3355.86,2012.0,201204.0
50%,1.0,37.0,10284.77,2012.0,201206.0
75%,1.0,72.0,32861.655,2012.0,201208.0
max,1.0,99.0,165039.54,2012.0,201210.0


In [53]:
train.describe()

Unnamed: 0,Store,Dept,Weekly_Sales,year,yearmonth
count,7745.0,7745.0,7745.0,7745.0,7745.0
mean,1.0,44.304454,21526.470789,2010.629826,201069.440542
std,0.0,29.771799,27463.239011,0.618009,60.833349
min,1.0,1.0,-863.0,2010.0,201002.0
25%,1.0,20.0,3516.39,2010.0,201008.0
50%,1.0,38.0,10291.31,2011.0,201102.0
75%,1.0,72.0,31150.62,2011.0,201108.0
max,1.0,99.0,203670.47,2012.0,201202.0


In [54]:
sc = SparkContext()
sqlCtx = SQLContext(sc)

In [55]:
spark_df = sqlCtx.createDataFrame(train)

In [56]:
spark_df.show(3)

+-----+----+-------------------+------------+---------+----+---------+
|Store|Dept|               Date|Weekly_Sales|IsHoliday|year|yearmonth|
+-----+----+-------------------+------------+---------+----+---------+
|    1|   1|2010-02-05 00:00:00|     24924.5|    false|2010|   201002|
|    1|   1|2010-02-12 00:00:00|    46039.49|     true|2010|   201002|
|    1|   1|2010-02-19 00:00:00|    41595.55|    false|2010|   201002|
+-----+----+-------------------+------------+---------+----+---------+
only showing top 3 rows



In [59]:
schema = StructType([
    StructField('Store', IntegerType(), False), 
    StructField('Dept', IntegerType(), False),
    StructField('Date', DateType(), False),
    StructField('Weekly_Sales', DoubleType(), True),
    StructField('pred', DoubleType(), True)
])

In [60]:
# Input/output are both a pandas.DataFrame
@pandas_udf(schema, PandasUDFType.GROUPED_MAP)
def forecast_sales(wm_df):    
    from fbprophet import Prophet
    
    if wm_df.shape[0] < 2:
        plan_forward = None
    else:
        first_sku = wm_df.sort_values(by = 'Date')
        first_sku = first_sku.filter(['dt','sales'], axis =1)
        first_sku = first_sku.rename(columns={'dt':'ds','sales':'y'})
        #first_sku['y'] = first_sku['y'].apply(pd.to_numeric, errors='coerce', downcast='float')
        first_sku['y'] = first_sku.y.apply(pylog)
        n_data_points = len(first_sku.index)
        potential_change_points = 25
        if(n_data_points < potential_change_points ):
            potential_change_points  = n_data_points
        test_prophet = Prophet(n_changepoints=potential_change_points)
        test_prophet.fit(first_sku)
        future = test_prophet.make_future_dataframe(periods = 30)
        forecast = test_prophet.predict(future)
        forecast['trend'] = pow(2.303,forecast.trend)
        plan_forward = round(forecast['trend'].sum() - first_sku['y'].sum(), 0)
        if plan_forward > 1000000:
            test_prophet = Prophet(n_changepoints=potential_change_points)
            test_prophet.fit(first_sku, algorithm = 'Newton')
            future = test_prophet.make_future_dataframe(periods=30)
            forecast = test_prophet.predict(future)
            forecast['trend'] = pow(2.303, forecast.trend)
            plan_forward = round(forecast['trend'].sum() - first_sku['y'].sum(), 0)
    ret_df = sku_df.assign(pred=plan_forward)
    return ret_df[['sku','pred']]

In [84]:
sample = spark_df.filter("Dept = 1").toPandas()

sample = sample.sort_values(by = 'Date')
sample = sample.filter(['Date','Weekly_Sales'], axis =1)
sample = sample.rename(columns={'Date':'ds','Weekly_Sales':'y'})
sample['y'] = sample.y.apply(pylog)
test_prophet = Prophet()
test_prophet.fit(sample)
future = test_prophet.make_future_dataframe(periods = 10, freq='W-FRI')
forecast = test_prophet.predict(future)
forecast['trend'] = pow(2.303,forecast.trend)
# plan_forward = round(forecast['trend'].sum() - sample['y'].sum(), 0)
# if plan_forward > 1000000:
#     test_prophet = Prophet(n_changepoints=potential_change_points)
#     test_prophet.fit(sample, algorithm = 'Newton')
#     future = test_prophet.make_future_dataframe(periods=30)
#     forecast = test_prophet.predict(future)
#     forecast['trend'] = pow(2.303, forecast.trend)
#     plan_forward = round(forecast['trend'].sum() - sample['y'].sum())
    
# sample = sample.assign(pred=plan_forward)
forecast.tail(10)

INFO:fbprophet.forecaster:Disabling weekly seasonality. Run prophet with weekly_seasonality=True to override this.
INFO:fbprophet.forecaster:Disabling daily seasonality. Run prophet with daily_seasonality=True to override this.


Unnamed: 0,ds,trend,yhat_lower,yhat_upper,trend_lower,trend_upper,additive_terms,additive_terms_lower,additive_terms_upper,yearly,yearly_lower,yearly_upper,multiplicative_terms,multiplicative_terms_lower,multiplicative_terms_upper,yhat
108,2012-03-02,4028.379418,9.704353,10.206602,9.950844,9.950844,-0.005189,-0.005189,-0.005189,-0.005189,-0.005189,-0.005189,0.0,0.0,0.0,9.945655
109,2012-03-09,4027.763349,9.531339,10.013865,9.950651,9.95067,-0.160124,-0.160124,-0.160124,-0.160124,-0.160124,-0.160124,0.0,0.0,0.0,9.790537
110,2012-03-16,4027.147374,9.644925,10.131054,9.950449,9.950505,-0.068515,-0.068515,-0.068515,-0.068515,-0.068515,-0.068515,0.0,0.0,0.0,9.881962
111,2012-03-23,4026.531494,9.874842,10.346969,9.950245,9.950347,0.1616,0.1616,0.1616,0.1616,0.1616,0.1616,0.0,0.0,0.0,10.111894
112,2012-03-30,4025.915707,10.059199,10.528488,9.950033,9.950193,0.343548,0.343548,0.343548,0.343548,0.343548,0.343548,0.0,0.0,0.0,10.293658
113,2012-04-06,4025.300015,10.101811,10.57953,9.949814,9.950041,0.380973,0.380973,0.380973,0.380973,0.380973,0.380973,0.0,0.0,0.0,10.3309
114,2012-04-13,4024.684417,10.007674,10.513075,9.949597,9.949895,0.314641,0.314641,0.314641,0.314641,0.314641,0.314641,0.0,0.0,0.0,10.264385
115,2012-04-20,4024.068913,9.925621,10.416623,9.949377,9.949753,0.226083,0.226083,0.226083,0.226083,0.226083,0.226083,0.0,0.0,0.0,10.175643
116,2012-04-27,4023.453503,9.847831,10.313016,9.949137,9.949609,0.134283,0.134283,0.134283,0.134283,0.134283,0.134283,0.0,0.0,0.0,10.08366
117,2012-05-04,4022.838187,9.709837,10.190669,9.948907,9.949479,0.005192,0.005192,0.005192,0.005192,0.005192,0.005192,0.0,0.0,0.0,9.954386


In [74]:
test.head()

Unnamed: 0,Store,Dept,Date,Weekly_Sales,IsHoliday,year,yearmonth
108,1,1,2012-03-02,20113.03,False,2012,201203
109,1,1,2012-03-09,21140.07,False,2012,201203
110,1,1,2012-03-16,22366.88,False,2012,201203
111,1,1,2012-03-23,22107.7,False,2012,201203
112,1,1,2012-03-30,28952.86,False,2012,201203
