-
Notifications
You must be signed in to change notification settings - Fork 25.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a post init method to all models #14431
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
I'm wondering whether keeping self.init_weights()
in each model's __init__()
method might be better for readability as it's harder to see now for users where and how the model weights are initialized. But don't feel strongly about it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me and I don't think it hurts readability, but I don't feel strongly about it. If you'd rather the self.init_weights
call stay in each modeling file, that's also fine by me @patrickvonplaten
* Add a post init method to all models * Fix tests * Fix last tests * Fix templates * Add comment * Forgot to save
What does this PR do?
This PR introduces the proper fix for #14388 by introducing a new
post_init
method to each model, which replaces the currentinit_weights()
call. The method can execute any code that requires the model to be properly initialized, such as theinit_weights()
or the gradient checkpointing BC fix (and more if need to in the future).