Skip to content

Commit

Permalink
Break out clip at zero into separate param
Browse files Browse the repository at this point in the history
  • Loading branch information
mszheng committed Nov 5, 2017
1 parent 9985ce7 commit 6c80d89
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions python/fbprophet/forecaster.py
Expand Up @@ -98,6 +98,7 @@ def __init__(
interval_width=0.80,
uncertainty_samples=1000,
seasonality_type='additive',
clip_at_zero=False,
):
self.growth = growth

Expand All @@ -114,6 +115,8 @@ def __init__(
self.weekly_seasonality = weekly_seasonality
self.daily_seasonality = daily_seasonality

self.clip_at_zero = clip_at_zero

if holidays is not None:
if not (
isinstance(holidays, pd.DataFrame)
Expand Down Expand Up @@ -889,10 +892,12 @@ def predict(self, df=None):
df2 = pd.concat((df[cols], intervals, seasonal_components), axis=1)
if self.seasonality_type == 'multiplicative':
df2['yhat'] = df2['trend'] * df2['seasonal']
df2['yhat'] = df2['yhat'].clip(lower=0)
else:
df2['yhat'] = df2['trend'] + df2['seasonal']

if self.clip_at_zero:
df2['yhat'] = df2['yhat'].clip(lower=0)

return df2

@staticmethod
Expand Down Expand Up @@ -981,7 +986,7 @@ def predict_trend(self, df):

scaled_trend = trend * self.y_scale + df['floor']

if self.seasonality_type == 'multiplicative':
if self.clip_at_zero:
scaled_trend = np.clip(scaled_trend, 0, None)

return scaled_trend
Expand Down Expand Up @@ -1037,10 +1042,18 @@ def predict_seasonal_components(self, df):
data[component + '_upper'] = np.nanpercentile(comp, upper_p,
axis=1)

component_predictions = pd.DataFrame(data)

if self.seasonality_type == 'multiplicative':
return pd.DataFrame(data) + 1
else:
return pd.DataFrame(data)
component_predictions = component_predictions + 1

if self.clip_at_zero:
clip_cols = [col for col in component_predictions.columns
if (col.endswith('_lower') or col.endswith('_upper'))]
for col in clip_cols:
component_predictions[col] = component_predictions[col].clip(lower=0)

return component_predictions

def add_group_component(self, components, name, group):
"""Adds a component with given name that contains all of the components
Expand Down Expand Up @@ -1133,7 +1146,7 @@ def predict_uncertainty(self, df):

uncertainties = pd.DataFrame(series)

if self.seasonality_type == 'multiplicative':
if self.clip_at_zero:
clip_cols = [col for col in uncertainties.columns
if (col.endswith('_lower') or col.endswith('_upper'))]
for col in clip_cols:
Expand Down

0 comments on commit 6c80d89

Please sign in to comment.