Skip to content

Commit

Permalink
Merge pull request #734 from mv1388/ddp_support_for_cpu_training
Browse files Browse the repository at this point in the history
DDP support for CPU training
  • Loading branch information
mv1388 committed Aug 13, 2022
2 parents 3867a2b + f136f95 commit 7ad2e70
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions aitoolbox/torchtrain/train_loop/train_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ def _train_dp(self, num_epochs, num_iterations, callbacks=None, grad_accumulatio
def _train_ddp(self, num_epochs, num_iterations, callbacks=None, grad_accumulation=1,
ddp_model_args=None, in_process_data_load=None,
num_nodes=1, node_rank=0, num_gpus=torch.cuda.device_count(),
backend='nccl', init_method='env://'):
backend='nccl', init_method='env://', on_gpu=True):
"""Train the model using the train loop in the Distributed Data Parallel setting
During the training, multiple processes will be spawned, one for each of the available GPUs.
Expand All @@ -774,6 +774,7 @@ def _train_ddp(self, num_epochs, num_iterations, callbacks=None, grad_accumulati
``dist.init_process_group()``. Valid values include ``mpi``, ``gloo``, and ``nccl``.
init_method (str): URL specifying how to initialize the process group. For more information look up
the documentation for ``dist.init_process_group()``.
on_gpu (bool): if the DDP training is executed on the GPU or on the CPU
"""
self.ddp_training_mode = True
os.environ['MASTER_ADDR'] = 'localhost'
Expand All @@ -787,6 +788,7 @@ def _train_ddp(self, num_epochs, num_iterations, callbacks=None, grad_accumulati
'num_gpus': num_gpus,
'world_size': num_nodes * num_gpus,
'backend': backend,
'on_gpu': on_gpu,
'init_method': init_method,
'ddp_model_args': ddp_model_args if ddp_model_args is not None else {}
}
Expand Down Expand Up @@ -820,13 +822,18 @@ def _spawn_fit(self, gpu, ddp_args, num_epochs, num_iterations, callbacks, grad_
every spawned training process. This can in turn in cause extensive overall memory consumption.
"""
rank = ddp_args['node_rank'] * ddp_args['num_gpus'] + gpu

dist.init_process_group(
backend=ddp_args['backend'], init_method=ddp_args['init_method'],
world_size=ddp_args['world_size'], rank=rank
)

torch.manual_seed(0)
torch.cuda.set_device(gpu)
self.device = torch.device(f"cuda:{gpu}")
if ddp_args['on_gpu']:
torch.cuda.set_device(gpu)
self.device = torch.device(f"cuda:{gpu}")

ddp_args['ddp_model_args']['device_ids'] = [gpu]

# DDP MP device filter any existing callbacks and add new ones
self.callbacks_handler.mp_filter_callbacks()
Expand All @@ -851,9 +858,9 @@ def _spawn_fit(self, gpu, ddp_args, num_epochs, num_iterations, callbacks, grad_

# Wrap models into DDP module
if isinstance(self.model, TTModel):
self.model = TTDistributedDataParallel(self.model, device_ids=[gpu], **ddp_args['ddp_model_args'])
self.model = TTDistributedDataParallel(self.model, **ddp_args['ddp_model_args'])
else:
self.model = DistributedDataParallel(self.model, device_ids=[gpu], **ddp_args['ddp_model_args'])
self.model = DistributedDataParallel(self.model, **ddp_args['ddp_model_args'])

self._train(num_epochs, num_iterations, callbacks, grad_accumulation)

Expand Down

0 comments on commit 7ad2e70

Please sign in to comment.