From 406b1062f8274b9551058fa1dc79ab62519770fc Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 31 Mar 2025 04:27:35 +0200 Subject: [PATCH 01/19] update --- docs/source/en/api/cache.md | 81 ++++--- src/diffusers/__init__.py | 4 + src/diffusers/hooks/__init__.py | 15 ++ src/diffusers/hooks/_common.py | 30 +++ src/diffusers/hooks/_helpers.py | 199 ++++++++++++++++ src/diffusers/hooks/first_block_cache.py | 220 ++++++++++++++++++ src/diffusers/models/cache_utils.py | 26 ++- .../models/transformers/transformer_ltx.py | 3 +- src/diffusers/utils/dummy_pt_objects.py | 19 ++ tests/pipelines/cogvideo/test_cogvideox.py | 7 +- tests/pipelines/flux/test_pipeline_flux.py | 4 +- .../hunyuan_video/test_hunyuan_video.py | 7 +- tests/pipelines/ltx/test_ltx.py | 8 +- tests/pipelines/mochi/test_mochi.py | 6 +- tests/pipelines/test_pipelines_common.py | 52 ++++- 15 files changed, 632 insertions(+), 49 deletions(-) create mode 100644 src/diffusers/hooks/_common.py create mode 100644 src/diffusers/hooks/_helpers.py create mode 100644 src/diffusers/hooks/first_block_cache.py diff --git a/docs/source/en/api/cache.md b/docs/source/en/api/cache.md index a6aa5445a845..a1d961cc2974 100644 --- a/docs/source/en/api/cache.md +++ b/docs/source/en/api/cache.md @@ -11,6 +11,50 @@ specific language governing permissions and limitations under the License. --> # Caching methods +## Faster Cache + +[FasterCache](https://huggingface.co/papers/2410.19355) from Zhengyao Lv, Chenyang Si, Junhao Song, Zhenyu Yang, Yu Qiao, Ziwei Liu, Kwan-Yee K. Wong. + +FasterCache is a method that speeds up inference in diffusion transformers by: +- Reusing attention states between successive inference steps, due to high similarity between them +- Skipping unconditional branch prediction used in classifier-free guidance by revealing redundancies between unconditional and conditional branch outputs for the same timestep, and therefore approximating the unconditional branch output using the conditional branch output + +```python +import torch +from diffusers import CogVideoXPipeline, FasterCacheConfig + +pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +config = FasterCacheConfig( + spatial_attention_block_skip_range=2, + spatial_attention_timestep_skip_range=(-1, 681), + current_timestep_callback=lambda: pipe.current_timestep, + attention_weight_callback=lambda _: 0.3, + unconditional_batch_skip_range=5, + unconditional_batch_timestep_skip_range=(-1, 781), + tensor_format="BFCHW", +) +pipe.transformer.enable_cache(config) +``` + +## First Block Cache + +[First Block Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching) is a method that builds upon the ideas of [TeaCache](https://huggingface.co/papers/2411.19108) to speed up inference in diffusion transformers. The generation quality is superior with greatly reduced inference time. This method always computes the output of the first transformer block and computes the differences between past and current outputs of the first transformer block. If the difference is smaller than a predefined threshold, the computation of remaining transformer blocks is skipped, and otherwise the computation is performed as usual. + +```python +import torch +from diffusers import CogVideoXPipeline, FirstBlockCacheConfig + +pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +# Increasing the threshold may lead to faster inference speeds, but may also lead to poorer quality of generated videos. +# Smaller values between 0.02-2.0 are recommended based on the model being used. The default value is 0.05. +config = FirstBlockCacheConfig(threshold=0.07) +pipe.transformer.enable_cache(config) +``` + ## Pyramid Attention Broadcast [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You. @@ -38,45 +82,24 @@ config = PyramidAttentionBroadcastConfig( pipe.transformer.enable_cache(config) ``` -## Faster Cache +### CacheMixin -[FasterCache](https://huggingface.co/papers/2410.19355) from Zhengyao Lv, Chenyang Si, Junhao Song, Zhenyu Yang, Yu Qiao, Ziwei Liu, Kwan-Yee K. Wong. +[[autodoc]] CacheMixin -FasterCache is a method that speeds up inference in diffusion transformers by: -- Reusing attention states between successive inference steps, due to high similarity between them -- Skipping unconditional branch prediction used in classifier-free guidance by revealing redundancies between unconditional and conditional branch outputs for the same timestep, and therefore approximating the unconditional branch output using the conditional branch output +### FasterCacheConfig -```python -import torch -from diffusers import CogVideoXPipeline, FasterCacheConfig +[[autodoc]] FasterCacheConfig -pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) -pipe.to("cuda") +[[autodoc]] apply_faster_cache -config = FasterCacheConfig( - spatial_attention_block_skip_range=2, - spatial_attention_timestep_skip_range=(-1, 681), - current_timestep_callback=lambda: pipe.current_timestep, - attention_weight_callback=lambda _: 0.3, - unconditional_batch_skip_range=5, - unconditional_batch_timestep_skip_range=(-1, 781), - tensor_format="BFCHW", -) -pipe.transformer.enable_cache(config) -``` +### FirstBlockCacheConfig -### CacheMixin +[[autodoc]] FirstBlockCacheConfig -[[autodoc]] CacheMixin +[[autodoc]] apply_first_block_cache ### PyramidAttentionBroadcastConfig [[autodoc]] PyramidAttentionBroadcastConfig [[autodoc]] apply_pyramid_attention_broadcast - -### FasterCacheConfig - -[[autodoc]] FasterCacheConfig - -[[autodoc]] apply_faster_cache diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 656f9b27db90..2c7372baa678 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -132,9 +132,11 @@ _import_structure["hooks"].extend( [ "FasterCacheConfig", + "FirstBlockCacheConfig", "HookRegistry", "PyramidAttentionBroadcastConfig", "apply_faster_cache", + "apply_first_block_cache", "apply_pyramid_attention_broadcast", ] ) @@ -709,9 +711,11 @@ else: from .hooks import ( FasterCacheConfig, + FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig, apply_faster_cache, + apply_first_block_cache, apply_pyramid_attention_broadcast, ) from .models import ( diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 764ceb25b465..365bed371864 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -1,8 +1,23 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from ..utils import is_torch_available if is_torch_available(): from .faster_cache import FasterCacheConfig, apply_faster_cache + from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache from .group_offloading import apply_group_offloading from .hooks import HookRegistry, ModelHook from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook diff --git a/src/diffusers/hooks/_common.py b/src/diffusers/hooks/_common.py new file mode 100644 index 000000000000..3be77dd4cedf --- /dev/null +++ b/src/diffusers/hooks/_common.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..models.attention_processor import Attention, MochiAttention + + +_ATTENTION_CLASSES = (Attention, MochiAttention) + +_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers") +_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) +_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers") + +_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple( + { + *_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS, + *_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS, + *_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS, + } +) diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py new file mode 100644 index 000000000000..606a58cd578e --- /dev/null +++ b/src/diffusers/hooks/_helpers.py @@ -0,0 +1,199 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Callable, Type + +from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock +from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock +from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock +from ..models.transformers.transformer_hunyuan_video import ( + HunyuanVideoSingleTransformerBlock, + HunyuanVideoTokenReplaceSingleTransformerBlock, + HunyuanVideoTokenReplaceTransformerBlock, + HunyuanVideoTransformerBlock, +) +from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock +from ..models.transformers.transformer_mochi import MochiTransformerBlock +from ..models.transformers.transformer_wan import WanTransformerBlock + + +@dataclass +class TransformerBlockMetadata: + skip_block_output_fn: Callable[[Any], Any] + return_hidden_states_index: int = None + return_encoder_hidden_states_index: int = None + + +class TransformerBlockRegistry: + _registry = {} + + @classmethod + def register(cls, model_class: Type, metadata: TransformerBlockMetadata): + cls._registry[model_class] = metadata + + @classmethod + def get(cls, model_class: Type) -> TransformerBlockMetadata: + if model_class not in cls._registry: + raise ValueError(f"Model class {model_class} not registered.") + return cls._registry[model_class] + + +def _register_transformer_blocks_metadata(): + # CogVideoX + TransformerBlockRegistry.register( + model_class=CogVideoXBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_CogVideoXBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # CogView4 + TransformerBlockRegistry.register( + model_class=CogView4TransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_CogView4TransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # Flux + TransformerBlockRegistry.register( + model_class=FluxTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_FluxTransformerBlock, + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ), + ) + TransformerBlockRegistry.register( + model_class=FluxSingleTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_FluxSingleTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ), + ) + + # HunyuanVideo + TransformerBlockRegistry.register( + model_class=HunyuanVideoTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + TransformerBlockRegistry.register( + model_class=HunyuanVideoSingleTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoSingleTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + TransformerBlockRegistry.register( + model_class=HunyuanVideoTokenReplaceTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + TransformerBlockRegistry.register( + model_class=HunyuanVideoTokenReplaceSingleTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # LTXVideo + TransformerBlockRegistry.register( + model_class=LTXVideoTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_LTXVideoTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ), + ) + + # Mochi + TransformerBlockRegistry.register( + model_class=MochiTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_MochiTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # Wan + TransformerBlockRegistry.register( + model_class=WanTransformerBlock, + metadata=TransformerBlockMetadata( + skip_block_output_fn=_skip_block_output_fn_WanTransformerBlock, + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ), + ) + + +# fmt: off +def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + return hidden_states + + +def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + encoder_hidden_states = kwargs.get("encoder_hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + if encoder_hidden_states is None and len(args) > 1: + encoder_hidden_states = args[1] + return hidden_states, encoder_hidden_states + + +def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + encoder_hidden_states = kwargs.get("encoder_hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + if encoder_hidden_states is None and len(args) > 1: + encoder_hidden_states = args[1] + return encoder_hidden_states, hidden_states + + +_skip_block_output_fn_CogVideoXBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_CogView4TransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_FluxTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states +_skip_block_output_fn_FluxSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states +_skip_block_output_fn_HunyuanVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_HunyuanVideoSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_LTXVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states +_skip_block_output_fn_MochiTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states +_skip_block_output_fn_WanTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states +# fmt: on + + +_register_transformer_blocks_metadata() diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py new file mode 100644 index 000000000000..1f1bfd6c8cf9 --- /dev/null +++ b/src/diffusers/hooks/first_block_cache.py @@ -0,0 +1,220 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Tuple, Union + +import torch + +from ..utils import get_logger +from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS +from ._helpers import TransformerBlockRegistry +from .hooks import HookRegistry, ModelHook + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook" +_FBC_BLOCK_HOOK = "fbc_block_hook" + + +@dataclass +class FirstBlockCacheConfig: + r""" + Configuration for [First Block + Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching). + + Args: + threshold (`float`, defaults to `0.05`): + The threshold to determine whether or not a forward pass through all layers of the model is required. A + higher threshold usually results in lower number of forward passes and faster inference, but might lead to + poorer generation quality. A lower threshold may not result in significant generation speedup. The + threshold is compared against the absmean difference of the residuals between the current and cached + outputs from the first transformer block. If the difference is below the threshold, the forward pass is + skipped. + """ + + threshold: float = 0.05 + + +class FBCSharedBlockState: + def __init__(self) -> None: + self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None + self.head_block_residual: torch.Tensor = None + self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None + self.should_compute: bool = True + + def reset(self): + self.tail_block_residuals = None + self.should_compute = True + + +class FBCHeadBlockHook(ModelHook): + _is_stateful = True + + def __init__(self, shared_state: FBCSharedBlockState, threshold: float): + self.shared_state = shared_state + self.threshold = threshold + self._metadata = None + + def initialize_hook(self, module): + self._metadata = TransformerBlockRegistry.get(module.__class__) + return module + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + outputs_if_skipped = self._metadata.skip_block_output_fn(module, *args, **kwargs) + original_hs = outputs_if_skipped[self._metadata.return_hidden_states_index] + + output = self.fn_ref.original_forward(*args, **kwargs) + is_output_tuple = isinstance(output, tuple) + + hs_residual = output[self._metadata.return_hidden_states_index] - original_hs + hs, ehs = None, None + + should_compute = self._should_compute_remaining_blocks(hs_residual) + self.shared_state.should_compute = should_compute + + if not should_compute: + # Apply caching + logger.info("Skipping forward pass through remaining blocks") + + if is_output_tuple: + hs = self.shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index] + else: + hs = output + + if self._metadata.return_encoder_hidden_states_index is not None: + ehs = ( + self.shared_state.tail_block_residuals[1] + + output[self._metadata.return_encoder_hidden_states_index] + ) + + if is_output_tuple: + return_output = [None] * len(output) + return_output[self._metadata.return_hidden_states_index] = hs + return_output[self._metadata.return_encoder_hidden_states_index] = ehs + return_output = tuple(return_output) + else: + return_output = hs + return return_output + else: + logger.info("Computing forward pass through remaining blocks") + if is_output_tuple: + head_block_output = [None] * len(output) + head_block_output[0] = output[self._metadata.return_hidden_states_index] + head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index] + else: + head_block_output = output + self.shared_state.head_block_output = head_block_output + self.shared_state.head_block_residual = hs_residual + return output + + def reset_state(self, module): + self.shared_state.reset() + return module + + def _should_compute_remaining_blocks(self, hs_residual: torch.Tensor) -> bool: + if self.shared_state.head_block_residual is None: + return True + prev_hs_residual = self.shared_state.head_block_residual + hs_absmean = (hs_residual - prev_hs_residual).abs().mean() + prev_hs_mean = prev_hs_residual.abs().mean() + diff = (hs_absmean / prev_hs_mean).item() + logger.info(f"Diff: {diff}, Threshold: {self.threshold}") + return diff > self.threshold + + +class FBCBlockHook(ModelHook): + def __init__(self, shared_state: FBCSharedBlockState, is_tail: bool = False): + super().__init__() + self.shared_state = shared_state + self.is_tail = is_tail + self._metadata = None + + def initialize_hook(self, module): + self._metadata = TransformerBlockRegistry.get(module.__class__) + return module + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + outputs_if_skipped = self._metadata.skip_block_output_fn(module, *args, **kwargs) + original_hs = outputs_if_skipped[self._metadata.return_hidden_states_index] + original_ehs = None + if self._metadata.return_encoder_hidden_states_index is not None: + original_ehs = outputs_if_skipped[self._metadata.return_encoder_hidden_states_index] + + if self.shared_state.should_compute: + output = self.fn_ref.original_forward(*args, **kwargs) + if self.is_tail: + hs_residual, ehs_residual = None, None + if isinstance(output, tuple): + hs_residual = ( + output[self._metadata.return_hidden_states_index] - self.shared_state.head_block_output[0] + ) + ehs_residual = ( + output[self._metadata.return_encoder_hidden_states_index] + - self.shared_state.head_block_output[1] + ) + else: + if isinstance(self.shared_state.head_block_output, list): + # For cases where double blocks returning list is followed by single blocks returning single value (Flux) + hs_residual = output - self.shared_state.head_block_output[0] + else: + hs_residual = output - self.shared_state.head_block_output + self.shared_state.tail_block_residuals = (hs_residual, ehs_residual) + return output + + output_count = len(outputs_if_skipped) if isinstance(outputs_if_skipped, tuple) else 1 + if output_count == 1: + return_output = original_hs + else: + return_output = [None] * output_count + return_output[self._metadata.return_hidden_states_index] = original_hs + return_output[self._metadata.return_encoder_hidden_states_index] = original_ehs + return return_output + + +def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None: + shared_state = FBCSharedBlockState() + remaining_blocks = [] + + for name, submodule in module.named_children(): + if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList): + continue + for block in submodule: + remaining_blocks.append((name, block)) + + head_block_name, head_block = remaining_blocks.pop(0) + tail_block_name, tail_block = remaining_blocks.pop(-1) + + logger.debug(f"Apply FBCHeadBlockHook to '{head_block_name}'") + apply_fbc_head_block_hook(head_block, shared_state, config.threshold) + + for name, block in remaining_blocks: + logger.debug(f"Apply FBCBlockHook to '{name}'") + apply_fbc_block_hook(block, shared_state) + + logger.debug(f"Apply FBCBlockHook to tail block '{tail_block_name}'") + apply_fbc_block_hook(tail_block, shared_state, is_tail=True) + + +def apply_fbc_head_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, threshold: float) -> None: + registry = HookRegistry.check_if_exists_or_initialize(block) + hook = FBCHeadBlockHook(state, threshold) + registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK) + + +def apply_fbc_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, is_tail: bool = False) -> None: + registry = HookRegistry.check_if_exists_or_initialize(block) + hook = FBCBlockHook(state, is_tail) + registry.register_hook(hook, _FBC_BLOCK_HOOK) diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 79bd8dc0b254..6d0192239ec5 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -25,6 +25,7 @@ class CacheMixin: Supported caching techniques: - [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) - [FasterCache](https://huggingface.co/papers/2410.19355) + - [FirstBlockCache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching) """ _cache_config = None @@ -62,8 +63,10 @@ def enable_cache(self, config) -> None: from ..hooks import ( FasterCacheConfig, + FirstBlockCacheConfig, PyramidAttentionBroadcastConfig, apply_faster_cache, + apply_first_block_cache, apply_pyramid_attention_broadcast, ) @@ -72,31 +75,36 @@ def enable_cache(self, config) -> None: f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first." ) - if isinstance(config, PyramidAttentionBroadcastConfig): - apply_pyramid_attention_broadcast(self, config) - elif isinstance(config, FasterCacheConfig): + if isinstance(config, FasterCacheConfig): apply_faster_cache(self, config) + elif isinstance(config, FirstBlockCacheConfig): + apply_first_block_cache(self, config) + elif isinstance(config, PyramidAttentionBroadcastConfig): + apply_pyramid_attention_broadcast(self, config) else: raise ValueError(f"Cache config {type(config)} is not supported.") self._cache_config = config def disable_cache(self) -> None: - from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig + from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK + from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK if self._cache_config is None: logger.warning("Caching techniques have not been enabled, so there's nothing to disable.") return - if isinstance(self._cache_config, PyramidAttentionBroadcastConfig): - registry = HookRegistry.check_if_exists_or_initialize(self) - registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True) - elif isinstance(self._cache_config, FasterCacheConfig): - registry = HookRegistry.check_if_exists_or_initialize(self) + registry = HookRegistry.check_if_exists_or_initialize(self) + if isinstance(self._cache_config, FasterCacheConfig): registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True) registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True) + elif isinstance(self._cache_config, FirstBlockCacheConfig): + registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True) + registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True) + elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig): + registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True) else: raise ValueError(f"Cache config {type(self._cache_config)} is not supported.") diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index c1f2df587927..2ae2418098f6 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -26,6 +26,7 @@ from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_processor import Attention +from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -298,7 +299,7 @@ def forward( @maybe_allow_in_graph -class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): +class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin): r""" A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video). diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 6edbd737e32c..dfbac9512e91 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -17,6 +17,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class FirstBlockCacheConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class HookRegistry(metaclass=DummyObject): _backends = ["torch"] @@ -51,6 +66,10 @@ def apply_faster_cache(*args, **kwargs): requires_backends(apply_faster_cache, ["torch"]) +def apply_first_block_cache(*args, **kwargs): + requires_backends(apply_first_block_cache, ["torch"]) + + def apply_pyramid_attention_broadcast(*args, **kwargs): requires_backends(apply_pyramid_attention_broadcast, ["torch"]) diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py index 388dc9ef7ec4..385984f0b497 100644 --- a/tests/pipelines/cogvideo/test_cogvideox.py +++ b/tests/pipelines/cogvideo/test_cogvideox.py @@ -32,6 +32,7 @@ from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import ( FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, check_qkv_fusion_matches_attn_procs_length, @@ -44,7 +45,11 @@ class CogVideoXPipelineFastTests( - PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase + PipelineTesterMixin, + PyramidAttentionBroadcastTesterMixin, + FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, + unittest.TestCase, ): pipeline_class = CogVideoXPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index 6a560367a5b8..b9795fc20b1f 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -25,6 +25,7 @@ from ..test_pipelines_common import ( FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, FluxIPAdapterTesterMixin, PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, @@ -34,11 +35,12 @@ class FluxPipelineFastTests( - unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, + unittest.TestCase, ): pipeline_class = FluxPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video.py b/tests/pipelines/hunyuan_video/test_hunyuan_video.py index aa4f045966c3..e6587520c932 100644 --- a/tests/pipelines/hunyuan_video/test_hunyuan_video.py +++ b/tests/pipelines/hunyuan_video/test_hunyuan_video.py @@ -33,6 +33,7 @@ from ..test_pipelines_common import ( FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np, @@ -43,7 +44,11 @@ class HunyuanVideoPipelineFastTests( - PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase + PipelineTesterMixin, + PyramidAttentionBroadcastTesterMixin, + FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, + unittest.TestCase, ): pipeline_class = HunyuanVideoPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) diff --git a/tests/pipelines/ltx/test_ltx.py b/tests/pipelines/ltx/test_ltx.py index 4f72729fc9ce..1f94b746f12f 100644 --- a/tests/pipelines/ltx/test_ltx.py +++ b/tests/pipelines/ltx/test_ltx.py @@ -23,13 +23,13 @@ from diffusers.utils.testing_utils import enable_full_determinism, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin, to_np +from ..test_pipelines_common import FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np enable_full_determinism() -class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class LTXPipelineFastTests(PipelineTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase): pipeline_class = LTXPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS @@ -49,7 +49,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): test_layerwise_casting = True test_group_offloading = True - def get_dummy_components(self): + def get_dummy_components(self, num_layers: int = 1): torch.manual_seed(0) transformer = LTXVideoTransformer3DModel( in_channels=8, @@ -59,7 +59,7 @@ def get_dummy_components(self): num_attention_heads=4, attention_head_dim=8, cross_attention_dim=32, - num_layers=1, + num_layers=num_layers, caption_channels=32, ) diff --git a/tests/pipelines/mochi/test_mochi.py b/tests/pipelines/mochi/test_mochi.py index ea2d015af52a..ce052962e511 100644 --- a/tests/pipelines/mochi/test_mochi.py +++ b/tests/pipelines/mochi/test_mochi.py @@ -33,13 +33,15 @@ ) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import FasterCacheTesterMixin, PipelineTesterMixin, to_np +from ..test_pipelines_common import FasterCacheTesterMixin, FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np enable_full_determinism() -class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unittest.TestCase): +class MochiPipelineFastTests( + PipelineTesterMixin, FasterCacheTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase +): pipeline_class = MochiPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index d069def66ecf..08029419de3b 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -33,6 +33,7 @@ ) from diffusers.hooks import apply_group_offloading from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook +from diffusers.hooks.first_block_cache import FirstBlockCacheConfig from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin @@ -2608,7 +2609,7 @@ def run_forward(pipe): self.faster_cache_config.current_timestep_callback = lambda: pipe.current_timestep pipe = create_pipe() pipe.transformer.enable_cache(self.faster_cache_config) - output = run_forward(pipe).flatten().flatten() + output = run_forward(pipe).flatten() image_slice_faster_cache_enabled = np.concatenate((output[:8], output[-8:])) # Run inference with FasterCache disabled @@ -2715,6 +2716,55 @@ def faster_cache_state_check_callback(pipe, i, t, kwargs): self.assertTrue(state.cache is None, "Cache should be reset to None.") +# TODO(aryan, dhruv): the cache tester mixins should probably be rewritten so that more models can be tested out +# of the box once there is better cache support/implementation +class FirstBlockCacheTesterMixin: + # threshold is intentionally set higher than usual values since we're testing with random unconverged models + # that will not satisfy the expected properties of the denoiser for caching to be effective + first_block_cache_config = FirstBlockCacheConfig(threshold=0.8) + + def test_first_block_cache_inference(self, expected_atol: float = 0.1): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + def create_pipe(): + torch.manual_seed(0) + num_layers = 2 + components = self.get_dummy_components(num_layers=num_layers) + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + return pipe + + def run_forward(pipe): + torch.manual_seed(0) + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 4 + return pipe(**inputs)[0] + + # Run inference without FirstBlockCache + pipe = create_pipe() + output = run_forward(pipe).flatten() + original_image_slice = np.concatenate((output[:8], output[-8:])) + + # Run inference with FirstBlockCache enabled + pipe = create_pipe() + pipe.transformer.enable_cache(self.first_block_cache_config) + output = run_forward(pipe).flatten() + image_slice_fbc_enabled = np.concatenate((output[:8], output[-8:])) + + # Run inference with FirstBlockCache disabled + pipe.transformer.disable_cache() + output = run_forward(pipe).flatten() + image_slice_fbc_disabled = np.concatenate((output[:8], output[-8:])) + + assert np.allclose( + original_image_slice, image_slice_fbc_enabled, atol=expected_atol + ), "FirstBlockCache outputs should not differ much." + assert np.allclose( + original_image_slice, image_slice_fbc_disabled, atol=1e-4 + ), "Outputs from normal inference and after disabling cache should not differ." + + # Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. # This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a # reference image. From dd69b418349bb923155d11371185a0424a1c0041 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 1 Apr 2025 01:28:09 +0200 Subject: [PATCH 02/19] modify flux single blocks to make compatible with cache techniques (without too much model-specific intrusion code) --- src/diffusers/hooks/_helpers.py | 6 +++--- src/diffusers/hooks/first_block_cache.py | 6 +----- .../models/transformers/transformer_flux.py | 21 ++++++++++--------- 3 files changed, 15 insertions(+), 18 deletions(-) diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py index 606a58cd578e..253ca88059e5 100644 --- a/src/diffusers/hooks/_helpers.py +++ b/src/diffusers/hooks/_helpers.py @@ -84,8 +84,8 @@ def _register_transformer_blocks_metadata(): model_class=FluxSingleTransformerBlock, metadata=TransformerBlockMetadata( skip_block_output_fn=_skip_block_output_fn_FluxSingleTransformerBlock, - return_hidden_states_index=0, - return_encoder_hidden_states_index=None, + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, ), ) @@ -185,7 +185,7 @@ def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___en _skip_block_output_fn_CogVideoXBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states _skip_block_output_fn_CogView4TransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states _skip_block_output_fn_FluxTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states -_skip_block_output_fn_FluxSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states +_skip_block_output_fn_FluxSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states _skip_block_output_fn_HunyuanVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states _skip_block_output_fn_HunyuanVideoSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states _skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py index 1f1bfd6c8cf9..b440af0faed2 100644 --- a/src/diffusers/hooks/first_block_cache.py +++ b/src/diffusers/hooks/first_block_cache.py @@ -166,11 +166,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): - self.shared_state.head_block_output[1] ) else: - if isinstance(self.shared_state.head_block_output, list): - # For cases where double blocks returning list is followed by single blocks returning single value (Flux) - hs_residual = output - self.shared_state.head_block_output[0] - else: - hs_residual = output - self.shared_state.head_block_output + hs_residual = output - self.shared_state.head_block_output self.shared_state.tail_block_residuals = (hs_residual, ehs_residual) return output diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 87537890d246..b0fb3900f657 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -79,10 +79,14 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, def forward( self, hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + residual = hidden_states norm_hidden_states, gate = self.norm(hidden_states, emb=temb) mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) @@ -100,7 +104,8 @@ def forward( if hidden_states.dtype == torch.float16: hidden_states = hidden_states.clip(-65504, 65504) - return hidden_states + encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] + return encoder_hidden_states, hidden_states @maybe_allow_in_graph @@ -508,20 +513,21 @@ def forward( ) else: hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) for index_block, block in enumerate(self.single_transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func( + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( block, hidden_states, + encoder_hidden_states, temb, image_rotary_emb, ) else: - hidden_states = block( + encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, @@ -531,12 +537,7 @@ def forward( if controlnet_single_block_samples is not None: interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) interval_control = int(np.ceil(interval_control)) - hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( - hidden_states[:, encoder_hidden_states.shape[1] :, ...] - + controlnet_single_block_samples[index_block // interval_control] - ) - - hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] + hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control] hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) From 7ab424a15a9f89fc7679cd72a22a1a756959be4e Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 1 Apr 2025 01:39:00 +0200 Subject: [PATCH 03/19] remove debug logs --- src/diffusers/hooks/first_block_cache.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py index b440af0faed2..cdc08b4a4c9f 100644 --- a/src/diffusers/hooks/first_block_cache.py +++ b/src/diffusers/hooks/first_block_cache.py @@ -87,8 +87,6 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): if not should_compute: # Apply caching - logger.info("Skipping forward pass through remaining blocks") - if is_output_tuple: hs = self.shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index] else: @@ -109,7 +107,6 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): return_output = hs return return_output else: - logger.info("Computing forward pass through remaining blocks") if is_output_tuple: head_block_output = [None] * len(output) head_block_output[0] = output[self._metadata.return_hidden_states_index] @@ -131,7 +128,6 @@ def _should_compute_remaining_blocks(self, hs_residual: torch.Tensor) -> bool: hs_absmean = (hs_residual - prev_hs_residual).abs().mean() prev_hs_mean = prev_hs_residual.abs().mean() diff = (hs_absmean / prev_hs_mean).item() - logger.info(f"Diff: {diff}, Threshold: {self.threshold}") return diff > self.threshold From d71fe55895c4503b313a2d0c0740b50d51e5eb31 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 1 Apr 2025 17:06:45 +0200 Subject: [PATCH 04/19] update --- src/diffusers/hooks/first_block_cache.py | 14 +++++++++----- .../models/transformers/transformer_wan.py | 3 ++- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py index cdc08b4a4c9f..f1b150ac75d7 100644 --- a/src/diffusers/hooks/first_block_cache.py +++ b/src/diffusers/hooks/first_block_cache.py @@ -105,7 +105,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): return_output = tuple(return_output) else: return_output = hs - return return_output + output = return_output else: if is_output_tuple: head_block_output = [None] * len(output) @@ -115,12 +115,14 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): head_block_output = output self.shared_state.head_block_output = head_block_output self.shared_state.head_block_residual = hs_residual - return output + + return output def reset_state(self, module): self.shared_state.reset() return module + @torch.compiler.disable def _should_compute_remaining_blocks(self, hs_residual: torch.Tensor) -> bool: if self.shared_state.head_block_residual is None: return True @@ -144,6 +146,8 @@ def initialize_hook(self, module): def new_forward(self, module: torch.nn.Module, *args, **kwargs): outputs_if_skipped = self._metadata.skip_block_output_fn(module, *args, **kwargs) + if not isinstance(outputs_if_skipped, tuple): + outputs_if_skipped = (outputs_if_skipped,) original_hs = outputs_if_skipped[self._metadata.return_hidden_states_index] original_ehs = None if self._metadata.return_encoder_hidden_states_index is not None: @@ -166,7 +170,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): self.shared_state.tail_block_residuals = (hs_residual, ehs_residual) return output - output_count = len(outputs_if_skipped) if isinstance(outputs_if_skipped, tuple) else 1 + output_count = len(outputs_if_skipped) if output_count == 1: return_output = original_hs else: @@ -183,8 +187,8 @@ def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConf for name, submodule in module.named_children(): if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList): continue - for block in submodule: - remaining_blocks.append((name, block)) + for index, block in enumerate(submodule): + remaining_blocks.append((f"{name}.{index}", block)) head_block_name, head_block = remaining_blocks.pop(0) tail_block_name, tail_block = remaining_blocks.pop(-1) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 4eb4add37601..aa03e97093aa 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -24,6 +24,7 @@ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention import FeedForward from ..attention_processor import Attention +from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -288,7 +289,7 @@ def forward( return hidden_states -class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): +class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): r""" A Transformer model for video-like data used in the Wan model. From 2557238b4d33ea60b6c5e1829c065a132aa9c9aa Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 1 Apr 2025 19:40:23 +0200 Subject: [PATCH 05/19] cache context for different batches of data --- src/diffusers/hooks/first_block_cache.py | 7 +- src/diffusers/hooks/hooks.py | 88 +++++++++++++++++++ src/diffusers/models/cache_utils.py | 20 +++++ .../pipelines/cogview4/pipeline_cogview4.py | 4 +- .../hunyuan_video/pipeline_hunyuan_video.py | 4 +- src/diffusers/pipelines/wan/pipeline_wan.py | 4 +- 6 files changed, 122 insertions(+), 5 deletions(-) diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py index f1b150ac75d7..306825800e76 100644 --- a/src/diffusers/hooks/first_block_cache.py +++ b/src/diffusers/hooks/first_block_cache.py @@ -20,7 +20,7 @@ from ..utils import get_logger from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS from ._helpers import TransformerBlockRegistry -from .hooks import HookRegistry, ModelHook +from .hooks import BaseMarkedState, HookRegistry, ModelHook logger = get_logger(__name__) # pylint: disable=invalid-name @@ -48,8 +48,10 @@ class FirstBlockCacheConfig: threshold: float = 0.05 -class FBCSharedBlockState: +class FBCSharedBlockState(BaseMarkedState): def __init__(self) -> None: + super().__init__() + self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None self.head_block_residual: torch.Tensor = None self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None @@ -130,6 +132,7 @@ def _should_compute_remaining_blocks(self, hs_residual: torch.Tensor) -> bool: hs_absmean = (hs_residual - prev_hs_residual).abs().mean() prev_hs_mean = prev_hs_residual.abs().mean() diff = (hs_absmean / prev_hs_mean).item() + print("diff:", self.shared_state._mark_name, diff, flush=True) return diff > self.threshold diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index 3b2e4ed91c2f..9e8128d0bb18 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -23,6 +23,70 @@ logger = get_logger(__name__) # pylint: disable=invalid-name +class BaseState: + def reset(self, *args, **kwargs) -> None: + raise NotImplementedError( + "BaseState::reset is not implemented. Please implement this method in the derived class." + ) + + +class BaseMarkedState(BaseState): + def __init__(self, init_args=None, init_kwargs=None): + super().__init__() + + self._init_args = init_args if init_args is not None else () + self._init_kwargs = init_kwargs if init_kwargs is not None else {} + self._mark_name = None + self._state_cache = {} + + def get_current_state(self) -> "BaseMarkedState": + if self._mark_name is None: + # If no mark name is set, simply return a dummy object since we're not going to be using it + return self + if self._mark_name not in self._state_cache.keys(): + self._state_cache[self._mark_name] = self.__class__(*self._init_args, **self._init_kwargs) + return self._state_cache[self._mark_name] + + def mark_batch(self, name: str) -> None: + self._mark_name = name + + def reset(self, *args, **kwargs) -> None: + for name, state in list(self._state_cache.items()): + state.reset(*args, **kwargs) + self._state_cache.pop(name) + self._mark_name = None + + def __getattribute__(self, name): + if name in ( + "get_current_state", + "mark_batch", + "reset", + "_init_args", + "_init_kwargs", + "_mark_name", + "_state_cache", + ) or _is_dunder_method(name): + return object.__getattribute__(self, name) + else: + current_state = BaseMarkedState.get_current_state(self) + return object.__getattribute__(current_state, name) + + def __setattr__(self, name, value): + if name in ( + "get_current_state", + "mark_batch", + "reset", + "_init_args", + "_init_kwargs", + "_mark_name", + "_state_cache", + ) or _is_dunder_method(name): + object.__setattr__(self, name, value) + else: + current_state = BaseMarkedState.get_current_state(self) + object.__setattr__(current_state, name, value) + + class ModelHook: r""" A hook that contains callbacks to be executed just before and after the forward method of a model. @@ -99,6 +163,14 @@ def reset_state(self, module: torch.nn.Module): raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.") return module + def _mark_state(self, module: torch.nn.Module, name: str) -> None: + # Iterate over all attributes of the hook to see if any of them have the type `BaseMarkedState`. If so, call `mark_batch` on them. + for attr_name in dir(self): + attr = getattr(self, attr_name) + if isinstance(attr, BaseMarkedState): + attr.mark_batch(name) + return module + class HookFunctionReference: def __init__(self) -> None: @@ -223,6 +295,18 @@ def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry module._diffusers_hook = cls(module) return module._diffusers_hook + def _mark_state(self, name: str) -> None: + for hook_name in reversed(self._hook_order): + hook = self.hooks[hook_name] + if hook._is_stateful: + hook._mark_state(self._module_ref, name) + + for module_name, module in self._module_ref.named_modules(): + if module_name == "": + continue + if hasattr(module, "_diffusers_hook"): + module._diffusers_hook._mark_state(name) + def __repr__(self) -> str: registry_repr = "" for i, hook_name in enumerate(self._hook_order): @@ -234,3 +318,7 @@ def __repr__(self) -> str: if i < len(self._hook_order) - 1: registry_repr += "\n" return f"HookRegistry(\n{registry_repr}\n)" + + +def _is_dunder_method(name: str) -> bool: + return name.startswith("__") and name.endswith("__") and name in dir(object) diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 6d0192239ec5..6c4bcb301d70 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import contextmanager + from ..utils.logging import get_logger @@ -114,3 +116,21 @@ def _reset_stateful_cache(self, recurse: bool = True) -> None: from ..hooks import HookRegistry HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse) + + @contextmanager + def _cache_context(self): + r"""Context manager that provides additional methods for cache management.""" + cache_context = _CacheContextManager(self) + yield cache_context + + +class _CacheContextManager: + def __init__(self, model: CacheMixin): + self.model = model + + def mark_state(self, name: str) -> None: + from ..hooks import HookRegistry + + if self.model.is_cache_enabled: + registry = HookRegistry.check_if_exists_or_initialize(self.model) + registry._mark_state(name) diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py index c27a1a19774d..6cf74ac5d942 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py @@ -610,7 +610,7 @@ def __call__( transformer_dtype = self.transformer.dtype num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -621,6 +621,7 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]) + cc.mark_state("cond") noise_pred_cond = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, @@ -634,6 +635,7 @@ def __call__( # perform guidance if self.do_classifier_free_guidance: + cc.mark_state("uncond") noise_pred_uncond = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=negative_prompt_embeds, diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 3cb91b3782f2..b36de61c02ef 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -683,7 +683,7 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -693,6 +693,7 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) + cc.mark_state("cond") noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, @@ -705,6 +706,7 @@ def __call__( )[0] if do_true_cfg: + cc.mark_state("uncond") neg_noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index 6fab997e6660..733d79b5ac2c 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -512,7 +512,7 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -521,6 +521,7 @@ def __call__( latent_model_input = latents.to(transformer_dtype) timestep = t.expand(latents.shape[0]) + cc.mark_state("cond") noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, @@ -530,6 +531,7 @@ def __call__( )[0] if self.do_classifier_free_guidance: + cc.mark_state("uncond") noise_uncond = self.transformer( hidden_states=latent_model_input, timestep=timestep, From 0e232ac8c0a922e2caf137641b29b7b2cc59a529 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 2 Apr 2025 00:38:11 +0200 Subject: [PATCH 06/19] fix hs residual bug for single return outputs; support ltx --- src/diffusers/hooks/first_block_cache.py | 11 +++++++---- src/diffusers/pipelines/ltx/pipeline_ltx.py | 3 ++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py index 306825800e76..1293ded558f0 100644 --- a/src/diffusers/hooks/first_block_cache.py +++ b/src/diffusers/hooks/first_block_cache.py @@ -81,9 +81,12 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): output = self.fn_ref.original_forward(*args, **kwargs) is_output_tuple = isinstance(output, tuple) - hs_residual = output[self._metadata.return_hidden_states_index] - original_hs - hs, ehs = None, None + if is_output_tuple: + hs_residual = output[self._metadata.return_hidden_states_index] - original_hs + else: + hs_residual = output - original_hs + hs, ehs = None, None should_compute = self._should_compute_remaining_blocks(hs_residual) self.shared_state.should_compute = should_compute @@ -92,9 +95,10 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): if is_output_tuple: hs = self.shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index] else: - hs = output + hs = self.shared_state.tail_block_residuals[0] + output if self._metadata.return_encoder_hidden_states_index is not None: + assert is_output_tuple ehs = ( self.shared_state.tail_block_residuals[1] + output[self._metadata.return_encoder_hidden_states_index] @@ -132,7 +136,6 @@ def _should_compute_remaining_blocks(self, hs_residual: torch.Tensor) -> bool: hs_absmean = (hs_residual - prev_hs_residual).abs().mean() prev_hs_mean = prev_hs_residual.abs().mean() diff = (hs_absmean / prev_hs_mean).item() - print("diff:", self.shared_state._mark_name, diff, flush=True) return diff > self.threshold diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index f7b0811d1a22..316fc4d6b722 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -701,7 +701,7 @@ def __call__( ) # 7. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -712,6 +712,7 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) + cc.mark_state("cond_uncond") noise_pred = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, From 41b0c473d2c8da7eef17abf3a1290878ec509f1b Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 2 Apr 2025 01:20:53 +0200 Subject: [PATCH 07/19] fix controlnet flux --- src/diffusers/models/controlnets/controlnet_flux.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index 51c34b7fe965..04ab72e82a03 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -343,25 +343,25 @@ def forward( ) block_samples = block_samples + (hidden_states,) - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - single_block_samples = () for index_block, block in enumerate(self.single_transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func( + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( block, hidden_states, + encoder_hidden_states, temb, image_rotary_emb, ) else: - hidden_states = block( + encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, ) - single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],) + single_block_samples = single_block_samples + (hidden_states,) # controlnet block controlnet_block_samples = () From 1f33ca276d064b258dc67b285fad5c6c80f43a98 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 2 Apr 2025 01:21:09 +0200 Subject: [PATCH 08/19] support flux, ltx i2v, ltx condition --- src/diffusers/pipelines/flux/pipeline_flux.py | 5 ++++- src/diffusers/pipelines/ltx/pipeline_ltx_condition.py | 3 ++- src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py | 3 ++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 862c279cfaf3..a7195d3a679d 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -906,7 +906,7 @@ def __call__( ) # 6. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -917,6 +917,7 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) + cc.mark_state("cond") noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, @@ -932,6 +933,8 @@ def __call__( if do_true_cfg: if negative_image_embeds is not None: self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + + cc.mark_state("uncond") neg_noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py index e7f3666cb2c7..e3b49cb673a3 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -1061,7 +1061,7 @@ def __call__( self._num_timesteps = len(timesteps) # 7. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -1090,6 +1090,7 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float() timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0) + cc.mark_state("cond_uncond") noise_pred = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index 6c4214fe1b26..9ee96e6a3954 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -771,7 +771,7 @@ def __call__( ) # 7. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -783,6 +783,7 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]) timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) + cc.mark_state("cond_uncond") noise_pred = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, From c76e1cc17e451724848a96d7f3bbf6c8aa184267 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 2 Apr 2025 21:52:33 +0200 Subject: [PATCH 09/19] update --- src/diffusers/hooks/first_block_cache.py | 5 +++-- src/diffusers/hooks/hooks.py | 17 ++++++++++------- src/diffusers/utils/torch_utils.py | 5 +++++ 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py index 1293ded558f0..7863a1268843 100644 --- a/src/diffusers/hooks/first_block_cache.py +++ b/src/diffusers/hooks/first_block_cache.py @@ -18,6 +18,7 @@ import torch from ..utils import get_logger +from ..utils.torch_utils import unwrap_module from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS from ._helpers import TransformerBlockRegistry from .hooks import BaseMarkedState, HookRegistry, ModelHook @@ -71,7 +72,7 @@ def __init__(self, shared_state: FBCSharedBlockState, threshold: float): self._metadata = None def initialize_hook(self, module): - self._metadata = TransformerBlockRegistry.get(module.__class__) + self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__) return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): @@ -147,7 +148,7 @@ def __init__(self, shared_state: FBCSharedBlockState, is_tail: bool = False): self._metadata = None def initialize_hook(self, module): - self._metadata = TransformerBlockRegistry.get(module.__class__) + self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__) return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index 9e8128d0bb18..c42592783d91 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -18,6 +18,7 @@ import torch from ..utils.logging import get_logger +from ..utils.torch_utils import unwrap_module logger = get_logger(__name__) # pylint: disable=invalid-name @@ -47,7 +48,7 @@ def get_current_state(self) -> "BaseMarkedState": self._state_cache[self._mark_name] = self.__class__(*self._init_args, **self._init_kwargs) return self._state_cache[self._mark_name] - def mark_batch(self, name: str) -> None: + def mark_state(self, name: str) -> None: self._mark_name = name def reset(self, *args, **kwargs) -> None: @@ -59,7 +60,7 @@ def reset(self, *args, **kwargs) -> None: def __getattribute__(self, name): if name in ( "get_current_state", - "mark_batch", + "mark_state", "reset", "_init_args", "_init_kwargs", @@ -74,7 +75,7 @@ def __getattribute__(self, name): def __setattr__(self, name, value): if name in ( "get_current_state", - "mark_batch", + "mark_state", "reset", "_init_args", "_init_kwargs", @@ -164,11 +165,11 @@ def reset_state(self, module: torch.nn.Module): return module def _mark_state(self, module: torch.nn.Module, name: str) -> None: - # Iterate over all attributes of the hook to see if any of them have the type `BaseMarkedState`. If so, call `mark_batch` on them. + # Iterate over all attributes of the hook to see if any of them have the type `BaseMarkedState`. If so, call `mark_state` on them. for attr_name in dir(self): attr = getattr(self, attr_name) if isinstance(attr, BaseMarkedState): - attr.mark_batch(name) + attr.mark_state(name) return module @@ -283,9 +284,10 @@ def reset_stateful_hooks(self, recurse: bool = True) -> None: hook.reset_state(self._module_ref) if recurse: - for module_name, module in self._module_ref.named_modules(): + for module_name, module in unwrap_module(self._module_ref).named_modules(): if module_name == "": continue + module = unwrap_module(module) if hasattr(module, "_diffusers_hook"): module._diffusers_hook.reset_stateful_hooks(recurse=False) @@ -301,9 +303,10 @@ def _mark_state(self, name: str) -> None: if hook._is_stateful: hook._mark_state(self._module_ref, name) - for module_name, module in self._module_ref.named_modules(): + for module_name, module in unwrap_module(self._module_ref).named_modules(): if module_name == "": continue + module = unwrap_module(module) if hasattr(module, "_diffusers_hook"): module._diffusers_hook._mark_state(name) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 3c8911773e39..06f9981f0138 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -90,6 +90,11 @@ def is_compiled_module(module) -> bool: return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) +def unwrap_module(module): + """Unwraps a module if it was compiled with torch.compile()""" + return module._orig_mod if is_compiled_module(module) else module + + def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor": """Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497). From 594e8d663f136e88ff7c4a4574b5b5372a770c57 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 3 Apr 2025 00:13:15 +0200 Subject: [PATCH 10/19] classifier-free guidance --- src/diffusers/__init__.py | 3 + src/diffusers/guiders/__init__.py | 20 ++++ .../guiders/classifier_free_guidance.py | 86 +++++++++++++++++ src/diffusers/guiders/guider_utils.py | 96 +++++++++++++++++++ .../pipelines/cogview4/pipeline_cogview4.py | 74 +++++++------- src/diffusers/utils/dummy_pt_objects.py | 15 +++ 6 files changed, 260 insertions(+), 34 deletions(-) create mode 100644 src/diffusers/guiders/__init__.py create mode 100644 src/diffusers/guiders/classifier_free_guidance.py create mode 100644 src/diffusers/guiders/guider_utils.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 60ac21626943..cb007b1c1d97 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -33,6 +33,7 @@ _import_structure = { "configuration_utils": ["ConfigMixin"], + "guiders": [], "hooks": [], "loaders": ["FromOriginalModelMixin"], "models": [], @@ -129,6 +130,7 @@ _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: + _import_structure["guiders"].extend(["ClassifierFreeGuidance"]) _import_structure["hooks"].extend( [ "FasterCacheConfig", @@ -710,6 +712,7 @@ except OptionalDependencyNotAvailable: from .utils.dummy_pt_objects import * # noqa F403 else: + from .guiders import ClassifierFreeGuidance from .hooks import ( FasterCacheConfig, FirstBlockCacheConfig, diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py new file mode 100644 index 000000000000..c56f825512de --- /dev/null +++ b/src/diffusers/guiders/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..utils import is_torch_available + + +if is_torch_available(): + from .classifier_free_guidance import ClassifierFreeGuidance + from .guider_utils import GuidanceMixin, _raise_guidance_deprecation_warning diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py new file mode 100644 index 000000000000..2de97291c629 --- /dev/null +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -0,0 +1,86 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import torch + +from .guider_utils import GuidanceMixin, rescale_noise_cfg + + +class ClassifierFreeGuidance(GuidanceMixin): + """ + Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598 + + CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by + jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during + inference. This allows the model to tradeoff between generation quality and sample diversity. + + The original paper proposes scaling and shifting the conditional distribution based on the difference between + conditional and unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)] + + Diffusers implemented the scaling and shifting on the unconditional prediction instead, which is equivalent to what + the original paper proposed in theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)] + + The intution behind the original formulation can be thought of as moving the conditional distribution estimates + further away from the unconditional distribution estimates, while the diffusers-native implementation can be + thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of + the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.) + + The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the + paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time. + + Args: + scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. + """ + + def __init__(self, scale: float = 7.5, rescale: float = 0.0, use_original_formulation: bool = False): + self.scale = scale + self.rescale = rescale + self.use_original_formulation = use_original_formulation + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + if math.isclose(self.scale, 1.0): + return pred_cond + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.scale * shift + if self.rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.rescale) + return pred + + @property + def num_conditions(self) -> int: + if math.isclose(self.scale, 1.0): + return 1 + return 2 + + @property + def guidance_scale(self) -> float: + return self.scale + + @property + def guidance_rescale(self) -> float: + return self.rescale diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py new file mode 100644 index 000000000000..54d3c519556e --- /dev/null +++ b/src/diffusers/guiders/guider_utils.py @@ -0,0 +1,96 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Optional, Tuple, Union + +import torch + +from ..utils import deprecate, get_logger + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +class GuidanceMixin: + r"""Base mixin class providing the skeleton for implementing guidance techniques.""" + + def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + num_conditions = self.num_conditions + list_of_inputs = [] + for arg in args: + if isinstance(arg, torch.Tensor): + list_of_inputs.append([arg] * num_conditions) + elif isinstance(arg, (tuple, list)): + inputs = [x for x in arg if x is not None] + if len(inputs) < num_conditions: + raise ValueError(f"Required at least {num_conditions} inputs, but got {len(inputs)}.") + list_of_inputs.append(inputs[:num_conditions]) + else: + raise ValueError( + f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list." + ) + return tuple(list_of_inputs) + + def __call__(self, *args) -> Any: + if len(args) != self.num_conditions: + raise ValueError( + f"Expected {self.num_conditions} arguments, but got {len(args)}. Please provide the correct number of arguments." + ) + return self.forward(*args) + + def forward(self, *args, **kwargs) -> Any: + raise NotImplementedError("GuidanceMixin::forward must be implemented in subclasses.") + + @property + def num_conditions(self) -> int: + raise NotImplementedError("GuidanceMixin::num_conditions must be implemented in subclasses.") + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +def _raise_guidance_deprecation_warning( + *, + guidance_scale: Optional[float] = None, + guidance_rescale: Optional[float] = None, +) -> None: + if guidance_scale is not None: + msg = "The `guidance_scale` argument is deprecated and will be removed in version 1.0.0. Please pass a `GuidanceMixin` object for the `guidance` argument instead." + deprecate("guidance_scale", "1.0.0", msg, standard_warn=False) + if guidance_rescale is not None: + msg = "The `guidance_rescale` argument is deprecated and will be removed in version 1.0.0. Please pass a `GuidanceMixin` object for the `guidance` argument instead." + deprecate("guidance_rescale", "1.0.0", msg, standard_warn=False) diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py index 6cf74ac5d942..568cfd04b833 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py @@ -21,6 +21,7 @@ from transformers import AutoTokenizer, GlmModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...guiders import ClassifierFreeGuidance, GuidanceMixin, _raise_guidance_deprecation_warning from ...image_processor import VaeImageProcessor from ...loaders import CogView4LoraLoaderMixin from ...models import AutoencoderKL, CogView4Transformer2DModel @@ -428,6 +429,7 @@ def __call__( ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 1024, + guidance: Optional[GuidanceMixin] = None, ) -> Union[CogView4PipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -516,6 +518,10 @@ def __call__( `tuple`. When returning a tuple, the first element is a list with the generated images. """ + _raise_guidance_deprecation_warning(guidance_scale=guidance_scale) + if guidance is None: + guidance = ClassifierFreeGuidance(scale=guidance_scale) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs @@ -606,52 +612,45 @@ def __call__( ) self._num_timesteps = len(timesteps) + latents, prompt_embeds, original_size, target_size, crops_coords_top_left = guidance.prepare_inputs( + latents, + (prompt_embeds, negative_prompt_embeds), + original_size, + target_size, + crops_coords_top_left, + ) + # Denoising loop transformer_dtype = self.transformer.dtype num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: for i, t in enumerate(timesteps): + self._current_timestep = t if self.interrupt: continue - self._current_timestep = t - latent_model_input = latents.to(transformer_dtype) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latents.shape[0]) - - cc.mark_state("cond") - noise_pred_cond = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - original_size=original_size, - target_size=target_size, - crop_coords=crops_coords_top_left, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - # perform guidance - if self.do_classifier_free_guidance: - cc.mark_state("uncond") - noise_pred_uncond = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=negative_prompt_embeds, + noise_preds = [] + for i, (latent, condition, original_size_c, target_size_c, crop_coord_c) in enumerate( + zip(latents, prompt_embeds, original_size, target_size, crops_coords_top_left) + ): + cc.mark_state(f"batch_{i}") + latent = latent.to(transformer_dtype) + timestep = t.expand(latent.shape[0]) + noise_pred = self.transformer( + hidden_states=latent, + encoder_hidden_states=condition, timestep=timestep, - original_size=original_size, - target_size=target_size, - crop_coords=crops_coords_top_left, + original_size=original_size_c, + target_size=target_size_c, + crop_coords=crop_coord_c, attention_kwargs=attention_kwargs, return_dict=False, )[0] + noise_preds.append(noise_pred) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) - else: - noise_pred = noise_pred_cond - - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + noise_pred = guidance(*noise_preds) + latents = self.scheduler.step(noise_pred, t, latents[0], return_dict=False)[0] # call the callback, if provided if callback_on_step_end is not None: @@ -660,8 +659,14 @@ def __call__( callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs) latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + prompt_embeds = [callback_outputs.pop("prompt_embeds", prompt_embeds[0])] + negative_prompt_embeds = [ + callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds[0]) + ] + + latents, prompt_embeds = guidance.prepare_inputs( + latents, (prompt_embeds[0], negative_prompt_embeds[0]) + ) if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() @@ -670,6 +675,7 @@ def __call__( xm.mark_step() self._current_timestep = None + latents = latents[0] if not output_type == "latent": latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index dfbac9512e91..7ae9ca4c67a6 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -2,6 +2,21 @@ from ..utils import DummyObject, requires_backends +class ClassifierFreeGuidance(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class FasterCacheConfig(metaclass=DummyObject): _backends = ["torch"] From 5ac7f360b015cd505159a90c3e68bfb9a93a7709 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 3 Apr 2025 03:26:55 +0200 Subject: [PATCH 11/19] skip layer guidance --- src/diffusers/__init__.py | 8 +- src/diffusers/guiders/__init__.py | 1 + .../guiders/classifier_free_guidance.py | 35 ++-- src/diffusers/guiders/guider_utils.py | 32 ++- src/diffusers/guiders/skip_layer_guidance.py | 195 ++++++++++++++++++ src/diffusers/hooks/__init__.py | 1 + src/diffusers/hooks/layer_skip.py | 110 ++++++++++ .../pipelines/cogview4/pipeline_cogview4.py | 33 +-- src/diffusers/utils/dummy_pt_objects.py | 34 +++ 9 files changed, 407 insertions(+), 42 deletions(-) create mode 100644 src/diffusers/guiders/skip_layer_guidance.py create mode 100644 src/diffusers/hooks/layer_skip.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index cb007b1c1d97..91c41bdd438e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -130,15 +130,17 @@ _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: - _import_structure["guiders"].extend(["ClassifierFreeGuidance"]) + _import_structure["guiders"].extend(["ClassifierFreeGuidance", "SkipLayerGuidance"]) _import_structure["hooks"].extend( [ "FasterCacheConfig", "FirstBlockCacheConfig", "HookRegistry", + "LayerSkipConfig", "PyramidAttentionBroadcastConfig", "apply_faster_cache", "apply_first_block_cache", + "apply_layer_skip", "apply_pyramid_attention_broadcast", ] ) @@ -712,14 +714,16 @@ except OptionalDependencyNotAvailable: from .utils.dummy_pt_objects import * # noqa F403 else: - from .guiders import ClassifierFreeGuidance + from .guiders import ClassifierFreeGuidance, SkipLayerGuidance from .hooks import ( FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, + LayerSkipConfig, PyramidAttentionBroadcastConfig, apply_faster_cache, apply_first_block_cache, + apply_layer_skip, apply_pyramid_attention_broadcast, ) from .models import ( diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index c56f825512de..9724d307560f 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -18,3 +18,4 @@ if is_torch_available(): from .classifier_free_guidance import ClassifierFreeGuidance from .guider_utils import GuidanceMixin, _raise_guidance_deprecation_warning + from .skip_layer_guidance import SkipLayerGuidance diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py index 2de97291c629..18f2a2d31bdd 100644 --- a/src/diffusers/guiders/classifier_free_guidance.py +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -43,11 +43,11 @@ class ClassifierFreeGuidance(GuidanceMixin): paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time. Args: - scale (`float`, defaults to `7.5`): + guidance_scale (`float`, defaults to `7.5`): The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and deterioration of image quality. - rescale (`float`, defaults to `0.0`): + guidance_rescale (`float`, defaults to `0.0`): The rescale factor applied to the noise predictions. This is used to improve image quality and fix overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). @@ -56,31 +56,26 @@ class ClassifierFreeGuidance(GuidanceMixin): we use the diffusers-native implementation that has been in the codebase for a long time. """ - def __init__(self, scale: float = 7.5, rescale: float = 0.0, use_original_formulation: bool = False): - self.scale = scale - self.rescale = rescale + def __init__( + self, guidance_scale: float = 7.5, guidance_rescale: float = 0.0, use_original_formulation: bool = False + ): + self.guidance_scale = guidance_scale + self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: - if math.isclose(self.scale, 1.0): + if math.isclose(self.guidance_scale, 1.0): return pred_cond shift = pred_cond - pred_uncond pred = pred_cond if self.use_original_formulation else pred_uncond - pred = pred + self.scale * shift - if self.rescale > 0.0: - pred = rescale_noise_cfg(pred, pred_cond, self.rescale) + pred = pred + self.guidance_scale * shift + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) return pred @property def num_conditions(self) -> int: - if math.isclose(self.scale, 1.0): - return 1 - return 2 - - @property - def guidance_scale(self) -> float: - return self.scale - - @property - def guidance_rescale(self) -> float: - return self.rescale + num_conditions = 1 + if not math.isclose(self.guidance_scale, 1.0): + num_conditions += 1 + return num_conditions diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index 54d3c519556e..413a33c41cef 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -25,6 +25,19 @@ class GuidanceMixin: r"""Base mixin class providing the skeleton for implementing guidance techniques.""" + def __init__(self): + self._step: int = None + self._num_inference_steps: int = None + self._timestep: torch.LongTensor = None + + def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None: + self._step = step + self._num_inference_steps = num_inference_steps + self._timestep = timestep + + def prepare_models(self, denoiser: torch.nn.Module) -> None: + pass + def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: num_conditions = self.num_conditions list_of_inputs = [] @@ -32,16 +45,27 @@ def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) if isinstance(arg, torch.Tensor): list_of_inputs.append([arg] * num_conditions) elif isinstance(arg, (tuple, list)): - inputs = [x for x in arg if x is not None] - if len(inputs) < num_conditions: - raise ValueError(f"Required at least {num_conditions} inputs, but got {len(inputs)}.") - list_of_inputs.append(inputs[:num_conditions]) + if len(arg) != 2: + raise ValueError( + f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 " + f"with the first element being the conditional input and the second element being the unconditional input or None." + ) + if arg[1] is None: + # Only conditioning inputs for all batches + list_of_inputs.append([arg[0]] * num_conditions) + else: + # Alternating conditional and unconditional inputs as batches + inputs = [arg[i % 2] for i in range(num_conditions)] + list_of_inputs.append(inputs) else: raise ValueError( f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list." ) return tuple(list_of_inputs) + def cleanup_models(self, denoiser: torch.nn.Module) -> None: + pass + def __call__(self, *args) -> Any: if len(args) != self.num_conditions: raise ValueError( diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py new file mode 100644 index 000000000000..677d97a47c9f --- /dev/null +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -0,0 +1,195 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple, Union + +import torch + +from ..hooks import HookRegistry, LayerSkipConfig +from ..hooks.layer_skip import _apply_layer_skip_hook +from .guider_utils import GuidanceMixin, rescale_noise_cfg + + +class SkipLayerGuidance(GuidanceMixin): + """ + Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5 + + CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by + jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during + inference. This allows the model to tradeoff between generation quality and sample diversity. + + The original paper proposes scaling and shifting the conditional distribution based on the difference between + conditional and unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)] + + Diffusers implemented the scaling and shifting on the unconditional prediction instead, which is equivalent to what + the original paper proposed in theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)] + + The intution behind the original formulation can be thought of as moving the conditional distribution estimates + further away from the unconditional distribution estimates, while the diffusers-native implementation can be + thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of + the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.) + + The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the + paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time. + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. + """ + + def __init__( + self, + guidance_scale: float = 7.5, + skip_layer_guidance_scale: float = 2.8, + skip_layer_guidance_start: float = 0.01, + skip_layer_guidance_stop: float = 0.2, + skip_guidance_layers: Optional[Union[int, List[int]]] = None, + skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + ): + self.guidance_scale = guidance_scale + self.skip_layer_guidance_scale = skip_layer_guidance_scale + self.skip_layer_guidance_start = skip_layer_guidance_start + self.skip_layer_guidance_stop = skip_layer_guidance_stop + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + if skip_guidance_layers is None and skip_layer_config is None: + raise ValueError( + "Either `skip_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance." + ) + if skip_guidance_layers is not None and skip_layer_config is not None: + raise ValueError("Only one of `skip_guidance_layers` or `skip_layer_config` can be provided.") + + if skip_guidance_layers is not None: + if isinstance(skip_guidance_layers, int): + skip_guidance_layers = [skip_guidance_layers] + if not isinstance(skip_guidance_layers, list): + raise ValueError( + f"Expected `skip_guidance_layers` to be an int or a list of ints, but got {type(skip_guidance_layers)}." + ) + skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_guidance_layers] + + if isinstance(skip_layer_config, LayerSkipConfig): + skip_layer_config = [skip_layer_config] + + if not isinstance(skip_layer_config, list): + raise ValueError( + f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}." + ) + + self.skip_layer_config = skip_layer_config + self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))] + + def prepare_models(self, denoiser: torch.nn.Module): + skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps) + skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps) + + # Register the hooks for layer skipping if the step is within the specified range + if skip_start_step < self._step < skip_stop_step: + for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config): + _apply_layer_skip_hook(denoiser, config, name=name) + + def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + num_conditions = self.num_conditions + list_of_inputs = [] + for arg in args: + if isinstance(arg, torch.Tensor): + list_of_inputs.append([arg] * num_conditions) + elif isinstance(arg, (tuple, list)): + if len(arg) != 2: + raise ValueError( + f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 " + f"with the first element being the conditional input and the second element being the unconditional input or None." + ) + if arg[1] is None: + # Only conditioning inputs for all batches + list_of_inputs.append([arg[0]] * num_conditions) + else: + list_of_inputs.append([arg[0], arg[1], arg[0]]) + else: + raise ValueError( + f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list." + ) + return tuple(list_of_inputs) + + def cleanup_models(self, denoiser: torch.nn.Module): + registry = HookRegistry.check_if_exists_or_initialize(denoiser) + # Remove the hooks after inference + for hook_name in self._skip_layer_hook_names: + registry.remove_hook(hook_name, recurse=True) + + def forward( + self, + pred_cond: torch.Tensor, + pred_uncond: Optional[torch.Tensor] = None, + pred_cond_skip: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + pred = None + skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps) + skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps) + + if math.isclose(self.guidance_scale, 1.0) and math.isclose(self.skip_layer_guidance_scale, 1.0): + pred = pred_cond + + elif math.isclose(self.guidance_scale, 1.0): + if skip_start_step < self._step < skip_stop_step: + shift = pred_cond - pred_cond_skip + pred = pred_cond if self.use_original_formulation else pred_cond_skip + pred = pred + self.skip_layer_guidance_scale * shift + else: + pred = pred_cond + + elif math.isclose(self.skip_layer_guidance_scale, 1.0): + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + + else: + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + + if skip_start_step < self._step < skip_stop_step: + shift_skip = pred_cond - pred_cond_skip + pred = pred + self.skip_layer_guidance_scale * shift_skip + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + @property + def num_conditions(self) -> int: + num_conditions = 1 + skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps) + skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps) + + if not math.isclose(self.guidance_scale, 1.0): + num_conditions += 1 + if not math.isclose(self.skip_layer_guidance_scale, 1.0) and skip_start_step < self._step < skip_stop_step: + num_conditions += 1 + + return num_conditions diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 365bed371864..2db36d4366f1 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -20,5 +20,6 @@ from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache from .group_offloading import apply_group_offloading from .hooks import HookRegistry, ModelHook + from .layer_skip import LayerSkipConfig, apply_layer_skip from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py new file mode 100644 index 000000000000..45f9365bcdce --- /dev/null +++ b/src/diffusers/hooks/layer_skip.py @@ -0,0 +1,110 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional + +import torch + +from ..utils import get_logger +from ..utils.torch_utils import unwrap_module +from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS +from ._helpers import TransformerBlockRegistry +from .hooks import HookRegistry, ModelHook + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +_LAYER_SKIP_HOOK = "layer_skip_hook" + + +@dataclass +class LayerSkipConfig: + r""" + Configuration for skipping internal transformer blocks when executing a transformer model. + + Args: + indices (`List[int]`): + The indices of the layer to skip. This is typically the first layer in the transformer block. + fqn (`str`, defaults to `"auto"`): + The fully qualified name identifying the stack of transformer blocks. Typically, this is + `transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`. + """ + + indices: List[int] + fqn: str = "auto" + + +class LayerSkipHook(ModelHook): + def initialize_hook(self, module): + self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__) + return module + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + return self._metadata.skip_block_output_fn(module, *args, **kwargs) + + +def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None: + r""" + Apply layer skipping to internal layers of a transformer. + + Args: + module (`torch.nn.Module`): + The transformer model to which the layer skip hook should be applied. + config (`LayerSkipConfig`): + The configuration for the layer skip hook. + + Example: + + ```python + >>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig + + >>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) + >>> config = LayerSkipConfig(layer_index=[10, 20], fqn="transformer_blocks") + >>> apply_layer_skip_hook(transformer, config) + ``` + """ + _apply_layer_skip_hook(module, config) + + +def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None: + name = name or _LAYER_SKIP_HOOK + + if config.fqn == "auto": + for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS: + if hasattr(module, identifier): + config.fqn = identifier + break + else: + raise ValueError( + "Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid " + "`fqn` (fully qualified name) that identifies a stack of transformer blocks." + ) + + transformer_blocks = getattr(module, config.fqn, None) + if transformer_blocks is None or not isinstance(transformer_blocks, torch.nn.ModuleList): + raise ValueError( + f"Could not find {config.fqn} in the provided module, or configured `fqn` (fully qualified name) does not identify " + f"a `torch.nn.ModuleList`. Please provide a valid `fqn` that identifies a stack of transformer blocks." + ) + if len(config.indices) == 0: + raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.") + + for i, block in enumerate(transformer_blocks): + if i not in config.indices: + continue + logger.debug(f"Apply LayerSkipHook to '{config.fqn}.{i}'") + registry = HookRegistry.check_if_exists_or_initialize(block) + hook = LayerSkipHook() + registry.register_hook(hook, name) diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py index 568cfd04b833..5d63b588a66d 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py @@ -520,7 +520,7 @@ def __call__( _raise_guidance_deprecation_warning(guidance_scale=guidance_scale) if guidance is None: - guidance = ClassifierFreeGuidance(scale=guidance_scale) + guidance = ClassifierFreeGuidance(guidance_scale=guidance_scale) if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs @@ -612,29 +612,34 @@ def __call__( ) self._num_timesteps = len(timesteps) - latents, prompt_embeds, original_size, target_size, crops_coords_top_left = guidance.prepare_inputs( - latents, - (prompt_embeds, negative_prompt_embeds), - original_size, - target_size, - crops_coords_top_left, - ) - # Denoising loop transformer_dtype = self.transformer.dtype num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + conds = [prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left] + prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left = [[v] for v in conds] + with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: for i, t in enumerate(timesteps): self._current_timestep = t if self.interrupt: continue + guidance.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t) + guidance.prepare_models(self.transformer) + latents, prompt_embeds, original_size, target_size, crops_coords_top_left = guidance.prepare_inputs( + latents, + (prompt_embeds[0], negative_prompt_embeds[0]), + original_size[0], + target_size[0], + crops_coords_top_left[0], + ) + noise_preds = [] - for i, (latent, condition, original_size_c, target_size_c, crop_coord_c) in enumerate( + for batch_index, (latent, condition, original_size_c, target_size_c, crop_coord_c) in enumerate( zip(latents, prompt_embeds, original_size, target_size, crops_coords_top_left) ): - cc.mark_state(f"batch_{i}") + cc.mark_state(f"batch_{batch_index}") latent = latent.to(transformer_dtype) timestep = t.expand(latent.shape[0]) noise_pred = self.transformer( @@ -651,6 +656,7 @@ def __call__( noise_pred = guidance(*noise_preds) latents = self.scheduler.step(noise_pred, t, latents[0], return_dict=False)[0] + guidance.cleanup_models(self.transformer) # call the callback, if provided if callback_on_step_end is not None: @@ -664,10 +670,6 @@ def __call__( callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds[0]) ] - latents, prompt_embeds = guidance.prepare_inputs( - latents, (prompt_embeds[0], negative_prompt_embeds[0]) - ) - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() @@ -675,7 +677,6 @@ def __call__( xm.mark_step() self._current_timestep = None - latents = latents[0] if not output_type == "latent": latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 7ae9ca4c67a6..3c0f45461b63 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -17,6 +17,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class SkipLayerGuidance(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class FasterCacheConfig(metaclass=DummyObject): _backends = ["torch"] @@ -62,6 +77,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class LayerSkipConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class PyramidAttentionBroadcastConfig(metaclass=DummyObject): _backends = ["torch"] @@ -85,6 +115,10 @@ def apply_first_block_cache(*args, **kwargs): requires_backends(apply_first_block_cache, ["torch"]) +def apply_layer_skip(*args, **kwargs): + requires_backends(apply_layer_skip, ["torch"]) + + def apply_pyramid_attention_broadcast(*args, **kwargs): requires_backends(apply_pyramid_attention_broadcast, ["torch"]) From d91d10737aaf0823c455a4bd487fefd81559a14f Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 3 Apr 2025 03:43:10 +0200 Subject: [PATCH 12/19] update slg docstring --- src/diffusers/guiders/skip_layer_guidance.py | 29 ++++++++------------ 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index 677d97a47c9f..97d6d0b4e9ec 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -26,23 +26,18 @@ class SkipLayerGuidance(GuidanceMixin): """ Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5 - CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by - jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during - inference. This allows the model to tradeoff between generation quality and sample diversity. - - The original paper proposes scaling and shifting the conditional distribution based on the difference between - conditional and unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)] - - Diffusers implemented the scaling and shifting on the unconditional prediction instead, which is equivalent to what - the original paper proposed in theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)] - - The intution behind the original formulation can be thought of as moving the conditional distribution estimates - further away from the unconditional distribution estimates, while the diffusers-native implementation can be - thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of - the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.) - - The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the - paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time. + SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by + skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional + batch of data, apart from the conditional and unconditional batches already used in CFG + ([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions + based on the difference between conditional without skipping and conditional with skipping predictions. + + The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from + worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse + version of the model for the conditional prediction). + + Additional reading: + - [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507) Args: guidance_scale (`float`, defaults to `7.5`): From 9997c223a82c9ef47a39e968cc41ce308ac87fae Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 3 Apr 2025 03:50:30 +0200 Subject: [PATCH 13/19] more slg improvements --- src/diffusers/guiders/skip_layer_guidance.py | 29 +++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index 97d6d0b4e9ec..50f864331ad0 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -39,18 +39,36 @@ class SkipLayerGuidance(GuidanceMixin): Additional reading: - [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507) + The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are + defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium. + Args: guidance_scale (`float`, defaults to `7.5`): The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and deterioration of image quality. + skip_layer_guidance_scale (`float`, defaults to `2.8`): + The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher + values, but it may also lead to overexposure and saturation. + skip_layer_guidance_start (`float`, defaults to `0.01`): + The fraction of the total number of denoising steps after which skip layer guidance starts. + skip_layer_guidance_stop (`float`, defaults to `0.2`): + The fraction of the total number of denoising steps after which skip layer guidance stops. + skip_guidance_layers (`int` or `List[int]`, *optional*): + The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not + provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion + 3.5 Medium. + skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*): + The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of + `LayerSkipConfig`. If not provided, `skip_guidance_layers` must be provided. guidance_rescale (`float`, defaults to `0.0`): The rescale factor applied to the noise predictions. This is used to improve image quality and fix overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). use_original_formulation (`bool`, defaults to `False`): Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, - we use the diffusers-native implementation that has been in the codebase for a long time. + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. """ def __init__( @@ -71,6 +89,15 @@ def __init__( self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation + if not (0.0 <= skip_layer_guidance_start < 1.0): + raise ValueError( + f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}." + ) + if not (0.0 < skip_layer_guidance_stop <= 1.0): + raise ValueError( + f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}." + ) + if skip_guidance_layers is None and skip_layer_config is None: raise ValueError( "Either `skip_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance." From 05d74ef3e7e1bc09888145dd27a3b82844280189 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 3 Apr 2025 04:21:24 +0200 Subject: [PATCH 14/19] cfg zero star --- src/diffusers/__init__.py | 6 +- src/diffusers/guiders/__init__.py | 1 + .../guiders/classifier_free_guidance.py | 13 ++- .../classifier_free_zero_star_guidance.py | 100 ++++++++++++++++++ src/diffusers/guiders/skip_layer_guidance.py | 3 - src/diffusers/utils/dummy_pt_objects.py | 15 +++ 6 files changed, 129 insertions(+), 9 deletions(-) create mode 100644 src/diffusers/guiders/classifier_free_zero_star_guidance.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 91c41bdd438e..e0d629087e88 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -130,7 +130,9 @@ _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: - _import_structure["guiders"].extend(["ClassifierFreeGuidance", "SkipLayerGuidance"]) + _import_structure["guiders"].extend( + ["ClassifierFreeGuidance", "ClassifierFreeZeroStarGuidance", "SkipLayerGuidance"] + ) _import_structure["hooks"].extend( [ "FasterCacheConfig", @@ -714,7 +716,7 @@ except OptionalDependencyNotAvailable: from .utils.dummy_pt_objects import * # noqa F403 else: - from .guiders import ClassifierFreeGuidance, SkipLayerGuidance + from .guiders import ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance from .hooks import ( FasterCacheConfig, FirstBlockCacheConfig, diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index 9724d307560f..af6c961e23ad 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -17,5 +17,6 @@ if is_torch_available(): from .classifier_free_guidance import ClassifierFreeGuidance + from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance from .guider_utils import GuidanceMixin, _raise_guidance_deprecation_warning from .skip_layer_guidance import SkipLayerGuidance diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py index 18f2a2d31bdd..96ac875db5fd 100644 --- a/src/diffusers/guiders/classifier_free_guidance.py +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -64,13 +64,18 @@ def __init__( self.use_original_formulation = use_original_formulation def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + if math.isclose(self.guidance_scale, 1.0): - return pred_cond - shift = pred_cond - pred_uncond - pred = pred_cond if self.use_original_formulation else pred_uncond - pred = pred + self.guidance_scale * shift + pred = pred_cond + else: + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + if self.guidance_rescale > 0.0: pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + return pred @property diff --git a/src/diffusers/guiders/classifier_free_zero_star_guidance.py b/src/diffusers/guiders/classifier_free_zero_star_guidance.py new file mode 100644 index 000000000000..518b108554f6 --- /dev/null +++ b/src/diffusers/guiders/classifier_free_zero_star_guidance.py @@ -0,0 +1,100 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import torch + +from .guider_utils import GuidanceMixin, rescale_noise_cfg + + +class ClassifierFreeZeroStarGuidance(GuidanceMixin): + """ + Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886 + + This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free + guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion + process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the + quality of generated images. + + The authors of the paper suggest setting zero initialization in the first 4% of the inference steps. + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + zero_init_steps (`int`, defaults to `1`): + The number of inference steps for which the noise predictions are zeroed out (see Section 4.2). + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + """ + + def __init__( + self, + guidance_scale: float = 7.5, + zero_init_steps: int = 1, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + ): + self.guidance_scale = guidance_scale + self.zero_init_steps = zero_init_steps + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if self._step < self.zero_init_steps: + pred = torch.zeros_like(pred_cond) + elif math.isclose(self.guidance_scale, 1.0): + pred = pred_cond + else: + shift = pred_cond - pred_uncond + pred_cond_flat = pred_cond.flatten(1) + pred_uncond_flat = pred_uncond.flatten(1) + alpha = cfg_zero_star_scale(pred_cond_flat, pred_uncond_flat) + alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1)) + pred_uncond = pred_uncond * alpha + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if not math.isclose(self.guidance_scale, 1.0): + num_conditions += 1 + return num_conditions + + +def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: + cond = cond.float() + uncond = uncond.float() + dot_product = torch.sum(cond * uncond, dim=1, keepdim=True) + squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + scale = dot_product / squared_norm + return scale.to(cond.dtype) diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index 50f864331ad0..120b0d632ba0 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -175,7 +175,6 @@ def forward( if math.isclose(self.guidance_scale, 1.0) and math.isclose(self.skip_layer_guidance_scale, 1.0): pred = pred_cond - elif math.isclose(self.guidance_scale, 1.0): if skip_start_step < self._step < skip_stop_step: shift = pred_cond - pred_cond_skip @@ -183,12 +182,10 @@ def forward( pred = pred + self.skip_layer_guidance_scale * shift else: pred = pred_cond - elif math.isclose(self.skip_layer_guidance_scale, 1.0): shift = pred_cond - pred_uncond pred = pred_cond if self.use_original_formulation else pred_uncond pred = pred + self.guidance_scale * shift - else: shift = pred_cond - pred_uncond pred = pred_cond if self.use_original_formulation else pred_uncond diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 3c0f45461b63..9e9e2cdfbb70 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -17,6 +17,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class ClassifierFreeZeroStarGuidance(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class SkipLayerGuidance(metaclass=DummyObject): _backends = ["torch"] From 77324c40c4de9353b95e091cf4f8f2a91e32ce86 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 3 Apr 2025 05:01:54 +0200 Subject: [PATCH 15/19] adaptive projected guidance --- src/diffusers/__init__.py | 9 +- src/diffusers/guiders/__init__.py | 1 + .../guiders/adaptive_projected_guidance.py | 134 ++++++++++++++++++ .../classifier_free_zero_star_guidance.py | 2 +- src/diffusers/utils/dummy_pt_objects.py | 15 ++ 5 files changed, 158 insertions(+), 3 deletions(-) create mode 100644 src/diffusers/guiders/adaptive_projected_guidance.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index e0d629087e88..58db74395ea2 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -131,7 +131,7 @@ else: _import_structure["guiders"].extend( - ["ClassifierFreeGuidance", "ClassifierFreeZeroStarGuidance", "SkipLayerGuidance"] + ["AdaptiveProjectedGuidance", "ClassifierFreeGuidance", "ClassifierFreeZeroStarGuidance", "SkipLayerGuidance"] ) _import_structure["hooks"].extend( [ @@ -716,7 +716,12 @@ except OptionalDependencyNotAvailable: from .utils.dummy_pt_objects import * # noqa F403 else: - from .guiders import ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance + from .guiders import ( + AdaptiveProjectedGuidance, + ClassifierFreeGuidance, + ClassifierFreeZeroStarGuidance, + SkipLayerGuidance, + ) from .hooks import ( FasterCacheConfig, FirstBlockCacheConfig, diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index af6c961e23ad..9b6cf6093152 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -16,6 +16,7 @@ if is_torch_available(): + from .adaptive_projected_guidance import AdaptiveProjectedGuidance from .classifier_free_guidance import ClassifierFreeGuidance from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance from .guider_utils import GuidanceMixin, _raise_guidance_deprecation_warning diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py new file mode 100644 index 000000000000..e05bdfb429a1 --- /dev/null +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -0,0 +1,134 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import torch + +from .guider_utils import GuidanceMixin, rescale_noise_cfg + + +class AdaptiveProjectedGuidance(GuidanceMixin): + """ + Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416 + + Args: + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + adaptive_projected_guidance_momentum (`float`, defaults to `None`): + The momentum parameter for the adaptive projected guidance. Disabled if set to `None`. + adaptive_projected_guidance_rescale (`float`, defaults to `15.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. + """ + + def __init__( + self, + guidance_scale: float = 7.5, + adaptive_projected_guidance_momentum: Optional[float] = None, + adaptive_projected_guidance_rescale: float = 15.0, + eta: float = 1.0, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + ): + self.guidance_scale = guidance_scale + self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum + self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale + self.eta = eta + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + self.momentum_buffer = None + + def prepare_inputs(self, *args): + if self._step == 0: + if self.adaptive_projected_guidance_momentum is not None: + self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum) + return super().prepare_inputs(*args) + + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: + pred = None + + if math.isclose(self.guidance_scale, 1.0): + pred = pred_cond + else: + pred = normalized_guidance( + pred_cond, + pred_uncond, + self.guidance_scale, + self.momentum_buffer, + self.eta, + self.adaptive_projected_guidance_rescale, + self.use_original_formulation, + ) + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if not math.isclose(self.guidance_scale, 1.0): + num_conditions += 1 + return num_conditions + + +class MomentumBuffer: + def __init__(self, momentum: float): + self.momentum = momentum + self.running_average = 0 + + def update(self, update_value: torch.Tensor): + new_average = self.momentum * self.running_average + self.running_average = update_value + new_average + + +def normalized_guidance( + pred_cond: torch.Tensor, + pred_uncond: torch.Tensor, + guidance_scale: float, + momentum_buffer: Optional[MomentumBuffer] = None, + eta: float = 1.0, + norm_threshold: float = 0.0, + use_original_formulation: bool = False, +): + diff = pred_cond - pred_uncond + dim = [-i for i in range(1, len(diff.shape))] + if momentum_buffer is not None: + momentum_buffer.update(diff) + diff = momentum_buffer.running_average + if norm_threshold > 0: + ones = torch.ones_like(diff) + diff_norm = diff.norm(p=2, dim=dim, keepdim=True) + scale_factor = torch.minimum(ones, norm_threshold / diff_norm) + diff = diff * scale_factor + v0, v1 = diff.double(), pred_cond.double() + v1 = torch.nn.functional.normalize(v1, dim=dim) + v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1 + v0_orthogonal = v0 - v0_parallel + diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff) + normalized_update = diff_orthogonal + eta * diff_parallel + pred = pred_cond if use_original_formulation else pred_uncond + pred = pred + (guidance_scale - 1) * normalized_update + return pred diff --git a/src/diffusers/guiders/classifier_free_zero_star_guidance.py b/src/diffusers/guiders/classifier_free_zero_star_guidance.py index 518b108554f6..8507b2892217 100644 --- a/src/diffusers/guiders/classifier_free_zero_star_guidance.py +++ b/src/diffusers/guiders/classifier_free_zero_star_guidance.py @@ -97,4 +97,4 @@ def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1 squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps # st_star = v_cond^T * v_uncond / ||v_uncond||^2 scale = dot_product / squared_norm - return scale.to(cond.dtype) + return scale.type_as(cond) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 9e9e2cdfbb70..be0765c99dd7 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -2,6 +2,21 @@ from ..utils import DummyObject, requires_backends +class AdaptiveProjectedGuidance(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class ClassifierFreeGuidance(metaclass=DummyObject): _backends = ["torch"] From 46643564a3379c51f2adf8174cd69706eeddcd4b Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 4 Apr 2025 01:41:34 +0200 Subject: [PATCH 16/19] refactor --- .../guiders/adaptive_projected_guidance.py | 14 ++++++- .../guiders/classifier_free_guidance.py | 14 ++++++- .../classifier_free_zero_star_guidance.py | 14 ++++++- src/diffusers/guiders/guider_utils.py | 32 ++++++++++++--- src/diffusers/guiders/skip_layer_guidance.py | 41 +++++++++++++++---- .../pipelines/cogview4/pipeline_cogview4.py | 6 +-- 6 files changed, 98 insertions(+), 23 deletions(-) diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py index e05bdfb429a1..4ee52e2376f6 100644 --- a/src/diffusers/guiders/adaptive_projected_guidance.py +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -42,6 +42,8 @@ class AdaptiveProjectedGuidance(GuidanceMixin): we use the diffusers-native implementation that has been in the codebase for a long time. """ + _input_predictions = ["pred_cond", "pred_uncond"] + def __init__( self, guidance_scale: float = 7.5, @@ -51,6 +53,8 @@ def __init__( guidance_rescale: float = 0.0, use_original_formulation: bool = False, ): + super().__init__() + self.guidance_scale = guidance_scale self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale @@ -68,7 +72,7 @@ def prepare_inputs(self, *args): def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: pred = None - if math.isclose(self.guidance_scale, 1.0): + if self._is_cfg_enabled(): pred = pred_cond else: pred = normalized_guidance( @@ -89,10 +93,16 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = @property def num_conditions(self) -> int: num_conditions = 1 - if not math.isclose(self.guidance_scale, 1.0): + if self._is_cfg_enabled(): num_conditions += 1 return num_conditions + def _is_cfg_enabled(self) -> bool: + if self.use_original_formulation: + return not math.isclose(self.guidance_scale, 0.0) + else: + return not math.isclose(self.guidance_scale, 1.0) + class MomentumBuffer: def __init__(self, momentum: float): diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py index 96ac875db5fd..95bb380e19d1 100644 --- a/src/diffusers/guiders/classifier_free_guidance.py +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -56,9 +56,13 @@ class ClassifierFreeGuidance(GuidanceMixin): we use the diffusers-native implementation that has been in the codebase for a long time. """ + _input_predictions = ["pred_cond", "pred_uncond"] + def __init__( self, guidance_scale: float = 7.5, guidance_rescale: float = 0.0, use_original_formulation: bool = False ): + super().__init__() + self.guidance_scale = guidance_scale self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation @@ -66,7 +70,7 @@ def __init__( def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor: pred = None - if math.isclose(self.guidance_scale, 1.0): + if not self._is_cfg_enabled(): pred = pred_cond else: shift = pred_cond - pred_uncond @@ -81,6 +85,12 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = @property def num_conditions(self) -> int: num_conditions = 1 - if not math.isclose(self.guidance_scale, 1.0): + if self._is_cfg_enabled(): num_conditions += 1 return num_conditions + + def _is_cfg_enabled(self) -> bool: + if self.use_original_formulation: + return not math.isclose(self.guidance_scale, 0.0) + else: + return not math.isclose(self.guidance_scale, 1.0) diff --git a/src/diffusers/guiders/classifier_free_zero_star_guidance.py b/src/diffusers/guiders/classifier_free_zero_star_guidance.py index 8507b2892217..f34675e1a93a 100644 --- a/src/diffusers/guiders/classifier_free_zero_star_guidance.py +++ b/src/diffusers/guiders/classifier_free_zero_star_guidance.py @@ -48,6 +48,8 @@ class ClassifierFreeZeroStarGuidance(GuidanceMixin): [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. """ + _input_predictions = ["pred_cond", "pred_uncond"] + def __init__( self, guidance_scale: float = 7.5, @@ -55,6 +57,8 @@ def __init__( guidance_rescale: float = 0.0, use_original_formulation: bool = False, ): + super().__init__() + self.guidance_scale = guidance_scale self.zero_init_steps = zero_init_steps self.guidance_rescale = guidance_rescale @@ -65,7 +69,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = if self._step < self.zero_init_steps: pred = torch.zeros_like(pred_cond) - elif math.isclose(self.guidance_scale, 1.0): + elif self._is_cfg_enabled(): pred = pred_cond else: shift = pred_cond - pred_uncond @@ -85,10 +89,16 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = @property def num_conditions(self) -> int: num_conditions = 1 - if not math.isclose(self.guidance_scale, 1.0): + if self._is_cfg_enabled(): num_conditions += 1 return num_conditions + def _is_cfg_enabled(self) -> bool: + if self.use_original_formulation: + return not math.isclose(self.guidance_scale, 0.0) + else: + return not math.isclose(self.guidance_scale, 1.0) + def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: cond = cond.float() diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index 413a33c41cef..015713f564d1 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -25,15 +25,26 @@ class GuidanceMixin: r"""Base mixin class providing the skeleton for implementing guidance techniques.""" + _input_predictions = None + def __init__(self): self._step: int = None self._num_inference_steps: int = None self._timestep: torch.LongTensor = None + self._preds: Dict[str, torch.Tensor] = {} + self._num_outputs_prepared: int = 0 + + if self._input_predictions is None or not isinstance(self._input_predictions, list): + raise ValueError( + "`_input_predictions` must be a list of required prediction names for the guidance technique." + ) def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None: self._step = step self._num_inference_steps = num_inference_steps self._timestep = timestep + self._preds = {} + self._num_outputs_prepared = 0 def prepare_models(self, denoiser: torch.nn.Module) -> None: pass @@ -63,15 +74,22 @@ def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) ) return tuple(list_of_inputs) + def prepare_outputs(self, pred: torch.Tensor) -> None: + self._num_outputs_prepared += 1 + if self._num_outputs_prepared > self.num_conditions: + raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") + key = self._input_predictions[self._num_outputs_prepared - 1] + self._preds[key] = pred + def cleanup_models(self, denoiser: torch.nn.Module) -> None: pass - def __call__(self, *args) -> Any: - if len(args) != self.num_conditions: + def __call__(self, **kwargs) -> Any: + if len(kwargs) != self.num_conditions: raise ValueError( - f"Expected {self.num_conditions} arguments, but got {len(args)}. Please provide the correct number of arguments." + f"Expected {self.num_conditions} arguments, but got {len(kwargs)}. Please provide the correct number of arguments." ) - return self.forward(*args) + return self.forward(**kwargs) def forward(self, *args, **kwargs) -> Any: raise NotImplementedError("GuidanceMixin::forward must be implemented in subclasses.") @@ -80,6 +98,10 @@ def forward(self, *args, **kwargs) -> Any: def num_conditions(self) -> int: raise NotImplementedError("GuidanceMixin::num_conditions must be implemented in subclasses.") + @property + def outputs(self) -> Dict[str, torch.Tensor]: + return self._preds + def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): r""" diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index 120b0d632ba0..6d0d9cb4a7e6 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -71,6 +71,8 @@ class SkipLayerGuidance(GuidanceMixin): [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. """ + _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"] + def __init__( self, guidance_scale: float = 7.5, @@ -82,6 +84,8 @@ def __init__( guidance_rescale: float = 0.0, use_original_formulation: bool = False, ): + super().__init__() + self.guidance_scale = guidance_scale self.skip_layer_guidance_scale = skip_layer_guidance_scale self.skip_layer_guidance_start = skip_layer_guidance_start @@ -157,6 +161,18 @@ def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) ) return tuple(list_of_inputs) + def prepare_outputs(self, pred: torch.Tensor) -> None: + self._num_outputs_prepared += 1 + if self._num_outputs_prepared > self.num_conditions: + raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") + key = self._input_predictions[self._num_outputs_prepared - 1] + if not self._is_cfg_enabled() and self._is_slg_enabled(): + # If we're predicting pred_cond and pred_cond_skip only, we need to set the key to pred_cond_skip + # to avoid writing into pred_uncond which is not used + if self._num_outputs_prepared == 2: + key = "pred_cond_skip" + self._preds[key] = pred + def cleanup_models(self, denoiser: torch.nn.Module): registry = HookRegistry.check_if_exists_or_initialize(denoiser) # Remove the hooks after inference @@ -173,16 +189,16 @@ def forward( skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps) skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps) - if math.isclose(self.guidance_scale, 1.0) and math.isclose(self.skip_layer_guidance_scale, 1.0): + if not self._is_cfg_enabled() and not self._is_slg_enabled(): pred = pred_cond - elif math.isclose(self.guidance_scale, 1.0): + elif not self._is_cfg_enabled(): if skip_start_step < self._step < skip_stop_step: shift = pred_cond - pred_cond_skip pred = pred_cond if self.use_original_formulation else pred_cond_skip pred = pred + self.skip_layer_guidance_scale * shift else: pred = pred_cond - elif math.isclose(self.skip_layer_guidance_scale, 1.0): + elif not self._is_slg_enabled(): shift = pred_cond - pred_uncond pred = pred_cond if self.use_original_formulation else pred_uncond pred = pred + self.guidance_scale * shift @@ -203,12 +219,19 @@ def forward( @property def num_conditions(self) -> int: num_conditions = 1 - skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps) - skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps) - - if not math.isclose(self.guidance_scale, 1.0): + if self._is_cfg_enabled(): num_conditions += 1 - if not math.isclose(self.skip_layer_guidance_scale, 1.0) and skip_start_step < self._step < skip_stop_step: + if self._is_slg_enabled(): num_conditions += 1 - return num_conditions + + def _is_cfg_enabled(self) -> bool: + if self.use_original_formulation: + return not math.isclose(self.guidance_scale, 0.0) + else: + return not math.isclose(self.guidance_scale, 1.0) + + def _is_slg_enabled(self) -> bool: + skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps) + skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps) + return skip_start_step < self._step < skip_stop_step diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py index 5d63b588a66d..5d7e09c389ae 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py @@ -635,7 +635,6 @@ def __call__( crops_coords_top_left[0], ) - noise_preds = [] for batch_index, (latent, condition, original_size_c, target_size_c, crop_coord_c) in enumerate( zip(latents, prompt_embeds, original_size, target_size, crops_coords_top_left) ): @@ -652,9 +651,10 @@ def __call__( attention_kwargs=attention_kwargs, return_dict=False, )[0] - noise_preds.append(noise_pred) + guidance.prepare_outputs(noise_pred) - noise_pred = guidance(*noise_preds) + outputs = guidance.outputs + noise_pred = guidance(**outputs) latents = self.scheduler.step(noise_pred, t, latents[0], return_dict=False)[0] guidance.cleanup_models(self.transformer) From 53b6b9fcb6d1f578041398b4c1c0741bdc67f422 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 4 Apr 2025 04:16:46 +0200 Subject: [PATCH 17/19] perturbed attention guidance --- src/diffusers/__init__.py | 9 +- src/diffusers/guiders/__init__.py | 1 + src/diffusers/guiders/guider_utils.py | 73 ++++++- .../guiders/perturbed_attention_guidance.py | 180 ++++++++++++++++++ src/diffusers/guiders/skip_layer_guidance.py | 22 +-- src/diffusers/hooks/_helpers.py | 36 +++- .../transformers/transformer_cogview4.py | 81 ++++++++ src/diffusers/utils/dummy_pt_objects.py | 15 ++ 8 files changed, 400 insertions(+), 17 deletions(-) create mode 100644 src/diffusers/guiders/perturbed_attention_guidance.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 58db74395ea2..8b5257f46c8d 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -131,7 +131,13 @@ else: _import_structure["guiders"].extend( - ["AdaptiveProjectedGuidance", "ClassifierFreeGuidance", "ClassifierFreeZeroStarGuidance", "SkipLayerGuidance"] + [ + "AdaptiveProjectedGuidance", + "ClassifierFreeGuidance", + "ClassifierFreeZeroStarGuidance", + "PerturbedAttentionGuidance", + "SkipLayerGuidance", + ] ) _import_structure["hooks"].extend( [ @@ -720,6 +726,7 @@ AdaptiveProjectedGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, + PerturbedAttentionGuidance, SkipLayerGuidance, ) from .hooks import ( diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py index 9b6cf6093152..3893b30935c0 100644 --- a/src/diffusers/guiders/__init__.py +++ b/src/diffusers/guiders/__init__.py @@ -20,4 +20,5 @@ from .classifier_free_guidance import ClassifierFreeGuidance from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance from .guider_utils import GuidanceMixin, _raise_guidance_deprecation_warning + from .perturbed_attention_guidance import PerturbedAttentionGuidance from .skip_layer_guidance import SkipLayerGuidance diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index 015713f564d1..36a9fa552e54 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -12,13 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple, Union +import re +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import torch from ..utils import deprecate, get_logger +if TYPE_CHECKING: + from ..models.attention_processor import AttentionProcessor + + logger = get_logger(__name__) # pylint: disable=invalid-name @@ -129,6 +134,72 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): return noise_cfg +def _replace_attention_processors( + module: torch.nn.Module, + pag_applied_layers: Optional[Union[str, List[str]]] = None, + skip_context_attention: bool = False, + processors: Optional[List[Tuple[torch.nn.Module, "AttentionProcessor"]]] = None, + metadata_name: Optional[str] = None, +) -> Optional[List[Tuple[torch.nn.Module, "AttentionProcessor"]]]: + if processors is not None and metadata_name is not None: + raise ValueError("Cannot pass both `processors` and `metadata_name` at the same time.") + if metadata_name is not None: + if isinstance(pag_applied_layers, str): + pag_applied_layers = [pag_applied_layers] + return _replace_layers_with_guidance_processors( + module, pag_applied_layers, skip_context_attention, metadata_name + ) + if processors is not None: + _replace_layers_with_existing_processors(processors) + + +def _replace_layers_with_guidance_processors( + module: torch.nn.Module, + pag_applied_layers: List[str], + skip_context_attention: bool, + metadata_name: str, +) -> List[Tuple[torch.nn.Module, "AttentionProcessor"]]: + from ..hooks._common import _ATTENTION_CLASSES + from ..hooks._helpers import GuidanceMetadataRegistry + + processors = [] + for name, submodule in module.named_modules(): + if ( + (not isinstance(submodule, _ATTENTION_CLASSES)) + or (getattr(submodule, "processor", None) is None) + or not ( + any( + re.search(pag_layer, name) is not None and not _is_fake_integral_match(pag_layer, name) + for pag_layer in pag_applied_layers + ) + ) + ): + continue + old_attention_processor = submodule.processor + metadata = GuidanceMetadataRegistry.get(old_attention_processor.__class__) + new_attention_processor_cls = getattr(metadata, metadata_name) + new_attention_processor = new_attention_processor_cls() + # !!! dunder methods cannot be replaced on instances !!! + # if "skip_context_attention" in inspect.signature(new_attention_processor.__call__).parameters: + # new_attention_processor.__call__ = partial( + # new_attention_processor.__call__, skip_context_attention=skip_context_attention + # ) + submodule.processor = new_attention_processor + processors.append((submodule, old_attention_processor)) + return processors + + +def _replace_layers_with_existing_processors(processors: List[Tuple[torch.nn.Module, "AttentionProcessor"]]) -> None: + for module, proc in processors: + module.processor = proc + + +def _is_fake_integral_match(layer_id, name): + layer_id = layer_id.split(".")[-1] + name = name.split(".")[-1] + return layer_id.isnumeric() and name.isnumeric() and layer_id == name + + def _raise_guidance_deprecation_warning( *, guidance_scale: Optional[float] = None, diff --git a/src/diffusers/guiders/perturbed_attention_guidance.py b/src/diffusers/guiders/perturbed_attention_guidance.py new file mode 100644 index 000000000000..87a5fb9614f6 --- /dev/null +++ b/src/diffusers/guiders/perturbed_attention_guidance.py @@ -0,0 +1,180 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple, Union + +import torch + +from .guider_utils import GuidanceMixin, _replace_attention_processors, rescale_noise_cfg + + +class PerturbedAttentionGuidance(GuidanceMixin): + """ + Perturbed Attention Guidance (PAB): https://huggingface.co/papers/2403.17377 + + Args: + pag_applied_layers (`str` or `List[str]`): + The name of the attention layers where Perturbed Attention Guidance is applied. This can be a single layer + name or a list of layer names. The names should either be FQNs (fully qualified names) to each attention + layer or a regex pattern that matches the FQNs of the attention layers. For example, if you want to apply + PAG to transformer blocks 10 and 20, you can set this to `["transformer_blocks.10", + "transformer_blocks.20"]`, or `"transformer_blocks.(10|20)"`. + guidance_scale (`float`, defaults to `7.5`): + The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text + prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and + deterioration of image quality. + pag_scale (`float`, defaults to `3.0`): + The scale parameter for perturbed attention guidance. + guidance_rescale (`float`, defaults to `0.0`): + The rescale factor applied to the noise predictions. This is used to improve image quality and fix + overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891). + use_original_formulation (`bool`, defaults to `False`): + Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. + """ + + _input_predictions = ["pred_cond", "pred_uncond", "pred_perturbed"] + + def __init__( + self, + pag_applied_layers: Union[str, List[str]], + guidance_scale: float = 7.5, + pag_scale: float = 3.0, + skip_context_attention: bool = False, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + ): + super().__init__() + + self.pag_applied_layers = pag_applied_layers + self.guidance_scale = guidance_scale + self.pag_scale = pag_scale + self.skip_context_attention = skip_context_attention + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + + self._is_pag_batch = False + self._original_processors = None + self._denoiser = None + + def prepare_models(self, denoiser: torch.nn.Module): + self._denoiser = denoiser + + def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]: + num_conditions = self.num_conditions + list_of_inputs = [] + for arg in args: + if isinstance(arg, torch.Tensor): + list_of_inputs.append([arg] * num_conditions) + elif isinstance(arg, (tuple, list)): + if len(arg) != 2: + raise ValueError( + f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 " + f"with the first element being the conditional input and the second element being the unconditional input or None." + ) + if arg[1] is None: + # Only conditioning inputs for all batches + list_of_inputs.append([arg[0]] * num_conditions) + else: + list_of_inputs.append([arg[0], arg[1], arg[0]]) + else: + raise ValueError( + f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list." + ) + return tuple(list_of_inputs) + + def prepare_outputs(self, pred: torch.Tensor) -> None: + self._num_outputs_prepared += 1 + if self._num_outputs_prepared > self.num_conditions: + raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.") + key = self._input_predictions[self._num_outputs_prepared - 1] + if not self._is_cfg_enabled() and self._is_pag_enabled(): + # If we're predicting pred_cond and pred_perturbed only, we need to set the key to pred_perturbed + # to avoid writing into pred_uncond which is not used + if self._num_outputs_prepared == 2: + key = "pred_perturbed" + self._preds[key] = pred + + # Prepare denoiser for perturbed attention prediction if needed + if not self._is_pag_enabled(): + return + should_register_pag = (self._is_cfg_enabled() and self._num_outputs_prepared == 2) or ( + not self._is_cfg_enabled() and self._num_outputs_prepared == 1 + ) + if should_register_pag: + self._is_pag_batch = True + self._original_processors = _replace_attention_processors( + self._denoiser, + self.pag_applied_layers, + skip_context_attention=self.skip_context_attention, + metadata_name="perturbed_attention_guidance_processor_cls", + ) + elif self._is_pag_batch: + # Restore the original attention processors + _replace_attention_processors(self._denoiser, processors=self._original_processors) + self._is_pag_batch = False + self._original_processors = None + + def cleanup_models(self, denoiser: torch.nn.Module): + self._denoiser = None + + def forward( + self, + pred_cond: torch.Tensor, + pred_uncond: Optional[torch.Tensor] = None, + pred_perturbed: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + pred = None + + if not self._is_cfg_enabled() and not self._is_pag_enabled(): + pred = pred_cond + elif not self._is_cfg_enabled(): + shift = pred_cond - pred_perturbed + pred = pred_cond + self.pag_scale * shift + elif not self._is_pag_enabled(): + shift = pred_cond - pred_uncond + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + else: + shift = pred_cond - pred_uncond + shift_perturbed = pred_cond - pred_perturbed + pred = pred_cond if self.use_original_formulation else pred_uncond + pred = pred + self.guidance_scale * shift + self.pag_scale * shift_perturbed + + if self.guidance_rescale > 0.0: + pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) + + return pred + + @property + def num_conditions(self) -> int: + num_conditions = 1 + if self._is_cfg_enabled(): + num_conditions += 1 + if self._is_pag_enabled(): + num_conditions += 1 + return num_conditions + + def _is_cfg_enabled(self) -> bool: + if self.use_original_formulation: + return not math.isclose(self.guidance_scale, 0.0) + else: + return not math.isclose(self.guidance_scale, 1.0) + + def _is_pag_enabled(self) -> bool: + is_zero = math.isclose(self.pag_scale, 0.0) + return not is_zero diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index 6d0d9cb4a7e6..61f452a97673 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -186,30 +186,22 @@ def forward( pred_cond_skip: Optional[torch.Tensor] = None, ) -> torch.Tensor: pred = None - skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps) - skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps) if not self._is_cfg_enabled() and not self._is_slg_enabled(): pred = pred_cond elif not self._is_cfg_enabled(): - if skip_start_step < self._step < skip_stop_step: - shift = pred_cond - pred_cond_skip - pred = pred_cond if self.use_original_formulation else pred_cond_skip - pred = pred + self.skip_layer_guidance_scale * shift - else: - pred = pred_cond + shift = pred_cond - pred_cond_skip + pred = pred_cond if self.use_original_formulation else pred_cond_skip + pred = pred + self.skip_layer_guidance_scale * shift elif not self._is_slg_enabled(): shift = pred_cond - pred_uncond pred = pred_cond if self.use_original_formulation else pred_uncond pred = pred + self.guidance_scale * shift else: shift = pred_cond - pred_uncond + shift_skip = pred_cond - pred_cond_skip pred = pred_cond if self.use_original_formulation else pred_uncond - pred = pred + self.guidance_scale * shift - - if skip_start_step < self._step < skip_stop_step: - shift_skip = pred_cond - pred_cond_skip - pred = pred + self.skip_layer_guidance_scale * shift_skip + pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip if self.guidance_rescale > 0.0: pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale) @@ -234,4 +226,6 @@ def _is_cfg_enabled(self) -> bool: def _is_slg_enabled(self) -> bool: skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps) skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps) - return skip_start_step < self._step < skip_stop_step + is_within_range = skip_start_step < self._step < skip_stop_step + is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0) + return is_within_range and not is_zero diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py index 253ca88059e5..c87468001e1f 100644 --- a/src/diffusers/hooks/_helpers.py +++ b/src/diffusers/hooks/_helpers.py @@ -16,7 +16,11 @@ from typing import Any, Callable, Type from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock -from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock +from ..models.transformers.transformer_cogview4 import ( + CogView4AttnProcessor, + CogView4PAGAttnProcessor, + CogView4TransformerBlock, +) from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock from ..models.transformers.transformer_hunyuan_video import ( HunyuanVideoSingleTransformerBlock, @@ -50,6 +54,25 @@ def get(cls, model_class: Type) -> TransformerBlockMetadata: return cls._registry[model_class] +@dataclass +class GuidanceMetadata: + perturbed_attention_guidance_processor_cls: Type = None + + +class GuidanceMetadataRegistry: + _registry = {} + + @classmethod + def register(cls, model_class: Type, metadata: GuidanceMetadata): + cls._registry[model_class] = metadata + + @classmethod + def get(cls, model_class: Type) -> GuidanceMetadata: + if model_class not in cls._registry: + raise ValueError(f"Model class {model_class} not registered.") + return cls._registry[model_class] + + def _register_transformer_blocks_metadata(): # CogVideoX TransformerBlockRegistry.register( @@ -154,6 +177,16 @@ def _register_transformer_blocks_metadata(): ) +def _register_guidance_metadata(): + # CogView4 + GuidanceMetadataRegistry.register( + model_class=CogView4AttnProcessor, + metadata=GuidanceMetadata( + perturbed_attention_guidance_processor_cls=CogView4PAGAttnProcessor, + ), + ) + + # fmt: off def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs): hidden_states = kwargs.get("hidden_states", None) @@ -197,3 +230,4 @@ def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___en _register_transformer_blocks_metadata() +_register_guidance_metadata() diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index 41c4cbbf97c7..d393490c7e8c 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -460,3 +460,84 @@ def forward( if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) + + +### ===== Custom attention processors for guidance methods ===== ### + + +class CogView4PAGAttnProcessor: + """ + Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("CogView4AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + skip_context_attention: bool = False, + ) -> torch.Tensor: + batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape + batch_size, image_seq_length, embed_dim = hidden_states.shape + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + # 1. QKV projections + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + # 2. QK normalization + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # 3. Rotational positional embeddings applied to latent stream + if image_rotary_emb is not None: + from ..embeddings import apply_rotary_emb + + query[:, :, text_seq_length:, :] = apply_rotary_emb( + query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2 + ) + key[:, :, text_seq_length:, :] = apply_rotary_emb( + key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2 + ) + + # 4. Attention + if skip_context_attention: + hidden_states = value + else: + # PAG uses a custom attention mask for perturbed attention path: + # - Create attention mask with `float("-inf")` for image tokens and `0.0` for text tokens + # - Set diagonal to `0.0` for attention between image tokens + seq_length = text_seq_length + image_seq_length + perturbed_attention_mask = hidden_states.new_full((seq_length, seq_length), float("-inf")) + perturbed_attention_mask[:text_seq_length, :text_seq_length] = 0.0 + perturbed_attention_mask.fill_diagonal_(0.0) + perturbed_attention_mask = perturbed_attention_mask.unsqueeze(0).unsqueeze(0) + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=perturbed_attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + # 5. Output projection + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + return hidden_states, encoder_hidden_states diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index be0765c99dd7..1e933483a5ca 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -47,6 +47,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class PerturbedAttentionGuidance(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class SkipLayerGuidance(metaclass=DummyObject): _backends = ["torch"] From 357f4f056bdb3b3077f4bc3d14c7c9e0bcd3837d Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 4 Apr 2025 22:44:02 +0200 Subject: [PATCH 18/19] update --- .../guiders/adaptive_projected_guidance.py | 3 +- .../guiders/classifier_free_guidance.py | 8 +++-- .../guiders/perturbed_attention_guidance.py | 32 +++++++++---------- 3 files changed, 23 insertions(+), 20 deletions(-) diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py index 4ee52e2376f6..06faaa80e894 100644 --- a/src/diffusers/guiders/adaptive_projected_guidance.py +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -39,7 +39,8 @@ class AdaptiveProjectedGuidance(GuidanceMixin): Flawed](https://huggingface.co/papers/2305.08891). use_original_formulation (`bool`, defaults to `False`): Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, - we use the diffusers-native implementation that has been in the codebase for a long time. + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. """ _input_predictions = ["pred_cond", "pred_uncond"] diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py index 95bb380e19d1..3423eb3f5fc2 100644 --- a/src/diffusers/guiders/classifier_free_guidance.py +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -31,8 +31,9 @@ class ClassifierFreeGuidance(GuidanceMixin): The original paper proposes scaling and shifting the conditional distribution based on the difference between conditional and unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)] - Diffusers implemented the scaling and shifting on the unconditional prediction instead, which is equivalent to what - the original paper proposed in theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)] + Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen + paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in + theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)] The intution behind the original formulation can be thought of as moving the conditional distribution estimates further away from the unconditional distribution estimates, while the diffusers-native implementation can be @@ -53,7 +54,8 @@ class ClassifierFreeGuidance(GuidanceMixin): Flawed](https://huggingface.co/papers/2305.08891). use_original_formulation (`bool`, defaults to `False`): Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, - we use the diffusers-native implementation that has been in the codebase for a long time. + we use the diffusers-native implementation that has been in the codebase for a long time. See + [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. """ _input_predictions = ["pred_cond", "pred_uncond"] diff --git a/src/diffusers/guiders/perturbed_attention_guidance.py b/src/diffusers/guiders/perturbed_attention_guidance.py index 87a5fb9614f6..242860886ec4 100644 --- a/src/diffusers/guiders/perturbed_attention_guidance.py +++ b/src/diffusers/guiders/perturbed_attention_guidance.py @@ -109,26 +109,26 @@ def prepare_outputs(self, pred: torch.Tensor) -> None: key = "pred_perturbed" self._preds[key] = pred - # Prepare denoiser for perturbed attention prediction if needed - if not self._is_pag_enabled(): - return - should_register_pag = (self._is_cfg_enabled() and self._num_outputs_prepared == 2) or ( - not self._is_cfg_enabled() and self._num_outputs_prepared == 1 - ) - if should_register_pag: - self._is_pag_batch = True - self._original_processors = _replace_attention_processors( - self._denoiser, - self.pag_applied_layers, - skip_context_attention=self.skip_context_attention, - metadata_name="perturbed_attention_guidance_processor_cls", - ) - elif self._is_pag_batch: - # Restore the original attention processors + # Restore the original attention processors if previously replaced + if self._is_pag_batch: _replace_attention_processors(self._denoiser, processors=self._original_processors) self._is_pag_batch = False self._original_processors = None + # Prepare denoiser for perturbed attention prediction if needed + if self._is_pag_enabled(): + should_register_pag = (self._is_cfg_enabled() and self._num_outputs_prepared == 2) or ( + not self._is_cfg_enabled() and self._num_outputs_prepared == 1 + ) + if should_register_pag: + self._is_pag_batch = True + self._original_processors = _replace_attention_processors( + self._denoiser, + self.pag_applied_layers, + skip_context_attention=self.skip_context_attention, + metadata_name="perturbed_attention_guidance_processor_cls", + ) + def cleanup_models(self, denoiser: torch.nn.Module): self._denoiser = None From b30cf5d4528353bc86e87d50c809d8d3d981d4c6 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 5 Apr 2025 10:39:46 +0200 Subject: [PATCH 19/19] spatio temporal guidance --- src/diffusers/guiders/skip_layer_guidance.py | 32 +++---- src/diffusers/hooks/_common.py | 2 + src/diffusers/hooks/_helpers.py | 79 ++++++++++++++---- src/diffusers/hooks/layer_skip.py | 88 ++++++++++++++++++-- 4 files changed, 161 insertions(+), 40 deletions(-) diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index 61f452a97673..1fe09ddac615 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -24,7 +24,8 @@ class SkipLayerGuidance(GuidanceMixin): """ - Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5 + Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5 Spatio-Temporal Guidance (STG): + https://huggingface.co/papers/2411.18664 SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional @@ -36,6 +37,9 @@ class SkipLayerGuidance(GuidanceMixin): worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse version of the model for the conditional prediction). + STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving + generation quality in video diffusion models. + Additional reading: - [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507) @@ -54,13 +58,13 @@ class SkipLayerGuidance(GuidanceMixin): The fraction of the total number of denoising steps after which skip layer guidance starts. skip_layer_guidance_stop (`float`, defaults to `0.2`): The fraction of the total number of denoising steps after which skip layer guidance stops. - skip_guidance_layers (`int` or `List[int]`, *optional*): + skip_layer_guidance_layers (`int` or `List[int]`, *optional*): The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion 3.5 Medium. skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*): The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of - `LayerSkipConfig`. If not provided, `skip_guidance_layers` must be provided. + `LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided. guidance_rescale (`float`, defaults to `0.0`): The rescale factor applied to the noise predictions. This is used to improve image quality and fix overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are @@ -79,7 +83,7 @@ def __init__( skip_layer_guidance_scale: float = 2.8, skip_layer_guidance_start: float = 0.01, skip_layer_guidance_stop: float = 0.2, - skip_guidance_layers: Optional[Union[int, List[int]]] = None, + skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None, skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None, guidance_rescale: float = 0.0, use_original_formulation: bool = False, @@ -102,21 +106,21 @@ def __init__( f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}." ) - if skip_guidance_layers is None and skip_layer_config is None: + if skip_layer_guidance_layers is None and skip_layer_config is None: raise ValueError( - "Either `skip_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance." + "Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance." ) - if skip_guidance_layers is not None and skip_layer_config is not None: - raise ValueError("Only one of `skip_guidance_layers` or `skip_layer_config` can be provided.") + if skip_layer_guidance_layers is not None and skip_layer_config is not None: + raise ValueError("Only one of `skip_layer_guidance_layers` or `skip_layer_config` can be provided.") - if skip_guidance_layers is not None: - if isinstance(skip_guidance_layers, int): - skip_guidance_layers = [skip_guidance_layers] - if not isinstance(skip_guidance_layers, list): + if skip_layer_guidance_layers is not None: + if isinstance(skip_layer_guidance_layers, int): + skip_layer_guidance_layers = [skip_layer_guidance_layers] + if not isinstance(skip_layer_guidance_layers, list): raise ValueError( - f"Expected `skip_guidance_layers` to be an int or a list of ints, but got {type(skip_guidance_layers)}." + f"Expected `skip_layer_guidance_layers` to be an int or a list of ints, but got {type(skip_layer_guidance_layers)}." ) - skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_guidance_layers] + skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers] if isinstance(skip_layer_config, LayerSkipConfig): skip_layer_config = [skip_layer_config] diff --git a/src/diffusers/hooks/_common.py b/src/diffusers/hooks/_common.py index 3be77dd4cedf..6ea83dcbf6a7 100644 --- a/src/diffusers/hooks/_common.py +++ b/src/diffusers/hooks/_common.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from ..models.attention import FeedForward, LuminaFeedForward from ..models.attention_processor import Attention, MochiAttention _ATTENTION_CLASSES = (Attention, MochiAttention) +_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward) _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers") _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py index c87468001e1f..9dabc7b286b5 100644 --- a/src/diffusers/hooks/_helpers.py +++ b/src/diffusers/hooks/_helpers.py @@ -33,6 +33,16 @@ from ..models.transformers.transformer_wan import WanTransformerBlock +@dataclass +class AttentionProcessorMetadata: + skip_processor_output_fn: Callable[[Any], Any] + + +@dataclass +class GuidanceMetadata: + perturbed_attention_guidance_processor_cls: Type = None + + @dataclass class TransformerBlockMetadata: skip_block_output_fn: Callable[[Any], Any] @@ -40,25 +50,20 @@ class TransformerBlockMetadata: return_encoder_hidden_states_index: int = None -class TransformerBlockRegistry: +class AttentionProcessorRegistry: _registry = {} @classmethod - def register(cls, model_class: Type, metadata: TransformerBlockMetadata): + def register(cls, model_class: Type, metadata: AttentionProcessorMetadata): cls._registry[model_class] = metadata @classmethod - def get(cls, model_class: Type) -> TransformerBlockMetadata: + def get(cls, model_class: Type) -> AttentionProcessorMetadata: if model_class not in cls._registry: raise ValueError(f"Model class {model_class} not registered.") return cls._registry[model_class] -@dataclass -class GuidanceMetadata: - perturbed_attention_guidance_processor_cls: Type = None - - class GuidanceMetadataRegistry: _registry = {} @@ -73,6 +78,40 @@ def get(cls, model_class: Type) -> GuidanceMetadata: return cls._registry[model_class] +class TransformerBlockRegistry: + _registry = {} + + @classmethod + def register(cls, model_class: Type, metadata: TransformerBlockMetadata): + cls._registry[model_class] = metadata + + @classmethod + def get(cls, model_class: Type) -> TransformerBlockMetadata: + if model_class not in cls._registry: + raise ValueError(f"Model class {model_class} not registered.") + return cls._registry[model_class] + + +def _register_attention_processors_metadata(): + # CogView4 + AttentionProcessorRegistry.register( + model_class=CogView4AttnProcessor, + metadata=AttentionProcessorMetadata( + skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor, + ), + ) + + +def _register_guidance_metadata(): + # CogView4 + GuidanceMetadataRegistry.register( + model_class=CogView4AttnProcessor, + metadata=GuidanceMetadata( + perturbed_attention_guidance_processor_cls=CogView4PAGAttnProcessor, + ), + ) + + def _register_transformer_blocks_metadata(): # CogVideoX TransformerBlockRegistry.register( @@ -177,17 +216,20 @@ def _register_transformer_blocks_metadata(): ) -def _register_guidance_metadata(): - # CogView4 - GuidanceMetadataRegistry.register( - model_class=CogView4AttnProcessor, - metadata=GuidanceMetadata( - perturbed_attention_guidance_processor_cls=CogView4PAGAttnProcessor, - ), - ) +# fmt: off +def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + encoder_hidden_states = kwargs.get("encoder_hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + if encoder_hidden_states is None and len(args) > 1: + encoder_hidden_states = args[1] + return hidden_states, encoder_hidden_states + + +_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states -# fmt: off def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs): hidden_states = kwargs.get("hidden_states", None) if hidden_states is None and len(args) > 0: @@ -229,5 +271,6 @@ def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___en # fmt: on -_register_transformer_blocks_metadata() +_register_attention_processors_metadata() _register_guidance_metadata() +_register_transformer_blocks_metadata() diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py index 45f9365bcdce..30b169e8f4af 100644 --- a/src/diffusers/hooks/layer_skip.py +++ b/src/diffusers/hooks/layer_skip.py @@ -13,14 +13,14 @@ # limitations under the License. from dataclasses import dataclass -from typing import List, Optional +from typing import Callable, List, Optional import torch from ..utils import get_logger from ..utils.torch_utils import unwrap_module -from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS -from ._helpers import TransformerBlockRegistry +from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES +from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry from .hooks import HookRegistry, ModelHook @@ -44,9 +44,50 @@ class LayerSkipConfig: indices: List[int] fqn: str = "auto" + skip_attention: bool = True + skip_attention_scores: bool = False + skip_ff: bool = True -class LayerSkipHook(ModelHook): +class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode): + def __init__(self) -> None: + super().__init__() + + def __torch_function__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func is torch.nn.functional.scaled_dot_product_attention: + value = kwargs.get("value", None) + if value is None: + value = args[2] + return value + return func(*args, **kwargs) + + +class AttentionProcessorSkipHook(ModelHook): + def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False): + self.skip_processor_output_fn = skip_processor_output_fn + self.skip_attention_scores = skip_attention_scores + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + if self.skip_attention_scores: + with AttentionScoreSkipFunctionMode(): + return self.fn_ref.original_forward(*args, **kwargs) + else: + return self.skip_processor_output_fn(module, *args, **kwargs) + + +class FeedForwardSkipHook(ModelHook): + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + output = kwargs.get("hidden_states", None) + if output is None: + output = kwargs.get("x", None) + if output is None and len(args) > 0: + output = args[0] + return output + + +class TransformerBlockSkipHook(ModelHook): def initialize_hook(self, module): self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__) return module @@ -81,6 +122,9 @@ def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None: def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None: name = name or _LAYER_SKIP_HOOK + if config.skip_attention and config.skip_attention_scores: + raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.") + if config.fqn == "auto": for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS: if hasattr(module, identifier): @@ -101,10 +145,38 @@ def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, nam if len(config.indices) == 0: raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.") + blocks_found = False for i, block in enumerate(transformer_blocks): if i not in config.indices: continue - logger.debug(f"Apply LayerSkipHook to '{config.fqn}.{i}'") - registry = HookRegistry.check_if_exists_or_initialize(block) - hook = LayerSkipHook() - registry.register_hook(hook, name) + blocks_found = True + if config.skip_attention and config.skip_ff: + logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'") + registry = HookRegistry.check_if_exists_or_initialize(block) + hook = TransformerBlockSkipHook() + registry.register_hook(hook, name) + elif config.skip_attention or config.skip_attention_scores: + for submodule_name, submodule in block.named_modules(): + if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention: + logger.debug(f"Applying AttentionProcessorSkipHook to '{config.fqn}.{i}.{submodule_name}'") + output_fn = AttentionProcessorRegistry.get(submodule.processor.__class__).skip_processor_output_fn + registry = HookRegistry.check_if_exists_or_initialize(submodule) + hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores) + registry.register_hook(hook, name) + elif config.skip_ff: + for submodule_name, submodule in block.named_modules(): + if isinstance(submodule, _FEEDFORWARD_CLASSES): + logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'") + registry = HookRegistry.check_if_exists_or_initialize(submodule) + hook = FeedForwardSkipHook() + registry.register_hook(hook, name) + else: + raise ValueError( + "At least one of `skip_attention`, `skip_attention_scores`, or `skip_ff` must be set to True." + ) + + if not blocks_found: + raise ValueError( + f"Could not find any transformer blocks matching the provided indices {config.indices} and " + f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness." + )