Implementation of Arima using the ibex_ARIMA class in `src.arima`

In [1]:
# Standard Imports
import importlib
import os
import numpy as np
import pandas as pd
import sys

# visualisation
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go

# Src modules
sys.path.append(os.path.dirname(os.getcwd())) # Add the parent directory to the Python path so we can import src modules
import src
from src.arima import arima_trial, ibex_ARIMA
from src.data_setup import get_data, train_val_split, create_day_of_week, RAW_PATH, RESULTS_PATH
from src.model_evaluation import model_eval_pipeline, eval_hypothesis_test, transform_daily_sales_predictions
from src.visualisation import plot_heatmap, plot_sales_by, plot_time_series_preds, plot_rolling_average_stdev

In [2]:
train, test, stores, transactions = get_data()
# merge stores data into train
train = train.merge(stores, on ='store_nbr')
assert train.isnull().any().any() == False
assert train.duplicated().any() == False
print('no null or duplicate values in the training data')

loading pickled dataframes...
no null or duplicate values in the training data


In [3]:
# split into train and validation data
train, validation = train_val_split(train)
# validation = validation.groupby('date')['sales'].sum().reset_index()
train.shape, validation.shape

((2949210, 25), (51678, 25))

In [4]:
importlib.reload(src.arima)
from src.arima import ibex_ARIMA

In [5]:
arima = ibex_ARIMA(train)
# basic_arima.plot_autocorrelation()
# basic_arima.test_stationarity()

In [6]:
arima.fit((1, 1, 1), (1, 1, 1, 7), plot_summary=False)

In [7]:
val = arima.evaluate(validation=validation.groupby('date')['sales'].sum().reset_index())

In [8]:
val

Unnamed: 0,date,sales,day_of_week,pred_diff_sales,pred_sales
0,2017-07-18,730133.7,2,-79304.309105,739021.2
1,2017-07-19,767978.8,3,31470.695726,770491.9
2,2017-07-20,688288.1,4,-133149.956091,637341.9
3,2017-07-21,782418.3,5,146184.255046,783526.2
4,2017-07-22,932902.1,6,236679.372029,1020206.0
5,2017-07-23,1024289.0,7,96954.549813,1117160.0
6,2017-07-24,816564.3,1,-311570.117048,805590.0
7,2017-07-25,713581.6,2,-70624.491944,734965.5
8,2017-07-26,740653.1,3,28274.542886,763240.0
9,2017-07-27,659849.8,4,-131729.048438,631511.0


In [23]:
val = val.drop(columns=['pred_diff_sales'])
val


Unnamed: 0,date,sales,day_of_week,pred_sales
0,2017-07-18,730133.7,2,739021.2
1,2017-07-19,767978.8,3,770491.9
2,2017-07-20,688288.1,4,637341.9
3,2017-07-21,782418.3,5,783526.2
4,2017-07-22,932902.1,6,1020206.0
5,2017-07-23,1024289.0,7,1117160.0
6,2017-07-24,816564.3,1,805590.0
7,2017-07-25,713581.6,2,734965.5
8,2017-07-26,740653.1,3,763240.0
9,2017-07-27,659849.8,4,631511.0


In [25]:
train.head()

Unnamed: 0,date,store_nbr,family,sales,onpromotion,year,month,week,day_of_week,day_of_month,...,is_month_end,is_quarter_start,is_quarter_end,is_year_start,is_year_end,season,city,state,type,cluster
0,2013-01-01,1,AUTOMOTIVE,0.0,0.0,2013,1,1,2,1,...,0,1,0,1,0,0,Quito,Pichincha,D,13
1,2013-01-01,1,BABY CARE,0.0,0.0,2013,1,1,2,1,...,0,1,0,1,0,0,Quito,Pichincha,D,13
2,2013-01-01,1,BEAUTY,0.0,0.0,2013,1,1,2,1,...,0,1,0,1,0,0,Quito,Pichincha,D,13
3,2013-01-01,1,BEVERAGES,0.0,0.0,2013,1,1,2,1,...,0,1,0,1,0,0,Quito,Pichincha,D,13
4,2013-01-01,1,BOOKS,0.0,0.0,2013,1,1,2,1,...,0,1,0,1,0,0,Quito,Pichincha,D,13


In [26]:
importlib.reload(src.model_evaluation)
from src.model_evaluation import transform_daily_sales_predictions

val = transform_daily_sales_predictions(val, train)
val.head()

Unnamed: 0,date,sales,day_of_week,pred_sales,store_nbr,family,pct_sales,transformed_sales
0,2017-07-18,730133.6875,2,739021.190895,1,AUTOMOTIVE,6.798028e-06,5.023887
1,2017-07-18,730133.6875,2,739021.190895,1,BABY CARE,0.0,0.0
2,2017-07-18,730133.6875,2,739021.190895,1,BEAUTY,4.529532e-06,3.34742
3,2017-07-18,730133.6875,2,739021.190895,1,BEVERAGES,0.002892669,2137.74392
4,2017-07-18,730133.6875,2,739021.190895,1,BOOKS,2.611756e-07,0.193014


In [21]:
model_eval_pipeline(validation['sales'], a['transformed_pred_sales'])

{'mae': 336.5845219207806,
 'mse': 1118342.995234885,
 'rmse': 1057.5173734908024,
 'rmsle': 1.1717345176616767,
 'r2': 0.2998583776179432}