-
Notifications
You must be signed in to change notification settings - Fork 6.2k
[Single File] Allow loading T5 encoder in mixed precision #8778
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1808,4 +1808,17 @@ def create_diffusers_t5_model_from_checkpoint( | |
|
||
else: | ||
model.load_state_dict(diffusers_format_checkpoint) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ohh is this related to this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah exactly. When you cast the entire model to |
||
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (torch_dtype == torch.float16) | ||
if use_keep_in_fp32_modules: | ||
keep_in_fp32_modules = model._keep_in_fp32_modules | ||
else: | ||
keep_in_fp32_modules = [] | ||
|
||
if keep_in_fp32_modules is not None: | ||
for name, param in model.named_parameters(): | ||
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules): | ||
# param = param.to(torch.float32) does not work here as only in the local scope. | ||
param.data = param.data.to(torch.float32) | ||
|
||
return model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is this handled then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the model level by passing in
dtype
to their respective loading methods