Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions src/diffusers/models/transformers/latte_transformer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,48 @@ def __init__(
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value

# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume Latte doesn't use a novel transformer block like SD3?

If so, we need to do it like this:

if self._chunk_size is not None:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Latte uses the BasicTransformerBlock which already has the chunked ff implementation here:

if self._chunk_size is not None:

Have I missed anything else that might be causing very poor memory improvements?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No then we are good.

Copy link
Contributor Author

@a-r-r-o-w a-r-r-o-w Jul 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As in, we close this PR without merging it, right? Since memory improvements are negligible for way too much increase in time required

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might benefit from adding vae slicing and tiling support it here btw. Decode memory goes up to 19 GB

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth checking that, yes.

This forward chunking phenomenon (less memory savings) is becoming evident for models that are purely transformer based as opposed to hybrid architectures (such as the SDXL UNet). For I2VXLGen or SVD, we see savings however, for pure transformer ones like Latte, SD3, we don’t.

For the SD3 800M variant, I did see nice improvements in memory but not so much for the 2B one. So, I think there is some specific FLOP/param count pattern where chunking tends to shine better.

We should investigate this phenomenon further because ff chunking could be crucial for these models. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh interesting, I haven't really explored much of the transfomer-based image/video models and can't comment much on why we see this behaviour, but we can definitely try and investigate. I will try and put together a script over the weekend unless we have something ready already.

Quick question: IIUC, we do chunking only on the FeedForward modules and never on the individual linear layers (such as qkv projection layers of attention). Isn't the benefit of doing chunking for ffn lost if we are never going to also apply it on ALL linear layers that follow? Apologies if I'm being thick-headed but it's been a while since I've looked at some of the internal modeling code and I can't wrap my head around this. We have Attn followed by FFN. In Attn, we don't do chunking, and we materialize the QK product, which should be a main contributor to overall memory required. In FFN, we do chunking and save some memory, but as I understand, this saving is largely "invisible" in the overall memory measurement, where we care about the total/max usage, due to attention projection layers (since it does not use chunking) and QK, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot do chunking on all linear layers just like that. Read more here:
https://huggingface.co/blog/reformer

Regarding QK materialization ping pong, you are likely disregarding the fact that attention computation is handled by SDPA which optimizes the memory bits already. So, I don’t think that is as much of an issue.

"""
Sets the attention processor to use [feed forward
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).

Parameters:
chunk_size (`int`, *optional*):
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
over each tensor of dim=`dim`.
dim (`int`, *optional*, defaults to `0`):
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
or dim=1 (sequence length).
"""
if dim not in [0, 1]:
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")

# By default chunk size is 1
chunk_size = chunk_size or 1

def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)

for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)

for module in self.children():
fn_recursive_feed_forward(module, chunk_size, dim)

# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
def disable_forward_chunking(self):
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)

for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)

for module in self.children():
fn_recursive_feed_forward(module, None, 0)

def forward(
self,
hidden_states: torch.Tensor,
Expand Down
20 changes: 20 additions & 0 deletions tests/pipelines/latte/test_latte.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,26 @@ def test_save_load_optional_components(self):
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(max_diff, 1.0)

def test_feed_forward_chunking(self):
device = "cpu"

components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device)
image = pipe(**inputs)[0]
image_slice_no_chunking = image[0, -3:, -3:, -1]

pipe.transformer.enable_forward_chunking(chunk_size=1, dim=0)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs)[0]
image_slice_chunking = image[0, -3:, -3:, -1]

max_diff = np.abs(to_np(image_slice_no_chunking) - to_np(image_slice_chunking)).max()
self.assertLess(max_diff, 1e-4)


@slow
@require_torch_gpu
Expand Down