Skip to content

Commit

Permalink
Use updated API for overlapping grad sync with pipeline parallelism (N…
Browse files Browse the repository at this point in the history
…VIDIA#5236)

Signed-off-by: Tim Moon <tmoon@nvidia.com>

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Hainan Xu <hainanx@nvidia.com>
  • Loading branch information
timmoon10 authored and Hainan Xu committed Nov 29, 2022
1 parent ecffc88 commit 778c9f7
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -282,13 +282,16 @@ def training_step(self, batch, batch_idx):
tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size]

# handle asynchronous grad reduction
custom_sync_context_handler = None
custom_grad_sync_func = None
if self.with_distributed_adam:
if self.megatron_amp_o2:
# copy grads to main grad
custom_sync_context_handler = lambda: self._optimizer.no_sync(greedy_grad_copy=True)
else:
# keep grad tensors around
custom_sync_context_handler = lambda: self._optimizer.no_sync(greedy_grad_copy=False)
custom_grad_sync_func = self.reduce_overlap_gradients
else:
if self.megatron_amp_o2 and not self.cfg.get('sequence_parallel', False):
custom_sync_context_handler = self._optimizer.no_sync
Expand All @@ -309,6 +312,7 @@ def training_step(self, batch, batch_idx):
dtype=self.autocast_dtype,
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
custom_sync_context_handler=custom_sync_context_handler,
custom_grad_sync_func=custom_grad_sync_func,
sequence_parallel_enabled=self.cfg.get('sequence_parallel', False),
sync_batch_comm=self.cfg.get('sync_batch_comm', True),
num_micro_batches_with_partial_activation_checkpoints=self.cfg.get(
Expand All @@ -330,11 +334,8 @@ def training_step(self, batch, batch_idx):
self.allreduce_sequence_parallel_gradients()

if self.with_distributed_adam:
# launch grad reductions
# Note: grads in first pipeline stage have already been
# reduced
if not parallel_state.is_pipeline_first_stage():
self.reduce_overlap_gradients()
# gradients are reduced internally in distributed optimizer
pass
elif self.megatron_amp_o2:
# when using pipeline parallelism grads must be all-reduced after the pipeline (not asynchronously)
if self.cfg.get('pipeline_model_parallel_size', 1) > 1 or self.cfg.get('sequence_parallel', False):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,13 +308,16 @@ def training_step(self, batch, batch_idx):
tensor_shape = [encoder_seq_length, get_micro_batch_size(), self.cfg.encoder.hidden_size]

# handle asynchronous grad reduction
custom_sync_context_handler = None
custom_grad_sync_func = None
if self.with_distributed_adam:
if self.megatron_amp_o2:
# copy grads to main grad
custom_sync_context_handler = lambda: self._optimizer.no_sync(greedy_grad_copy=True)
else:
# keep grad tensors around
custom_sync_context_handler = lambda: self._optimizer.no_sync(greedy_grad_copy=False)
custom_grad_sync_func = self.reduce_overlap_gradients
else:
if (
self.megatron_amp_o2
Expand All @@ -340,6 +343,7 @@ def training_step(self, batch, batch_idx):
sync_batch_comm=self.cfg.get('sync_batch_comm', False),
grad_scaler=self.trainer.precision_plugin.scaler if self.cfg.precision == 16 else None,
custom_sync_context_handler=custom_sync_context_handler,
custom_grad_sync_func=custom_grad_sync_func,
)
else:
losses_reduced_per_micro_batch = forward_backward_no_pipelining(
Expand All @@ -365,11 +369,8 @@ def training_step(self, batch, batch_idx):
loss_mean = torch.tensor(0.0).cuda()

if self.with_distributed_adam:
# launch grad reductions
# Note: grads in first pipeline stage have already been
# reduced
if not parallel_state.is_pipeline_first_stage():
self.reduce_overlap_gradients()
# gradients are reduced internally in distributed optimizer
pass
elif self.megatron_amp_o2:
# when using pipeline parallelism grads must be reduced after the pipeline (not asynchronously)
if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
Expand Down

0 comments on commit 778c9f7

Please sign in to comment.