# SARIMA(p,d,q)(P,D,Q)m model

In [None]:
from statsmodels.tsa.seasonal import STL
from statsmodels.stats.diagnostic import acorr_ljungbox
from statsmodels.tsa.statespace.sarimax import SARIMAX
from statsmodels.tsa.stattools import adfuller
from tqdm import tqdm_notebook
from itertools import product
from typing import Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import warnings
warnings.filterwarnings('ignore')

%matplotlib inline

In [None]:
plt.rcParams["figure.figsize"] = (9,6)

## Exploring seasonality 

In [None]:
# source: https://raw.githubusercontent.com/plotly/datasets/master/monthly-milk-production-pounds.csv



In [None]:
fig, ax = plt.subplots()

ax.plot(df['Month'], df['Milk'])
ax.set_xlabel('Date')
ax.set_ylabel('Milk production (lbs/cow)')

plt.xticks(np.arange(0, 179, 12), np.arange(1962, 1977, 1))

fig.autofmt_xdate()
plt.tight_layout()

In [None]:
fig, ax = plt.subplots()

ax.plot(df['Month'], df['Milk'], markevery=np.arange(4, 169, 12), marker='o')
ax.set_xlabel('Date')
ax.set_ylabel('Milk production (lbs/cow)')

plt.xticks(np.arange(0, 179, 12), np.arange(1962, 1977, 1))

fig.autofmt_xdate()
plt.tight_layout()

In [None]:
fig, ax = plt.subplots()

ax.plot(df['Month'], df['Milk'])
for i in np.arange(0, 169, 12):
    ax.axvline(x=i, linestyle='--', color='black', linewidth=1)
ax.set_xlabel('Date')
ax.set_ylabel('Milk production (lbs/cow)')

plt.xticks(np.arange(0, 179, 12), np.arange(1962, 1977, 1))

fig.autofmt_xdate()
plt.tight_layout()

## Decomposition 

In [None]:
# Decompose the series


# Plot each component
fig, (ax1, ax2, ax3, ax4) = plt.subplots(nrows=4, ncols=1, sharex=True, figsize=(10,8))

ax1.plot(decomposition.observed)
ax1.set_ylabel('Observed')

ax2.plot(decomposition.trend)
ax2.set_ylabel('Trend')

ax3.plot(decomposition.seasonal)
ax3.set_ylabel('Seasonal')

ax4.plot(decomposition.resid)
ax4.set_ylabel('Residuals')

plt.xticks(np.arange(0, 179, 12), np.arange(1962, 1977, 1))

fig.autofmt_xdate()
plt.tight_layout()

## Forecasting with SARIMA 

In [None]:
# Run the ADF test


print(f'ADF Statistic: {ad_fuller_result[0]}')
print(f'p-value: {ad_fuller_result[1]}')

### Differencing and stationarity

In [None]:
# Difference and run the ADF test


print(f'ADF Statistic: {ad_fuller_result[0]}')
print(f'p-value: {ad_fuller_result[1]}')

In [None]:
# Seasonal difference and run the ADF test


print(f'ADF Statistic: {ad_fuller_result[0]}')
print(f'p-value: {ad_fuller_result[1]}')

### Define test set 

In [None]:
fig, ax = plt.subplots()

ax.plot(df['Month'], df['Milk'])
ax.set_xlabel('Date')
ax.set_ylabel('Milk production (lbs/cow)')
ax.axvspan(120, 167, color='#808080', alpha=0.2)

plt.xticks(np.arange(0, 179, 12), np.arange(1962, 1977, 1))

fig.autofmt_xdate()
plt.tight_layout()

In [None]:
# Split into a training and test set. Keep the last 48 data points for the test set



### Model selection with AIC 

In [None]:
def SARIMA_gridsearch(endog, min_p, max_p, min_q, max_q, min_P, max_P, min_Q, max_Q, d, D, s):
    
    all_p = range(min_p, max_p+1, 1)
    all_q = range(min_q, max_q+1, 1)
    # Define the range of P
    
    # Define the range of Q
    
    # Make a list of all unique order combinations
    
    print(f'Fitting {len(all_orders)} unique models')
    
    results = []
    
    for order in tqdm_notebook(all_orders):
        try: 
            # Fit a SARIMA model
            
        except:
            continue
            
        results.append([order, model.aic])
        
    result_df = pd.DataFrame(results)
    result_df.columns = ['(p,q,P,Q)', 'AIC']
    
    #Sort in ascending order, lower AIC is better
    result_df = result_df.sort_values(by='AIC', ascending=True).reset_index(drop=True)
    
    return result_df

In [None]:
min_p = 0
max_p = 2
min_q = 0
max_q = 4

min_P = 0
max_P = 2
min_Q = 0
max_Q = 2

d = 1
D = 1
s = 12

result_df = SARIMA_gridsearch(train['Milk'], min_p, max_p, min_q, max_q, min_P, max_P, min_q, max_Q, d, D, s)
result_df.head()

### Residuals analysis 

In [None]:
def ljung_box_test(residuals, is_seasonal, period):
    
    if is_seasonal:
        lb_df = acorr_ljungbox(residuals, period=period)
    else:
        max_lag = min([10, len(residuals)/5])
        
        lb_df = acorr_ljungbox(residuals, np.arange(1, max_lag+1, 1))

    fig, ax = plt.subplots()
    ax.plot(lb_df['lb_pvalue'], 'b-', label='p-values')
    ax.hlines(y=0.05, xmin=1, xmax=len(lb_df), color='black')
    plt.tight_layout()
    
    if all(pvalue > 0.05 for pvalue in lb_df['lb_pvalue']):
        print('All values are above 0.05. We fail to reject the null hypothesis. The residuals are uncorrelated')
    else:
        print('At least one p-value is smaller than 0.05')

In [None]:
# Run the Ljung-Box test



### Forecasting 

In [None]:
def rolling_predictions(df, train_len, horizon, window, period, method):
    
    TOTAL_LEN = train_len + horizon
    
    seasonal_steps = int((window/period))
    
    if method == 'mean':
        pred_mean = []
        
        for i in range(train_len, TOTAL_LEN, window):
            mean = np.mean(df[:i].values)
            pred_mean.extend(mean for _ in range(window))
        
        return pred_mean[:horizon]

    elif method == 'last':
        pred_last_value = []
        
        for i in range(train_len, TOTAL_LEN, window):
            last_value = diff[:i].iloc[-1].values[0]
            pred_last_value.extend(last_value for _ in range(window))

        return pred_last_value[:horizon]
    
    elif method == 'last_season':
        pred_last_season = []
        
        for i in range(train_len, TOTAL_LEN, window):
            last_season = df[:i][-period:].values
            pred_last_season.extend(last_season for _ in range(seasonal_steps))

        pred_last_season = np.array(pred_last_season).reshape(1, -1)
        
        return pred_last_season[0][:horizon]
    
    if method == 'SARIMA':
        # Get the predictions from SARIMA
            
        return pred_SARIMA[:horizon]

In [None]:
pred_df = test.copy()

TRAIN_LEN = len(train)
HORIZON = len(test)

windows = [12, 24, 36, 48]

for window in windows:
    pred_last_season = rolling_predictions(df['Milk'], TRAIN_LEN, HORIZON, window, 12, 'last_season')
    pred_SARIMA = rolling_predictions(df['Milk'], TRAIN_LEN, HORIZON, window, 12, 'SARIMA')

    pred_df[f'pred_last_season_{window}'] = pred_last_season
    pred_df[f'pred_SARIMA_{window}'] = pred_SARIMA

pred_df.head()

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(12,9))

for i, ax in enumerate(axes.flatten()):
    
    ax.plot(df['Milk'])
    ax.plot(pred_df['Milk'], 'b-', label='actual')
    ax.plot(pred_df[f'pred_last_season_{windows[i]}'], 'r-.', label='last_season')
    ax.plot(pred_df[f'pred_SARIMA_{windows[i]}'], 'k--', label='SARIMA')
    
    ax.legend(loc=2)
    ax.set_xlabel('Date')
    ax.set_ylabel('Milk production (lbs/cow)')
    ax.axvspan(120, 167, color='#808080', alpha=0.2)
    ax.set_xlim(100, 167)
    ax.set_title(f'Horizon = {windows[i]}')
    
plt.tight_layout()

#### Evaluation 

In [None]:
def mape(y_true, y_pred):
    return round(np.mean(np.abs((y_true - y_pred) / y_true)) * 100, 2)

In [None]:
mape_naive_seasonal_12 = mape(pred_df['Milk'], pred_df['pred_last_season_12'])
mape_SARIMA_12 = mape(pred_df['Milk'], pred_df['pred_SARIMA_12'])

mape_naive_seasonal_24 = mape(pred_df['Milk'], pred_df['pred_last_season_24'])
mape_SARIMA_24 = mape(pred_df['Milk'], pred_df['pred_SARIMA_24'])

mape_naive_seasonal_36 = mape(pred_df['Milk'], pred_df['pred_last_season_36'])
mape_SARIMA_36 = mape(pred_df['Milk'], pred_df['pred_SARIMA_36'])

mape_naive_seasonal_48 = mape(pred_df['Milk'], pred_df['pred_last_season_48'])
mape_SARIMA_48 = mape(pred_df['Milk'], pred_df['pred_SARIMA_48'])

In [None]:
mapes_naive_seasonal = [mape_naive_seasonal_12, mape_naive_seasonal_24, mape_naive_seasonal_36, mape_naive_seasonal_48]
mapes_SARIMA = [mape_SARIMA_12, mape_SARIMA_24, mape_SARIMA_36, mape_SARIMA_48]

fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(12,9))

x = ['last season', 'SARIMA']
width = 0.3

for i, ax in enumerate(axes.flatten()):
    
    y = [mapes_naive_seasonal[i], mapes_SARIMA[i]]
    ax.bar(x, y, width)
    ax.set_xlabel('Methods')
    ax.set_ylabel('MAPE (%)')
    ax.set_ylim(0, 10)
    ax.set_title(f'Horizon = {windows[i]}')    
    
    for index, value in enumerate(y):
        ax.text(x=index, y=value+0.5, s=str(value), ha='center')

plt.tight_layout()