Skip to content

Commit

Permalink
[Trainer] memory tracker metrics (#10225)
Browse files Browse the repository at this point in the history
* memory tracker metrics

* go back to eval for somewhat consistency

* handle no-gpu case

* deal with stackable eval calls

* restore callback order

* style

* simplify the API

* add test

* docs

* consistently use eval_ prefix

* improve docs

* Update src/transformers/trainer_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* rename method

* style

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
stas00 and sgugger committed Feb 18, 2021
1 parent d7f38c5 commit 97e688b
Show file tree
Hide file tree
Showing 7 changed files with 294 additions and 14 deletions.
29 changes: 18 additions & 11 deletions examples/seq2seq/run_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,9 +588,12 @@ def compute_metrics(eval_preds):
)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
if trainer.is_world_process_zero():
metrics_formatted = trainer.metrics_format(metrics)
logger.info("***** train metrics *****")
for key in sorted(metrics.keys()):
logger.info(f" {key} = {metrics[key]}")
k_width = max(len(str(x)) for x in metrics_formatted.keys())
v_width = max(len(str(x)) for x in metrics_formatted.values())
for key in sorted(metrics_formatted.keys()):
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
save_json(metrics, os.path.join(training_args.output_dir, "train_results.json"))
all_metrics.update(metrics)

Expand All @@ -603,17 +606,19 @@ def compute_metrics(eval_preds):
logger.info("*** Evaluate ***")

metrics = trainer.evaluate(
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="val"
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
)
metrics = {k: round(v, 4) for k, v in metrics.items()}
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
metrics["val_samples"] = min(max_val_samples, len(eval_dataset))
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))

if trainer.is_world_process_zero():
metrics_formatted = trainer.metrics_format(metrics)
logger.info("***** val metrics *****")
for key in sorted(metrics.keys()):
logger.info(f" {key} = {metrics[key]}")
save_json(metrics, os.path.join(training_args.output_dir, "val_results.json"))
k_width = max(len(str(x)) for x in metrics_formatted.keys())
v_width = max(len(str(x)) for x in metrics_formatted.values())
for key in sorted(metrics_formatted.keys()):
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
save_json(metrics, os.path.join(training_args.output_dir, "eval_results.json"))
all_metrics.update(metrics)

if training_args.do_predict:
Expand All @@ -628,12 +633,14 @@ def compute_metrics(eval_preds):
metrics = test_results.metrics
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
metrics = {k: round(v, 4) for k, v in metrics.items()}

if trainer.is_world_process_zero():
metrics_formatted = trainer.metrics_format(metrics)
logger.info("***** test metrics *****")
for key in sorted(metrics.keys()):
logger.info(f" {key} = {metrics[key]}")
k_width = max(len(str(x)) for x in metrics_formatted.keys())
v_width = max(len(str(x)) for x in metrics_formatted.values())
for key in sorted(metrics_formatted.keys()):
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
save_json(metrics, os.path.join(training_args.output_dir, "test_results.json"))
all_metrics.update(metrics)

Expand Down
4 changes: 2 additions & 2 deletions examples/tests/deepspeed/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def test_do_eval_no_train(self):
extra_args_str="--do_eval",
remove_args_str="--do_train",
)
val_metrics = load_json(os.path.join(output_dir, "val_results.json"))
assert "val_bleu" in val_metrics
val_metrics = load_json(os.path.join(output_dir, "eval_results.json"))
assert "eval_bleu" in val_metrics

# XXX: need to do better validation beyond just that the run was successful
def run_quick(self, distributed=True, extra_args_str=None, remove_args_str=None):
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,15 @@ def is_torch_available():
return _torch_available


def is_torch_cuda_available():
if is_torch_available():
import torch

return torch.cuda.is_available()
else:
return False


def is_tf_available():
return _tf_available

Expand Down
49 changes: 49 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
EvalPrediction,
HPSearchBackend,
PredictionOutput,
TrainerMemoryTracker,
TrainOutput,
default_compute_objective,
default_hp_space,
Expand Down Expand Up @@ -243,6 +244,10 @@ def __init__(
self.hp_name = None
self.deepspeed = None

# memory metrics - must set up as early as possible
self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
self._memory_tracker.start()

# force device and distributed setup init explicitly
args._setup_devices

Expand Down Expand Up @@ -394,6 +399,9 @@ def __init__(
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)

# very last
self._memory_tracker.stop_and_update_metrics()

def add_callback(self, callback):
"""
Add a callback to the current list of :class:`~transformer.TrainerCallback`.
Expand Down Expand Up @@ -761,6 +769,10 @@ def train(
kwargs:
Additional keyword arguments used to hide deprecated arguments
"""

# memory metrics - must set up as early as possible
self._memory_tracker.start()

if "model_path" in kwargs:
resume_from_checkpoint = kwargs.pop("model_path")
warnings.warn(
Expand Down Expand Up @@ -1077,6 +1089,8 @@ def train(
self.model_wrapped = self.model
gc.collect() # force memory release

self._memory_tracker.stop_and_update_metrics(metrics)

return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step, metrics)

def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
Expand Down Expand Up @@ -1306,6 +1320,29 @@ def log(self, logs: Dict[str, float]) -> None:
self.state.log_history.append(output)
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)

def metrics_format(self, metrics: Dict[str, float]) -> Dict[str, float]:
"""
Reformat Trainer metrics values to a human-readable format
Args:
metrics (:obj:`Dict[str, float]`):
The metrics returned from train/evaluate/predict
Returns:
metrics (:obj:`Dict[str, float]`): The reformatted metrics
"""

metrics_copy = metrics.copy()
for k, v in metrics_copy.items():
if "_mem_" in k:
metrics_copy[k] = f"{ v >> 20 }MB"
elif k == "total_flos":
metrics_copy[k] = f"{ int(v) >> 30 }GF"
elif type(metrics_copy[k]) == float:
metrics_copy[k] = round(v, 4)

return metrics_copy

def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
"""
Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
Expand Down Expand Up @@ -1542,6 +1579,9 @@ def evaluate(
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
dictionary also contains the epoch number which comes from the training state.
"""
# memory metrics - must set up as early as possible
self._memory_tracker.start()

if eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
raise ValueError("eval_dataset must implement __len__")

Expand All @@ -1567,6 +1607,9 @@ def evaluate(
xm.master_print(met.metrics_report())

self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)

self._memory_tracker.stop_and_update_metrics(output.metrics)

return output.metrics

def predict(
Expand Down Expand Up @@ -1602,6 +1645,9 @@ def predict(
- metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset
contained labels).
"""
# memory metrics - must set up as early as possible
self._memory_tracker.start()

if test_dataset is not None and not isinstance(test_dataset, collections.abc.Sized):
raise ValueError("test_dataset must implement __len__")

Expand All @@ -1612,6 +1658,9 @@ def predict(
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
)
output.metrics.update(speed_metrics(metric_key_prefix, start_time, len(test_dataset)))

self._memory_tracker.stop_and_update_metrics(output.metrics)

return output

def prediction_loop(
Expand Down

0 comments on commit 97e688b

Please sign in to comment.