Skip to content

Commit

Permalink
Switch to DistributedDataParallelC10d and bump version 0.5.0 -> 0.6.0
Browse files Browse the repository at this point in the history
- no more FP16Trainer, we just have an FP16Optimizer wrapper
- most of the distributed code is moved to a new wrapper class called DistributedFairseqModel, which behaves like DistributedDataParallel and a FairseqModel at the same time
- Trainer now requires an extra dummy_batch argument at initialization, which we do fwd/bwd on when there's an uneven number of batches per worker. We hide the gradients from these dummy batches by multiplying the loss by 0
- Trainer.train_step now takes a list of samples, which will allow cleaner --update-freq
  • Loading branch information
edunov authored and myleott committed Sep 25, 2018
1 parent d8070c7 commit 4908863
Show file tree
Hide file tree
Showing 20 changed files with 590 additions and 427 deletions.
2 changes: 1 addition & 1 deletion distributed_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def main(args):
raise e
except FileNotFoundError as e: # Slurm is not installed
pass
if args.distributed_init_method is None:
if args.distributed_init_method is None and args.distributed_port is None:
raise ValueError('--distributed-init-method or --distributed-port '
'must be specified for distributed training')

Expand Down
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@
# built documents.
#
# The short X.Y version.
version = '0.5.0'
version = '0.6.0'
# The full version, including alpha/beta/rc tags.
release = '0.5.0'
release = '0.6.0'

# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
Expand Down
2 changes: 2 additions & 0 deletions docs/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,7 @@ Iterators
:members:
.. autoclass:: fairseq.data.EpochBatchIterator
:members:
.. autoclass:: fairseq.data.GroupedIterator
:members:
.. autoclass:: fairseq.data.ShardedIterator
:members:
4 changes: 4 additions & 0 deletions fairseq/criterions/adaptive_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def forward(self, model, sample, reduce=True):
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample['target'].size(0),
'sample_size': sample_size,
}
return loss, sample_size, logging_output
Expand All @@ -63,9 +64,12 @@ def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
agg_output = {
'loss': loss_sum / sample_size / math.log(2),
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size,
}
if sample_size != ntokens:
Expand Down
4 changes: 4 additions & 0 deletions fairseq/criterions/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def forward(self, model, sample, reduce=True):
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample['target'].size(0),
'sample_size': sample_size,
}
return loss, sample_size, logging_output
Expand All @@ -46,9 +47,12 @@ def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
agg_output = {
'loss': loss_sum / sample_size / math.log(2),
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size,
}
if sample_size != ntokens:
Expand Down
5 changes: 4 additions & 1 deletion fairseq/criterions/label_smoothed_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def forward(self, model, sample, reduce=True):
'loss': utils.item(loss.data) if reduce else loss.data,
'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample['target'].size(0),
'sample_size': sample_size,
}
return loss, sample_size, logging_output
Expand All @@ -58,14 +59,16 @@ def compute_loss(self, model, net_output, sample, reduce=True):
loss = (1. - self.eps) * nll_loss + eps_i * smooth_loss
return loss, nll_loss


@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
return {
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2),
'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / ntokens / math.log(2),
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size,
}
10 changes: 8 additions & 2 deletions fairseq/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,24 @@
from .monolingual_dataset import MonolingualDataset
from .token_block_dataset import TokenBlockDataset

from .iterators import CountingIterator, EpochBatchIterator, ShardedIterator
from .iterators import (
CountingIterator,
EpochBatchIterator,
GroupedIterator,
ShardedIterator,
)

__all__ = [
'CountingIterator',
'Dictionary',
'EpochBatchIterator',
'FairseqDataset',
'GroupedIterator',
'IndexedDataset',
'IndexedInMemoryDataset',
'IndexedRawTextDataset',
'LanguagePairDataset',
'MonolingualDataset',
'TokenBlockDataset',
'ShardedIterator',
'TokenBlockDataset',
]
31 changes: 31 additions & 0 deletions fairseq/data/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory.

import itertools
import math

import numpy as np
import torch
Expand Down Expand Up @@ -150,6 +151,36 @@ def _get_iterator_for_epoch(self, epoch, shuffle):
))


class GroupedIterator(object):
"""Wrapper around an iterable that returns groups (chunks) of items.
Args:
iterable (iterable): iterable to wrap
chunk_size (int): size of each chunk
"""

def __init__(self, iterable, chunk_size):
self._len = int(math.ceil(len(iterable) / float(chunk_size)))
self.itr = iter(iterable)
self.chunk_size = chunk_size

def __len__(self):
return self._len

def __iter__(self):
return self

def __next__(self):
chunk = []
try:
for _ in range(self.chunk_size):
chunk.append(next(self.itr))
except StopIteration as e:
if len(chunk) == 0:
raise e
return chunk


class ShardedIterator(object):
"""A sharded wrapper around an iterable, padded to length.
Expand Down
119 changes: 90 additions & 29 deletions fairseq/distributed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

import pickle

import torch.distributed
import torch
from torch import distributed
from torch.distributed import group

from fairseq import utils

Expand All @@ -16,22 +18,39 @@ def is_master(args):
return args.distributed_rank == 0


_use_c10d = [None]


def distributed_init(args):
if args.distributed_world_size == 1:
raise ValueError('Cannot initialize distributed with distributed_world_size=1')

if _use_c10d[0] is None:
_use_c10d[0] = not args.no_c10d

if _use_c10d[0] and not hasattr(torch.nn.parallel, '_DistributedDataParallelC10d'):
_use_c10d[0] = False
print('WARNING: cannot find DistributedDataParallelC10d, '
'falling back to standard DistributedDataParallel')

print('| distributed init (rank {}): {}'.format(
args.distributed_rank, args.distributed_init_method), flush=True)
if args.distributed_init_method.startswith('tcp://'):
torch.distributed.init_process_group(
backend=args.distributed_backend, init_method=args.distributed_init_method,
world_size=args.distributed_world_size, rank=args.distributed_rank)

if _use_c10d[0]:
distributed.c10d.init_process_group(
backend=args.distributed_backend,
init_method=args.distributed_init_method,
world_size=args.distributed_world_size,
rank=args.distributed_rank,
)
else:
torch.distributed.init_process_group(
backend=args.distributed_backend, init_method=args.distributed_init_method,
world_size=args.distributed_world_size)
distributed.init_process_group(
backend=args.distributed_backend,
init_method=args.distributed_init_method,
world_size=args.distributed_world_size,
rank=args.distributed_rank,
)

args.distributed_rank = torch.distributed.get_rank()
if not is_master(args):
suppress_output()

Expand All @@ -52,35 +71,77 @@ def print(*args, **kwargs):
__builtin__.print = print


def all_gather_list(data, max_size=16384):
"""Gathers arbitrary data from all nodes into a list."""
world_size = torch.distributed.get_world_size()
if not hasattr(all_gather_list, '_in_buffer') or \
max_size != all_gather_list._in_buffer.size():
all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size)
all_gather_list._out_buffers = [
torch.cuda.ByteTensor(max_size)
for i in range(world_size)
]
in_buffer = all_gather_list._in_buffer
out_buffers = all_gather_list._out_buffers
def get_rank():
if _use_c10d[0]:
return distributed.c10d.get_rank()
else:
return distributed.get_rank()


def get_world_size():
if _use_c10d[0]:
return distributed.c10d.get_world_size()
else:
return distributed.get_world_size()


def get_default_group():
if _use_c10d[0]:
return distributed.c10d.group.WORLD
else:
return distributed.group.WORLD


def all_reduce(tensor, group=None):
if group is None:
group = get_default_group()
if _use_c10d[0]:
return distributed.c10d.all_reduce(tensor, group=group)
else:
return distributed.all_reduce(tensor, group=group)


def all_gather_list(data, group=None, max_size=16384):
"""Gathers arbitrary data from all nodes into a list.
Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python
data. Note that *data* must be picklable.
Args:
data (Any): data from the local worker to be gathered on other workers
group (optional): group of the collective
max_size (int, optional): maximum size of the data to be gathered
across workers
"""
rank = get_rank()
world_size = get_world_size()

buffer_size = max_size * world_size
if not hasattr(all_gather_list, '_buffer') or \
all_gather_list._buffer.numel() < buffer_size:
all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size)
buffer = all_gather_list._buffer
buffer.zero_()

enc = pickle.dumps(data)
enc_size = len(enc)
if enc_size + 2 > max_size:
raise ValueError('encoded data exceeds max_size: {}'.format(enc_size + 2))
assert max_size < 255*256
in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k
in_buffer[1] = enc_size % 255
in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc))

torch.distributed.all_gather(out_buffers, in_buffer.cuda())
buffer_rank = buffer[rank * max_size : (rank + 1) * max_size]
buffer_rank[0] = enc_size // 255 # this encoding works for max_size < 65k
buffer_rank[1] = enc_size % 255
buffer_rank[2:enc_size+2] = torch.ByteTensor(list(enc))

all_reduce(buffer, group=group)

result = []
for i in range(world_size):
out_buffer = out_buffers[i]
out_buffer = buffer[i * max_size : (i + 1) * max_size]
size = (255 * utils.item(out_buffer[0])) + utils.item(out_buffer[1])
result.append(
pickle.loads(bytes(out_buffer[2:size+2].tolist()))
)
if size > 0:
result.append(
pickle.loads(bytes(out_buffer[2:size+2].tolist()))
)
return result

0 comments on commit 4908863

Please sign in to comment.