In [1]:
#https://machinelearningmastery.com/grid-search-arima-hyperparameters-with-python/

In [6]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import statsmodels.api as sm
from sklearn.metrics import mean_squared_error

import warnings
warnings.filterwarnings('ignore')

In [3]:
# one-step sarima forecast
def sarima_forecast(history, config):
	order, sorder, trend = config
	# define model
	model = sm.tsa.SARIMAX(history, order=order, 
                    seasonal_order=sorder, 
                    trend=trend, 
                    enforce_stationarity=False, 
                    enforce_invertibility=False)
	# fit model
	model_fit = model.fit(disp=False)
	# make one step forecast
	yhat = model_fit.predict(len(history), len(history))
	return yhat[0]

In [4]:
# split a univariate dataset into train and test set 
def split_train_test(data, train_size):
    return data[0:train_size], data[train_size:len(data)]

In [7]:
def get_rmse_socre(actual, predicted):
    return np.sqrt(mean_squared_error(actual, predicted))

The walk_forward_validation() function below taking a univariate time series, a number of time steps to use in the test set, and an array of model configuration.

In [9]:
def walk_forward_validation(data, n_tests, cfg):
    predictions=list()
    train, test = split_train_test(data, n_tests)
    # seed history with training dataset
    history = [x for x in train]
    # step over each time-step in the test set
    for t in range(len(test)):
        forecast_res = sarima_forecast(history=history, config=cfg)
        predictions.append(forecast_res)
        history.append(test[t])
        
    # estimate prediction error
    rmse = get_rmse_socre(test, predicted=predictions)
    return rmse

In [10]:
def score_model(data, n_test, cfg, debug=False):
	result = None
	# convert config to a key
	key = str(cfg)
	# show all warnings and fail on exception if debugging
	if debug:
		result = walk_forward_validation(data, n_test, cfg)
	else:
		# one failure during model validation suggests an unstable config
		try:
			# never show warnings when grid searching, too noisy
			with catch_warnings():
				filterwarnings("ignore")
				result = walk_forward_validation(data, n_test, cfg)
		except:
			error = None
	# check for an interesting result
	if result is not None:
		print(' > Model[%s] %.3f' % (key, result))
	return (key, result)

In [14]:
cpu_count()

12

In [12]:
from multiprocessing import cpu_count
from joblib import Parallel
from joblib import delayed
from warnings import catch_warnings
from warnings import filterwarnings
# grid search configs
def grid_search(data, cfg_list, n_test, parallel=True):
	scores = None
	if parallel:
		# execute configs in parallel
		executor = Parallel(n_jobs=cpu_count(), backend='multiprocessing')
		tasks = (delayed(score_model)(data, n_test, cfg) for cfg in cfg_list)
		scores = executor(tasks)
	else:
		scores = [score_model(data, n_test, cfg) for cfg in cfg_list]
	# remove empty results
	scores = [r for r in scores if r[1] != None]
	# sort configs by error, asc
	scores.sort(key=lambda tup: tup[1])
	return scores


In [15]:
# create a set of sarima configs to try
def sarima_configs(seasonal=[0]):
	models = list()
	# define config lists
	p_params = [0, 1, 2]
	d_params = [0, 1]
	q_params = [0, 1, 2]
	t_params = ['n','c','t','ct']
	P_params = [0, 1, 2]
	D_params = [0, 1]
	Q_params = [0, 1, 2]
	m_params = seasonal
	# create config instances
	for p in p_params:
		for d in d_params:
			for q in q_params:
				for t in t_params:
					for P in P_params:
						for D in D_params:
							for Q in Q_params:
								for m in m_params:
									cfg = [(p,d,q), (P,D,Q,m), t]
									models.append(cfg)
	return models

In [16]:
if __name__ == '__main__':
	# define dataset
	data = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
	print(data)
	# data split
	n_test = 4
	# model configs
	cfg_list = sarima_configs()
	# grid search
	scores = grid_search(data, cfg_list, n_test)
	print('done')
	# list top 3 configs
	for cfg, error in scores[:3]:
		print(cfg, error)

[10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
 > Model[[(0, 0, 0), (0, 0, 0, 0), 'n']] 76.920
 > Model[[(0, 0, 0), (0, 0, 0, 0), 'c']] 33.603
 > Model[[(0, 1, 0), (0, 0, 0, 0), 'n']] 10.000
 > Model[[(0, 1, 0), (0, 0, 0, 0), 't']] 5.512
 > Model[[(0, 0, 0), (0, 0, 0, 0), 't']] 6.383
 > Model[[(0, 1, 0), (0, 0, 0, 0), 'c']] 0.000
 > Model[[(0, 0, 1), (0, 0, 0, 0), 'n']] 76.920 > Model[[(0, 0, 1), (0, 0, 0, 0), 'c']] 52.886

 > Model[[(0, 0, 1), (0, 0, 0, 0), 't']] 66.599
 > Model[[(0, 1, 1), (0, 0, 0, 0), 'n']] 8.948
 > Model[[(0, 1, 1), (0, 0, 0, 0), 'c']] 0.001
 > Model[[(0, 0, 0), (0, 0, 0, 0), 'ct']] 0.052 > Model[[(0, 0, 1), (0, 0, 0, 0), 'ct']] 1.322

 > Model[[(1, 0, 0), (0, 0, 0, 0), 'n']] 6.383
 > Model[[(0, 1, 1), (0, 0, 0, 0), 't']] 2.944
 > Model[[(1, 0, 0), (0, 0, 0, 0), 't']] 6.383 > Model[[(1, 0, 0), (0, 0, 0, 0), 'c']] 0.000

 > Model[[(1, 0, 0), (0, 0, 0, 0), 'ct']] 0.000
 > Model[[(1, 0, 1), (0, 0, 0, 0), 'n']] 3.980
 > Model[[(1, 0, 1), (0, 0, 0, 0), '