Skip to content

Commit

Permalink
TensorBoard/Wandb/optuna/raytune integration improvements. (#7935)
Browse files Browse the repository at this point in the history
Improved TensorBoard and Wandb integration, as well as optuna and ray/tune support, with minor modifications to trainer core code.
  • Loading branch information
madlag committed Oct 21, 2020
1 parent bf162ce commit e174bfe
Show file tree
Hide file tree
Showing 7 changed files with 344 additions and 23 deletions.
85 changes: 75 additions & 10 deletions src/transformers/integrations.py
Expand Up @@ -85,6 +85,17 @@ def is_ray_available():
return _has_ray


def hp_params(trial):
if is_optuna_available():
if isinstance(trial, optuna.Trial):
return trial.params
if is_ray_available():
if isinstance(trial, dict):
return trial

raise RuntimeError(f"Unknown type for trial {trial.__class__}")


def default_hp_search_backend():
if is_optuna_available():
return "optuna"
Expand Down Expand Up @@ -192,6 +203,18 @@ def _objective(trial, checkpoint_dir=None):
return best_run


def rewrite_logs(d):
new_d = {}
eval_prefix = "eval_"
eval_prefix_len = len(eval_prefix)
for k, v in d.items():
if k.startswith(eval_prefix):
new_d["eval/" + k[eval_prefix_len:]] = v
else:
new_d["train/" + k] = v
return new_d


class TensorBoardCallback(TrainerCallback):
"""
A :class:`~transformers.TrainerCallback` that sends the logs to `TensorBoard
Expand All @@ -208,17 +231,39 @@ def __init__(self, tb_writer=None):
), "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX."
self.tb_writer = tb_writer

def on_init_end(self, args, state, control, **kwargs):
if self.tb_writer is None and state.is_world_process_zero:
self.tb_writer = SummaryWriter(log_dir=args.logging_dir)
def _init_summary_writer(self, args, log_dir=None):
log_dir = log_dir or args.logging_dir
self.tb_writer = SummaryWriter(log_dir=log_dir)

def on_train_begin(self, args, state, control, **kwargs):
if not state.is_world_process_zero:
return

log_dir = None

if state.is_hyper_param_search:
trial_name = state.trial_name
if trial_name is not None:
log_dir = os.path.join(args.logging_dir, trial_name)

self._init_summary_writer(args, log_dir)

if self.tb_writer is not None:
self.tb_writer.add_text("args", args.to_json_string())
if "model" in kwargs:
model = kwargs["model"]
if hasattr(model, "config") and model.config is not None:
model_config_json = model.config.to_json_string()
self.tb_writer.add_text("model_config", model_config_json)
self.tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={})

def on_log(self, args, state, control, logs=None, **kwargs):
if state.is_world_process_zero:
if self.tb_writer is None:
self._init_summary_writer(args)

if self.tb_writer:
logs = rewrite_logs(logs)
for k, v in logs.items():
if isinstance(v, (int, float)):
self.tb_writer.add_scalar(k, v, state.global_step)
Expand Down Expand Up @@ -249,7 +294,7 @@ def __init__(self):
assert _has_wandb, "WandbCallback requires wandb to be installed. Run `pip install wandb`."
self._initialized = False

def setup(self, args, state, model):
def setup(self, args, state, model, reinit, **kwargs):
"""
Setup the optional Weights & Biases (`wandb`) integration.
Expand All @@ -271,21 +316,41 @@ def setup(self, args, state, model):
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
)
combined_dict = {**args.to_sanitized_dict()}
if getattr(model, "config", None) is not None:
combined_dict = {**model.config.to_dict(), **combined_dict}
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=args.run_name)

if hasattr(model, "config") and model.config is not None:
model_config = model.config.to_dict()
combined_dict = {**model_config, **combined_dict}
trial_name = state.trial_name
init_args = {}
if trial_name is not None:
run_name = trial_name
init_args["group"] = args.run_name
else:
run_name = args.run_name

wandb.init(
project=os.getenv("WANDB_PROJECT", "huggingface"),
config=combined_dict,
name=run_name,
reinit=reinit,
**init_args,
)

# keep track of model topology and gradients, unsupported on TPU
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
wandb.watch(model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps))

def on_train_begin(self, args, state, control, model=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)
hp_search = state.is_hyper_param_search
if not self._initialized or hp_search:
print(args.run_name)
self.setup(args, state, model, reinit=hp_search, **kwargs)

def on_log(self, args, state, control, model=None, logs=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)
self.setup(args, state, model, reinit=False)
if state.is_world_process_zero:
logs = rewrite_logs(logs)
wandb.log(logs, step=state.global_step)


Expand Down
27 changes: 27 additions & 0 deletions src/transformers/testing_utils.py
Expand Up @@ -20,6 +20,7 @@
_torch_available,
_torch_tpu_available,
)
from .integrations import _has_optuna, _has_ray


SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
Expand Down Expand Up @@ -233,6 +234,32 @@ def require_faiss(test_case):
return test_case


def require_optuna(test_case):
"""
Decorator marking a test that requires optuna.
These tests are skipped when optuna isn't installed.
"""
if not _has_optuna:
return unittest.skip("test requires optuna")(test_case)
else:
return test_case


def require_ray(test_case):
"""
Decorator marking a test that requires Ray/tune.
These tests are skipped when Ray/tune isn't installed.
"""
if not _has_ray:
return unittest.skip("test requires Ray/tune")(test_case)
else:
return test_case


def get_tests_dir(append_path=None):
"""
Args:
Expand Down
34 changes: 25 additions & 9 deletions src/transformers/trainer.py
Expand Up @@ -39,6 +39,7 @@
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_in_notebook, is_torch_tpu_available
from .integrations import (
default_hp_search_backend,
hp_params,
is_comet_available,
is_optuna_available,
is_ray_available,
Expand Down Expand Up @@ -224,6 +225,7 @@ def __init__(
model is not None or model_init is not None
), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
self.model_init = model_init
self.hp_name = None
if model is None and model_init is not None:
model = self.call_model_init()
self.model = model.to(args.device) if model is not None else None
Expand Down Expand Up @@ -508,8 +510,11 @@ def num_examples(self, dataloader: DataLoader) -> int:

def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
""" HP search setup code """
self._trial = trial

if self.hp_search_backend is None or trial is None:
return

params = self.hp_space(trial) if self.hp_search_backend == HPSearchBackend.OPTUNA else trial
for key, value in params.items():
if not hasattr(self.args, key):
Expand Down Expand Up @@ -558,7 +563,10 @@ def call_model_init(self, trial=None):
elif model_init_argcount == 1:
model = self.model_init(trial)
else:
raise Exception("model_init should have 0 or 1 argument.")
raise RuntimeError("model_init should have 0 or 1 argument.")

if model is None:
raise RuntimeError("model_init should not return None.")

return model

Expand Down Expand Up @@ -617,6 +625,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D

self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self.state = TrainerState()
self.state.is_hyper_param_search = trial is not None

# Check if saved optimizer or scheduler states exist
if (
Expand Down Expand Up @@ -702,6 +711,8 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
self.callback_handler.optimizer = self.optimizer
self.callback_handler.lr_scheduler = self.lr_scheduler
self.callback_handler.train_dataloader = train_dataloader
self.state.trial_name = self.hp_name(trial) if self.hp_name is not None else None
self.state.trial_params = hp_params(trial) if trial is not None else None
# This should be the same if the state has been saved but in case the training arguments changed, it's safer
# to set this after the load.
self.state.max_steps = max_steps
Expand Down Expand Up @@ -783,13 +794,13 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
self.state.epoch = epoch + (step + 1) / steps_in_epoch
self.control = self.callback_handler.on_step_end(self.args, self.state, self.control)

self._maybe_log_save_evalute(tr_loss, model, trial, epoch)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)

if self.control.should_epoch_stop or self.control.should_training_stop:
break

self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control)
self._maybe_log_save_evalute(tr_loss, model, trial, epoch)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch)

if self.args.tpu_metrics_debug or self.args.debug:
if is_torch_tpu_available():
Expand Down Expand Up @@ -823,7 +834,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D

return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)

def _maybe_log_save_evalute(self, tr_loss, model, trial, epoch):
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
if self.control.should_log:
logs: Dict[str, float] = {}
tr_loss_scalar = tr_loss.item()
Expand All @@ -842,6 +853,7 @@ def _maybe_log_save_evalute(self, tr_loss, model, trial, epoch):
if self.control.should_evaluate:
metrics = self.evaluate()
self._report_to_hp_search(trial, epoch, metrics)

self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)

if self.control.should_save:
Expand All @@ -857,12 +869,15 @@ def _save_checkpoint(self, model, trial, metrics=None):
assert model is self.model, f"Model {model} should be a reference to self.model"
# Save model checkpoint
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"

if self.hp_search_backend is not None and trial is not None:
run_id = trial.number if self.hp_search_backend == HPSearchBackend.OPTUNA else tune.get_trial_id()
checkpoint_folder += f"-run-{run_id}"
output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
output_dir = os.path.join(self.args.output_dir, run_name, checkpoint_folder)
else:
output_dir = os.path.join(self.args.output_dir, checkpoint_folder)

self.store_flos()
self.store_flos()
self.save_model(output_dir)

# Save optimizer and scheduler
Expand Down Expand Up @@ -909,6 +924,7 @@ def hyperparameter_search(
n_trials: int = 20,
direction: str = "minimize",
backend: Optional[Union["str", HPSearchBackend]] = None,
hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
**kwargs
) -> BestRun:
"""
Expand Down Expand Up @@ -966,13 +982,13 @@ def hyperparameter_search(
"You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
)
self.hp_search_backend = backend

if self.model_init is None:
raise RuntimeError(
"To use hyperparameter search, you need to pass your model through a model_init function."
)

self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
self.hp_name = hp_name
self.compute_objective = default_compute_objective if compute_objective is None else compute_objective

run_hp_search = run_hp_search_optuna if backend == HPSearchBackend.OPTUNA else run_hp_search_ray
Expand All @@ -997,12 +1013,12 @@ def log(self, logs: Dict[str, float]) -> None:
FutureWarning,
)
return self._log(logs)

if self.state.epoch is not None:
logs["epoch"] = self.state.epoch
if self._total_flos is not None:
self.store_flos()
logs["total_flos"] = self.state.total_flos

self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
output = {**logs, **{"step": self.state.global_step}}
self.state.log_history.append(output)
Expand Down
8 changes: 7 additions & 1 deletion src/transformers/trainer_callback.py
Expand Up @@ -19,7 +19,7 @@
import dataclasses
import json
from dataclasses import dataclass
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

from tqdm.auto import tqdm

Expand Down Expand Up @@ -66,6 +66,9 @@ class TrainerState:
is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not this process is the global main process (when training in a distributed fashion on
several machines, this is only going to be :obj:`True` for one process).
is_hyper_param_search (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether we are in the process of a hyper parameter search using Trainer.hyperparameter_search.
This will impact the way data will be logged in TensorBoard.
"""

epoch: Optional[float] = None
Expand All @@ -78,6 +81,9 @@ class TrainerState:
best_model_checkpoint: Optional[str] = None
is_local_process_zero: bool = True
is_world_process_zero: bool = True
is_hyper_param_search: bool = False
trial_name: str = None
trial_params: Dict[str, Union[str, float, int, bool]] = None

def __post_init__(self):
if self.log_history is None:
Expand Down
8 changes: 7 additions & 1 deletion src/transformers/trainer_utils.py
Expand Up @@ -16,6 +16,7 @@
Utilities for the Trainer and TFTrainer class. Should be independent from PyTorch and TensorFlow.
"""

import copy
import random
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union

Expand Down Expand Up @@ -110,10 +111,15 @@ def default_compute_objective(metrics: Dict[str, float]) -> float:
Return:
:obj:`float`: The objective to minimize or maximize
"""
metrics = copy.deepcopy(metrics)
loss = metrics.pop("eval_loss", None)
_ = metrics.pop("epoch", None)
_ = metrics.pop("total_flos", None)
return loss if len(metrics) == 0 else sum(metrics.values())
if len(metrics) != 0:
raise RuntimeError(
"Metrics contains more entries than just 'eval_loss', 'epoch' and 'total_flos', please provide your own compute_objective function."
)
return loss


def default_hp_space_optuna(trial) -> Dict[str, float]:
Expand Down

0 comments on commit e174bfe

Please sign in to comment.