Skip to content

Commit

Permalink
Fix nan-loss condition (#13911)
Browse files Browse the repository at this point in the history
  • Loading branch information
anton-l authored and LysandreJik committed Oct 6, 2021
1 parent bb2caca commit 3202896
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/transformers/trainer.py
Expand Up @@ -1315,10 +1315,13 @@ def train(
else:
tr_loss_step = self.training_step(model, inputs)

if args.logging_nan_inf_filter and not is_torch_tpu_available():
if torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step):
# if loss is nan or inf simply add the average of previous logged losses
tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
if (
args.logging_nan_inf_filter
and not is_torch_tpu_available()
and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
):
# if loss is nan or inf simply add the average of previous logged losses
tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
else:
tr_loss += tr_loss_step

Expand Down

0 comments on commit 3202896

Please sign in to comment.