Skip to content

Commit

Permalink
Update ProphetModel to handle external timestamp (#203)
Browse files Browse the repository at this point in the history
  • Loading branch information
d-a-bunin committed Dec 26, 2023
1 parent 4f3afd5 commit 1e9cb9d
Show file tree
Hide file tree
Showing 6 changed files with 1,288 additions and 757 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
- Update change point transforms (`ChangePointsSegmentationTransform`, `ChangePointsTrendTransform`, `ChangePointsLevelTransform`, `TrendTransform`) to handle integer timestamp ([#176](https://github.com/etna-team/etna/pull/176))
- Update `BATSModel`, `TBATSModel` models to work with integer timestamp ([#195](https://github.com/etna-team/etna/pull/195))
- Update `ProphetModel` to handle external timestamp ([#203](https://github.com/etna-team/etna/pull/203))

### Fixed
-
Expand Down
53 changes: 47 additions & 6 deletions etna/models/prophet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Sequence
from typing import Set
from typing import Union
from typing import cast

import pandas as pd

Expand Down Expand Up @@ -50,6 +51,7 @@ def __init__(
uncertainty_samples: Union[int, bool] = 1000,
stan_backend: Optional[str] = None,
additional_seasonality_params: Iterable[Dict[str, Union[str, float, int]]] = (),
timestamp_column: Optional[str] = None,
):

self.growth = growth
Expand All @@ -69,6 +71,7 @@ def __init__(
self.uncertainty_samples = uncertainty_samples
self.stan_backend = stan_backend
self.additional_seasonality_params = additional_seasonality_params
self.timestamp_column = timestamp_column

self.model = self._create_model()

Expand Down Expand Up @@ -131,8 +134,12 @@ def _select_regressors(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
)

if self.regressor_columns:
columns = deepcopy(self.regressor_columns)
if self.timestamp_column in columns:
columns.remove(self.timestamp_column)

try:
result = df[self.regressor_columns].apply(pd.to_numeric)
result = df[columns].apply(pd.to_numeric)
except ValueError as e:
raise ValueError(f"Only convertible to numeric features are allowed! Error: {str(e)}")
else:
Expand All @@ -156,8 +163,11 @@ def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_ProphetAdapter":

prophet_df = self._prepare_prophet_df(df=df)
for regressor in self.regressor_columns:
if regressor not in self.predefined_regressors_names:
self.model.add_regressor(regressor)
if regressor in self.predefined_regressors_names:
continue
if regressor == self.timestamp_column:
continue
self.model.add_regressor(regressor)
self.model.fit(prophet_df)
return self

Expand Down Expand Up @@ -193,20 +203,45 @@ def predict(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Sequen
y_pred = y_pred.rename(rename_dict, axis=1)
return y_pred

def _validate_timestamp(self, df: pd.DataFrame):
self.regressor_columns = cast(List[str], self.regressor_columns)

if self.timestamp_column is None:
if not pd.api.types.is_datetime64_dtype(df["timestamp"]):
raise ValueError("Invalid timestamp! Only datetime type is supported.")

else:
if self.timestamp_column not in df.columns:
raise ValueError("Invalid timestamp_column! It isn't present in a given dataset.")

if self.timestamp_column not in self.regressor_columns:
raise ValueError("Invalid timestamp_column! It should be a regressor.")

if not pd.api.types.is_datetime64_dtype(df[self.timestamp_column]):
raise ValueError("Invalid timestamp_column! Only datetime type is supported.")

if len(df[self.timestamp_column]) >= 3 and pd.infer_freq(df[self.timestamp_column]) is None:
raise ValueError("Invalid timestamp_column! It doesn't contain sequential timestamps.")

def _prepare_prophet_df(self, df: pd.DataFrame) -> pd.DataFrame:
"""Prepare dataframe for fit and predict."""
if self.regressor_columns is None:
raise ValueError("List of regressor is not set!")

self._validate_timestamp(df)
df = df.reset_index()

prophet_df = pd.DataFrame()
prophet_df["y"] = df["target"]
prophet_df["ds"] = df["timestamp"]

if self.timestamp_column is None:
prophet_df["ds"] = df["timestamp"]
else:
prophet_df["ds"] = df[self.timestamp_column]

regressors_data = self._select_regressors(df)
if regressors_data is not None:
prophet_df[self.regressor_columns] = regressors_data[self.regressor_columns]
prophet_df[regressors_data.columns] = regressors_data

return prophet_df

Expand Down Expand Up @@ -358,7 +393,7 @@ class ProphetModel(
daily_seasonality = 'auto', holidays = None, seasonality_mode = 'additive',
seasonality_prior_scale = 10.0, holidays_prior_scale = 10.0, changepoint_prior_scale = 0.05,
mcmc_samples = 0, interval_width = 0.8, uncertainty_samples = 1000, stan_backend = None,
additional_seasonality_params = (), )
additional_seasonality_params = (), timestamp_column = None, )
>>> forecast = model.forecast(future)
>>> forecast
segment segment_0 segment_1 segment_2 segment_3
Expand Down Expand Up @@ -392,6 +427,7 @@ def __init__(
uncertainty_samples: Union[int, bool] = 1000,
stan_backend: Optional[str] = None,
additional_seasonality_params: Iterable[Dict[str, Union[str, float, int]]] = (),
timestamp_column: Optional[str] = None,
):
"""
Create instance of Prophet model.
Expand Down Expand Up @@ -467,6 +503,9 @@ def __init__(
parameters that describe additional (not 'daily', 'weekly', 'yearly') seasonality that should be
added to model; dict with required keys 'name', 'period', 'fourier_order' and optional ones 'prior_scale',
'mode', 'condition_name' will be used for :py:meth:`prophet.Prophet.add_seasonality` method call.
timestamp_column:
Name of a column to be used as timestamp. If not given, index is used.
Column is expected to be regressor containing datetime values with some fixed frequency.
"""
self.growth = growth
self.n_changepoints = n_changepoints
Expand All @@ -485,6 +524,7 @@ def __init__(
self.uncertainty_samples = uncertainty_samples
self.stan_backend = stan_backend
self.additional_seasonality_params = additional_seasonality_params
self.timestamp_column = timestamp_column

super(ProphetModel, self).__init__(
base_model=_ProphetAdapter(
Expand All @@ -505,6 +545,7 @@ def __init__(
uncertainty_samples=self.uncertainty_samples,
stan_backend=self.stan_backend,
additional_seasonality_params=self.additional_seasonality_params,
timestamp_column=self.timestamp_column,
)
)

Expand Down
13 changes: 13 additions & 0 deletions tests/test_models/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from copy import deepcopy

import numpy as np
import pandas as pd
import pytest

from etna.datasets import generate_ar_df
Expand Down Expand Up @@ -80,3 +81,15 @@ def ts_with_non_regressor_exog(example_tsds) -> TSDataset:
df_exog_wide = TSDataset.to_dataset(df_exog)
ts = TSDataset(df=df_wide, df_exog=df_exog_wide, freq=ts.freq)
return ts


@pytest.fixture
def ts_with_external_timestamp() -> TSDataset:
df = generate_ar_df(periods=100, start_time=10, n_segments=2, freq=None)
df_wide = TSDataset.to_dataset(df)
df_exog = generate_ar_df(periods=100, start_time=10, n_segments=2, freq=None)
df_exog["target"] = pd.date_range(start="2020-01-01", periods=100).tolist() * 2
df_exog_wide = TSDataset.to_dataset(df_exog)
df_exog_wide.rename(columns={"target": "external_timestamp"}, level="feature", inplace=True)
ts = TSDataset(df=df_wide.iloc[:-10], df_exog=df_exog_wide, known_future="all", freq=None)
return ts

0 comments on commit 1e9cb9d

Please sign in to comment.