-
Notifications
You must be signed in to change notification settings - Fork 880
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix non persistant buffer dispatch #1941
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing that issue. In general, the implementation LGTM, so no comments from me on that.
One thing I would like to see are some tests to check for excluding non-persistent buffers. There are some tests for named_module_tensors
, which could be extended to check for the new option. This doesn't encapsulate every addition of this PR, so ideally the device hooks could also be tested, not sure how easy that is.
Another more general comment is that now, the new parameter remove_non_persistant
has to be passed around all over the place. Are there use cases that do require to include non-persistent buffers? If not, always removing them would simplify things.
Finally, the functions big_modeling.load_checkpoint_and_dispatch
and utils.bnb.load_and_quantize_model
call dispatch_model
, should their signatures be extended to include remove_non_persistant
too?
Hi @BenjaminBossan, thanks for your review !
I don't think. If a user needs to offload non-persistant buffers, he will need to create a mapping using directly
In any case, I will add a few test for named_modules_tensors and try to come up with something for the hooks.
Yes if we decide to keep |
Maybe @muellerzr can comment on that? If we decide to make the change, we should ensure to mention it somewhere (like making a note to include it in the next release text). |
I've added a few tests ! Let me know if you want to make the change. If we make the change @muellerzr, I will keep |
@SunMarc let's go with that, and if users request that we expand the param/trickle it down to lower levels, we can. |
still having an issue with falcon 180B gptq offload. I will put it in draft for now |
After digging a little bit into it, I think we should keep |
Thanks for tagging me in. From the from diffusers import PixArtAlphaPipeline
import torch
def bytes_to_giga_bytes(bytes):
return bytes / 1024 / 1024 / 1024
pipe = PixArtAlphaPipeline.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
torch_dtype=torch.float16
)
pipe.enable_sequential_cpu_offload()
images = pipe("hey", num_images_per_prompt=4).images[0]
print(
f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this issue and adding a couple of tests to check it.
I'm a bit concerned that we have to rely on a private attribute _non_persistent_buffers_set
but AFAICT, it's the only place that information is stored, so there is no way around it, and it seems to exist since many PyTorch releases, which hopefully means that it's safe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great fix! However one big issue with the logic here we need to address :)
What does this PR do ?
This PR fixes the offload of buffers when using
dispatch_model
orcpu_offload
(related to hooks).Currently,
dispatch_model
will offload all buffers ifoffload_buffers
is set toTrue
. This is problematic when we have non persistant buffers that are, by definition, not saved in the state_dict of the model. When we usedispatch_model
alone, we will store thestate_dict
of some modules on the disk. Hence, we won't be able to retrieve them. To fix that, we simply remove the non persistant buffers from the list of buffers to offload.PS: diffusers team actually use offload buffers