diff --git a/fairseq/fp16_trainer.py b/fairseq/fp16_trainer.py index cd11c0fb98..2b0736e287 100644 --- a/fairseq/fp16_trainer.py +++ b/fairseq/fp16_trainer.py @@ -11,7 +11,7 @@ import torch -from fairseq import optim +from fairseq import optim, utils from fairseq.meters import AverageMeter from fairseq.optim import lr_scheduler from fairseq.trainer import Trainer @@ -105,8 +105,26 @@ def _all_reduce_and_rescale(self, grad_denom): # undo effect of dynamic loss scaling on gradients grad_denom *= self.scaler.loss_scale - # all-reduce and rescale gradients - grad_norm = super()._all_reduce_and_rescale(grad_denom) + if self.args.distributed_world_size > 1: + # flatten grads into a single buffer + flat_grads = self._flat_grads = self._get_flat_grads(self._flat_grads) + + # scale gradients to avoid overflow in all-reduce + flat_grads.div_(self.args.distributed_world_size) + grad_denom /= self.args.distributed_world_size + + # all-reduce flat grads + torch.distributed.all_reduce(flat_grads) + + # copy grads back to FP32 + self.fp32_params.grad.data.copy_(flat_grads) + else: + # single worker: copy grads directly to FP32 + self._get_flat_grads(out=self.fp32_params.grad.data) + + # rescale and clip grads + self.fp32_params.grad.data.div_(grad_denom) + grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, self.args.clip_norm) # detect overflow and adjust loss scale overflow = DynamicLossScaler.has_overflow(grad_norm) @@ -116,15 +134,6 @@ def _all_reduce_and_rescale(self, grad_denom): return grad_norm - def _get_flat_grads(self, out=None): - if out is None: - out = self.fp32_params.grad - return super()._get_flat_grads(out) - - def _set_flat_grads(self, new_grads): - # no-op - assert new_grads.data_ptr() == self.fp32_params.grad.data.data_ptr() - def _opt(self): # take an optimization step using the FP32 params and grads super()._opt()