Skip to content

Commit

Permalink
Add feedback for forecasting (#1258)
Browse files Browse the repository at this point in the history
Provide feedback when `Forecasting` UDF is called in the following ways:

- [x] Reporting confidence intervals
- [x] Returning a metric for the forecasting performance.
- [x] Providing suggestions in simple special cases, such as during Flat
predictions.

Eg:
```sql
SELECT HomeForecast();
```

```
SUGGESTION: Predictions are flat. Consider using LIBRARY 'neuralforecast' for more accrate predictions.
```

Partially fixes #1257 and #1243.

---------

Co-authored-by: Andy Xu <xzdandy@gmail.com>
  • Loading branch information
americast and xzdandy committed Nov 16, 2023
1 parent 2575f4f commit 69b39b8
Show file tree
Hide file tree
Showing 6 changed files with 269 additions and 24 deletions.
22 changes: 17 additions & 5 deletions docs/source/reference/ai/model-forecasting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Time Series Forecasting
========================

A Time Series is a series of data points recorded at different time intervals. Time series forecasting involves estimating future values of a time series by analyzing historical data.
You can train a forecasting model easily in EvaDB.

.. note::
Expand All @@ -28,15 +29,14 @@ Next, we create a function of `TYPE Forecasting`. We must enter the column name
CREATE FUNCTION IF NOT EXISTS Forecast FROM
(SELECT y FROM AirData)
TYPE Forecasting
HORIZON 12
PREDICT 'y';
This trains a forecasting model. The model can be called by providing the horizon for forecasting.

.. code-block:: sql
SELECT Forecast(12);
Here, the horizon is `12`, which represents the forecast 12 steps into the future.
SELECT Forecast();
Forecast Parameters
Expand All @@ -61,10 +61,22 @@ EvaDB's default forecast framework is `statsforecast <https://nixtla.github.io/s
- If LIBRARY is `statsforecast`, we can select one of ARIMA, ting, ETS, Theta. The default is ARIMA. Check `Automatic Forecasting <https://nixtla.mintlify.app/statsforecast/index.html#automatic-forecasting>`_ to learn details about these models. If LIBRARY is `neuralforecast`, we can select one of NHITS or NBEATS. The default is NBEATS. Check `NBEATS docs <https://nixtla.github.io/neuralforecast/models.nbeats.html>`_ for details.
* - AUTO (str, default: 'T')
- If set to 'T', it enables automatic hyperparameter optimization. Must be set to 'T' for `statsforecast` library. One may set this parameter to `false` if LIBRARY is `neuralforecast` for faster (but less reliable) results.
* - Frequency (str, default: 'auto')
* - CONF (int, default: 90)
- Sets the confidence interval in percentage for the forecast. Must be a number between 0 and 100. The lower and upper bounds of the confidence interval are returned in two separate columns, named as the PREDICT column with `-lo` and `-hi` suffixes.
* - FREQUENCY (str, default: 'auto')
- A string indicating the frequency of the data. The common used ones are D, W, M, Y, which respectively represents day-, week-, month- and year- end frequency. The default value is M. Check `pandas available frequencies <https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases>`_ for all available frequencies. If it is not provided, the frequency is attempted to be determined automatically.
* - METRICS (str, default: 'True')
- Compute NRMSE by performing cross-validation. It is `False` by default if `LIBRARY` is `neuralforecast` as it can take an extensively long time. The metrics are logged locally.

.. note::

If columns other than the ones required as mentioned above are passed while creating the function, they will be treated as exogenous variables if LIBRARY is `neuralforecast`. Otherwise, they would be ignored.


.. note::

`Forecasting` function also logs suggestions. Logged information, such as metrics and suggestions, is sent to STDOUT by default. If you wish not to print it, please send `FALSE` as an optional argument while calling the function. Eg. `SELECT Forecast(FALSE);`

Note: If columns other than the ones required as mentioned above are passed while creating the function, they will be treated as exogenous variables if LIBRARY is `neuralforecast`. Otherwise, they would be ignored.

Below is an example query specifying the above parameters:

Expand Down
17 changes: 17 additions & 0 deletions evadb/binder/statement_binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def _bind_create_function_statement(self, node: CreateFunctionStatement):
for column in all_column_list:
if column.name in predict_columns:
column.name = column.name + "_predictions"

outputs.append(column)
else:
inputs.append(column)
Expand Down Expand Up @@ -122,6 +123,22 @@ def _bind_create_function_statement(self, node: CreateFunctionStatement):
assert (
len(required_columns) == 0
), f"Missing required {required_columns} columns for forecasting function."
outputs.extend(
[
ColumnDefinition(
arg_map.get("predict", "y") + "-lo",
ColumnType.FLOAT,
None,
None,
),
ColumnDefinition(
arg_map.get("predict", "y") + "-hi",
ColumnType.FLOAT,
None,
None,
),
]
)
else:
raise BinderError(
f"Unsupported type of function: {node.function_type}."
Expand Down
127 changes: 117 additions & 10 deletions evadb/executor/create_function_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pathlib import Path
from typing import Dict, List

import numpy as np
import pandas as pd

from evadb.catalog.catalog_utils import get_metadata_properties
Expand Down Expand Up @@ -55,6 +56,10 @@
from evadb.utils.logging_manager import logger


def root_mean_squared_error(y_true, y_pred):
return np.sqrt(np.mean(np.square(y_pred - y_true)))


# From https://stackoverflow.com/a/34333710
@contextlib.contextmanager
def set_env(**environ):
Expand Down Expand Up @@ -354,6 +359,20 @@ def handle_forecasting_function(self):
aggregated_batch.rename(columns={arg_map["time"]: "ds"})
if "id" in arg_map.keys():
aggregated_batch.rename(columns={arg_map["id"]: "unique_id"})
if "conf" in arg_map.keys():
try:
conf = round(arg_map["conf"])
except Exception:
err_msg = "Confidence must be a number."
logger.error(err_msg)
raise FunctionIODefinitionError(err_msg)
else:
conf = 90

if conf > 100:
err_msg = "Confidence must <= 100."
logger.error(err_msg)
raise FunctionIODefinitionError(err_msg)

data = aggregated_batch.frames
if "unique_id" not in list(data.columns):
Expand Down Expand Up @@ -396,18 +415,51 @@ def handle_forecasting_function(self):
if library == "neuralforecast":
try_to_import_neuralforecast()
from neuralforecast import NeuralForecast
from neuralforecast.auto import AutoNBEATS, AutoNHITS
from neuralforecast.models import NBEATS, NHITS
from neuralforecast.auto import (
AutoDeepAR,
AutoFEDformer,
AutoInformer,
AutoNBEATS,
AutoNHITS,
AutoPatchTST,
AutoTFT,
)

# from neuralforecast.auto import AutoAutoformer as AutoAFormer
from neuralforecast.losses.pytorch import MQLoss
from neuralforecast.models import (
NBEATS,
NHITS,
TFT,
DeepAR,
FEDformer,
Informer,
PatchTST,
)

# from neuralforecast.models import Autoformer as AFormer

model_dict = {
"AutoNBEATS": AutoNBEATS,
"AutoNHITS": AutoNHITS,
"NBEATS": NBEATS,
"NHITS": NHITS,
"PatchTST": PatchTST,
"AutoPatchTST": AutoPatchTST,
"DeepAR": DeepAR,
"AutoDeepAR": AutoDeepAR,
"FEDformer": FEDformer,
"AutoFEDformer": AutoFEDformer,
# "AFormer": AFormer,
# "AutoAFormer": AutoAFormer,
"Informer": Informer,
"AutoInformer": AutoInformer,
"TFT": TFT,
"AutoTFT": AutoTFT,
}

if "model" not in arg_map.keys():
arg_map["model"] = "NBEATS"
arg_map["model"] = "TFT"

if "auto" not in arg_map.keys() or (
arg_map["auto"].lower()[0] == "t"
Expand Down Expand Up @@ -441,13 +493,16 @@ def handle_forecasting_function(self):
else:
model_args_config["hist_exog_list"] = exogenous_columns

def get_optuna_config(trial):
return model_args_config
if "auto" in arg_map["model"].lower():

model_args["config"] = get_optuna_config
model_args["backend"] = "optuna"
def get_optuna_config(trial):
return model_args_config

model_args["config"] = get_optuna_config
model_args["backend"] = "optuna"

model_args["h"] = horizon
model_args["loss"] = MQLoss(level=[conf])

model = NeuralForecast(
[model_here(**model_args)],
Expand Down Expand Up @@ -492,7 +547,11 @@ def get_optuna_config(trial):

data["ds"] = pd.to_datetime(data["ds"])

model_save_dir_name = library + "_" + arg_map["model"] + "_" + new_freq
model_save_dir_name = (
library + "_" + arg_map["model"] + "_" + new_freq
if "statsforecast" in library
else library + "_" + str(conf) + "_" + arg_map["model"] + "_" + new_freq
)
if len(data.columns) >= 4 and library == "neuralforecast":
model_save_dir_name += "_exogenous_" + str(sorted(exogenous_columns))

Expand Down Expand Up @@ -524,6 +583,7 @@ def get_optuna_config(trial):
data[column] = data.apply(
lambda x: self.convert_to_numeric(x[column]), axis=1
)
rmses = []
if library == "neuralforecast":
cuda_devices_here = "0"
if "CUDA_VISIBLE_DEVICES" in os.environ:
Expand All @@ -532,6 +592,26 @@ def get_optuna_config(trial):
with set_env(CUDA_VISIBLE_DEVICES=cuda_devices_here):
model.fit(df=data, val_size=horizon)
model.save(model_path, overwrite=True)
if "metrics" in arg_map and arg_map["metrics"].lower()[0] == "t":
crossvalidation_df = model.cross_validation(
df=data, val_size=horizon
)
for uid in crossvalidation_df.unique_id.unique():
crossvalidation_df_here = crossvalidation_df[
crossvalidation_df.unique_id == uid
]
rmses.append(
root_mean_squared_error(
crossvalidation_df_here.y,
crossvalidation_df_here[
arg_map["model"] + "-median"
],
)
/ np.mean(crossvalidation_df_here.y)
)
mean_rmse = np.mean(rmses)
with open(model_path + "_rmse", "w") as f:
f.write(str(mean_rmse) + "\n")
else:
# The following lines of code helps eliminate the math error encountered in statsforecast when only one datapoint is available in a time series
for col in data["unique_id"].unique():
Expand All @@ -541,14 +621,40 @@ def get_optuna_config(trial):
)

model.fit(df=data[["ds", "y", "unique_id"]])
hypers = ""
if "arima" in arg_map["model"].lower():
from statsforecast.arima import arima_string

hypers += arima_string(model.fitted_[0, 0].model_)
f = open(model_path, "wb")
pickle.dump(model, f)
f.close()
if "metrics" not in arg_map or arg_map["metrics"].lower()[0] == "t":
crossvalidation_df = model.cross_validation(
df=data[["ds", "y", "unique_id"]],
h=horizon,
step_size=24,
n_windows=1,
).reset_index()
for uid in crossvalidation_df.unique_id.unique():
crossvalidation_df_here = crossvalidation_df[
crossvalidation_df.unique_id == uid
]
rmses.append(
root_mean_squared_error(
crossvalidation_df_here.y,
crossvalidation_df_here[arg_map["model"]],
)
/ np.mean(crossvalidation_df_here.y)
)
mean_rmse = np.mean(rmses)
with open(model_path + "_rmse", "w") as f:
f.write(str(mean_rmse) + "\n")
f.write(hypers + "\n")
elif not Path(model_path).exists():
model_path = os.path.join(model_dir, existing_model_files[-1])

io_list = self._resolve_function_io(None)

data["ds"] = data.ds.astype(str)
metadata_here = [
FunctionMetadataCatalogEntry("model_name", arg_map["model"]),
FunctionMetadataCatalogEntry("model_path", model_path),
Expand All @@ -563,6 +669,7 @@ def get_optuna_config(trial):
),
FunctionMetadataCatalogEntry("horizon", horizon),
FunctionMetadataCatalogEntry("library", library),
FunctionMetadataCatalogEntry("conf", conf),
]

return (
Expand Down
61 changes: 57 additions & 4 deletions evadb/functions/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.


import os
import pickle

import pandas as pd
Expand All @@ -37,6 +38,7 @@ def setup(
id_column_rename: str,
horizon: int,
library: str,
conf: int,
):
self.library = library
if "neuralforecast" in self.library:
Expand All @@ -53,18 +55,69 @@ def setup(
self.time_column_rename = time_column_rename
self.id_column_rename = id_column_rename
self.horizon = int(horizon)
self.library = library
self.suggestion_dict = {
1: "Predictions are flat. Consider using LIBRARY 'neuralforecast' for more accrate predictions.",
}
self.conf = conf
self.hypers = None
self.rmse = None
if os.path.isfile(model_path + "_rmse"):
with open(model_path + "_rmse", "r") as f:
self.rmse = float(f.readline())
if "arima" in model_name.lower():
self.hypers = "p,d,q: " + f.readline()

def forward(self, data) -> pd.DataFrame:
log_str = ""
if self.library == "statsforecast":
forecast_df = self.model.predict(h=self.horizon)
forecast_df = self.model.predict(
h=self.horizon, level=[self.conf]
).reset_index()
else:
forecast_df = self.model.predict()
forecast_df.reset_index(inplace=True)
forecast_df = self.model.predict().reset_index()

# Feedback
if len(data) == 0 or list(list(data.iloc[0]))[0] is True:
# Suggestions
suggestion_list = []
# 1: Flat predictions
if self.library == "statsforecast":
for type_here in forecast_df["unique_id"].unique():
if (
forecast_df.loc[forecast_df["unique_id"] == type_here][
self.model_name
].nunique()
== 1
):
suggestion_list.append(1)

for suggestion in set(suggestion_list):
log_str += "\nSUGGESTION: " + self.suggestion_dict[suggestion]

# Metrics
if self.rmse is not None:
log_str += "\nMean normalized RMSE: " + str(self.rmse)
if self.hypers is not None:
log_str += "\nHyperparameters: " + self.hypers

print(log_str)

forecast_df = forecast_df.rename(
columns={
"unique_id": self.id_column_rename,
"ds": self.time_column_rename,
self.model_name: self.predict_column_rename,
self.model_name
if self.library == "statsforecast"
else self.model_name + "-median": self.predict_column_rename,
self.model_name
+ "-lo-"
+ str(self.conf): self.predict_column_rename
+ "-lo",
self.model_name
+ "-hi-"
+ str(self.conf): self.predict_column_rename
+ "-hi",
}
)[: self.horizon * forecast_df["unique_id"].nunique()]
return forecast_df
Loading

0 comments on commit 69b39b8

Please sign in to comment.