New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Smp grad accum #10488
Smp grad accum #10488
Conversation
@@ -108,7 +107,7 @@ def _wrap_model(self, model, training=True): | |||
# Wrapping the base model twice in a DistributedModel will raise an error. | |||
if isinstance(self.model_wrapped, smp.model.DistributedModel): | |||
return self.model_wrapped | |||
return smp.DistributedModel(model) | |||
return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does the equivalent of no_sync
in regular DDP.
def _no_sync_in_gradient_accumulation(self): | ||
""" | ||
Whether or not to use no_sync for the gradients when doing gradient accumulation. | ||
""" | ||
return not self.deepspeed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is introduced to make it easy to skip the no_sync part in subclasses.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the fix! LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! LGTM!
What does this PR do?
This PR adds support for gradient accumulation in
SageMakerTrainer
. It has been tested on the glue script with success (with and without gradient accumulation passed along).