Skip to content

Commit

Permalink
Smarter prediction loop and no- -> no_ in console args (#8151)
Browse files Browse the repository at this point in the history
* Smarter prediction loop and no- -> no_ in console args

* Fix test
  • Loading branch information
sgugger committed Oct 29, 2020
1 parent b0f1c0e commit acf5640
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 14 deletions.
1 change: 0 additions & 1 deletion examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ def test_run_clm(self):
--num_train_epochs 2
--output_dir {tmp_dir}
--overwrite_output_dir
--prediction_loss_only
""".split()

if torch.cuda.device_count() > 1:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/benchmark/benchmark_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self, **kwargs):
positive_arg = deprecated_arg[3:]
setattr(self, positive_arg, not kwargs.pop(deprecated_arg))
logger.warning(
f"{deprecated_arg} is depreciated. Please use --no-{positive_arg} or {positive_arg}={kwargs[positive_arg]}"
f"{deprecated_arg} is depreciated. Please use --no_{positive_arg} or {positive_arg}={kwargs[positive_arg]}"
)

self.torchscript = kwargs.pop("torchscript", self.torchscript)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/hf_argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _add_dataclass_arguments(self, dtype: DataClassType):
if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
kwargs["action"] = "store_false" if field.default is True else "store_true"
if field.default is True:
field_name = f"--no-{field.name}"
field_name = f"--no_{field.name}"
kwargs["dest"] = field.name
elif hasattr(field.type, "__origin__") and issubclass(field.type.__origin__, List):
kwargs["nargs"] = "+"
Expand Down
27 changes: 18 additions & 9 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1300,7 +1300,13 @@ def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:

eval_dataloader = self.get_eval_dataloader(eval_dataset)

output = self.prediction_loop(eval_dataloader, description="Evaluation")
output = self.prediction_loop(
eval_dataloader,
description="Evaluation",
# No point gathering the predictions if there are no metrics, otherwise we defer to
# self.args.prediction_loss_only
prediction_loss_only=True if self.compute_metrics is None else None,
)

self.log(output.metrics)

Expand Down Expand Up @@ -1382,8 +1388,9 @@ def prediction_loop(
world_size = max(1, world_size)

eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
preds_gatherer = DistributedTensorGatherer(world_size, num_examples)
labels_gatherer = DistributedTensorGatherer(world_size, num_examples)
if not prediction_loss_only:
preds_gatherer = DistributedTensorGatherer(world_size, num_examples)
labels_gatherer = DistributedTensorGatherer(world_size, num_examples)

model.eval()

Expand All @@ -1409,8 +1416,9 @@ def prediction_loop(
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0:
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
if not prediction_loss_only:
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))

# Set back to None to begin a new accumulation
losses_host, preds_host, labels_host = None, None, None
Expand All @@ -1421,12 +1429,13 @@ def prediction_loop(

# Gather all remaining tensors and put them back on the CPU
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
if not prediction_loss_only:
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))

eval_loss = eval_losses_gatherer.finalize()
preds = preds_gatherer.finalize()
label_ids = labels_gatherer.finalize()
preds = preds_gatherer.finalize() if not prediction_loss_only else None
label_ids = labels_gatherer.finalize() if not prediction_loss_only else None

if self.compute_metrics is not None and preds is not None and label_ids is not None:
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
Expand Down
4 changes: 2 additions & 2 deletions tests/test_hf_argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,13 @@ def test_with_default_bool(self):

expected = argparse.ArgumentParser()
expected.add_argument("--foo", action="store_true")
expected.add_argument("--no-baz", action="store_false", dest="baz")
expected.add_argument("--no_baz", action="store_false", dest="baz")
self.argparsersEqual(parser, expected)

args = parser.parse_args([])
self.assertEqual(args, Namespace(foo=False, baz=True))

args = parser.parse_args(["--foo", "--no-baz"])
args = parser.parse_args(["--foo", "--no_baz"])
self.assertEqual(args, Namespace(foo=True, baz=False))

def test_with_enum(self):
Expand Down

0 comments on commit acf5640

Please sign in to comment.