Skip to content

Commit

Permalink
fix exogenous for auto; made default
Browse files Browse the repository at this point in the history
  • Loading branch information
americast committed Sep 30, 2023
1 parent b422000 commit e176bd4
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 29 deletions.
13 changes: 7 additions & 6 deletions docs/source/reference/ai/model-forecasting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ EvaDB's default forecast framework is `statsforecast <https://nixtla.github.io/s
- The name of column that represents an identifier for the series. If relevant column is not found, the whole table is considered as one series of data.
* - LIBRARY (str, default: 'statsforecast')
- We can select one of `statsforecast` (default) or `neuralforecast`. `statsforecast` provides access to statistical forecasting methods, while `neuralforecast` gives access to deep-learning based forecasting methods.
* - MODEL (str, default: 'AutoARIMA')
- If LIBRARY is `statsforecast`, we can select one of AutoARIMA, AutoCES, AutoETS, AutoTheta. The default is AutoARIMA. Check `Automatic Forecasting <https://nixtla.github.io/statsforecast/src/core/models_intro.html#automatic-forecasting>`_ to learn details about these models. If LIBRARY is `neuralforecast`, we can select one of NHITS or NBEATS. The default is NBEATS. Check `NBEATS docs <https://nixtla.github.io/neuralforecast/models.nbeats.html>`_ for details.
* - AUTO (str, default: 'F')
- The names of columns to be treated as exogenous variables, separated by comma. These columns would be considered for forecasting by the backend only for LIBRARY `neuralforecast`.
* - MODEL (str, default: 'ARIMA')
- If LIBRARY is `statsforecast`, we can select one of ARIMA, CES, ETS, Theta. The default is ARIMA. Check `Automatic Forecasting <https://nixtla.github.io/statsforecast/src/core/models_intro.html#automatic-forecasting>`_ to learn details about these models. If LIBRARY is `neuralforecast`, we can select one of NHITS or NBEATS. The default is NBEATS. Check `NBEATS docs <https://nixtla.github.io/neuralforecast/models.nbeats.html>`_ for details.
* - AUTO (str, default: 'T')
- If set to 'T', it enables automatic hyperparameter optimization. Must be set to 'T' for `statsforecast` library. One may set this parameter to `false` if LIBRARY is `neuralforecast` for faster (but less reliable) results.
* - Frequency (str, default: 'auto')
- A string indicating the frequency of the data. The common used ones are D, W, M, Y, which repestively represents day-, week-, month- and year- end frequency. The default value is M. Check `pandas available frequencies <https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases>`_ for all available frequencies. If it is not provided, the frequency is attempted to be determined automatically.

Note: If columns other than the ones required as mentioned above are passed while creating the function, they will be treated as exogenous variables if LIBRARY is `neuralforecast` and the AUTO is set to F. In other situations, they would be ignored.
Note: If columns other than the ones required as mentioned above are passed while creating the function, they will be treated as exogenous variables if LIBRARY is `neuralforecast`. Otherwise, they would be ignored.

Below is an example query specifying the above parameters:

Expand All @@ -79,7 +79,7 @@ Below is an example query specifying the above parameters:
ID 'type'
Frequency 'W';
Below is an example query with `neuralforecast` with `trend` column as exogenous:
Below is an example query with `neuralforecast` with `trend` column as exogenous and without automatic hyperparameter optimization:

.. code-block:: sql
Expand All @@ -89,4 +89,5 @@ Below is an example query with `neuralforecast` with `trend` column as exogenous
HORIZON 12
PREDICT 'y'
LIBRARY 'neuralforecast'
AUTO 'f'
FREQUENCY 'M';
45 changes: 23 additions & 22 deletions evadb/executor/create_function_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,8 @@ def handle_forecasting_function(self):
if "model" not in arg_map.keys():
arg_map["model"] = "NBEATS"

if (
"auto" in arg_map.keys()
and arg_map["auto"].lower()[0] == "t"
if "auto" not in arg_map.keys() or (
arg_map["auto"].lower()[0] == "t"
and "auto" not in arg_map["model"].lower()
):
arg_map["model"] = "Auto" + arg_map["model"]
Expand All @@ -307,15 +306,21 @@ def handle_forecasting_function(self):

if "auto" not in arg_map["model"].lower():
model_args["input_size"] = 2 * horizon
if len(data.columns) >= 4:
exogenous_columns = [
x
for x in list(data.columns)
if x not in ["ds", "y", "unique_id"]
]
model_args["hist_exog_list"] = exogenous_columns

model_args["early_stop_patience_steps"] = 20
else:
model_args["config"] = {
"input_size": 2 * horizon,
"early_stop_patience_steps": 20,
}

if len(data.columns) >= 4:
exogenous_columns = [
x for x in list(data.columns) if x not in ["ds", "y", "unique_id"]
]
if "auto" not in arg_map["model"].lower():
model_args["hist_exog_list"] = exogenous_columns
else:
model_args["config"]["hist_exog_list"] = exogenous_columns

model_args["h"] = horizon

Expand All @@ -328,6 +333,10 @@ def handle_forecasting_function(self):
# Statsforecast implementation
# """
else:
if "auto" in arg_map.keys() and arg_map["auto"].lower()[0] != "t":
raise RuntimeError(
"Statsforecast implementation only supports automatic hyperparameter optimization. Please set AUTO to true."
)
try_to_import_statsforecast()
from statsforecast import StatsForecast
from statsforecast.models import AutoARIMA, AutoCES, AutoETS, AutoTheta
Expand All @@ -340,13 +349,9 @@ def handle_forecasting_function(self):
}

if "model" not in arg_map.keys():
arg_map["model"] = "AutoARIMA"
arg_map["model"] = "ARIMA"

if (
"auto" in arg_map.keys()
and arg_map["auto"].lower()[0] == "t"
and "auto" not in arg_map["model"].lower()
):
if "auto" not in arg_map["model"].lower():
arg_map["model"] = "Auto" + arg_map["model"]

try:
Expand All @@ -363,11 +368,7 @@ def handle_forecasting_function(self):
data["ds"] = pd.to_datetime(data["ds"])

model_save_dir_name = library + "_" + arg_map["model"] + "_" + new_freq
if (
len(data.columns) >= 4
and "auto" not in arg_map["model"].lower()
and library == "neuralforecast"
):
if len(data.columns) >= 4 and library == "neuralforecast":
model_save_dir_name += "_exogenous_" + str(sorted(exogenous_columns))

model_dir = os.path.join(
Expand Down
2 changes: 1 addition & 1 deletion test/integration_tests/long/test_model_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_forecast(self):
SELECT AirPanelForecast() order by y;
"""
result = execute_query_fetch_all(self.evadb, predict_query)
self.assertEqual(len(result), 12)
self.assertEqual(len(result), 24)
self.assertEqual(
result.columns,
["airpanelforecast.unique_id", "airpanelforecast.ds", "airpanelforecast.y"],
Expand Down

0 comments on commit e176bd4

Please sign in to comment.