Skip to content

Commit

Permalink
BUG: Fix a future issue in ExpSmooth
Browse files Browse the repository at this point in the history
Ensure inputs are sanitized to prevent pandas Series from entering
  • Loading branch information
bashtage committed Jul 18, 2019
1 parent 2a0cb32 commit a387cda
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 19 deletions.
47 changes: 30 additions & 17 deletions statsmodels/tsa/holtwinters.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
from scipy.stats import boxcox

from statsmodels.base.model import Results
from statsmodels.base.wrapper import populate_wrapper, union_dicts, ResultsWrapper
from statsmodels.tools.validation import array_like
from statsmodels.base.wrapper import (populate_wrapper, union_dicts,
ResultsWrapper)
from statsmodels.tools.validation import (array_like, bool_like, float_like,
string_like, int_like)
from statsmodels.tsa.base.tsa_model import TimeSeriesModel
from statsmodels.tsa.tsatools import freq_to_period
import statsmodels.tsa._exponential_smoothers as smoothers
Expand Down Expand Up @@ -488,10 +490,14 @@ def __init__(self, endog, trend=None, damped=False, seasonal=None,
self.endog = self.endog
self._y = self._data = array_like(endog, 'endog', contiguous=True,
order='C')
options = ("add", "mul", "additive", "multiplicative")
trend = string_like(trend, 'trend', options=options, optional=True)
if trend in ['additive', 'multiplicative']:
trend = {'additive': 'add', 'multiplicative': 'mul'}[trend]
self.trend = trend
self.damped = damped
self.damped = bool_like(damped, 'damped')
seasonal = string_like(seasonal, 'seasonal', options=options,
optional=True)
if seasonal in ['additive', 'multiplicative']:
seasonal = {'additive': 'add', 'multiplicative': 'mul'}[seasonal]
self.seasonal = seasonal
Expand All @@ -504,7 +510,8 @@ def __init__(self, endog, trend=None, damped=False, seasonal=None,
if self.damped and not self.trending:
raise ValueError('Can only dampen the trend component')
if self.seasoning:
self.seasonal_periods = seasonal_periods
self.seasonal_periods = int_like(seasonal_periods,
'seasonal_periods', optional=True)
if seasonal_periods is None:
self.seasonal_periods = freq_to_period(self._index_freq)
if self.seasonal_periods <= 1:
Expand Down Expand Up @@ -608,13 +615,15 @@ def fit(self, smoothing_level=None, smoothing_slope=None, smoothing_seasonal=Non
"""
# Variable renames to alpha,beta, etc as this helps with following the
# mathematical notation in general
alpha = smoothing_level
beta = smoothing_slope
gamma = smoothing_seasonal
phi = damping_slope
l0 = self._l0 = initial_level
b0 = self._b0 = initial_slope

alpha = float_like(smoothing_level, 'smoothing_level', True)
beta = float_like(smoothing_slope, 'smoothing_slope', True)
gamma = float_like(smoothing_seasonal, 'smoothing_seasonal', True)
phi = float_like(damping_slope, 'damping_slope', True)
l0 = self._l0 = float_like(initial_level, 'initial_level', True)
b0 = self._b0 = float_like(initial_slope, 'initial_slope', True)
if start_params is not None:
start_params = array_like(start_params, 'start_params',
contiguous=True)
data = self._data
damped = self.damped
seasoning = self.seasoning
Expand Down Expand Up @@ -675,18 +684,22 @@ def fit(self, smoothing_level=None, smoothing_slope=None, smoothing_seasonal=Non
txi = txi.astype(np.bool)
bounds = np.array([(0.0, 1.0), (0.0, 1.0), (0.0, 1.0),
(0.0, None), (0.0, None), (0.0, 1.0)] + [(None, None), ] * m)
args = (txi.astype(np.uint8), p, y, lvls, b, s, m, self.nobs, max_seen)
args = (txi.astype(np.uint8), p, y, lvls, b, s, m, self.nobs,
max_seen)
if start_params is None and np.any(txi) and use_brute:
res = brute(func, bounds[txi], args, Ns=20, full_output=True, finish=None)
res = brute(func, bounds[txi], args, Ns=20,
full_output=True, finish=None)
p[txi], max_seen, _, _ = res
else:
if start_params is not None:
start_params = np.atleast_1d(np.squeeze(start_params))
if len(start_params) != xi.sum():
raise ValueError('start_params must have {0} values but '
'has {1} instead'.format(len(xi), len(start_params)))
msg = 'start_params must have {0} values but ' \
'has {1} instead'
nxi, nsp = len(xi), len(start_params)
raise ValueError(msg.format(nxi, nsp))
p[xi] = start_params
args = (xi.astype(np.uint8), p, y, lvls, b, s, m, self.nobs, max_seen)
args = (xi.astype(np.uint8), p, y, lvls, b, s, m,
self.nobs, max_seen)
max_seen = func(np.ascontiguousarray(p[xi]), *args)
# alpha, beta, gamma, l0, b0, phi = p[:6]
# s0 = p[6:]
Expand Down
2 changes: 1 addition & 1 deletion statsmodels/tsa/tests/test_holtwinters.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def test_negative_multipliative(trend_seasonal):
@pytest.mark.parametrize('seasonal', SEASONALS)
def test_dampen_no_trend(seasonal):
y = -np.ones(100)
with pytest.raises(ValueError):
with pytest.raises(TypeError):
ExponentialSmoothing(housing_data, trend=False, seasonal=seasonal, damped=True,
seasonal_periods=10)

Expand Down
2 changes: 1 addition & 1 deletion statsmodels/tsa/tsatools.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,4 +804,4 @@ def freq_to_period(freq):

__all__ = ['lagmat', 'lagmat2ds','add_trend', 'duplication_matrix',
'elimination_matrix', 'commutation_matrix',
'vec', 'vech', 'unvec', 'unvech']
'vec', 'vech', 'unvec', 'unvech', 'freq_to_period']

0 comments on commit a387cda

Please sign in to comment.