Skip to content

Commit

Permalink
make tr_loss regular float
Browse files Browse the repository at this point in the history
  • Loading branch information
bminixhofer committed Nov 16, 2020
1 parent daaa684 commit c9d7ccf
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions src/transformers/trainer.py
Expand Up @@ -727,7 +727,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
self.state.is_local_process_zero = self.is_local_process_zero()
self.state.is_world_process_zero = self.is_world_process_zero()

tr_loss = torch.tensor(0.0).to(self.args.device)
tr_loss = 0.0
self._logging_loss_scalar = 0
self._globalstep_last_logged = 0
self._total_flos = self.state.total_flos
Expand Down Expand Up @@ -770,9 +770,9 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
and _use_ddp_no_sync
):
with model.no_sync():
tr_loss += self.training_step(model, inputs)
tr_loss += self.training_step(model, inputs).item()
else:
tr_loss += self.training_step(model, inputs)
tr_loss += self.training_step(model, inputs).item()
self._total_flos += self.floating_point_ops(inputs)

if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
Expand Down Expand Up @@ -844,13 +844,12 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D

self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)

return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)
return TrainOutput(self.state.global_step, tr_loss / self.state.global_step)

def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
if self.control.should_log:
logs: Dict[str, float] = {}
tr_loss_scalar = tr_loss.item()
logs["loss"] = (tr_loss_scalar - self._logging_loss_scalar) / (
logs["loss"] = (tr_loss - self._logging_loss_scalar) / (
self.state.global_step - self._globalstep_last_logged
)
# backward compatibility for pytorch schedulers
Expand All @@ -859,7 +858,7 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
if version.parse(torch.__version__) >= version.parse("1.4")
else self.lr_scheduler.get_lr()[0]
)
self._logging_loss_scalar = tr_loss_scalar
self._logging_loss_scalar = tr_loss
self._globalstep_last_logged = self.state.global_step

self.log(logs)
Expand Down

0 comments on commit c9d7ccf

Please sign in to comment.