From d9e7372d6237973a0521b5057c5672ef220fb0c3 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 9 Dec 2024 14:28:38 +0100 Subject: [PATCH 01/26] init --- src/diffusers/models/hooks.py | 228 ++++++++++++++++++ src/diffusers/pipelines/faster_cache_utils.py | 16 ++ src/diffusers/pipelines/pipeline_utils.py | 4 + 3 files changed, 248 insertions(+) create mode 100644 src/diffusers/models/hooks.py create mode 100644 src/diffusers/pipelines/faster_cache_utils.py diff --git a/src/diffusers/models/hooks.py b/src/diffusers/models/hooks.py new file mode 100644 index 000000000000..9c3ca0a76fe6 --- /dev/null +++ b/src/diffusers/models/hooks.py @@ -0,0 +1,228 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import Any, Callable, Dict, Tuple + +import torch + + +# Reference: https://github.com/huggingface/accelerate/blob/ba7ab93f5e688466ea56908ea3b056fae2f9a023/src/accelerate/hooks.py +class ModelHook: + r""" + A hook that contains callbacks to be executed just before and after the forward method of a model. The difference + with PyTorch existing hooks is that they get passed along the kwargs. + """ + + def init_hook(self, module: torch.nn.Module) -> torch.nn.Module: + r""" + Hook that is executed when a model is initialized. + Args: + module (`torch.nn.Module`): + The module attached to this hook. + """ + return module + + def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: + r""" + Hook that is executed just before the forward method of the model. + Args: + module (`torch.nn.Module`): + The module whose forward pass will be executed just after this event. + args (`Tuple[Any]`): + The positional arguments passed to the module. + kwargs (`Dict[Str, Any]`): + The keyword arguments passed to the module. + Returns: + `Tuple[Tuple[Any], Dict[Str, Any]]`: + A tuple with the treated `args` and `kwargs`. + """ + return args, kwargs + + def post_forward(self, module: torch.nn.Module, output: Any) -> Any: + r""" + Hook that is executed just after the forward method of the model. + Args: + module (`torch.nn.Module`): + The module whose forward pass been executed just before this event. + output (`Any`): + The output of the module. + Returns: + `Any`: The processed `output`. + """ + return output + + def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: + r""" + Hook that is executed when the hook is detached from a module. + Args: + module (`torch.nn.Module`): + The module detached from this hook. + """ + return module + + def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: + return module + + +class SequentialHook(ModelHook): + r"""A hook that can contain several hooks and iterates through them at each event.""" + + def __init__(self, *hooks): + self.hooks = hooks + + def init_hook(self, module): + for hook in self.hooks: + module = hook.init_hook(module) + return module + + def pre_forward(self, module, *args, **kwargs): + for hook in self.hooks: + args, kwargs = hook.pre_forward(module, *args, **kwargs) + return args, kwargs + + def post_forward(self, module, output): + for hook in self.hooks: + output = hook.post_forward(module, output) + return output + + def detach_hook(self, module): + for hook in self.hooks: + module = hook.detach_hook(module) + return module + + def reset_state(self, module): + for hook in self.hooks: + module = hook.reset_state(module) + return module + + +class FasterCacheHook(ModelHook): + def __init__( + self, + skip_callback: Callable[[torch.nn.Module], bool], + ) -> None: + super().__init__() + + self.skip_callback = skip_callback + + self.cache = None + self._iteration = 0 + + def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: + args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) + + if self.cache is not None and self.skip_callback(module): + output = self.cache + else: + output = module._old_forward(*args, **kwargs) + + return module._diffusers_hook.post_forward(module, output) + + def post_forward(self, module: torch.nn.Module, output: Any) -> Any: + self.cache = output + return output + + def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: + self.cache = None + self._iteration = 0 + return module + + +def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False): + r""" + Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove + this behavior and restore the original `forward` method, use `remove_hook_from_module`. + + If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks + together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class. + + Args: + module (`torch.nn.Module`): + The module to attach a hook to. + hook (`ModelHook`): + The hook to attach. + append (`bool`, *optional*, defaults to `False`): + Whether the hook should be chained with an existing one (if module already contains a hook) or not. + Returns: + `torch.nn.Module`: + The same module, with the hook attached (the module is modified in place, so the result can be discarded). + """ + original_hook = hook + + if append and getattr(module, "_diffusers_hook", None) is not None: + old_hook = module._diffusers_hook + remove_hook_from_module(module) + hook = SequentialHook(old_hook, hook) + + if hasattr(module, "_diffusers_hook") and hasattr(module, "_old_forward"): + # If we already put some hook on this module, we replace it with the new one. + old_forward = module._old_forward + else: + old_forward = module.forward + module._old_forward = old_forward + + module = hook.init_hook(module) + module._diffusers_hook = hook + + if hasattr(original_hook, "new_forward"): + new_forward = original_hook.new_forward + else: + + def new_forward(module, *args, **kwargs): + args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) + output = module._old_forward(*args, **kwargs) + return module._diffusers_hook.post_forward(module, output) + + # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail. + # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409 + if "GraphModuleImpl" in str(type(module)): + module.__class__.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward) + else: + module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward) + + return module + + +def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> torch.nn.Module: + """ + Removes any hook attached to a module via `add_hook_to_module`. + Args: + module (`torch.nn.Module`): + The module to attach a hook to. + recurse (`bool`, defaults to `False`): + Whether to remove the hooks recursively + Returns: + `torch.nn.Module`: + The same module, with the hook detached (the module is modified in place, so the result can be discarded). + """ + + if hasattr(module, "_diffusers_hook"): + module._diffusers_hook.detach_hook(module) + delattr(module, "_diffusers_hook") + + if hasattr(module, "_old_forward"): + # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail. + # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409 + if "GraphModuleImpl" in str(type(module)): + module.__class__.forward = module._old_forward + else: + module.forward = module._old_forward + delattr(module, "_old_forward") + + if recurse: + for child in module.children(): + remove_hook_from_module(child, recurse) + + return module diff --git a/src/diffusers/pipelines/faster_cache_utils.py b/src/diffusers/pipelines/faster_cache_utils.py new file mode 100644 index 000000000000..01f2aa822ea7 --- /dev/null +++ b/src/diffusers/pipelines/faster_cache_utils.py @@ -0,0 +1,16 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + + diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index a504184ea2f2..b296f2cef46a 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1088,6 +1088,10 @@ def maybe_free_model_hooks(self): is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it functions correctly when applying enable_model_cpu_offload. """ + + if hasattr(self, "_diffusers_hook"): + self._diffusers_hook.reset_state() + if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0: # `enable_model_cpu_offload` has not be called, so silently do nothing return From 9a732f062c3e2de4ef0e8c15c09f218bfb115a48 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 9 Dec 2024 22:06:35 +0100 Subject: [PATCH 02/26] update --- src/diffusers/models/hooks.py | 51 +--- .../pipelines/cogvideo/pipeline_cogvideox.py | 4 + src/diffusers/pipelines/faster_cache_utils.py | 289 ++++++++++++++++++ src/diffusers/pipelines/pipeline_utils.py | 2 +- 4 files changed, 304 insertions(+), 42 deletions(-) diff --git a/src/diffusers/models/hooks.py b/src/diffusers/models/hooks.py index 9c3ca0a76fe6..af21e03f775b 100644 --- a/src/diffusers/models/hooks.py +++ b/src/diffusers/models/hooks.py @@ -13,7 +13,7 @@ # limitations under the License. import functools -from typing import Any, Callable, Dict, Tuple +from typing import Any, Dict, Tuple import torch @@ -28,6 +28,7 @@ class ModelHook: def init_hook(self, module: torch.nn.Module) -> torch.nn.Module: r""" Hook that is executed when a model is initialized. + Args: module (`torch.nn.Module`): The module attached to this hook. @@ -37,6 +38,7 @@ def init_hook(self, module: torch.nn.Module) -> torch.nn.Module: def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: r""" Hook that is executed just before the forward method of the model. + Args: module (`torch.nn.Module`): The module whose forward pass will be executed just after this event. @@ -53,6 +55,7 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[A def post_forward(self, module: torch.nn.Module, output: Any) -> Any: r""" Hook that is executed just after the forward method of the model. + Args: module (`torch.nn.Module`): The module whose forward pass been executed just before this event. @@ -66,15 +69,13 @@ def post_forward(self, module: torch.nn.Module, output: Any) -> Any: def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: r""" Hook that is executed when the hook is detached from a module. + Args: module (`torch.nn.Module`): The module detached from this hook. """ return module - def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: - return module - class SequentialHook(ModelHook): r"""A hook that can contain several hooks and iterates through them at each event.""" @@ -102,52 +103,19 @@ def detach_hook(self, module): module = hook.detach_hook(module) return module - def reset_state(self, module): - for hook in self.hooks: - module = hook.reset_state(module) - return module - - -class FasterCacheHook(ModelHook): - def __init__( - self, - skip_callback: Callable[[torch.nn.Module], bool], - ) -> None: - super().__init__() - - self.skip_callback = skip_callback - - self.cache = None - self._iteration = 0 - - def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: - args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) - - if self.cache is not None and self.skip_callback(module): - output = self.cache - else: - output = module._old_forward(*args, **kwargs) - - return module._diffusers_hook.post_forward(module, output) - - def post_forward(self, module: torch.nn.Module, output: Any) -> Any: - self.cache = output - return output - - def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: - self.cache = None - self._iteration = 0 - return module - def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False): r""" Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove this behavior and restore the original `forward` method, use `remove_hook_from_module`. + + If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class. + + Args: module (`torch.nn.Module`): The module to attach a hook to. @@ -198,6 +166,7 @@ def new_forward(module, *args, **kwargs): def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> torch.nn.Module: """ Removes any hook attached to a module via `add_hook_to_module`. + Args: module (`torch.nn.Module`): The module to attach a hook to. diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 27c2de384cb8..c6b392393c1d 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -622,6 +622,7 @@ def __call__( ) self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs + self._current_timestep = None self._interrupt = False # 2. Default call parameters @@ -700,6 +701,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -755,6 +757,8 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + self._current_timestep = None + if not output_type == "latent": # Discard any padding frames that were added for CogVideoX 1.5 latents = latents[:, additional_frames:] diff --git a/src/diffusers/pipelines/faster_cache_utils.py b/src/diffusers/pipelines/faster_cache_utils.py index 01f2aa822ea7..9f0f9899ccb8 100644 --- a/src/diffusers/pipelines/faster_cache_utils.py +++ b/src/diffusers/pipelines/faster_cache_utils.py @@ -12,5 +12,294 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass +from typing import Any, Callable, Optional, Tuple +import torch +import torch.fft as FFT +import torch.nn as nn +from ..models.attention_processor import Attention +from ..models.hooks import ModelHook, add_hook_to_module +from ..utils import logging +from .pipeline_utils import DiffusionPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +_ATTENTION_CLASSES = (Attention,) + +_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks") +_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) + + +@dataclass +class FasterCacheConfig: + r""" + Configuration for [FasterCache](https://huggingface.co/papers/2410.19355). + """ + + # In the paper and codebase, they hardcode these values to 2. However, it can be made configurable + # after some testing. We default to 2 if these parameters are not provided. + spatial_attention_block_skip_range: Optional[int] = None + temporal_attention_block_skip_range: Optional[int] = None + + # TODO(aryan): write heuristics for what the best way to obtain these values are + spatial_attention_timestep_skip_range: Tuple[int, int] = (-1, 681) + temporal_attention_timestep_skip_range: Tuple[int, int] = (-1, 681) + + # Indicator functions for low/high frequency as mentioned in Equation 11 of the paper + low_frequency_weight_update_timestep_range: Tuple[int, int] = (99, 641) + high_frequency_weight_update_timestep_range: Tuple[int, int] = (-1, 301) + + # ⍺1 and ⍺2 as mentioned in Equation 11 of the paper + alpha_low_frequency = 1.1 + alpha_high_frequency = 1.1 + + spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS + + attention_weight_callback: Callable[[nn.Module], float] = None + low_frequency_weight_callback: Callable[[nn.Module], float] = None + high_frequency_weight_callback: Callable[[nn.Module], float] = None + + +class FasterCacheDenoiserState: + r""" + State for [FasterCache](https://huggingface.co/papers/2410.19355) top-level denoiser module. + """ + + def __init__(self, delta_update_callback: Callable[[Any, int, float, float], Tuple[float, float]]) -> None: + self.delta_update_callback = delta_update_callback + + self.iteration = 0 + self.low_frequency_delta = None + self.high_frequency_delta = None + + def update_state(self, output: Any) -> None: + self.iteration += 1 + self.low_frequency_delta, self.high_frequency_delta = self.delta_update_callback( + output, self.iteration, self.low_frequency_delta, self.high_frequency_delta + ) + + def reset_state(self): + self.iteration = 0 + self.low_frequency_delta = None + self.high_frequency_delta = None + + +class FasterCacheState: + r""" + State for [FasterCache](https://huggingface.co/papers/2410.19355). Every underlying block that FasterCache is + applied to will have an instance of this state. + + Attributes: + iteration (`int`): + The current iteration of the FasterCache. It is necessary to ensure that `reset_state` is called before + starting a new inference forward pass for this to work correctly. + """ + + def __init__(self) -> None: + self.iteration = 0 + self.cache = None + + def update_state(self, output: Any) -> None: + self.iteration += 1 + if self.cache is None: + self.cache = [output, output] + else: + self.cache = [self.cache[-1], output] + + def reset_state(self): + self.iteration = 0 + self.cache = None + + +def apply_faster_cache( + pipeline: DiffusionPipeline, + config: Optional[FasterCacheConfig] = None, + denoiser: Optional[nn.Module] = None, +) -> None: + r""" + Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline. + + Args: + pipeline (`DiffusionPipeline`): + The diffusion pipeline to apply FasterCache to. + config (`Optional[FasterCacheConfig]`, `optional`, defaults to `None`): + The configuration to use for FasterCache. + denoiser (`Optional[nn.Module]`, `optional`, defaults to `None`): + The denoiser module to apply FasterCache to. If `None`, the pipeline's transformer or unet module will be + used. + + Example: + ```python + # TODO(aryan) + ``` + """ + + if config is None: + config = FasterCacheConfig() + + if config.spatial_attention_block_skip_range is None and config.temporal_attention_block_skip_range is None: + logger.warning( + "FasterCache requires one of `spatial_attention_block_skip_range` or `temporal_attention_block_skip_range` " + "to be set to an integer, not `None`. Defaulting to using `spatial_attention_block_skip_range=2` and " + "`temporal_attention_block_skip_range=2`. To avoid this warning, please set one of the above parameters." + ) + config.spatial_attention_block_skip_range = 2 + config.temporal_attention_block_skip_range = 2 + + if config.attention_weight_callback is None: + # If the user has not provided a weight callback, we default to 0.5 for all timesteps. + # In the paper, they recommend using a gradually increasing weight from 0 to 1 as the inference progresses, but + # this depends from model-to-model. It is required by the user to provide a weight callback if they want to + # use a different weight function. Defaulting to 0.5 works well in practice for most cases. + logger.warning( + "FasterCache requires an `attention_weight_callback` to be set. Defaulting to using a weight of 0.5 for all timesteps." + ) + config.attention_weight_callback = lambda _: 0.5 + + if config.low_frequency_weight_callback is None: + logger.debug( + "Low frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper." + ) + config.low_frequency_weight_callback = lambda _: config.alpha_low_frequency + + if config.high_frequency_weight_callback is None: + logger.debug( + "High frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper." + ) + config.high_frequency_weight_callback = lambda _: config.alpha_high_frequency + + if denoiser is None: + denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet + + for name, module in denoiser.named_modules(): + if not isinstance(module, _ATTENTION_CLASSES): + continue + if isinstance(module, Attention): + _apply_fastercache_on_attention_class(pipeline, name, module, config) + + +def apply_fastercache_on_module( + module: nn.Module, skip_callback: Callable[[nn.Module], bool], weight_callback: Callable[[nn.Module], float] +) -> None: + module._fastercache_state = FasterCacheState() + hook = FasterCacheBlockHook(skip_callback, weight_callback) + add_hook_to_module(module, hook, append=True) + + +def _apply_fastercache_on_attention_class( + pipeline: DiffusionPipeline, name: str, module: Attention, config: FasterCacheConfig +) -> None: + # Similar check as PEFT to determine if a string layer name matches a module name + # TODO(aryan): make this regex based + is_spatial_self_attention = ( + any( + f"{identifier}." in name or identifier == name for identifier in config.spatial_attention_block_identifiers + ) + and config.spatial_attention_block_skip_range is not None + and not module.is_cross_attention + ) + is_temporal_self_attention = ( + any( + f"{identifier}." in name or identifier == name + for identifier in config.temporal_attention_block_identifiers + ) + and config.temporal_attention_block_skip_range is not None + and not module.is_cross_attention + ) + + block_skip_range, timestep_skip_range, block_type = None, None, None + if is_spatial_self_attention: + block_skip_range = config.spatial_attention_block_skip_range + timestep_skip_range = config.spatial_attention_timestep_skip_range + block_type = "spatial" + elif is_temporal_self_attention: + block_skip_range = config.temporal_attention_block_skip_range + timestep_skip_range = config.temporal_attention_timestep_skip_range + block_type = "temporal" + + if block_skip_range is None or timestep_skip_range is None: + logger.info( + f'Unable to apply FasterCache to the selected layer: "{name}" because it does ' + f"not match any of the required criteria for spatial, temporal or cross attention layers. Note, " + f"however, that this layer may still be valid for applying PAB. Please specify the correct " + f"block identifiers in the configuration or use the specialized `apply_fastercache_on_module` " + f"function to apply FasterCache to this layer." + ) + return + + def skip_callback(module: nn.Module) -> bool: + is_using_classifier_free_guidance = pipeline.do_classifier_free_guidance + if not is_using_classifier_free_guidance: + return False + + fastercache_state = module._fastercache_state + is_within_timestep_range = timestep_skip_range[0] < pipeline._current_timestep < timestep_skip_range[1] + + if not is_within_timestep_range: + # We are still not in the phase of inference where skipping attention is possible without minimal quality + # loss, as described in the paper. So, the attention computation cannot be skipped + return False + + should_compute_attention = ( + fastercache_state.iteration > 0 and fastercache_state.iteration % block_skip_range == 0 + ) + return not should_compute_attention + + logger.debug(f"Enabling FasterCache ({block_type}) for layer: {name}") + apply_fastercache_on_module(module, skip_callback, config.attention_weight_callback) + + +class FasterCacheModelHook(ModelHook): + def __init__(self) -> None: + super().__init__() + + +class FasterCacheBlockHook(ModelHook): + def __init__( + self, skip_callback: Callable[[nn.Module], bool], weight_callback: Callable[[nn.Module], float] + ) -> None: + super().__init__() + + self.skip_callback = skip_callback + self.weight_callback = weight_callback + + def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: + args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) + + if self.skip_callback(module): + t_2_output, t_output = module._fastercache_state.cache + output = t_output + (t_output - t_2_output) * self.weight_callback(module) + else: + output = module._old_forward(*args, **kwargs) + + return module._diffusers_hook.post_forward(module, output) + + def post_forward(self, module: nn.Module, output: Any) -> Any: + module._fastercache_state.update_state(output) + return output + + +# Reference: https://github.com/Vchitect/FasterCache/blob/fab32c15014636dc854948319c0a9a8d92c7acb4/scripts/latte/fastercache_sample_latte.py#L127C1-L143C39 +@torch.no_grad() +def _fft(tensor): + tensor_fft = FFT.fft2(tensor) + tensor_fft_shifted = FFT.fftshift(tensor_fft) + batch_size, num_channels, height, width = tensor.size() + radius = min(height, width) // 5 + + y_grid, x_grid = torch.meshgrid(torch.arange(height), torch.arange(width)) + center_x, center_y = width // 2, height // 2 + mask = (x_grid - center_x) ** 2 + (y_grid - center_y) ** 2 <= radius**2 + + low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(tensor.device) + high_freq_mask = ~low_freq_mask + + low_freq_fft = tensor_fft_shifted * low_freq_mask + high_freq_fft = tensor_fft_shifted * high_freq_mask + + return low_freq_fft, high_freq_fft diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index b296f2cef46a..6725fe49dfc1 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1091,7 +1091,7 @@ def maybe_free_model_hooks(self): if hasattr(self, "_diffusers_hook"): self._diffusers_hook.reset_state() - + if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0: # `enable_model_cpu_offload` has not be called, so silently do nothing return From 80c5acdaebdf500deeb7bae19c3e74a38c20b918 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 9 Dec 2024 22:11:17 +0100 Subject: [PATCH 03/26] update --- src/diffusers/pipelines/pipeline_utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 6725fe49dfc1..a504184ea2f2 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1088,10 +1088,6 @@ def maybe_free_model_hooks(self): is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it functions correctly when applying enable_model_cpu_offload. """ - - if hasattr(self, "_diffusers_hook"): - self._diffusers_hook.reset_state() - if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0: # `enable_model_cpu_offload` has not be called, so silently do nothing return From 6047114e90089a964efef79f7facd2286b97f6a7 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 27 Dec 2024 16:12:52 +0100 Subject: [PATCH 04/26] update --- src/diffusers/models/hooks.py | 31 +- src/diffusers/pipelines/faster_cache_utils.py | 324 ++++++++++++++---- .../pipelines/latte/pipeline_latte.py | 10 +- 3 files changed, 298 insertions(+), 67 deletions(-) diff --git a/src/diffusers/models/hooks.py b/src/diffusers/models/hooks.py index af21e03f775b..f29ecfe9a969 100644 --- a/src/diffusers/models/hooks.py +++ b/src/diffusers/models/hooks.py @@ -13,7 +13,7 @@ # limitations under the License. import functools -from typing import Any, Dict, Tuple +from typing import Any, Dict, List, Tuple import torch @@ -25,6 +25,8 @@ class ModelHook: with PyTorch existing hooks is that they get passed along the kwargs. """ + _is_stateful = False + def init_hook(self, module: torch.nn.Module) -> torch.nn.Module: r""" Hook that is executed when a model is initialized. @@ -75,13 +77,17 @@ def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: The module detached from this hook. """ return module + + def reset_state(self): + if self._is_stateful: + raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.") class SequentialHook(ModelHook): r"""A hook that can contain several hooks and iterates through them at each event.""" def __init__(self, *hooks): - self.hooks = hooks + self.hooks: List[ModelHook] = hooks def init_hook(self, module): for hook in self.hooks: @@ -102,6 +108,11 @@ def detach_hook(self, module): for hook in self.hooks: module = hook.detach_hook(module) return module + + def reset_state(self): + for hook in self.hooks: + if hook._is_stateful: + hook.reset_state() def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False): @@ -195,3 +206,19 @@ def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> t remove_hook_from_module(child, recurse) return module + + +def reset_stateful_hooks(module: torch.nn.Module, recurse: bool = False): + """ + Resets the state of all stateful hooks attached to a module. + + Args: + module (`torch.nn.Module`): + The module to reset the stateful hooks from. + """ + if hasattr(module, "_diffusers_hook") and (module._diffusers_hook._is_stateful or isinstance(module._diffusers_hook, SequentialHook)): + module._diffusers_hook.reset_state(module) + + if recurse: + for child in module.children(): + reset_stateful_hooks(child, recurse) diff --git a/src/diffusers/pipelines/faster_cache_utils.py b/src/diffusers/pipelines/faster_cache_utils.py index 9f0f9899ccb8..020e92794b34 100644 --- a/src/diffusers/pipelines/faster_cache_utils.py +++ b/src/diffusers/pipelines/faster_cache_utils.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re from dataclasses import dataclass -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple import torch import torch.fft as FFT @@ -30,8 +31,18 @@ _ATTENTION_CLASSES = (Attention,) -_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks") +_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ( + "blocks", + "transformer_blocks", +) _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) +_UNCOND_COND_INPUT_KWARGS_IDENTIFIERS = ( + "hidden_states", + "encoder_hidden_states", + "timestep", + "attention_mask", + "encoder_attention_mask", +) @dataclass @@ -40,23 +51,29 @@ class FasterCacheConfig: Configuration for [FasterCache](https://huggingface.co/papers/2410.19355). """ + num_train_timesteps: int = 1000 + # In the paper and codebase, they hardcode these values to 2. However, it can be made configurable # after some testing. We default to 2 if these parameters are not provided. spatial_attention_block_skip_range: Optional[int] = None temporal_attention_block_skip_range: Optional[int] = None # TODO(aryan): write heuristics for what the best way to obtain these values are - spatial_attention_timestep_skip_range: Tuple[int, int] = (-1, 681) - temporal_attention_timestep_skip_range: Tuple[int, int] = (-1, 681) + spatial_attention_timestep_skip_range: Tuple[float, float] = (-1, 681) + temporal_attention_timestep_skip_range: Tuple[float, float] = (-1, 681) # Indicator functions for low/high frequency as mentioned in Equation 11 of the paper - low_frequency_weight_update_timestep_range: Tuple[int, int] = (99, 641) + low_frequency_weight_update_timestep_range: Tuple[int, int] = (99, 901) high_frequency_weight_update_timestep_range: Tuple[int, int] = (-1, 301) # ⍺1 and ⍺2 as mentioned in Equation 11 of the paper alpha_low_frequency = 1.1 alpha_high_frequency = 1.1 + # n as described in CFG-Cache explanation in the paper - dependant on the model + unconditional_batch_skip_range: int = 5 + unconditional_batch_timestep_skip_range: Tuple[float, float] = (-1, 641) + spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS @@ -64,26 +81,30 @@ class FasterCacheConfig: low_frequency_weight_callback: Callable[[nn.Module], float] = None high_frequency_weight_callback: Callable[[nn.Module], float] = None + tensor_format: str = "BCFHW" + unconditional_conditional_input_kwargs_identifiers: List[str] = _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS + class FasterCacheDenoiserState: r""" State for [FasterCache](https://huggingface.co/papers/2410.19355) top-level denoiser module. """ - def __init__(self, delta_update_callback: Callable[[Any, int, float, float], Tuple[float, float]]) -> None: - self.delta_update_callback = delta_update_callback + def __init__( + self, + low_frequency_weight_callback: Callable[[nn.Module], torch.Tensor], + high_frequency_weight_callback: Callable[[nn.Module], torch.Tensor], + uncond_skip_callback: Callable[[nn.Module], bool], + ) -> None: + self.low_frequency_weight_callback = low_frequency_weight_callback + self.high_frequency_weight_callback = high_frequency_weight_callback + self.uncond_skip_callback = uncond_skip_callback self.iteration = 0 self.low_frequency_delta = None self.high_frequency_delta = None - def update_state(self, output: Any) -> None: - self.iteration += 1 - self.low_frequency_delta, self.high_frequency_delta = self.delta_update_callback( - output, self.iteration, self.low_frequency_delta, self.high_frequency_delta - ) - - def reset_state(self): + def reset(self): self.iteration = 0 self.low_frequency_delta = None self.high_frequency_delta = None @@ -100,19 +121,19 @@ class FasterCacheState: starting a new inference forward pass for this to work correctly. """ - def __init__(self) -> None: + def __init__( + self, skip_callback: Callable[[nn.Module], bool], weight_callback: Callable[[nn.Module], float] + ) -> None: + self.skip_callback = skip_callback + self.weight_callback = weight_callback + self.iteration = 0 + self.batch_size = None self.cache = None - def update_state(self, output: Any) -> None: - self.iteration += 1 - if self.cache is None: - self.cache = [output, output] - else: - self.cache = [self.cache[-1], output] - - def reset_state(self): + def reset(self): self.iteration = 0 + self.batch_size = None self.cache = None @@ -144,7 +165,7 @@ def apply_faster_cache( if config.spatial_attention_block_skip_range is None and config.temporal_attention_block_skip_range is None: logger.warning( - "FasterCache requires one of `spatial_attention_block_skip_range` or `temporal_attention_block_skip_range` " + "FasterCache requires one of `spatial_attention_block_skip_range` and/or `temporal_attention_block_skip_range` " "to be set to an integer, not `None`. Defaulting to using `spatial_attention_block_skip_range=2` and " "`temporal_attention_block_skip_range=2`. To avoid this warning, please set one of the above parameters." ) @@ -165,16 +186,39 @@ def apply_faster_cache( logger.debug( "Low frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper." ) - config.low_frequency_weight_callback = lambda _: config.alpha_low_frequency + + def low_frequency_weight_callback(module: nn.Module) -> float: + is_within_range = ( + config.low_frequency_weight_update_timestep_range[0] + < pipeline._current_timestep + < config.low_frequency_weight_update_timestep_range[1] + ) + return config.alpha_low_frequency if is_within_range else 1.0 + + config.low_frequency_weight_callback = low_frequency_weight_callback if config.high_frequency_weight_callback is None: logger.debug( "High frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper." ) - config.high_frequency_weight_callback = lambda _: config.alpha_high_frequency + + def high_frequency_weight_callback(module: nn.Module) -> float: + is_within_range = ( + config.high_frequency_weight_update_timestep_range[0] + < pipeline._current_timestep + < config.high_frequency_weight_update_timestep_range[1] + ) + return config.alpha_high_frequency if is_within_range else 1.0 + + config.high_frequency_weight_callback = high_frequency_weight_callback + + supported_tensor_formats = ["BCFHW", "BFCHW", "BCHW"] # TODO(aryan): Support BSC for LTX Video + if config.tensor_format not in supported_tensor_formats: + raise ValueError(f"`tensor_format` must be one of {supported_tensor_formats}, but got {config.tensor_format}.") if denoiser is None: denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet + _apply_fastercache_on_denoiser(pipeline, denoiser, config) for name, module in denoiser.named_modules(): if not isinstance(module, _ATTENTION_CLASSES): @@ -183,23 +227,44 @@ def apply_faster_cache( _apply_fastercache_on_attention_class(pipeline, name, module, config) -def apply_fastercache_on_module( - module: nn.Module, skip_callback: Callable[[nn.Module], bool], weight_callback: Callable[[nn.Module], float] +def _apply_fastercache_on_denoiser( + pipeline: DiffusionPipeline, denoiser: nn.Module, config: FasterCacheConfig ) -> None: - module._fastercache_state = FasterCacheState() - hook = FasterCacheBlockHook(skip_callback, weight_callback) - add_hook_to_module(module, hook, append=True) + def uncond_skip_callback(module: nn.Module) -> bool: + # If we are not using classifier-free guidance, we cannot skip the denoiser computation. We only compute the + # conditional branch in this case. + is_using_classifier_free_guidance = pipeline.do_classifier_free_guidance + if not is_using_classifier_free_guidance: + return False + + # We skip the unconditional branch only if the following conditions are met: + # 1. We have completed at least one iteration of the denoiser + # 2. The current timestep is within the range specified by the user. This is the optimal timestep range + # where approximating the unconditional branch from the computation of the conditional branch is possible + # without a significant loss in quality. + # 3. The current iteration is not a multiple of the unconditional batch skip range. This is to ensure that + # we compute the unconditional branch at least once every few iterations to ensure minimal quality loss. + + state: FasterCacheDenoiserState = module._fastercache_state + is_within_range = ( + config.unconditional_batch_timestep_skip_range[0] + < pipeline._current_timestep + < config.unconditional_batch_timestep_skip_range[1] + ) + return state.iteration > 0 and is_within_range and state.iteration % config.unconditional_batch_skip_range != 0 + + denoiser._fastercache_state = FasterCacheDenoiserState( + config.low_frequency_weight_callback, config.high_frequency_weight_callback, uncond_skip_callback + ) + hook = FasterCacheModelHook(config.unconditional_conditional_input_kwargs_identifiers, config.tensor_format) + add_hook_to_module(denoiser, hook, append=True) def _apply_fastercache_on_attention_class( pipeline: DiffusionPipeline, name: str, module: Attention, config: FasterCacheConfig ) -> None: - # Similar check as PEFT to determine if a string layer name matches a module name - # TODO(aryan): make this regex based is_spatial_self_attention = ( - any( - f"{identifier}." in name or identifier == name for identifier in config.spatial_attention_block_identifiers - ) + any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers) and config.spatial_attention_block_skip_range is not None and not module.is_cross_attention ) @@ -223,9 +288,9 @@ def _apply_fastercache_on_attention_class( block_type = "temporal" if block_skip_range is None or timestep_skip_range is None: - logger.info( + logger.debug( f'Unable to apply FasterCache to the selected layer: "{name}" because it does ' - f"not match any of the required criteria for spatial, temporal or cross attention layers. Note, " + f"not match any of the required criteria for spatial or temporal attention layers. Note, " f"however, that this layer may still be valid for applying PAB. Please specify the correct " f"block identifiers in the configuration or use the specialized `apply_fastercache_on_module` " f"function to apply FasterCache to this layer." @@ -237,13 +302,16 @@ def skip_callback(module: nn.Module) -> bool: if not is_using_classifier_free_guidance: return False - fastercache_state = module._fastercache_state + fastercache_state: FasterCacheState = module._fastercache_state is_within_timestep_range = timestep_skip_range[0] < pipeline._current_timestep < timestep_skip_range[1] if not is_within_timestep_range: # We are still not in the phase of inference where skipping attention is possible without minimal quality # loss, as described in the paper. So, the attention computation cannot be skipped return False + if fastercache_state.cache is None or fastercache_state.iteration < 2: + # We need at least 2 iterations to start skipping attention computation + return False should_compute_attention = ( fastercache_state.iteration > 0 and fastercache_state.iteration % block_skip_range == 0 @@ -251,55 +319,185 @@ def skip_callback(module: nn.Module) -> bool: return not should_compute_attention logger.debug(f"Enabling FasterCache ({block_type}) for layer: {name}") - apply_fastercache_on_module(module, skip_callback, config.attention_weight_callback) + module._fastercache_state = FasterCacheState(skip_callback, config.attention_weight_callback) + hook = FasterCacheBlockHook() + add_hook_to_module(module, hook, append=True) class FasterCacheModelHook(ModelHook): - def __init__(self) -> None: + _is_stateful = True + + def __init__(self, uncond_cond_input_kwargs_identifiers: List[str], tensor_format: str) -> None: super().__init__() + # We can't easily detect what args are to be split in unconditional and conditional branches. We + # can only do it for kwargs, hence they are the only ones we split. The args are passed as-is. + # If a model is to be made compatible with FasterCache, the user must ensure that the inputs that + # contain batchwise-concatenated unconditional and conditional inputs are passed as kwargs. + self.uncond_cond_input_kwargs_identifiers = uncond_cond_input_kwargs_identifiers + self.tensor_format = tensor_format -class FasterCacheBlockHook(ModelHook): - def __init__( - self, skip_callback: Callable[[nn.Module], bool], weight_callback: Callable[[nn.Module], float] - ) -> None: - super().__init__() - - self.skip_callback = skip_callback - self.weight_callback = weight_callback + def _get_cond_input(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # Note: this method assumes that the input tensor is batchwise-concatenated with unconditional inputs + # followed by conditional inputs. + _, cond = input.chunk(2, dim=0) + return cond def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) + state: FasterCacheDenoiserState = module._fastercache_state + + # Split the unconditional and conditional inputs. We only want to infer the conditional branch if the + # requirements for skipping the unconditional branch are met as described in the paper. + should_skip_uncond = state.uncond_skip_callback(module) + if should_skip_uncond: + kwargs = { + k: v if k not in self.uncond_cond_input_kwargs_identifiers else self._get_cond_input(v) + for k, v in kwargs.items() + } + # TODO(aryan): remove later + logger.debug("Skipping unconditional branch computation") + + if should_skip_uncond: + breakpoint() + output = module._old_forward(*args, **kwargs) + # TODO(aryan): handle Transformer2DModelOutput + hidden_states = output[0] if isinstance(output, tuple) else output + batch_size = hidden_states.size(0) + + if should_skip_uncond: + state.low_frequency_delta = state.low_frequency_delta * state.low_frequency_weight_callback(module) + state.high_frequency_delta = state.high_frequency_delta * state.high_frequency_weight_callback(module) + + if self.tensor_format == "BCFHW": + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW": + hidden_states = hidden_states.flatten(0, 1) + + low_freq_cond, high_freq_cond = _split_low_high_freq(hidden_states.float()) + + # Approximate/compute the unconditional branch outputs as described in Equation 9 and 10 of the paper + low_freq_uncond = state.low_frequency_delta + low_freq_cond + high_freq_uncond = state.high_frequency_delta + high_freq_cond + uncond_freq = low_freq_uncond + high_freq_uncond + + uncond_states = FFT.ifftshift(uncond_freq) + uncond_states = FFT.ifft2(uncond_states).real + + if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW": + uncond_states = uncond_states.unflatten(0, (batch_size, -1)) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)) + if self.tensor_format == "BCFHW": + uncond_states = uncond_states.permute(0, 2, 1, 3, 4) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + # Concatenate the approximated unconditional and predicted conditional branches + uncond_states = uncond_states.to(hidden_states.dtype) + hidden_states = torch.cat([uncond_states, hidden_states], dim=0) + else: + # TODO(aryan): remove later + logger.debug("Computing unconditional branch") + + uncond_states, cond_states = hidden_states.chunk(2, dim=0) + if self.tensor_format == "BCFHW": + uncond_states = uncond_states.permute(0, 2, 1, 3, 4) + cond_states = cond_states.permute(0, 2, 1, 3, 4) + if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW": + uncond_states = uncond_states.flatten(0, 1) + cond_states = cond_states.flatten(0, 1) + + low_freq_uncond, high_freq_uncond = _split_low_high_freq(uncond_states.float()) + low_freq_cond, high_freq_cond = _split_low_high_freq(cond_states.float()) + state.low_frequency_delta = low_freq_uncond - low_freq_cond + state.high_frequency_delta = high_freq_uncond - high_freq_cond + + state.iteration += 1 + output = (hidden_states, *output[1:]) if isinstance(output, tuple) else hidden_states + return output + + def reset_state(self, module: nn.Module) -> None: + module._fastercache_state.reset() + - if self.skip_callback(module): - t_2_output, t_output = module._fastercache_state.cache - output = t_output + (t_output - t_2_output) * self.weight_callback(module) +class FasterCacheBlockHook(ModelHook): + _is_stateful = True + + def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: + args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) + state: FasterCacheState = module._fastercache_state + + batch_size = [ + *[arg.size(0) for arg in args if torch.is_tensor(arg)], + *[v.size(0) for v in kwargs.values() if torch.is_tensor(v)], + ][0] + if state.batch_size is None: + # Will be updated on first forward pass through the denoiser + state.batch_size = batch_size + + # If we have to skip due to the skip conditions, then let's skip as expected. + # But, we can't skip if the denoiser wants to infer both unconditional and conditional branches. So, + # if state.batch_size (which is the true unconditional-conditional batch size) is same as the current + # batch size, we don't perform the layer skip. Otherwise, we conditionally skip the layer based on + # what state.skip_callback returns. + if state.skip_callback(module) and state.batch_size != batch_size: + # TODO(aryan): remove later + logger.debug("Skipping layer computation") + t_2_output, t_output = state.cache + + # TODO(aryan): these conditions may not be needed after latest refactor. they exist for safety. do test if they can be removed + if t_2_output.size(0) != batch_size: + # The cache t_2_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just + # take the conditional branch outputs. + assert t_2_output.size(0) == 2 * batch_size + t_2_output = t_2_output[batch_size:] + if t_output.size(0) != batch_size: + # The cache t_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just + # take the conditional branch outputs. + assert t_output.size(0) == 2 * batch_size + t_output = t_output[batch_size:] + + output = t_output + (t_output - t_2_output) * state.weight_callback(module) else: output = module._old_forward(*args, **kwargs) - return module._diffusers_hook.post_forward(module, output) + # The output here can be both unconditional-conditional branch outputs or just conditional branch outputs. + # This is determined at the higher-level denoiser module. We only want to cache the conditional branch outputs. + cache_output = output + if output.size(0) == state.batch_size: + cache_output = cache_output.chunk(2, dim=0)[1] + + # Just to be safe that the output is of the correct size for both unconditional-conditional branch inference + # and only-conditional branch inference. + assert 2 * cache_output.size(0) == state.batch_size + + if state.cache is None: + state.cache = [cache_output, cache_output] + else: + state.cache = [state.cache[-1], cache_output] - def post_forward(self, module: nn.Module, output: Any) -> Any: - module._fastercache_state.update_state(output) - return output + state.iteration += 1 + return module._diffusers_hook.post_forward(module, output) + + def reset_state(self, module: nn.Module) -> None: + module._fastercache_state.reset() # Reference: https://github.com/Vchitect/FasterCache/blob/fab32c15014636dc854948319c0a9a8d92c7acb4/scripts/latte/fastercache_sample_latte.py#L127C1-L143C39 @torch.no_grad() -def _fft(tensor): - tensor_fft = FFT.fft2(tensor) - tensor_fft_shifted = FFT.fftshift(tensor_fft) - batch_size, num_channels, height, width = tensor.size() +def _split_low_high_freq(x): + fft = FFT.fft2(x) + fft_shifted = FFT.fftshift(fft) + height, width = x.shape[-2:] radius = min(height, width) // 5 y_grid, x_grid = torch.meshgrid(torch.arange(height), torch.arange(width)) center_x, center_y = width // 2, height // 2 mask = (x_grid - center_x) ** 2 + (y_grid - center_y) ** 2 <= radius**2 - low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(tensor.device) + low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(x.device) high_freq_mask = ~low_freq_mask - low_freq_fft = tensor_fft_shifted * low_freq_mask - high_freq_fft = tensor_fft_shifted * high_freq_mask + low_freq_fft = fft_shifted * low_freq_mask + high_freq_fft = fft_shifted * high_freq_mask return low_freq_fft, high_freq_fft diff --git a/src/diffusers/pipelines/latte/pipeline_latte.py b/src/diffusers/pipelines/latte/pipeline_latte.py index 19c4a6d1ddf9..9933b4b90029 100644 --- a/src/diffusers/pipelines/latte/pipeline_latte.py +++ b/src/diffusers/pipelines/latte/pipeline_latte.py @@ -25,6 +25,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...models import AutoencoderKL, LatteTransformer3DModel +from ...models.hooks import reset_stateful_hooks from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -719,6 +720,7 @@ def __call__( negative_prompt_embeds, ) self._guidance_scale = guidance_scale + self._current_timestep = None self._interrupt = False # 2. Default height and width to transformer @@ -780,6 +782,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -800,7 +803,7 @@ def __call__( # predict noise model_output noise_pred = self.transformer( - latent_model_input, + hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, timestep=current_timestep, enable_temporal_attentions=enable_temporal_attentions, @@ -836,7 +839,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - if not output_type == "latents": + self._current_timestep = None + + if not output_type == "latent": video = self.decode_latents(latents, video_length, decode_chunk_size=14) video = self.video_processor.postprocess_video(video=video, output_type=output_type) else: @@ -844,6 +849,7 @@ def __call__( # Offload all models self.maybe_free_model_hooks() + reset_stateful_hooks(self.transformer, recurse=True) if not return_dict: return (video,) From 82d85bd1a615541af2b43f86e7c8c099ab7174a7 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 27 Dec 2024 16:13:03 +0100 Subject: [PATCH 05/26] make style --- src/diffusers/models/hooks.py | 8 +++++--- src/diffusers/pipelines/faster_cache_utils.py | 20 +++++++++---------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/diffusers/models/hooks.py b/src/diffusers/models/hooks.py index f29ecfe9a969..b048869d9abd 100644 --- a/src/diffusers/models/hooks.py +++ b/src/diffusers/models/hooks.py @@ -77,7 +77,7 @@ def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: The module detached from this hook. """ return module - + def reset_state(self): if self._is_stateful: raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.") @@ -108,7 +108,7 @@ def detach_hook(self, module): for hook in self.hooks: module = hook.detach_hook(module) return module - + def reset_state(self): for hook in self.hooks: if hook._is_stateful: @@ -216,7 +216,9 @@ def reset_stateful_hooks(module: torch.nn.Module, recurse: bool = False): module (`torch.nn.Module`): The module to reset the stateful hooks from. """ - if hasattr(module, "_diffusers_hook") and (module._diffusers_hook._is_stateful or isinstance(module._diffusers_hook, SequentialHook)): + if hasattr(module, "_diffusers_hook") and ( + module._diffusers_hook._is_stateful or isinstance(module._diffusers_hook, SequentialHook) + ): module._diffusers_hook.reset_state(module) if recurse: diff --git a/src/diffusers/pipelines/faster_cache_utils.py b/src/diffusers/pipelines/faster_cache_utils.py index 020e92794b34..6febda4f4377 100644 --- a/src/diffusers/pipelines/faster_cache_utils.py +++ b/src/diffusers/pipelines/faster_cache_utils.py @@ -236,7 +236,7 @@ def uncond_skip_callback(module: nn.Module) -> bool: is_using_classifier_free_guidance = pipeline.do_classifier_free_guidance if not is_using_classifier_free_guidance: return False - + # We skip the unconditional branch only if the following conditions are met: # 1. We have completed at least one iteration of the denoiser # 2. The current timestep is within the range specified by the user. This is the optimal timestep range @@ -326,7 +326,7 @@ def skip_callback(module: nn.Module) -> bool: class FasterCacheModelHook(ModelHook): _is_stateful = True - + def __init__(self, uncond_cond_input_kwargs_identifiers: List[str], tensor_format: str) -> None: super().__init__() @@ -397,7 +397,7 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: else: # TODO(aryan): remove later logger.debug("Computing unconditional branch") - + uncond_states, cond_states = hidden_states.chunk(2, dim=0) if self.tensor_format == "BCFHW": uncond_states = uncond_states.permute(0, 2, 1, 3, 4) @@ -412,16 +412,16 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: state.high_frequency_delta = high_freq_uncond - high_freq_cond state.iteration += 1 - output = (hidden_states, *output[1:]) if isinstance(output, tuple) else hidden_states + output = (hidden_states, *output[1:]) if isinstance(output, tuple) else hidden_states return output - + def reset_state(self, module: nn.Module) -> None: module._fastercache_state.reset() class FasterCacheBlockHook(ModelHook): _is_stateful = True - + def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) state: FasterCacheState = module._fastercache_state @@ -443,7 +443,7 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: # TODO(aryan): remove later logger.debug("Skipping layer computation") t_2_output, t_output = state.cache - + # TODO(aryan): these conditions may not be needed after latest refactor. they exist for safety. do test if they can be removed if t_2_output.size(0) != batch_size: # The cache t_2_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just @@ -455,7 +455,7 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: # take the conditional branch outputs. assert t_output.size(0) == 2 * batch_size t_output = t_output[batch_size:] - + output = t_output + (t_output - t_2_output) * state.weight_callback(module) else: output = module._old_forward(*args, **kwargs) @@ -465,7 +465,7 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: cache_output = output if output.size(0) == state.batch_size: cache_output = cache_output.chunk(2, dim=0)[1] - + # Just to be safe that the output is of the correct size for both unconditional-conditional branch inference # and only-conditional branch inference. assert 2 * cache_output.size(0) == state.batch_size @@ -477,7 +477,7 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: state.iteration += 1 return module._diffusers_hook.post_forward(module, output) - + def reset_state(self, module: nn.Module) -> None: module._fastercache_state.reset() From 535922287bd834949eb49db99ce8eb1b577aaa59 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 29 Dec 2024 16:24:13 +0100 Subject: [PATCH 06/26] update --- src/diffusers/models/embeddings.py | 2 +- .../pipelines/cogvideo/pipeline_cogvideox.py | 2 + src/diffusers/pipelines/faster_cache_utils.py | 127 ++++++++++-------- .../hunyuan_video/pipeline_hunyuan_video.py | 6 + .../pipelines/mochi/pipeline_mochi.py | 13 +- 5 files changed, 89 insertions(+), 61 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 1768c81ce039..c2eb5b31f8cf 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -334,7 +334,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"): " `from_numpy` is no longer required." " Pass `output_type='pt' to use the new version now." ) - deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False) + # deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False) return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos) if embed_dim % 2 != 0: raise ValueError("embed_dim must be divisible by 2") diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 312155c816fa..112b4c132261 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -24,6 +24,7 @@ from ...loaders import CogVideoXLoraLoaderMixin from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from ...models.embeddings import get_3d_rotary_pos_embed +from ...models.hooks import reset_stateful_hooks from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from ...utils import logging, replace_example_docstring @@ -769,6 +770,7 @@ def __call__( # Offload all models self.maybe_free_model_hooks() + reset_stateful_hooks(self.transformer, recurse=True) if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/faster_cache_utils.py b/src/diffusers/pipelines/faster_cache_utils.py index 6febda4f4377..b77f6daea91d 100644 --- a/src/diffusers/pipelines/faster_cache_utils.py +++ b/src/diffusers/pipelines/faster_cache_utils.py @@ -49,13 +49,12 @@ class FasterCacheConfig: r""" Configuration for [FasterCache](https://huggingface.co/papers/2410.19355). - """ - num_train_timesteps: int = 1000 + Attributes:""" # In the paper and codebase, they hardcode these values to 2. However, it can be made configurable # after some testing. We default to 2 if these parameters are not provided. - spatial_attention_block_skip_range: Optional[int] = None + spatial_attention_block_skip_range: int = 2 temporal_attention_block_skip_range: Optional[int] = None # TODO(aryan): write heuristics for what the best way to obtain these values are @@ -145,6 +144,9 @@ def apply_faster_cache( r""" Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline. + Note: FasterCache should only be applied when using classifer-free guidance. It will not work as expected even if + the inference runs successfully. + Args: pipeline (`DiffusionPipeline`): The diffusion pipeline to apply FasterCache to. @@ -163,15 +165,6 @@ def apply_faster_cache( if config is None: config = FasterCacheConfig() - if config.spatial_attention_block_skip_range is None and config.temporal_attention_block_skip_range is None: - logger.warning( - "FasterCache requires one of `spatial_attention_block_skip_range` and/or `temporal_attention_block_skip_range` " - "to be set to an integer, not `None`. Defaulting to using `spatial_attention_block_skip_range=2` and " - "`temporal_attention_block_skip_range=2`. To avoid this warning, please set one of the above parameters." - ) - config.spatial_attention_block_skip_range = 2 - config.temporal_attention_block_skip_range = 2 - if config.attention_weight_callback is None: # If the user has not provided a weight callback, we default to 0.5 for all timesteps. # In the paper, they recommend using a gradually increasing weight from 0 to 1 as the inference progresses, but @@ -231,12 +224,6 @@ def _apply_fastercache_on_denoiser( pipeline: DiffusionPipeline, denoiser: nn.Module, config: FasterCacheConfig ) -> None: def uncond_skip_callback(module: nn.Module) -> bool: - # If we are not using classifier-free guidance, we cannot skip the denoiser computation. We only compute the - # conditional branch in this case. - is_using_classifier_free_guidance = pipeline.do_classifier_free_guidance - if not is_using_classifier_free_guidance: - return False - # We skip the unconditional branch only if the following conditions are met: # 1. We have completed at least one iteration of the denoiser # 2. The current timestep is within the range specified by the user. This is the optimal timestep range @@ -298,10 +285,6 @@ def _apply_fastercache_on_attention_class( return def skip_callback(module: nn.Module) -> bool: - is_using_classifier_free_guidance = pipeline.do_classifier_free_guidance - if not is_using_classifier_free_guidance: - return False - fastercache_state: FasterCacheState = module._fastercache_state is_within_timestep_range = timestep_skip_range[0] < pipeline._current_timestep < timestep_skip_range[1] @@ -309,9 +292,6 @@ def skip_callback(module: nn.Module) -> bool: # We are still not in the phase of inference where skipping attention is possible without minimal quality # loss, as described in the paper. So, the attention computation cannot be skipped return False - if fastercache_state.cache is None or fastercache_state.iteration < 2: - # We need at least 2 iterations to start skipping attention computation - return False should_compute_attention = ( fastercache_state.iteration > 0 and fastercache_state.iteration % block_skip_range == 0 @@ -358,8 +338,6 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: # TODO(aryan): remove later logger.debug("Skipping unconditional branch computation") - if should_skip_uncond: - breakpoint() output = module._old_forward(*args, **kwargs) # TODO(aryan): handle Transformer2DModelOutput hidden_states = output[0] if isinstance(output, tuple) else output @@ -422,6 +400,22 @@ def reset_state(self, module: nn.Module) -> None: class FasterCacheBlockHook(ModelHook): _is_stateful = True + def _compute_approximated_attention_output( + self, t_2_output: torch.Tensor, t_output: torch.Tensor, weight: float, batch_size: int + ) -> torch.Tensor: + # TODO(aryan): these conditions may not be needed after latest refactor. they exist for safety. do test if they can be removed + if t_2_output.size(0) != batch_size: + # The cache t_2_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just + # take the conditional branch outputs. + assert t_2_output.size(0) == 2 * batch_size + t_2_output = t_2_output[batch_size:] + if t_output.size(0) != batch_size: + # The cache t_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just + # take the conditional branch outputs. + assert t_output.size(0) == 2 * batch_size + t_output = t_output[batch_size:] + return t_output + (t_output - t_2_output) * weight + def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) state: FasterCacheState = module._fastercache_state @@ -435,40 +429,59 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: state.batch_size = batch_size # If we have to skip due to the skip conditions, then let's skip as expected. - # But, we can't skip if the denoiser wants to infer both unconditional and conditional branches. So, - # if state.batch_size (which is the true unconditional-conditional batch size) is same as the current - # batch size, we don't perform the layer skip. Otherwise, we conditionally skip the layer based on - # what state.skip_callback returns. - if state.skip_callback(module) and state.batch_size != batch_size: + # But, we can't skip if the denoiser wants to infer both unconditional and conditional branches. This + # is because the expected output shapes of attention layer will not match if we only return values from + # the cache (which only caches conditional branch outputs). So, if state.batch_size (which is the true + # unconditional-conditional batch size) is same as the current batch size, we don't perform the layer + # skip. Otherwise, we conditionally skip the layer based on what state.skip_callback returns. + should_skip_attention = state.skip_callback(module) and state.batch_size != batch_size + + if should_skip_attention: # TODO(aryan): remove later - logger.debug("Skipping layer computation") - t_2_output, t_output = state.cache - - # TODO(aryan): these conditions may not be needed after latest refactor. they exist for safety. do test if they can be removed - if t_2_output.size(0) != batch_size: - # The cache t_2_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just - # take the conditional branch outputs. - assert t_2_output.size(0) == 2 * batch_size - t_2_output = t_2_output[batch_size:] - if t_output.size(0) != batch_size: - # The cache t_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just - # take the conditional branch outputs. - assert t_output.size(0) == 2 * batch_size - t_output = t_output[batch_size:] - - output = t_output + (t_output - t_2_output) * state.weight_callback(module) + logger.debug("Skipping attention") + + if torch.is_tensor(state.cache): + t_2_output, t_output = state.cache + weight = state.weight_callback(module) + output = self._compute_approximated_attention_output(t_2_output, t_output, weight, batch_size) + else: + # The cache contains multiple tensors from past N iterations (N=2 for FasterCache). We need to handle all of them. + # Diffusers blocks can return multiple tensors - let's call them [A, B, C, ...] for simplicity. + # In our cache, we would have [[A_1, B_1, C_1, ...], [A_2, B_2, C_2, ...], ...] where each list is the output from + # a forward pass of the block. We need to compute the approximated output for each of these tensors. + # The zip(*state.cache) operation will give us [(A_1, A_2, ...), (B_1, B_2, ...), (C_1, C_2, ...), ...] which + # allows us to compute the approximated attention output for each tensor in the cache. + output = () + for t_2_output, t_output in zip(*state.cache): + result = self._compute_approximated_attention_output( + t_2_output, t_output, state.weight_callback(module), batch_size + ) + output += (result,) else: + logger.debug("Computing attention") output = module._old_forward(*args, **kwargs) - # The output here can be both unconditional-conditional branch outputs or just conditional branch outputs. - # This is determined at the higher-level denoiser module. We only want to cache the conditional branch outputs. - cache_output = output - if output.size(0) == state.batch_size: - cache_output = cache_output.chunk(2, dim=0)[1] - - # Just to be safe that the output is of the correct size for both unconditional-conditional branch inference - # and only-conditional branch inference. - assert 2 * cache_output.size(0) == state.batch_size + # Note that the following condition for getting hidden_states should suffice since Diffusers blocks either return + # a single hidden_states tensor, or a tuple of (hidden_states, encoder_hidden_states) tensors. We need to handle + # both cases. + if torch.is_tensor(output): + cache_output = output + if cache_output.size(0) == state.batch_size: + # The output here can be both unconditional-conditional branch outputs or just conditional branch outputs. + # This is determined at the higher-level denoiser module. We only want to cache the conditional branch outputs. + cache_output = cache_output.chunk(2, dim=0)[1] + + # Just to be safe that the output is of the correct size for both unconditional-conditional branch inference + # and only-conditional branch inference. + assert 2 * cache_output.size(0) == state.batch_size + else: + # Cache all return values and perform the same operation as above + cache_output = () + for out in output: + if out.size(0) == state.batch_size: + out = out.chunk(2, dim=0)[1] + assert 2 * out.size(0) == state.batch_size + cache_output += (out,) if state.cache is None: state.cache = [cache_output, cache_output] diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 3b0956a32da3..e2200ef39e3e 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -22,6 +22,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...loaders import HunyuanVideoLoraLoaderMixin from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel +from ...models.hooks import reset_stateful_hooks from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging, replace_example_docstring from ...utils.torch_utils import randn_tensor @@ -573,6 +574,7 @@ def __call__( self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs + self._current_timestep = None self._interrupt = False device = self._execution_device @@ -640,6 +642,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t latent_model_input = latents.to(transformer_dtype) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) @@ -671,6 +674,8 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + self._current_timestep = None + if not output_type == "latent": latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor video = self.vae.decode(latents, return_dict=False)[0] @@ -680,6 +685,7 @@ def __call__( # Offload all models self.maybe_free_model_hooks() + reset_stateful_hooks(self.transformer, recurse=True) if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index aac4e32e33f0..7899ab5f409c 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -21,8 +21,8 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...loaders import Mochi1LoraLoaderMixin -from ...models.autoencoders import AutoencoderKL -from ...models.transformers import MochiTransformer3DModel +from ...models import AutoencoderKLHunyuanVideo, MochiTransformer3DModel +from ...models.hooks import reset_stateful_hooks from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( is_torch_xla_available, @@ -184,7 +184,7 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin): def __init__( self, scheduler: FlowMatchEulerDiscreteScheduler, - vae: AutoencoderKL, + vae: AutoencoderKLHunyuanVideo, text_encoder: T5EncoderModel, tokenizer: T5TokenizerFast, transformer: MochiTransformer3DModel, @@ -604,6 +604,7 @@ def __call__( self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs + self._current_timestep = None self._interrupt = False # 2. Define call parameters @@ -673,6 +674,9 @@ def __call__( if self.interrupt: continue + # Note: Mochi uses reversed timesteps. To ensure compatibility with methods like FasterCache, we need + # to make sure we're using the correct non-reversed timestep values. + self._current_timestep = 1000 - t latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) @@ -718,6 +722,8 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + self._current_timestep = None + if output_type == "latent": video = latents else: @@ -741,6 +747,7 @@ def __call__( # Offload all models self.maybe_free_model_hooks() + reset_stateful_hooks(self.transformer, recurse=True) if not return_dict: return (video,) From c02f72d61f953f7adbc99bd3fc7e98ad4510a01e Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 29 Dec 2024 16:44:59 +0100 Subject: [PATCH 07/26] fix --- src/diffusers/pipelines/faster_cache_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/faster_cache_utils.py b/src/diffusers/pipelines/faster_cache_utils.py index b77f6daea91d..aad477774e29 100644 --- a/src/diffusers/pipelines/faster_cache_utils.py +++ b/src/diffusers/pipelines/faster_cache_utils.py @@ -440,7 +440,7 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: # TODO(aryan): remove later logger.debug("Skipping attention") - if torch.is_tensor(state.cache): + if torch.is_tensor(state.cache[-1]): t_2_output, t_output = state.cache weight = state.weight_callback(module) output = self._compute_approximated_attention_output(t_2_output, t_output, weight, batch_size) From 30d9aafdabeeef51cf4d0b5473a317b3c152440f Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 2 Jan 2025 00:18:09 +0100 Subject: [PATCH 08/26] make it work with guidance distilled models --- src/diffusers/models/hooks.py | 10 +- src/diffusers/pipelines/faster_cache_utils.py | 184 +++++++++++++++--- src/diffusers/pipelines/flux/pipeline_flux.py | 10 +- 3 files changed, 171 insertions(+), 33 deletions(-) diff --git a/src/diffusers/models/hooks.py b/src/diffusers/models/hooks.py index b048869d9abd..5433a0bd525a 100644 --- a/src/diffusers/models/hooks.py +++ b/src/diffusers/models/hooks.py @@ -78,9 +78,10 @@ def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: """ return module - def reset_state(self): + def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: if self._is_stateful: raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.") + return module class SequentialHook(ModelHook): @@ -109,13 +110,13 @@ def detach_hook(self, module): module = hook.detach_hook(module) return module - def reset_state(self): + def reset_state(self, module): for hook in self.hooks: if hook._is_stateful: - hook.reset_state() + hook.reset_state(module) -def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False): +def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False) -> torch.nn.Module: r""" Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove this behavior and restore the original `forward` method, use `remove_hook_from_module`. @@ -134,6 +135,7 @@ def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = The hook to attach. append (`bool`, *optional*, defaults to `False`): Whether the hook should be chained with an existing one (if module already contains a hook) or not. + Returns: `torch.nn.Module`: The same module, with the hook attached (the module is modified in place, so the result can be discarded). diff --git a/src/diffusers/pipelines/faster_cache_utils.py b/src/diffusers/pipelines/faster_cache_utils.py index aad477774e29..6c8afd257760 100644 --- a/src/diffusers/pipelines/faster_cache_utils.py +++ b/src/diffusers/pipelines/faster_cache_utils.py @@ -32,10 +32,11 @@ _ATTENTION_CLASSES = (Attention,) _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ( - "blocks", - "transformer_blocks", + "blocks.*attn1", + "transformer_blocks.*attn1", + "single_transformer_blocks.*attn1", ) -_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) +_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks.*attn1",) _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS = ( "hidden_states", "encoder_hidden_states", @@ -43,6 +44,7 @@ "attention_mask", "encoder_attention_mask", ) +_GUIDANCE_DISTILLATION_KWARGS_IDENTIFIERS = ("guidance",) @dataclass @@ -50,7 +52,85 @@ class FasterCacheConfig: r""" Configuration for [FasterCache](https://huggingface.co/papers/2410.19355). - Attributes:""" + Attributes: + spatial_attention_block_skip_range (`int`, defaults to `2`): + Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will + be skipped `N - 1` times (i.e., cached attention states will be re-used) before computing the new attention + states again. + temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`): + Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will + be skipped `N - 1` times (i.e., cached attention states will be re-used) before computing the new attention + states again. + spatial_attention_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 681)`): + The timestep range within which the spatial attention computation can be skipped without a significant loss + in quality. This is to be determined by the user based on the underlying model. The first value in the tuple + is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for denoising are + in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at timestep 0). For the + default values, this would mean that the spatial attention computation skipping will be applicable only + after denoising timestep 681 is reached, and continue until the end of the denoising process. + temporal_attention_timestep_skip_range (`Tuple[float, float]`, *optional*, defaults to `None`): + The timestep range within which the temporal attention computation can be skipped without a significant loss + in quality. This is to be determined by the user based on the underlying model. The first value in the tuple + is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for denoising are + in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at timestep 0). + low_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(99, 901)`): + The timestep range within which the low frequency weight scaling update is applied. The first value in the tuple + is the lower bound and the second value is the upper bound of the timestep range. The callback function for + the update is called only within this range. + high_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(-1, 301)`): + The timestep range within which the high frequency weight scaling update is applied. The first value in the tuple + is the lower bound and the second value is the upper bound of the timestep range. The callback function for + the update is called only within this range. + alpha_low_frequency (`float`, defaults to `1.1`): + The weight to scale the low frequency updates by. This is used to approximate the unconditional branch from + the conditional branch outputs. + alpha_high_frequency (`float`, defaults to `1.1`): + The weight to scale the high frequency updates by. This is used to approximate the unconditional branch from + the conditional branch outputs. + unconditional_batch_skip_range (`int`, defaults to `5`): + Process the unconditional branch every `N` iterations. If this is set to `N`, the unconditional branch + computation will be skipped `N - 1` times (i.e., cached unconditional branch states will be re-used) before + computing the new unconditional branch states again. + unconditional_batch_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 641)`): + The timestep range within which the unconditional branch computation can be skipped without a significant loss + in quality. This is to be determined by the user based on the underlying model. The first value in the tuple + is the lower bound and the second value is the upper bound. + spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks.*attn1", "transformer_blocks.*attn1", "single_transformer_blocks.*attn1")`): + The identifiers to match the spatial attention blocks in the model. If the name of the block contains any of + these identifiers, FasterCache will be applied to that block. This can either be the full layer names, partial + layer names, or regex patterns. Matching will always be done using a regex match. + temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks.*attn1",)`): + The identifiers to match the temporal attention blocks in the model. If the name of the block contains any of + these identifiers, FasterCache will be applied to that block. This can either be the full layer names, partial + layer names, or regex patterns. Matching will always be done using a regex match. + attention_weight_callback (`Callable[[nn.Module], float]`, defaults to `None`): + The callback function to determine the weight to scale the attention outputs by. This function should take the + attention module as input and return a float value. This is used to approximate the unconditional branch from + the conditional branch outputs. If not provided, the default weight is 0.5 for all timesteps. Typically, as + described in the paper, this weight should gradually increase from 0 to 1 as the inference progresses. Users + are encouraged to experiment and provide custom weight schedules that take into account the number of inference + steps and underlying model behaviour as denoising progresses. + low_frequency_weight_callback (`Callable[[nn.Module], float]`, defaults to `None`): + The callback function to determine the weight to scale the low frequency updates by. If not provided, the default + weight is 1.1 for timesteps within the range specified (as described in the paper). + high_frequency_weight_callback (`Callable[[nn.Module], float]`, defaults to `None`): + The callback function to determine the weight to scale the high frequency updates by. If not provided, the default + weight is 1.1 for timesteps within the range specified (as described in the paper). + tensor_format (`str`, defaults to `"BCFHW"`): + The format of the input tensors. This should be one of `"BCFHW"`, `"BFCHW"`, or `"BCHW"`. The format is used to + split individual latent frames in order for low and high frequency components to be computed. + _unconditional_conditional_input_kwargs_identifiers (`List[str]`, defaults to `("hidden_states", "encoder_hidden_states", "timestep", "attention_mask", "encoder_attention_mask")`): + The identifiers to match the input kwargs that contain the batchwise-concatenated unconditional and conditional + inputs. If the name of the input kwargs contains any of these identifiers, FasterCache will split the inputs + into unconditional and conditional branches. This must be a list of exact input kwargs names that contain the + batchwise-concatenated unconditional and conditional inputs. + _guidance_distillation_kwargs_identifiers (`List[str]`, defaults to `("guidance",)`): + The identifiers to match the input kwargs that contain the guidance distillation inputs. If the name of the input + kwargs contains any of these identifiers, FasterCache will not split the inputs into unconditional and conditional + branches (unconditional branches are only computed sometimes based on certain checks). This allows usage of + FasterCache in models like Flux-Dev and HunyuanVideo which are guidance-distilled (only attention skipping + related parts are applied, and not unconditional branch approximation). + """ # In the paper and codebase, they hardcode these values to 2. However, it can be made configurable # after some testing. We default to 2 if these parameters are not provided. @@ -81,7 +161,9 @@ class FasterCacheConfig: high_frequency_weight_callback: Callable[[nn.Module], float] = None tensor_format: str = "BCFHW" - unconditional_conditional_input_kwargs_identifiers: List[str] = _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS + + _unconditional_conditional_input_kwargs_identifiers: List[str] = _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS + _guidance_distillation_kwargs_identifiers: List[str] = _GUIDANCE_DISTILLATION_KWARGS_IDENTIFIERS class FasterCacheDenoiserState: @@ -102,14 +184,16 @@ def __init__( self.iteration = 0 self.low_frequency_delta = None self.high_frequency_delta = None + self.is_guidance_distilled = None def reset(self): self.iteration = 0 self.low_frequency_delta = None self.high_frequency_delta = None + self.is_guidance_distilled = None -class FasterCacheState: +class FasterCacheBlockState: r""" State for [FasterCache](https://huggingface.co/papers/2410.19355). Every underlying block that FasterCache is applied to will have an instance of this state. @@ -129,11 +213,13 @@ def __init__( self.iteration = 0 self.batch_size = None self.cache = None + self.is_guidance_distilled = None def reset(self): self.iteration = 0 self.batch_size = None self.cache = None + self.is_guidance_distilled = None def apply_faster_cache( @@ -158,7 +244,22 @@ def apply_faster_cache( Example: ```python - # TODO(aryan) + >>> import torch + >>> from diffusers import CogVideoXPipeline, FasterCacheConfig, apply_faster_cache + + >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> config = FasterCacheConfig( + ... spatial_attention_block_skip_range=2, + ... spatial_attention_timestep_skip_range=(-1, 681), + ... low_frequency_weight_update_timestep_range=(99, 641), + ... high_frequency_weight_update_timestep_range=(-1, 301), + ... spatial_attention_block_identifiers=["transformer_blocks"], + ... attention_weight_callback=lambda _: 0.3, + ... tensor_format="BFCHW", + ... ) + >>> apply_faster_cache(pipe, config) ``` """ @@ -229,7 +330,7 @@ def uncond_skip_callback(module: nn.Module) -> bool: # 2. The current timestep is within the range specified by the user. This is the optimal timestep range # where approximating the unconditional branch from the computation of the conditional branch is possible # without a significant loss in quality. - # 3. The current iteration is not a multiple of the unconditional batch skip range. This is to ensure that + # 3. The current iteration is not a multiple of the unconditional batch skip range. This is done so that # we compute the unconditional branch at least once every few iterations to ensure minimal quality loss. state: FasterCacheDenoiserState = module._fastercache_state @@ -243,7 +344,7 @@ def uncond_skip_callback(module: nn.Module) -> bool: denoiser._fastercache_state = FasterCacheDenoiserState( config.low_frequency_weight_callback, config.high_frequency_weight_callback, uncond_skip_callback ) - hook = FasterCacheModelHook(config.unconditional_conditional_input_kwargs_identifiers, config.tensor_format) + hook = FasterCacheDenoiserHook(config._unconditional_conditional_input_kwargs_identifiers, config._guidance_distillation_kwargs_identifiers, config.tensor_format) add_hook_to_module(denoiser, hook, append=True) @@ -285,7 +386,7 @@ def _apply_fastercache_on_attention_class( return def skip_callback(module: nn.Module) -> bool: - fastercache_state: FasterCacheState = module._fastercache_state + fastercache_state: FasterCacheBlockState = module._fastercache_state is_within_timestep_range = timestep_skip_range[0] < pipeline._current_timestep < timestep_skip_range[1] if not is_within_timestep_range: @@ -299,15 +400,19 @@ def skip_callback(module: nn.Module) -> bool: return not should_compute_attention logger.debug(f"Enabling FasterCache ({block_type}) for layer: {name}") - module._fastercache_state = FasterCacheState(skip_callback, config.attention_weight_callback) + module._fastercache_state = FasterCacheBlockState(skip_callback, config.attention_weight_callback) hook = FasterCacheBlockHook() add_hook_to_module(module, hook, append=True) -class FasterCacheModelHook(ModelHook): +class FasterCacheDenoiserHook(ModelHook): _is_stateful = True - def __init__(self, uncond_cond_input_kwargs_identifiers: List[str], tensor_format: str) -> None: + def __init__(self, + uncond_cond_input_kwargs_identifiers: List[str], + guidance_distillation_kwargs_identifiers: List[str], + tensor_format: str + ) -> None: super().__init__() # We can't easily detect what args are to be split in unconditional and conditional branches. We @@ -315,9 +420,14 @@ def __init__(self, uncond_cond_input_kwargs_identifiers: List[str], tensor_forma # If a model is to be made compatible with FasterCache, the user must ensure that the inputs that # contain batchwise-concatenated unconditional and conditional inputs are passed as kwargs. self.uncond_cond_input_kwargs_identifiers = uncond_cond_input_kwargs_identifiers + + # See documentation for `guidance_distillation_kwargs_identifiers` in FasterCacheConfig for more information + self.guidance_distillation_kwargs_identifiers = guidance_distillation_kwargs_identifiers + self.tensor_format = tensor_format - def _get_cond_input(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + @staticmethod + def _get_cond_input(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # Note: this method assumes that the input tensor is batchwise-concatenated with unconditional inputs # followed by conditional inputs. _, cond = input.chunk(2, dim=0) @@ -330,7 +440,24 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: # Split the unconditional and conditional inputs. We only want to infer the conditional branch if the # requirements for skipping the unconditional branch are met as described in the paper. should_skip_uncond = state.uncond_skip_callback(module) - if should_skip_uncond: + + if state.is_guidance_distilled is None: + # The following check assumes that the guidance embedding are torch tensors for check to pass. This + # seems to be true for all models supported in diffusers + state.is_guidance_distilled = any( + identifier in kwargs and kwargs[identifier] is not None and torch.is_tensor(kwargs[identifier]) + for identifier in self.guidance_distillation_kwargs_identifiers + ) + # Make all children FasterCacheBlockHooks aware of whether the model is guidance distilled or not + # because we cannot determine this within the block hooks + for name, child_module in module.named_modules(): + if hasattr(child_module, "_fastercache_state") and isinstance(child_module._fastercache_state, FasterCacheBlockState): + # TODO(aryan): remove later + logger.debug(f"Setting guidance distillation flag for layer: {name}") + child_module._fastercache_state.is_guidance_distilled = state.is_guidance_distilled + assert state.is_guidance_distilled is not None + + if should_skip_uncond and not state.is_guidance_distilled: kwargs = { k: v if k not in self.uncond_cond_input_kwargs_identifiers else self._get_cond_input(v) for k, v in kwargs.items() @@ -339,6 +466,11 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: logger.debug("Skipping unconditional branch computation") output = module._old_forward(*args, **kwargs) + + if state.is_guidance_distilled: + state.iteration += 1 + return output + # TODO(aryan): handle Transformer2DModelOutput hidden_states = output[0] if isinstance(output, tuple) else output batch_size = hidden_states.size(0) @@ -393,8 +525,9 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: output = (hidden_states, *output[1:]) if isinstance(output, tuple) else hidden_states return output - def reset_state(self, module: nn.Module) -> None: + def reset_state(self, module: nn.Module) -> nn.Module: module._fastercache_state.reset() + return module class FasterCacheBlockHook(ModelHook): @@ -418,7 +551,10 @@ def _compute_approximated_attention_output( def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) - state: FasterCacheState = module._fastercache_state + state: FasterCacheBlockState = module._fastercache_state + + # The denoiser should have set this flag for all children FasterCacheBlockHooks to either True or False + assert state.is_guidance_distilled is not None batch_size = [ *[arg.size(0) for arg in args if torch.is_tensor(arg)], @@ -434,7 +570,7 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: # the cache (which only caches conditional branch outputs). So, if state.batch_size (which is the true # unconditional-conditional batch size) is same as the current batch size, we don't perform the layer # skip. Otherwise, we conditionally skip the layer based on what state.skip_callback returns. - should_skip_attention = state.skip_callback(module) and state.batch_size != batch_size + should_skip_attention = state.skip_callback(module) and (state.is_guidance_distilled or state.batch_size != batch_size) if should_skip_attention: # TODO(aryan): remove later @@ -466,21 +602,16 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: # both cases. if torch.is_tensor(output): cache_output = output - if cache_output.size(0) == state.batch_size: + if not state.is_guidance_distilled and cache_output.size(0) == state.batch_size: # The output here can be both unconditional-conditional branch outputs or just conditional branch outputs. # This is determined at the higher-level denoiser module. We only want to cache the conditional branch outputs. cache_output = cache_output.chunk(2, dim=0)[1] - - # Just to be safe that the output is of the correct size for both unconditional-conditional branch inference - # and only-conditional branch inference. - assert 2 * cache_output.size(0) == state.batch_size else: # Cache all return values and perform the same operation as above cache_output = () for out in output: - if out.size(0) == state.batch_size: + if not state.is_guidance_distilled and out.size(0) == state.batch_size: out = out.chunk(2, dim=0)[1] - assert 2 * out.size(0) == state.batch_size cache_output += (out,) if state.cache is None: @@ -491,8 +622,9 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: state.iteration += 1 return module._diffusers_hook.post_forward(module, output) - def reset_state(self, module: nn.Module) -> None: + def reset_state(self, module: nn.Module) -> nn.Module: module._fastercache_state.reset() + return module # Reference: https://github.com/Vchitect/FasterCache/blob/fab32c15014636dc854948319c0a9a8d92c7acb4/scripts/latte/fastercache_sample_latte.py#L127C1-L143C39 diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 181f0269ce3e..819b0cd5c7c8 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -28,8 +28,8 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin -from ...models.autoencoders import AutoencoderKL -from ...models.transformers import FluxTransformer2DModel +from ...models import AutoencoderKL, FluxTransformer2DModel +from ...models.hooks import reset_stateful_hooks from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( USE_PEFT_BACKEND, @@ -760,6 +760,7 @@ def __call__( self._guidance_scale = guidance_scale self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None self._interrupt = False # 2. Define call parameters @@ -881,6 +882,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t if image_embeds is not None: self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds # broadcast to batch dimension in a way that's compatible with ONNX/Core ML @@ -939,9 +941,10 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + self._current_timestep = None + if output_type == "latent": image = latents - else: latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor @@ -950,6 +953,7 @@ def __call__( # Offload all models self.maybe_free_model_hooks() + reset_stateful_hooks(self.transformer, recurse=True) if not return_dict: return (image,) From 07edfa986f5e7969760252977eac2728ea83aff8 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 2 Jan 2025 10:37:24 +0100 Subject: [PATCH 09/26] update --- src/diffusers/__init__.py | 4 + src/diffusers/models/embeddings.py | 4 +- src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/faster_cache_utils.py | 129 +++++++++--------- 4 files changed, 77 insertions(+), 62 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 5e9ab2a117d1..bb8bdb5930ce 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -276,6 +276,7 @@ "CogVideoXVideoToVideoPipeline", "CogView3PlusPipeline", "CycleDiffusionPipeline", + "FasterCacheConfig", "FluxControlImg2ImgPipeline", "FluxControlInpaintPipeline", "FluxControlNetImg2ImgPipeline", @@ -422,6 +423,7 @@ "WuerstchenCombinedPipeline", "WuerstchenDecoderPipeline", "WuerstchenPriorPipeline", + "apply_faster_cache", ] ) @@ -765,6 +767,7 @@ CogVideoXVideoToVideoPipeline, CogView3PlusPipeline, CycleDiffusionPipeline, + FasterCacheConfig, FluxControlImg2ImgPipeline, FluxControlInpaintPipeline, FluxControlNetImg2ImgPipeline, @@ -909,6 +912,7 @@ WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, WuerstchenPriorPipeline, + apply_faster_cache, ) try: diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index c2eb5b31f8cf..2c0582e9528a 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -334,7 +334,9 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"): " `from_numpy` is no longer required." " Pass `output_type='pt' to use the new version now." ) - # deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False) + # TODO: Needs to be handled or errors out. Updated to 0.34.0 so that the benchmark code + # runs without issues, but this should be handled properly before merge. + deprecate("output_type=='np'", "0.34.0", deprecation_message, standard_warn=False) return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos) if embed_dim % 2 != 0: raise ValueError("embed_dim must be divisible by 2") diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index ce291e5ceb45..278fb5d0b8aa 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -58,6 +58,7 @@ "StableDiffusionMixin", "ImagePipelineOutput", ] + _import_structure["faster_cache_utils"] = ["FasterCacheConfig", "apply_faster_cache"] _import_structure["deprecated"].extend( [ "PNDMPipeline", @@ -449,6 +450,7 @@ from .ddpm import DDPMPipeline from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline from .dit import DiTPipeline + from .faster_cache_utils import FasterCacheConfig, apply_faster_cache from .latent_diffusion import LDMSuperResolutionPipeline from .pipeline_utils import ( AudioPipelineOutput, diff --git a/src/diffusers/pipelines/faster_cache_utils.py b/src/diffusers/pipelines/faster_cache_utils.py index 6c8afd257760..10a77cb98d4e 100644 --- a/src/diffusers/pipelines/faster_cache_utils.py +++ b/src/diffusers/pipelines/faster_cache_utils.py @@ -29,6 +29,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +# TODO(aryan): handle mochi attention _ATTENTION_CLASSES = (Attention,) _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ( @@ -63,73 +64,75 @@ class FasterCacheConfig: states again. spatial_attention_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 681)`): The timestep range within which the spatial attention computation can be skipped without a significant loss - in quality. This is to be determined by the user based on the underlying model. The first value in the tuple - is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for denoising are - in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at timestep 0). For the - default values, this would mean that the spatial attention computation skipping will be applicable only - after denoising timestep 681 is reached, and continue until the end of the denoising process. + in quality. This is to be determined by the user based on the underlying model. The first value in the + tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for + denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at + timestep 0). For the default values, this would mean that the spatial attention computation skipping will + be applicable only after denoising timestep 681 is reached, and continue until the end of the denoising + process. temporal_attention_timestep_skip_range (`Tuple[float, float]`, *optional*, defaults to `None`): - The timestep range within which the temporal attention computation can be skipped without a significant loss - in quality. This is to be determined by the user based on the underlying model. The first value in the tuple - is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for denoising are - in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at timestep 0). + The timestep range within which the temporal attention computation can be skipped without a significant + loss in quality. This is to be determined by the user based on the underlying model. The first value in the + tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for + denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at + timestep 0). low_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(99, 901)`): - The timestep range within which the low frequency weight scaling update is applied. The first value in the tuple - is the lower bound and the second value is the upper bound of the timestep range. The callback function for - the update is called only within this range. + The timestep range within which the low frequency weight scaling update is applied. The first value in the + tuple is the lower bound and the second value is the upper bound of the timestep range. The callback + function for the update is called only within this range. high_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(-1, 301)`): - The timestep range within which the high frequency weight scaling update is applied. The first value in the tuple - is the lower bound and the second value is the upper bound of the timestep range. The callback function for - the update is called only within this range. + The timestep range within which the high frequency weight scaling update is applied. The first value in the + tuple is the lower bound and the second value is the upper bound of the timestep range. The callback + function for the update is called only within this range. alpha_low_frequency (`float`, defaults to `1.1`): The weight to scale the low frequency updates by. This is used to approximate the unconditional branch from the conditional branch outputs. alpha_high_frequency (`float`, defaults to `1.1`): - The weight to scale the high frequency updates by. This is used to approximate the unconditional branch from - the conditional branch outputs. + The weight to scale the high frequency updates by. This is used to approximate the unconditional branch + from the conditional branch outputs. unconditional_batch_skip_range (`int`, defaults to `5`): Process the unconditional branch every `N` iterations. If this is set to `N`, the unconditional branch computation will be skipped `N - 1` times (i.e., cached unconditional branch states will be re-used) before computing the new unconditional branch states again. unconditional_batch_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 641)`): - The timestep range within which the unconditional branch computation can be skipped without a significant loss - in quality. This is to be determined by the user based on the underlying model. The first value in the tuple - is the lower bound and the second value is the upper bound. + The timestep range within which the unconditional branch computation can be skipped without a significant + loss in quality. This is to be determined by the user based on the underlying model. The first value in the + tuple is the lower bound and the second value is the upper bound. spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks.*attn1", "transformer_blocks.*attn1", "single_transformer_blocks.*attn1")`): - The identifiers to match the spatial attention blocks in the model. If the name of the block contains any of - these identifiers, FasterCache will be applied to that block. This can either be the full layer names, partial - layer names, or regex patterns. Matching will always be done using a regex match. + The identifiers to match the spatial attention blocks in the model. If the name of the block contains any + of these identifiers, FasterCache will be applied to that block. This can either be the full layer names, + partial layer names, or regex patterns. Matching will always be done using a regex match. temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks.*attn1",)`): - The identifiers to match the temporal attention blocks in the model. If the name of the block contains any of - these identifiers, FasterCache will be applied to that block. This can either be the full layer names, partial - layer names, or regex patterns. Matching will always be done using a regex match. + The identifiers to match the temporal attention blocks in the model. If the name of the block contains any + of these identifiers, FasterCache will be applied to that block. This can either be the full layer names, + partial layer names, or regex patterns. Matching will always be done using a regex match. attention_weight_callback (`Callable[[nn.Module], float]`, defaults to `None`): - The callback function to determine the weight to scale the attention outputs by. This function should take the - attention module as input and return a float value. This is used to approximate the unconditional branch from - the conditional branch outputs. If not provided, the default weight is 0.5 for all timesteps. Typically, as - described in the paper, this weight should gradually increase from 0 to 1 as the inference progresses. Users - are encouraged to experiment and provide custom weight schedules that take into account the number of inference - steps and underlying model behaviour as denoising progresses. + The callback function to determine the weight to scale the attention outputs by. This function should take + the attention module as input and return a float value. This is used to approximate the unconditional + branch from the conditional branch outputs. If not provided, the default weight is 0.5 for all timesteps. + Typically, as described in the paper, this weight should gradually increase from 0 to 1 as the inference + progresses. Users are encouraged to experiment and provide custom weight schedules that take into account + the number of inference steps and underlying model behaviour as denoising progresses. low_frequency_weight_callback (`Callable[[nn.Module], float]`, defaults to `None`): - The callback function to determine the weight to scale the low frequency updates by. If not provided, the default - weight is 1.1 for timesteps within the range specified (as described in the paper). + The callback function to determine the weight to scale the low frequency updates by. If not provided, the + default weight is 1.1 for timesteps within the range specified (as described in the paper). high_frequency_weight_callback (`Callable[[nn.Module], float]`, defaults to `None`): - The callback function to determine the weight to scale the high frequency updates by. If not provided, the default - weight is 1.1 for timesteps within the range specified (as described in the paper). + The callback function to determine the weight to scale the high frequency updates by. If not provided, the + default weight is 1.1 for timesteps within the range specified (as described in the paper). tensor_format (`str`, defaults to `"BCFHW"`): - The format of the input tensors. This should be one of `"BCFHW"`, `"BFCHW"`, or `"BCHW"`. The format is used to - split individual latent frames in order for low and high frequency components to be computed. + The format of the input tensors. This should be one of `"BCFHW"`, `"BFCHW"`, or `"BCHW"`. The format is + used to split individual latent frames in order for low and high frequency components to be computed. _unconditional_conditional_input_kwargs_identifiers (`List[str]`, defaults to `("hidden_states", "encoder_hidden_states", "timestep", "attention_mask", "encoder_attention_mask")`): - The identifiers to match the input kwargs that contain the batchwise-concatenated unconditional and conditional - inputs. If the name of the input kwargs contains any of these identifiers, FasterCache will split the inputs - into unconditional and conditional branches. This must be a list of exact input kwargs names that contain the - batchwise-concatenated unconditional and conditional inputs. + The identifiers to match the input kwargs that contain the batchwise-concatenated unconditional and + conditional inputs. If the name of the input kwargs contains any of these identifiers, FasterCache will + split the inputs into unconditional and conditional branches. This must be a list of exact input kwargs + names that contain the batchwise-concatenated unconditional and conditional inputs. _guidance_distillation_kwargs_identifiers (`List[str]`, defaults to `("guidance",)`): - The identifiers to match the input kwargs that contain the guidance distillation inputs. If the name of the input - kwargs contains any of these identifiers, FasterCache will not split the inputs into unconditional and conditional - branches (unconditional branches are only computed sometimes based on certain checks). This allows usage of - FasterCache in models like Flux-Dev and HunyuanVideo which are guidance-distilled (only attention skipping - related parts are applied, and not unconditional branch approximation). + The identifiers to match the input kwargs that contain the guidance distillation inputs. If the name of the + input kwargs contains any of these identifiers, FasterCache will not split the inputs into unconditional + and conditional branches (unconditional branches are only computed sometimes based on certain checks). This + allows usage of FasterCache in models like Flux-Dev and HunyuanVideo which are guidance-distilled (only + attention skipping related parts are applied, and not unconditional branch approximation). """ # In the paper and codebase, they hardcode these values to 2. However, it can be made configurable @@ -225,7 +228,6 @@ def reset(self): def apply_faster_cache( pipeline: DiffusionPipeline, config: Optional[FasterCacheConfig] = None, - denoiser: Optional[nn.Module] = None, ) -> None: r""" Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline. @@ -238,9 +240,6 @@ def apply_faster_cache( The diffusion pipeline to apply FasterCache to. config (`Optional[FasterCacheConfig]`, `optional`, defaults to `None`): The configuration to use for FasterCache. - denoiser (`Optional[nn.Module]`, `optional`, defaults to `None`): - The denoiser module to apply FasterCache to. If `None`, the pipeline's transformer or unet module will be - used. Example: ```python @@ -310,8 +309,7 @@ def high_frequency_weight_callback(module: nn.Module) -> float: if config.tensor_format not in supported_tensor_formats: raise ValueError(f"`tensor_format` must be one of {supported_tensor_formats}, but got {config.tensor_format}.") - if denoiser is None: - denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet + denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet _apply_fastercache_on_denoiser(pipeline, denoiser, config) for name, module in denoiser.named_modules(): @@ -344,7 +342,11 @@ def uncond_skip_callback(module: nn.Module) -> bool: denoiser._fastercache_state = FasterCacheDenoiserState( config.low_frequency_weight_callback, config.high_frequency_weight_callback, uncond_skip_callback ) - hook = FasterCacheDenoiserHook(config._unconditional_conditional_input_kwargs_identifiers, config._guidance_distillation_kwargs_identifiers, config.tensor_format) + hook = FasterCacheDenoiserHook( + config._unconditional_conditional_input_kwargs_identifiers, + config._guidance_distillation_kwargs_identifiers, + config.tensor_format, + ) add_hook_to_module(denoiser, hook, append=True) @@ -408,11 +410,12 @@ def skip_callback(module: nn.Module) -> bool: class FasterCacheDenoiserHook(ModelHook): _is_stateful = True - def __init__(self, - uncond_cond_input_kwargs_identifiers: List[str], - guidance_distillation_kwargs_identifiers: List[str], - tensor_format: str - ) -> None: + def __init__( + self, + uncond_cond_input_kwargs_identifiers: List[str], + guidance_distillation_kwargs_identifiers: List[str], + tensor_format: str, + ) -> None: super().__init__() # We can't easily detect what args are to be split in unconditional and conditional branches. We @@ -451,7 +454,9 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: # Make all children FasterCacheBlockHooks aware of whether the model is guidance distilled or not # because we cannot determine this within the block hooks for name, child_module in module.named_modules(): - if hasattr(child_module, "_fastercache_state") and isinstance(child_module._fastercache_state, FasterCacheBlockState): + if hasattr(child_module, "_fastercache_state") and isinstance( + child_module._fastercache_state, FasterCacheBlockState + ): # TODO(aryan): remove later logger.debug(f"Setting guidance distillation flag for layer: {name}") child_module._fastercache_state.is_guidance_distilled = state.is_guidance_distilled @@ -570,7 +575,9 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: # the cache (which only caches conditional branch outputs). So, if state.batch_size (which is the true # unconditional-conditional batch size) is same as the current batch size, we don't perform the layer # skip. Otherwise, we conditionally skip the layer based on what state.skip_callback returns. - should_skip_attention = state.skip_callback(module) and (state.is_guidance_distilled or state.batch_size != batch_size) + should_skip_attention = state.skip_callback(module) and ( + state.is_guidance_distilled or state.batch_size != batch_size + ) if should_skip_attention: # TODO(aryan): remove later From 436b7727a2b18bc2c2648b63f95980c5e644efe6 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 2 Jan 2025 10:37:37 +0100 Subject: [PATCH 10/26] make fix-copies --- .../dummy_torch_and_transformers_objects.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 9b36be9e0604..7298c6a66726 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -377,6 +377,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class FasterCacheConfig(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class FluxControlImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -2535,3 +2550,7 @@ def from_config(cls, *args, **kwargs): @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) + + +def apply_faster_cache(*args, **kwargs): + requires_backends(apply_faster_cache, ["torch", "transformers"]) From d68977d7c6bb6eceaabb10bbf07a3f7d12cfba4a Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 2 Jan 2025 11:44:17 +0100 Subject: [PATCH 11/26] add tests --- src/diffusers/pipelines/faster_cache_utils.py | 11 +- tests/pipelines/cogvideo/test_cogvideox.py | 7 +- tests/pipelines/flux/test_pipeline_flux.py | 9 +- tests/pipelines/latte/test_latte.py | 18 +- tests/pipelines/mochi/test_mochi.py | 8 +- tests/pipelines/test_pipelines_common.py | 164 ++++++++++++++++++ 6 files changed, 197 insertions(+), 20 deletions(-) diff --git a/src/diffusers/pipelines/faster_cache_utils.py b/src/diffusers/pipelines/faster_cache_utils.py index 10a77cb98d4e..cc4bbe23d880 100644 --- a/src/diffusers/pipelines/faster_cache_utils.py +++ b/src/diffusers/pipelines/faster_cache_utils.py @@ -33,11 +33,11 @@ _ATTENTION_CLASSES = (Attention,) _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ( - "blocks.*attn1", - "transformer_blocks.*attn1", - "single_transformer_blocks.*attn1", + "blocks.*attn", + "transformer_blocks.*attn", + "single_transformer_blocks.*attn", ) -_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks.*attn1",) +_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks.*attn",) _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS = ( "hidden_states", "encoder_hidden_states", @@ -263,6 +263,7 @@ def apply_faster_cache( """ if config is None: + logger.warning("No FasterCacheConfig provided. Using default configuration.") config = FasterCacheConfig() if config.attention_weight_callback is None: @@ -271,7 +272,7 @@ def apply_faster_cache( # this depends from model-to-model. It is required by the user to provide a weight callback if they want to # use a different weight function. Defaulting to 0.5 works well in practice for most cases. logger.warning( - "FasterCache requires an `attention_weight_callback` to be set. Defaulting to using a weight of 0.5 for all timesteps." + "No `attention_weight_callback` provided when enabling FasterCache. Defaulting to using a weight of 0.5 for all timesteps." ) config.attention_weight_callback = lambda _: 0.5 diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py index 884ddfb2a95a..2b86ca19911f 100644 --- a/tests/pipelines/cogvideo/test_cogvideox.py +++ b/tests/pipelines/cogvideo/test_cogvideox.py @@ -31,6 +31,7 @@ from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import ( + FasterCacheTesterMixin, PipelineTesterMixin, check_qkv_fusion_matches_attn_procs_length, check_qkv_fusion_processors_exist, @@ -41,7 +42,7 @@ enable_full_determinism() -class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class CogVideoXPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unittest.TestCase): pipeline_class = CogVideoXPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS @@ -59,7 +60,7 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ) test_xformers_attention = False - def get_dummy_components(self): + def get_dummy_components(self, num_layers: int = 1): torch.manual_seed(0) transformer = CogVideoXTransformer3DModel( # Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings @@ -71,7 +72,7 @@ def get_dummy_components(self): out_channels=4, time_embed_dim=2, text_embed_dim=32, # Must match with tiny-random-t5 - num_layers=1, + num_layers=num_layers, sample_width=2, # latent width: 2 -> final width: 16 sample_height=2, # latent height: 2 -> final height: 16 sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9 diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index 7981e6c2a93b..8ac10adeaf4c 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -16,6 +16,7 @@ ) from ..test_pipelines_common import ( + FasterCacheTesterMixin, FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fusion_matches_attn_procs_length, @@ -23,7 +24,7 @@ ) -class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin): +class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin, FasterCacheTesterMixin): pipeline_class = FluxPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) batch_params = frozenset(["prompt"]) @@ -31,13 +32,13 @@ class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapte # there is no xformers processor for Flux test_xformers_attention = False - def get_dummy_components(self): + def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): torch.manual_seed(0) transformer = FluxTransformer2DModel( patch_size=1, in_channels=4, - num_layers=1, - num_single_layers=1, + num_layers=num_layers, + num_single_layers=num_single_layers, attention_head_dim=16, num_attention_heads=2, joint_attention_dim=32, diff --git a/tests/pipelines/latte/test_latte.py b/tests/pipelines/latte/test_latte.py index 9667ebff249d..38f879aa096e 100644 --- a/tests/pipelines/latte/test_latte.py +++ b/tests/pipelines/latte/test_latte.py @@ -25,6 +25,7 @@ from diffusers import ( AutoencoderKL, DDIMScheduler, + FasterCacheConfig, LattePipeline, LatteTransformer3DModel, ) @@ -38,13 +39,13 @@ ) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin, to_np +from ..test_pipelines_common import FasterCacheTesterMixin, PipelineTesterMixin, to_np enable_full_determinism() -class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class LattePipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unittest.TestCase): pipeline_class = LattePipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS @@ -53,11 +54,20 @@ class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase): required_optional_params = PipelineTesterMixin.required_optional_params - def get_dummy_components(self): + fastercache_config = FasterCacheConfig( + spatial_attention_block_skip_range=2, + temporal_attention_block_skip_range=2, + spatial_attention_timestep_skip_range=(-1, 901), + temporal_attention_timestep_skip_range=(-1, 901), + unconditional_batch_skip_range=2, + attention_weight_callback=lambda _: 0.5, + ) + + def get_dummy_components(self, num_layers: int = 1): torch.manual_seed(0) transformer = LatteTransformer3DModel( sample_size=8, - num_layers=1, + num_layers=num_layers, patch_size=2, attention_head_dim=8, num_attention_heads=3, diff --git a/tests/pipelines/mochi/test_mochi.py b/tests/pipelines/mochi/test_mochi.py index bbcf6d210ce5..dbbc7ef4f748 100644 --- a/tests/pipelines/mochi/test_mochi.py +++ b/tests/pipelines/mochi/test_mochi.py @@ -30,13 +30,13 @@ ) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin, to_np +from ..test_pipelines_common import FasterCacheTesterMixin, PipelineTesterMixin, to_np enable_full_determinism() -class MochiPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unittest.TestCase): pipeline_class = MochiPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS @@ -54,13 +54,13 @@ class MochiPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ) test_xformers_attention = False - def get_dummy_components(self): + def get_dummy_components(self, num_layers: int = 2): torch.manual_seed(0) transformer = MochiTransformer3DModel( patch_size=2, num_attention_heads=2, attention_head_dim=8, - num_layers=2, + num_layers=num_layers, pooled_projection_dim=16, in_channels=12, out_channels=None, diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 764be1890cc5..368572e4531a 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -23,10 +23,12 @@ ConsistencyDecoderVAE, DDIMScheduler, DiffusionPipeline, + FasterCacheConfig, KolorsPipeline, StableDiffusionPipeline, StableDiffusionXLPipeline, UNet2DConditionModel, + apply_faster_cache, ) from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin @@ -35,6 +37,7 @@ from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel from diffusers.models.unets.unet_i2vgen_xl import I2VGenXLUNet from diffusers.models.unets.unet_motion_model import UNetMotionModel +from diffusers.pipelines.faster_cache_utils import FasterCacheBlockHook, FasterCacheDenoiserHook from diffusers.pipelines.pipeline_utils import StableDiffusionMixin from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import logging @@ -2271,6 +2274,167 @@ def _test_save_load_optional_components(self, expected_max_difference=1e-4): self.assertLess(max_diff, expected_max_difference) +class FasterCacheTesterMixin: + fastercache_config = FasterCacheConfig( + spatial_attention_block_skip_range=2, + spatial_attention_timestep_skip_range=(-1, 901), + unconditional_batch_skip_range=2, + attention_weight_callback=lambda _: 0.5, + ) + + def test_fastercache_basic_warning_or_errors_raised(self): + components = self.get_dummy_components() + + logger = logging.get_logger("diffusers.pipelines.faster_cache_utils") + logger.setLevel(logging.INFO) + + # Check if warning is raised when no FasterCacheConfig is provided + pipe = self.pipeline_class(**components) + with CaptureLogger(logger) as cap_logger: + apply_faster_cache(pipe) + self.assertTrue("No FasterCacheConfig provided" in cap_logger.out) + + # Check if warning is raise when no attention_weight_callback is provided + pipe = self.pipeline_class(**components) + with CaptureLogger(logger) as cap_logger: + config = FasterCacheConfig(spatial_attention_block_skip_range=2, attention_weight_callback=None) + apply_faster_cache(pipe, config) + self.assertTrue("No `attention_weight_callback` provided when enabling FasterCache" in cap_logger.out) + + # Check if error raised when unsupported tensor format used + pipe = self.pipeline_class(**components) + with self.assertRaises(ValueError): + config = FasterCacheConfig(spatial_attention_block_skip_range=2, tensor_format="BFHWC") + apply_faster_cache(pipe, config) + + def test_fastercache_inference(self, expected_atol: float = 0.1): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + num_layers = 2 + components = self.get_dummy_components(num_layers=num_layers) + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 4 + output = pipe(**inputs)[0] + original_image_slice = output.flatten() + original_image_slice = np.concatenate((original_image_slice[:8], original_image_slice[-8:])) + + apply_faster_cache(pipe, self.fastercache_config) + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 4 + output = pipe(**inputs)[0] + image_slice_fastercache_enabled = output.flatten() + image_slice_fastercache_enabled = np.concatenate( + (image_slice_fastercache_enabled[:8], image_slice_fastercache_enabled[-8:]) + ) + + assert np.allclose( + original_image_slice, image_slice_fastercache_enabled, atol=expected_atol + ), "FasterCache outputs should not differ much in specified timestep range." + + def test_fastercache_state(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + num_layers = 0 + num_single_layers = 0 + dummy_component_kwargs = {} + dummy_component_parameters = inspect.signature(self.get_dummy_components).parameters + if "num_layers" in dummy_component_parameters: + num_layers = 2 + dummy_component_kwargs["num_layers"] = num_layers + if "num_single_layers" in dummy_component_parameters: + num_single_layers = 2 + dummy_component_kwargs["num_single_layers"] = num_single_layers + + components = self.get_dummy_components(**dummy_component_kwargs) + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + apply_faster_cache(pipe, self.fastercache_config) + + expected_hooks = 0 + if self.fastercache_config.spatial_attention_block_skip_range is not None: + expected_hooks += num_layers + num_single_layers + if self.fastercache_config.temporal_attention_block_skip_range is not None: + expected_hooks += num_layers + num_single_layers + + # Check if fastercache denoiser hook is attached + denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet + self.assertTrue( + hasattr(denoiser, "_diffusers_hook") and isinstance(denoiser._diffusers_hook, FasterCacheDenoiserHook), + "Hook should be of type FasterCacheDenoiserHook.", + ) + + # Check if all blocks have fastercache block hook attached + count = 0 + for name, module in denoiser.named_modules(): + if hasattr(module, "_diffusers_hook"): + if name == "": + # Skip the root denoiser module + continue + count += 1 + self.assertTrue( + isinstance(module._diffusers_hook, FasterCacheBlockHook), + "Hook should be of type FasterCacheBlockHook.", + ) + self.assertEqual(count, expected_hooks, "Number of hooks should match expected number.") + + # Perform inference to ensure that states are updated correctly + def fastercache_state_check_callback(pipe, i, t, kwargs): + for name, module in denoiser.named_modules(): + if not hasattr(module, "_diffusers_hook"): + continue + + state = module._fastercache_state + + if name == "": + # Root denoiser module + self.assertTrue(state.low_frequency_delta is not None, "Low frequency delta should be set.") + self.assertTrue(state.high_frequency_delta is not None, "High frequency delta should be set.") + else: + # Internal blocks + self.assertTrue(state.cache is not None and len(state.cache) == 2, "Cache should be set.") + + self.assertTrue(state.iteration == i + 1, "Hook iteration state should have updated during inference.") + self.assertTrue( + state.is_guidance_distilled is not None, + "`is_guidance_distilled` should be set to either True or False.", + ) + + return {} + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 4 + inputs["callback_on_step_end"] = fastercache_state_check_callback + _ = pipe(**inputs)[0] + + # After inference, reset_stateful_hooks is called within the pipeline, which should have reset the states + for name, module in denoiser.named_modules(): + if not hasattr(module, "_diffusers_hook"): + continue + + state = module._fastercache_state + + if name == "": + # Root denoiser module + self.assertTrue(state.iteration == 0, "Iteration should be reset to 0.") + self.assertTrue(state.low_frequency_delta is None, "Low frequency delta should be reset to None.") + self.assertTrue(state.high_frequency_delta is None, "High frequency delta should be reset to None.") + self.assertTrue( + state.is_guidance_distilled is None, "`is_guidance_distilled` should be reset to None." + ) + else: + self.assertTrue(state.iteration == 0, "Iteration should be reset to 0.") + self.assertTrue(state.batch_size is None, "Batch size should be reset to None.") + self.assertTrue(state.cache is None, "Cache should be reset to None.") + self.assertTrue( + state.is_guidance_distilled is None, "`is_guidance_distilled` should be reset to None." + ) + + # Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. # This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a # reference image. From f3cb80caf870660419656052d8585df63ba9b8a7 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 2 Jan 2025 13:17:28 +0100 Subject: [PATCH 12/26] update --- src/diffusers/pipelines/faster_cache_utils.py | 29 +++++++++---------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/src/diffusers/pipelines/faster_cache_utils.py b/src/diffusers/pipelines/faster_cache_utils.py index cc4bbe23d880..fbc8b0677d9c 100644 --- a/src/diffusers/pipelines/faster_cache_utils.py +++ b/src/diffusers/pipelines/faster_cache_utils.py @@ -141,20 +141,20 @@ class FasterCacheConfig: temporal_attention_block_skip_range: Optional[int] = None # TODO(aryan): write heuristics for what the best way to obtain these values are - spatial_attention_timestep_skip_range: Tuple[float, float] = (-1, 681) - temporal_attention_timestep_skip_range: Tuple[float, float] = (-1, 681) + spatial_attention_timestep_skip_range: Tuple[int, int] = (-1, 681) + temporal_attention_timestep_skip_range: Tuple[int, int] = (-1, 681) # Indicator functions for low/high frequency as mentioned in Equation 11 of the paper low_frequency_weight_update_timestep_range: Tuple[int, int] = (99, 901) high_frequency_weight_update_timestep_range: Tuple[int, int] = (-1, 301) # ⍺1 and ⍺2 as mentioned in Equation 11 of the paper - alpha_low_frequency = 1.1 - alpha_high_frequency = 1.1 + alpha_low_frequency: float = 1.1 + alpha_high_frequency: float = 1.1 # n as described in CFG-Cache explanation in the paper - dependant on the model unconditional_batch_skip_range: int = 5 - unconditional_batch_timestep_skip_range: Tuple[float, float] = (-1, 641) + unconditional_batch_timestep_skip_range: Tuple[int, int] = (-1, 641) spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS @@ -184,10 +184,10 @@ def __init__( self.high_frequency_weight_callback = high_frequency_weight_callback self.uncond_skip_callback = uncond_skip_callback - self.iteration = 0 - self.low_frequency_delta = None - self.high_frequency_delta = None - self.is_guidance_distilled = None + self.iteration: int = 0 + self.low_frequency_delta: torch.Tensor = None + self.high_frequency_delta: torch.Tensor = None + self.is_guidance_distilled: bool = None def reset(self): self.iteration = 0 @@ -213,10 +213,10 @@ def __init__( self.skip_callback = skip_callback self.weight_callback = weight_callback - self.iteration = 0 - self.batch_size = None - self.cache = None - self.is_guidance_distilled = None + self.iteration: int = 0 + self.batch_size: int = None + self.cache: Tuple[torch.Tensor, torch.Tensor] = None + self.is_guidance_distilled: bool = None def reset(self): self.iteration = 0 @@ -232,9 +232,6 @@ def apply_faster_cache( r""" Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline. - Note: FasterCache should only be applied when using classifer-free guidance. It will not work as expected even if - the inference runs successfully. - Args: pipeline (`DiffusionPipeline`): The diffusion pipeline to apply FasterCache to. From 3c498efaa573604a45963ef73ae63b60f52155a5 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 2 Jan 2025 13:18:41 +0100 Subject: [PATCH 13/26] apply_faster_cache -> apply_fastercache --- src/diffusers/__init__.py | 4 ++-- src/diffusers/pipelines/__init__.py | 4 ++-- ...{faster_cache_utils.py => fastercache_utils.py} | 6 +++--- .../utils/dummy_torch_and_transformers_objects.py | 4 ++-- tests/pipelines/test_pipelines_common.py | 14 +++++++------- 5 files changed, 16 insertions(+), 16 deletions(-) rename src/diffusers/pipelines/{faster_cache_utils.py => fastercache_utils.py} (99%) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index bb8bdb5930ce..d5a4720ef52c 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -423,7 +423,7 @@ "WuerstchenCombinedPipeline", "WuerstchenDecoderPipeline", "WuerstchenPriorPipeline", - "apply_faster_cache", + "apply_fastercache", ] ) @@ -912,7 +912,7 @@ WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, WuerstchenPriorPipeline, - apply_faster_cache, + apply_fastercache, ) try: diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 278fb5d0b8aa..7194d247ee64 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -58,7 +58,7 @@ "StableDiffusionMixin", "ImagePipelineOutput", ] - _import_structure["faster_cache_utils"] = ["FasterCacheConfig", "apply_faster_cache"] + _import_structure["faster_cache_utils"] = ["FasterCacheConfig", "apply_fastercache"] _import_structure["deprecated"].extend( [ "PNDMPipeline", @@ -450,7 +450,7 @@ from .ddpm import DDPMPipeline from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline from .dit import DiTPipeline - from .faster_cache_utils import FasterCacheConfig, apply_faster_cache + from .fastercache_utils import FasterCacheConfig, apply_fastercache from .latent_diffusion import LDMSuperResolutionPipeline from .pipeline_utils import ( AudioPipelineOutput, diff --git a/src/diffusers/pipelines/faster_cache_utils.py b/src/diffusers/pipelines/fastercache_utils.py similarity index 99% rename from src/diffusers/pipelines/faster_cache_utils.py rename to src/diffusers/pipelines/fastercache_utils.py index fbc8b0677d9c..14a85dcd0244 100644 --- a/src/diffusers/pipelines/faster_cache_utils.py +++ b/src/diffusers/pipelines/fastercache_utils.py @@ -225,7 +225,7 @@ def reset(self): self.is_guidance_distilled = None -def apply_faster_cache( +def apply_fastercache( pipeline: DiffusionPipeline, config: Optional[FasterCacheConfig] = None, ) -> None: @@ -241,7 +241,7 @@ def apply_faster_cache( Example: ```python >>> import torch - >>> from diffusers import CogVideoXPipeline, FasterCacheConfig, apply_faster_cache + >>> from diffusers import CogVideoXPipeline, FasterCacheConfig, apply_fastercache >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") @@ -255,7 +255,7 @@ def apply_faster_cache( ... attention_weight_callback=lambda _: 0.3, ... tensor_format="BFCHW", ... ) - >>> apply_faster_cache(pipe, config) + >>> apply_fastercache(pipe, config) ``` """ diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 7298c6a66726..2ac3d30d86c8 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2552,5 +2552,5 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -def apply_faster_cache(*args, **kwargs): - requires_backends(apply_faster_cache, ["torch", "transformers"]) +def apply_fastercache(*args, **kwargs): + requires_backends(apply_fastercache, ["torch", "transformers"]) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 368572e4531a..ee9900ebc1bd 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -28,7 +28,7 @@ StableDiffusionPipeline, StableDiffusionXLPipeline, UNet2DConditionModel, - apply_faster_cache, + apply_fastercache, ) from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin @@ -37,7 +37,7 @@ from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel from diffusers.models.unets.unet_i2vgen_xl import I2VGenXLUNet from diffusers.models.unets.unet_motion_model import UNetMotionModel -from diffusers.pipelines.faster_cache_utils import FasterCacheBlockHook, FasterCacheDenoiserHook +from diffusers.pipelines.fastercache_utils import FasterCacheBlockHook, FasterCacheDenoiserHook from diffusers.pipelines.pipeline_utils import StableDiffusionMixin from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import logging @@ -2291,21 +2291,21 @@ def test_fastercache_basic_warning_or_errors_raised(self): # Check if warning is raised when no FasterCacheConfig is provided pipe = self.pipeline_class(**components) with CaptureLogger(logger) as cap_logger: - apply_faster_cache(pipe) + apply_fastercache(pipe) self.assertTrue("No FasterCacheConfig provided" in cap_logger.out) # Check if warning is raise when no attention_weight_callback is provided pipe = self.pipeline_class(**components) with CaptureLogger(logger) as cap_logger: config = FasterCacheConfig(spatial_attention_block_skip_range=2, attention_weight_callback=None) - apply_faster_cache(pipe, config) + apply_fastercache(pipe, config) self.assertTrue("No `attention_weight_callback` provided when enabling FasterCache" in cap_logger.out) # Check if error raised when unsupported tensor format used pipe = self.pipeline_class(**components) with self.assertRaises(ValueError): config = FasterCacheConfig(spatial_attention_block_skip_range=2, tensor_format="BFHWC") - apply_faster_cache(pipe, config) + apply_fastercache(pipe, config) def test_fastercache_inference(self, expected_atol: float = 0.1): device = "cpu" # ensure determinism for the device-dependent torch.Generator @@ -2321,7 +2321,7 @@ def test_fastercache_inference(self, expected_atol: float = 0.1): original_image_slice = output.flatten() original_image_slice = np.concatenate((original_image_slice[:8], original_image_slice[-8:])) - apply_faster_cache(pipe, self.fastercache_config) + apply_fastercache(pipe, self.fastercache_config) inputs = self.get_dummy_inputs(device) inputs["num_inference_steps"] = 4 @@ -2353,7 +2353,7 @@ def test_fastercache_state(self): pipe = self.pipeline_class(**components) pipe.set_progress_bar_config(disable=None) - apply_faster_cache(pipe, self.fastercache_config) + apply_fastercache(pipe, self.fastercache_config) expected_hooks = 0 if self.fastercache_config.spatial_attention_block_skip_range is not None: From 4996dfd0280184088f984437537ee808c6f056c5 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 2 Jan 2025 15:29:07 +0100 Subject: [PATCH 14/26] fix --- src/diffusers/models/hooks.py | 1 + src/diffusers/pipelines/fastercache_utils.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/hooks.py b/src/diffusers/models/hooks.py index 5433a0bd525a..1686505f8f94 100644 --- a/src/diffusers/models/hooks.py +++ b/src/diffusers/models/hooks.py @@ -114,6 +114,7 @@ def reset_state(self, module): for hook in self.hooks: if hook._is_stateful: hook.reset_state(module) + return module def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False) -> torch.nn.Module: diff --git a/src/diffusers/pipelines/fastercache_utils.py b/src/diffusers/pipelines/fastercache_utils.py index 14a85dcd0244..84280a708c6e 100644 --- a/src/diffusers/pipelines/fastercache_utils.py +++ b/src/diffusers/pipelines/fastercache_utils.py @@ -526,7 +526,7 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: state.iteration += 1 output = (hidden_states, *output[1:]) if isinstance(output, tuple) else hidden_states - return output + return module._diffusers_hook.post_forward(module, output) def reset_state(self, module: nn.Module) -> nn.Module: module._fastercache_state.reset() From 04874a7a5bc5384d155198b953d02565f49a115b Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 2 Jan 2025 15:48:15 +0100 Subject: [PATCH 15/26] reorder --- src/diffusers/pipelines/fastercache_utils.py | 360 +++++++++---------- 1 file changed, 180 insertions(+), 180 deletions(-) diff --git a/src/diffusers/pipelines/fastercache_utils.py b/src/diffusers/pipelines/fastercache_utils.py index 84280a708c6e..542e1b7b888d 100644 --- a/src/diffusers/pipelines/fastercache_utils.py +++ b/src/diffusers/pipelines/fastercache_utils.py @@ -225,186 +225,6 @@ def reset(self): self.is_guidance_distilled = None -def apply_fastercache( - pipeline: DiffusionPipeline, - config: Optional[FasterCacheConfig] = None, -) -> None: - r""" - Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline. - - Args: - pipeline (`DiffusionPipeline`): - The diffusion pipeline to apply FasterCache to. - config (`Optional[FasterCacheConfig]`, `optional`, defaults to `None`): - The configuration to use for FasterCache. - - Example: - ```python - >>> import torch - >>> from diffusers import CogVideoXPipeline, FasterCacheConfig, apply_fastercache - - >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) - >>> pipe.to("cuda") - - >>> config = FasterCacheConfig( - ... spatial_attention_block_skip_range=2, - ... spatial_attention_timestep_skip_range=(-1, 681), - ... low_frequency_weight_update_timestep_range=(99, 641), - ... high_frequency_weight_update_timestep_range=(-1, 301), - ... spatial_attention_block_identifiers=["transformer_blocks"], - ... attention_weight_callback=lambda _: 0.3, - ... tensor_format="BFCHW", - ... ) - >>> apply_fastercache(pipe, config) - ``` - """ - - if config is None: - logger.warning("No FasterCacheConfig provided. Using default configuration.") - config = FasterCacheConfig() - - if config.attention_weight_callback is None: - # If the user has not provided a weight callback, we default to 0.5 for all timesteps. - # In the paper, they recommend using a gradually increasing weight from 0 to 1 as the inference progresses, but - # this depends from model-to-model. It is required by the user to provide a weight callback if they want to - # use a different weight function. Defaulting to 0.5 works well in practice for most cases. - logger.warning( - "No `attention_weight_callback` provided when enabling FasterCache. Defaulting to using a weight of 0.5 for all timesteps." - ) - config.attention_weight_callback = lambda _: 0.5 - - if config.low_frequency_weight_callback is None: - logger.debug( - "Low frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper." - ) - - def low_frequency_weight_callback(module: nn.Module) -> float: - is_within_range = ( - config.low_frequency_weight_update_timestep_range[0] - < pipeline._current_timestep - < config.low_frequency_weight_update_timestep_range[1] - ) - return config.alpha_low_frequency if is_within_range else 1.0 - - config.low_frequency_weight_callback = low_frequency_weight_callback - - if config.high_frequency_weight_callback is None: - logger.debug( - "High frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper." - ) - - def high_frequency_weight_callback(module: nn.Module) -> float: - is_within_range = ( - config.high_frequency_weight_update_timestep_range[0] - < pipeline._current_timestep - < config.high_frequency_weight_update_timestep_range[1] - ) - return config.alpha_high_frequency if is_within_range else 1.0 - - config.high_frequency_weight_callback = high_frequency_weight_callback - - supported_tensor_formats = ["BCFHW", "BFCHW", "BCHW"] # TODO(aryan): Support BSC for LTX Video - if config.tensor_format not in supported_tensor_formats: - raise ValueError(f"`tensor_format` must be one of {supported_tensor_formats}, but got {config.tensor_format}.") - - denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet - _apply_fastercache_on_denoiser(pipeline, denoiser, config) - - for name, module in denoiser.named_modules(): - if not isinstance(module, _ATTENTION_CLASSES): - continue - if isinstance(module, Attention): - _apply_fastercache_on_attention_class(pipeline, name, module, config) - - -def _apply_fastercache_on_denoiser( - pipeline: DiffusionPipeline, denoiser: nn.Module, config: FasterCacheConfig -) -> None: - def uncond_skip_callback(module: nn.Module) -> bool: - # We skip the unconditional branch only if the following conditions are met: - # 1. We have completed at least one iteration of the denoiser - # 2. The current timestep is within the range specified by the user. This is the optimal timestep range - # where approximating the unconditional branch from the computation of the conditional branch is possible - # without a significant loss in quality. - # 3. The current iteration is not a multiple of the unconditional batch skip range. This is done so that - # we compute the unconditional branch at least once every few iterations to ensure minimal quality loss. - - state: FasterCacheDenoiserState = module._fastercache_state - is_within_range = ( - config.unconditional_batch_timestep_skip_range[0] - < pipeline._current_timestep - < config.unconditional_batch_timestep_skip_range[1] - ) - return state.iteration > 0 and is_within_range and state.iteration % config.unconditional_batch_skip_range != 0 - - denoiser._fastercache_state = FasterCacheDenoiserState( - config.low_frequency_weight_callback, config.high_frequency_weight_callback, uncond_skip_callback - ) - hook = FasterCacheDenoiserHook( - config._unconditional_conditional_input_kwargs_identifiers, - config._guidance_distillation_kwargs_identifiers, - config.tensor_format, - ) - add_hook_to_module(denoiser, hook, append=True) - - -def _apply_fastercache_on_attention_class( - pipeline: DiffusionPipeline, name: str, module: Attention, config: FasterCacheConfig -) -> None: - is_spatial_self_attention = ( - any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers) - and config.spatial_attention_block_skip_range is not None - and not module.is_cross_attention - ) - is_temporal_self_attention = ( - any( - f"{identifier}." in name or identifier == name - for identifier in config.temporal_attention_block_identifiers - ) - and config.temporal_attention_block_skip_range is not None - and not module.is_cross_attention - ) - - block_skip_range, timestep_skip_range, block_type = None, None, None - if is_spatial_self_attention: - block_skip_range = config.spatial_attention_block_skip_range - timestep_skip_range = config.spatial_attention_timestep_skip_range - block_type = "spatial" - elif is_temporal_self_attention: - block_skip_range = config.temporal_attention_block_skip_range - timestep_skip_range = config.temporal_attention_timestep_skip_range - block_type = "temporal" - - if block_skip_range is None or timestep_skip_range is None: - logger.debug( - f'Unable to apply FasterCache to the selected layer: "{name}" because it does ' - f"not match any of the required criteria for spatial or temporal attention layers. Note, " - f"however, that this layer may still be valid for applying PAB. Please specify the correct " - f"block identifiers in the configuration or use the specialized `apply_fastercache_on_module` " - f"function to apply FasterCache to this layer." - ) - return - - def skip_callback(module: nn.Module) -> bool: - fastercache_state: FasterCacheBlockState = module._fastercache_state - is_within_timestep_range = timestep_skip_range[0] < pipeline._current_timestep < timestep_skip_range[1] - - if not is_within_timestep_range: - # We are still not in the phase of inference where skipping attention is possible without minimal quality - # loss, as described in the paper. So, the attention computation cannot be skipped - return False - - should_compute_attention = ( - fastercache_state.iteration > 0 and fastercache_state.iteration % block_skip_range == 0 - ) - return not should_compute_attention - - logger.debug(f"Enabling FasterCache ({block_type}) for layer: {name}") - module._fastercache_state = FasterCacheBlockState(skip_callback, config.attention_weight_callback) - hook = FasterCacheBlockHook() - add_hook_to_module(module, hook, append=True) - - class FasterCacheDenoiserHook(ModelHook): _is_stateful = True @@ -632,6 +452,186 @@ def reset_state(self, module: nn.Module) -> nn.Module: return module +def apply_fastercache( + pipeline: DiffusionPipeline, + config: Optional[FasterCacheConfig] = None, +) -> None: + r""" + Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline. + + Args: + pipeline (`DiffusionPipeline`): + The diffusion pipeline to apply FasterCache to. + config (`Optional[FasterCacheConfig]`, `optional`, defaults to `None`): + The configuration to use for FasterCache. + + Example: + ```python + >>> import torch + >>> from diffusers import CogVideoXPipeline, FasterCacheConfig, apply_fastercache + + >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> config = FasterCacheConfig( + ... spatial_attention_block_skip_range=2, + ... spatial_attention_timestep_skip_range=(-1, 681), + ... low_frequency_weight_update_timestep_range=(99, 641), + ... high_frequency_weight_update_timestep_range=(-1, 301), + ... spatial_attention_block_identifiers=["transformer_blocks"], + ... attention_weight_callback=lambda _: 0.3, + ... tensor_format="BFCHW", + ... ) + >>> apply_fastercache(pipe, config) + ``` + """ + + if config is None: + logger.warning("No FasterCacheConfig provided. Using default configuration.") + config = FasterCacheConfig() + + if config.attention_weight_callback is None: + # If the user has not provided a weight callback, we default to 0.5 for all timesteps. + # In the paper, they recommend using a gradually increasing weight from 0 to 1 as the inference progresses, but + # this depends from model-to-model. It is required by the user to provide a weight callback if they want to + # use a different weight function. Defaulting to 0.5 works well in practice for most cases. + logger.warning( + "No `attention_weight_callback` provided when enabling FasterCache. Defaulting to using a weight of 0.5 for all timesteps." + ) + config.attention_weight_callback = lambda _: 0.5 + + if config.low_frequency_weight_callback is None: + logger.debug( + "Low frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper." + ) + + def low_frequency_weight_callback(module: nn.Module) -> float: + is_within_range = ( + config.low_frequency_weight_update_timestep_range[0] + < pipeline._current_timestep + < config.low_frequency_weight_update_timestep_range[1] + ) + return config.alpha_low_frequency if is_within_range else 1.0 + + config.low_frequency_weight_callback = low_frequency_weight_callback + + if config.high_frequency_weight_callback is None: + logger.debug( + "High frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper." + ) + + def high_frequency_weight_callback(module: nn.Module) -> float: + is_within_range = ( + config.high_frequency_weight_update_timestep_range[0] + < pipeline._current_timestep + < config.high_frequency_weight_update_timestep_range[1] + ) + return config.alpha_high_frequency if is_within_range else 1.0 + + config.high_frequency_weight_callback = high_frequency_weight_callback + + supported_tensor_formats = ["BCFHW", "BFCHW", "BCHW"] # TODO(aryan): Support BSC for LTX Video + if config.tensor_format not in supported_tensor_formats: + raise ValueError(f"`tensor_format` must be one of {supported_tensor_formats}, but got {config.tensor_format}.") + + denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet + _apply_fastercache_on_denoiser(pipeline, denoiser, config) + + for name, module in denoiser.named_modules(): + if not isinstance(module, _ATTENTION_CLASSES): + continue + if isinstance(module, Attention): + _apply_fastercache_on_attention_class(pipeline, name, module, config) + + +def _apply_fastercache_on_denoiser( + pipeline: DiffusionPipeline, denoiser: nn.Module, config: FasterCacheConfig +) -> None: + def uncond_skip_callback(module: nn.Module) -> bool: + # We skip the unconditional branch only if the following conditions are met: + # 1. We have completed at least one iteration of the denoiser + # 2. The current timestep is within the range specified by the user. This is the optimal timestep range + # where approximating the unconditional branch from the computation of the conditional branch is possible + # without a significant loss in quality. + # 3. The current iteration is not a multiple of the unconditional batch skip range. This is done so that + # we compute the unconditional branch at least once every few iterations to ensure minimal quality loss. + + state: FasterCacheDenoiserState = module._fastercache_state + is_within_range = ( + config.unconditional_batch_timestep_skip_range[0] + < pipeline._current_timestep + < config.unconditional_batch_timestep_skip_range[1] + ) + return state.iteration > 0 and is_within_range and state.iteration % config.unconditional_batch_skip_range != 0 + + denoiser._fastercache_state = FasterCacheDenoiserState( + config.low_frequency_weight_callback, config.high_frequency_weight_callback, uncond_skip_callback + ) + hook = FasterCacheDenoiserHook( + config._unconditional_conditional_input_kwargs_identifiers, + config._guidance_distillation_kwargs_identifiers, + config.tensor_format, + ) + add_hook_to_module(denoiser, hook, append=True) + + +def _apply_fastercache_on_attention_class( + pipeline: DiffusionPipeline, name: str, module: Attention, config: FasterCacheConfig +) -> None: + is_spatial_self_attention = ( + any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers) + and config.spatial_attention_block_skip_range is not None + and not module.is_cross_attention + ) + is_temporal_self_attention = ( + any( + f"{identifier}." in name or identifier == name + for identifier in config.temporal_attention_block_identifiers + ) + and config.temporal_attention_block_skip_range is not None + and not module.is_cross_attention + ) + + block_skip_range, timestep_skip_range, block_type = None, None, None + if is_spatial_self_attention: + block_skip_range = config.spatial_attention_block_skip_range + timestep_skip_range = config.spatial_attention_timestep_skip_range + block_type = "spatial" + elif is_temporal_self_attention: + block_skip_range = config.temporal_attention_block_skip_range + timestep_skip_range = config.temporal_attention_timestep_skip_range + block_type = "temporal" + + if block_skip_range is None or timestep_skip_range is None: + logger.debug( + f'Unable to apply FasterCache to the selected layer: "{name}" because it does ' + f"not match any of the required criteria for spatial or temporal attention layers. Note, " + f"however, that this layer may still be valid for applying PAB. Please specify the correct " + f"block identifiers in the configuration or use the specialized `apply_fastercache_on_module` " + f"function to apply FasterCache to this layer." + ) + return + + def skip_callback(module: nn.Module) -> bool: + fastercache_state: FasterCacheBlockState = module._fastercache_state + is_within_timestep_range = timestep_skip_range[0] < pipeline._current_timestep < timestep_skip_range[1] + + if not is_within_timestep_range: + # We are still not in the phase of inference where skipping attention is possible without minimal quality + # loss, as described in the paper. So, the attention computation cannot be skipped + return False + + should_compute_attention = ( + fastercache_state.iteration > 0 and fastercache_state.iteration % block_skip_range == 0 + ) + return not should_compute_attention + + logger.debug(f"Enabling FasterCache ({block_type}) for layer: {name}") + module._fastercache_state = FasterCacheBlockState(skip_callback, config.attention_weight_callback) + hook = FasterCacheBlockHook() + add_hook_to_module(module, hook, append=True) + + # Reference: https://github.com/Vchitect/FasterCache/blob/fab32c15014636dc854948319c0a9a8d92c7acb4/scripts/latte/fastercache_sample_latte.py#L127C1-L143C39 @torch.no_grad() def _split_low_high_freq(x): From 6de34fe21eb7c23a98a83cc020433877bd8b341b Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 28 Jan 2025 05:48:21 +0100 Subject: [PATCH 16/26] update --- src/diffusers/models/hooks.py | 229 ------------------ .../pipelines/cogvideo/pipeline_cogvideox.py | 2 - src/diffusers/pipelines/flux/pipeline_flux.py | 1 - .../hunyuan_video/pipeline_hunyuan_video.py | 2 - .../pipelines/latte/pipeline_latte.py | 2 - .../pipelines/mochi/pipeline_mochi.py | 1 - tests/pipelines/cogvideo/test_cogvideox.py | 4 +- tests/pipelines/flux/test_pipeline_flux.py | 7 +- tests/pipelines/latte/test_latte.py | 11 +- 9 files changed, 18 insertions(+), 241 deletions(-) delete mode 100644 src/diffusers/models/hooks.py diff --git a/src/diffusers/models/hooks.py b/src/diffusers/models/hooks.py deleted file mode 100644 index 1686505f8f94..000000000000 --- a/src/diffusers/models/hooks.py +++ /dev/null @@ -1,229 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import functools -from typing import Any, Dict, List, Tuple - -import torch - - -# Reference: https://github.com/huggingface/accelerate/blob/ba7ab93f5e688466ea56908ea3b056fae2f9a023/src/accelerate/hooks.py -class ModelHook: - r""" - A hook that contains callbacks to be executed just before and after the forward method of a model. The difference - with PyTorch existing hooks is that they get passed along the kwargs. - """ - - _is_stateful = False - - def init_hook(self, module: torch.nn.Module) -> torch.nn.Module: - r""" - Hook that is executed when a model is initialized. - - Args: - module (`torch.nn.Module`): - The module attached to this hook. - """ - return module - - def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]: - r""" - Hook that is executed just before the forward method of the model. - - Args: - module (`torch.nn.Module`): - The module whose forward pass will be executed just after this event. - args (`Tuple[Any]`): - The positional arguments passed to the module. - kwargs (`Dict[Str, Any]`): - The keyword arguments passed to the module. - Returns: - `Tuple[Tuple[Any], Dict[Str, Any]]`: - A tuple with the treated `args` and `kwargs`. - """ - return args, kwargs - - def post_forward(self, module: torch.nn.Module, output: Any) -> Any: - r""" - Hook that is executed just after the forward method of the model. - - Args: - module (`torch.nn.Module`): - The module whose forward pass been executed just before this event. - output (`Any`): - The output of the module. - Returns: - `Any`: The processed `output`. - """ - return output - - def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module: - r""" - Hook that is executed when the hook is detached from a module. - - Args: - module (`torch.nn.Module`): - The module detached from this hook. - """ - return module - - def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: - if self._is_stateful: - raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.") - return module - - -class SequentialHook(ModelHook): - r"""A hook that can contain several hooks and iterates through them at each event.""" - - def __init__(self, *hooks): - self.hooks: List[ModelHook] = hooks - - def init_hook(self, module): - for hook in self.hooks: - module = hook.init_hook(module) - return module - - def pre_forward(self, module, *args, **kwargs): - for hook in self.hooks: - args, kwargs = hook.pre_forward(module, *args, **kwargs) - return args, kwargs - - def post_forward(self, module, output): - for hook in self.hooks: - output = hook.post_forward(module, output) - return output - - def detach_hook(self, module): - for hook in self.hooks: - module = hook.detach_hook(module) - return module - - def reset_state(self, module): - for hook in self.hooks: - if hook._is_stateful: - hook.reset_state(module) - return module - - -def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False) -> torch.nn.Module: - r""" - Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove - this behavior and restore the original `forward` method, use `remove_hook_from_module`. - - - - If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks - together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class. - - - - Args: - module (`torch.nn.Module`): - The module to attach a hook to. - hook (`ModelHook`): - The hook to attach. - append (`bool`, *optional*, defaults to `False`): - Whether the hook should be chained with an existing one (if module already contains a hook) or not. - - Returns: - `torch.nn.Module`: - The same module, with the hook attached (the module is modified in place, so the result can be discarded). - """ - original_hook = hook - - if append and getattr(module, "_diffusers_hook", None) is not None: - old_hook = module._diffusers_hook - remove_hook_from_module(module) - hook = SequentialHook(old_hook, hook) - - if hasattr(module, "_diffusers_hook") and hasattr(module, "_old_forward"): - # If we already put some hook on this module, we replace it with the new one. - old_forward = module._old_forward - else: - old_forward = module.forward - module._old_forward = old_forward - - module = hook.init_hook(module) - module._diffusers_hook = hook - - if hasattr(original_hook, "new_forward"): - new_forward = original_hook.new_forward - else: - - def new_forward(module, *args, **kwargs): - args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) - output = module._old_forward(*args, **kwargs) - return module._diffusers_hook.post_forward(module, output) - - # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail. - # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409 - if "GraphModuleImpl" in str(type(module)): - module.__class__.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward) - else: - module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward) - - return module - - -def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> torch.nn.Module: - """ - Removes any hook attached to a module via `add_hook_to_module`. - - Args: - module (`torch.nn.Module`): - The module to attach a hook to. - recurse (`bool`, defaults to `False`): - Whether to remove the hooks recursively - Returns: - `torch.nn.Module`: - The same module, with the hook detached (the module is modified in place, so the result can be discarded). - """ - - if hasattr(module, "_diffusers_hook"): - module._diffusers_hook.detach_hook(module) - delattr(module, "_diffusers_hook") - - if hasattr(module, "_old_forward"): - # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail. - # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409 - if "GraphModuleImpl" in str(type(module)): - module.__class__.forward = module._old_forward - else: - module.forward = module._old_forward - delattr(module, "_old_forward") - - if recurse: - for child in module.children(): - remove_hook_from_module(child, recurse) - - return module - - -def reset_stateful_hooks(module: torch.nn.Module, recurse: bool = False): - """ - Resets the state of all stateful hooks attached to a module. - - Args: - module (`torch.nn.Module`): - The module to reset the stateful hooks from. - """ - if hasattr(module, "_diffusers_hook") and ( - module._diffusers_hook._is_stateful or isinstance(module._diffusers_hook, SequentialHook) - ): - module._diffusers_hook.reset_state(module) - - if recurse: - for child in module.children(): - reset_stateful_hooks(child, recurse) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 068b7c54406f..99ae9025cd3e 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -24,7 +24,6 @@ from ...loaders import CogVideoXLoraLoaderMixin from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from ...models.embeddings import get_3d_rotary_pos_embed -from ...models.hooks import reset_stateful_hooks from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring @@ -782,7 +781,6 @@ def __call__( # Offload all models self.maybe_free_model_hooks() - reset_stateful_hooks(self.transformer, recurse=True) if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index b49b0ebc757c..aa02dc1de5da 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -974,7 +974,6 @@ def __call__( # Offload all models self.maybe_free_model_hooks() - reset_stateful_hooks(self.transformer, recurse=True) if not return_dict: return (image,) diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 5797a0a93dae..8cc77ed4c148 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -22,7 +22,6 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...loaders import HunyuanVideoLoraLoaderMixin from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel -from ...models.hooks import reset_stateful_hooks from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor @@ -696,7 +695,6 @@ def __call__( # Offload all models self.maybe_free_model_hooks() - reset_stateful_hooks(self.transformer, recurse=True) if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/latte/pipeline_latte.py b/src/diffusers/pipelines/latte/pipeline_latte.py index b154f7102319..6ec3eaf65005 100644 --- a/src/diffusers/pipelines/latte/pipeline_latte.py +++ b/src/diffusers/pipelines/latte/pipeline_latte.py @@ -25,7 +25,6 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...models import AutoencoderKL, LatteTransformer3DModel -from ...models.hooks import reset_stateful_hooks from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -874,7 +873,6 @@ def __call__( # Offload all models self.maybe_free_model_hooks() - reset_stateful_hooks(self.transformer, recurse=True) if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index fd0fa75ce3e6..d1f88b02c5cc 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -737,7 +737,6 @@ def __call__( # Offload all models self.maybe_free_model_hooks() - reset_stateful_hooks(self.transformer, recurse=True) if not return_dict: return (video,) diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py index fbedd12fd6d2..f48532f37e1e 100644 --- a/tests/pipelines/cogvideo/test_cogvideox.py +++ b/tests/pipelines/cogvideo/test_cogvideox.py @@ -43,7 +43,9 @@ enable_full_determinism() -class CogVideoXPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase): +class CogVideoXPipelineFastTests( + PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase +): pipeline_class = CogVideoXPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index bc5dcf6776af..fa8979c4ad52 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -27,7 +27,12 @@ class FluxPipelineFastTests( - unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin): + unittest.TestCase, + PipelineTesterMixin, + FluxIPAdapterTesterMixin, + PyramidAttentionBroadcastTesterMixin, + FasterCacheTesterMixin, +): pipeline_class = FluxPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) batch_params = frozenset(["prompt"]) diff --git a/tests/pipelines/latte/test_latte.py b/tests/pipelines/latte/test_latte.py index 69f197aeb57e..be731813cd1b 100644 --- a/tests/pipelines/latte/test_latte.py +++ b/tests/pipelines/latte/test_latte.py @@ -40,13 +40,20 @@ ) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import FasterCacheTesterMixin, PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np +from ..test_pipelines_common import ( + FasterCacheTesterMixin, + PipelineTesterMixin, + PyramidAttentionBroadcastTesterMixin, + to_np, +) enable_full_determinism() -class LattePipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase): +class LattePipelineFastTests( + PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase +): pipeline_class = LattePipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS From d98473d3b7f5df133cbd8717b74bb18b0b9cea2c Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 28 Jan 2025 22:04:29 +0100 Subject: [PATCH 17/26] refactor --- src/diffusers/__init__.py | 14 +- src/diffusers/hooks/__init__.py | 1 + .../faster_cache.py} | 402 +++++++++--------- .../hooks/pyramid_attention_broadcast.py | 2 +- src/diffusers/pipelines/__init__.py | 2 - src/diffusers/utils/dummy_pt_objects.py | 19 + .../dummy_torch_and_transformers_objects.py | 19 - tests/pipelines/test_pipelines_common.py | 12 +- 8 files changed, 238 insertions(+), 233 deletions(-) rename src/diffusers/{pipelines/fastercache_utils.py => hooks/faster_cache.py} (70%) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index d5cf47f74571..67e3d4451c8a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -78,8 +78,10 @@ else: _import_structure["hooks"].extend( [ + "FasterCacheConfig", "HookRegistry", "PyramidAttentionBroadcastConfig", + "apply_faster_cache", "apply_pyramid_attention_broadcast", ] ) @@ -287,7 +289,6 @@ "CogView3PlusPipeline", "ConsisIDPipeline", "CycleDiffusionPipeline", - "FasterCacheConfig", "FluxControlImg2ImgPipeline", "FluxControlInpaintPipeline", "FluxControlNetImg2ImgPipeline", @@ -434,7 +435,6 @@ "WuerstchenCombinedPipeline", "WuerstchenDecoderPipeline", "WuerstchenPriorPipeline", - "apply_fastercache", ] ) @@ -599,7 +599,13 @@ except OptionalDependencyNotAvailable: from .utils.dummy_pt_objects import * # noqa F403 else: - from .hooks import HookRegistry, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast + from .hooks import ( + FasterCacheConfig, + HookRegistry, + PyramidAttentionBroadcastConfig, + apply_faster_cache, + apply_pyramid_attention_broadcast, + ) from .models import ( AllegroTransformer3DModel, AsymmetricAutoencoderKL, @@ -782,7 +788,6 @@ CogView3PlusPipeline, ConsisIDPipeline, CycleDiffusionPipeline, - FasterCacheConfig, FluxControlImg2ImgPipeline, FluxControlInpaintPipeline, FluxControlNetImg2ImgPipeline, @@ -927,7 +932,6 @@ WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, WuerstchenPriorPipeline, - apply_fastercache, ) try: diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index e745b1320e84..dcdbd20664b4 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -2,6 +2,7 @@ if is_torch_available(): + from .faster_cache import FasterCacheConfig, apply_faster_cache from .hooks import HookRegistry, ModelHook from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast diff --git a/src/diffusers/pipelines/fastercache_utils.py b/src/diffusers/hooks/faster_cache.py similarity index 70% rename from src/diffusers/pipelines/fastercache_utils.py rename to src/diffusers/hooks/faster_cache.py index 542e1b7b888d..3f54d305eb73 100644 --- a/src/diffusers/pipelines/fastercache_utils.py +++ b/src/diffusers/hooks/faster_cache.py @@ -17,20 +17,19 @@ from typing import Any, Callable, List, Optional, Tuple import torch -import torch.fft as FFT -import torch.nn as nn -from ..models.attention_processor import Attention -from ..models.hooks import ModelHook, add_hook_to_module +from ..models.attention_processor import Attention, MochiAttention +from ..models.modeling_outputs import Transformer2DModelOutput from ..utils import logging -from .pipeline_utils import DiffusionPipeline +from .hooks import HookRegistry, ModelHook logger = logging.get_logger(__name__) # pylint: disable=invalid-name -# TODO(aryan): handle mochi attention -_ATTENTION_CLASSES = (Attention,) +_FASTER_CACHE_DENOISER_HOOK = "faster_cache_denoiser" +_FASTER_CACHE_BLOCK_HOOK = "faster_cache_block" +_ATTENTION_CLASSES = (Attention, MochiAttention) _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ( "blocks.*attn", @@ -45,7 +44,6 @@ "attention_mask", "encoder_attention_mask", ) -_GUIDANCE_DISTILLATION_KWARGS_IDENTIFIERS = ("guidance",) @dataclass @@ -106,33 +104,30 @@ class FasterCacheConfig: The identifiers to match the temporal attention blocks in the model. If the name of the block contains any of these identifiers, FasterCache will be applied to that block. This can either be the full layer names, partial layer names, or regex patterns. Matching will always be done using a regex match. - attention_weight_callback (`Callable[[nn.Module], float]`, defaults to `None`): + attention_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`): The callback function to determine the weight to scale the attention outputs by. This function should take the attention module as input and return a float value. This is used to approximate the unconditional branch from the conditional branch outputs. If not provided, the default weight is 0.5 for all timesteps. Typically, as described in the paper, this weight should gradually increase from 0 to 1 as the inference progresses. Users are encouraged to experiment and provide custom weight schedules that take into account the number of inference steps and underlying model behaviour as denoising progresses. - low_frequency_weight_callback (`Callable[[nn.Module], float]`, defaults to `None`): + low_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`): The callback function to determine the weight to scale the low frequency updates by. If not provided, the default weight is 1.1 for timesteps within the range specified (as described in the paper). - high_frequency_weight_callback (`Callable[[nn.Module], float]`, defaults to `None`): + high_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`): The callback function to determine the weight to scale the high frequency updates by. If not provided, the default weight is 1.1 for timesteps within the range specified (as described in the paper). tensor_format (`str`, defaults to `"BCFHW"`): The format of the input tensors. This should be one of `"BCFHW"`, `"BFCHW"`, or `"BCHW"`. The format is used to split individual latent frames in order for low and high frequency components to be computed. + is_guidance_distilled (`bool`, defaults to `False`): + Whether the model is guidance distilled or not. If the model is guidance distilled, FasterCache will not be + applied at the denoiser-level to skip the unconditional branch computation (as there is none). _unconditional_conditional_input_kwargs_identifiers (`List[str]`, defaults to `("hidden_states", "encoder_hidden_states", "timestep", "attention_mask", "encoder_attention_mask")`): The identifiers to match the input kwargs that contain the batchwise-concatenated unconditional and conditional inputs. If the name of the input kwargs contains any of these identifiers, FasterCache will split the inputs into unconditional and conditional branches. This must be a list of exact input kwargs names that contain the batchwise-concatenated unconditional and conditional inputs. - _guidance_distillation_kwargs_identifiers (`List[str]`, defaults to `("guidance",)`): - The identifiers to match the input kwargs that contain the guidance distillation inputs. If the name of the - input kwargs contains any of these identifiers, FasterCache will not split the inputs into unconditional - and conditional branches (unconditional branches are only computed sometimes based on certain checks). This - allows usage of FasterCache in models like Flux-Dev and HunyuanVideo which are guidance-distilled (only - attention skipping related parts are applied, and not unconditional branch approximation). """ # In the paper and codebase, they hardcode these values to 2. However, it can be made configurable @@ -140,7 +135,6 @@ class FasterCacheConfig: spatial_attention_block_skip_range: int = 2 temporal_attention_block_skip_range: Optional[int] = None - # TODO(aryan): write heuristics for what the best way to obtain these values are spatial_attention_timestep_skip_range: Tuple[int, int] = (-1, 681) temporal_attention_timestep_skip_range: Tuple[int, int] = (-1, 681) @@ -159,14 +153,35 @@ class FasterCacheConfig: spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS - attention_weight_callback: Callable[[nn.Module], float] = None - low_frequency_weight_callback: Callable[[nn.Module], float] = None - high_frequency_weight_callback: Callable[[nn.Module], float] = None + attention_weight_callback: Callable[[torch.nn.Module], float] = None + low_frequency_weight_callback: Callable[[torch.nn.Module], float] = None + high_frequency_weight_callback: Callable[[torch.nn.Module], float] = None tensor_format: str = "BCFHW" + is_guidance_distilled: bool = False + + current_timestep_callback: Callable[[], int] = None _unconditional_conditional_input_kwargs_identifiers: List[str] = _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS - _guidance_distillation_kwargs_identifiers: List[str] = _GUIDANCE_DISTILLATION_KWARGS_IDENTIFIERS + + def __repr__(self) -> str: + return ( + f"FasterCacheConfig(\n" + f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n" + f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n" + f" spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range},\n" + f" temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range},\n" + f" low_frequency_weight_update_timestep_range={self.low_frequency_weight_update_timestep_range},\n" + f" high_frequency_weight_update_timestep_range={self.high_frequency_weight_update_timestep_range},\n" + f" alpha_low_frequency={self.alpha_low_frequency},\n" + f" alpha_high_frequency={self.alpha_high_frequency},\n" + f" unconditional_batch_skip_range={self.unconditional_batch_skip_range},\n" + f" unconditional_batch_timestep_skip_range={self.unconditional_batch_timestep_skip_range},\n" + f" spatial_attention_block_identifiers={self.spatial_attention_block_identifiers},\n" + f" temporal_attention_block_identifiers={self.temporal_attention_block_identifiers},\n" + f" tensor_format={self.tensor_format},\n" + f")" + ) class FasterCacheDenoiserState: @@ -174,55 +189,32 @@ class FasterCacheDenoiserState: State for [FasterCache](https://huggingface.co/papers/2410.19355) top-level denoiser module. """ - def __init__( - self, - low_frequency_weight_callback: Callable[[nn.Module], torch.Tensor], - high_frequency_weight_callback: Callable[[nn.Module], torch.Tensor], - uncond_skip_callback: Callable[[nn.Module], bool], - ) -> None: - self.low_frequency_weight_callback = low_frequency_weight_callback - self.high_frequency_weight_callback = high_frequency_weight_callback - self.uncond_skip_callback = uncond_skip_callback - + def __init__(self) -> None: self.iteration: int = 0 self.low_frequency_delta: torch.Tensor = None self.high_frequency_delta: torch.Tensor = None - self.is_guidance_distilled: bool = None def reset(self): self.iteration = 0 self.low_frequency_delta = None self.high_frequency_delta = None - self.is_guidance_distilled = None class FasterCacheBlockState: r""" State for [FasterCache](https://huggingface.co/papers/2410.19355). Every underlying block that FasterCache is applied to will have an instance of this state. - - Attributes: - iteration (`int`): - The current iteration of the FasterCache. It is necessary to ensure that `reset_state` is called before - starting a new inference forward pass for this to work correctly. """ - def __init__( - self, skip_callback: Callable[[nn.Module], bool], weight_callback: Callable[[nn.Module], float] - ) -> None: - self.skip_callback = skip_callback - self.weight_callback = weight_callback - + def __init__(self) -> None: self.iteration: int = 0 self.batch_size: int = None self.cache: Tuple[torch.Tensor, torch.Tensor] = None - self.is_guidance_distilled: bool = None def reset(self): self.iteration = 0 self.batch_size = None self.cache = None - self.is_guidance_distilled = None class FasterCacheDenoiserHook(ModelHook): @@ -230,22 +222,34 @@ class FasterCacheDenoiserHook(ModelHook): def __init__( self, - uncond_cond_input_kwargs_identifiers: List[str], - guidance_distillation_kwargs_identifiers: List[str], + unconditional_batch_skip_range: int, + unconditional_batch_timestep_skip_range: Tuple[int, int], tensor_format: str, + is_guidance_distilled: bool, + uncond_cond_input_kwargs_identifiers: List[str], + current_timestep_callback: Callable[[], int], + low_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor], + high_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor], ) -> None: super().__init__() + self.unconditional_batch_skip_range = unconditional_batch_skip_range + self.unconditional_batch_timestep_skip_range = unconditional_batch_timestep_skip_range # We can't easily detect what args are to be split in unconditional and conditional branches. We # can only do it for kwargs, hence they are the only ones we split. The args are passed as-is. # If a model is to be made compatible with FasterCache, the user must ensure that the inputs that # contain batchwise-concatenated unconditional and conditional inputs are passed as kwargs. self.uncond_cond_input_kwargs_identifiers = uncond_cond_input_kwargs_identifiers + self.tensor_format = tensor_format + self.is_guidance_distilled = is_guidance_distilled - # See documentation for `guidance_distillation_kwargs_identifiers` in FasterCacheConfig for more information - self.guidance_distillation_kwargs_identifiers = guidance_distillation_kwargs_identifiers + self.current_timestep_callback = current_timestep_callback + self.low_frequency_weight_callback = low_frequency_weight_callback + self.high_frequency_weight_callback = high_frequency_weight_callback - self.tensor_format = tensor_format + def initialize_hook(self, module): + self.state = FasterCacheDenoiserState() + return module @staticmethod def _get_cond_input(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -254,53 +258,57 @@ def _get_cond_input(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: _, cond = input.chunk(2, dim=0) return cond - def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: - args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) - state: FasterCacheDenoiserState = module._fastercache_state - + def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: # Split the unconditional and conditional inputs. We only want to infer the conditional branch if the # requirements for skipping the unconditional branch are met as described in the paper. - should_skip_uncond = state.uncond_skip_callback(module) - - if state.is_guidance_distilled is None: - # The following check assumes that the guidance embedding are torch tensors for check to pass. This - # seems to be true for all models supported in diffusers - state.is_guidance_distilled = any( - identifier in kwargs and kwargs[identifier] is not None and torch.is_tensor(kwargs[identifier]) - for identifier in self.guidance_distillation_kwargs_identifiers - ) - # Make all children FasterCacheBlockHooks aware of whether the model is guidance distilled or not - # because we cannot determine this within the block hooks - for name, child_module in module.named_modules(): - if hasattr(child_module, "_fastercache_state") and isinstance( - child_module._fastercache_state, FasterCacheBlockState - ): - # TODO(aryan): remove later - logger.debug(f"Setting guidance distillation flag for layer: {name}") - child_module._fastercache_state.is_guidance_distilled = state.is_guidance_distilled - assert state.is_guidance_distilled is not None - - if should_skip_uncond and not state.is_guidance_distilled: - kwargs = { - k: v if k not in self.uncond_cond_input_kwargs_identifiers else self._get_cond_input(v) - for k, v in kwargs.items() - } - # TODO(aryan): remove later - logger.debug("Skipping unconditional branch computation") - - output = module._old_forward(*args, **kwargs) - - if state.is_guidance_distilled: - state.iteration += 1 + # We skip the unconditional branch only if the following conditions are met: + # 1. We have completed at least one iteration of the denoiser + # 2. The current timestep is within the range specified by the user. This is the optimal timestep range + # where approximating the unconditional branch from the computation of the conditional branch is possible + # without a significant loss in quality. + # 3. The current iteration is not a multiple of the unconditional batch skip range. This is done so that + # we compute the unconditional branch at least once every few iterations to ensure minimal quality loss. + is_within_timestep_range = ( + self.unconditional_batch_timestep_skip_range[0] + < self.current_timestep_callback() + < self.unconditional_batch_timestep_skip_range[1] + ) + should_skip_uncond = ( + self.state.iteration > 0 + and is_within_timestep_range + and self.state.iteration % self.unconditional_batch_skip_range != 0 + ) + + if should_skip_uncond and not self.is_guidance_distilled: + is_any_kwarg_uncond = any(k in self.uncond_cond_input_kwargs_identifiers for k in kwargs.keys()) + if is_any_kwarg_uncond: + logger.debug("FasterCache - Skipping unconditional branch computation") + args = tuple([self._get_cond_input(arg) if torch.is_tensor(arg) else arg for arg in args]) + kwargs = { + k: v if k not in self.uncond_cond_input_kwargs_identifiers else self._get_cond_input(v) + for k, v in kwargs.items() + } + + output = self.fn_ref.original_forward(*args, **kwargs) + + if self.is_guidance_distilled: + self.state.iteration += 1 return output - # TODO(aryan): handle Transformer2DModelOutput - hidden_states = output[0] if isinstance(output, tuple) else output + if torch.is_tensor(output): + hidden_states = output + elif isinstance(output, (tuple, Transformer2DModelOutput)): + hidden_states = output[0] + batch_size = hidden_states.size(0) if should_skip_uncond: - state.low_frequency_delta = state.low_frequency_delta * state.low_frequency_weight_callback(module) - state.high_frequency_delta = state.high_frequency_delta * state.high_frequency_weight_callback(module) + self.state.low_frequency_delta = self.state.low_frequency_delta * self.low_frequency_weight_callback( + module + ) + self.state.high_frequency_delta = self.state.high_frequency_delta * self.high_frequency_weight_callback( + module + ) if self.tensor_format == "BCFHW": hidden_states = hidden_states.permute(0, 2, 1, 3, 4) @@ -310,12 +318,12 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: low_freq_cond, high_freq_cond = _split_low_high_freq(hidden_states.float()) # Approximate/compute the unconditional branch outputs as described in Equation 9 and 10 of the paper - low_freq_uncond = state.low_frequency_delta + low_freq_cond - high_freq_uncond = state.high_frequency_delta + high_freq_cond + low_freq_uncond = self.state.low_frequency_delta + low_freq_cond + high_freq_uncond = self.state.high_frequency_delta + high_freq_cond uncond_freq = low_freq_uncond + high_freq_uncond - uncond_states = FFT.ifftshift(uncond_freq) - uncond_states = FFT.ifft2(uncond_states).real + uncond_states = torch.fft.ifftshift(uncond_freq) + uncond_states = torch.fft.ifft2(uncond_states).real if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW": uncond_states = uncond_states.unflatten(0, (batch_size, -1)) @@ -328,9 +336,6 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: uncond_states = uncond_states.to(hidden_states.dtype) hidden_states = torch.cat([uncond_states, hidden_states], dim=0) else: - # TODO(aryan): remove later - logger.debug("Computing unconditional branch") - uncond_states, cond_states = hidden_states.chunk(2, dim=0) if self.tensor_format == "BCFHW": uncond_states = uncond_states.permute(0, 2, 1, 3, 4) @@ -341,25 +346,51 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: low_freq_uncond, high_freq_uncond = _split_low_high_freq(uncond_states.float()) low_freq_cond, high_freq_cond = _split_low_high_freq(cond_states.float()) - state.low_frequency_delta = low_freq_uncond - low_freq_cond - state.high_frequency_delta = high_freq_uncond - high_freq_cond + self.state.low_frequency_delta = low_freq_uncond - low_freq_cond + self.state.high_frequency_delta = high_freq_uncond - high_freq_cond - state.iteration += 1 - output = (hidden_states, *output[1:]) if isinstance(output, tuple) else hidden_states - return module._diffusers_hook.post_forward(module, output) + self.state.iteration += 1 + if torch.is_tensor(output): + output = hidden_states + elif isinstance(output, tuple): + output = (hidden_states, *output[1:]) + else: + output.sample = hidden_states + + return output - def reset_state(self, module: nn.Module) -> nn.Module: - module._fastercache_state.reset() + def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: + self.state.reset() return module class FasterCacheBlockHook(ModelHook): _is_stateful = True + def __init__( + self, + block_skip_range: int, + timestep_skip_range: Tuple[int, int], + is_guidance_distilled: bool, + weight_callback: Callable[[torch.nn.Module], float], + current_timestep_callback: Callable[[], int], + ) -> None: + super().__init__() + + self.block_skip_range = block_skip_range + self.timestep_skip_range = timestep_skip_range + self.is_guidance_distilled = is_guidance_distilled + + self.weight_callback = weight_callback + self.current_timestep_callback = current_timestep_callback + + def initialize_hook(self, module): + self.state = FasterCacheBlockState() + return module + def _compute_approximated_attention_output( self, t_2_output: torch.Tensor, t_output: torch.Tensor, weight: float, batch_size: int ) -> torch.Tensor: - # TODO(aryan): these conditions may not be needed after latest refactor. they exist for safety. do test if they can be removed if t_2_output.size(0) != batch_size: # The cache t_2_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just # take the conditional branch outputs. @@ -372,20 +403,14 @@ def _compute_approximated_attention_output( t_output = t_output[batch_size:] return t_output + (t_output - t_2_output) * weight - def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: - args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs) - state: FasterCacheBlockState = module._fastercache_state - - # The denoiser should have set this flag for all children FasterCacheBlockHooks to either True or False - assert state.is_guidance_distilled is not None - + def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: batch_size = [ *[arg.size(0) for arg in args if torch.is_tensor(arg)], *[v.size(0) for v in kwargs.values() if torch.is_tensor(v)], ][0] - if state.batch_size is None: + if self.state.batch_size is None: # Will be updated on first forward pass through the denoiser - state.batch_size = batch_size + self.state.batch_size = batch_size # If we have to skip due to the skip conditions, then let's skip as expected. # But, we can't skip if the denoiser wants to infer both unconditional and conditional branches. This @@ -393,17 +418,22 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: # the cache (which only caches conditional branch outputs). So, if state.batch_size (which is the true # unconditional-conditional batch size) is same as the current batch size, we don't perform the layer # skip. Otherwise, we conditionally skip the layer based on what state.skip_callback returns. - should_skip_attention = state.skip_callback(module) and ( - state.is_guidance_distilled or state.batch_size != batch_size + is_within_timestep_range = ( + self.timestep_skip_range[0] < self.current_timestep_callback() < self.timestep_skip_range[1] ) - + if not is_within_timestep_range: + should_skip_attention = False + else: + should_compute_attention = self.state.iteration > 0 and self.state.iteration % self.block_skip_range == 0 + should_skip_attention = not should_compute_attention if should_skip_attention: - # TODO(aryan): remove later - logger.debug("Skipping attention") + should_skip_attention = self.is_guidance_distilled or self.state.batch_size != batch_size - if torch.is_tensor(state.cache[-1]): - t_2_output, t_output = state.cache - weight = state.weight_callback(module) + if should_skip_attention: + logger.debug("FasterCache - Skipping attention and using approximation") + if torch.is_tensor(self.state.cache[-1]): + t_2_output, t_output = self.state.cache + weight = self.weight_callback(module) output = self._compute_approximated_attention_output(t_2_output, t_output, weight, batch_size) else: # The cache contains multiple tensors from past N iterations (N=2 for FasterCache). We need to handle all of them. @@ -413,21 +443,21 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: # The zip(*state.cache) operation will give us [(A_1, A_2, ...), (B_1, B_2, ...), (C_1, C_2, ...), ...] which # allows us to compute the approximated attention output for each tensor in the cache. output = () - for t_2_output, t_output in zip(*state.cache): + for t_2_output, t_output in zip(*self.state.cache): result = self._compute_approximated_attention_output( - t_2_output, t_output, state.weight_callback(module), batch_size + t_2_output, t_output, self.weight_callback(module), batch_size ) output += (result,) else: - logger.debug("Computing attention") - output = module._old_forward(*args, **kwargs) + logger.debug("FasterCache - Computing attention") + output = self.fn_ref.original_forward(*args, **kwargs) # Note that the following condition for getting hidden_states should suffice since Diffusers blocks either return # a single hidden_states tensor, or a tuple of (hidden_states, encoder_hidden_states) tensors. We need to handle # both cases. if torch.is_tensor(output): cache_output = output - if not state.is_guidance_distilled and cache_output.size(0) == state.batch_size: + if not self.is_guidance_distilled and cache_output.size(0) == self.state.batch_size: # The output here can be both unconditional-conditional branch outputs or just conditional branch outputs. # This is determined at the higher-level denoiser module. We only want to cache the conditional branch outputs. cache_output = cache_output.chunk(2, dim=0)[1] @@ -435,25 +465,25 @@ def new_forward(self, module: nn.Module, *args, **kwargs) -> Any: # Cache all return values and perform the same operation as above cache_output = () for out in output: - if not state.is_guidance_distilled and out.size(0) == state.batch_size: + if not self.is_guidance_distilled and out.size(0) == self.state.batch_size: out = out.chunk(2, dim=0)[1] cache_output += (out,) - if state.cache is None: - state.cache = [cache_output, cache_output] + if self.state.cache is None: + self.state.cache = [cache_output, cache_output] else: - state.cache = [state.cache[-1], cache_output] + self.state.cache = [self.state.cache[-1], cache_output] - state.iteration += 1 - return module._diffusers_hook.post_forward(module, output) + self.state.iteration += 1 + return output - def reset_state(self, module: nn.Module) -> nn.Module: - module._fastercache_state.reset() + def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: + self.state.reset() return module -def apply_fastercache( - pipeline: DiffusionPipeline, +def apply_faster_cache( + module: torch.nn.Module, config: Optional[FasterCacheConfig] = None, ) -> None: r""" @@ -468,7 +498,7 @@ def apply_fastercache( Example: ```python >>> import torch - >>> from diffusers import CogVideoXPipeline, FasterCacheConfig, apply_fastercache + >>> from diffusers import CogVideoXPipeline, FasterCacheConfig, apply_faster_cache >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") @@ -482,7 +512,7 @@ def apply_fastercache( ... attention_weight_callback=lambda _: 0.3, ... tensor_format="BFCHW", ... ) - >>> apply_fastercache(pipe, config) + >>> apply_faster_cache(pipe.transformer, config) ``` """ @@ -505,10 +535,10 @@ def apply_fastercache( "Low frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper." ) - def low_frequency_weight_callback(module: nn.Module) -> float: + def low_frequency_weight_callback(module: torch.nn.Module) -> float: is_within_range = ( config.low_frequency_weight_update_timestep_range[0] - < pipeline._current_timestep + < config.current_timestep_callback() < config.low_frequency_weight_update_timestep_range[1] ) return config.alpha_low_frequency if is_within_range else 1.0 @@ -520,10 +550,10 @@ def low_frequency_weight_callback(module: nn.Module) -> float: "High frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper." ) - def high_frequency_weight_callback(module: nn.Module) -> float: + def high_frequency_weight_callback(module: torch.nn.Module) -> float: is_within_range = ( config.high_frequency_weight_update_timestep_range[0] - < pipeline._current_timestep + < config.current_timestep_callback() < config.high_frequency_weight_update_timestep_range[1] ) return config.alpha_high_frequency if is_within_range else 1.0 @@ -534,50 +564,30 @@ def high_frequency_weight_callback(module: nn.Module) -> float: if config.tensor_format not in supported_tensor_formats: raise ValueError(f"`tensor_format` must be one of {supported_tensor_formats}, but got {config.tensor_format}.") - denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet - _apply_fastercache_on_denoiser(pipeline, denoiser, config) + _apply_faster_cache_on_denoiser(module, config) - for name, module in denoiser.named_modules(): - if not isinstance(module, _ATTENTION_CLASSES): + for name, submodule in module.named_modules(): + if not isinstance(submodule, _ATTENTION_CLASSES): continue - if isinstance(module, Attention): - _apply_fastercache_on_attention_class(pipeline, name, module, config) - + _apply_faster_cache_on_attention_class(name, submodule, config) -def _apply_fastercache_on_denoiser( - pipeline: DiffusionPipeline, denoiser: nn.Module, config: FasterCacheConfig -) -> None: - def uncond_skip_callback(module: nn.Module) -> bool: - # We skip the unconditional branch only if the following conditions are met: - # 1. We have completed at least one iteration of the denoiser - # 2. The current timestep is within the range specified by the user. This is the optimal timestep range - # where approximating the unconditional branch from the computation of the conditional branch is possible - # without a significant loss in quality. - # 3. The current iteration is not a multiple of the unconditional batch skip range. This is done so that - # we compute the unconditional branch at least once every few iterations to ensure minimal quality loss. - - state: FasterCacheDenoiserState = module._fastercache_state - is_within_range = ( - config.unconditional_batch_timestep_skip_range[0] - < pipeline._current_timestep - < config.unconditional_batch_timestep_skip_range[1] - ) - return state.iteration > 0 and is_within_range and state.iteration % config.unconditional_batch_skip_range != 0 - denoiser._fastercache_state = FasterCacheDenoiserState( - config.low_frequency_weight_callback, config.high_frequency_weight_callback, uncond_skip_callback - ) +def _apply_faster_cache_on_denoiser(module: torch.nn.Module, config: FasterCacheConfig) -> None: hook = FasterCacheDenoiserHook( - config._unconditional_conditional_input_kwargs_identifiers, - config._guidance_distillation_kwargs_identifiers, + config.unconditional_batch_skip_range, + config.unconditional_batch_timestep_skip_range, config.tensor_format, + config.is_guidance_distilled, + config._unconditional_conditional_input_kwargs_identifiers, + config.current_timestep_callback, + config.low_frequency_weight_callback, + config.high_frequency_weight_callback, ) - add_hook_to_module(denoiser, hook, append=True) + registry = HookRegistry.check_if_exists_or_initialize(module) + registry.register_hook(hook, _FASTER_CACHE_DENOISER_HOOK) -def _apply_fastercache_on_attention_class( - pipeline: DiffusionPipeline, name: str, module: Attention, config: FasterCacheConfig -) -> None: +def _apply_faster_cache_on_attention_class(name: str, module: Attention, config: FasterCacheConfig) -> None: is_spatial_self_attention = ( any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers) and config.spatial_attention_block_skip_range is not None @@ -607,36 +617,28 @@ def _apply_fastercache_on_attention_class( f'Unable to apply FasterCache to the selected layer: "{name}" because it does ' f"not match any of the required criteria for spatial or temporal attention layers. Note, " f"however, that this layer may still be valid for applying PAB. Please specify the correct " - f"block identifiers in the configuration or use the specialized `apply_fastercache_on_module` " + f"block identifiers in the configuration or use the specialized `apply_faster_cache_on_module` " f"function to apply FasterCache to this layer." ) return - def skip_callback(module: nn.Module) -> bool: - fastercache_state: FasterCacheBlockState = module._fastercache_state - is_within_timestep_range = timestep_skip_range[0] < pipeline._current_timestep < timestep_skip_range[1] - - if not is_within_timestep_range: - # We are still not in the phase of inference where skipping attention is possible without minimal quality - # loss, as described in the paper. So, the attention computation cannot be skipped - return False - - should_compute_attention = ( - fastercache_state.iteration > 0 and fastercache_state.iteration % block_skip_range == 0 - ) - return not should_compute_attention - logger.debug(f"Enabling FasterCache ({block_type}) for layer: {name}") - module._fastercache_state = FasterCacheBlockState(skip_callback, config.attention_weight_callback) - hook = FasterCacheBlockHook() - add_hook_to_module(module, hook, append=True) + hook = FasterCacheBlockHook( + block_skip_range, + timestep_skip_range, + config.is_guidance_distilled, + config.attention_weight_callback, + config.current_timestep_callback, + ) + registry = HookRegistry.check_if_exists_or_initialize(module) + registry.register_hook(hook, _FASTER_CACHE_BLOCK_HOOK) # Reference: https://github.com/Vchitect/FasterCache/blob/fab32c15014636dc854948319c0a9a8d92c7acb4/scripts/latte/fastercache_sample_latte.py#L127C1-L143C39 @torch.no_grad() def _split_low_high_freq(x): - fft = FFT.fft2(x) - fft_shifted = FFT.fftshift(fft) + fft = torch.fft.fft2(x) + fft_shifted = torch.fft.fftshift(fft) height, width = x.shape[-2:] radius = min(height, width) // 5 diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py index 9f8597d52f8c..de914b92b95b 100644 --- a/src/diffusers/hooks/pyramid_attention_broadcast.py +++ b/src/diffusers/hooks/pyramid_attention_broadcast.py @@ -87,7 +87,7 @@ class PyramidAttentionBroadcastConfig: def __repr__(self) -> str: return ( - f"PyramidAttentionBroadcastConfig(" + f"PyramidAttentionBroadcastConfig(\n" f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n" f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n" f" cross_attention_block_skip_range={self.cross_attention_block_skip_range},\n" diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 1463b3832225..5829cf495dcc 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -58,7 +58,6 @@ "StableDiffusionMixin", "ImagePipelineOutput", ] - _import_structure["faster_cache_utils"] = ["FasterCacheConfig", "apply_fastercache"] _import_structure["deprecated"].extend( [ "PNDMPipeline", @@ -451,7 +450,6 @@ from .ddpm import DDPMPipeline from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline from .dit import DiTPipeline - from .fastercache_utils import FasterCacheConfig, apply_fastercache from .latent_diffusion import LDMSuperResolutionPipeline from .pipeline_utils import ( AudioPipelineOutput, diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 6a1978944c9f..1b3119a04ea3 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -2,6 +2,21 @@ from ..utils import DummyObject, requires_backends +class FasterCacheConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class HookRegistry(metaclass=DummyObject): _backends = ["torch"] @@ -32,6 +47,10 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +def apply_faster_cache(*args, **kwargs): + requires_backends(apply_faster_cache, ["torch"]) + + def apply_pyramid_attention_broadcast(*args, **kwargs): requires_backends(apply_pyramid_attention_broadcast, ["torch"]) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 23ef00e4fe55..b899915c3046 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -392,21 +392,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class FasterCacheConfig(metaclass=DummyObject): - _backends = ["torch", "transformers"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers"]) - - class FluxControlImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -2565,7 +2550,3 @@ def from_config(cls, *args, **kwargs): @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) - - -def apply_fastercache(*args, **kwargs): - requires_backends(apply_fastercache, ["torch", "transformers"]) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 1b0f1a3cc9fa..6984779e98ad 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -29,7 +29,7 @@ StableDiffusionPipeline, StableDiffusionXLPipeline, UNet2DConditionModel, - apply_fastercache, + apply_faster_cache, ) from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook from diffusers.image_processor import VaeImageProcessor @@ -2479,21 +2479,21 @@ def test_fastercache_basic_warning_or_errors_raised(self): # Check if warning is raised when no FasterCacheConfig is provided pipe = self.pipeline_class(**components) with CaptureLogger(logger) as cap_logger: - apply_fastercache(pipe) + apply_faster_cache(pipe) self.assertTrue("No FasterCacheConfig provided" in cap_logger.out) # Check if warning is raise when no attention_weight_callback is provided pipe = self.pipeline_class(**components) with CaptureLogger(logger) as cap_logger: config = FasterCacheConfig(spatial_attention_block_skip_range=2, attention_weight_callback=None) - apply_fastercache(pipe, config) + apply_faster_cache(pipe, config) self.assertTrue("No `attention_weight_callback` provided when enabling FasterCache" in cap_logger.out) # Check if error raised when unsupported tensor format used pipe = self.pipeline_class(**components) with self.assertRaises(ValueError): config = FasterCacheConfig(spatial_attention_block_skip_range=2, tensor_format="BFHWC") - apply_fastercache(pipe, config) + apply_faster_cache(pipe, config) def test_fastercache_inference(self, expected_atol: float = 0.1): device = "cpu" # ensure determinism for the device-dependent torch.Generator @@ -2509,7 +2509,7 @@ def test_fastercache_inference(self, expected_atol: float = 0.1): original_image_slice = output.flatten() original_image_slice = np.concatenate((original_image_slice[:8], original_image_slice[-8:])) - apply_fastercache(pipe, self.fastercache_config) + apply_faster_cache(pipe, self.fastercache_config) inputs = self.get_dummy_inputs(device) inputs["num_inference_steps"] = 4 @@ -2541,7 +2541,7 @@ def test_fastercache_state(self): pipe = self.pipeline_class(**components) pipe.set_progress_bar_config(disable=None) - apply_fastercache(pipe, self.fastercache_config) + apply_faster_cache(pipe, self.fastercache_config) expected_hooks = 0 if self.fastercache_config.spatial_attention_block_skip_range is not None: From 93de5f36d52428506d6ada0f3477699891abd05a Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 28 Jan 2025 22:23:31 +0100 Subject: [PATCH 18/26] update docs --- docs/source/en/api/cache.md | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/docs/source/en/api/cache.md b/docs/source/en/api/cache.md index 403dbf88b431..a6aa5445a845 100644 --- a/docs/source/en/api/cache.md +++ b/docs/source/en/api/cache.md @@ -38,6 +38,33 @@ config = PyramidAttentionBroadcastConfig( pipe.transformer.enable_cache(config) ``` +## Faster Cache + +[FasterCache](https://huggingface.co/papers/2410.19355) from Zhengyao Lv, Chenyang Si, Junhao Song, Zhenyu Yang, Yu Qiao, Ziwei Liu, Kwan-Yee K. Wong. + +FasterCache is a method that speeds up inference in diffusion transformers by: +- Reusing attention states between successive inference steps, due to high similarity between them +- Skipping unconditional branch prediction used in classifier-free guidance by revealing redundancies between unconditional and conditional branch outputs for the same timestep, and therefore approximating the unconditional branch output using the conditional branch output + +```python +import torch +from diffusers import CogVideoXPipeline, FasterCacheConfig + +pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +config = FasterCacheConfig( + spatial_attention_block_skip_range=2, + spatial_attention_timestep_skip_range=(-1, 681), + current_timestep_callback=lambda: pipe.current_timestep, + attention_weight_callback=lambda _: 0.3, + unconditional_batch_skip_range=5, + unconditional_batch_timestep_skip_range=(-1, 781), + tensor_format="BFCHW", +) +pipe.transformer.enable_cache(config) +``` + ### CacheMixin [[autodoc]] CacheMixin @@ -47,3 +74,9 @@ pipe.transformer.enable_cache(config) [[autodoc]] PyramidAttentionBroadcastConfig [[autodoc]] apply_pyramid_attention_broadcast + +### FasterCacheConfig + +[[autodoc]] FasterCacheConfig + +[[autodoc]] apply_faster_cache From ea18eb68e9b0d092b951c124f0e714cfadd307df Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 28 Jan 2025 22:23:43 +0100 Subject: [PATCH 19/26] add fastercache to CacheMixin --- src/diffusers/hooks/faster_cache.py | 1 - .../hooks/pyramid_attention_broadcast.py | 4 +-- src/diffusers/models/cache_utils.py | 25 ++++++++++++++++--- 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/src/diffusers/hooks/faster_cache.py b/src/diffusers/hooks/faster_cache.py index 3f54d305eb73..0860dfe3f949 100644 --- a/src/diffusers/hooks/faster_cache.py +++ b/src/diffusers/hooks/faster_cache.py @@ -30,7 +30,6 @@ _FASTER_CACHE_DENOISER_HOOK = "faster_cache_denoiser" _FASTER_CACHE_BLOCK_HOOK = "faster_cache_block" _ATTENTION_CLASSES = (Attention, MochiAttention) - _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ( "blocks.*attn", "transformer_blocks.*attn", diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py index de914b92b95b..020e5e4e4566 100644 --- a/src/diffusers/hooks/pyramid_attention_broadcast.py +++ b/src/diffusers/hooks/pyramid_attention_broadcast.py @@ -26,8 +26,8 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +_PYRAMID_ATTENTION_BROADCAST_HOOK = "pyramid_attention_broadcast" _ATTENTION_CLASSES = (Attention, MochiAttention) - _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks") _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) _CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks") @@ -311,4 +311,4 @@ def _apply_pyramid_attention_broadcast_hook( """ registry = HookRegistry.check_if_exists_or_initialize(module) hook = PyramidAttentionBroadcastHook(timestep_skip_range, block_skip_range, current_timestep_callback) - registry.register_hook(hook, "pyramid_attention_broadcast") + registry.register_hook(hook, _PYRAMID_ATTENTION_BROADCAST_HOOK) diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index f2c621b3011a..79bd8dc0b254 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -24,6 +24,7 @@ class CacheMixin: Supported caching techniques: - [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) + - [FasterCache](https://huggingface.co/papers/2410.19355) """ _cache_config = None @@ -59,17 +60,31 @@ def enable_cache(self, config) -> None: ``` """ - from ..hooks import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast + from ..hooks import ( + FasterCacheConfig, + PyramidAttentionBroadcastConfig, + apply_faster_cache, + apply_pyramid_attention_broadcast, + ) + + if self.is_cache_enabled: + raise ValueError( + f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first." + ) if isinstance(config, PyramidAttentionBroadcastConfig): apply_pyramid_attention_broadcast(self, config) + elif isinstance(config, FasterCacheConfig): + apply_faster_cache(self, config) else: raise ValueError(f"Cache config {type(config)} is not supported.") self._cache_config = config def disable_cache(self) -> None: - from ..hooks import HookRegistry, PyramidAttentionBroadcastConfig + from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig + from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK + from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK if self._cache_config is None: logger.warning("Caching techniques have not been enabled, so there's nothing to disable.") @@ -77,7 +92,11 @@ def disable_cache(self) -> None: if isinstance(self._cache_config, PyramidAttentionBroadcastConfig): registry = HookRegistry.check_if_exists_or_initialize(self) - registry.remove_hook("pyramid_attention_broadcast", recurse=True) + registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True) + elif isinstance(self._cache_config, FasterCacheConfig): + registry = HookRegistry.check_if_exists_or_initialize(self) + registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True) + registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True) else: raise ValueError(f"Cache config {type(self._cache_config)} is not supported.") From f92f45e95ff9a7527eb6ea162f1140cf885f8792 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 28 Jan 2025 23:13:52 +0100 Subject: [PATCH 20/26] update tests --- src/diffusers/hooks/faster_cache.py | 30 ++--- .../hooks/pyramid_attention_broadcast.py | 2 +- tests/pipelines/flux/test_pipeline_flux.py | 16 ++- .../hunyuan_video/test_hunyuan_video.py | 20 ++- tests/pipelines/latte/test_latte.py | 2 +- tests/pipelines/test_pipelines_common.py | 123 +++++++++--------- 6 files changed, 110 insertions(+), 83 deletions(-) diff --git a/src/diffusers/hooks/faster_cache.py b/src/diffusers/hooks/faster_cache.py index 0860dfe3f949..9980af949140 100644 --- a/src/diffusers/hooks/faster_cache.py +++ b/src/diffusers/hooks/faster_cache.py @@ -31,11 +31,12 @@ _FASTER_CACHE_BLOCK_HOOK = "faster_cache_block" _ATTENTION_CLASSES = (Attention, MochiAttention) _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ( - "blocks.*attn", - "transformer_blocks.*attn", - "single_transformer_blocks.*attn", + "^blocks.*attn", + "^transformer_blocks.*attn", + "^single_transformer_blocks.*attn" ) -_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks.*attn",) +_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",) +_TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS = ( "hidden_states", "encoder_hidden_states", @@ -276,9 +277,10 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: self.state.iteration > 0 and is_within_timestep_range and self.state.iteration % self.unconditional_batch_skip_range != 0 + and not self.is_guidance_distilled ) - if should_skip_uncond and not self.is_guidance_distilled: + if should_skip_uncond: is_any_kwarg_uncond = any(k in self.uncond_cond_input_kwargs_identifiers for k in kwargs.keys()) if is_any_kwarg_uncond: logger.debug("FasterCache - Skipping unconditional branch computation") @@ -483,7 +485,7 @@ def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: def apply_faster_cache( module: torch.nn.Module, - config: Optional[FasterCacheConfig] = None, + config: FasterCacheConfig ) -> None: r""" Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline. @@ -515,10 +517,6 @@ def apply_faster_cache( ``` """ - if config is None: - logger.warning("No FasterCacheConfig provided. Using default configuration.") - config = FasterCacheConfig() - if config.attention_weight_callback is None: # If the user has not provided a weight callback, we default to 0.5 for all timesteps. # In the paper, they recommend using a gradually increasing weight from 0 to 1 as the inference progresses, but @@ -568,7 +566,8 @@ def high_frequency_weight_callback(module: torch.nn.Module) -> float: for name, submodule in module.named_modules(): if not isinstance(submodule, _ATTENTION_CLASSES): continue - _apply_faster_cache_on_attention_class(name, submodule, config) + if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS): + _apply_faster_cache_on_attention_class(name, submodule, config) def _apply_faster_cache_on_denoiser(module: torch.nn.Module, config: FasterCacheConfig) -> None: @@ -590,13 +589,10 @@ def _apply_faster_cache_on_attention_class(name: str, module: Attention, config: is_spatial_self_attention = ( any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers) and config.spatial_attention_block_skip_range is not None - and not module.is_cross_attention + and not getattr(module, "is_cross_attention", False) ) is_temporal_self_attention = ( - any( - f"{identifier}." in name or identifier == name - for identifier in config.temporal_attention_block_identifiers - ) + any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers) and config.temporal_attention_block_skip_range is not None and not module.is_cross_attention ) @@ -633,7 +629,7 @@ def _apply_faster_cache_on_attention_class(name: str, module: Attention, config: registry.register_hook(hook, _FASTER_CACHE_BLOCK_HOOK) -# Reference: https://github.com/Vchitect/FasterCache/blob/fab32c15014636dc854948319c0a9a8d92c7acb4/scripts/latte/fastercache_sample_latte.py#L127C1-L143C39 +# Reference: https://github.com/Vchitect/FasterCache/blob/fab32c15014636dc854948319c0a9a8d92c7acb4/scripts/latte/faster_cache_sample_latte.py#L127C1-L143C39 @torch.no_grad() def _split_low_high_freq(x): fft = torch.fft.fft2(x) diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py index 020e5e4e4566..c815076795d6 100644 --- a/src/diffusers/hooks/pyramid_attention_broadcast.py +++ b/src/diffusers/hooks/pyramid_attention_broadcast.py @@ -177,7 +177,7 @@ def reset_state(self, module: torch.nn.Module) -> None: def apply_pyramid_attention_broadcast( module: torch.nn.Module, - config: PyramidAttentionBroadcastConfig, + config: PyramidAttentionBroadcastConfig ): r""" Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given pipeline. diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index fa8979c4ad52..56e241c4af4f 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -7,7 +7,13 @@ from huggingface_hub import hf_hub_download from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel -from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel +from diffusers import ( + AutoencoderKL, + FasterCacheConfig, + FlowMatchEulerDiscreteScheduler, + FluxPipeline, + FluxTransformer2DModel, +) from diffusers.utils.testing_utils import ( nightly, numpy_cosine_similarity_distance, @@ -41,6 +47,14 @@ class FluxPipelineFastTests( test_xformers_attention = False test_layerwise_casting = True + faster_cache_config = FasterCacheConfig( + spatial_attention_block_skip_range=2, + spatial_attention_timestep_skip_range=(-1, 901), + unconditional_batch_skip_range=2, + attention_weight_callback=lambda _: 0.5, + is_guidance_distilled=True, + ) + def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): torch.manual_seed(0) transformer = FluxTransformer2DModel( diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video.py b/tests/pipelines/hunyuan_video/test_hunyuan_video.py index 1ecfee666fcd..e5c5a705ee7b 100644 --- a/tests/pipelines/hunyuan_video/test_hunyuan_video.py +++ b/tests/pipelines/hunyuan_video/test_hunyuan_video.py @@ -21,6 +21,7 @@ from diffusers import ( AutoencoderKLHunyuanVideo, + FasterCacheConfig, FlowMatchEulerDiscreteScheduler, HunyuanVideoPipeline, HunyuanVideoTransformer3DModel, @@ -30,13 +31,20 @@ torch_device, ) -from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np +from ..test_pipelines_common import ( + FasterCacheTesterMixin, + PipelineTesterMixin, + PyramidAttentionBroadcastTesterMixin, + to_np, +) enable_full_determinism() -class HunyuanVideoPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase): +class HunyuanVideoPipelineFastTests( + PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase +): pipeline_class = HunyuanVideoPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) batch_params = frozenset(["prompt"]) @@ -55,6 +63,14 @@ class HunyuanVideoPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadca test_xformers_attention = False test_layerwise_casting = True + faster_cache_config = FasterCacheConfig( + spatial_attention_block_skip_range=2, + spatial_attention_timestep_skip_range=(-1, 901), + unconditional_batch_skip_range=2, + attention_weight_callback=lambda _: 0.5, + is_guidance_distilled=True, + ) + def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): torch.manual_seed(0) transformer = HunyuanVideoTransformer3DModel( diff --git a/tests/pipelines/latte/test_latte.py b/tests/pipelines/latte/test_latte.py index be731813cd1b..1d0fcce04bc0 100644 --- a/tests/pipelines/latte/test_latte.py +++ b/tests/pipelines/latte/test_latte.py @@ -75,7 +75,7 @@ class LattePipelineFastTests( cross_attention_block_identifiers=["transformer_blocks"], ) - fastercache_config = FasterCacheConfig( + faster_cache_config = FasterCacheConfig( spatial_attention_block_skip_range=2, temporal_attention_block_skip_range=2, spatial_attention_timestep_skip_range=(-1, 901), diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 6984779e98ad..8bd5dc7280c8 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -31,6 +31,7 @@ UNet2DConditionModel, apply_faster_cache, ) +from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin @@ -39,7 +40,6 @@ from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel from diffusers.models.unets.unet_i2vgen_xl import I2VGenXLUNet from diffusers.models.unets.unet_motion_model import UNetMotionModel -from diffusers.pipelines.fastercache_utils import FasterCacheBlockHook, FasterCacheDenoiserHook from diffusers.pipelines.pipeline_utils import StableDiffusionMixin from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import logging @@ -2463,69 +2463,78 @@ def test_pyramid_attention_broadcast_inference(self, expected_atol: float = 0.2) class FasterCacheTesterMixin: - fastercache_config = FasterCacheConfig( + faster_cache_config = FasterCacheConfig( spatial_attention_block_skip_range=2, spatial_attention_timestep_skip_range=(-1, 901), unconditional_batch_skip_range=2, attention_weight_callback=lambda _: 0.5, ) - def test_fastercache_basic_warning_or_errors_raised(self): + def test_faster_cache_basic_warning_or_errors_raised(self): components = self.get_dummy_components() - logger = logging.get_logger("diffusers.pipelines.faster_cache_utils") + logger = logging.get_logger("diffusers.hooks.faster_cache") logger.setLevel(logging.INFO) - # Check if warning is raised when no FasterCacheConfig is provided - pipe = self.pipeline_class(**components) - with CaptureLogger(logger) as cap_logger: - apply_faster_cache(pipe) - self.assertTrue("No FasterCacheConfig provided" in cap_logger.out) - # Check if warning is raise when no attention_weight_callback is provided pipe = self.pipeline_class(**components) with CaptureLogger(logger) as cap_logger: config = FasterCacheConfig(spatial_attention_block_skip_range=2, attention_weight_callback=None) - apply_faster_cache(pipe, config) + apply_faster_cache(pipe.transformer, config) self.assertTrue("No `attention_weight_callback` provided when enabling FasterCache" in cap_logger.out) # Check if error raised when unsupported tensor format used pipe = self.pipeline_class(**components) with self.assertRaises(ValueError): config = FasterCacheConfig(spatial_attention_block_skip_range=2, tensor_format="BFHWC") - apply_faster_cache(pipe, config) + apply_faster_cache(pipe.transformer, config) - def test_fastercache_inference(self, expected_atol: float = 0.1): + def test_faster_cache_inference(self, expected_atol: float = 0.1): device = "cpu" # ensure determinism for the device-dependent torch.Generator - num_layers = 2 - components = self.get_dummy_components(num_layers=num_layers) - pipe = self.pipeline_class(**components) - pipe = pipe.to(device) - pipe.set_progress_bar_config(disable=None) - inputs = self.get_dummy_inputs(device) - inputs["num_inference_steps"] = 4 - output = pipe(**inputs)[0] - original_image_slice = output.flatten() - original_image_slice = np.concatenate((original_image_slice[:8], original_image_slice[-8:])) + def create_pipe(): + torch.manual_seed(0) + num_layers = 2 + components = self.get_dummy_components(num_layers=num_layers) + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + return pipe + + def run_forward(pipe): + torch.manual_seed(0) + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 4 + return pipe(**inputs)[0] - apply_faster_cache(pipe, self.fastercache_config) + # Run inference without FasterCache + pipe = create_pipe() + output = run_forward(pipe).flatten() + original_image_slice = np.concatenate((output[:8], output[-8:])) - inputs = self.get_dummy_inputs(device) - inputs["num_inference_steps"] = 4 - output = pipe(**inputs)[0] - image_slice_fastercache_enabled = output.flatten() - image_slice_fastercache_enabled = np.concatenate( - (image_slice_fastercache_enabled[:8], image_slice_fastercache_enabled[-8:]) - ) + # Run inference with FasterCache enabled + self.faster_cache_config.current_timestep_callback = lambda: pipe.current_timestep + pipe = create_pipe() + pipe.transformer.enable_cache(self.faster_cache_config) + output = run_forward(pipe).flatten().flatten() + image_slice_faster_cache_enabled = np.concatenate((output[:8], output[-8:])) + + # Run inference with FasterCache disabled + pipe.transformer.disable_cache() + output = run_forward(pipe).flatten() + image_slice_faster_cache_disabled = np.concatenate((output[:8], output[-8:])) assert np.allclose( - original_image_slice, image_slice_fastercache_enabled, atol=expected_atol + original_image_slice, image_slice_faster_cache_enabled, atol=expected_atol ), "FasterCache outputs should not differ much in specified timestep range." + assert np.allclose( + original_image_slice, image_slice_faster_cache_disabled, atol=1e-4 + ), "Outputs from normal inference and after disabling cache should not differ." - def test_fastercache_state(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator + def test_faster_cache_state(self): + from diffusers.hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK + device = "cpu" # ensure determinism for the device-dependent torch.Generator num_layers = 0 num_single_layers = 0 dummy_component_kwargs = {} @@ -2541,22 +2550,24 @@ def test_fastercache_state(self): pipe = self.pipeline_class(**components) pipe.set_progress_bar_config(disable=None) - apply_faster_cache(pipe, self.fastercache_config) + self.faster_cache_config.current_timestep_callback = lambda: pipe.current_timestep + pipe.transformer.enable_cache(self.faster_cache_config) expected_hooks = 0 - if self.fastercache_config.spatial_attention_block_skip_range is not None: + if self.faster_cache_config.spatial_attention_block_skip_range is not None: expected_hooks += num_layers + num_single_layers - if self.fastercache_config.temporal_attention_block_skip_range is not None: + if self.faster_cache_config.temporal_attention_block_skip_range is not None: expected_hooks += num_layers + num_single_layers - # Check if fastercache denoiser hook is attached + # Check if faster_cache denoiser hook is attached denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet self.assertTrue( - hasattr(denoiser, "_diffusers_hook") and isinstance(denoiser._diffusers_hook, FasterCacheDenoiserHook), + hasattr(denoiser, "_diffusers_hook") + and isinstance(denoiser._diffusers_hook.get_hook(_FASTER_CACHE_DENOISER_HOOK), FasterCacheDenoiserHook), "Hook should be of type FasterCacheDenoiserHook.", ) - # Check if all blocks have fastercache block hook attached + # Check if all blocks have faster_cache block hook attached count = 0 for name, module in denoiser.named_modules(): if hasattr(module, "_diffusers_hook"): @@ -2565,38 +2576,32 @@ def test_fastercache_state(self): continue count += 1 self.assertTrue( - isinstance(module._diffusers_hook, FasterCacheBlockHook), + isinstance(module._diffusers_hook.get_hook(_FASTER_CACHE_BLOCK_HOOK), FasterCacheBlockHook), "Hook should be of type FasterCacheBlockHook.", ) self.assertEqual(count, expected_hooks, "Number of hooks should match expected number.") # Perform inference to ensure that states are updated correctly - def fastercache_state_check_callback(pipe, i, t, kwargs): + def faster_cache_state_check_callback(pipe, i, t, kwargs): for name, module in denoiser.named_modules(): if not hasattr(module, "_diffusers_hook"): continue - - state = module._fastercache_state - if name == "": # Root denoiser module - self.assertTrue(state.low_frequency_delta is not None, "Low frequency delta should be set.") - self.assertTrue(state.high_frequency_delta is not None, "High frequency delta should be set.") + state = module._diffusers_hook.get_hook(_FASTER_CACHE_DENOISER_HOOK).state + if not self.faster_cache_config.is_guidance_distilled: + self.assertTrue(state.low_frequency_delta is not None, "Low frequency delta should be set.") + self.assertTrue(state.high_frequency_delta is not None, "High frequency delta should be set.") else: # Internal blocks + state = module._diffusers_hook.get_hook(_FASTER_CACHE_BLOCK_HOOK).state self.assertTrue(state.cache is not None and len(state.cache) == 2, "Cache should be set.") - self.assertTrue(state.iteration == i + 1, "Hook iteration state should have updated during inference.") - self.assertTrue( - state.is_guidance_distilled is not None, - "`is_guidance_distilled` should be set to either True or False.", - ) - return {} inputs = self.get_dummy_inputs(device) inputs["num_inference_steps"] = 4 - inputs["callback_on_step_end"] = fastercache_state_check_callback + inputs["callback_on_step_end"] = faster_cache_state_check_callback _ = pipe(**inputs)[0] # After inference, reset_stateful_hooks is called within the pipeline, which should have reset the states @@ -2604,23 +2609,19 @@ def fastercache_state_check_callback(pipe, i, t, kwargs): if not hasattr(module, "_diffusers_hook"): continue - state = module._fastercache_state if name == "": # Root denoiser module + state = module._diffusers_hook.get_hook(_FASTER_CACHE_DENOISER_HOOK).state self.assertTrue(state.iteration == 0, "Iteration should be reset to 0.") self.assertTrue(state.low_frequency_delta is None, "Low frequency delta should be reset to None.") self.assertTrue(state.high_frequency_delta is None, "High frequency delta should be reset to None.") - self.assertTrue( - state.is_guidance_distilled is None, "`is_guidance_distilled` should be reset to None." - ) else: + # Internal blocks + state = module._diffusers_hook.get_hook(_FASTER_CACHE_BLOCK_HOOK).state self.assertTrue(state.iteration == 0, "Iteration should be reset to 0.") self.assertTrue(state.batch_size is None, "Batch size should be reset to None.") self.assertTrue(state.cache is None, "Cache should be reset to None.") - self.assertTrue( - state.is_guidance_distilled is None, "`is_guidance_distilled` should be reset to None." - ) # Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. From 251ade15c9c7388e60be7be0b262e2f449f4af58 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 29 Jan 2025 03:47:17 +0530 Subject: [PATCH 21/26] Apply suggestions from code review --- src/diffusers/models/embeddings.py | 2 -- src/diffusers/pipelines/latte/pipeline_latte.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index bbd8425e4d80..c0646437e54c 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -334,8 +334,6 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"): " `from_numpy` is no longer required." " Pass `output_type='pt' to use the new version now." ) - # TODO: Needs to be handled or errors out. Updated to 0.34.0 so that the benchmark code - # runs without issues, but this should be handled properly before merge. deprecate("output_type=='np'", "0.34.0", deprecation_message, standard_warn=False) return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos) if embed_dim % 2 != 0: diff --git a/src/diffusers/pipelines/latte/pipeline_latte.py b/src/diffusers/pipelines/latte/pipeline_latte.py index 6ec3eaf65005..e9a95e8be45c 100644 --- a/src/diffusers/pipelines/latte/pipeline_latte.py +++ b/src/diffusers/pipelines/latte/pipeline_latte.py @@ -858,7 +858,7 @@ def __call__( self._current_timestep = None - if output_type == "latent": + if output_type == "latents": deprecation_message = ( "Passing `output_type='latents'` is deprecated. Please pass `output_type='latent'` instead." ) From fa9a1f3050d0b8e07e3ba51827d07a3bbcab40f1 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 28 Jan 2025 23:17:53 +0100 Subject: [PATCH 22/26] make style --- src/diffusers/hooks/faster_cache.py | 7 ++----- src/diffusers/hooks/pyramid_attention_broadcast.py | 5 +---- tests/pipelines/test_pipelines_common.py | 1 - 3 files changed, 3 insertions(+), 10 deletions(-) diff --git a/src/diffusers/hooks/faster_cache.py b/src/diffusers/hooks/faster_cache.py index 9980af949140..722eb6832589 100644 --- a/src/diffusers/hooks/faster_cache.py +++ b/src/diffusers/hooks/faster_cache.py @@ -33,7 +33,7 @@ _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ( "^blocks.*attn", "^transformer_blocks.*attn", - "^single_transformer_blocks.*attn" + "^single_transformer_blocks.*attn", ) _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",) _TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS @@ -483,10 +483,7 @@ def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: return module -def apply_faster_cache( - module: torch.nn.Module, - config: FasterCacheConfig -) -> None: +def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> None: r""" Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline. diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py index c815076795d6..5d50f4b816c1 100644 --- a/src/diffusers/hooks/pyramid_attention_broadcast.py +++ b/src/diffusers/hooks/pyramid_attention_broadcast.py @@ -175,10 +175,7 @@ def reset_state(self, module: torch.nn.Module) -> None: return module -def apply_pyramid_attention_broadcast( - module: torch.nn.Module, - config: PyramidAttentionBroadcastConfig -): +def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAttentionBroadcastConfig): r""" Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given pipeline. diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 8bd5dc7280c8..29049df357cd 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -2609,7 +2609,6 @@ def faster_cache_state_check_callback(pipe, i, t, kwargs): if not hasattr(module, "_diffusers_hook"): continue - if name == "": # Root denoiser module state = module._diffusers_hook.get_hook(_FASTER_CACHE_DENOISER_HOOK).state From 7ad7cc89c11787c849a4a7872ff34161ed7eb6b2 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 28 Jan 2025 23:55:27 +0100 Subject: [PATCH 23/26] try to fix partial import error --- src/diffusers/models/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 3ef40ffb5783..57767a7b62de 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -33,7 +33,6 @@ from torch import Tensor, nn from .. import __version__ -from ..hooks import apply_layerwise_casting from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer from ..quantizers.quantization_config import QuantizationMethod from ..utils import ( @@ -402,6 +401,7 @@ def enable_layerwise_casting( non_blocking (`bool`, *optional*, defaults to `False`): If `True`, the weight casting operations are non-blocking. """ + from ..hooks import apply_layerwise_casting user_provided_patterns = True if skip_modules_pattern is None: From a20e8462eb4fe5e51cbcd9f578f0d820a4526528 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 19 Mar 2025 05:17:11 +0000 Subject: [PATCH 24/26] Apply style fixes --- src/diffusers/models/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index fefed4b6da1e..2b9c490d788c 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -37,7 +37,7 @@ from typing_extensions import Self from .. import __version__ -from ..hooks import apply_group_offloading, apply_layerwise_casting +from ..hooks import apply_group_offloading from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer from ..quantizers.quantization_config import QuantizationMethod from ..utils import ( From 2a342157068e5b0827f70c480508b2fd9482424d Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 21 Mar 2025 03:04:11 +0100 Subject: [PATCH 25/26] raise warning --- src/diffusers/hooks/faster_cache.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/diffusers/hooks/faster_cache.py b/src/diffusers/hooks/faster_cache.py index 722eb6832589..634635346474 100644 --- a/src/diffusers/hooks/faster_cache.py +++ b/src/diffusers/hooks/faster_cache.py @@ -514,6 +514,12 @@ def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> No ``` """ + logger.warning( + "FasterCache is a purely experimental feature and may not work as expected. Not all models support FasterCache. " + "The API is subject to change in future releases, with no guarantee of backward compatibility. Please report any issues at " + "https://github.com/huggingface/diffusers/issues." + ) + if config.attention_weight_callback is None: # If the user has not provided a weight callback, we default to 0.5 for all timesteps. # In the paper, they recommend using a gradually increasing weight from 0 to 1 as the inference progresses, but From 4a4bab8740bd0569c0b7194c79654cc1036cf06e Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 21 Mar 2025 04:02:37 +0100 Subject: [PATCH 26/26] update --- src/diffusers/models/modeling_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 8497f4dddfd3..be1ad1420a3e 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -37,7 +37,6 @@ from typing_extensions import Self from .. import __version__ -from ..hooks import apply_group_offloading from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer from ..quantizers.quantization_config import QuantizationMethod from ..utils import ( @@ -571,6 +570,8 @@ def enable_group_offload( ... ) ``` """ + from ..hooks import apply_group_offloading + if getattr(self, "enable_tiling", None) is not None and getattr(self, "use_tiling", False) and use_stream: msg = ( "Applying group offloading on autoencoders, with CUDA streams, may not work as expected if the first "