Skip to content

Commit

Permalink
added fast stats sync option (facebookresearch#858)
Browse files Browse the repository at this point in the history
Summary:
Added `--fast-stat-sync` option.
This avoids pickle and achieves `~7%` more `wps` on 16 nodes.
It is less flexible as it just aggregates only basic stats and it ignores the aggregate function defined by criterion.

Let me know what you think myleott
Pull Request resolved: fairinternal/fairseq-py#858

Differential Revision: D17398770

fbshipit-source-id: 36261a1d970e67deeda8211af8f009ef9b4f9c14
  • Loading branch information
Naman Goyal authored and facebook-github-bot committed Sep 16, 2019
1 parent 1fd8943 commit e1ba32a
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 11 deletions.
1 change: 1 addition & 0 deletions fairseq/criterions/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def forward(self, model, sample, reduce=True):
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'nll_loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample['target'].size(0),
'sample_size': sample_size,
Expand Down
1 change: 1 addition & 0 deletions fairseq/criterions/masked_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def forward(self, model, sample, reduce=True):

logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'nll_loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample['nsentences'],
'sample_size': sample_size,
Expand Down
3 changes: 3 additions & 0 deletions fairseq/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,9 @@ def add_distributed_training_args(parser):
group.add_argument('--find-unused-parameters', default=False, action='store_true',
help='disable unused parameter detection (not applicable to '
'no_c10d ddp-backend')
group.add_argument('--fast-stat-sync', default=False, action='store_true',
help='Enable fast sync of stats between nodes, this hardcodes to '
'sync only some default stats from logging_output.')
# fmt: on
return group

Expand Down
72 changes: 61 additions & 11 deletions fairseq/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ def __init__(self, args, task, model, criterion, dummy_batch=None, oom_batch=Non
self._wrapped_criterion = None
self._wrapped_model = None

# Fast stats sync avoids memcpy and is 7% faster when tested on 16 nodes.
# It is less flexible and syncs only the default stats.
self._all_reduce_list = [0.0] * 6
self.fast_stat_sync = args.fast_stat_sync

self.init_meters(args)

def init_meters(self, args):
Expand Down Expand Up @@ -292,6 +297,13 @@ def maybe_no_sync():
if not ignore_grad:
logging_outputs.append(logging_output)
sample_sizes.append(sample_size)

if self.fast_stat_sync:
self._all_reduce_list[0] += sample_size
self._all_reduce_list[1] += logging_output.get('nsentences', 0.0)
self._all_reduce_list[2] += logging_output.get('loss', 0.0)
self._all_reduce_list[3] += logging_output.get('nll_loss', 0.0)
self._all_reduce_list[4] += logging_output.get('ntokens', 0.0)
except RuntimeError as e:
if 'out of memory' in str(e):
msg = (
Expand All @@ -311,20 +323,41 @@ def maybe_no_sync():
else:
raise e

if self.fast_stat_sync:
self._all_reduce_list[5] += ooms


if ooms > 0 and self._oom_batch is not None:
self.handle_ooms(ooms)

if dummy_batch:
return None

# gather logging outputs from all replicas
if self.args.distributed_world_size > 1 and (
(not self.args.use_bmuf)
or (
self.args.use_bmuf
and (self.get_num_updates() + 1) % self.args.global_sync_iter == 0
if self.fast_stat_sync:
# rework all_gather_list
all_reduce_list_tensor = torch.cuda.DoubleTensor(self._all_reduce_list)
if self._sync_stats():
torch.distributed.all_reduce(all_reduce_list_tensor)
# Normalize loss and nll_loss by "sample_size"
# and convert to log base 2
all_reduce_list_tensor[2:4].div_(
(
all_reduce_list_tensor[0:1] *
torch.log(torch.cuda.DoubleTensor([2]))
)
)
):
self._all_reduce_list = all_reduce_list_tensor.tolist()
logging_output = {}
[
sample_size,
logging_output['nsentences'],
logging_output['loss'],
logging_output['nll_loss'],
logging_output['ntokens'],
ooms,
] = self._all_reduce_list
elif self._sync_stats():
logging_outputs, sample_sizes, ooms, prev_norms = \
zip(*distributed_utils.all_gather_list(
[logging_outputs, sample_sizes, ooms, self._prev_grad_norm],
Expand All @@ -345,11 +378,12 @@ def maybe_no_sync():
self.zero_grad()
return None

# aggregate logging outputs and sample sizes
logging_output = self.task.aggregate_logging_outputs(
logging_outputs, self.get_criterion()
)
sample_size = self.task.grad_denom(sample_sizes, self.get_criterion())
if not self.fast_stat_sync:
# aggregate logging outputs and sample sizes
logging_output = self.task.aggregate_logging_outputs(
logging_outputs, self.get_criterion()
)
sample_size = self.task.grad_denom(sample_sizes, self.get_criterion())

if not all(k in logging_output for k in ['ntokens', 'nsentences']):
raise Exception((
Expand Down Expand Up @@ -400,6 +434,7 @@ def maybe_no_sync():
self.meters['loss_scale'].reset()
self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale)

self.clear_buffered_stats()
self.meters['train_wall'].stop()

return logging_output
Expand Down Expand Up @@ -484,6 +519,9 @@ def handle_ooms(self, number_of_ooms):
def zero_grad(self):
self.optimizer.zero_grad()

def clear_buffered_stats(self):
self._all_reduce_list = [0.0] * 6

def lr_step(self, epoch, val_loss=None):
"""Adjust the learning rate based on the validation loss."""
self.lr_scheduler.step(epoch, val_loss)
Expand Down Expand Up @@ -545,3 +583,15 @@ def _set_seed(self):
torch.manual_seed(seed)
if self.cuda:
torch.cuda.manual_seed(seed)

def _sync_stats(self):
return (
self.args.distributed_world_size > 1 and
(
(not self.args.use_bmuf) or
(
self.args.use_bmuf
and (self.get_num_updates() + 1) % self.args.global_sync_iter == 0
)
)
)

0 comments on commit e1ba32a

Please sign in to comment.