Skip to content

Commit

Permalink
Fix gradient clipping for Sharded DDP (#9168)
Browse files Browse the repository at this point in the history
* Fix gradient clipping for Sharded DDP

* Fix typos in comments
  • Loading branch information
sgugger committed Dec 17, 2020
1 parent 1aca3d6 commit 77d6941
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions src/transformers/trainer.py
Expand Up @@ -804,14 +804,23 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
steps_in_epoch <= self.args.gradient_accumulation_steps
and (step + 1) == steps_in_epoch
):
if self.use_amp:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
elif self.use_apex:
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)

# Gradient clipping
if self.args.max_grad_norm is not None and self.args.max_grad_norm > 0:
if self.use_amp:
# AMP: gradients need unscaling
self.scaler.unscale_(self.optimizer)

if hasattr(self.optimizer, "clip_grad_norm"):
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
self.optimizer.clip_grad_norm(self.args.max_grad_norm)
else:
# Revert to normal clipping otherwise, handling Apex or full precision
torch.nn.utils.clip_grad_norm_(
amp.master_params(self.optimizer) if self.use_apex else model.parameters(),
self.args.max_grad_norm,
)

# Optimizer step
if is_torch_tpu_available():
xm.optimizer_step(self.optimizer)
elif self.use_amp:
Expand Down

0 comments on commit 77d6941

Please sign in to comment.