Skip to content

Commit

Permalink
All-reduce in FP16 (#287)
Browse files Browse the repository at this point in the history
  • Loading branch information
myleott committed May 23, 2018
1 parent bd110fd commit a848a00
Showing 1 changed file with 21 additions and 12 deletions.
33 changes: 21 additions & 12 deletions fairseq/fp16_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import torch

from fairseq import optim
from fairseq import optim, utils
from fairseq.meters import AverageMeter
from fairseq.optim import lr_scheduler
from fairseq.trainer import Trainer
Expand Down Expand Up @@ -105,8 +105,26 @@ def _all_reduce_and_rescale(self, grad_denom):
# undo effect of dynamic loss scaling on gradients
grad_denom *= self.scaler.loss_scale

# all-reduce and rescale gradients
grad_norm = super()._all_reduce_and_rescale(grad_denom)
if self.args.distributed_world_size > 1:
# flatten grads into a single buffer
flat_grads = self._flat_grads = self._get_flat_grads(self._flat_grads)

# scale gradients to avoid overflow in all-reduce
flat_grads.div_(self.args.distributed_world_size)
grad_denom /= self.args.distributed_world_size

# all-reduce flat grads
torch.distributed.all_reduce(flat_grads)

# copy grads back to FP32
self.fp32_params.grad.data.copy_(flat_grads)
else:
# single worker: copy grads directly to FP32
self._get_flat_grads(out=self.fp32_params.grad.data)

# rescale and clip grads
self.fp32_params.grad.data.div_(grad_denom)
grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, self.args.clip_norm)

# detect overflow and adjust loss scale
overflow = DynamicLossScaler.has_overflow(grad_norm)
Expand All @@ -116,15 +134,6 @@ def _all_reduce_and_rescale(self, grad_denom):

return grad_norm

def _get_flat_grads(self, out=None):
if out is None:
out = self.fp32_params.grad
return super()._get_flat_grads(out)

def _set_flat_grads(self, new_grads):
# no-op
assert new_grads.data_ptr() == self.fp32_params.grad.data.data_ptr()

def _opt(self):
# take an optimization step using the FP32 params and grads
super()._opt()
Expand Down

0 comments on commit a848a00

Please sign in to comment.