Skip to content

Commit

Permalink
Merge pull request #726 from mv1388/ddp_dist_init_process_additional_…
Browse files Browse the repository at this point in the history
…params

Add additional user-specified params for DDP dist init process (backend & init_method)
  • Loading branch information
mv1388 committed Aug 8, 2022
2 parents f7af487 + d2e38c9 commit 1362e8e
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions aitoolbox/torchtrain/train_loop/train_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,8 @@ def _train_dp(self, num_epochs, num_iterations, callbacks=None, grad_accumulatio

def _train_ddp(self, num_epochs, num_iterations, callbacks=None, grad_accumulation=1,
ddp_model_args=None, in_process_data_load=None,
num_nodes=1, node_rank=0, num_gpus=torch.cuda.device_count()):
num_nodes=1, node_rank=0, num_gpus=torch.cuda.device_count(),
backend='nccl', init_method='env://'):
"""Train the model using the train loop in the Distributed Data Parallel setting
During the training, multiple processes will be spawned, one for each of the available GPUs.
Expand All @@ -769,6 +770,10 @@ def _train_ddp(self, num_epochs, num_iterations, callbacks=None, grad_accumulati
num_nodes (int): number of nodes in the cluster
node_rank (int): rank of the current node
num_gpus (int): number of GPUs in the node
backend (str): The backend to use. For more information look up the documentation for
``dist.init_process_group()``. Valid values include ``mpi``, ``gloo``, and ``nccl``.
init_method (str): URL specifying how to initialize the process group. For more information look up
the documentation for ``dist.init_process_group()``.
"""
self.ddp_training_mode = True
os.environ['MASTER_ADDR'] = 'localhost'
Expand All @@ -781,18 +786,22 @@ def _train_ddp(self, num_epochs, num_iterations, callbacks=None, grad_accumulati
'node_rank': node_rank,
'num_gpus': num_gpus,
'world_size': num_nodes * num_gpus,
'backend': backend,
'init_method': init_method,
'ddp_model_args': ddp_model_args if ddp_model_args is not None else {}
}

from aitoolbox.torchtrain.callbacks.abstract import AbstractCallback
if isinstance(in_process_data_load, AbstractCallback):
in_process_data_load = [in_process_data_load]

mp.spawn(self._spawn_fit,
args=(
ddp_args, num_epochs, num_iterations, callbacks, grad_accumulation, in_process_data_load
),
nprocs=ddp_args['world_size'])
mp.spawn(
self._spawn_fit,
args=(
ddp_args, num_epochs, num_iterations, callbacks, grad_accumulation, in_process_data_load
),
nprocs=ddp_args['world_size']
)

def _spawn_fit(self, gpu, ddp_args, num_epochs, num_iterations, callbacks, grad_accumulation, in_process_data_load):
"""Helper function that prepares the TrainLoop state inside each of the spawned processes and initiates training
Expand All @@ -811,7 +820,10 @@ def _spawn_fit(self, gpu, ddp_args, num_epochs, num_iterations, callbacks, grad_
every spawned training process. This can in turn in cause extensive overall memory consumption.
"""
rank = ddp_args['node_rank'] * ddp_args['num_gpus'] + gpu
dist.init_process_group(backend='nccl', init_method='env://', world_size=ddp_args['world_size'], rank=rank)
dist.init_process_group(
backend=ddp_args['backend'], init_method=ddp_args['init_method'],
world_size=ddp_args['world_size'], rank=rank
)
torch.manual_seed(0)
torch.cuda.set_device(gpu)
self.device = torch.device(f"cuda:{gpu}")
Expand Down
Binary file modified dist/aitoolbox-1.6.1-py3-none-any.whl
Binary file not shown.
Binary file modified dist/aitoolbox-1.6.1.tar.gz
Binary file not shown.

0 comments on commit 1362e8e

Please sign in to comment.