Skip to content

Commit

Permalink
Initialize distributed using multiproc with all visible GPUs
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #695

Differential Revision: D15182613

Pulled By: myleott

fbshipit-source-id: 4196346517d8e75ed9e903e9e01ab943d086f6f1
  • Loading branch information
myleott authored and facebook-github-bot committed May 5, 2019
1 parent 96ac28d commit cf17068
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 12 deletions.
24 changes: 20 additions & 4 deletions fairseq/distributed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from collections import namedtuple
import os
import pickle
import socket
import subprocess
import warnings

Expand Down Expand Up @@ -42,9 +43,20 @@ def infer_init_method(args):
hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', node_list])
args.distributed_init_method = 'tcp://{host}:{port}'.format(
host=hostnames.split()[0].decode('utf-8'),
port=args.distributed_port)
args.distributed_rank = int(os.environ.get('SLURM_PROCID'))
args.device_id = int(os.environ.get('SLURM_LOCALID'))
port=args.distributed_port,
)
nnodes = int(os.environ.get('SLURM_NNODES'))
ntasks_per_node = int(os.environ.get('SLURM_NTASKS_PER_NODE'))
if ntasks_per_node == 1:
assert args.distributed_world_size % nnodes == 0
gpus_per_node = args.distributed_world_size // nnodes
node_id = int(os.environ.get('SLURM_NODEID'))
args.distributed_rank = node_id * gpus_per_node
else:
assert ntasks_per_node == args.distributed_world_size // nnodes
args.distributed_no_spawn = True
args.distributed_rank = int(os.environ.get('SLURM_PROCID'))
args.device_id = int(os.environ.get('SLURM_LOCALID'))
except subprocess.CalledProcessError as e: # scontrol failed
raise e
except FileNotFoundError: # Slurm is not installed
Expand All @@ -60,13 +72,17 @@ def distributed_init(args):
else:
print('| distributed init (rank {}): {}'.format(
args.distributed_rank, args.distributed_init_method), flush=True)

dist.init_process_group(
backend=args.distributed_backend,
init_method=args.distributed_init_method,
world_size=args.distributed_world_size,
rank=args.distributed_rank,
)
print('| initialized host {} as rank {}'.format(
socket.gethostname(), args.distributed_rank), flush=True)

# perform a dummy all-reduce to initialize the NCCL communicator
dist.all_reduce(torch.rand(1).cuda())

suppress_output(is_master(args))

Expand Down
2 changes: 2 additions & 0 deletions fairseq/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,8 @@ def add_distributed_training_args(parser):
help='port number (not required if using --distributed-init-method)')
group.add_argument('--device-id', '--local_rank', default=0, type=int,
help='which GPU to use (usually configured automatically)')
group.add_argument('--distributed-no-spawn', action='store_true',
help='do not spawn multiple processes even if multiple GPUs are visible')
group.add_argument('--ddp-backend', default='c10d', type=str,
choices=['c10d', 'no_c10d'],
help='DistributedDataParallel backend')
Expand Down
31 changes: 23 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,21 @@
from fairseq.meters import AverageMeter, StopwatchMeter


def main(args):
def main(args, init_distributed=False):
utils.import_user_module(args)

if args.max_tokens is None:
args.max_tokens = 6000
print(args)
assert args.max_tokens is not None or args.max_sentences is not None, \
'Must specify batch size either with --max-tokens or --max-sentences'

# Initialize CUDA and distributed training
if torch.cuda.is_available() and not args.cpu:
torch.cuda.set_device(args.device_id)
torch.manual_seed(args.seed)
if init_distributed:
args.distributed_rank = distributed_utils.distributed_init(args)

# Print args
print(args)

# Setup task, e.g., translation, language modeling, etc.
task = tasks.setup_task(args)
Expand Down Expand Up @@ -372,11 +377,11 @@ def load_dataset_splits(args, task):
raise e


def distributed_main(i, args):
def distributed_main(i, args, start_rank=0):
args.device_id = i
if args.distributed_rank is None: # torch.multiprocessing.spawn
args.distributed_rank = i
main(args)
args.distributed_rank = start_rank + i
main(args, init_distributed=True)


def cli_main():
Expand All @@ -388,9 +393,19 @@ def cli_main():

if args.distributed_init_method is not None:
# distributed training
distributed_main(args.device_id, args)
if torch.cuda.device_count() > 1 and not args.distributed_no_spawn:
start_rank = args.distributed_rank
args.distributed_rank = None # assign automatically
torch.multiprocessing.spawn(
fn=distributed_main,
args=(args, start_rank),
nprocs=torch.cuda.device_count(),
)
else:
distributed_main(args.device_id, args)
elif args.distributed_world_size > 1:
# fallback for single node with multiple GPUs
assert args.distributed_world_size <= torch.cuda.device_count()
port = random.randint(10000, 20000)
args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
args.distributed_rank = None # set based on device id
Expand Down

0 comments on commit cf17068

Please sign in to comment.