-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[core] support chunked feed forward in latte #8842
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Failing tests seem unrelated |
|
thanks! did you see a significant difference with this for latte? |
Unfortunately, the impact is minimal. I expected it to be more but only realised after testing, which I got to a while after opening the PR since GPUs were occupied at the time.
I think it doesn't make sense to support this after doing my tests since the increase in inference time is not justifiable. |
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Left a single comment.
| self.gradient_checkpointing = value | ||
|
|
||
| # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking | ||
| def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume Latte doesn't use a novel transformer block like SD3?
| JointTransformerBlock( |
If so, we need to do it like this:
diffusers/src/diffusers/models/attention.py
Line 182 in 973a62d
| if self._chunk_size is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Latte uses the BasicTransformerBlock which already has the chunked ff implementation here:
diffusers/src/diffusers/models/attention.py
Line 515 in bbd2f9d
| if self._chunk_size is not None: |
Have I missed anything else that might be causing very poor memory improvements?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No then we are good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As in, we close this PR without merging it, right? Since memory improvements are negligible for way too much increase in time required
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might benefit from adding vae slicing and tiling support it here btw. Decode memory goes up to 19 GB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be worth checking that, yes.
This forward chunking phenomenon (less memory savings) is becoming evident for models that are purely transformer based as opposed to hybrid architectures (such as the SDXL UNet). For I2VXLGen or SVD, we see savings however, for pure transformer ones like Latte, SD3, we don’t.
For the SD3 800M variant, I did see nice improvements in memory but not so much for the 2B one. So, I think there is some specific FLOP/param count pattern where chunking tends to shine better.
We should investigate this phenomenon further because ff chunking could be crucial for these models. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh interesting, I haven't really explored much of the transfomer-based image/video models and can't comment much on why we see this behaviour, but we can definitely try and investigate. I will try and put together a script over the weekend unless we have something ready already.
Quick question: IIUC, we do chunking only on the FeedForward modules and never on the individual linear layers (such as qkv projection layers of attention). Isn't the benefit of doing chunking for ffn lost if we are never going to also apply it on ALL linear layers that follow? Apologies if I'm being thick-headed but it's been a while since I've looked at some of the internal modeling code and I can't wrap my head around this. We have Attn followed by FFN. In Attn, we don't do chunking, and we materialize the QK product, which should be a main contributor to overall memory required. In FFN, we do chunking and save some memory, but as I understand, this saving is largely "invisible" in the overall memory measurement, where we care about the total/max usage, due to attention projection layers (since it does not use chunking) and QK, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We cannot do chunking on all linear layers just like that. Read more here:
https://huggingface.co/blog/reformer
Regarding QK materialization ping pong, you are likely disregarding the fact that attention computation is handled by SDPA which optimizes the memory bits already. So, I don’t think that is as much of an issue.
|
let's not add this then if it does not help? |
What does this PR do?
Adds chunked feed forward support to Latte video pipeline.
Code
Who can review?
@maxin-cn @sayakpaul @yiyixuxu