From 5b47e7d2aca3257f06dc7131da3a04869db3bf3b Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Thu, 29 Oct 2020 10:04:38 -0400 Subject: [PATCH 1/2] Smarter prediction loop and no- -> no_ in console args --- examples/test_examples.py | 1 - src/transformers/benchmark/benchmark_args.py | 2 +- src/transformers/hf_argparser.py | 2 +- src/transformers/trainer.py | 27 +++++++++++++------- 4 files changed, 20 insertions(+), 12 deletions(-) diff --git a/examples/test_examples.py b/examples/test_examples.py index f5252cdd63d41..240e7f010ad47 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -147,7 +147,6 @@ def test_run_clm(self): --num_train_epochs 2 --output_dir {tmp_dir} --overwrite_output_dir - --prediction_loss_only """.split() if torch.cuda.device_count() > 1: diff --git a/src/transformers/benchmark/benchmark_args.py b/src/transformers/benchmark/benchmark_args.py index d23880a9dc799..28f92eab1addf 100644 --- a/src/transformers/benchmark/benchmark_args.py +++ b/src/transformers/benchmark/benchmark_args.py @@ -55,7 +55,7 @@ def __init__(self, **kwargs): positive_arg = deprecated_arg[3:] setattr(self, positive_arg, not kwargs.pop(deprecated_arg)) logger.warning( - f"{deprecated_arg} is depreciated. Please use --no-{positive_arg} or {positive_arg}={kwargs[positive_arg]}" + f"{deprecated_arg} is depreciated. Please use --no_{positive_arg} or {positive_arg}={kwargs[positive_arg]}" ) self.torchscript = kwargs.pop("torchscript", self.torchscript) diff --git a/src/transformers/hf_argparser.py b/src/transformers/hf_argparser.py index 72c9c45c82959..20d5f96ba3d5a 100644 --- a/src/transformers/hf_argparser.py +++ b/src/transformers/hf_argparser.py @@ -66,7 +66,7 @@ def _add_dataclass_arguments(self, dtype: DataClassType): if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING): kwargs["action"] = "store_false" if field.default is True else "store_true" if field.default is True: - field_name = f"--no-{field.name}" + field_name = f"--no_{field.name}" kwargs["dest"] = field.name elif hasattr(field.type, "__origin__") and issubclass(field.type.__origin__, List): kwargs["nargs"] = "+" diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 69b346d063c9e..67fb183c2ab2a 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1300,7 +1300,13 @@ def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]: eval_dataloader = self.get_eval_dataloader(eval_dataset) - output = self.prediction_loop(eval_dataloader, description="Evaluation") + output = self.prediction_loop( + eval_dataloader, + description="Evaluation", + # No point gathering the predictions if there are no metrics, otherwise we defer to + # self.args.prediction_loss_only + prediction_loss_only=True if self.compute_metrics is None else None, + ) self.log(output.metrics) @@ -1382,8 +1388,9 @@ def prediction_loop( world_size = max(1, world_size) eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size) - preds_gatherer = DistributedTensorGatherer(world_size, num_examples) - labels_gatherer = DistributedTensorGatherer(world_size, num_examples) + if not prediction_loss_only: + preds_gatherer = DistributedTensorGatherer(world_size, num_examples) + labels_gatherer = DistributedTensorGatherer(world_size, num_examples) model.eval() @@ -1409,8 +1416,9 @@ def prediction_loop( # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0: eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) - preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) - labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) + if not prediction_loss_only: + preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) + labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) # Set back to None to begin a new accumulation losses_host, preds_host, labels_host = None, None, None @@ -1421,12 +1429,13 @@ def prediction_loop( # Gather all remaining tensors and put them back on the CPU eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) - preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) - labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) + if not prediction_loss_only: + preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) + labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) eval_loss = eval_losses_gatherer.finalize() - preds = preds_gatherer.finalize() - label_ids = labels_gatherer.finalize() + preds = preds_gatherer.finalize() if not prediction_loss_only else None + label_ids = labels_gatherer.finalize() if not prediction_loss_only else None if self.compute_metrics is not None and preds is not None and label_ids is not None: metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) From c0d35484dd5ebf43d26afc2aca14092f5d90a3fe Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Thu, 29 Oct 2020 10:28:50 -0400 Subject: [PATCH 2/2] Fix test --- tests/test_hf_argparser.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_hf_argparser.py b/tests/test_hf_argparser.py index 3c219d0b6f3a0..c42e2cf8dcbbd 100644 --- a/tests/test_hf_argparser.py +++ b/tests/test_hf_argparser.py @@ -93,13 +93,13 @@ def test_with_default_bool(self): expected = argparse.ArgumentParser() expected.add_argument("--foo", action="store_true") - expected.add_argument("--no-baz", action="store_false", dest="baz") + expected.add_argument("--no_baz", action="store_false", dest="baz") self.argparsersEqual(parser, expected) args = parser.parse_args([]) self.assertEqual(args, Namespace(foo=False, baz=True)) - args = parser.parse_args(["--foo", "--no-baz"]) + args = parser.parse_args(["--foo", "--no_baz"]) self.assertEqual(args, Namespace(foo=True, baz=False)) def test_with_enum(self):