<a href="https://colab.research.google.com/github/ksetdekov/test_predict_ts/blob/main/sarimax_grid_search.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from math import sqrt
from multiprocessing import cpu_count
from joblib import Parallel
from joblib import delayed
from warnings import catch_warnings
from warnings import filterwarnings
from statsmodels.tsa.statespace.sarimax import SARIMAX
from sklearn.metrics import mean_squared_error
from pandas import read_csv
import pandas as pd


  import pandas.util.testing as tm


In [2]:
data_url = 'https://raw.githubusercontent.com/ksetdekov/test_predict_ts/main/2_5350833418869411277.csv'
data = pd.read_csv(data_url)
data.date = pd.to_datetime(data.date)

data = data.set_index('date')
data.dropna(inplace=True)

ts_data = data.copy()
ts_data['hour'] = [ts_data.index[i].hour for i in range(len(ts_data))]
ts_data['month'] = [ts_data.index[i].month for i in range(len(ts_data))]
ts_data['weekday'] = [ts_data.index[i].day for i in range(len(ts_data))]
ts_data['last_day_m'] = [ts_data.index[i].is_month_end for i in range(len(ts_data))]
ts_data['last_day_m']=ts_data['last_day_m'].astype(int)

# rolling window
clients_val = ts_data[["clients"]]
width = 24
shifted = clients_val.shift(width -1)
window = shifted.rolling(window=width)

new_df = pd.concat([window.min(), window.mean(), window.max()], axis=1)
new_df.columns = ['min24', 'mean24', 'max24']


width=48
shifted = clients_val.shift(width -1)
window = shifted.rolling(window=width)

new_df_2 = pd.concat([window.min(), window.mean(), window.max()], axis=1)
new_df_2.columns = ['min48', 'mean48', 'max48']

new_df_2= pd.merge(new_df_2, new_df, how='outer', on="date")
ts_data = pd.merge(ts_data, new_df_2, how='outer', on="date")

without_na_ts_data = ts_data.dropna()

without_na_ts_data.tail()

Unnamed: 0_level_0,clients,hour,month,weekday,last_day_m,min48,mean48,max48,min24,mean24,max24
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
2014-12-30 19:00:00,35.0,19,12,30,0,32.0,211.520833,444.0,20.0,180.125,373.0
2014-12-30 20:00:00,26.0,20,12,30,0,32.0,212.583333,444.0,20.0,173.25,373.0
2014-12-30 21:00:00,20.0,21,12,30,0,32.0,214.958333,444.0,20.0,166.583333,373.0
2014-12-30 22:00:00,8.0,22,12,30,0,32.0,218.020833,444.0,20.0,159.0,373.0
2014-12-30 23:00:00,16.0,23,12,30,0,32.0,222.166667,444.0,20.0,151.791667,373.0


In [8]:
# grid search sarima hyperparameters for monthly mean temp dataset

 
# one-step sarima forecast
def sarima_forecast(history, config):
	order, sorder, trend = config
	# define model
	model = SARIMAX(history, order=order, seasonal_order=sorder, trend=trend, enforce_stationarity=False, enforce_invertibility=False)
	# fit model
	model_fit = model.fit(disp=False)
	# make one step forecast
	yhat = model_fit.predict(len(history), len(history))
	return yhat[0]
 
# root mean squared error or rmse
def measure_rmse(actual, predicted):
	return sqrt(mean_squared_error(actual, predicted))
 
# split a univariate dataset into train/test sets
def train_test_split(data, n_test):
	return data[:-n_test], data[-n_test:]
 
# walk-forward validation for univariate data
def walk_forward_validation(data, n_test, cfg):
	predictions = list()
	# split dataset
	train, test = train_test_split(data, n_test)
	# seed history with training dataset
	history = [x for x in train]
	# step over each time-step in the test set
	for i in range(len(test)):
		# fit model and make forecast for history
		yhat = sarima_forecast(history, cfg)
		# store forecast in list of predictions
		predictions.append(yhat)
		# add actual observation to history for the next loop
		history.append(test[i])
	# estimate prediction error
	error = measure_rmse(test, predictions)
	return error
 
# score a model, return None on failure
def score_model(data, n_test, cfg, debug=False):
	result = None
	# convert config to a key
	key = str(cfg)
	# show all warnings and fail on exception if debugging
	if debug:
		result = walk_forward_validation(data, n_test, cfg)
	else:
		# one failure during model validation suggests an unstable config
		try:
			# never show warnings when grid searching, too noisy
			with catch_warnings():
				filterwarnings("ignore")
				result = walk_forward_validation(data, n_test, cfg)
		except:
			error = None
	# check for an interesting result
	if result is not None:
		print(' > Model[%s] %.3f' % (key, result))
	return (key, result)
 
# grid search configs
def grid_search(data, cfg_list, n_test, parallel=True):
	scores = None
	if parallel:
		# execute configs in parallel
		executor = Parallel(n_jobs=cpu_count(), backend='multiprocessing')
		tasks = (delayed(score_model)(data, n_test, cfg) for cfg in cfg_list)
		scores = executor(tasks)
	else:
		scores = [score_model(data, n_test, cfg) for cfg in cfg_list]
	# remove empty results
	scores = [r for r in scores if r[1] != None]
	# sort configs by error, asc
	scores.sort(key=lambda tup: tup[1])
	return scores
 
# create a set of sarima configs to try
def sarima_configs(seasonal=[0]):
	models = list()
	# define config lists
	p_params = [ 4]
	d_params = [ 1]
	q_params = [ 1, 2]
	t_params = ['n','c','t','ct']
	P_params = [0, 1, 2]
	D_params = [0, 1]
	Q_params = [0, 1, 2]
	m_params = seasonal
	# create config instances
	for p in p_params:
		for d in d_params:
			for q in q_params:
				for t in t_params:
					for P in P_params:
						for D in D_params:
							for Q in Q_params:
								for m in m_params:
									cfg = [(p,d,q), (P,D,Q,m), t]
									models.append(cfg)
	return models
 

In [4]:
data = without_na_ts_data.clients
# trim dataset to 1 year
data = data[-(30*24):]
data

date
2014-12-01 00:00:00    48.0
2014-12-01 01:00:00    36.0
2014-12-01 02:00:00    25.0
2014-12-01 03:00:00    21.0
2014-12-01 04:00:00    13.0
                       ... 
2014-12-30 19:00:00    35.0
2014-12-30 20:00:00    26.0
2014-12-30 21:00:00    20.0
2014-12-30 22:00:00     8.0
2014-12-30 23:00:00    16.0
Name: clients, Length: 720, dtype: float64

In [9]:
	# data split
	n_test = 12
	# model configs
	cfg_list = sarima_configs(seasonal=[24])
	# grid search
	scores = grid_search(data, cfg_list, n_test, parallel=True)
	print('done')
	# list top 3 configs
	for cfg, error in scores[:3]:
		print(cfg, error)

 > Model[[(4, 1, 1), (0, 0, 0, 24), 'n']] 16.964
 > Model[[(4, 1, 1), (0, 0, 1, 24), 'n']] 16.499
 > Model[[(4, 1, 1), (0, 1, 0, 24), 'n']] 38.388
 > Model[[(4, 1, 1), (0, 0, 2, 24), 'n']] 15.677


KeyboardInterrupt: ignored

### Выводы:
Лучшая модель из тех, что успел перебрать: [(p,d,q), (P,D,Q,m), t] = [[(1, 1, 1), (2, 0, 2, 24), 'n']]

RMSE на walk-forward валидации 15.156, среди данных за последний квартал, для выбора гиперпараметров.

Потенциально, в пространстве гиперпараметров можно найти лучше модель. 

обучим эту модель на всех данных и сделаем ей кросс-валидацию + предсказание.