Skip to content

Commit

Permalink
Added ability to stop and resume hyperopt / automl runs (#2108)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Jun 8, 2022
1 parent aa496e3 commit 6c26ee9
Show file tree
Hide file tree
Showing 10 changed files with 199 additions and 160 deletions.
19 changes: 10 additions & 9 deletions examples/hyperopt/model_hyperopt_example.ipynb

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion ludwig/automl/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,13 @@ def auto_train(
parameter initialization and training set shuffling
:param use_reference_config: (bool) refine hyperopt search space by setting first
search point from reference model config, if any
:param kwargs: additional keyword args passed down to `ludwig.hyperopt.run.hyperopt`.
# Returns
:return: (AutoTrainResults) results containing hyperopt experiments and best model
"""
config = create_auto_config(
dataset, target, time_limit_s, tune_for_memory, user_config, random_seed, use_reference_config, **kwargs
dataset, target, time_limit_s, tune_for_memory, user_config, random_seed, use_reference_config
)
return train_with_config(dataset, config, output_directory=output_directory, random_seed=random_seed, **kwargs)

Expand Down Expand Up @@ -194,6 +195,7 @@ def train_with_config(
there is a call to a random number generator, including
hyperparameter search sampling, as well as data splitting,
parameter initialization and training set shuffling
:param kwargs: additional keyword args passed down to `ludwig.hyperopt.run.hyperopt`.
# Returns
:return: (AutoTrainResults) results containing hyperopt experiments and best model
Expand Down
7 changes: 7 additions & 0 deletions ludwig/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,13 @@ def on_hyperopt_trial_end(self, parameters: Dict[str, Any]):
"""
pass

def should_stop_hyperopt(self):
"""Returns true if the entire hyperopt run (all trials) should be stopped.
See: https://docs.ray.io/en/latest/tune/api_docs/stoppers.html#ray.tune.Stopper
"""
return False

def on_train_init(
self,
base_config: Dict[str, Any],
Expand Down
39 changes: 29 additions & 10 deletions ludwig/hyperopt/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import uuid
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

from ludwig.api import LudwigModel
from ludwig.backend import initialize_backend, RAY
Expand All @@ -33,7 +33,7 @@
try:
import ray
from ray import tune
from ray.tune import register_trainable
from ray.tune import register_trainable, Stopper
from ray.tune.suggest import BasicVariantGenerator, ConcurrencyLimiter, SEARCH_ALG_IMPORT
from ray.tune.sync_client import CommandBasedClient
from ray.tune.syncer import get_cloud_sync_client
Expand All @@ -44,6 +44,7 @@
except ImportError as e:
logger.warning(f"ImportError (execution.py) failed to import ray with error: \n\t{e}")
ray = None
Stopper = object
get_horovod_kwargs = None


Expand Down Expand Up @@ -173,8 +174,7 @@ def execute(
data_format=None,
experiment_name="hyperopt",
model_name="run",
model_load_path=None,
model_resume_path=None,
resume=None,
skip_save_training_description=False,
skip_save_training_statistics=False,
skip_save_model=False,
Expand Down Expand Up @@ -552,8 +552,7 @@ def execute(
data_format=None,
experiment_name="hyperopt",
model_name="run",
# model_load_path=None,
# model_resume_path=None,
resume=None,
skip_save_training_description=False,
skip_save_training_statistics=False,
skip_save_model=False,
Expand Down Expand Up @@ -601,8 +600,6 @@ def execute(
data_format=data_format,
experiment_name=experiment_name,
model_name=model_name,
# model_load_path=model_load_path,
# model_resume_path=model_resume_path,
eval_split=self.split,
skip_save_training_description=skip_save_training_description,
skip_save_training_statistics=skip_save_training_statistics,
Expand Down Expand Up @@ -698,9 +695,15 @@ def run_experiment_trial(config, local_hyperopt_dict, checkpoint_dir=None):
run_experiment_trial_params = tune.with_parameters(run_experiment_trial, local_hyperopt_dict=hyperopt_dict)
register_trainable(f"trainable_func_f{hash_dict(config).decode('ascii')}", run_experiment_trial_params)

# Note that resume="AUTO" will attempt to resume the experiment if possible, and
# otherwise will start a new experiment:
# https://docs.ray.io/en/latest/tune/tutorials/tune-stopping.html
should_resume = "AUTO" if resume is None else resume

try:
analysis = tune.run(
f"trainable_func_f{hash_dict(config).decode('ascii')}",
name=experiment_name,
config={
**self.search_space,
**tune_config,
Expand All @@ -719,7 +722,9 @@ def run_experiment_trial(config, local_hyperopt_dict, checkpoint_dir=None):
trial_name_creator=lambda trial: f"trial_{trial.trial_id}",
trial_dirname_creator=lambda trial: f"trial_{trial.trial_id}",
callbacks=tune_callbacks,
stop=CallbackStopper(callbacks),
verbose=hyperopt_log_verbosity,
resume=should_resume,
)
except Exception as e:
# Explicitly raise a RuntimeError if an error is encountered during a Ray trial.
Expand Down Expand Up @@ -777,6 +782,22 @@ def run_experiment_trial(config, local_hyperopt_dict, checkpoint_dir=None):
return RayTuneResults(ordered_trials=ordered_trials, experiment_analysis=analysis)


class CallbackStopper(Stopper):
"""Ray Tune Stopper that triggers the entire job to stop if one callback returns True."""

def __init__(self, callbacks: Optional[List[Callback]]):
self.callbacks = callbacks or []

def __call__(self, trial_id, result):
return False

def stop_all(self):
for callback in self.callbacks:
if callback.should_stop_hyperopt():
return True
return False


def get_build_hyperopt_executor(executor_type):
return get_from_registry(executor_type, executor_registry)

Expand Down Expand Up @@ -833,7 +854,6 @@ def run_experiment(
data_format=None,
experiment_name="hyperopt",
model_name="run",
# model_load_path=None,
model_resume_path=None,
eval_split=VALIDATION,
skip_save_training_description=False,
Expand Down Expand Up @@ -877,7 +897,6 @@ def run_experiment(
data_format=data_format,
experiment_name=experiment_name,
model_name=model_name,
# model_load_path=model_load_path,
model_resume_path=model_resume_path,
eval_split=eval_split,
skip_save_training_description=skip_save_training_description,
Expand Down
28 changes: 16 additions & 12 deletions ludwig/hyperopt/run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
from pprint import pformat
from typing import List, Union
from typing import List, Optional, Union

import pandas as pd
import yaml
Expand All @@ -26,9 +27,6 @@ class RayBackend:
pass


logger = logging.getLogger(__name__)


def hyperopt(
config: Union[str, dict],
dataset: Union[str, dict, pd.DataFrame] = None,
Expand All @@ -39,6 +37,7 @@ def hyperopt(
data_format: str = None,
experiment_name: str = "hyperopt",
model_name: str = "run",
resume: Optional[bool] = None,
skip_save_training_description: bool = False,
skip_save_training_statistics: bool = False,
skip_save_model: bool = False,
Expand Down Expand Up @@ -93,6 +92,11 @@ def hyperopt(
the experiment.
:param model_name: (str, default: `'run'`) name of the model that is
being used.
:param resume: (bool) If true, continue hyperopt from the state of the previous
run in the output directory with the same experiment name. If false, will create
new trials, ignoring any previous state, even if they exist in the output_directory.
By default, will attempt to resume if there is already an existing experiment with
the same name, and will create new trials if not.
:param skip_save_training_description: (bool, default: `False`) disables
saving the description JSON file.
:param skip_save_training_statistics: (bool, default: `False`) disables
Expand Down Expand Up @@ -176,8 +180,8 @@ def hyperopt(
update_hyperopt_params_with_defaults(hyperopt_config)

# print hyperopt config
logger.info(pformat(hyperopt_config, indent=4))
logger.info("\n")
logging.info(pformat(hyperopt_config, indent=4))
logging.info("\n")

search_alg = hyperopt_config["search_alg"]
executor = hyperopt_config["executor"]
Expand Down Expand Up @@ -310,8 +314,7 @@ def hyperopt(
data_format=data_format,
experiment_name=experiment_name,
model_name=model_name,
# model_load_path=None,
# model_resume_path=None,
resume=resume,
skip_save_training_description=skip_save_training_description,
skip_save_training_statistics=skip_save_training_statistics,
skip_save_model=skip_save_model,
Expand All @@ -336,21 +339,22 @@ def hyperopt(
print_hyperopt_results(hyperopt_results)

if not skip_save_hyperopt_statistics:
makedirs(output_directory, exist_ok=True)
results_directory = os.path.join(output_directory, experiment_name)
makedirs(results_directory, exist_ok=True)

hyperopt_stats = {
"hyperopt_config": hyperopt_config,
"hyperopt_results": [t.to_dict() for t in hyperopt_results.ordered_trials],
}

save_hyperopt_stats(hyperopt_stats, output_directory)
logger.info(f"Hyperopt stats saved to: {output_directory}")
save_hyperopt_stats(hyperopt_stats, results_directory)
logging.info(f"Hyperopt stats saved to: {results_directory}")

for callback in callbacks or []:
callback.on_hyperopt_end(experiment_name)
callback.on_hyperopt_finish(experiment_name)

logger.info("Finished hyperopt")
logging.info("Finished hyperopt")

return hyperopt_results

Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,6 @@ def hyperopt_results():
# add hyperopt parameter space to the config
config["hyperopt"] = hyperopt_configs

hyperopt(config, dataset=rel_path, output_directory="results")
hyperopt(config, dataset=rel_path, output_directory="results", experiment_name="hyperopt_test")

return os.path.abspath("results")
return os.path.join(os.path.abspath("results"), "hyperopt_test")
66 changes: 36 additions & 30 deletions tests/integration_tests/test_hyperopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,18 @@ def ray_start(num_cpus: Optional[int] = None, num_gpus: Optional[int] = None):
ray.shutdown()


@pytest.fixture(scope="module")
def ray_cluster():
gpus = [i for i in range(torch.cuda.device_count())]
with ray_start(num_gpus=len(gpus)):
yield


@pytest.mark.distributed
@pytest.mark.parametrize("search_alg", SEARCH_ALGS)
def test_hyperopt_search_alg(search_alg, csv_filename, validate_output_feature=False, validation_metric=None):
def test_hyperopt_search_alg(
search_alg, csv_filename, tmpdir, ray_cluster, validate_output_feature=False, validation_metric=None
):
config, rel_path = _setup_ludwig_config(csv_filename)

hyperopt_config = HYPEROPT_CONFIG.copy()
Expand Down Expand Up @@ -151,29 +160,30 @@ def test_hyperopt_search_alg(search_alg, csv_filename, validate_output_feature=F
search_alg = hyperopt_config["search_alg"]

hyperopt_sampler = get_build_hyperopt_sampler(RAY)(parameters)

gpus = [i for i in range(torch.cuda.device_count())]
with ray_start(num_gpus=len(gpus)):
hyperopt_executor = get_build_hyperopt_executor(RAY)(
hyperopt_sampler, output_feature, metric, goal, split, search_alg=search_alg, **executor
)
raytune_results = hyperopt_executor.execute(config, dataset=rel_path)
assert isinstance(raytune_results, RayTuneResults)
hyperopt_executor = get_build_hyperopt_executor(RAY)(
hyperopt_sampler, output_feature, metric, goal, split, search_alg=search_alg, **executor
)
raytune_results = hyperopt_executor.execute(config, dataset=rel_path, output_directory=tmpdir)
assert isinstance(raytune_results, RayTuneResults)


@pytest.mark.distributed
def test_hyperopt_executor_with_metric(csv_filename):
def test_hyperopt_executor_with_metric(csv_filename, tmpdir, ray_cluster):
test_hyperopt_search_alg(
"variant_generator",
csv_filename,
tmpdir,
ray_cluster,
validate_output_feature=True,
validation_metric=ACCURACY,
)


@pytest.mark.distributed
@pytest.mark.parametrize("scheduler", SCHEDULERS)
def test_hyperopt_scheduler(scheduler, csv_filename, validate_output_feature=False, validation_metric=None):
def test_hyperopt_scheduler(
scheduler, csv_filename, tmpdir, ray_cluster, validate_output_feature=False, validation_metric=None
):
config, rel_path = _setup_ludwig_config(csv_filename)

hyperopt_config = HYPEROPT_CONFIG.copy()
Expand Down Expand Up @@ -211,25 +221,23 @@ def test_hyperopt_scheduler(scheduler, csv_filename, validate_output_feature=Fal

hyperopt_sampler = get_build_hyperopt_sampler(RAY)(parameters)

gpus = [i for i in range(torch.cuda.device_count())]
with ray_start(num_gpus=len(gpus)):
# TODO: Determine if we still need this if-then-else construct
if search_alg["type"] in {""}:
with pytest.raises(ImportError):
get_build_hyperopt_executor(RAY)(
hyperopt_sampler, output_feature, metric, goal, split, search_alg=search_alg, **executor
)
else:
hyperopt_executor = get_build_hyperopt_executor(RAY)(
# TODO: Determine if we still need this if-then-else construct
if search_alg["type"] in {""}:
with pytest.raises(ImportError):
get_build_hyperopt_executor(RAY)(
hyperopt_sampler, output_feature, metric, goal, split, search_alg=search_alg, **executor
)
raytune_results = hyperopt_executor.execute(config, dataset=rel_path)
assert isinstance(raytune_results, RayTuneResults)
else:
hyperopt_executor = get_build_hyperopt_executor(RAY)(
hyperopt_sampler, output_feature, metric, goal, split, search_alg=search_alg, **executor
)
raytune_results = hyperopt_executor.execute(config, dataset=rel_path, output_directory=tmpdir)
assert isinstance(raytune_results, RayTuneResults)


@pytest.mark.distributed
@pytest.mark.parametrize("search_space", ["random", "grid"])
def test_hyperopt_run_hyperopt(csv_filename, search_space):
def test_hyperopt_run_hyperopt(csv_filename, search_space, tmpdir, ray_cluster):
input_features = [
text_feature(name="utterance", cell_type="lstm", reduce_output="sum"),
category_feature(vocab_size=2, reduce_input="sum"),
Expand Down Expand Up @@ -288,9 +296,7 @@ def test_hyperopt_run_hyperopt(csv_filename, search_space):
# add hyperopt parameter space to the config
config["hyperopt"] = hyperopt_configs

with ray_start():
hyperopt_results = hyperopt(config, dataset=rel_path, output_directory="results_hyperopt")

hyperopt_results = hyperopt(config, dataset=rel_path, output_directory=tmpdir, experiment_name="test_hyperopt")
if search_space == "random":
assert hyperopt_results.experiment_analysis.results_df.shape[0] == RANDOM_SEARCH_SIZE
else:
Expand All @@ -304,7 +310,7 @@ def test_hyperopt_run_hyperopt(csv_filename, search_space):
assert isinstance(hyperopt_results, HyperoptResults)

# check for existence of the hyperopt statistics file
assert os.path.isfile(os.path.join("results_hyperopt", "hyperopt_statistics.json"))
assert os.path.isfile(os.path.join(tmpdir, "test_hyperopt", "hyperopt_statistics.json"))

if os.path.isfile(os.path.join("results_hyperopt", "hyperopt_statistics.json")):
os.remove(os.path.join("results_hyperopt", "hyperopt_statistics.json"))
if os.path.isfile(os.path.join(tmpdir, "test_hyperopt", "hyperopt_statistics.json")):
os.remove(os.path.join(tmpdir, "test_hyperopt", "hyperopt_statistics.json"))
Loading

0 comments on commit 6c26ee9

Please sign in to comment.