-
Notifications
You must be signed in to change notification settings - Fork 4.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dreambooth: reduce VRAM usage #2039
Conversation
The documentation is not available anymore as the PR was closed or merged. |
optimizer.zero_grad(set_to_none=True)
to reduce VRAM usage
cc @patil-suraj |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the PR! It saves quite a bit of memory. But according to torch docs, it changes certain behaviour. So I think instead of setting it to True
by default we could make it a argument that the user can control, something like --optimizer_set_to_none
, which will be false
by default. Also cc @pcuenca @patrickvonplaten
Agree with @patil-suraj here, @gleb-akhmerov could we maybe add an arguments |
@patil-suraj @patrickvonplaten Good point! I've changed the code to include the argument. Also, I've updated the readme to allow people to find this argument more easily. What do you think? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks!
This allowed me to run Dreambooth on a 3060 with VRAM usage of 11124MiB using the following parameters:
--mixed_precision="fp16" \ --use_8bit_adam \ --enable_xformers_memory_efficient_attention \ --gradient_accumulation_steps=1 \ --gradient_checkpointing \ --resolution=512 \ --train_batch_size=1 \ --train_text_encoder
Also, I was able to use batches of size up to 5 images, with VRAM usage of 11694MiB.
--resolution=768
is possible too, with VRAM usage of 11450MiB.--resolution=1024
uses 11598MiB, but it's pretty slow (about 4s/it).