Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Smp grad accum #10488

Merged
merged 2 commits into from
Mar 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/transformers/sagemaker/trainer_sm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@
import smdistributed.modelparallel.torch as smp

@smp.step()
def forward_backward(model, inputs):
def forward_backward(model, inputs, gradient_accumulation_steps=1):
outputs = model(**inputs)
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
loss /= gradient_accumulation_steps
model.backward(loss)
return loss

Expand Down Expand Up @@ -73,8 +74,6 @@ class SageMakerTrainer(Trainer):
def __init__(self, args=None, **kwargs):
self.is_model_parallel_enabled = is_smdistributed_available() and args.mp_parameters != ""
super().__init__(args=args, **kwargs)
if self.is_model_parallel_enabled and self.args.gradient_accumulation_steps != 1:
raise ValueError("Gradient accumulation is not supported when model parallel is enabled.")

def is_world_process_zero(self) -> bool:
"""
Expand Down Expand Up @@ -108,7 +107,7 @@ def _wrap_model(self, model, training=True):
# Wrapping the base model twice in a DistributedModel will raise an error.
if isinstance(self.model_wrapped, smp.model.DistributedModel):
return self.model_wrapped
return smp.DistributedModel(model)
return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does the equivalent of no_sync in regular DDP.

else:
return super()._wrap_model(model)

Expand All @@ -121,7 +120,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
if self.is_model_parallel_enabled:
model.train()
inputs = self._prepare_inputs(inputs)
loss_mb = forward_backward(model, inputs)
loss_mb = forward_backward(model, inputs, self.args.gradient_accumulation_steps)
return loss_mb.reduce_mean().detach().to(self.args.device)
else:
return super().training_step(model, inputs)
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/sagemaker/training_args_sm.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,7 @@ def _setup_devices(self) -> "torch.device":
@property
def place_model_on_device(self):
return not (is_smdistributed_available() and self.mp_parameters != "")

@property
def _no_sync_in_gradient_accumulation(self):
return False
2 changes: 1 addition & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,7 +1039,7 @@ def train(
if (
((step + 1) % self.args.gradient_accumulation_steps != 0)
and self.args.local_rank != -1
and not self.args.deepspeed
and self.args._no_sync_in_gradient_accumulation
):
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
with model.no_sync():
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,13 @@ def place_model_on_device(self):
"""
return True

@property
def _no_sync_in_gradient_accumulation(self):
"""
Whether or not to use no_sync for the gradients when doing gradient accumulation.
"""
return not self.deepspeed
Comment on lines +741 to +745
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is introduced to make it easy to skip the no_sync part in subclasses.


def to_dict(self):
"""
Serializes this instance while replace `Enum` by their values (for JSON serialization support).
Expand Down