Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added scheduler support to Ray Tune hyperopt and fixed GPU usage #1088

Merged
merged 24 commits into from
Feb 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ install:
else
pip install tensorflow==$TENSORFLOW
fi
- pip install --no-build-isolation ConfigSpace # temporary fix: https://github.com/automl/ConfigSpace/issues/173
- HOROVOD_WITH_TENSORFLOW=1 HOROVOD_WITHOUT_MPI=1 HOROVOD_WITHOUT_PYTORCH=1 HOROVOD_WITHOUT_MXNET=1 pip install --no-cache-dir '.[test]'
script:
- pip list
Expand Down
16 changes: 8 additions & 8 deletions ludwig/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,26 @@


class Callback(ABC):
def on_batch_start(self, trainer, progress_tracker):
def on_batch_start(self, trainer, progress_tracker, save_path):
pass

def on_batch_end(self, trainer, progress_tracker):
def on_batch_end(self, trainer, progress_tracker, save_path):
pass

def on_epoch_start(self, trainer, progress_tracker):
def on_epoch_start(self, trainer, progress_tracker, save_path):
pass

def on_epoch_end(self, trainer, progress_tracker):
def on_epoch_end(self, trainer, progress_tracker, save_path):
pass

def on_validation_start(self, trainer, progress_tracker):
def on_validation_start(self, trainer, progress_tracker, save_path):
pass

def on_validation_end(self, trainer, progress_tracker):
def on_validation_end(self, trainer, progress_tracker, save_path):
pass

def on_test_start(self, trainer, progress_tracker):
def on_test_start(self, trainer, progress_tracker, save_path):
pass

def on_test_end(self, trainer, progress_tracker):
def on_test_end(self, trainer, progress_tracker, save_path):
pass
114 changes: 84 additions & 30 deletions ludwig/hyperopt/execution.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
import os
import copy
import json
import multiprocessing
import signal
import shutil
from abc import ABC, abstractmethod
from typing import Union

from ludwig.api import LudwigModel
from ludwig.callbacks import Callback
from ludwig.constants import *
from ludwig.hyperopt.sampling import HyperoptSampler, RayTuneSampler, logger
from ludwig.modules.metric_modules import get_best_function
from ludwig.utils.data_utils import NumpyEncoder
from ludwig.utils.defaults import default_random_seed
from ludwig.utils.misc_utils import get_available_gpu_memory, get_from_registry
from ludwig.utils.tf_utils import get_available_gpus_cuda_string

try:
import ray
from ray import tune
from ray.tune.utils import wait_for_gpu
except ImportError:
ray = None

Expand Down Expand Up @@ -45,7 +50,7 @@ def get_metric_score(self, train_stats, eval_stats) -> float:
"best validation performance")
return self.get_metric_score_from_eval_stats(eval_stats)

def get_metric_score_from_eval_stats(self, eval_stats) -> float:
def get_metric_score_from_eval_stats(self, eval_stats) -> Union[float, list]:
if '.' in self.metric:
metric_parts = self.metric.split('.')
stats = eval_stats[self.output_feature]
Expand Down Expand Up @@ -111,7 +116,7 @@ def execute(
skip_save_model=False,
skip_save_progress=False,
skip_save_log=False,
skip_save_processed_input=False,
skip_save_processed_input=True,
skip_save_unprocessed_output=False,
skip_save_predictions=False,
skip_save_eval_stats=False,
Expand Down Expand Up @@ -154,7 +159,7 @@ def execute(
skip_save_model=False,
skip_save_progress=False,
skip_save_log=False,
skip_save_processed_input=False,
skip_save_processed_input=True,
skip_save_unprocessed_output=False,
skip_save_predictions=False,
skip_save_eval_stats=False,
Expand Down Expand Up @@ -307,7 +312,7 @@ def execute(
skip_save_model=False,
skip_save_progress=False,
skip_save_log=False,
skip_save_processed_input=False,
skip_save_processed_input=True,
skip_save_unprocessed_output=False,
skip_save_predictions=False,
skip_save_eval_stats=False,
Expand Down Expand Up @@ -560,7 +565,7 @@ def execute(
skip_save_model=False,
skip_save_progress=False,
skip_save_log=False,
skip_save_processed_input=False,
skip_save_processed_input=True,
skip_save_unprocessed_output=False,
skip_save_predictions=False,
skip_save_eval_stats=False,
Expand Down Expand Up @@ -681,6 +686,8 @@ def __init__(
self.num_samples = hyperopt_sampler.num_samples
self.goal = hyperopt_sampler.goal
self.search_alg_dict = hyperopt_sampler.search_alg_dict
self.scheduler = hyperopt_sampler.scheduler
self.decode_ctx = hyperopt_sampler.decode_ctx
self.output_feature = output_feature
self.metric = metric
self.split = split
Expand All @@ -689,24 +696,54 @@ def __init__(
self.gpu_resources_per_trial = gpu_resources_per_trial
self.kubernetes_namespace = kubernetes_namespace

def _run_experiment(self, config, hyperopt_dict):
def _run_experiment(self, config, checkpoint_dir, hyperopt_dict, decode_ctx):
for gpu_id in ray.get_gpu_ids():
# Previous trial may not have freed its memory yet, so wait to avoid OOM
wait_for_gpu(gpu_id)

# Some config values may be JSON encoded as strings, so decode them here
config = RayTuneSampler.decode_values(config, decode_ctx)

trial_id = tune.get_trial_id()
gpus_ids = ray.get_gpu_ids()
if gpus_ids:
gpus = ",".join(str(id) for id in gpus_ids)
else:
gpus = None
modified_config = substitute_parameters(
copy.deepcopy(hyperopt_dict["config"]), config)
hyperopt_dict["config"] = modified_config
hyperopt_dict["experiment_name"] = f'{hyperopt_dict["experiment_name"]}_{trial_id}'
hyperopt_dict["gpus"] = gpus
copy.deepcopy(hyperopt_dict["config"]), config
)

train_stats, eval_stats = run_experiment(**hyperopt_dict)
metric_score = self.get_metric_score(train_stats, eval_stats)
hyperopt_dict['config'] = modified_config
hyperopt_dict['experiment_name '] = f'{hyperopt_dict["experiment_name"]}_{trial_id}'

tune_executor = self

class RayTuneReportCallback(Callback):
def on_epoch_end(self, trainer, progress_tracker, save_path):
if trainer.is_coordinator():
with tune.checkpoint_dir(step=progress_tracker.epoch) as checkpoint_dir:
checkpoint_model = os.path.join(checkpoint_dir, 'model')
shutil.copytree(save_path, checkpoint_model)

train_stats, eval_stats = progress_tracker.train_metrics, progress_tracker.vali_metrics
stats = eval_stats or train_stats
metric_score = tune_executor.get_metric_score_from_eval_stats(stats)[-1]
tune.report(
parameters=json.dumps(config, cls=NumpyEncoder),
metric_score=metric_score,
training_stats=json.dumps(train_stats, cls=NumpyEncoder),
eval_stats=json.dumps(eval_stats, cls=NumpyEncoder)
)

tune.report(parameters=str(config), metric_score=metric_score,
training_stats=str(train_stats), eval_stats=str(eval_stats))
train_stats, eval_stats = run_experiment(
**hyperopt_dict,
model_resume_path=checkpoint_dir,
callbacks=[RayTuneReportCallback()],
)

metric_score = self.get_metric_score(train_stats, eval_stats)
tune.report(
parameters=json.dumps(config, cls=NumpyEncoder),
metric_score=metric_score,
training_stats=json.dumps(train_stats, cls=NumpyEncoder),
eval_stats=json.dumps(eval_stats, cls=NumpyEncoder)
)

def execute(self,
config,
Expand All @@ -733,13 +770,18 @@ def execute(self,
gpus=None,
gpu_memory_limit=None,
allow_parallel_threads=True,
use_horovod=None,
backend=None,
random_seed=default_random_seed,
debug=False,
**kwargs):
if isinstance(dataset, str) and not os.path.isabs(dataset):
dataset = os.path.abspath(dataset)

if gpus is not None:
raise ValueError("Parameter `gpus` is not supported when using Ray Tune. "
"Configure GPU resources with Ray and set `gpu_resources_per_trial` in your "
"hyperopt config.")

hyperopt_dict = dict(
config=config,
dataset=dataset,
Expand All @@ -766,11 +808,13 @@ def execute(self,
gpus=gpus,
gpu_memory_limit=gpu_memory_limit,
allow_parallel_threads=allow_parallel_threads,
use_horovod=use_horovod,
backend=backend,
random_seed=random_seed,
debug=debug,
)

mode = "min" if self.goal != MAXIMIZE else "max"
metric = "metric_score"
if self.search_alg_dict is not None:
if TYPE not in self.search_alg_dict:
logger.warning(
Expand All @@ -779,10 +823,9 @@ def execute(self,
)
search_alg = None
else:
mode = "min" if self.goal != MAXIMIZE else "max"
search_alg_type = self.search_alg_dict.pop(TYPE)
search_alg = tune.create_searcher(
search_alg_type, metric="metric_score", mode=mode, **self.search_alg_dict)
search_alg_type, metric=metric, mode=mode, **self.search_alg_dict)
else:
search_alg = None

Expand All @@ -793,17 +836,28 @@ def execute(self,
sync_to_driver=NamespacedKubernetesSyncer(self.kubernetes_namespace)
)

resources_per_trial = {
"cpu": self.cpu_resources_per_trial or 1,
"gpu": self.gpu_resources_per_trial or 0,
}

def run_experiment_trial(config, checkpoint_dir=None):
return self._run_experiment(config, checkpoint_dir, hyperopt_dict, self.decode_ctx)

analysis = tune.run(
tune.with_parameters(self._run_experiment, hyperopt_dict=hyperopt_dict),
run_experiment_trial,
config=self.search_space,
scheduler=self.scheduler,
search_alg=search_alg,
num_samples=self.num_samples,
resources_per_trial={
"cpu": self.cpu_resources_per_trial or 1,
"gpu": self.gpu_resources_per_trial or 0,
},
resources_per_trial=resources_per_trial,
queue_trials=True,
sync_config=sync_config,
local_dir=output_directory,
metric=metric,
mode=mode,
trial_name_creator=lambda trial: f"trial_{trial.trial_id}",
trial_dirname_creator=lambda trial: f"trial_{trial.trial_id}",
)

hyperopt_results = analysis.results_df.sort_values(
Expand Down Expand Up @@ -876,7 +930,7 @@ def run_experiment(
experiment_name="hyperopt",
model_name="run",
# model_load_path=None,
# model_resume_path=None,
model_resume_path=None,
eval_split=VALIDATION,
skip_save_training_description=False,
skip_save_training_statistics=False,
Expand Down Expand Up @@ -916,7 +970,7 @@ def run_experiment(
experiment_name=experiment_name,
model_name=model_name,
# model_load_path=model_load_path,
# model_resume_path=model_resume_path,
model_resume_path=model_resume_path,
eval_split=eval_split,
skip_save_training_description=skip_save_training_description,
skip_save_training_statistics=skip_save_training_statistics,
Expand Down
Loading