New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multiple fixes in SageMakerTrainer #10687
Conversation
if self.is_model_parallel_enabled: | ||
self._save_smp(output_dir) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to use a special save because model parallelism requires to:
- gather the state dict on all processes of d_rank 0 (it triggers a sync across those processes)
- save it only on processes 0
# Consolidate the state dict on all processed of dp_rank 0 | ||
opt_state_dict = self.optimizer.state_dict() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method is overloaded for this particular line/behavior. As for the model, the state dict of the optimizer needs to be gathered from all dp rank 0 processes.
os.path.join(checkpoint, "scheduler.pt") | ||
): | ||
self.optimizer.load_state_dict( | ||
torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location="cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method is overloaded for this particular line (needs to be loaded on the CPU and not the device).
@@ -927,6 +924,9 @@ def train( | |||
if delay_optimizer_creation: | |||
self.create_optimizer_and_scheduler(num_training_steps=max_steps) | |||
|
|||
# Check if saved optimizer or scheduler states exist | |||
self._load_optimizer_and_scheduler(resume_from_checkpoint) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This move should be harmless.
elif self.args.local_rank != -1: | ||
world_size = dist.get_world_size() | ||
world_size = max(1, world_size) | ||
world_size = max(1, self.args.world_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactor + allows world_size to be overloaded in SageMakerTrainingArguments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Tested on run_glue.py
on mnli
and mrpc
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* Handle save differently * Missing imports * Fix typo * Adapt to recent changes in save_pretrained * Forgotten brackets * Optimizer load * Fix world size * Deal wth None * Remove needless self
What does this PR do?
This PR adds quite a few fixes to the
SageMakerTrainer
to make sure example scripts run fully. In particular it fixes:drop_last=True
which is not something anyone wants.The goal is now to test a little bit more that functionality before merging the
SageMakerTrainer
into the mainTrainer
(otherwise one can't use model parallelism in seq2seq examples or QA example). The plan is to have them merged in the v4.5.0.