-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[core] support chunked feed forward in latte #8842
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
+62
−0
Closed
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume Latte doesn't use a novel transformer block like SD3?
diffusers/src/diffusers/models/transformers/transformer_sd3.py
Line 97 in 973a62d
If so, we need to do it like this:
diffusers/src/diffusers/models/attention.py
Line 182 in 973a62d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Latte uses the
BasicTransformerBlockwhich already has the chunked ff implementation here:diffusers/src/diffusers/models/attention.py
Line 515 in bbd2f9d
Have I missed anything else that might be causing very poor memory improvements?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No then we are good.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As in, we close this PR without merging it, right? Since memory improvements are negligible for way too much increase in time required
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might benefit from adding vae slicing and tiling support it here btw. Decode memory goes up to 19 GB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be worth checking that, yes.
This forward chunking phenomenon (less memory savings) is becoming evident for models that are purely transformer based as opposed to hybrid architectures (such as the SDXL UNet). For I2VXLGen or SVD, we see savings however, for pure transformer ones like Latte, SD3, we don’t.
For the SD3 800M variant, I did see nice improvements in memory but not so much for the 2B one. So, I think there is some specific FLOP/param count pattern where chunking tends to shine better.
We should investigate this phenomenon further because ff chunking could be crucial for these models. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh interesting, I haven't really explored much of the transfomer-based image/video models and can't comment much on why we see this behaviour, but we can definitely try and investigate. I will try and put together a script over the weekend unless we have something ready already.
Quick question: IIUC, we do chunking only on the FeedForward modules and never on the individual linear layers (such as qkv projection layers of attention). Isn't the benefit of doing chunking for ffn lost if we are never going to also apply it on ALL linear layers that follow? Apologies if I'm being thick-headed but it's been a while since I've looked at some of the internal modeling code and I can't wrap my head around this. We have Attn followed by FFN. In Attn, we don't do chunking, and we materialize the QK product, which should be a main contributor to overall memory required. In FFN, we do chunking and save some memory, but as I understand, this saving is largely "invisible" in the overall memory measurement, where we care about the total/max usage, due to attention projection layers (since it does not use chunking) and QK, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We cannot do chunking on all linear layers just like that. Read more here:
https://huggingface.co/blog/reformer
Regarding QK materialization ping pong, you are likely disregarding the fact that attention computation is handled by SDPA which optimizes the memory bits already. So, I don’t think that is as much of an issue.