-
Notifications
You must be signed in to change notification settings - Fork 25.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use model.from_pretrained for DataParallel also #8795
Conversation
When training on multiple GPUs, the code wraps a model with torch.nn.DataParallel. However if the model has custom from_pretrained logic, it does not get applied during load_best_model_at_end. This commit uses the underlying model during load_best_model_at_end, and re-wraps the loaded model with DataParallel. If you choose to reject this change, then could you please move the this logic to a function, e.g. def load_best_model_checkpoint(best_model_checkpoint) or something, so that it can be overridden?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand the problem but I disagree with the solution you picked: the model
attribute of the Learner
is always a reference to the original (unwrapped) model. I suggested a change to use that for the fix, would you mind applying it?
Thanks!
src/transformers/trainer.py
Outdated
if is_data_parallel: | ||
# re-wrap with DataParallel | ||
self.model = torch.nn.DataParallel(self.model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The model attribute of the Learner
is never wrapped in DataParallel
or the like, it stays a reference to the original model so these lines are not necessary.
With the same reasoning, I think the only change to make is on the line above this comment (L816) and replace
self.model = model.from_pretrained(self.state.best_model_checkpoint)
by
self.model = self.model.from_pretrained(self.state.best_model_checkpoint)
Thanks for the feedback. I made the change that you proposed, but I also think we should update L811 to check if `self.mode` is an instance of `PreTrained`, otherwise we would still not get into that `if` section, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect, thanks a lot!
Oh, looks like there is a last code-style issue to fix. Could you run |
I don't have |
|
Weird indeed. Will merge and fix if the issue persists. |
* Use model.from_pretrained for DataParallel also When training on multiple GPUs, the code wraps a model with torch.nn.DataParallel. However if the model has custom from_pretrained logic, it does not get applied during load_best_model_at_end. This commit uses the underlying model during load_best_model_at_end, and re-wraps the loaded model with DataParallel. If you choose to reject this change, then could you please move the this logic to a function, e.g. def load_best_model_checkpoint(best_model_checkpoint) or something, so that it can be overridden? * Fix silly bug * Address review comments Thanks for the feedback. I made the change that you proposed, but I also think we should update L811 to check if `self.mode` is an instance of `PreTrained`, otherwise we would still not get into that `if` section, right?
When training on multiple GPUs, the code wraps a model with torch.nn.DataParallel. However if the model has custom from_pretrained logic, it does not get applied during load_best_model_at_end.
This commit uses the underlying model during load_best_model_at_end, and re-wraps the loaded model with DataParallel.
If you choose to reject this change, then could you please move the this logic to a function, e.g. def load_best_model_checkpoint(best_model_checkpoint) or something, so that it can be overridden?
What does this PR do?
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors which may be interested in your PR.