-
Notifications
You must be signed in to change notification settings - Fork 0
/
sarimax.py
executable file
·80 lines (67 loc) · 2.43 KB
/
sarimax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import pmdarima
from . import _stat_model
class Arima(_stat_model.StatModel):
"""
Implementation of a class for a (Seasonal) Autoregressive Integrated Moving Average with eXogenous factors ((S)ARIMAX) model.
See :obj:`~ForeTiS.model._base_model.BaseModel` for more information on the attributes.
"""
def define_model(self) -> pmdarima.ARIMA:
"""
Definition of the actual prediction model.
See :obj:`~ForeTiS.model._base_model.BaseModel` for more information.
"""
self.conf = True
self.use_exog = True
self.exog_cols_dropped = None
P = self.suggest_hyperparam_to_optuna('P')
D = self.suggest_hyperparam_to_optuna('D')
Q = self.suggest_hyperparam_to_optuna('Q')
seasonal_periods = self.suggest_hyperparam_to_optuna('seasonal_periods')
p = self.suggest_hyperparam_to_optuna('p')
d = self.suggest_hyperparam_to_optuna('d')
q = self.suggest_hyperparam_to_optuna('q')
self.trend = None
order = [p, d, q]
seasonal_order = [P, D, Q, seasonal_periods]
return pmdarima.ARIMA(order=order, seasonal_order=seasonal_order, method='lbfgs', maxiter=50, disp=1,
with_intercept=True, enforce_stationarity=False, suppress_warnings=True)
def define_hyperparams_to_tune(self) -> dict:
"""
See :obj:`~ForeTiS.model._base_model.BaseModel` for more information on the format.
"""
return {
'p': {
'datatype': 'int',
'lower_bound': 0,
'upper_bound': 3
},
'd': {
'datatype': 'int',
'lower_bound': 0,
'upper_bound': 1
},
'q': {
'datatype': 'int',
'lower_bound': 0,
'upper_bound': 3
},
'P': {
'datatype': 'int',
'lower_bound': 0,
'upper_bound': 3
},
'D': {
'datatype': 'int',
'lower_bound': 0,
'upper_bound': 1
},
'Q': {
'datatype': 'int',
'lower_bound': 0,
'upper_bound': 3
},
'seasonal_periods': {
'datatype': 'categorical',
'list_of_values': [self.datasets.seasonal_periods]
}
}