diff --git a/docs/source/en/api/cache.md b/docs/source/en/api/cache.md index 403dbf88b431..a6aa5445a845 100644 --- a/docs/source/en/api/cache.md +++ b/docs/source/en/api/cache.md @@ -38,6 +38,33 @@ config = PyramidAttentionBroadcastConfig( pipe.transformer.enable_cache(config) ``` +## Faster Cache + +[FasterCache](https://huggingface.co/papers/2410.19355) from Zhengyao Lv, Chenyang Si, Junhao Song, Zhenyu Yang, Yu Qiao, Ziwei Liu, Kwan-Yee K. Wong. + +FasterCache is a method that speeds up inference in diffusion transformers by: +- Reusing attention states between successive inference steps, due to high similarity between them +- Skipping unconditional branch prediction used in classifier-free guidance by revealing redundancies between unconditional and conditional branch outputs for the same timestep, and therefore approximating the unconditional branch output using the conditional branch output + +```python +import torch +from diffusers import CogVideoXPipeline, FasterCacheConfig + +pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +config = FasterCacheConfig( + spatial_attention_block_skip_range=2, + spatial_attention_timestep_skip_range=(-1, 681), + current_timestep_callback=lambda: pipe.current_timestep, + attention_weight_callback=lambda _: 0.3, + unconditional_batch_skip_range=5, + unconditional_batch_timestep_skip_range=(-1, 781), + tensor_format="BFCHW", +) +pipe.transformer.enable_cache(config) +``` + ### CacheMixin [[autodoc]] CacheMixin @@ -47,3 +74,9 @@ pipe.transformer.enable_cache(config) [[autodoc]] PyramidAttentionBroadcastConfig [[autodoc]] apply_pyramid_attention_broadcast + +### FasterCacheConfig + +[[autodoc]] FasterCacheConfig + +[[autodoc]] apply_faster_cache diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index ad658f1b14ff..bc0f3eca3623 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -131,8 +131,10 @@ else: _import_structure["hooks"].extend( [ + "FasterCacheConfig", "HookRegistry", "PyramidAttentionBroadcastConfig", + "apply_faster_cache", "apply_pyramid_attention_broadcast", ] ) @@ -703,7 +705,13 @@ except OptionalDependencyNotAvailable: from .utils.dummy_pt_objects import * # noqa F403 else: - from .hooks import HookRegistry, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast + from .hooks import ( + FasterCacheConfig, + HookRegistry, + PyramidAttentionBroadcastConfig, + apply_faster_cache, + apply_pyramid_attention_broadcast, + ) from .models import ( AllegroTransformer3DModel, AsymmetricAutoencoderKL, diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 56be0bbdf305..764ceb25b465 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -2,6 +2,7 @@ if is_torch_available(): + from .faster_cache import FasterCacheConfig, apply_faster_cache from .group_offloading import apply_group_offloading from .hooks import HookRegistry, ModelHook from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook diff --git a/src/diffusers/hooks/faster_cache.py b/src/diffusers/hooks/faster_cache.py new file mode 100644 index 000000000000..634635346474 --- /dev/null +++ b/src/diffusers/hooks/faster_cache.py @@ -0,0 +1,653 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from dataclasses import dataclass +from typing import Any, Callable, List, Optional, Tuple + +import torch + +from ..models.attention_processor import Attention, MochiAttention +from ..models.modeling_outputs import Transformer2DModelOutput +from ..utils import logging +from .hooks import HookRegistry, ModelHook + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +_FASTER_CACHE_DENOISER_HOOK = "faster_cache_denoiser" +_FASTER_CACHE_BLOCK_HOOK = "faster_cache_block" +_ATTENTION_CLASSES = (Attention, MochiAttention) +_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ( + "^blocks.*attn", + "^transformer_blocks.*attn", + "^single_transformer_blocks.*attn", +) +_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",) +_TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS +_UNCOND_COND_INPUT_KWARGS_IDENTIFIERS = ( + "hidden_states", + "encoder_hidden_states", + "timestep", + "attention_mask", + "encoder_attention_mask", +) + + +@dataclass +class FasterCacheConfig: + r""" + Configuration for [FasterCache](https://huggingface.co/papers/2410.19355). + + Attributes: + spatial_attention_block_skip_range (`int`, defaults to `2`): + Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will + be skipped `N - 1` times (i.e., cached attention states will be re-used) before computing the new attention + states again. + temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`): + Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will + be skipped `N - 1` times (i.e., cached attention states will be re-used) before computing the new attention + states again. + spatial_attention_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 681)`): + The timestep range within which the spatial attention computation can be skipped without a significant loss + in quality. This is to be determined by the user based on the underlying model. The first value in the + tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for + denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at + timestep 0). For the default values, this would mean that the spatial attention computation skipping will + be applicable only after denoising timestep 681 is reached, and continue until the end of the denoising + process. + temporal_attention_timestep_skip_range (`Tuple[float, float]`, *optional*, defaults to `None`): + The timestep range within which the temporal attention computation can be skipped without a significant + loss in quality. This is to be determined by the user based on the underlying model. The first value in the + tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for + denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at + timestep 0). + low_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(99, 901)`): + The timestep range within which the low frequency weight scaling update is applied. The first value in the + tuple is the lower bound and the second value is the upper bound of the timestep range. The callback + function for the update is called only within this range. + high_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(-1, 301)`): + The timestep range within which the high frequency weight scaling update is applied. The first value in the + tuple is the lower bound and the second value is the upper bound of the timestep range. The callback + function for the update is called only within this range. + alpha_low_frequency (`float`, defaults to `1.1`): + The weight to scale the low frequency updates by. This is used to approximate the unconditional branch from + the conditional branch outputs. + alpha_high_frequency (`float`, defaults to `1.1`): + The weight to scale the high frequency updates by. This is used to approximate the unconditional branch + from the conditional branch outputs. + unconditional_batch_skip_range (`int`, defaults to `5`): + Process the unconditional branch every `N` iterations. If this is set to `N`, the unconditional branch + computation will be skipped `N - 1` times (i.e., cached unconditional branch states will be re-used) before + computing the new unconditional branch states again. + unconditional_batch_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 641)`): + The timestep range within which the unconditional branch computation can be skipped without a significant + loss in quality. This is to be determined by the user based on the underlying model. The first value in the + tuple is the lower bound and the second value is the upper bound. + spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks.*attn1", "transformer_blocks.*attn1", "single_transformer_blocks.*attn1")`): + The identifiers to match the spatial attention blocks in the model. If the name of the block contains any + of these identifiers, FasterCache will be applied to that block. This can either be the full layer names, + partial layer names, or regex patterns. Matching will always be done using a regex match. + temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks.*attn1",)`): + The identifiers to match the temporal attention blocks in the model. If the name of the block contains any + of these identifiers, FasterCache will be applied to that block. This can either be the full layer names, + partial layer names, or regex patterns. Matching will always be done using a regex match. + attention_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`): + The callback function to determine the weight to scale the attention outputs by. This function should take + the attention module as input and return a float value. This is used to approximate the unconditional + branch from the conditional branch outputs. If not provided, the default weight is 0.5 for all timesteps. + Typically, as described in the paper, this weight should gradually increase from 0 to 1 as the inference + progresses. Users are encouraged to experiment and provide custom weight schedules that take into account + the number of inference steps and underlying model behaviour as denoising progresses. + low_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`): + The callback function to determine the weight to scale the low frequency updates by. If not provided, the + default weight is 1.1 for timesteps within the range specified (as described in the paper). + high_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`): + The callback function to determine the weight to scale the high frequency updates by. If not provided, the + default weight is 1.1 for timesteps within the range specified (as described in the paper). + tensor_format (`str`, defaults to `"BCFHW"`): + The format of the input tensors. This should be one of `"BCFHW"`, `"BFCHW"`, or `"BCHW"`. The format is + used to split individual latent frames in order for low and high frequency components to be computed. + is_guidance_distilled (`bool`, defaults to `False`): + Whether the model is guidance distilled or not. If the model is guidance distilled, FasterCache will not be + applied at the denoiser-level to skip the unconditional branch computation (as there is none). + _unconditional_conditional_input_kwargs_identifiers (`List[str]`, defaults to `("hidden_states", "encoder_hidden_states", "timestep", "attention_mask", "encoder_attention_mask")`): + The identifiers to match the input kwargs that contain the batchwise-concatenated unconditional and + conditional inputs. If the name of the input kwargs contains any of these identifiers, FasterCache will + split the inputs into unconditional and conditional branches. This must be a list of exact input kwargs + names that contain the batchwise-concatenated unconditional and conditional inputs. + """ + + # In the paper and codebase, they hardcode these values to 2. However, it can be made configurable + # after some testing. We default to 2 if these parameters are not provided. + spatial_attention_block_skip_range: int = 2 + temporal_attention_block_skip_range: Optional[int] = None + + spatial_attention_timestep_skip_range: Tuple[int, int] = (-1, 681) + temporal_attention_timestep_skip_range: Tuple[int, int] = (-1, 681) + + # Indicator functions for low/high frequency as mentioned in Equation 11 of the paper + low_frequency_weight_update_timestep_range: Tuple[int, int] = (99, 901) + high_frequency_weight_update_timestep_range: Tuple[int, int] = (-1, 301) + + # ⍺1 and ⍺2 as mentioned in Equation 11 of the paper + alpha_low_frequency: float = 1.1 + alpha_high_frequency: float = 1.1 + + # n as described in CFG-Cache explanation in the paper - dependant on the model + unconditional_batch_skip_range: int = 5 + unconditional_batch_timestep_skip_range: Tuple[int, int] = (-1, 641) + + spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS + + attention_weight_callback: Callable[[torch.nn.Module], float] = None + low_frequency_weight_callback: Callable[[torch.nn.Module], float] = None + high_frequency_weight_callback: Callable[[torch.nn.Module], float] = None + + tensor_format: str = "BCFHW" + is_guidance_distilled: bool = False + + current_timestep_callback: Callable[[], int] = None + + _unconditional_conditional_input_kwargs_identifiers: List[str] = _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS + + def __repr__(self) -> str: + return ( + f"FasterCacheConfig(\n" + f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n" + f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n" + f" spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range},\n" + f" temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range},\n" + f" low_frequency_weight_update_timestep_range={self.low_frequency_weight_update_timestep_range},\n" + f" high_frequency_weight_update_timestep_range={self.high_frequency_weight_update_timestep_range},\n" + f" alpha_low_frequency={self.alpha_low_frequency},\n" + f" alpha_high_frequency={self.alpha_high_frequency},\n" + f" unconditional_batch_skip_range={self.unconditional_batch_skip_range},\n" + f" unconditional_batch_timestep_skip_range={self.unconditional_batch_timestep_skip_range},\n" + f" spatial_attention_block_identifiers={self.spatial_attention_block_identifiers},\n" + f" temporal_attention_block_identifiers={self.temporal_attention_block_identifiers},\n" + f" tensor_format={self.tensor_format},\n" + f")" + ) + + +class FasterCacheDenoiserState: + r""" + State for [FasterCache](https://huggingface.co/papers/2410.19355) top-level denoiser module. + """ + + def __init__(self) -> None: + self.iteration: int = 0 + self.low_frequency_delta: torch.Tensor = None + self.high_frequency_delta: torch.Tensor = None + + def reset(self): + self.iteration = 0 + self.low_frequency_delta = None + self.high_frequency_delta = None + + +class FasterCacheBlockState: + r""" + State for [FasterCache](https://huggingface.co/papers/2410.19355). Every underlying block that FasterCache is + applied to will have an instance of this state. + """ + + def __init__(self) -> None: + self.iteration: int = 0 + self.batch_size: int = None + self.cache: Tuple[torch.Tensor, torch.Tensor] = None + + def reset(self): + self.iteration = 0 + self.batch_size = None + self.cache = None + + +class FasterCacheDenoiserHook(ModelHook): + _is_stateful = True + + def __init__( + self, + unconditional_batch_skip_range: int, + unconditional_batch_timestep_skip_range: Tuple[int, int], + tensor_format: str, + is_guidance_distilled: bool, + uncond_cond_input_kwargs_identifiers: List[str], + current_timestep_callback: Callable[[], int], + low_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor], + high_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor], + ) -> None: + super().__init__() + + self.unconditional_batch_skip_range = unconditional_batch_skip_range + self.unconditional_batch_timestep_skip_range = unconditional_batch_timestep_skip_range + # We can't easily detect what args are to be split in unconditional and conditional branches. We + # can only do it for kwargs, hence they are the only ones we split. The args are passed as-is. + # If a model is to be made compatible with FasterCache, the user must ensure that the inputs that + # contain batchwise-concatenated unconditional and conditional inputs are passed as kwargs. + self.uncond_cond_input_kwargs_identifiers = uncond_cond_input_kwargs_identifiers + self.tensor_format = tensor_format + self.is_guidance_distilled = is_guidance_distilled + + self.current_timestep_callback = current_timestep_callback + self.low_frequency_weight_callback = low_frequency_weight_callback + self.high_frequency_weight_callback = high_frequency_weight_callback + + def initialize_hook(self, module): + self.state = FasterCacheDenoiserState() + return module + + @staticmethod + def _get_cond_input(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # Note: this method assumes that the input tensor is batchwise-concatenated with unconditional inputs + # followed by conditional inputs. + _, cond = input.chunk(2, dim=0) + return cond + + def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: + # Split the unconditional and conditional inputs. We only want to infer the conditional branch if the + # requirements for skipping the unconditional branch are met as described in the paper. + # We skip the unconditional branch only if the following conditions are met: + # 1. We have completed at least one iteration of the denoiser + # 2. The current timestep is within the range specified by the user. This is the optimal timestep range + # where approximating the unconditional branch from the computation of the conditional branch is possible + # without a significant loss in quality. + # 3. The current iteration is not a multiple of the unconditional batch skip range. This is done so that + # we compute the unconditional branch at least once every few iterations to ensure minimal quality loss. + is_within_timestep_range = ( + self.unconditional_batch_timestep_skip_range[0] + < self.current_timestep_callback() + < self.unconditional_batch_timestep_skip_range[1] + ) + should_skip_uncond = ( + self.state.iteration > 0 + and is_within_timestep_range + and self.state.iteration % self.unconditional_batch_skip_range != 0 + and not self.is_guidance_distilled + ) + + if should_skip_uncond: + is_any_kwarg_uncond = any(k in self.uncond_cond_input_kwargs_identifiers for k in kwargs.keys()) + if is_any_kwarg_uncond: + logger.debug("FasterCache - Skipping unconditional branch computation") + args = tuple([self._get_cond_input(arg) if torch.is_tensor(arg) else arg for arg in args]) + kwargs = { + k: v if k not in self.uncond_cond_input_kwargs_identifiers else self._get_cond_input(v) + for k, v in kwargs.items() + } + + output = self.fn_ref.original_forward(*args, **kwargs) + + if self.is_guidance_distilled: + self.state.iteration += 1 + return output + + if torch.is_tensor(output): + hidden_states = output + elif isinstance(output, (tuple, Transformer2DModelOutput)): + hidden_states = output[0] + + batch_size = hidden_states.size(0) + + if should_skip_uncond: + self.state.low_frequency_delta = self.state.low_frequency_delta * self.low_frequency_weight_callback( + module + ) + self.state.high_frequency_delta = self.state.high_frequency_delta * self.high_frequency_weight_callback( + module + ) + + if self.tensor_format == "BCFHW": + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW": + hidden_states = hidden_states.flatten(0, 1) + + low_freq_cond, high_freq_cond = _split_low_high_freq(hidden_states.float()) + + # Approximate/compute the unconditional branch outputs as described in Equation 9 and 10 of the paper + low_freq_uncond = self.state.low_frequency_delta + low_freq_cond + high_freq_uncond = self.state.high_frequency_delta + high_freq_cond + uncond_freq = low_freq_uncond + high_freq_uncond + + uncond_states = torch.fft.ifftshift(uncond_freq) + uncond_states = torch.fft.ifft2(uncond_states).real + + if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW": + uncond_states = uncond_states.unflatten(0, (batch_size, -1)) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)) + if self.tensor_format == "BCFHW": + uncond_states = uncond_states.permute(0, 2, 1, 3, 4) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + # Concatenate the approximated unconditional and predicted conditional branches + uncond_states = uncond_states.to(hidden_states.dtype) + hidden_states = torch.cat([uncond_states, hidden_states], dim=0) + else: + uncond_states, cond_states = hidden_states.chunk(2, dim=0) + if self.tensor_format == "BCFHW": + uncond_states = uncond_states.permute(0, 2, 1, 3, 4) + cond_states = cond_states.permute(0, 2, 1, 3, 4) + if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW": + uncond_states = uncond_states.flatten(0, 1) + cond_states = cond_states.flatten(0, 1) + + low_freq_uncond, high_freq_uncond = _split_low_high_freq(uncond_states.float()) + low_freq_cond, high_freq_cond = _split_low_high_freq(cond_states.float()) + self.state.low_frequency_delta = low_freq_uncond - low_freq_cond + self.state.high_frequency_delta = high_freq_uncond - high_freq_cond + + self.state.iteration += 1 + if torch.is_tensor(output): + output = hidden_states + elif isinstance(output, tuple): + output = (hidden_states, *output[1:]) + else: + output.sample = hidden_states + + return output + + def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: + self.state.reset() + return module + + +class FasterCacheBlockHook(ModelHook): + _is_stateful = True + + def __init__( + self, + block_skip_range: int, + timestep_skip_range: Tuple[int, int], + is_guidance_distilled: bool, + weight_callback: Callable[[torch.nn.Module], float], + current_timestep_callback: Callable[[], int], + ) -> None: + super().__init__() + + self.block_skip_range = block_skip_range + self.timestep_skip_range = timestep_skip_range + self.is_guidance_distilled = is_guidance_distilled + + self.weight_callback = weight_callback + self.current_timestep_callback = current_timestep_callback + + def initialize_hook(self, module): + self.state = FasterCacheBlockState() + return module + + def _compute_approximated_attention_output( + self, t_2_output: torch.Tensor, t_output: torch.Tensor, weight: float, batch_size: int + ) -> torch.Tensor: + if t_2_output.size(0) != batch_size: + # The cache t_2_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just + # take the conditional branch outputs. + assert t_2_output.size(0) == 2 * batch_size + t_2_output = t_2_output[batch_size:] + if t_output.size(0) != batch_size: + # The cache t_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just + # take the conditional branch outputs. + assert t_output.size(0) == 2 * batch_size + t_output = t_output[batch_size:] + return t_output + (t_output - t_2_output) * weight + + def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: + batch_size = [ + *[arg.size(0) for arg in args if torch.is_tensor(arg)], + *[v.size(0) for v in kwargs.values() if torch.is_tensor(v)], + ][0] + if self.state.batch_size is None: + # Will be updated on first forward pass through the denoiser + self.state.batch_size = batch_size + + # If we have to skip due to the skip conditions, then let's skip as expected. + # But, we can't skip if the denoiser wants to infer both unconditional and conditional branches. This + # is because the expected output shapes of attention layer will not match if we only return values from + # the cache (which only caches conditional branch outputs). So, if state.batch_size (which is the true + # unconditional-conditional batch size) is same as the current batch size, we don't perform the layer + # skip. Otherwise, we conditionally skip the layer based on what state.skip_callback returns. + is_within_timestep_range = ( + self.timestep_skip_range[0] < self.current_timestep_callback() < self.timestep_skip_range[1] + ) + if not is_within_timestep_range: + should_skip_attention = False + else: + should_compute_attention = self.state.iteration > 0 and self.state.iteration % self.block_skip_range == 0 + should_skip_attention = not should_compute_attention + if should_skip_attention: + should_skip_attention = self.is_guidance_distilled or self.state.batch_size != batch_size + + if should_skip_attention: + logger.debug("FasterCache - Skipping attention and using approximation") + if torch.is_tensor(self.state.cache[-1]): + t_2_output, t_output = self.state.cache + weight = self.weight_callback(module) + output = self._compute_approximated_attention_output(t_2_output, t_output, weight, batch_size) + else: + # The cache contains multiple tensors from past N iterations (N=2 for FasterCache). We need to handle all of them. + # Diffusers blocks can return multiple tensors - let's call them [A, B, C, ...] for simplicity. + # In our cache, we would have [[A_1, B_1, C_1, ...], [A_2, B_2, C_2, ...], ...] where each list is the output from + # a forward pass of the block. We need to compute the approximated output for each of these tensors. + # The zip(*state.cache) operation will give us [(A_1, A_2, ...), (B_1, B_2, ...), (C_1, C_2, ...), ...] which + # allows us to compute the approximated attention output for each tensor in the cache. + output = () + for t_2_output, t_output in zip(*self.state.cache): + result = self._compute_approximated_attention_output( + t_2_output, t_output, self.weight_callback(module), batch_size + ) + output += (result,) + else: + logger.debug("FasterCache - Computing attention") + output = self.fn_ref.original_forward(*args, **kwargs) + + # Note that the following condition for getting hidden_states should suffice since Diffusers blocks either return + # a single hidden_states tensor, or a tuple of (hidden_states, encoder_hidden_states) tensors. We need to handle + # both cases. + if torch.is_tensor(output): + cache_output = output + if not self.is_guidance_distilled and cache_output.size(0) == self.state.batch_size: + # The output here can be both unconditional-conditional branch outputs or just conditional branch outputs. + # This is determined at the higher-level denoiser module. We only want to cache the conditional branch outputs. + cache_output = cache_output.chunk(2, dim=0)[1] + else: + # Cache all return values and perform the same operation as above + cache_output = () + for out in output: + if not self.is_guidance_distilled and out.size(0) == self.state.batch_size: + out = out.chunk(2, dim=0)[1] + cache_output += (out,) + + if self.state.cache is None: + self.state.cache = [cache_output, cache_output] + else: + self.state.cache = [self.state.cache[-1], cache_output] + + self.state.iteration += 1 + return output + + def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: + self.state.reset() + return module + + +def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> None: + r""" + Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline. + + Args: + pipeline (`DiffusionPipeline`): + The diffusion pipeline to apply FasterCache to. + config (`Optional[FasterCacheConfig]`, `optional`, defaults to `None`): + The configuration to use for FasterCache. + + Example: + ```python + >>> import torch + >>> from diffusers import CogVideoXPipeline, FasterCacheConfig, apply_faster_cache + + >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> config = FasterCacheConfig( + ... spatial_attention_block_skip_range=2, + ... spatial_attention_timestep_skip_range=(-1, 681), + ... low_frequency_weight_update_timestep_range=(99, 641), + ... high_frequency_weight_update_timestep_range=(-1, 301), + ... spatial_attention_block_identifiers=["transformer_blocks"], + ... attention_weight_callback=lambda _: 0.3, + ... tensor_format="BFCHW", + ... ) + >>> apply_faster_cache(pipe.transformer, config) + ``` + """ + + logger.warning( + "FasterCache is a purely experimental feature and may not work as expected. Not all models support FasterCache. " + "The API is subject to change in future releases, with no guarantee of backward compatibility. Please report any issues at " + "https://github.com/huggingface/diffusers/issues." + ) + + if config.attention_weight_callback is None: + # If the user has not provided a weight callback, we default to 0.5 for all timesteps. + # In the paper, they recommend using a gradually increasing weight from 0 to 1 as the inference progresses, but + # this depends from model-to-model. It is required by the user to provide a weight callback if they want to + # use a different weight function. Defaulting to 0.5 works well in practice for most cases. + logger.warning( + "No `attention_weight_callback` provided when enabling FasterCache. Defaulting to using a weight of 0.5 for all timesteps." + ) + config.attention_weight_callback = lambda _: 0.5 + + if config.low_frequency_weight_callback is None: + logger.debug( + "Low frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper." + ) + + def low_frequency_weight_callback(module: torch.nn.Module) -> float: + is_within_range = ( + config.low_frequency_weight_update_timestep_range[0] + < config.current_timestep_callback() + < config.low_frequency_weight_update_timestep_range[1] + ) + return config.alpha_low_frequency if is_within_range else 1.0 + + config.low_frequency_weight_callback = low_frequency_weight_callback + + if config.high_frequency_weight_callback is None: + logger.debug( + "High frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper." + ) + + def high_frequency_weight_callback(module: torch.nn.Module) -> float: + is_within_range = ( + config.high_frequency_weight_update_timestep_range[0] + < config.current_timestep_callback() + < config.high_frequency_weight_update_timestep_range[1] + ) + return config.alpha_high_frequency if is_within_range else 1.0 + + config.high_frequency_weight_callback = high_frequency_weight_callback + + supported_tensor_formats = ["BCFHW", "BFCHW", "BCHW"] # TODO(aryan): Support BSC for LTX Video + if config.tensor_format not in supported_tensor_formats: + raise ValueError(f"`tensor_format` must be one of {supported_tensor_formats}, but got {config.tensor_format}.") + + _apply_faster_cache_on_denoiser(module, config) + + for name, submodule in module.named_modules(): + if not isinstance(submodule, _ATTENTION_CLASSES): + continue + if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS): + _apply_faster_cache_on_attention_class(name, submodule, config) + + +def _apply_faster_cache_on_denoiser(module: torch.nn.Module, config: FasterCacheConfig) -> None: + hook = FasterCacheDenoiserHook( + config.unconditional_batch_skip_range, + config.unconditional_batch_timestep_skip_range, + config.tensor_format, + config.is_guidance_distilled, + config._unconditional_conditional_input_kwargs_identifiers, + config.current_timestep_callback, + config.low_frequency_weight_callback, + config.high_frequency_weight_callback, + ) + registry = HookRegistry.check_if_exists_or_initialize(module) + registry.register_hook(hook, _FASTER_CACHE_DENOISER_HOOK) + + +def _apply_faster_cache_on_attention_class(name: str, module: Attention, config: FasterCacheConfig) -> None: + is_spatial_self_attention = ( + any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers) + and config.spatial_attention_block_skip_range is not None + and not getattr(module, "is_cross_attention", False) + ) + is_temporal_self_attention = ( + any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers) + and config.temporal_attention_block_skip_range is not None + and not module.is_cross_attention + ) + + block_skip_range, timestep_skip_range, block_type = None, None, None + if is_spatial_self_attention: + block_skip_range = config.spatial_attention_block_skip_range + timestep_skip_range = config.spatial_attention_timestep_skip_range + block_type = "spatial" + elif is_temporal_self_attention: + block_skip_range = config.temporal_attention_block_skip_range + timestep_skip_range = config.temporal_attention_timestep_skip_range + block_type = "temporal" + + if block_skip_range is None or timestep_skip_range is None: + logger.debug( + f'Unable to apply FasterCache to the selected layer: "{name}" because it does ' + f"not match any of the required criteria for spatial or temporal attention layers. Note, " + f"however, that this layer may still be valid for applying PAB. Please specify the correct " + f"block identifiers in the configuration or use the specialized `apply_faster_cache_on_module` " + f"function to apply FasterCache to this layer." + ) + return + + logger.debug(f"Enabling FasterCache ({block_type}) for layer: {name}") + hook = FasterCacheBlockHook( + block_skip_range, + timestep_skip_range, + config.is_guidance_distilled, + config.attention_weight_callback, + config.current_timestep_callback, + ) + registry = HookRegistry.check_if_exists_or_initialize(module) + registry.register_hook(hook, _FASTER_CACHE_BLOCK_HOOK) + + +# Reference: https://github.com/Vchitect/FasterCache/blob/fab32c15014636dc854948319c0a9a8d92c7acb4/scripts/latte/faster_cache_sample_latte.py#L127C1-L143C39 +@torch.no_grad() +def _split_low_high_freq(x): + fft = torch.fft.fft2(x) + fft_shifted = torch.fft.fftshift(fft) + height, width = x.shape[-2:] + radius = min(height, width) // 5 + + y_grid, x_grid = torch.meshgrid(torch.arange(height), torch.arange(width)) + center_x, center_y = width // 2, height // 2 + mask = (x_grid - center_x) ** 2 + (y_grid - center_y) ** 2 <= radius**2 + + low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(x.device) + high_freq_mask = ~low_freq_mask + + low_freq_fft = fft_shifted * low_freq_mask + high_freq_fft = fft_shifted * high_freq_mask + + return low_freq_fft, high_freq_fft diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py index 9f8597d52f8c..5d50f4b816c1 100644 --- a/src/diffusers/hooks/pyramid_attention_broadcast.py +++ b/src/diffusers/hooks/pyramid_attention_broadcast.py @@ -26,8 +26,8 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +_PYRAMID_ATTENTION_BROADCAST_HOOK = "pyramid_attention_broadcast" _ATTENTION_CLASSES = (Attention, MochiAttention) - _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks") _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) _CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks") @@ -87,7 +87,7 @@ class PyramidAttentionBroadcastConfig: def __repr__(self) -> str: return ( - f"PyramidAttentionBroadcastConfig(" + f"PyramidAttentionBroadcastConfig(\n" f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n" f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n" f" cross_attention_block_skip_range={self.cross_attention_block_skip_range},\n" @@ -175,10 +175,7 @@ def reset_state(self, module: torch.nn.Module) -> None: return module -def apply_pyramid_attention_broadcast( - module: torch.nn.Module, - config: PyramidAttentionBroadcastConfig, -): +def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAttentionBroadcastConfig): r""" Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given pipeline. @@ -311,4 +308,4 @@ def _apply_pyramid_attention_broadcast_hook( """ registry = HookRegistry.check_if_exists_or_initialize(module) hook = PyramidAttentionBroadcastHook(timestep_skip_range, block_skip_range, current_timestep_callback) - registry.register_hook(hook, "pyramid_attention_broadcast") + registry.register_hook(hook, _PYRAMID_ATTENTION_BROADCAST_HOOK) diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index f2c621b3011a..79bd8dc0b254 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -24,6 +24,7 @@ class CacheMixin: Supported caching techniques: - [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) + - [FasterCache](https://huggingface.co/papers/2410.19355) """ _cache_config = None @@ -59,17 +60,31 @@ def enable_cache(self, config) -> None: ``` """ - from ..hooks import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast + from ..hooks import ( + FasterCacheConfig, + PyramidAttentionBroadcastConfig, + apply_faster_cache, + apply_pyramid_attention_broadcast, + ) + + if self.is_cache_enabled: + raise ValueError( + f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first." + ) if isinstance(config, PyramidAttentionBroadcastConfig): apply_pyramid_attention_broadcast(self, config) + elif isinstance(config, FasterCacheConfig): + apply_faster_cache(self, config) else: raise ValueError(f"Cache config {type(config)} is not supported.") self._cache_config = config def disable_cache(self) -> None: - from ..hooks import HookRegistry, PyramidAttentionBroadcastConfig + from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig + from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK + from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK if self._cache_config is None: logger.warning("Caching techniques have not been enabled, so there's nothing to disable.") @@ -77,7 +92,11 @@ def disable_cache(self) -> None: if isinstance(self._cache_config, PyramidAttentionBroadcastConfig): registry = HookRegistry.check_if_exists_or_initialize(self) - registry.remove_hook("pyramid_attention_broadcast", recurse=True) + registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True) + elif isinstance(self._cache_config, FasterCacheConfig): + registry = HookRegistry.check_if_exists_or_initialize(self) + registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True) + registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True) else: raise ValueError(f"Cache config {type(self._cache_config)} is not supported.") diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 006ea8b4013f..b1e14ca6a7fe 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -336,7 +336,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"): " `from_numpy` is no longer required." " Pass `output_type='pt' to use the new version now." ) - deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False) + deprecate("output_type=='np'", "0.34.0", deprecation_message, standard_warn=False) return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos) if embed_dim % 2 != 0: raise ValueError("embed_dim must be divisible by 2") diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 351ce7b1772c..be1ad1420a3e 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -37,7 +37,6 @@ from typing_extensions import Self from .. import __version__ -from ..hooks import apply_group_offloading, apply_layerwise_casting from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer from ..quantizers.quantization_config import QuantizationMethod from ..utils import ( @@ -504,6 +503,7 @@ def enable_layerwise_casting( non_blocking (`bool`, *optional*, defaults to `False`): If `True`, the weight casting operations are non-blocking. """ + from ..hooks import apply_layerwise_casting user_provided_patterns = True if skip_modules_pattern is None: @@ -570,6 +570,8 @@ def enable_group_offload( ... ) ``` """ + from ..hooks import apply_group_offloading + if getattr(self, "enable_tiling", None) is not None and getattr(self, "use_tiling", False) and use_stream: msg = ( "Applying group offloading on autoencoders, with CUDA streams, may not work as expected if the first " diff --git a/src/diffusers/pipelines/latte/pipeline_latte.py b/src/diffusers/pipelines/latte/pipeline_latte.py index 578f373e8e3f..e9a95e8be45c 100644 --- a/src/diffusers/pipelines/latte/pipeline_latte.py +++ b/src/diffusers/pipelines/latte/pipeline_latte.py @@ -817,7 +817,7 @@ def __call__( # predict noise model_output noise_pred = self.transformer( - latent_model_input, + hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, timestep=current_timestep, enable_temporal_attentions=enable_temporal_attentions, diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 31d2e1e2d78d..3f443b5b40bf 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -2,6 +2,21 @@ from ..utils import DummyObject, requires_backends +class FasterCacheConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class HookRegistry(metaclass=DummyObject): _backends = ["torch"] @@ -32,6 +47,10 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +def apply_faster_cache(*args, **kwargs): + requires_backends(apply_faster_cache, ["torch"]) + + def apply_pyramid_attention_broadcast(*args, **kwargs): requires_backends(apply_pyramid_attention_broadcast, ["torch"]) diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py index c09b00e1d16b..388dc9ef7ec4 100644 --- a/tests/pipelines/cogvideo/test_cogvideox.py +++ b/tests/pipelines/cogvideo/test_cogvideox.py @@ -31,6 +31,7 @@ from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import ( + FasterCacheTesterMixin, PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, check_qkv_fusion_matches_attn_procs_length, @@ -42,7 +43,9 @@ enable_full_determinism() -class CogVideoXPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase): +class CogVideoXPipelineFastTests( + PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase +): pipeline_class = CogVideoXPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index e878216d1bab..6a560367a5b8 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -7,7 +7,13 @@ from huggingface_hub import hf_hub_download from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel -from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel +from diffusers import ( + AutoencoderKL, + FasterCacheConfig, + FlowMatchEulerDiscreteScheduler, + FluxPipeline, + FluxTransformer2DModel, +) from diffusers.utils.testing_utils import ( backend_empty_cache, nightly, @@ -18,6 +24,7 @@ ) from ..test_pipelines_common import ( + FasterCacheTesterMixin, FluxIPAdapterTesterMixin, PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, @@ -27,7 +34,11 @@ class FluxPipelineFastTests( - unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin, PyramidAttentionBroadcastTesterMixin + unittest.TestCase, + PipelineTesterMixin, + FluxIPAdapterTesterMixin, + PyramidAttentionBroadcastTesterMixin, + FasterCacheTesterMixin, ): pipeline_class = FluxPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) @@ -38,6 +49,14 @@ class FluxPipelineFastTests( test_layerwise_casting = True test_group_offloading = True + faster_cache_config = FasterCacheConfig( + spatial_attention_block_skip_range=2, + spatial_attention_timestep_skip_range=(-1, 901), + unconditional_batch_skip_range=2, + attention_weight_callback=lambda _: 0.5, + is_guidance_distilled=True, + ) + def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): torch.manual_seed(0) transformer = FluxTransformer2DModel( diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video.py b/tests/pipelines/hunyuan_video/test_hunyuan_video.py index dd0f6437df87..aa4f045966c3 100644 --- a/tests/pipelines/hunyuan_video/test_hunyuan_video.py +++ b/tests/pipelines/hunyuan_video/test_hunyuan_video.py @@ -21,6 +21,7 @@ from diffusers import ( AutoencoderKLHunyuanVideo, + FasterCacheConfig, FlowMatchEulerDiscreteScheduler, HunyuanVideoPipeline, HunyuanVideoTransformer3DModel, @@ -30,13 +31,20 @@ torch_device, ) -from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np +from ..test_pipelines_common import ( + FasterCacheTesterMixin, + PipelineTesterMixin, + PyramidAttentionBroadcastTesterMixin, + to_np, +) enable_full_determinism() -class HunyuanVideoPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase): +class HunyuanVideoPipelineFastTests( + PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase +): pipeline_class = HunyuanVideoPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) batch_params = frozenset(["prompt"]) @@ -56,6 +64,14 @@ class HunyuanVideoPipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadca test_layerwise_casting = True test_group_offloading = True + faster_cache_config = FasterCacheConfig( + spatial_attention_block_skip_range=2, + spatial_attention_timestep_skip_range=(-1, 901), + unconditional_batch_skip_range=2, + attention_weight_callback=lambda _: 0.5, + is_guidance_distilled=True, + ) + def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): torch.manual_seed(0) transformer = HunyuanVideoTransformer3DModel( diff --git a/tests/pipelines/latte/test_latte.py b/tests/pipelines/latte/test_latte.py index 7530f06d9d18..80d370647f57 100644 --- a/tests/pipelines/latte/test_latte.py +++ b/tests/pipelines/latte/test_latte.py @@ -25,6 +25,7 @@ from diffusers import ( AutoencoderKL, DDIMScheduler, + FasterCacheConfig, LattePipeline, LatteTransformer3DModel, PyramidAttentionBroadcastConfig, @@ -40,13 +41,20 @@ ) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np +from ..test_pipelines_common import ( + FasterCacheTesterMixin, + PipelineTesterMixin, + PyramidAttentionBroadcastTesterMixin, + to_np, +) enable_full_determinism() -class LattePipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase): +class LattePipelineFastTests( + PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase +): pipeline_class = LattePipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS @@ -69,6 +77,15 @@ class LattePipelineFastTests(PipelineTesterMixin, PyramidAttentionBroadcastTeste cross_attention_block_identifiers=["transformer_blocks"], ) + faster_cache_config = FasterCacheConfig( + spatial_attention_block_skip_range=2, + temporal_attention_block_skip_range=2, + spatial_attention_timestep_skip_range=(-1, 901), + temporal_attention_timestep_skip_range=(-1, 901), + unconditional_batch_skip_range=2, + attention_weight_callback=lambda _: 0.5, + ) + def get_dummy_components(self, num_layers: int = 1): torch.manual_seed(0) transformer = LatteTransformer3DModel( diff --git a/tests/pipelines/mochi/test_mochi.py b/tests/pipelines/mochi/test_mochi.py index 32d09155cdeb..ea2d015af52a 100644 --- a/tests/pipelines/mochi/test_mochi.py +++ b/tests/pipelines/mochi/test_mochi.py @@ -33,13 +33,13 @@ ) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin, to_np +from ..test_pipelines_common import FasterCacheTesterMixin, PipelineTesterMixin, to_np enable_full_determinism() -class MochiPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unittest.TestCase): pipeline_class = MochiPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS @@ -59,13 +59,13 @@ class MochiPipelineFastTests(PipelineTesterMixin, unittest.TestCase): test_layerwise_casting = True test_group_offloading = True - def get_dummy_components(self): + def get_dummy_components(self, num_layers: int = 2): torch.manual_seed(0) transformer = MochiTransformer3DModel( patch_size=2, num_attention_heads=2, attention_head_dim=8, - num_layers=2, + num_layers=num_layers, pooled_projection_dim=16, in_channels=12, out_channels=None, diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index d965a4090d72..d069def66ecf 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -23,13 +23,16 @@ ConsistencyDecoderVAE, DDIMScheduler, DiffusionPipeline, + FasterCacheConfig, KolorsPipeline, PyramidAttentionBroadcastConfig, StableDiffusionPipeline, StableDiffusionXLPipeline, UNet2DConditionModel, + apply_faster_cache, ) from diffusers.hooks import apply_group_offloading +from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin @@ -2551,6 +2554,167 @@ def test_pyramid_attention_broadcast_inference(self, expected_atol: float = 0.2) ), "Outputs from normal inference and after disabling cache should not differ." +class FasterCacheTesterMixin: + faster_cache_config = FasterCacheConfig( + spatial_attention_block_skip_range=2, + spatial_attention_timestep_skip_range=(-1, 901), + unconditional_batch_skip_range=2, + attention_weight_callback=lambda _: 0.5, + ) + + def test_faster_cache_basic_warning_or_errors_raised(self): + components = self.get_dummy_components() + + logger = logging.get_logger("diffusers.hooks.faster_cache") + logger.setLevel(logging.INFO) + + # Check if warning is raise when no attention_weight_callback is provided + pipe = self.pipeline_class(**components) + with CaptureLogger(logger) as cap_logger: + config = FasterCacheConfig(spatial_attention_block_skip_range=2, attention_weight_callback=None) + apply_faster_cache(pipe.transformer, config) + self.assertTrue("No `attention_weight_callback` provided when enabling FasterCache" in cap_logger.out) + + # Check if error raised when unsupported tensor format used + pipe = self.pipeline_class(**components) + with self.assertRaises(ValueError): + config = FasterCacheConfig(spatial_attention_block_skip_range=2, tensor_format="BFHWC") + apply_faster_cache(pipe.transformer, config) + + def test_faster_cache_inference(self, expected_atol: float = 0.1): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + def create_pipe(): + torch.manual_seed(0) + num_layers = 2 + components = self.get_dummy_components(num_layers=num_layers) + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + return pipe + + def run_forward(pipe): + torch.manual_seed(0) + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 4 + return pipe(**inputs)[0] + + # Run inference without FasterCache + pipe = create_pipe() + output = run_forward(pipe).flatten() + original_image_slice = np.concatenate((output[:8], output[-8:])) + + # Run inference with FasterCache enabled + self.faster_cache_config.current_timestep_callback = lambda: pipe.current_timestep + pipe = create_pipe() + pipe.transformer.enable_cache(self.faster_cache_config) + output = run_forward(pipe).flatten().flatten() + image_slice_faster_cache_enabled = np.concatenate((output[:8], output[-8:])) + + # Run inference with FasterCache disabled + pipe.transformer.disable_cache() + output = run_forward(pipe).flatten() + image_slice_faster_cache_disabled = np.concatenate((output[:8], output[-8:])) + + assert np.allclose( + original_image_slice, image_slice_faster_cache_enabled, atol=expected_atol + ), "FasterCache outputs should not differ much in specified timestep range." + assert np.allclose( + original_image_slice, image_slice_faster_cache_disabled, atol=1e-4 + ), "Outputs from normal inference and after disabling cache should not differ." + + def test_faster_cache_state(self): + from diffusers.hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK + + device = "cpu" # ensure determinism for the device-dependent torch.Generator + num_layers = 0 + num_single_layers = 0 + dummy_component_kwargs = {} + dummy_component_parameters = inspect.signature(self.get_dummy_components).parameters + if "num_layers" in dummy_component_parameters: + num_layers = 2 + dummy_component_kwargs["num_layers"] = num_layers + if "num_single_layers" in dummy_component_parameters: + num_single_layers = 2 + dummy_component_kwargs["num_single_layers"] = num_single_layers + + components = self.get_dummy_components(**dummy_component_kwargs) + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + self.faster_cache_config.current_timestep_callback = lambda: pipe.current_timestep + pipe.transformer.enable_cache(self.faster_cache_config) + + expected_hooks = 0 + if self.faster_cache_config.spatial_attention_block_skip_range is not None: + expected_hooks += num_layers + num_single_layers + if self.faster_cache_config.temporal_attention_block_skip_range is not None: + expected_hooks += num_layers + num_single_layers + + # Check if faster_cache denoiser hook is attached + denoiser = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet + self.assertTrue( + hasattr(denoiser, "_diffusers_hook") + and isinstance(denoiser._diffusers_hook.get_hook(_FASTER_CACHE_DENOISER_HOOK), FasterCacheDenoiserHook), + "Hook should be of type FasterCacheDenoiserHook.", + ) + + # Check if all blocks have faster_cache block hook attached + count = 0 + for name, module in denoiser.named_modules(): + if hasattr(module, "_diffusers_hook"): + if name == "": + # Skip the root denoiser module + continue + count += 1 + self.assertTrue( + isinstance(module._diffusers_hook.get_hook(_FASTER_CACHE_BLOCK_HOOK), FasterCacheBlockHook), + "Hook should be of type FasterCacheBlockHook.", + ) + self.assertEqual(count, expected_hooks, "Number of hooks should match expected number.") + + # Perform inference to ensure that states are updated correctly + def faster_cache_state_check_callback(pipe, i, t, kwargs): + for name, module in denoiser.named_modules(): + if not hasattr(module, "_diffusers_hook"): + continue + if name == "": + # Root denoiser module + state = module._diffusers_hook.get_hook(_FASTER_CACHE_DENOISER_HOOK).state + if not self.faster_cache_config.is_guidance_distilled: + self.assertTrue(state.low_frequency_delta is not None, "Low frequency delta should be set.") + self.assertTrue(state.high_frequency_delta is not None, "High frequency delta should be set.") + else: + # Internal blocks + state = module._diffusers_hook.get_hook(_FASTER_CACHE_BLOCK_HOOK).state + self.assertTrue(state.cache is not None and len(state.cache) == 2, "Cache should be set.") + self.assertTrue(state.iteration == i + 1, "Hook iteration state should have updated during inference.") + return {} + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 4 + inputs["callback_on_step_end"] = faster_cache_state_check_callback + _ = pipe(**inputs)[0] + + # After inference, reset_stateful_hooks is called within the pipeline, which should have reset the states + for name, module in denoiser.named_modules(): + if not hasattr(module, "_diffusers_hook"): + continue + + if name == "": + # Root denoiser module + state = module._diffusers_hook.get_hook(_FASTER_CACHE_DENOISER_HOOK).state + self.assertTrue(state.iteration == 0, "Iteration should be reset to 0.") + self.assertTrue(state.low_frequency_delta is None, "Low frequency delta should be reset to None.") + self.assertTrue(state.high_frequency_delta is None, "High frequency delta should be reset to None.") + else: + # Internal blocks + state = module._diffusers_hook.get_hook(_FASTER_CACHE_BLOCK_HOOK).state + self.assertTrue(state.iteration == 0, "Iteration should be reset to 0.") + self.assertTrue(state.batch_size is None, "Batch size should be reset to None.") + self.assertTrue(state.cache is None, "Cache should be reset to None.") + + # Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. # This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a # reference image.