From e174bfeb340d3d3468d9c8eebce95c42aa2dcf84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Lagunas?= Date: Wed, 21 Oct 2020 17:18:52 +0200 Subject: [PATCH] TensorBoard/Wandb/optuna/raytune integration improvements. (#7935) Improved TensorBoard and Wandb integration, as well as optuna and ray/tune support, with minor modifications to trainer core code. --- src/transformers/integrations.py | 85 +++++++++++++-- src/transformers/testing_utils.py | 27 +++++ src/transformers/trainer.py | 34 ++++-- src/transformers/trainer_callback.py | 8 +- src/transformers/trainer_utils.py | 8 +- src/transformers/utils/hp_naming.py | 148 +++++++++++++++++++++++++++ tests/test_trainer.py | 57 ++++++++++- 7 files changed, 344 insertions(+), 23 deletions(-) create mode 100644 src/transformers/utils/hp_naming.py diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index 486736f0868f4..93e1e6eab62b5 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -85,6 +85,17 @@ def is_ray_available(): return _has_ray +def hp_params(trial): + if is_optuna_available(): + if isinstance(trial, optuna.Trial): + return trial.params + if is_ray_available(): + if isinstance(trial, dict): + return trial + + raise RuntimeError(f"Unknown type for trial {trial.__class__}") + + def default_hp_search_backend(): if is_optuna_available(): return "optuna" @@ -192,6 +203,18 @@ def _objective(trial, checkpoint_dir=None): return best_run +def rewrite_logs(d): + new_d = {} + eval_prefix = "eval_" + eval_prefix_len = len(eval_prefix) + for k, v in d.items(): + if k.startswith(eval_prefix): + new_d["eval/" + k[eval_prefix_len:]] = v + else: + new_d["train/" + k] = v + return new_d + + class TensorBoardCallback(TrainerCallback): """ A :class:`~transformers.TrainerCallback` that sends the logs to `TensorBoard @@ -208,17 +231,39 @@ def __init__(self, tb_writer=None): ), "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX." self.tb_writer = tb_writer - def on_init_end(self, args, state, control, **kwargs): - if self.tb_writer is None and state.is_world_process_zero: - self.tb_writer = SummaryWriter(log_dir=args.logging_dir) + def _init_summary_writer(self, args, log_dir=None): + log_dir = log_dir or args.logging_dir + self.tb_writer = SummaryWriter(log_dir=log_dir) def on_train_begin(self, args, state, control, **kwargs): + if not state.is_world_process_zero: + return + + log_dir = None + + if state.is_hyper_param_search: + trial_name = state.trial_name + if trial_name is not None: + log_dir = os.path.join(args.logging_dir, trial_name) + + self._init_summary_writer(args, log_dir) + if self.tb_writer is not None: self.tb_writer.add_text("args", args.to_json_string()) + if "model" in kwargs: + model = kwargs["model"] + if hasattr(model, "config") and model.config is not None: + model_config_json = model.config.to_json_string() + self.tb_writer.add_text("model_config", model_config_json) self.tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={}) def on_log(self, args, state, control, logs=None, **kwargs): + if state.is_world_process_zero: + if self.tb_writer is None: + self._init_summary_writer(args) + if self.tb_writer: + logs = rewrite_logs(logs) for k, v in logs.items(): if isinstance(v, (int, float)): self.tb_writer.add_scalar(k, v, state.global_step) @@ -249,7 +294,7 @@ def __init__(self): assert _has_wandb, "WandbCallback requires wandb to be installed. Run `pip install wandb`." self._initialized = False - def setup(self, args, state, model): + def setup(self, args, state, model, reinit, **kwargs): """ Setup the optional Weights & Biases (`wandb`) integration. @@ -271,21 +316,41 @@ def setup(self, args, state, model): 'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"' ) combined_dict = {**args.to_sanitized_dict()} - if getattr(model, "config", None) is not None: - combined_dict = {**model.config.to_dict(), **combined_dict} - wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=args.run_name) + + if hasattr(model, "config") and model.config is not None: + model_config = model.config.to_dict() + combined_dict = {**model_config, **combined_dict} + trial_name = state.trial_name + init_args = {} + if trial_name is not None: + run_name = trial_name + init_args["group"] = args.run_name + else: + run_name = args.run_name + + wandb.init( + project=os.getenv("WANDB_PROJECT", "huggingface"), + config=combined_dict, + name=run_name, + reinit=reinit, + **init_args, + ) + # keep track of model topology and gradients, unsupported on TPU if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false": wandb.watch(model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps)) def on_train_begin(self, args, state, control, model=None, **kwargs): - if not self._initialized: - self.setup(args, state, model) + hp_search = state.is_hyper_param_search + if not self._initialized or hp_search: + print(args.run_name) + self.setup(args, state, model, reinit=hp_search, **kwargs) def on_log(self, args, state, control, model=None, logs=None, **kwargs): if not self._initialized: - self.setup(args, state, model) + self.setup(args, state, model, reinit=False) if state.is_world_process_zero: + logs = rewrite_logs(logs) wandb.log(logs, step=state.global_step) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index d108112e8f4b4..9250933d703c5 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -20,6 +20,7 @@ _torch_available, _torch_tpu_available, ) +from .integrations import _has_optuna, _has_ray SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy" @@ -233,6 +234,32 @@ def require_faiss(test_case): return test_case +def require_optuna(test_case): + """ + Decorator marking a test that requires optuna. + + These tests are skipped when optuna isn't installed. + + """ + if not _has_optuna: + return unittest.skip("test requires optuna")(test_case) + else: + return test_case + + +def require_ray(test_case): + """ + Decorator marking a test that requires Ray/tune. + + These tests are skipped when Ray/tune isn't installed. + + """ + if not _has_ray: + return unittest.skip("test requires Ray/tune")(test_case) + else: + return test_case + + def get_tests_dir(append_path=None): """ Args: diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 318c5a6e19410..c20799f638829 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -39,6 +39,7 @@ from .file_utils import WEIGHTS_NAME, is_datasets_available, is_in_notebook, is_torch_tpu_available from .integrations import ( default_hp_search_backend, + hp_params, is_comet_available, is_optuna_available, is_ray_available, @@ -224,6 +225,7 @@ def __init__( model is not None or model_init is not None ), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument." self.model_init = model_init + self.hp_name = None if model is None and model_init is not None: model = self.call_model_init() self.model = model.to(args.device) if model is not None else None @@ -508,8 +510,11 @@ def num_examples(self, dataloader: DataLoader) -> int: def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]): """ HP search setup code """ + self._trial = trial + if self.hp_search_backend is None or trial is None: return + params = self.hp_space(trial) if self.hp_search_backend == HPSearchBackend.OPTUNA else trial for key, value in params.items(): if not hasattr(self.args, key): @@ -558,7 +563,10 @@ def call_model_init(self, trial=None): elif model_init_argcount == 1: model = self.model_init(trial) else: - raise Exception("model_init should have 0 or 1 argument.") + raise RuntimeError("model_init should have 0 or 1 argument.") + + if model is None: + raise RuntimeError("model_init should not return None.") return model @@ -617,6 +625,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D self.create_optimizer_and_scheduler(num_training_steps=max_steps) self.state = TrainerState() + self.state.is_hyper_param_search = trial is not None # Check if saved optimizer or scheduler states exist if ( @@ -702,6 +711,8 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D self.callback_handler.optimizer = self.optimizer self.callback_handler.lr_scheduler = self.lr_scheduler self.callback_handler.train_dataloader = train_dataloader + self.state.trial_name = self.hp_name(trial) if self.hp_name is not None else None + self.state.trial_params = hp_params(trial) if trial is not None else None # This should be the same if the state has been saved but in case the training arguments changed, it's safer # to set this after the load. self.state.max_steps = max_steps @@ -783,13 +794,13 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D self.state.epoch = epoch + (step + 1) / steps_in_epoch self.control = self.callback_handler.on_step_end(self.args, self.state, self.control) - self._maybe_log_save_evalute(tr_loss, model, trial, epoch) + self._maybe_log_save_evaluate(tr_loss, model, trial, epoch) if self.control.should_epoch_stop or self.control.should_training_stop: break self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control) - self._maybe_log_save_evalute(tr_loss, model, trial, epoch) + self._maybe_log_save_evaluate(tr_loss, model, trial, epoch) if self.args.tpu_metrics_debug or self.args.debug: if is_torch_tpu_available(): @@ -823,7 +834,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step) - def _maybe_log_save_evalute(self, tr_loss, model, trial, epoch): + def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch): if self.control.should_log: logs: Dict[str, float] = {} tr_loss_scalar = tr_loss.item() @@ -842,6 +853,7 @@ def _maybe_log_save_evalute(self, tr_loss, model, trial, epoch): if self.control.should_evaluate: metrics = self.evaluate() self._report_to_hp_search(trial, epoch, metrics) + self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics) if self.control.should_save: @@ -857,12 +869,15 @@ def _save_checkpoint(self, model, trial, metrics=None): assert model is self.model, f"Model {model} should be a reference to self.model" # Save model checkpoint checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + if self.hp_search_backend is not None and trial is not None: run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id() - checkpoint_folder += f"-run-{run_id}" - output_dir = os.path.join(self.args.output_dir, checkpoint_folder) + run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}" + output_dir = os.path.join(self.args.output_dir, run_name, checkpoint_folder) + else: + output_dir = os.path.join(self.args.output_dir, checkpoint_folder) - self.store_flos() + self.store_flos() self.save_model(output_dir) # Save optimizer and scheduler @@ -909,6 +924,7 @@ def hyperparameter_search( n_trials: int = 20, direction: str = "minimize", backend: Optional[Union["str", HPSearchBackend]] = None, + hp_name: Optional[Callable[["optuna.Trial"], str]] = None, **kwargs ) -> BestRun: """ @@ -966,13 +982,13 @@ def hyperparameter_search( "You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`." ) self.hp_search_backend = backend - if self.model_init is None: raise RuntimeError( "To use hyperparameter search, you need to pass your model through a model_init function." ) self.hp_space = default_hp_space[backend] if hp_space is None else hp_space + self.hp_name = hp_name self.compute_objective = default_compute_objective if compute_objective is None else compute_objective run_hp_search = run_hp_search_optuna if backend == HPSearchBackend.OPTUNA else run_hp_search_ray @@ -997,12 +1013,12 @@ def log(self, logs: Dict[str, float]) -> None: FutureWarning, ) return self._log(logs) - if self.state.epoch is not None: logs["epoch"] = self.state.epoch if self._total_flos is not None: self.store_flos() logs["total_flos"] = self.state.total_flos + self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs) output = {**logs, **{"step": self.state.global_step}} self.state.log_history.append(output) diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index c3057f8825012..2222362618afb 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -19,7 +19,7 @@ import dataclasses import json from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union from tqdm.auto import tqdm @@ -66,6 +66,9 @@ class TrainerState: is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not this process is the global main process (when training in a distributed fashion on several machines, this is only going to be :obj:`True` for one process). + is_hyper_param_search (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether we are in the process of a hyper parameter search using Trainer.hyperparameter_search. + This will impact the way data will be logged in TensorBoard. """ epoch: Optional[float] = None @@ -78,6 +81,9 @@ class TrainerState: best_model_checkpoint: Optional[str] = None is_local_process_zero: bool = True is_world_process_zero: bool = True + is_hyper_param_search: bool = False + trial_name: str = None + trial_params: Dict[str, Union[str, float, int, bool]] = None def __post_init__(self): if self.log_history is None: diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index ef3eaa1b0594c..0c747c226c16a 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -16,6 +16,7 @@ Utilities for the Trainer and TFTrainer class. Should be independent from PyTorch and TensorFlow. """ +import copy import random from typing import Any, Dict, NamedTuple, Optional, Tuple, Union @@ -110,10 +111,15 @@ def default_compute_objective(metrics: Dict[str, float]) -> float: Return: :obj:`float`: The objective to minimize or maximize """ + metrics = copy.deepcopy(metrics) loss = metrics.pop("eval_loss", None) _ = metrics.pop("epoch", None) _ = metrics.pop("total_flos", None) - return loss if len(metrics) == 0 else sum(metrics.values()) + if len(metrics) != 0: + raise RuntimeError( + "Metrics contains more entries than just 'eval_loss', 'epoch' and 'total_flos', please provide your own compute_objective function." + ) + return loss def default_hp_space_optuna(trial) -> Dict[str, float]: diff --git a/src/transformers/utils/hp_naming.py b/src/transformers/utils/hp_naming.py new file mode 100644 index 0000000000000..6954da95a6b61 --- /dev/null +++ b/src/transformers/utils/hp_naming.py @@ -0,0 +1,148 @@ +import copy +import re + + +class TrialShortNamer: + PREFIX = "hp" + DEFAULTS = {} + NAMING_INFO = None + + @classmethod + def set_defaults(cls, prefix, defaults): + cls.PREFIX = prefix + cls.DEFAULTS = defaults + cls.build_naming_info() + + @staticmethod + def shortname_for_word(info, word): + if len(word) == 0: + return "" + short_word = None + if any(char.isdigit() for char in word): + raise Exception(f"Parameters should not contain numbers: '{word}' contains a number") + if word in info["short_word"]: + return info["short_word"][word] + for prefix_len in range(1, len(word) + 1): + prefix = word[:prefix_len] + if prefix in info["reverse_short_word"]: + continue + else: + short_word = prefix + break + + if short_word is None: + # Paranoid fallback + def int_to_alphabetic(integer): + s = "" + while integer != 0: + s = chr(ord("A") + integer % 10) + s + integer //= 10 + return s + + i = 0 + while True: + sword = word + "#" + int_to_alphabetic(i) + if sword in info["reverse_short_word"]: + continue + else: + short_word = sword + break + + info["short_word"][word] = short_word + info["reverse_short_word"][short_word] = word + return short_word + + @staticmethod + def shortname_for_key(info, param_name): + words = param_name.split("_") + + shortname_parts = [TrialShortNamer.shortname_for_word(info, word) for word in words] + + # We try to create a separatorless short name, but if there is a collision we have to fallback + # to a separated short name + separators = ["", "_"] + + for separator in separators: + shortname = separator.join(shortname_parts) + if shortname not in info["reverse_short_param"]: + info["short_param"][param_name] = shortname + info["reverse_short_param"][shortname] = param_name + return shortname + + return param_name + + @staticmethod + def add_new_param_name(info, param_name): + short_name = TrialShortNamer.shortname_for_key(info, param_name) + info["short_param"][param_name] = short_name + info["reverse_short_param"][short_name] = param_name + + @classmethod + def build_naming_info(cls): + if cls.NAMING_INFO is not None: + return + + info = dict( + short_word={}, + reverse_short_word={}, + short_param={}, + reverse_short_param={}, + ) + + field_keys = list(cls.DEFAULTS.keys()) + + for k in field_keys: + cls.add_new_param_name(info, k) + + cls.NAMING_INFO = info + + @classmethod + def shortname(cls, params): + cls.build_naming_info() + assert cls.PREFIX is not None + name = [copy.copy(cls.PREFIX)] + + for k, v in params.items(): + if k not in cls.DEFAULTS: + raise Exception(f"You should provide a default value for the param name {k} with value {v}") + if v == cls.DEFAULTS[k]: + # The default value is not added to the name + continue + + key = cls.NAMING_INFO["short_param"][k] + + if isinstance(v, bool): + v = 1 if v else 0 + + sep = "" if isinstance(v, (int, float)) else "-" + e = f"{key}{sep}{v}" + name.append(e) + + return "_".join(name) + + @classmethod + def parse_repr(cls, repr): + repr = repr[len(cls.PREFIX) + 1 :] + if repr == "": + values = [] + else: + values = repr.split("_") + + parameters = {} + + for value in values: + if "-" in value: + p_k, p_v = value.split("-") + else: + p_k = re.sub("[0-9.]", "", value) + p_v = float(re.sub("[^0-9.]", "", value)) + + key = cls.NAMING_INFO["reverse_short_param"][p_k] + + parameters[key] = p_v + + for k in cls.DEFAULTS: + if k not in parameters: + parameters[k] = cls.DEFAULTS[k] + + return parameters diff --git a/tests/test_trainer.py b/tests/test_trainer.py index ec8a339281768..6505539cdac04 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -21,9 +21,17 @@ import datasets import numpy as np -from transformers import AutoTokenizer, PretrainedConfig, TrainingArguments, is_torch_available +from transformers import AutoTokenizer, EvaluationStrategy, PretrainedConfig, TrainingArguments, is_torch_available from transformers.file_utils import WEIGHTS_NAME -from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, require_torch, slow +from transformers.testing_utils import ( + get_tests_dir, + require_optuna, + require_sentencepiece, + require_tokenizers, + require_torch, + slow, +) +from transformers.utils.hp_naming import TrialShortNamer if is_torch_available(): @@ -142,6 +150,7 @@ def get_regression_trainer(a=0, b=0, double_output=False, train_len=64, eval_len data_collator = kwargs.pop("data_collator", None) optimizers = kwargs.pop("optimizers", (None, None)) output_dir = kwargs.pop("output_dir", "./regression") + model_init = kwargs.pop("model_init", None) args = TrainingArguments(output_dir, **kwargs) return Trainer( model, @@ -151,6 +160,7 @@ def get_regression_trainer(a=0, b=0, double_output=False, train_len=64, eval_len eval_dataset=eval_dataset, compute_metrics=compute_metrics, optimizers=optimizers, + model_init=model_init, ) @@ -617,3 +627,46 @@ def assert_flos_extraction(trainer, wrapped_model_to_check): # with enforced DataParallel assert_flos_extraction(trainer, torch.nn.DataParallel(trainer.model)) + + +@require_torch +@require_optuna +class TrainerHyperParameterIntegrationTest(unittest.TestCase): + def setUp(self): + args = TrainingArguments(".") + self.n_epochs = args.num_train_epochs + self.batch_size = args.train_batch_size + + def test_hyperparameter_search(self): + class MyTrialShortNamer(TrialShortNamer): + DEFAULTS = {"a": 0, "b": 0} + + def hp_space(trial): + return {} + + def model_init(trial): + if trial is not None: + a = trial.suggest_int("a", -4, 4) + b = trial.suggest_int("b", -4, 4) + else: + a = 0 + b = 0 + config = RegressionModelConfig(a=a, b=b, double_output=False) + + return RegressionPreTrainedModel(config) + + def hp_name(trial): + return MyTrialShortNamer.shortname(trial.params) + + trainer = get_regression_trainer( + learning_rate=0.1, + logging_steps=1, + evaluation_strategy=EvaluationStrategy.EPOCH, + num_train_epochs=4, + disable_tqdm=True, + load_best_model_at_end=True, + logging_dir="runs", + run_name="test", + model_init=model_init, + ) + trainer.hyperparameter_search(direction="minimize", hp_space=hp_space, hp_name=hp_name, n_trials=4)