Skip to content

Commit

Permalink
[ENH] NeuralForecastRNN should auto-detect freq (sktime#6003)
Browse files Browse the repository at this point in the history
Enhances `NeuralForecastRNN` to interpret `freq` from `ForecastingHorizon` when passed as `"auto"`
  • Loading branch information
geetu040 committed Mar 1, 2024
1 parent 80001b9 commit ad840de
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 4 deletions.
9 changes: 9 additions & 0 deletions .all-contributorsrc
Original file line number Diff line number Diff line change
Expand Up @@ -2604,6 +2604,15 @@
"bug",
"code"
]
},
{
"login": "geetu040",
"name": "Armaghan",
"avatar_url": "https://avatars.githubusercontent.com/u/90601662?s=96&v=4",
"profile": "https://github.com/geetu040",
"contributions": [
"code"
]
}
]
}
16 changes: 14 additions & 2 deletions sktime/forecasting/base/adapters/_neuralforecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ class _NeuralForecastAdapter(BaseForecaster):
Parameters
----------
freq : str
freq : str (default="auto")
frequency of the data, see available frequencies [1]_ from ``pandas``
default ("auto") interprets freq from ForecastingHorizon in ``fit``
local_scaler_type : str (default=None)
scaler to apply per-series to all features before fitting, which is inverted
after predicting
Expand Down Expand Up @@ -66,7 +68,7 @@ class _NeuralForecastAdapter(BaseForecaster):

def __init__(
self: "_NeuralForecastAdapter",
freq: str,
freq: str = "auto",
local_scaler_type: typing.Optional[
typing.Literal["standard", "robust", "robust-iqr", "minmax", "boxcox"]
] = None,
Expand Down Expand Up @@ -179,6 +181,16 @@ def _fit(
if not fh.is_all_out_of_sample(cutoff=self.cutoff):
raise NotImplementedError("in-sample prediction is currently not supported")

if self.freq == "auto":
if fh.freq:
# interpret freq from ForecastingHorizon
self.freq = fh.freq
else:
# when freq is not interpreted from ForecastingHorizon
raise ValueError(
"Could not interpret freq, try passing freq in model initialization"
)

train_indices = y.index
if isinstance(train_indices, pandas.PeriodIndex):
train_indices = train_indices.to_timestamp(freq=self.freq)
Expand Down
6 changes: 4 additions & 2 deletions sktime/forecasting/neuralforecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ class NeuralForecastRNN(_NeuralForecastAdapter):
Parameters
----------
freq : str
freq : str (default="auto")
frequency of the data, see available frequencies [4]_ from ``pandas``
default ("auto") interprets freq from ForecastingHorizon in ``fit``
local_scaler_type : str (default=None)
scaler to apply per-series to all features before fitting, which is inverted
after predicting
Expand Down Expand Up @@ -156,7 +158,7 @@ class NeuralForecastRNN(_NeuralForecastAdapter):

def __init__(
self: "NeuralForecastRNN",
freq: str,
freq: str = "auto",
local_scaler_type: typing.Optional[
typing.Literal["standard", "robust", "robust-iqr", "minmax", "boxcox"]
] = None,
Expand Down

0 comments on commit ad840de

Please sign in to comment.