-
Notifications
You must be signed in to change notification settings - Fork 6.3k
[Examples] Save SDXL LoRA weights with chosen precision #4791
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Examples] Save SDXL LoRA weights with chosen precision #4791
Conversation
# Final inference | ||
# Load previous pipeline | ||
vae = AutoencoderKL.from_pretrained( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this already defined?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we never del the one defined here so is still there even at the end. And... it works without it :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally looks good to me - wdyt @sayakpaul ?
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
@@ -1,4 +1,4 @@ | |||
accelerate>=0.16.0 | |||
accelerate>=0.22.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, without this change, there is still a spike in GPU memory during checkpointing :/. From Zach's message on the issue, my understanding is that there was a bug when saving with the accelerator while using mixed precision
Looks amazing actually. Thanks for delving deep here and for the fixes! |
…4791) * Increase min accelerate ver to avoid OOM when mixed precision * Rm re-instantiation of VAE * Rm casting to float32 * Del unused models and free GPU * Fix style
What does this PR do?
As discussed in issue #4736, this PR tackles the following tasks:
accelerate
from 0.16.0 to 0.22.0. This prevents unnecessary peaks in GPU memory consumption during checkpointing with mixed-precision.The command
make test-examples
completed successfully.The following images were generated using the weights saved in

fp32
:The following images were generated using the same weights as above, but saved in

fp16
:Previously, peak memory consumption was borderline on an RTX 4090 during checkpoint and final weight saving. Now, it remains consistently below 75%.
Saving and loading also work with
bf16
.Fixes #4736 (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@sayakpaul