Skip to content

Commit

Permalink
fix the unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
codeloop committed Apr 29, 2024
1 parent 86d5790 commit 72db865
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ def create_horizon(self, spec, historical_data):
pd.date_range(
start=historical_data.get_max_time(),
periods=spec.horizon + 1,
freq=historical_data.freq,
freq=historical_data.freq
or pd.infer_freq(
historical_data.data.reset_index()[spec.datetime_column.name][-5:]
),
),
name=spec.datetime_column.name,
)
Expand Down
57 changes: 40 additions & 17 deletions ads/opctl/operator/lowcode/forecast/model/ml_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,20 +73,35 @@ def _train_model(self, data_train, data_test, model_kwargs):
alpha=model_kwargs["lower_quantile"],
),
},
freq=pd.infer_freq(data_train.Date.drop_duplicates()),
freq=pd.infer_freq(data_train["Date"].drop_duplicates())
or pd.infer_freq(data_train["Date"].drop_duplicates()[-5:]),
target_transforms=[Differences([12])],
lags=model_kwargs.get("lags", [1, 6, 12]),
lag_transforms={
1: [ExpandingMean()],
12: [RollingMean(window_size=24)],
},
lags=model_kwargs.get(
"lags",
(
[1, 6, 12]
if len(self.datasets.get_additional_data_column_names()) > 0
else []
),
),
lag_transforms=(
{
1: [ExpandingMean()],
12: [RollingMean(window_size=24)],
}
if len(self.datasets.get_additional_data_column_names()) > 0
else {}
),
# date_features=[hour_index],
)

num_models = model_kwargs.get("recursive_models", False)

self.model_columns = [
ForecastOutputColumns.SERIES
] + data_train.select_dtypes(exclude=["object"]).columns.to_list()
fcst.fit(
data_train,
data_train[self.model_columns],
static_features=model_kwargs.get("static_features", []),
id_col=ForecastOutputColumns.SERIES,
time_col=self.spec.datetime_column.name,
Expand All @@ -99,8 +114,10 @@ def _train_model(self, data_train, data_test, model_kwargs):
h=self.spec.horizon,
X_df=pd.concat(
[
data_test,
fcst.get_missing_future(h=self.spec.horizon, X_df=data_test),
data_test[self.model_columns],
fcst.get_missing_future(
h=self.spec.horizon, X_df=data_test[self.model_columns]
),
],
axis=0,
ignore_index=True,
Expand Down Expand Up @@ -166,12 +183,16 @@ def _generate_report(self):
# Section 1: Forecast Overview
sec1_text = rc.Block(
rc.Heading("Forecast Overview", level=2),
rc.Text("These plots show your forecast in the context of historical data.")
rc.Text(
"These plots show your forecast in the context of historical data."
),
)
sec_1 = _select_plot_list(
lambda s_id: plot_series(
self.datasets.get_all_data_long(include_horizon=False),
pd.concat([self.fitted_values,self.outputs], axis=0, ignore_index=True),
pd.concat(
[self.fitted_values, self.outputs], axis=0, ignore_index=True
),
id_col=ForecastOutputColumns.SERIES,
time_col=self.spec.datetime_column.name,
target_col=self.original_target_column,
Expand All @@ -184,7 +205,7 @@ def _generate_report(self):
# Section 2: MlForecast Model Parameters
sec2_text = rc.Block(
rc.Heading("MlForecast Model Parameters", level=2),
rc.Text("These are the parameters used for the MlForecast model.")
rc.Text("These are the parameters used for the MlForecast model."),
)

blocks = [
Expand All @@ -197,9 +218,11 @@ def _generate_report(self):
sec_2 = rc.Select(blocks=blocks)

all_sections = [sec1_text, sec_1, sec2_text, sec_2]
model_description = rc.Text("mlforecast is a framework to perform time series forecasting using machine learning models"
"with the option to scale to massive amounts of data using remote clusters."
"Fastest implementations of feature engineering for time series forecasting in Python."
"Support for exogenous variables and static covariates.")
model_description = rc.Text(
"mlforecast is a framework to perform time series forecasting using machine learning models"
"with the option to scale to massive amounts of data using remote clusters."
"Fastest implementations of feature engineering for time series forecasting in Python."
"Support for exogenous variables and static covariates."
)

return model_description, all_sections
return model_description, all_sections
5 changes: 4 additions & 1 deletion tests/operators/forecast/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,10 @@ def test_load_datasets(model, data_details):

run(yaml_i, backend="operator.local", debug=False)
subprocess.run(f"ls -a {output_data_path}", shell=True)
if yaml_i["spec"]["generate_explanations"] and model != "automlx":
if yaml_i["spec"]["generate_explanations"] and model not in [
"automlx",
"mlforecast",
]:
verify_explanations(
tmpdirname=tmpdirname,
additional_cols=additional_cols,
Expand Down
2 changes: 1 addition & 1 deletion tests/operators/forecast/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ def test_arima_automlx_errors(operator_setup, model):
in error_content["13"]["error"]
), "Error message mismatch"

if model not in ["autots", "automlx"]:
if model not in ["autots", "automlx", "mlforecast"]:
global_fn = f"{tmpdirname}/results/global_explanation.csv"
assert os.path.exists(
global_fn
Expand Down

0 comments on commit 72db865

Please sign in to comment.