New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Distributed Trainer: 2 little fixes #7461
Conversation
Can we see when the config is accessed (in your error message)? |
…rs_fork into distributed-bug-fox
src/transformers/trainer.py
Outdated
@@ -675,12 +675,14 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D | |||
|
|||
# Distributed training (should be after apex fp16 initialization) | |||
if self.args.local_rank != -1: | |||
config = model.config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't assume model
has a config without proper test, having Trainer work with models that are not PreTrainedModel
s is a feature that has been asked. If there is an access to config that makes the code fail, we should fix that place.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's already assumed that model.config exists. The base trainer.py accesses model.config
23 times, including in the statement below this one
https://github.com/huggingface/transformers/blob/master/src/transformers/trainer.py#L682
|
It should add an assert the model is a
|
OK. I reduced scope of this PR to just the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That works for me :-)
* reset model.config * Update src/transformers/trainer.py * use lower case tensor * Just tensor change
This reverts commit 3f93ae7.
model.config
. We could also setself.config = model.config
earlier in__init__
After which the command in Seq2SeqTrainer Distributed: AttributeError and the RuntimeError #7460 works.
CC @patil-suraj , @TevenLeScao