diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index 71d19216e5ff..e5548eeb0c6c 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -165,6 +165,48 @@ def __init__( def _set_gradient_checkpointing(self, module, value=False): self.gradient_checkpointing = value + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking + def disable_forward_chunking(self): + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, None, 0) + def forward( self, hidden_states: torch.Tensor, diff --git a/tests/pipelines/latte/test_latte.py b/tests/pipelines/latte/test_latte.py index 94ff7fc0faf9..95e47cea444d 100644 --- a/tests/pipelines/latte/test_latte.py +++ b/tests/pipelines/latte/test_latte.py @@ -256,6 +256,26 @@ def test_save_load_optional_components(self): max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() self.assertLess(max_diff, 1.0) + def test_feed_forward_chunking(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs)[0] + image_slice_no_chunking = image[0, -3:, -3:, -1] + + pipe.transformer.enable_forward_chunking(chunk_size=1, dim=0) + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs)[0] + image_slice_chunking = image[0, -3:, -3:, -1] + + max_diff = np.abs(to_np(image_slice_no_chunking) - to_np(image_slice_chunking)).max() + self.assertLess(max_diff, 1e-4) + @slow @require_torch_gpu