diff --git a/python/chronos/example/auto_model/autoprophet_nyc_taxi.py b/python/chronos/example/auto_model/autoprophet_nyc_taxi.py index 99d80f7428d..1501796c1bb 100644 --- a/python/chronos/example/auto_model/autoprophet_nyc_taxi.py +++ b/python/chronos/example/auto_model/autoprophet_nyc_taxi.py @@ -79,6 +79,10 @@ def get_data(args): autoprophet_fit_time = time.time() - start_time stop_orca_context() + # save and load + autoprophet.save("autoprophet.ckpt") + autoprophet = AutoProphet(load_dir="autoprophet.ckpt") + # evaluate auto_searched_mse = autoprophet.evaluate(df_test, metrics=['mse'])[0] nonauto_searched_mse = prophet.evaluate(df_test, metrics=['mse'])[0] diff --git a/python/chronos/src/bigdl/chronos/autots/model/auto_arima.py b/python/chronos/src/bigdl/chronos/autots/model/auto_arima.py index 596ad525c39..6252ed14ea6 100644 --- a/python/chronos/src/bigdl/chronos/autots/model/auto_arima.py +++ b/python/chronos/src/bigdl/chronos/autots/model/auto_arima.py @@ -16,12 +16,12 @@ # limitations under the License. # -from bigdl.orca.automl.auto_estimator import AutoEstimator -from bigdl.chronos.model.arima import ARIMABuilder - +import warnings +from bigdl.chronos.model.arima import ARIMABuilder, ARIMAModel # - + class AutoARIMA: def __init__(self, @@ -36,6 +36,7 @@ def __init__(self, cpus_per_trial=1, name="auto_arima", remote_dir=None, + load_dir=None, **arima_config ): """ @@ -75,22 +76,29 @@ def __init__(self, :param arima_config: Other ARIMA hyperparameters. """ - self.search_space = { - "p": p, - "q": q, - "seasonal": seasonal, - "P": P, - "Q": Q, - "m": m, - } - self.metric = metric - model_builder = ARIMABuilder() - self.auto_est = AutoEstimator(model_builder=model_builder, - logs_dir=logs_dir, - resources_per_trial={ - "cpu": cpus_per_trial}, - remote_dir=remote_dir, - name=name) + if load_dir: + self.best_model = ARIMAModel() + self.best_model.restore(load_dir) + try: + from bigdl.orca.automl.auto_estimator import AutoEstimator + self.search_space = { + "p": p, + "q": q, + "seasonal": seasonal, + "P": P, + "Q": Q, + "m": m, + } + self.metric = metric + model_builder = ARIMABuilder() + self.auto_est = AutoEstimator(model_builder=model_builder, + logs_dir=logs_dir, + resources_per_trial={ + "cpu": cpus_per_trial}, + remote_dir=remote_dir, + name=name) + except ImportError: + warnings.warn("You need to install `bigdl-orca[automl]` to use `fit` function.") def fit(self, data, diff --git a/python/chronos/src/bigdl/chronos/autots/model/auto_prophet.py b/python/chronos/src/bigdl/chronos/autots/model/auto_prophet.py index 9f6ea695392..a09abc24682 100644 --- a/python/chronos/src/bigdl/chronos/autots/model/auto_prophet.py +++ b/python/chronos/src/bigdl/chronos/autots/model/auto_prophet.py @@ -17,10 +17,8 @@ # import pandas as pd - -from bigdl.orca.automl.auto_estimator import AutoEstimator -import bigdl.orca.automl.hp as hp -from bigdl.chronos.model.prophet import ProphetBuilder +import warnings +from bigdl.chronos.model.prophet import ProphetBuilder, ProphetModel # - @@ -28,16 +26,17 @@ class AutoProphet: def __init__(self, - changepoint_prior_scale=hp.grid_search([0.005, 0.05, 0.1, 0.5]), - seasonality_prior_scale=hp.grid_search([0.01, 0.1, 1.0, 10.0]), - holidays_prior_scale=hp.loguniform(0.01, 10), - seasonality_mode=hp.choice(['additive', 'multiplicative']), - changepoint_range=hp.uniform(0.8, 0.95), + changepoint_prior_scale=None, + seasonality_prior_scale=None, + holidays_prior_scale=None, + seasonality_mode=None, + changepoint_range=None, metric='mse', logs_dir="/tmp/auto_prophet_logs", cpus_per_trial=1, name="auto_prophet", remote_dir=None, + load_dir=None, **prophet_config ): """ @@ -70,24 +69,43 @@ def __init__(self, :param remote_dir: String. Remote directory to sync training results and checkpoints. It defaults to None and doesn't take effects while running in local. While running in cluster, it defaults to "hdfs:///tmp/{name}". + :param load_dir: Load the ckpt from load_dir. The value defaults to None. :param prophet_config: Other Prophet hyperparameters. """ - self.search_space = { - "changepoint_prior_scale": changepoint_prior_scale, - "seasonality_prior_scale": seasonality_prior_scale, - "holidays_prior_scale": holidays_prior_scale, - "seasonality_mode": seasonality_mode, - "changepoint_range": changepoint_range - } - self.search_space.update(prophet_config) # update other configs - self.metric = metric - model_builder = ProphetBuilder() - self.auto_est = AutoEstimator(model_builder=model_builder, - logs_dir=logs_dir, - resources_per_trial={"cpu": cpus_per_trial}, - remote_dir=remote_dir, - name=name) + if load_dir: + self.best_model = ProphetModel() + self.best_model.restore(load_dir) + try: + from bigdl.orca.automl.auto_estimator import AutoEstimator + import bigdl.orca.automl.hp as hp + self.search_space = { + "changepoint_prior_scale": hp.grid_search([0.005, 0.05, 0.1, 0.5]) + if changepoint_prior_scale is None + else changepoint_prior_scale, + "seasonality_prior_scale": hp.grid_search([0.01, 0.1, 1.0, 10.0]) + if seasonality_prior_scale is None + else seasonality_prior_scale, + "holidays_prior_scale": hp.loguniform(0.01, 10) + if holidays_prior_scale is None + else holidays_prior_scale, + "seasonality_mode": hp.choice(['additive', 'multiplicative']) + if seasonality_mode is None + else seasonality_mode, + "changepoint_range": hp.uniform(0.8, 0.95) + if changepoint_range is None + else changepoint_range + } + self.search_space.update(prophet_config) # update other configs + self.metric = metric + model_builder = ProphetBuilder() + self.auto_est = AutoEstimator(model_builder=model_builder, + logs_dir=logs_dir, + resources_per_trial={"cpu": cpus_per_trial}, + remote_dir=remote_dir, + name=name) + except ImportError: + warnings.warn("You need to install `bigdl-orca[automl]` to use `fit` function.") def fit(self, data, diff --git a/python/chronos/src/bigdl/chronos/forecaster/arima_forecaster.py b/python/chronos/src/bigdl/chronos/forecaster/arima_forecaster.py index 3c4dcb384cf..89d9ed8470d 100644 --- a/python/chronos/src/bigdl/chronos/forecaster/arima_forecaster.py +++ b/python/chronos/src/bigdl/chronos/forecaster/arima_forecaster.py @@ -77,7 +77,7 @@ def fit(self, data, validation_data): """ self._check_data(data, validation_data) data = data.reshape(-1, 1) - validation_data = validation_data.reshape(-1, 1) + # validation_data = validation_data.reshape(-1, 1) return self.internal.fit_eval(data=data, validation_data=validation_data, **self.model_config) diff --git a/python/chronos/src/bigdl/chronos/model/arima.py b/python/chronos/src/bigdl/chronos/model/arima.py index 19de284695a..db51f49a6cb 100644 --- a/python/chronos/src/bigdl/chronos/model/arima.py +++ b/python/chronos/src/bigdl/chronos/model/arima.py @@ -19,11 +19,10 @@ from pmdarima.arima import ndiffs from pmdarima.arima import nsdiffs -from bigdl.orca.automl.metrics import Evaluator -from bigdl.orca.automl.model.abstract import BaseModel, ModelBuilder +from bigdl.chronos.metric.forecast_metrics import Evaluator -class ARIMAModel(BaseModel): +class ARIMAModel: def __init__(self): """ @@ -138,7 +137,7 @@ def evaluate(self, target, x=None, metrics=['mse'], rolling=False): forecasts = self.predict(horizon=len(target), rolling=rolling) - return [Evaluator.evaluate(m, target, forecasts) for m in metrics] + return Evaluator.evaluate(metrics, target, forecasts, aggregate="mean") def save(self, checkpoint_file): if self.model is None: @@ -152,7 +151,7 @@ def restore(self, checkpoint_file): self.model_init = True -class ARIMABuilder(ModelBuilder): +class ARIMABuilder: def __init__(self, **arima_config): """ diff --git a/python/chronos/src/bigdl/chronos/model/prophet.py b/python/chronos/src/bigdl/chronos/model/prophet.py index 3efc1c1806a..0b7ed673895 100644 --- a/python/chronos/src/bigdl/chronos/model/prophet.py +++ b/python/chronos/src/bigdl/chronos/model/prophet.py @@ -20,11 +20,10 @@ from prophet.diagnostics import performance_metrics from prophet.diagnostics import cross_validation -from bigdl.orca.automl.metrics import Evaluator -from bigdl.orca.automl.model.abstract import BaseModel, ModelBuilder +from bigdl.chronos.metric.forecast_metrics import Evaluator -class ProphetModel(BaseModel): +class ProphetModel: def __init__(self): """ @@ -129,7 +128,8 @@ def evaluate(self, target, data=None, metrics=['mse']): raise Exception("Needs to call fit_eval or restore first before calling evaluate") target_pred = self.model.predict(target) - return [Evaluator.evaluate(m, target.y.values, target_pred.yhat.values) for m in metrics] + return Evaluator.evaluate(metrics, target.y.values, + target_pred.yhat.values, aggregate="mean") def save(self, checkpoint): if self.model is None: @@ -143,7 +143,7 @@ def restore(self, checkpoint): self.model_init = True -class ProphetBuilder(ModelBuilder): +class ProphetBuilder: def __init__(self, **prophet_config): """ diff --git a/python/chronos/test/bigdl/chronos/autots/model/test_auto_prophet.py b/python/chronos/test/bigdl/chronos/autots/model/test_auto_prophet.py index d7d15add17d..74ed5608ee6 100644 --- a/python/chronos/test/bigdl/chronos/autots/model/test_auto_prophet.py +++ b/python/chronos/test/bigdl/chronos/autots/model/test_auto_prophet.py @@ -48,7 +48,7 @@ def test_auto_prophet_fit(self): seasonality_prior_scale=hp.loguniform(0.01, 10), holidays_prior_scale=hp.loguniform(0.01, 10), seasonality_mode=hp.choice(['additive', 'multiplicative']), - changepoint_range=hp.uniform(0.8, 0.95) + changepoint_range=hp.uniform(0.8, 0.95), ) auto_prophet.fit(data=data, @@ -101,4 +101,10 @@ def test_auto_prophet_save_load(self): with tempfile.TemporaryDirectory() as tmp_dir_name: ckpt_name = os.path.join(tmp_dir_name, "json") auto_prophet.save(ckpt_name) + pred = auto_prophet.predict(horizon=10, freq="D") auto_prophet.restore(ckpt_name) + pred_old = auto_prophet.predict(horizon=10, freq="D") + new_auto_prophet = AutoProphet(load_dir=ckpt_name) + pred_new = new_auto_prophet.predict(horizon=10, freq="D") + np.testing.assert_almost_equal(pred.yhat.values, pred_new.yhat.values) + np.testing.assert_almost_equal(pred.yhat.values, pred_old.yhat.values) diff --git a/python/orca/src/bigdl/orca/automl/search/ray_tune/ray_tune_search_engine.py b/python/orca/src/bigdl/orca/automl/search/ray_tune/ray_tune_search_engine.py index 09c9266fdf1..f43ef7b4e08 100644 --- a/python/orca/src/bigdl/orca/automl/search/ray_tune/ray_tune_search_engine.py +++ b/python/orca/src/bigdl/orca/automl/search/ray_tune/ray_tune_search_engine.py @@ -290,8 +290,9 @@ def train_func(config): train_data = ray.get(data_id) val_data = ray.get(validation_data_id) config = convert_bayes_configs(config).copy() - if not isinstance(model_builder, ModelBuilder): - raise ValueError(f"You must input a ModelBuilder instance for model_builder") + # This check is turned off to support ducking typing + # if not isinstance(model_builder, ModelBuilder): + # raise ValueError(f"You must input a ModelBuilder instance for model_builder") trial_model = model_builder.build(config) # no need to call build since it is called the first time fit_eval is called.