Skip to content

Commit

Permalink
ENH: Add STL Forecasting method
Browse files Browse the repository at this point in the history
Add method that allows STL to be combined with an model to
produce forecasts and prediction intervals

closes statsmodels#6372
  • Loading branch information
bashtage committed Jul 30, 2020
1 parent b706a69 commit ce1e101
Show file tree
Hide file tree
Showing 7 changed files with 873 additions and 95 deletions.
20 changes: 20 additions & 0 deletions docs/source/tsa.rst
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,26 @@ are available in:
ThetaModel
ThetaModelResults

Forecasting after STL Decomposition
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
:class:`statsmodels.tsa.seasonal.STL` is commonly used to remove seasonal
components from a time series. The deseasonalized time series can then
be modeled using a any non-seasonal model, and forecasts are constructed
by adding the forecast from the non-seasonal model to the estimates of
the seasonal component from the final full-cycle which are forecast using
a random-walk model.

.. module:: statsmodels.tsa.forecasting.stl
:synopsis: Models designed for forecasting

.. currentmodule:: statsmodels.tsa.forecasting.stl

.. autosummary::
:toctree: generated/

STLForecast
STLForecastResults

Prediction Results
""""""""""""""""""
Most foreasting methods support a ``get_prediction`` method that return
Expand Down
52 changes: 51 additions & 1 deletion examples/notebooks/stl_decomposition.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,56 @@
"res = mod.fit()\n",
"fig = res.plot(observed=False, resid=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Forecasting with STL\n",
"\n",
"``STLForecast`` simplifies the process of using STL to remove seasonalities and then using a standard time-series model to forecast the trend and cyclical components. \n",
"\n",
"Here we use STL to handle the seasonality and then an ARIMA(1,1,0) to model the deseasonalized data. The seasonal component is forecast from the find full cycle where \n",
"\n",
"$$E[S_{T+h}|\\mathcal{F}_T]=\\hat{S}_{T-k}$$\n",
"\n",
"where $k= m - h + m \\lfloor \\frac{h-1}{m} \\rfloor$. The forecast automatically adds the seasonal component forecast to the ARIMA forecast."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from statsmodels.tsa.forecasting.stl import STLForecast\n",
"from statsmodels.tsa.arima.model import ARIMA\n",
"\n",
"elec_equip.index.freq = elec_equip.index.inferred_freq\n",
"stlf = STLForecast(elec_equip, ARIMA, model_kwargs=dict(order=(1,1,0), trend=\"c\"))\n",
"stlf_res = stlf.fit()\n",
"\n",
"forecast = stlf_res.forecast(24)\n",
"plt.plot(elec_equip)\n",
"plt.plot(forecast)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"``summary`` contains information about both the time-series model and the STL decomposition."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(stlf_res.summary())"
]
}
],
"metadata": {
Expand All @@ -291,7 +341,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
"version": "3.7.7"
}
},
"nbformat": 4,
Expand Down
25 changes: 15 additions & 10 deletions statsmodels/tsa/_stl.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ cdef class STL(object):
"""
cdef object endog
cdef Py_ssize_t nobs
cdef int period, seasonal, trend, low_pass, seasonal_deg, trend_deg
cdef int _period, seasonal, trend, low_pass, seasonal_deg, trend_deg
cdef int low_pass_deg, low_pass_jump, trend_jump, seasonal_jump
cdef bint robust, _use_rw
cdef double[::1] _ya, _trend, _season, _rw
Expand All @@ -220,20 +220,20 @@ cdef class STL(object):
period = freq_to_period(freq)
if not _is_pos_int(period, False) or period < 2:
raise ValueError('period must be a positive integer >= 2')
self.period = period # np
self._period = period # np
if not _is_pos_int(seasonal, True) or seasonal < 3:
raise ValueError('seasonal must be an odd positive integer >= 3')
self.seasonal = seasonal # ns
if trend is None:
trend = int(np.ceil(1.5 * self.period / (1 - 1.5 / self.seasonal)))
trend = int(np.ceil(1.5 * self._period / (1 - 1.5 / self.seasonal)))
# ensure odd
trend += ((trend % 2) == 0)
if not _is_pos_int(trend, True) or trend < 3 or trend <= period:
raise ValueError('trend must be an odd positive integer '
'>= 3 where trend > period')
self.trend = trend # nt
if low_pass is None:
low_pass = self.period + 1
low_pass = self._period + 1
low_pass += ((low_pass % 2) == 0)
if not _is_pos_int(low_pass, True) or \
low_pass < 3 or low_pass <= period:
Expand All @@ -260,6 +260,11 @@ cdef class STL(object):
self._rw = np.ones(self.nobs)
self._work = np.zeros((7, self.nobs + 2 * period))

@property
def period(self) -> int:
"""The period length of the time series"""
return self._period

@property
def config(self) -> Dict[str, Union[int, bool]]:
"""
Expand All @@ -270,7 +275,7 @@ cdef class STL(object):
dict[str, Union[int, bool]]
The values used in the STL decomposition.
"""
return dict(period=self.period,
return dict(period=self._period,
seasonal=self.seasonal,
seasonal_deg=self.seasonal_deg,
seasonal_jump=self.seasonal_jump,
Expand Down Expand Up @@ -346,7 +351,7 @@ cdef class STL(object):
y, n, np, ns, nt, nl, isdeg, itdeg, ildeg, nsjump,
ntjump, nljump, ni, userw, rw, season, trend, work
->
self._ya, self.nobs, self.period, self.seasonal,
self._ya, self.nobs, self._period, self.seasonal,
self.trend, self.low_pass, self.seasonal_deg,
self.trend_deg, self.low_pass_deg, self.seasonal_jump,
self.trend_jump, self.low_pass_jump, inner_iter,
Expand All @@ -364,7 +369,7 @@ cdef class STL(object):
nl = self.low_pass
ildeg = self.low_pass_deg
nljump = self.low_pass_jump
np = self.period
np = self._period
season = self._season
nt = self.trend
itdeg = self.trend_deg
Expand Down Expand Up @@ -535,8 +540,8 @@ cdef class STL(object):
cdef int n, np

x = self._work[1, :]
n = self.nobs + 2 * self.period
np = self.period
n = self.nobs + 2 * self._period
np = self._period
trend = self._work[2, :]
work = self._work[0, :]
self._ma(x, n, np, trend)
Expand All @@ -559,7 +564,7 @@ cdef class STL(object):
# Original variable names
y = self._work[0, :]
n = self.nobs
np = self.period
np = self._period
ns = self.seasonal
isdeg = self.seasonal_deg
nsjump = self.seasonal_jump
Expand Down
3 changes: 2 additions & 1 deletion statsmodels/tsa/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
'SARIMAX', 'UnobservedComponents', 'VARMAX', 'DynamicFactor',
'MarkovRegression', 'MarkovAutoregression',
'ExponentialSmoothing', 'SimpleExpSmoothing', 'Holt',
'arma_generate_sample', 'ArmaProcess', 'STL',
'arma_generate_sample', 'ArmaProcess', 'STL', 'STLForecast',
'bk_filter', 'cf_filter', 'hp_filter']

from .ar_model import AR, AutoReg
Expand Down Expand Up @@ -52,4 +52,5 @@
from .holtwinters import ExponentialSmoothing, SimpleExpSmoothing, Holt
from .innovations import api as innovations
from .seasonal import STL
from .forecasting.stl import STLForecast
from .filters import bk_filter, cf_filter, hp_filter

0 comments on commit ce1e101

Please sign in to comment.