Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ProphetModel to handle external timestamp #203

Merged
merged 12 commits into from
Dec 26, 2023
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Update `ResampleWithDistributionTransform` to work with integer timestamp ([#165](https://github.com/etna-team/etna/pull/165))
-
- Update change point transforms (`ChangePointsSegmentationTransform`, `ChangePointsTrendTransform`, `ChangePointsLevelTransform`, `TrendTransform`) to handle integer timestamp ([#176](https://github.com/etna-team/etna/pull/176))
-
- Update `ProphetModel` to handle external timestamp ([#203](https://github.com/etna-team/etna/pull/203))

### Fixed
-
Expand Down
53 changes: 47 additions & 6 deletions etna/models/prophet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Sequence
from typing import Set
from typing import Union
from typing import cast

import pandas as pd

Expand Down Expand Up @@ -50,6 +51,7 @@ def __init__(
uncertainty_samples: Union[int, bool] = 1000,
stan_backend: Optional[str] = None,
additional_seasonality_params: Iterable[Dict[str, Union[str, float, int]]] = (),
timestamp_column: Optional[str] = None,
):

self.growth = growth
Expand All @@ -69,6 +71,7 @@ def __init__(
self.uncertainty_samples = uncertainty_samples
self.stan_backend = stan_backend
self.additional_seasonality_params = additional_seasonality_params
self.timestamp_column = timestamp_column

self.model = self._create_model()

Expand Down Expand Up @@ -131,8 +134,12 @@ def _select_regressors(self, df: pd.DataFrame) -> Optional[pd.DataFrame]:
)

if self.regressor_columns:
columns = deepcopy(self.regressor_columns)
if self.timestamp_column in columns:
columns.remove(self.timestamp_column)

try:
result = df[self.regressor_columns].apply(pd.to_numeric)
result = df[columns].apply(pd.to_numeric)
except ValueError as e:
raise ValueError(f"Only convertible to numeric features are allowed! Error: {str(e)}")
else:
Expand All @@ -156,8 +163,11 @@ def fit(self, df: pd.DataFrame, regressors: List[str]) -> "_ProphetAdapter":

prophet_df = self._prepare_prophet_df(df=df)
for regressor in self.regressor_columns:
if regressor not in self.predefined_regressors_names:
self.model.add_regressor(regressor)
if regressor in self.predefined_regressors_names:
continue
if regressor == self.timestamp_column:
continue
self.model.add_regressor(regressor)
self.model.fit(prophet_df)
return self

Expand Down Expand Up @@ -193,20 +203,45 @@ def predict(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Sequen
y_pred = y_pred.rename(rename_dict, axis=1)
return y_pred

def _validate_timestamp(self, df: pd.DataFrame):
self.regressor_columns = cast(List[str], self.regressor_columns)

if self.timestamp_column is None:
if not pd.api.types.is_datetime64_dtype(df["timestamp"]):
raise ValueError("Invalid timestamp! Only datetime type is supported.")

else:
if self.timestamp_column not in df.columns:
raise ValueError("Invalid timestamp_column! It isn't present in a given dataset.")

if self.timestamp_column not in self.regressor_columns:
raise ValueError("Invalid timestamp_column! It should be a regressor.")

if not pd.api.types.is_datetime64_dtype(df[self.timestamp_column]):
raise ValueError("Invalid timestamp_column! Only datetime type is supported.")

if len(df[self.timestamp_column]) >= 3 and pd.infer_freq(df[self.timestamp_column]) is None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't check that frequency is always the same. For example, it works fine if we have one frequency for train and for test. In theory, prophet can work fine even if there no regular frequency, but I'm not sure should we support this case or not.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could probably infer frequency during train, for example, and check if it is as expected.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should check the freq and give a warning if the freq for the train is different from the freq for the test.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont like the idea of warning. It will be thrown in every per-segment model, which seems too much.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is a bad idea to have different frequencies, so the only option is to fail in this situation.

raise ValueError("Invalid timestamp_column! It doesn't contain sequential timestamps.")

def _prepare_prophet_df(self, df: pd.DataFrame) -> pd.DataFrame:
"""Prepare dataframe for fit and predict."""
if self.regressor_columns is None:
raise ValueError("List of regressor is not set!")

self._validate_timestamp(df)
df = df.reset_index()

prophet_df = pd.DataFrame()
prophet_df["y"] = df["target"]
prophet_df["ds"] = df["timestamp"]

if self.timestamp_column is None:
prophet_df["ds"] = df["timestamp"]
else:
prophet_df["ds"] = df[self.timestamp_column]

regressors_data = self._select_regressors(df)
if regressors_data is not None:
prophet_df[self.regressor_columns] = regressors_data[self.regressor_columns]
prophet_df[regressors_data.columns] = regressors_data

return prophet_df

Expand Down Expand Up @@ -358,7 +393,7 @@ class ProphetModel(
daily_seasonality = 'auto', holidays = None, seasonality_mode = 'additive',
seasonality_prior_scale = 10.0, holidays_prior_scale = 10.0, changepoint_prior_scale = 0.05,
mcmc_samples = 0, interval_width = 0.8, uncertainty_samples = 1000, stan_backend = None,
additional_seasonality_params = (), )
additional_seasonality_params = (), timestamp_column = None, )
>>> forecast = model.forecast(future)
>>> forecast
segment segment_0 segment_1 segment_2 segment_3
Expand Down Expand Up @@ -392,6 +427,7 @@ def __init__(
uncertainty_samples: Union[int, bool] = 1000,
stan_backend: Optional[str] = None,
additional_seasonality_params: Iterable[Dict[str, Union[str, float, int]]] = (),
timestamp_column: Optional[str] = None,
):
"""
Create instance of Prophet model.
Expand Down Expand Up @@ -467,6 +503,9 @@ def __init__(
parameters that describe additional (not 'daily', 'weekly', 'yearly') seasonality that should be
added to model; dict with required keys 'name', 'period', 'fourier_order' and optional ones 'prior_scale',
'mode', 'condition_name' will be used for :py:meth:`prophet.Prophet.add_seasonality` method call.
timestamp_column:
Name of a column to be used as timestamp. If not given, index is used.
Column is expected to be regressor containing datetime values with some fixed frequency.
"""
self.growth = growth
self.n_changepoints = n_changepoints
Expand All @@ -485,6 +524,7 @@ def __init__(
self.uncertainty_samples = uncertainty_samples
self.stan_backend = stan_backend
self.additional_seasonality_params = additional_seasonality_params
self.timestamp_column = timestamp_column

super(ProphetModel, self).__init__(
base_model=_ProphetAdapter(
Expand All @@ -505,6 +545,7 @@ def __init__(
uncertainty_samples=self.uncertainty_samples,
stan_backend=self.stan_backend,
additional_seasonality_params=self.additional_seasonality_params,
timestamp_column=self.timestamp_column,
)
)

Expand Down
17 changes: 17 additions & 0 deletions tests/test_models/test_inference/conftest.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we don't add this code to tests/conftest.py ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, it seems specific for inference tests. I'll think about moving it higher.

Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pandas as pd
import pytest

from etna.datasets import TSDataset
from etna.datasets import generate_ar_df


@pytest.fixture
def ts_with_external_timestamp() -> TSDataset:
df = generate_ar_df(periods=100, start_time=10, n_segments=2, freq=None)
df_wide = TSDataset.to_dataset(df)
df_exog = generate_ar_df(periods=100, start_time=10, n_segments=2, freq=None)
df_exog["target"] = pd.date_range(start="2020-01-01", periods=100).tolist() * 2
df_exog_wide = TSDataset.to_dataset(df_exog)
df_exog_wide.rename(columns={"target": "external_timestamp"}, level="feature", inplace=True)
ts = TSDataset(df=df_wide.iloc[:-10], df_exog=df_exog_wide, known_future="all", freq=None)
return ts