diff --git a/docs/source/en/api/cache.md b/docs/source/en/api/cache.md index a6aa5445a845..a1d961cc2974 100644 --- a/docs/source/en/api/cache.md +++ b/docs/source/en/api/cache.md @@ -11,6 +11,50 @@ specific language governing permissions and limitations under the License. --> # Caching methods +## Faster Cache + +[FasterCache](https://huggingface.co/papers/2410.19355) from Zhengyao Lv, Chenyang Si, Junhao Song, Zhenyu Yang, Yu Qiao, Ziwei Liu, Kwan-Yee K. Wong. + +FasterCache is a method that speeds up inference in diffusion transformers by: +- Reusing attention states between successive inference steps, due to high similarity between them +- Skipping unconditional branch prediction used in classifier-free guidance by revealing redundancies between unconditional and conditional branch outputs for the same timestep, and therefore approximating the unconditional branch output using the conditional branch output + +```python +import torch +from diffusers import CogVideoXPipeline, FasterCacheConfig + +pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +config = FasterCacheConfig( + spatial_attention_block_skip_range=2, + spatial_attention_timestep_skip_range=(-1, 681), + current_timestep_callback=lambda: pipe.current_timestep, + attention_weight_callback=lambda _: 0.3, + unconditional_batch_skip_range=5, + unconditional_batch_timestep_skip_range=(-1, 781), + tensor_format="BFCHW", +) +pipe.transformer.enable_cache(config) +``` + +## First Block Cache + +[First Block Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching) is a method that builds upon the ideas of [TeaCache](https://huggingface.co/papers/2411.19108) to speed up inference in diffusion transformers. The generation quality is superior with greatly reduced inference time. This method always computes the output of the first transformer block and computes the differences between past and current outputs of the first transformer block. If the difference is smaller than a predefined threshold, the computation of remaining transformer blocks is skipped, and otherwise the computation is performed as usual. + +```python +import torch +from diffusers import CogVideoXPipeline, FirstBlockCacheConfig + +pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +# Increasing the threshold may lead to faster inference speeds, but may also lead to poorer quality of generated videos. +# Smaller values between 0.02-2.0 are recommended based on the model being used. The default value is 0.05. +config = FirstBlockCacheConfig(threshold=0.07) +pipe.transformer.enable_cache(config) +``` + ## Pyramid Attention Broadcast [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You. @@ -38,45 +82,24 @@ config = PyramidAttentionBroadcastConfig( pipe.transformer.enable_cache(config) ``` -## Faster Cache +### CacheMixin -[FasterCache](https://huggingface.co/papers/2410.19355) from Zhengyao Lv, Chenyang Si, Junhao Song, Zhenyu Yang, Yu Qiao, Ziwei Liu, Kwan-Yee K. Wong. +[[autodoc]] CacheMixin -FasterCache is a method that speeds up inference in diffusion transformers by: -- Reusing attention states between successive inference steps, due to high similarity between them -- Skipping unconditional branch prediction used in classifier-free guidance by revealing redundancies between unconditional and conditional branch outputs for the same timestep, and therefore approximating the unconditional branch output using the conditional branch output +### FasterCacheConfig -```python -import torch -from diffusers import CogVideoXPipeline, FasterCacheConfig +[[autodoc]] FasterCacheConfig -pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) -pipe.to("cuda") +[[autodoc]] apply_faster_cache -config = FasterCacheConfig( - spatial_attention_block_skip_range=2, - spatial_attention_timestep_skip_range=(-1, 681), - current_timestep_callback=lambda: pipe.current_timestep, - attention_weight_callback=lambda _: 0.3, - unconditional_batch_skip_range=5, - unconditional_batch_timestep_skip_range=(-1, 781), - tensor_format="BFCHW", -) -pipe.transformer.enable_cache(config) -``` +### FirstBlockCacheConfig -### CacheMixin +[[autodoc]] FirstBlockCacheConfig -[[autodoc]] CacheMixin +[[autodoc]] apply_first_block_cache ### PyramidAttentionBroadcastConfig [[autodoc]] PyramidAttentionBroadcastConfig [[autodoc]] apply_pyramid_attention_broadcast - -### FasterCacheConfig - -[[autodoc]] FasterCacheConfig - -[[autodoc]] apply_faster_cache diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 9304c34b4e01..8b5257f46c8d 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -33,6 +33,7 @@ _import_structure = { "configuration_utils": ["ConfigMixin"], + "guiders": [], "hooks": [], "loaders": ["FromOriginalModelMixin"], "models": [], @@ -129,12 +130,25 @@ _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: + _import_structure["guiders"].extend( + [ + "AdaptiveProjectedGuidance", + "ClassifierFreeGuidance", + "ClassifierFreeZeroStarGuidance", + "PerturbedAttentionGuidance", + "SkipLayerGuidance", + ] + ) _import_structure["hooks"].extend( [ "FasterCacheConfig", + "FirstBlockCacheConfig", "HookRegistry", + "LayerSkipConfig", "PyramidAttentionBroadcastConfig", "apply_faster_cache", + "apply_first_block_cache", + "apply_layer_skip", "apply_pyramid_attention_broadcast", ] ) @@ -708,11 +722,22 @@ except OptionalDependencyNotAvailable: from .utils.dummy_pt_objects import * # noqa F403 else: + from .guiders import ( + AdaptiveProjectedGuidance, + ClassifierFreeGuidance, + ClassifierFreeZeroStarGuidance, + PerturbedAttentionGuidance, + SkipLayerGuidance, + ) from .hooks import ( FasterCacheConfig, + FirstBlockCacheConfig, HookRegistry, + LayerSkipConfig, PyramidAttentionBroadcastConfig, apply_faster_cache, + apply_first_block_cache, + apply_layer_skip, apply_pyramid_attention_broadcast, ) from .models import ( diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py new file mode 100644 index 000000000000..3893b30935c0 --- /dev/null +++ b/src/diffusers/guiders/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..utils import is_torch_available + + +if is_torch_available(): + from .adaptive_projected_guidance import AdaptiveProjectedGuidance + from .classifier_free_guidance import ClassifierFreeGuidance + from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance + from .guider_utils import GuidanceMixin, _raise_guidance_deprecation_warning + from .perturbed_attention_guidance import PerturbedAttentionGuidance + from .skip_layer_guidance import SkipLayerGuidance diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py new file mode 100644 index 000000000000..06faaa80e894 --- /dev/null +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -0,0 +1,145 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import torch + +from .guider_utils import GuidanceMixin, rescale_noise_cfg + + +class AdaptiveProjectedGuidance(GuidanceMixin): + """ + Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416 + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + adaptive_projected_guidance_momentum (`float`, defaults to `None`): + The momentum parameter for the adaptive projected guidance. Disabled if set to `None`. + adaptive_projected_guidance_rescale (`float`, defaults to `15.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, + guidance_scale: float = 7.5, + adaptive_projected_guidance_momentum: Optional[float] = None, + adaptive_projected_guidance_rescale: float = 15.0, + eta: float = 1.0, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + ): + super().__init__() + + self.guidance_scale = guidance_scale + self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum + self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale + self.eta = eta + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + self.momentum_buffer = None + + def prepare_inputs(self, *args): + if self._step == 0: + if self.adaptive_projected_guidance_momentum is not None: + self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum) + return super().prepare_inputs(*args) + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if self._is_cfg_enabled(): + pred = pred_cond + else: + pred = normalized_guidance( + pred_cond, + pred_uncond, + self.guidance_scale, + self.momentum_buffer, + self.eta, + self.adaptive_projected_guidance_rescale, + self.use_original_formulation, + ) + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + if self.use_original_formulation: + return not math.isclose(self.guidance_scale, 0.0) + else: + return not math.isclose(self.guidance_scale, 1.0) + + +class MomentumBuffer: + def __init__(self, momentum: float): + self.momentum = momentum + self.running_average = 0 + + def update(self, update_value: torch.Tensor): + new_average = self.momentum * self.running_average + self.running_average = update_value + new_average + + +def normalized_guidance( + pred_cond: torch.Tensor, + pred_uncond: torch.Tensor, + guidance_scale: float, + momentum_buffer: Optional[MomentumBuffer] = None, + eta: float = 1.0, + norm_threshold: float = 0.0, + use_original_formulation: bool = False, +): + diff = pred_cond - pred_uncond + dim = [-i for i in range(1, len(diff.shape))] + if momentum_buffer is not None: + momentum_buffer.update(diff) + diff = momentum_buffer.running_average + if norm_threshold > 0: + ones = torch.ones_like(diff) + diff_norm = diff.norm(p=2, dim=dim, keepdim=True) + scale_factor = torch.minimum(ones, norm_threshold / diff_norm) + diff = diff * scale_factor + v0, v1 = diff.double(), pred_cond.double() + v1 = torch.nn.functional.normalize(v1, dim=dim) + v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1 + v0_orthogonal = v0 - v0_parallel + diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff) + normalized_update = diff_orthogonal + eta * diff_parallel + pred = pred_cond if use_original_formulation else pred_uncond + pred = pred + (guidance_scale - 1) * normalized_update + return pred diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py new file mode 100644 index 000000000000..3423eb3f5fc2 --- /dev/null +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -0,0 +1,98 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import torch + +from .guider_utils import GuidanceMixin, rescale_noise_cfg + + +class ClassifierFreeGuidance(GuidanceMixin): + """ + Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598 + + CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by + jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during + inference. This allows the model to tradeoff between generation quality and sample diversity. + + The original paper proposes scaling and shifting the conditional distribution based on the difference between + conditional and unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)] + + Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen + paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in + theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)] + + The intution behind the original formulation can be thought of as moving the conditional distribution estimates + further away from the unconditional distribution estimates, while the diffusers-native implementation can be + thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of + the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.) + + The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the + paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time. + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, guidance_scale: float = 7.5, guidance_rescale: float = 0.0, use_original_formulation: bool = False + ): + super().__init__() + + self.guidance_scale = guidance_scale + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if not self._is_cfg_enabled(): + pred = pred_cond + else: + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + if self.use_original_formulation: + return not math.isclose(self.guidance_scale, 0.0) + else: + return not math.isclose(self.guidance_scale, 1.0) diff --git a/src/diffusers/guiders/classifier_free_zero_star_guidance.py b/src/diffusers/guiders/classifier_free_zero_star_guidance.py new file mode 100644 index 000000000000..f34675e1a93a --- /dev/null +++ b/src/diffusers/guiders/classifier_free_zero_star_guidance.py @@ -0,0 +1,110 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import torch + +from .guider_utils import GuidanceMixin, rescale_noise_cfg + + +class ClassifierFreeZeroStarGuidance(GuidanceMixin): + """ + Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886 + + This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free + guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion + process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the + quality of generated images. + + The authors of the paper suggest setting zero initialization in the first 4% of the inference steps. + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + zero_init_steps (`int`, defaults to `1`): + The number of inference steps for which the noise predictions are zeroed out (see Section 4.2). + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + """ + + _input_predictions = ["pred_cond", "pred_uncond"] + + def __init__( + self, + guidance_scale: float = 7.5, + zero_init_steps: int = 1, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + ): + super().__init__() + + self.guidance_scale = guidance_scale + self.zero_init_steps = zero_init_steps + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if self._step < self.zero_init_steps: + pred = torch.zeros_like(pred_cond) + elif self._is_cfg_enabled(): + pred = pred_cond + else: + shift = pred_cond - pred_uncond + pred_cond_flat = pred_cond.flatten(1) + pred_uncond_flat = pred_uncond.flatten(1) + alpha = cfg_zero_star_scale(pred_cond_flat, pred_uncond_flat) + alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1)) + pred_uncond = pred_uncond * alpha + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + if self.use_original_formulation: + return not math.isclose(self.guidance_scale, 0.0) + else: + return not math.isclose(self.guidance_scale, 1.0) + + +def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: + cond = cond.float() + uncond = uncond.float() + dot_product = torch.sum(cond * uncond, dim=1, keepdim=True) + squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + scale = dot_product / squared_norm + return scale.type_as(cond) diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py new file mode 100644 index 000000000000..36a9fa552e54 --- /dev/null +++ b/src/diffusers/guiders/guider_utils.py @@ -0,0 +1,213 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +import torch + +from ..utils import deprecate, get_logger + + +if TYPE_CHECKING: + from ..models.attention_processor import AttentionProcessor + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +class GuidanceMixin: + r"""Base mixin class providing the skeleton for implementing guidance techniques.""" + + _input_predictions = None + + def __init__(self): + self._step: int = None + self._num_inference_steps: int = None + self._timestep: torch.LongTensor = None + self._preds: Dict[str, torch.Tensor] = {} + self._num_outputs_prepared: int = 0 + + if self._input_predictions is None or not isinstance(self._input_predictions, list): + raise ValueError( + "`_input_predictions` must be a list of required prediction names for the guidance technique." + ) + + def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None: + self._step = step + self._num_inference_steps = num_inference_steps + self._timestep = timestep + self._preds = {} + self._num_outputs_prepared = 0 + + def prepare_models(self, denoiser: torch.nn.Module) -> None: + pass + + def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + num_conditions = self.num_conditions + list_of_inputs = [] + for arg in args: + if isinstance(arg, torch.Tensor): + list_of_inputs.append([arg] * num_conditions) + elif isinstance(arg, (tuple, list)): + if len(arg) != 2: + raise ValueError( + f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 " + f"with the first element being the conditional input and the second element being the unconditional input or None." + ) + if arg[1] is None: + # Only conditioning inputs for all batches + list_of_inputs.append([arg[0]] * num_conditions) + else: + # Alternating conditional and unconditional inputs as batches + inputs = [arg[i % 2] for i in range(num_conditions)] + list_of_inputs.append(inputs) + else: + raise ValueError( + f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list." + ) + return tuple(list_of_inputs) + + def prepare_outputs(self, pred: torch.Tensor) -> None: + self._num_outputs_prepared += 1 + if self._num_outputs_prepared > self.num_conditions: + raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") + key = self._input_predictions[self._num_outputs_prepared - 1] + self._preds[key] = pred + + def cleanup_models(self, denoiser: torch.nn.Module) -> None: + pass + + def __call__(self, **kwargs) -> Any: + if len(kwargs) != self.num_conditions: + raise ValueError( + f"Expected {self.num_conditions} arguments, but got {len(kwargs)}. Please provide the correct number of arguments." + ) + return self.forward(**kwargs) + + def forward(self, *args, **kwargs) -> Any: + raise NotImplementedError("GuidanceMixin::forward must be implemented in subclasses.") + + @property + def num_conditions(self) -> int: + raise NotImplementedError("GuidanceMixin::num_conditions must be implemented in subclasses.") + + @property + def outputs(self) -> Dict[str, torch.Tensor]: + return self._preds + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +def _replace_attention_processors( + module: torch.nn.Module, + pag_applied_layers: Optional[Union[str, List[str]]] = None, + skip_context_attention: bool = False, + processors: Optional[List[Tuple[torch.nn.Module, "AttentionProcessor"]]] = None, + metadata_name: Optional[str] = None, +) -> Optional[List[Tuple[torch.nn.Module, "AttentionProcessor"]]]: + if processors is not None and metadata_name is not None: + raise ValueError("Cannot pass both `processors` and `metadata_name` at the same time.") + if metadata_name is not None: + if isinstance(pag_applied_layers, str): + pag_applied_layers = [pag_applied_layers] + return _replace_layers_with_guidance_processors( + module, pag_applied_layers, skip_context_attention, metadata_name + ) + if processors is not None: + _replace_layers_with_existing_processors(processors) + + +def _replace_layers_with_guidance_processors( + module: torch.nn.Module, + pag_applied_layers: List[str], + skip_context_attention: bool, + metadata_name: str, +) -> List[Tuple[torch.nn.Module, "AttentionProcessor"]]: + from ..hooks._common import _ATTENTION_CLASSES + from ..hooks._helpers import GuidanceMetadataRegistry + + processors = [] + for name, submodule in module.named_modules(): + if ( + (not isinstance(submodule, _ATTENTION_CLASSES)) + or (getattr(submodule, "processor", None) is None) + or not ( + any( + re.search(pag_layer, name) is not None and not _is_fake_integral_match(pag_layer, name) + for pag_layer in pag_applied_layers + ) + ) + ): + continue + old_attention_processor = submodule.processor + metadata = GuidanceMetadataRegistry.get(old_attention_processor.__class__) + new_attention_processor_cls = getattr(metadata, metadata_name) + new_attention_processor = new_attention_processor_cls() + # !!! dunder methods cannot be replaced on instances !!! + # if "skip_context_attention" in inspect.signature(new_attention_processor.__call__).parameters: + # new_attention_processor.__call__ = partial( + # new_attention_processor.__call__, skip_context_attention=skip_context_attention + # ) + submodule.processor = new_attention_processor + processors.append((submodule, old_attention_processor)) + return processors + + +def _replace_layers_with_existing_processors(processors: List[Tuple[torch.nn.Module, "AttentionProcessor"]]) -> None: + for module, proc in processors: + module.processor = proc + + +def _is_fake_integral_match(layer_id, name): + layer_id = layer_id.split(".")[-1] + name = name.split(".")[-1] + return layer_id.isnumeric() and name.isnumeric() and layer_id == name + + +def _raise_guidance_deprecation_warning( + *, + guidance_scale: Optional[float] = None, + guidance_rescale: Optional[float] = None, +) -> None: + if guidance_scale is not None: + msg = "The `guidance_scale` argument is deprecated and will be removed in version 1.0.0. Please pass a `GuidanceMixin` object for the `guidance` argument instead." + deprecate("guidance_scale", "1.0.0", msg, standard_warn=False) + if guidance_rescale is not None: + msg = "The `guidance_rescale` argument is deprecated and will be removed in version 1.0.0. Please pass a `GuidanceMixin` object for the `guidance` argument instead." + deprecate("guidance_rescale", "1.0.0", msg, standard_warn=False) diff --git a/src/diffusers/guiders/perturbed_attention_guidance.py b/src/diffusers/guiders/perturbed_attention_guidance.py new file mode 100644 index 000000000000..242860886ec4 --- /dev/null +++ b/src/diffusers/guiders/perturbed_attention_guidance.py @@ -0,0 +1,180 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple, Union + +import torch + +from .guider_utils import GuidanceMixin, _replace_attention_processors, rescale_noise_cfg + + +class PerturbedAttentionGuidance(GuidanceMixin): + """ + Perturbed Attention Guidance (PAB): https://huggingface.co/papers/2403.17377 + + Args: + pag_applied_layers (`str` or `List[str]`): + The name of the attention layers where Perturbed Attention Guidance is applied. This can be a single layer + name or a list of layer names. The names should either be FQNs (fully qualified names) to each attention + layer or a regex pattern that matches the FQNs of the attention layers. For example, if you want to apply + PAG to transformer blocks 10 and 20, you can set this to `["transformer_blocks.10", + "transformer_blocks.20"]`, or `"transformer_blocks.(10|20)"`. + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + pag_scale (`float`, defaults to `3.0`): + The scale parameter for perturbed attention guidance. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + """ + + _input_predictions = ["pred_cond", "pred_uncond", "pred_perturbed"] + + def __init__( + self, + pag_applied_layers: Union[str, List[str]], + guidance_scale: float = 7.5, + pag_scale: float = 3.0, + skip_context_attention: bool = False, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + ): + super().__init__() + + self.pag_applied_layers = pag_applied_layers + self.guidance_scale = guidance_scale + self.pag_scale = pag_scale + self.skip_context_attention = skip_context_attention + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + self._is_pag_batch = False + self._original_processors = None + self._denoiser = None + + def prepare_models(self, denoiser: torch.nn.Module): + self._denoiser = denoiser + + def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + num_conditions = self.num_conditions + list_of_inputs = [] + for arg in args: + if isinstance(arg, torch.Tensor): + list_of_inputs.append([arg] * num_conditions) + elif isinstance(arg, (tuple, list)): + if len(arg) != 2: + raise ValueError( + f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 " + f"with the first element being the conditional input and the second element being the unconditional input or None." + ) + if arg[1] is None: + # Only conditioning inputs for all batches + list_of_inputs.append([arg[0]] * num_conditions) + else: + list_of_inputs.append([arg[0], arg[1], arg[0]]) + else: + raise ValueError( + f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list." + ) + return tuple(list_of_inputs) + + def prepare_outputs(self, pred: torch.Tensor) -> None: + self._num_outputs_prepared += 1 + if self._num_outputs_prepared > self.num_conditions: + raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") + key = self._input_predictions[self._num_outputs_prepared - 1] + if not self._is_cfg_enabled() and self._is_pag_enabled(): + # If we're predicting pred_cond and pred_perturbed only, we need to set the key to pred_perturbed + # to avoid writing into pred_uncond which is not used + if self._num_outputs_prepared == 2: + key = "pred_perturbed" + self._preds[key] = pred + + # Restore the original attention processors if previously replaced + if self._is_pag_batch: + _replace_attention_processors(self._denoiser, processors=self._original_processors) + self._is_pag_batch = False + self._original_processors = None + + # Prepare denoiser for perturbed attention prediction if needed + if self._is_pag_enabled(): + should_register_pag = (self._is_cfg_enabled() and self._num_outputs_prepared == 2) or ( + not self._is_cfg_enabled() and self._num_outputs_prepared == 1 + ) + if should_register_pag: + self._is_pag_batch = True + self._original_processors = _replace_attention_processors( + self._denoiser, + self.pag_applied_layers, + skip_context_attention=self.skip_context_attention, + metadata_name="perturbed_attention_guidance_processor_cls", + ) + + def cleanup_models(self, denoiser: torch.nn.Module): + self._denoiser = None + + def forward( + self, + pred_cond: torch.Tensor, + pred_uncond: Optional[torch.Tensor] = None, + pred_perturbed: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + pred = None + + if not self._is_cfg_enabled() and not self._is_pag_enabled(): + pred = pred_cond + elif not self._is_cfg_enabled(): + shift = pred_cond - pred_perturbed + pred = pred_cond + self.pag_scale * shift + elif not self._is_pag_enabled(): + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + else: + shift = pred_cond - pred_uncond + shift_perturbed = pred_cond - pred_perturbed + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + self.pag_scale * shift_perturbed + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + if self._is_pag_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + if self.use_original_formulation: + return not math.isclose(self.guidance_scale, 0.0) + else: + return not math.isclose(self.guidance_scale, 1.0) + + def _is_pag_enabled(self) -> bool: + is_zero = math.isclose(self.pag_scale, 0.0) + return not is_zero diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py new file mode 100644 index 000000000000..1fe09ddac615 --- /dev/null +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -0,0 +1,235 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple, Union + +import torch + +from ..hooks import HookRegistry, LayerSkipConfig +from ..hooks.layer_skip import _apply_layer_skip_hook +from .guider_utils import GuidanceMixin, rescale_noise_cfg + + +class SkipLayerGuidance(GuidanceMixin): + """ + Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5 Spatio-Temporal Guidance (STG): + https://huggingface.co/papers/2411.18664 + + SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by + skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional + batch of data, apart from the conditional and unconditional batches already used in CFG + ([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions + based on the difference between conditional without skipping and conditional with skipping predictions. + + The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from + worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse + version of the model for the conditional prediction). + + STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving + generation quality in video diffusion models. + + Additional reading: + - [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507) + + The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are + defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium. + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + skip_layer_guidance_scale (`float`, defaults to `2.8`): + The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher + values, but it may also lead to overexposure and saturation. + skip_layer_guidance_start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which skip layer guidance starts. + skip_layer_guidance_stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which skip layer guidance stops. + skip_layer_guidance_layers (`int` or `List[int]`, *optional*): + The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not + provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion + 3.5 Medium. + skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*): + The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of + `LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + """ + + _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"] + + def __init__( + self, + guidance_scale: float = 7.5, + skip_layer_guidance_scale: float = 2.8, + skip_layer_guidance_start: float = 0.01, + skip_layer_guidance_stop: float = 0.2, + skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None, + skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + ): + super().__init__() + + self.guidance_scale = guidance_scale + self.skip_layer_guidance_scale = skip_layer_guidance_scale + self.skip_layer_guidance_start = skip_layer_guidance_start + self.skip_layer_guidance_stop = skip_layer_guidance_stop + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + if not (0.0 <= skip_layer_guidance_start < 1.0): + raise ValueError( + f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}." + ) + if not (0.0 < skip_layer_guidance_stop <= 1.0): + raise ValueError( + f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}." + ) + + if skip_layer_guidance_layers is None and skip_layer_config is None: + raise ValueError( + "Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance." + ) + if skip_layer_guidance_layers is not None and skip_layer_config is not None: + raise ValueError("Only one of `skip_layer_guidance_layers` or `skip_layer_config` can be provided.") + + if skip_layer_guidance_layers is not None: + if isinstance(skip_layer_guidance_layers, int): + skip_layer_guidance_layers = [skip_layer_guidance_layers] + if not isinstance(skip_layer_guidance_layers, list): + raise ValueError( + f"Expected `skip_layer_guidance_layers` to be an int or a list of ints, but got {type(skip_layer_guidance_layers)}." + ) + skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers] + + if isinstance(skip_layer_config, LayerSkipConfig): + skip_layer_config = [skip_layer_config] + + if not isinstance(skip_layer_config, list): + raise ValueError( + f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}." + ) + + self.skip_layer_config = skip_layer_config + self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))] + + def prepare_models(self, denoiser: torch.nn.Module): + skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps) + skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps) + + # Register the hooks for layer skipping if the step is within the specified range + if skip_start_step < self._step < skip_stop_step: + for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config): + _apply_layer_skip_hook(denoiser, config, name=name) + + def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + num_conditions = self.num_conditions + list_of_inputs = [] + for arg in args: + if isinstance(arg, torch.Tensor): + list_of_inputs.append([arg] * num_conditions) + elif isinstance(arg, (tuple, list)): + if len(arg) != 2: + raise ValueError( + f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 " + f"with the first element being the conditional input and the second element being the unconditional input or None." + ) + if arg[1] is None: + # Only conditioning inputs for all batches + list_of_inputs.append([arg[0]] * num_conditions) + else: + list_of_inputs.append([arg[0], arg[1], arg[0]]) + else: + raise ValueError( + f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list." + ) + return tuple(list_of_inputs) + + def prepare_outputs(self, pred: torch.Tensor) -> None: + self._num_outputs_prepared += 1 + if self._num_outputs_prepared > self.num_conditions: + raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") + key = self._input_predictions[self._num_outputs_prepared - 1] + if not self._is_cfg_enabled() and self._is_slg_enabled(): + # If we're predicting pred_cond and pred_cond_skip only, we need to set the key to pred_cond_skip + # to avoid writing into pred_uncond which is not used + if self._num_outputs_prepared == 2: + key = "pred_cond_skip" + self._preds[key] = pred + + def cleanup_models(self, denoiser: torch.nn.Module): + registry = HookRegistry.check_if_exists_or_initialize(denoiser) + # Remove the hooks after inference + for hook_name in self._skip_layer_hook_names: + registry.remove_hook(hook_name, recurse=True) + + def forward( + self, + pred_cond: torch.Tensor, + pred_uncond: Optional[torch.Tensor] = None, + pred_cond_skip: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + pred = None + + if not self._is_cfg_enabled() and not self._is_slg_enabled(): + pred = pred_cond + elif not self._is_cfg_enabled(): + shift = pred_cond - pred_cond_skip + pred = pred_cond if self.use_original_formulation else pred_cond_skip + pred = pred + self.skip_layer_guidance_scale * shift + elif not self._is_slg_enabled(): + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + else: + shift = pred_cond - pred_uncond + shift_skip = pred_cond - pred_cond_skip + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + if self._is_slg_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + if self.use_original_formulation: + return not math.isclose(self.guidance_scale, 0.0) + else: + return not math.isclose(self.guidance_scale, 1.0) + + def _is_slg_enabled(self) -> bool: + skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps) + skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps) + is_within_range = skip_start_step < self._step < skip_stop_step + is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0) + return is_within_range and not is_zero diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 764ceb25b465..2db36d4366f1 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -1,9 +1,25 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from ..utils import is_torch_available if is_torch_available(): from .faster_cache import FasterCacheConfig, apply_faster_cache + from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache from .group_offloading import apply_group_offloading from .hooks import HookRegistry, ModelHook + from .layer_skip import LayerSkipConfig, apply_layer_skip from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast diff --git a/src/diffusers/hooks/_common.py b/src/diffusers/hooks/_common.py new file mode 100644 index 000000000000..6ea83dcbf6a7 --- /dev/null +++ b/src/diffusers/hooks/_common.py @@ -0,0 +1,32 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..models.attention import FeedForward, LuminaFeedForward +from ..models.attention_processor import Attention, MochiAttention + + +_ATTENTION_CLASSES = (Attention, MochiAttention) +_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward) + +_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers") +_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) +_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers") + +_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple( + { + *_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS, + *_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS, + *_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS, + } +) diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py new file mode 100644 index 000000000000..9dabc7b286b5 --- /dev/null +++ b/src/diffusers/hooks/_helpers.py @@ -0,0 +1,276 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Callable, Type + +from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock +from ..models.transformers.transformer_cogview4 import ( + CogView4AttnProcessor, + CogView4PAGAttnProcessor, + CogView4TransformerBlock, +) +from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock +from ..models.transformers.transformer_hunyuan_video import ( + HunyuanVideoSingleTransformerBlock, + HunyuanVideoTokenReplaceSingleTransformerBlock, + HunyuanVideoTokenReplaceTransformerBlock, + HunyuanVideoTransformerBlock, +) +from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock +from ..models.transformers.transformer_mochi import MochiTransformerBlock +from ..models.transformers.transformer_wan import WanTransformerBlock + + +@dataclass +class AttentionProcessorMetadata: + skip_processor_output_fn: Callable[[Any], Any] + + +@dataclass +class GuidanceMetadata: + perturbed_attention_guidance_processor_cls: Type = None + + +@dataclass +class TransformerBlockMetadata: + skip_block_output_fn: Callable[[Any], Any] + return_hidden_states_index: int = None + return_encoder_hidden_states_index: int = None + + +class AttentionProcessorRegistry: + _registry = {} + + @classmethod + def register(cls, model_class: Type, metadata: AttentionProcessorMetadata): + cls._registry[model_class] = metadata + + @classmethod + def get(cls, model_class: Type) -> AttentionProcessorMetadata: + if model_class not in cls._registry: + raise ValueError(f"Model class {model_class} not registered.") + return cls._registry[model_class] + + +class GuidanceMetadataRegistry: + _registry = {} + + @classmethod + def register(cls, model_class: Type, metadata: GuidanceMetadata): + cls._registry[model_class] = metadata + + @classmethod + def get(cls, model_class: Type) -> GuidanceMetadata: + if model_class not in cls._registry: + raise ValueError(f"Model class {model_class} not registered.") + return cls._registry[model_class] + + +class TransformerBlockRegistry: + _registry = {} + + @classmethod + def register(cls, model_class: Type, metadata: TransformerBlockMetadata): + cls._registry[model_class] = metadata + + @classmethod + def get(cls, model_class: Type) -> TransformerBlockMetadata: + if model_class not in cls._registry: + raise ValueError(f"Model class {model_class} not registered.") + return cls._registry[model_class] + + +def _register_attention_processors_metadata(): + # CogView4 + AttentionProcessorRegistry.register( + model_class=CogView4AttnProcessor, + metadata=AttentionProcessorMetadata( + skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor, + ), + ) + + +def _register_guidance_metadata(): + # CogView4 + GuidanceMetadataRegistry.register( + model_class=CogView4AttnProcessor, + metadata=GuidanceMetadata( + perturbed_attention_guidance_processor_cls=CogView4PAGAttnProcessor, + ), + ) + + +def _register_transformer_blocks_metadata(): + # CogVideoX + TransformerBlockRegistry.register( + model_class=CogVideoXBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_CogVideoXBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # CogView4 + TransformerBlockRegistry.register( + model_class=CogView4TransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_CogView4TransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # Flux + TransformerBlockRegistry.register( + model_class=FluxTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_FluxTransformerBlock, + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ), + ) + TransformerBlockRegistry.register( + model_class=FluxSingleTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_FluxSingleTransformerBlock, + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ), + ) + + # HunyuanVideo + TransformerBlockRegistry.register( + model_class=HunyuanVideoTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + TransformerBlockRegistry.register( + model_class=HunyuanVideoSingleTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoSingleTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + TransformerBlockRegistry.register( + model_class=HunyuanVideoTokenReplaceTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + TransformerBlockRegistry.register( + model_class=HunyuanVideoTokenReplaceSingleTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # LTXVideo + TransformerBlockRegistry.register( + model_class=LTXVideoTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_LTXVideoTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ), + ) + + # Mochi + TransformerBlockRegistry.register( + model_class=MochiTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_MochiTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # Wan + TransformerBlockRegistry.register( + model_class=WanTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_WanTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ), + ) + + +# fmt: off +def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + encoder_hidden_states = kwargs.get("encoder_hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + if encoder_hidden_states is None and len(args) > 1: + encoder_hidden_states = args[1] + return hidden_states, encoder_hidden_states + + +_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states + + +def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + return hidden_states + + +def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + encoder_hidden_states = kwargs.get("encoder_hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + if encoder_hidden_states is None and len(args) > 1: + encoder_hidden_states = args[1] + return hidden_states, encoder_hidden_states + + +def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + encoder_hidden_states = kwargs.get("encoder_hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + if encoder_hidden_states is None and len(args) > 1: + encoder_hidden_states = args[1] + return encoder_hidden_states, hidden_states + + +_skip_block_output_fn_CogVideoXBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_CogView4TransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_FluxTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states +_skip_block_output_fn_FluxSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states +_skip_block_output_fn_HunyuanVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_HunyuanVideoSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_LTXVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states +_skip_block_output_fn_MochiTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_WanTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states +# fmt: on + + +_register_attention_processors_metadata() +_register_guidance_metadata() +_register_transformer_blocks_metadata() diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py new file mode 100644 index 000000000000..7863a1268843 --- /dev/null +++ b/src/diffusers/hooks/first_block_cache.py @@ -0,0 +1,223 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Tuple, Union + +import torch + +from ..utils import get_logger +from ..utils.torch_utils import unwrap_module +from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS +from ._helpers import TransformerBlockRegistry +from .hooks import BaseMarkedState, HookRegistry, ModelHook + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook" +_FBC_BLOCK_HOOK = "fbc_block_hook" + + +@dataclass +class FirstBlockCacheConfig: + r""" + Configuration for [First Block + Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching). + + Args: + threshold (`float`, defaults to `0.05`): + The threshold to determine whether or not a forward pass through all layers of the model is required. A + higher threshold usually results in lower number of forward passes and faster inference, but might lead to + poorer generation quality. A lower threshold may not result in significant generation speedup. The + threshold is compared against the absmean difference of the residuals between the current and cached + outputs from the first transformer block. If the difference is below the threshold, the forward pass is + skipped. + """ + + threshold: float = 0.05 + + +class FBCSharedBlockState(BaseMarkedState): + def __init__(self) -> None: + super().__init__() + + self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None + self.head_block_residual: torch.Tensor = None + self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None + self.should_compute: bool = True + + def reset(self): + self.tail_block_residuals = None + self.should_compute = True + + +class FBCHeadBlockHook(ModelHook): + _is_stateful = True + + def __init__(self, shared_state: FBCSharedBlockState, threshold: float): + self.shared_state = shared_state + self.threshold = threshold + self._metadata = None + + def initialize_hook(self, module): + self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__) + return module + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + outputs_if_skipped = self._metadata.skip_block_output_fn(module, *args, **kwargs) + original_hs = outputs_if_skipped[self._metadata.return_hidden_states_index] + + output = self.fn_ref.original_forward(*args, **kwargs) + is_output_tuple = isinstance(output, tuple) + + if is_output_tuple: + hs_residual = output[self._metadata.return_hidden_states_index] - original_hs + else: + hs_residual = output - original_hs + + hs, ehs = None, None + should_compute = self._should_compute_remaining_blocks(hs_residual) + self.shared_state.should_compute = should_compute + + if not should_compute: + # Apply caching + if is_output_tuple: + hs = self.shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index] + else: + hs = self.shared_state.tail_block_residuals[0] + output + + if self._metadata.return_encoder_hidden_states_index is not None: + assert is_output_tuple + ehs = ( + self.shared_state.tail_block_residuals[1] + + output[self._metadata.return_encoder_hidden_states_index] + ) + + if is_output_tuple: + return_output = [None] * len(output) + return_output[self._metadata.return_hidden_states_index] = hs + return_output[self._metadata.return_encoder_hidden_states_index] = ehs + return_output = tuple(return_output) + else: + return_output = hs + output = return_output + else: + if is_output_tuple: + head_block_output = [None] * len(output) + head_block_output[0] = output[self._metadata.return_hidden_states_index] + head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index] + else: + head_block_output = output + self.shared_state.head_block_output = head_block_output + self.shared_state.head_block_residual = hs_residual + + return output + + def reset_state(self, module): + self.shared_state.reset() + return module + + @torch.compiler.disable + def _should_compute_remaining_blocks(self, hs_residual: torch.Tensor) -> bool: + if self.shared_state.head_block_residual is None: + return True + prev_hs_residual = self.shared_state.head_block_residual + hs_absmean = (hs_residual - prev_hs_residual).abs().mean() + prev_hs_mean = prev_hs_residual.abs().mean() + diff = (hs_absmean / prev_hs_mean).item() + return diff > self.threshold + + +class FBCBlockHook(ModelHook): + def __init__(self, shared_state: FBCSharedBlockState, is_tail: bool = False): + super().__init__() + self.shared_state = shared_state + self.is_tail = is_tail + self._metadata = None + + def initialize_hook(self, module): + self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__) + return module + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + outputs_if_skipped = self._metadata.skip_block_output_fn(module, *args, **kwargs) + if not isinstance(outputs_if_skipped, tuple): + outputs_if_skipped = (outputs_if_skipped,) + original_hs = outputs_if_skipped[self._metadata.return_hidden_states_index] + original_ehs = None + if self._metadata.return_encoder_hidden_states_index is not None: + original_ehs = outputs_if_skipped[self._metadata.return_encoder_hidden_states_index] + + if self.shared_state.should_compute: + output = self.fn_ref.original_forward(*args, **kwargs) + if self.is_tail: + hs_residual, ehs_residual = None, None + if isinstance(output, tuple): + hs_residual = ( + output[self._metadata.return_hidden_states_index] - self.shared_state.head_block_output[0] + ) + ehs_residual = ( + output[self._metadata.return_encoder_hidden_states_index] + - self.shared_state.head_block_output[1] + ) + else: + hs_residual = output - self.shared_state.head_block_output + self.shared_state.tail_block_residuals = (hs_residual, ehs_residual) + return output + + output_count = len(outputs_if_skipped) + if output_count == 1: + return_output = original_hs + else: + return_output = [None] * output_count + return_output[self._metadata.return_hidden_states_index] = original_hs + return_output[self._metadata.return_encoder_hidden_states_index] = original_ehs + return return_output + + +def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None: + shared_state = FBCSharedBlockState() + remaining_blocks = [] + + for name, submodule in module.named_children(): + if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList): + continue + for index, block in enumerate(submodule): + remaining_blocks.append((f"{name}.{index}", block)) + + head_block_name, head_block = remaining_blocks.pop(0) + tail_block_name, tail_block = remaining_blocks.pop(-1) + + logger.debug(f"Apply FBCHeadBlockHook to '{head_block_name}'") + apply_fbc_head_block_hook(head_block, shared_state, config.threshold) + + for name, block in remaining_blocks: + logger.debug(f"Apply FBCBlockHook to '{name}'") + apply_fbc_block_hook(block, shared_state) + + logger.debug(f"Apply FBCBlockHook to tail block '{tail_block_name}'") + apply_fbc_block_hook(tail_block, shared_state, is_tail=True) + + +def apply_fbc_head_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, threshold: float) -> None: + registry = HookRegistry.check_if_exists_or_initialize(block) + hook = FBCHeadBlockHook(state, threshold) + registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK) + + +def apply_fbc_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, is_tail: bool = False) -> None: + registry = HookRegistry.check_if_exists_or_initialize(block) + hook = FBCBlockHook(state, is_tail) + registry.register_hook(hook, _FBC_BLOCK_HOOK) diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index 3b2e4ed91c2f..c42592783d91 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -18,11 +18,76 @@ import torch from ..utils.logging import get_logger +from ..utils.torch_utils import unwrap_module logger = get_logger(__name__) # pylint: disable=invalid-name +class BaseState: + def reset(self, *args, **kwargs) -> None: + raise NotImplementedError( + "BaseState::reset is not implemented. Please implement this method in the derived class." + ) + + +class BaseMarkedState(BaseState): + def __init__(self, init_args=None, init_kwargs=None): + super().__init__() + + self._init_args = init_args if init_args is not None else () + self._init_kwargs = init_kwargs if init_kwargs is not None else {} + self._mark_name = None + self._state_cache = {} + + def get_current_state(self) -> "BaseMarkedState": + if self._mark_name is None: + # If no mark name is set, simply return a dummy object since we're not going to be using it + return self + if self._mark_name not in self._state_cache.keys(): + self._state_cache[self._mark_name] = self.__class__(*self._init_args, **self._init_kwargs) + return self._state_cache[self._mark_name] + + def mark_state(self, name: str) -> None: + self._mark_name = name + + def reset(self, *args, **kwargs) -> None: + for name, state in list(self._state_cache.items()): + state.reset(*args, **kwargs) + self._state_cache.pop(name) + self._mark_name = None + + def __getattribute__(self, name): + if name in ( + "get_current_state", + "mark_state", + "reset", + "_init_args", + "_init_kwargs", + "_mark_name", + "_state_cache", + ) or _is_dunder_method(name): + return object.__getattribute__(self, name) + else: + current_state = BaseMarkedState.get_current_state(self) + return object.__getattribute__(current_state, name) + + def __setattr__(self, name, value): + if name in ( + "get_current_state", + "mark_state", + "reset", + "_init_args", + "_init_kwargs", + "_mark_name", + "_state_cache", + ) or _is_dunder_method(name): + object.__setattr__(self, name, value) + else: + current_state = BaseMarkedState.get_current_state(self) + object.__setattr__(current_state, name, value) + + class ModelHook: r""" A hook that contains callbacks to be executed just before and after the forward method of a model. @@ -99,6 +164,14 @@ def reset_state(self, module: torch.nn.Module): raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.") return module + def _mark_state(self, module: torch.nn.Module, name: str) -> None: + # Iterate over all attributes of the hook to see if any of them have the type `BaseMarkedState`. If so, call `mark_state` on them. + for attr_name in dir(self): + attr = getattr(self, attr_name) + if isinstance(attr, BaseMarkedState): + attr.mark_state(name) + return module + class HookFunctionReference: def __init__(self) -> None: @@ -211,9 +284,10 @@ def reset_stateful_hooks(self, recurse: bool = True) -> None: hook.reset_state(self._module_ref) if recurse: - for module_name, module in self._module_ref.named_modules(): + for module_name, module in unwrap_module(self._module_ref).named_modules(): if module_name == "": continue + module = unwrap_module(module) if hasattr(module, "_diffusers_hook"): module._diffusers_hook.reset_stateful_hooks(recurse=False) @@ -223,6 +297,19 @@ def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry module._diffusers_hook = cls(module) return module._diffusers_hook + def _mark_state(self, name: str) -> None: + for hook_name in reversed(self._hook_order): + hook = self.hooks[hook_name] + if hook._is_stateful: + hook._mark_state(self._module_ref, name) + + for module_name, module in unwrap_module(self._module_ref).named_modules(): + if module_name == "": + continue + module = unwrap_module(module) + if hasattr(module, "_diffusers_hook"): + module._diffusers_hook._mark_state(name) + def __repr__(self) -> str: registry_repr = "" for i, hook_name in enumerate(self._hook_order): @@ -234,3 +321,7 @@ def __repr__(self) -> str: if i < len(self._hook_order) - 1: registry_repr += "\n" return f"HookRegistry(\n{registry_repr}\n)" + + +def _is_dunder_method(name: str) -> bool: + return name.startswith("__") and name.endswith("__") and name in dir(object) diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py new file mode 100644 index 000000000000..30b169e8f4af --- /dev/null +++ b/src/diffusers/hooks/layer_skip.py @@ -0,0 +1,182 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Callable, List, Optional + +import torch + +from ..utils import get_logger +from ..utils.torch_utils import unwrap_module +from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES +from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry +from .hooks import HookRegistry, ModelHook + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +_LAYER_SKIP_HOOK = "layer_skip_hook" + + +@dataclass +class LayerSkipConfig: + r""" + Configuration for skipping internal transformer blocks when executing a transformer model. + + Args: + indices (`List[int]`): + The indices of the layer to skip. This is typically the first layer in the transformer block. + fqn (`str`, defaults to `"auto"`): + The fully qualified name identifying the stack of transformer blocks. Typically, this is + `transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`. + """ + + indices: List[int] + fqn: str = "auto" + skip_attention: bool = True + skip_attention_scores: bool = False + skip_ff: bool = True + + +class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode): + def __init__(self) -> None: + super().__init__() + + def __torch_function__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func is torch.nn.functional.scaled_dot_product_attention: + value = kwargs.get("value", None) + if value is None: + value = args[2] + return value + return func(*args, **kwargs) + + +class AttentionProcessorSkipHook(ModelHook): + def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False): + self.skip_processor_output_fn = skip_processor_output_fn + self.skip_attention_scores = skip_attention_scores + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + if self.skip_attention_scores: + with AttentionScoreSkipFunctionMode(): + return self.fn_ref.original_forward(*args, **kwargs) + else: + return self.skip_processor_output_fn(module, *args, **kwargs) + + +class FeedForwardSkipHook(ModelHook): + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + output = kwargs.get("hidden_states", None) + if output is None: + output = kwargs.get("x", None) + if output is None and len(args) > 0: + output = args[0] + return output + + +class TransformerBlockSkipHook(ModelHook): + def initialize_hook(self, module): + self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__) + return module + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + return self._metadata.skip_block_output_fn(module, *args, **kwargs) + + +def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None: + r""" + Apply layer skipping to internal layers of a transformer. + + Args: + module (`torch.nn.Module`): + The transformer model to which the layer skip hook should be applied. + config (`LayerSkipConfig`): + The configuration for the layer skip hook. + + Example: + + ```python + >>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig + + >>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) + >>> config = LayerSkipConfig(layer_index=[10, 20], fqn="transformer_blocks") + >>> apply_layer_skip_hook(transformer, config) + ``` + """ + _apply_layer_skip_hook(module, config) + + +def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None: + name = name or _LAYER_SKIP_HOOK + + if config.skip_attention and config.skip_attention_scores: + raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.") + + if config.fqn == "auto": + for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS: + if hasattr(module, identifier): + config.fqn = identifier + break + else: + raise ValueError( + "Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid " + "`fqn` (fully qualified name) that identifies a stack of transformer blocks." + ) + + transformer_blocks = getattr(module, config.fqn, None) + if transformer_blocks is None or not isinstance(transformer_blocks, torch.nn.ModuleList): + raise ValueError( + f"Could not find {config.fqn} in the provided module, or configured `fqn` (fully qualified name) does not identify " + f"a `torch.nn.ModuleList`. Please provide a valid `fqn` that identifies a stack of transformer blocks." + ) + if len(config.indices) == 0: + raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.") + + blocks_found = False + for i, block in enumerate(transformer_blocks): + if i not in config.indices: + continue + blocks_found = True + if config.skip_attention and config.skip_ff: + logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'") + registry = HookRegistry.check_if_exists_or_initialize(block) + hook = TransformerBlockSkipHook() + registry.register_hook(hook, name) + elif config.skip_attention or config.skip_attention_scores: + for submodule_name, submodule in block.named_modules(): + if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention: + logger.debug(f"Applying AttentionProcessorSkipHook to '{config.fqn}.{i}.{submodule_name}'") + output_fn = AttentionProcessorRegistry.get(submodule.processor.__class__).skip_processor_output_fn + registry = HookRegistry.check_if_exists_or_initialize(submodule) + hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores) + registry.register_hook(hook, name) + elif config.skip_ff: + for submodule_name, submodule in block.named_modules(): + if isinstance(submodule, _FEEDFORWARD_CLASSES): + logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'") + registry = HookRegistry.check_if_exists_or_initialize(submodule) + hook = FeedForwardSkipHook() + registry.register_hook(hook, name) + else: + raise ValueError( + "At least one of `skip_attention`, `skip_attention_scores`, or `skip_ff` must be set to True." + ) + + if not blocks_found: + raise ValueError( + f"Could not find any transformer blocks matching the provided indices {config.indices} and " + f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness." + ) diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 79bd8dc0b254..6c4bcb301d70 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import contextmanager + from ..utils.logging import get_logger @@ -25,6 +27,7 @@ class CacheMixin: Supported caching techniques: - [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) - [FasterCache](https://huggingface.co/papers/2410.19355) + - [FirstBlockCache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching) """ _cache_config = None @@ -62,8 +65,10 @@ def enable_cache(self, config) -> None: from ..hooks import ( FasterCacheConfig, + FirstBlockCacheConfig, PyramidAttentionBroadcastConfig, apply_faster_cache, + apply_first_block_cache, apply_pyramid_attention_broadcast, ) @@ -72,31 +77,36 @@ def enable_cache(self, config) -> None: f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first." ) - if isinstance(config, PyramidAttentionBroadcastConfig): - apply_pyramid_attention_broadcast(self, config) - elif isinstance(config, FasterCacheConfig): + if isinstance(config, FasterCacheConfig): apply_faster_cache(self, config) + elif isinstance(config, FirstBlockCacheConfig): + apply_first_block_cache(self, config) + elif isinstance(config, PyramidAttentionBroadcastConfig): + apply_pyramid_attention_broadcast(self, config) else: raise ValueError(f"Cache config {type(config)} is not supported.") self._cache_config = config def disable_cache(self) -> None: - from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig + from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK + from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK if self._cache_config is None: logger.warning("Caching techniques have not been enabled, so there's nothing to disable.") return - if isinstance(self._cache_config, PyramidAttentionBroadcastConfig): - registry = HookRegistry.check_if_exists_or_initialize(self) - registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True) - elif isinstance(self._cache_config, FasterCacheConfig): - registry = HookRegistry.check_if_exists_or_initialize(self) + registry = HookRegistry.check_if_exists_or_initialize(self) + if isinstance(self._cache_config, FasterCacheConfig): registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True) registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True) + elif isinstance(self._cache_config, FirstBlockCacheConfig): + registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True) + registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True) + elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig): + registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True) else: raise ValueError(f"Cache config {type(self._cache_config)} is not supported.") @@ -106,3 +116,21 @@ def _reset_stateful_cache(self, recurse: bool = True) -> None: from ..hooks import HookRegistry HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse) + + @contextmanager + def _cache_context(self): + r"""Context manager that provides additional methods for cache management.""" + cache_context = _CacheContextManager(self) + yield cache_context + + +class _CacheContextManager: + def __init__(self, model: CacheMixin): + self.model = model + + def mark_state(self, name: str) -> None: + from ..hooks import HookRegistry + + if self.model.is_cache_enabled: + registry = HookRegistry.check_if_exists_or_initialize(self.model) + registry._mark_state(name) diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index 51c34b7fe965..04ab72e82a03 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -343,25 +343,25 @@ def forward( ) block_samples = block_samples + (hidden_states,) - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - single_block_samples = () for index_block, block in enumerate(self.single_transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func( + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( block, hidden_states, + encoder_hidden_states, temb, image_rotary_emb, ) else: - hidden_states = block( + encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, ) - single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],) + single_block_samples = single_block_samples + (hidden_states,) # controlnet block controlnet_block_samples = () diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index 41c4cbbf97c7..d393490c7e8c 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -460,3 +460,84 @@ def forward( if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) + + +### ===== Custom attention processors for guidance methods ===== ### + + +class CogView4PAGAttnProcessor: + """ + Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("CogView4AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + skip_context_attention: bool = False, + ) -> torch.Tensor: + batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape + batch_size, image_seq_length, embed_dim = hidden_states.shape + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + # 1. QKV projections + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + # 2. QK normalization + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # 3. Rotational positional embeddings applied to latent stream + if image_rotary_emb is not None: + from ..embeddings import apply_rotary_emb + + query[:, :, text_seq_length:, :] = apply_rotary_emb( + query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2 + ) + key[:, :, text_seq_length:, :] = apply_rotary_emb( + key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2 + ) + + # 4. Attention + if skip_context_attention: + hidden_states = value + else: + # PAG uses a custom attention mask for perturbed attention path: + # - Create attention mask with `float("-inf")` for image tokens and `0.0` for text tokens + # - Set diagonal to `0.0` for attention between image tokens + seq_length = text_seq_length + image_seq_length + perturbed_attention_mask = hidden_states.new_full((seq_length, seq_length), float("-inf")) + perturbed_attention_mask[:text_seq_length, :text_seq_length] = 0.0 + perturbed_attention_mask.fill_diagonal_(0.0) + perturbed_attention_mask = perturbed_attention_mask.unsqueeze(0).unsqueeze(0) + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=perturbed_attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + # 5. Output projection + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + return hidden_states, encoder_hidden_states diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 87537890d246..b0fb3900f657 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -79,10 +79,14 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, def forward( self, hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + residual = hidden_states norm_hidden_states, gate = self.norm(hidden_states, emb=temb) mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) @@ -100,7 +104,8 @@ def forward( if hidden_states.dtype == torch.float16: hidden_states = hidden_states.clip(-65504, 65504) - return hidden_states + encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] + return encoder_hidden_states, hidden_states @maybe_allow_in_graph @@ -508,20 +513,21 @@ def forward( ) else: hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) for index_block, block in enumerate(self.single_transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func( + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( block, hidden_states, + encoder_hidden_states, temb, image_rotary_emb, ) else: - hidden_states = block( + encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, @@ -531,12 +537,7 @@ def forward( if controlnet_single_block_samples is not None: interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) interval_control = int(np.ceil(interval_control)) - hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( - hidden_states[:, encoder_hidden_states.shape[1] :, ...] - + controlnet_single_block_samples[index_block // interval_control] - ) - - hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] + hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control] hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py index 8550fa94f9e4..72e8c70f899a 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py @@ -21,6 +21,7 @@ from transformers import AutoTokenizer, GlmModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...guiders import ClassifierFreeGuidance, GuidanceMixin, _raise_guidance_deprecation_warning from ...image_processor import VaeImageProcessor from ...loaders import CogView4LoraLoaderMixin from ...models import AutoencoderKL, CogView4Transformer2DModel @@ -426,6 +427,7 @@ def __call__( ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 1024, + guidance: Optional[GuidanceMixin] = None, ) -> Union[CogView4PipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -514,6 +516,10 @@ def __call__( `tuple`. When returning a tuple, the first element is a list with the generated images. """ + _raise_guidance_deprecation_warning(guidance_scale=guidance_scale) + if guidance is None: + guidance = ClassifierFreeGuidance(guidance_scale=guidance_scale) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs @@ -608,46 +614,47 @@ def __call__( transformer_dtype = self.transformer.dtype num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - with self.progress_bar(total=num_inference_steps) as progress_bar: + conds = [prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left] + prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left = [[v] for v in conds] + + with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: for i, t in enumerate(timesteps): + self._current_timestep = t if self.interrupt: continue - self._current_timestep = t - latent_model_input = latents.to(transformer_dtype) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latents.shape[0]) - - noise_pred_cond = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - original_size=original_size, - target_size=target_size, - crop_coords=crops_coords_top_left, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - # perform guidance - if self.do_classifier_free_guidance: - noise_pred_uncond = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=negative_prompt_embeds, + guidance.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t) + guidance.prepare_models(self.transformer) + latents, prompt_embeds, original_size, target_size, crops_coords_top_left = guidance.prepare_inputs( + latents, + (prompt_embeds[0], negative_prompt_embeds[0]), + original_size[0], + target_size[0], + crops_coords_top_left[0], + ) + + for batch_index, (latent, condition, original_size_c, target_size_c, crop_coord_c) in enumerate( + zip(latents, prompt_embeds, original_size, target_size, crops_coords_top_left) + ): + cc.mark_state(f"batch_{batch_index}") + latent = latent.to(transformer_dtype) + timestep = t.expand(latent.shape[0]) + noise_pred = self.transformer( + hidden_states=latent, + encoder_hidden_states=condition, timestep=timestep, - original_size=original_size, - target_size=target_size, - crop_coords=crops_coords_top_left, + original_size=original_size_c, + target_size=target_size_c, + crop_coords=crop_coord_c, attention_kwargs=attention_kwargs, return_dict=False, )[0] + guidance.prepare_outputs(noise_pred) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) - else: - noise_pred = noise_pred_cond - - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + outputs = guidance.outputs + noise_pred = guidance(**outputs) + latents = self.scheduler.step(noise_pred, t, latents[0], return_dict=False)[0] + guidance.cleanup_models(self.transformer) # call the callback, if provided if callback_on_step_end is not None: @@ -656,8 +663,10 @@ def __call__( callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs) latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + prompt_embeds = [callback_outputs.pop("prompt_embeds", prompt_embeds[0])] + negative_prompt_embeds = [ + callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds[0]) + ] if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 862c279cfaf3..a7195d3a679d 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -906,7 +906,7 @@ def __call__( ) # 6. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -917,6 +917,7 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) + cc.mark_state("cond") noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, @@ -932,6 +933,8 @@ def __call__( if do_true_cfg: if negative_image_embeds is not None: self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + + cc.mark_state("uncond") neg_noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 3cb91b3782f2..b36de61c02ef 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -683,7 +683,7 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -693,6 +693,7 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) + cc.mark_state("cond") noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, @@ -705,6 +706,7 @@ def __call__( )[0] if do_true_cfg: + cc.mark_state("uncond") neg_noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index 6f3faed8ff72..2fa9fa53e8f0 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -706,7 +706,7 @@ def __call__( ) # 7. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -719,6 +719,7 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) + cc.mark_state("cond_uncond") noise_pred = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py index dcfdfaf23288..aa7e8eb5597c 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -1072,7 +1072,7 @@ def __call__( self._num_timesteps = len(timesteps) # 6. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -1105,6 +1105,7 @@ def __call__( if is_conditioning_image_or_video: timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0) + cc.mark_state("cond_uncond") noise_pred = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index 1ae67967c6f5..e9d2566a9bf1 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -778,7 +778,7 @@ def __call__( ) # 7. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -792,6 +792,7 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]) timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) + cc.mark_state("cond_uncond") noise_pred = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index 3294e9a56a07..54164437e7f2 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -519,7 +519,7 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -528,6 +528,7 @@ def __call__( latent_model_input = latents.to(transformer_dtype) timestep = t.expand(latents.shape[0]) + cc.mark_state("cond") noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, @@ -537,6 +538,7 @@ def __call__( )[0] if self.do_classifier_free_guidance: + cc.mark_state("uncond") noise_uncond = self.transformer( hidden_states=latent_model_input, timestep=timestep, diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 6edbd737e32c..1e933483a5ca 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -2,6 +2,81 @@ from ..utils import DummyObject, requires_backends +class AdaptiveProjectedGuidance(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class ClassifierFreeGuidance(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class ClassifierFreeZeroStarGuidance(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class PerturbedAttentionGuidance(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class SkipLayerGuidance(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class FasterCacheConfig(metaclass=DummyObject): _backends = ["torch"] @@ -17,6 +92,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class FirstBlockCacheConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class HookRegistry(metaclass=DummyObject): _backends = ["torch"] @@ -32,6 +122,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class LayerSkipConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class PyramidAttentionBroadcastConfig(metaclass=DummyObject): _backends = ["torch"] @@ -51,6 +156,14 @@ def apply_faster_cache(*args, **kwargs): requires_backends(apply_faster_cache, ["torch"]) +def apply_first_block_cache(*args, **kwargs): + requires_backends(apply_first_block_cache, ["torch"]) + + +def apply_layer_skip(*args, **kwargs): + requires_backends(apply_layer_skip, ["torch"]) + + def apply_pyramid_attention_broadcast(*args, **kwargs): requires_backends(apply_pyramid_attention_broadcast, ["torch"]) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 3c8911773e39..06f9981f0138 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -90,6 +90,11 @@ def is_compiled_module(module) -> bool: return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) +def unwrap_module(module): + """Unwraps a module if it was compiled with torch.compile()""" + return module._orig_mod if is_compiled_module(module) else module + + def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor": """Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497). diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py index 388dc9ef7ec4..385984f0b497 100644 --- a/tests/pipelines/cogvideo/test_cogvideox.py +++ b/tests/pipelines/cogvideo/test_cogvideox.py @@ -32,6 +32,7 @@ from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import ( FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, check_qkv_fusion_matches_attn_procs_length, @@ -44,7 +45,11 @@ class CogVideoXPipelineFastTests( - PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase + PipelineTesterMixin, + PyramidAttentionBroadcastTesterMixin, + FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, + unittest.TestCase, ): pipeline_class = CogVideoXPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index 6a560367a5b8..b9795fc20b1f 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -25,6 +25,7 @@ from ..test_pipelines_common import ( FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, FluxIPAdapterTesterMixin, PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, @@ -34,11 +35,12 @@ class FluxPipelineFastTests( - unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, + unittest.TestCase, ): pipeline_class = FluxPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video.py b/tests/pipelines/hunyuan_video/test_hunyuan_video.py index aa4f045966c3..e6587520c932 100644 --- a/tests/pipelines/hunyuan_video/test_hunyuan_video.py +++ b/tests/pipelines/hunyuan_video/test_hunyuan_video.py @@ -33,6 +33,7 @@ from ..test_pipelines_common import ( FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np, @@ -43,7 +44,11 @@ class HunyuanVideoPipelineFastTests( - PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase + PipelineTesterMixin, + PyramidAttentionBroadcastTesterMixin, + FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, + unittest.TestCase, ): pipeline_class = HunyuanVideoPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) diff --git a/tests/pipelines/ltx/test_ltx.py b/tests/pipelines/ltx/test_ltx.py index 4f72729fc9ce..1f94b746f12f 100644 --- a/tests/pipelines/ltx/test_ltx.py +++ b/tests/pipelines/ltx/test_ltx.py @@ -23,13 +23,13 @@ from diffusers.utils.testing_utils import enable_full_determinism, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin, to_np +from ..test_pipelines_common import FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np enable_full_determinism() -class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class LTXPipelineFastTests(PipelineTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase): pipeline_class = LTXPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS @@ -49,7 +49,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): test_layerwise_casting = True test_group_offloading = True - def get_dummy_components(self): + def get_dummy_components(self, num_layers: int = 1): torch.manual_seed(0) transformer = LTXVideoTransformer3DModel( in_channels=8, @@ -59,7 +59,7 @@ def get_dummy_components(self): num_attention_heads=4, attention_head_dim=8, cross_attention_dim=32, - num_layers=1, + num_layers=num_layers, caption_channels=32, ) diff --git a/tests/pipelines/mochi/test_mochi.py b/tests/pipelines/mochi/test_mochi.py index ea2d015af52a..ce052962e511 100644 --- a/tests/pipelines/mochi/test_mochi.py +++ b/tests/pipelines/mochi/test_mochi.py @@ -33,13 +33,15 @@ ) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import FasterCacheTesterMixin, PipelineTesterMixin, to_np +from ..test_pipelines_common import FasterCacheTesterMixin, FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np enable_full_determinism() -class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unittest.TestCase): +class MochiPipelineFastTests( + PipelineTesterMixin, FasterCacheTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase +): pipeline_class = MochiPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index d3e39e363f91..7d7be3a4951a 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -33,6 +33,7 @@ ) from diffusers.hooks import apply_group_offloading from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook +from diffusers.hooks.first_block_cache import FirstBlockCacheConfig from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin @@ -2631,7 +2632,7 @@ def run_forward(pipe): self.faster_cache_config.current_timestep_callback = lambda: pipe.current_timestep pipe = create_pipe() pipe.transformer.enable_cache(self.faster_cache_config) - output = run_forward(pipe).flatten().flatten() + output = run_forward(pipe).flatten() image_slice_faster_cache_enabled = np.concatenate((output[:8], output[-8:])) # Run inference with FasterCache disabled @@ -2738,6 +2739,55 @@ def faster_cache_state_check_callback(pipe, i, t, kwargs): self.assertTrue(state.cache is None, "Cache should be reset to None.") +# TODO(aryan, dhruv): the cache tester mixins should probably be rewritten so that more models can be tested out +# of the box once there is better cache support/implementation +class FirstBlockCacheTesterMixin: + # threshold is intentionally set higher than usual values since we're testing with random unconverged models + # that will not satisfy the expected properties of the denoiser for caching to be effective + first_block_cache_config = FirstBlockCacheConfig(threshold=0.8) + + def test_first_block_cache_inference(self, expected_atol: float = 0.1): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + def create_pipe(): + torch.manual_seed(0) + num_layers = 2 + components = self.get_dummy_components(num_layers=num_layers) + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + return pipe + + def run_forward(pipe): + torch.manual_seed(0) + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 4 + return pipe(**inputs)[0] + + # Run inference without FirstBlockCache + pipe = create_pipe() + output = run_forward(pipe).flatten() + original_image_slice = np.concatenate((output[:8], output[-8:])) + + # Run inference with FirstBlockCache enabled + pipe = create_pipe() + pipe.transformer.enable_cache(self.first_block_cache_config) + output = run_forward(pipe).flatten() + image_slice_fbc_enabled = np.concatenate((output[:8], output[-8:])) + + # Run inference with FirstBlockCache disabled + pipe.transformer.disable_cache() + output = run_forward(pipe).flatten() + image_slice_fbc_disabled = np.concatenate((output[:8], output[-8:])) + + assert np.allclose( + original_image_slice, image_slice_fbc_enabled, atol=expected_atol + ), "FirstBlockCache outputs should not differ much." + assert np.allclose( + original_image_slice, image_slice_fbc_disabled, atol=1e-4 + ), "Outputs from normal inference and after disabling cache should not differ." + + # Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. # This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a # reference image.