Skip to content

Commit

Permalink
Split hp search methods (#6857)
Browse files Browse the repository at this point in the history
* Split the run_hp_search by backend

* Unused import
  • Loading branch information
sgugger committed Aug 31, 2020
1 parent 23f9611 commit a59bcef
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 73 deletions.
150 changes: 79 additions & 71 deletions src/transformers/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np

from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, HPSearchBackend
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun
from transformers.utils import logging


Expand Down Expand Up @@ -83,7 +83,7 @@ def default_hp_search_backend():
return "ray"


def run_hp_search(trainer, n_trials, direction, kwargs):
def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
def _objective(trial, checkpoint_dir=None):
model_path = None
if checkpoint_dir:
Expand All @@ -96,80 +96,88 @@ def _objective(trial, checkpoint_dir=None):
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics)
if trainer.hp_search_backend == HPSearchBackend.RAY:
trainer._tune_save_checkpoint()
ray.tune.report(objective=trainer.objective)
return trainer.objective

if trainer.hp_search_backend == HPSearchBackend.OPTUNA:
timeout = kwargs.pop("timeout", None)
n_jobs = kwargs.pop("n_jobs", 1)
study = optuna.create_study(direction=direction, **kwargs)
study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs)
best_trial = study.best_trial
best_run = BestRun(str(best_trial.number), best_trial.value, best_trial.params)
elif trainer.hp_search_backend == HPSearchBackend.RAY:
# The model and TensorBoard writer do not pickle so we have to remove them (if they exists)
# while doing the ray hp search.
_tb_writer = trainer.tb_writer
trainer.tb_writer = None
trainer.model = None
# Setup default `resources_per_trial` and `reporter`.
if "resources_per_trial" not in kwargs and trainer.args.n_gpu > 0:
# `args.n_gpu` is considered the total number of GPUs that will be split
# among the `n_jobs`
n_jobs = int(kwargs.pop("n_jobs", 1))
num_gpus_per_trial = trainer.args.n_gpu
if num_gpus_per_trial / n_jobs >= 1:
num_gpus_per_trial = int(np.ceil(num_gpus_per_trial / n_jobs))
kwargs["resources_per_trial"] = {"gpu": num_gpus_per_trial}

if "reporter" not in kwargs:
from ray.tune import CLIReporter

kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"])
if "keep_checkpoints_num" in kwargs and kwargs["keep_checkpoints_num"] > 0:
# `keep_checkpoints_num=0` would disabled checkpointing
trainer.use_tune_checkpoints = True
if kwargs["keep_checkpoints_num"] > 1:
logger.warning(
"Currently keeping {} checkpoints for each trial. Checkpoints are usually huge, "
"consider setting `keep_checkpoints_num=1`."
)
if "scheduler" in kwargs:
from ray.tune.schedulers import (
ASHAScheduler,
HyperBandForBOHB,
MedianStoppingRule,
PopulationBasedTraining,
timeout = kwargs.pop("timeout", None)
n_jobs = kwargs.pop("n_jobs", 1)
study = optuna.create_study(direction=direction, **kwargs)
study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs)
best_trial = study.best_trial
return BestRun(str(best_trial.number), best_trial.value, best_trial.params)


def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
def _objective(trial, checkpoint_dir=None):
model_path = None
if checkpoint_dir:
for subdir in os.listdir(checkpoint_dir):
if subdir.startswith(PREFIX_CHECKPOINT_DIR):
model_path = os.path.join(checkpoint_dir, subdir)
trainer.objective = None
trainer.train(model_path=model_path, trial=trial)
# If there hasn't been any evaluation during the training loop.
if getattr(trainer, "objective", None) is None:
metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics)
trainer._tune_save_checkpoint()
ray.tune.report(objective=trainer.objective)
return trainer.objective

# The model and TensorBoard writer do not pickle so we have to remove them (if they exists)
# while doing the ray hp search.
_tb_writer = trainer.tb_writer
trainer.tb_writer = None
trainer.model = None
# Setup default `resources_per_trial` and `reporter`.
if "resources_per_trial" not in kwargs and trainer.args.n_gpu > 0:
# `args.n_gpu` is considered the total number of GPUs that will be split
# among the `n_jobs`
n_jobs = int(kwargs.pop("n_jobs", 1))
num_gpus_per_trial = trainer.args.n_gpu
if num_gpus_per_trial / n_jobs >= 1:
num_gpus_per_trial = int(np.ceil(num_gpus_per_trial / n_jobs))
kwargs["resources_per_trial"] = {"gpu": num_gpus_per_trial}

if "reporter" not in kwargs:
from ray.tune import CLIReporter

kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"])
if "keep_checkpoints_num" in kwargs and kwargs["keep_checkpoints_num"] > 0:
# `keep_checkpoints_num=0` would disabled checkpointing
trainer.use_tune_checkpoints = True
if kwargs["keep_checkpoints_num"] > 1:
logger.warning(
"Currently keeping {} checkpoints for each trial. Checkpoints are usually huge, "
"consider setting `keep_checkpoints_num=1`."
)
if "scheduler" in kwargs:
from ray.tune.schedulers import ASHAScheduler, HyperBandForBOHB, MedianStoppingRule, PopulationBasedTraining

# Check if checkpointing is enabled for PopulationBasedTraining
if isinstance(kwargs["scheduler"], PopulationBasedTraining):
if not trainer.use_tune_checkpoints:
logger.warning(
"You are using PopulationBasedTraining but you haven't enabled checkpointing. "
"This means your trials will train from scratch everytime they are exploiting "
"new configurations. Consider enabling checkpointing by passing "
"`keep_checkpoints_num=1` as an additional argument to `Trainer.hyperparameter_search`."
)

# Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting.
if isinstance(
kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining)
) and (not trainer.args.do_eval or not trainer.args.evaluate_during_training):
raise RuntimeError(
"You are using {cls} as a scheduler but you haven't enabled evaluation during training. "
"This means your trials will not report intermediate results to Ray Tune, and "
"can thus not be stopped early or used to exploit other trials parameters. "
"If this is what you want, do not use {cls}. If you would like to use {cls}, "
"make sure you pass `do_eval=True` and `evaluate_during_training=True` in the "
"Trainer `args`.".format(cls=type(kwargs["scheduler"]).__name__)
# Check if checkpointing is enabled for PopulationBasedTraining
if isinstance(kwargs["scheduler"], PopulationBasedTraining):
if not trainer.use_tune_checkpoints:
logger.warning(
"You are using PopulationBasedTraining but you haven't enabled checkpointing. "
"This means your trials will train from scratch everytime they are exploiting "
"new configurations. Consider enabling checkpointing by passing "
"`keep_checkpoints_num=1` as an additional argument to `Trainer.hyperparameter_search`."
)

analysis = ray.tune.run(_objective, config=trainer.hp_space(None), num_samples=n_trials, **kwargs)
best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3])
best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config)
trainer.tb_writer = _tb_writer
# Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting.
if isinstance(
kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining)
) and (not trainer.args.do_eval or not trainer.args.evaluate_during_training):
raise RuntimeError(
"You are using {cls} as a scheduler but you haven't enabled evaluation during training. "
"This means your trials will not report intermediate results to Ray Tune, and "
"can thus not be stopped early or used to exploit other trials parameters. "
"If this is what you want, do not use {cls}. If you would like to use {cls}, "
"make sure you pass `do_eval=True` and `evaluate_during_training=True` in the "
"Trainer `args`.".format(cls=type(kwargs["scheduler"]).__name__)
)

analysis = ray.tune.run(_objective, config=trainer.hp_space(None), num_samples=n_trials, **kwargs)
best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3])
best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config)
trainer.tb_writer = _tb_writer
return best_run
6 changes: 4 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
is_ray_available,
is_tensorboard_available,
is_wandb_available,
run_hp_search,
run_hp_search_optuna,
run_hp_search_ray,
)
from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup
Expand Down Expand Up @@ -884,7 +885,8 @@ def hyperparameter_search(
self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
self.compute_objective = default_compute_objective if compute_objective is None else compute_objective

best_run = run_hp_search(self, n_trials, direction, kwargs)
run_hp_search = run_hp_search_optuna if backend == HPSearchBackend.OPTUNA else run_hp_search_ray
best_run = run_hp_search(self, n_trials, direction, **kwargs)

self.hp_search_backend = None
return best_run
Expand Down

0 comments on commit a59bcef

Please sign in to comment.