From a4bfa451fe38496dfd2a48a18076d0baf12b0999 Mon Sep 17 00:00:00 2001 From: toilaluan Date: Thu, 13 Nov 2025 15:06:36 +0700 Subject: [PATCH 1/7] init taylor_seer cache --- src/diffusers/hooks/taylorseer_cache.py | 118 ++++++++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 src/diffusers/hooks/taylorseer_cache.py diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py new file mode 100644 index 000000000000..2f8f6a4b476e --- /dev/null +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -0,0 +1,118 @@ +# Experimental hook for TaylorSeer cache +# Supports Flux only for now + +import torch +from dataclasses import dataclass +from typing import Callable +from .hooks import ModelHook +import math +from ..models.attention import Attention +from ..models.attention import AttentionModuleMixin +from ._common import ( + _ATTENTION_CLASSES, +) +from ..hooks import HookRegistry + +_TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache" + +@dataclass +class TaylorSeerCacheConfig: + fresh_threshold: int = 5 # interleave cache and compute: `fresh_threshold` steps are cached, then 1 full compute step is performed + max_order: int = 1 # order of Taylor series expansion + current_timestep_callback: Callable[[], int] = None + +class TaylorSeerState: + def __init__(self): + self.predict_counter: int = 1 + self.last_step: int = 1000 + self.taylor_factors: dict[int, torch.Tensor] = {} + + def reset(self): + self.predict_counter = 1 + self.last_step = 1000 + self.taylor_factors = {} + + def update(self, features: torch.Tensor, current_step: int, max_order: int, refresh_threshold: int): + N = math.abs(current_step - self.last_step) + # initialize the first order taylor factors + new_taylor_factors = {0: features} + for i in range(max_order): + if (self.taylor_factors.get(i) is not None) and current_step > 1: + new_taylor_factors[i+1] = (self.taylor_factors[i] - new_taylor_factors[i]) / N + else: + break + self.taylor_factors = new_taylor_factors + self.last_step = current_step + self.predict_counter = (self.predict_counter + 1) % refresh_threshold + + def predict(self, current_step: int, refresh_threshold: int): + k = current_step - self.last_step + device = self.taylor_factors[0].device + output = torch.zeros_like(self.taylor_factors[0], device=device) + for i in range(len(self.taylor_factors)): + output += self.taylor_factors[i] * (k ** i) / math.factorial(i) + self.predict_counter = (self.predict_counter + 1) % refresh_threshold + return output + +class TaylorSeerAttentionCacheHook(ModelHook): + _is_stateful = True + + def __init__(self, fresh_threshold: int, max_order: int, current_timestep_callback: Callable[[], int]): + super().__init__() + self.fresh_threshold = fresh_threshold + self.max_order = max_order + self.current_timestep_callback = current_timestep_callback + + def initialize_hook(self, module): + self.img_state = TaylorSeerState() + self.txt_state = TaylorSeerState() + self.ip_state = TaylorSeerState() + return module + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + current_step = self.current_timestep_callback() + assert current_step is not None, "timestep is required for TaylorSeerAttentionCacheHook" + should_predict = self.img_state.predict_counter > 0 + + if not should_predict: + attention_outputs = self.fn_ref.original_forward(*args, **kwargs) + if len(attention_outputs) == 2: + attn_output, context_attn_output = attention_outputs + self.img_state.update(attn_output, current_step, self.max_order, self.fresh_threshold) + self.txt_state.update(context_attn_output, current_step, self.max_order, self.fresh_threshold) + elif len(attention_outputs) == 3: + attn_output, context_attn_output, ip_attn_output = attention_outputs + self.img_state.update(attn_output, current_step, self.max_order, self.fresh_threshold) + self.txt_state.update(context_attn_output, current_step, self.max_order, self.fresh_threshold) + self.ip_state.update(ip_attn_output, current_step, self.max_order, self.fresh_threshold) + else: + attn_output = self.img_state.predict(current_step, self.fresh_threshold) + context_attn_output = self.txt_state.predict(current_step, self.fresh_threshold) + ip_attn_output = self.ip_state.predict(current_step, self.fresh_threshold) + attention_outputs = (attn_output, context_attn_output, ip_attn_output) + return attention_outputs + + def reset_state(self, module: torch.nn.Module) -> None: + self.img_state.reset() + self.txt_state.reset() + self.ip_state.reset() + return module + +def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfig): + for name, submodule in module.named_modules(): + if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)): + # PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB + # cannot be applied to this layer. For custom layers, users can extend this functionality and implement + # their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`. + continue + _apply_taylorseer_cache_on_attention_class(name, submodule, config) + + +def _apply_taylorseer_cache_on_attention_class(name: str, module: Attention, config: TaylorSeerCacheConfig): + _apply_taylorseer_cache_hook(module, config) + + +def _apply_taylorseer_cache_hook(module: Attention, config: TaylorSeerCacheConfig): + registry = HookRegistry.check_if_exists_or_initialize(module) + hook = TaylorSeerAttentionCacheHook(config.fresh_threshold, config.max_order, config.current_timestep_callback) + registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK) \ No newline at end of file From 8f495b607f1176d1dd11101c21bf12f35892f945 Mon Sep 17 00:00:00 2001 From: toilaluan Date: Thu, 13 Nov 2025 11:37:54 +0000 Subject: [PATCH 2/7] make compatible with any tuple size returned --- src/diffusers/hooks/taylorseer_cache.py | 55 ++++++++++++------------- 1 file changed, 26 insertions(+), 29 deletions(-) diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index 2f8f6a4b476e..b339ee1d6b9f 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -23,12 +23,12 @@ class TaylorSeerCacheConfig: class TaylorSeerState: def __init__(self): - self.predict_counter: int = 1 + self.predict_counter: int = 0 self.last_step: int = 1000 self.taylor_factors: dict[int, torch.Tensor] = {} def reset(self): - self.predict_counter = 1 + self.predict_counter = 0 self.last_step = 1000 self.taylor_factors = {} @@ -43,15 +43,15 @@ def update(self, features: torch.Tensor, current_step: int, max_order: int, refr break self.taylor_factors = new_taylor_factors self.last_step = current_step - self.predict_counter = (self.predict_counter + 1) % refresh_threshold + self.predict_counter = refresh_threshold - def predict(self, current_step: int, refresh_threshold: int): + def predict(self, current_step: int): k = current_step - self.last_step device = self.taylor_factors[0].device output = torch.zeros_like(self.taylor_factors[0], device=device) for i in range(len(self.taylor_factors)): output += self.taylor_factors[i] * (k ** i) / math.factorial(i) - self.predict_counter = (self.predict_counter + 1) % refresh_threshold + self.predict_counter -= 1 return output class TaylorSeerAttentionCacheHook(ModelHook): @@ -64,47 +64,44 @@ def __init__(self, fresh_threshold: int, max_order: int, current_timestep_callba self.current_timestep_callback = current_timestep_callback def initialize_hook(self, module): - self.img_state = TaylorSeerState() - self.txt_state = TaylorSeerState() - self.ip_state = TaylorSeerState() + self.states = None + self.num_outputs = None return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): current_step = self.current_timestep_callback() assert current_step is not None, "timestep is required for TaylorSeerAttentionCacheHook" - should_predict = self.img_state.predict_counter > 0 + + if self.states is None: + attention_outputs = self.fn_ref.original_forward(*args, **kwargs) + self.num_outputs = len(attention_outputs) + self.states = [TaylorSeerState() for _ in range(self.num_outputs)] + for i, feat in enumerate(attention_outputs): + self.states[i].update(feat, current_step, self.max_order, self.fresh_threshold) + return attention_outputs + + should_predict = self.states[0].predict_counter > 0 if not should_predict: attention_outputs = self.fn_ref.original_forward(*args, **kwargs) - if len(attention_outputs) == 2: - attn_output, context_attn_output = attention_outputs - self.img_state.update(attn_output, current_step, self.max_order, self.fresh_threshold) - self.txt_state.update(context_attn_output, current_step, self.max_order, self.fresh_threshold) - elif len(attention_outputs) == 3: - attn_output, context_attn_output, ip_attn_output = attention_outputs - self.img_state.update(attn_output, current_step, self.max_order, self.fresh_threshold) - self.txt_state.update(context_attn_output, current_step, self.max_order, self.fresh_threshold) - self.ip_state.update(ip_attn_output, current_step, self.max_order, self.fresh_threshold) - else: - attn_output = self.img_state.predict(current_step, self.fresh_threshold) - context_attn_output = self.txt_state.predict(current_step, self.fresh_threshold) - ip_attn_output = self.ip_state.predict(current_step, self.fresh_threshold) - attention_outputs = (attn_output, context_attn_output, ip_attn_output) + for i, feat in enumerate(attention_outputs): + self.states[i].update(feat, current_step, self.max_order, self.fresh_threshold) return attention_outputs + else: + predicted_outputs = [state.predict(current_step) for state in self.states] + return tuple(predicted_outputs) def reset_state(self, module: torch.nn.Module) -> None: - self.img_state.reset() - self.txt_state.reset() - self.ip_state.reset() + if self.states is not None: + for state in self.states: + state.reset() return module def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfig): for name, submodule in module.named_modules(): if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)): - # PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB - # cannot be applied to this layer. For custom layers, users can extend this functionality and implement - # their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`. continue + print(f"Applying TaylorSeer cache to {name}") _apply_taylorseer_cache_on_attention_class(name, submodule, config) From 8f8007261844069068ca70ba5d3497b66b1be526 Mon Sep 17 00:00:00 2001 From: toilaluan Date: Thu, 13 Nov 2025 13:11:29 +0000 Subject: [PATCH 3/7] use logger for printing, add warmup feature --- src/diffusers/hooks/taylorseer_cache.py | 35 ++++++++++++++++++++----- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index b339ee1d6b9f..8c3c6a7c3614 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -12,11 +12,13 @@ _ATTENTION_CLASSES, ) from ..hooks import HookRegistry - +from ..utils import logging +logger = logging.get_logger(__name__) # pylint: disable=invalid-name _TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache" @dataclass class TaylorSeerCacheConfig: + warmup_steps: int = 3 # full compute some first steps fresh_threshold: int = 5 # interleave cache and compute: `fresh_threshold` steps are cached, then 1 full compute step is performed max_order: int = 1 # order of Taylor series expansion current_timestep_callback: Callable[[], int] = None @@ -33,7 +35,9 @@ def reset(self): self.taylor_factors = {} def update(self, features: torch.Tensor, current_step: int, max_order: int, refresh_threshold: int): - N = math.abs(current_step - self.last_step) + logger.debug("="*10) + N = self.last_step - current_step + logger.debug(f"update: N: {N}, current_step: {current_step}, last_step: {self.last_step}") # initialize the first order taylor factors new_taylor_factors = {0: features} for i in range(max_order): @@ -44,6 +48,9 @@ def update(self, features: torch.Tensor, current_step: int, max_order: int, refr self.taylor_factors = new_taylor_factors self.last_step = current_step self.predict_counter = refresh_threshold + logger.debug(f"last_step: {self.last_step}") + logger.debug(f"predict_counter: {self.predict_counter}") + logger.debug("="*10) def predict(self, current_step: int): k = current_step - self.last_step @@ -52,20 +59,24 @@ def predict(self, current_step: int): for i in range(len(self.taylor_factors)): output += self.taylor_factors[i] * (k ** i) / math.factorial(i) self.predict_counter -= 1 + logger.debug(f"predict_counter: {self.predict_counter}") + logger.debug(f"k: {k}") return output class TaylorSeerAttentionCacheHook(ModelHook): _is_stateful = True - def __init__(self, fresh_threshold: int, max_order: int, current_timestep_callback: Callable[[], int]): + def __init__(self, fresh_threshold: int, max_order: int, current_timestep_callback: Callable[[], int], warmup_steps: int): super().__init__() self.fresh_threshold = fresh_threshold self.max_order = max_order self.current_timestep_callback = current_timestep_callback + self.warmup_steps = warmup_steps def initialize_hook(self, module): self.states = None self.num_outputs = None + self.warmup_steps_counter = 0 return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): @@ -74,21 +85,31 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): if self.states is None: attention_outputs = self.fn_ref.original_forward(*args, **kwargs) + if self.warmup_steps_counter < self.warmup_steps: + logger.debug(f"warmup_steps_counter: {self.warmup_steps_counter}") + self.warmup_steps_counter += 1 + return attention_outputs + if isinstance(attention_outputs, torch.Tensor): + attention_outputs = [attention_outputs] self.num_outputs = len(attention_outputs) self.states = [TaylorSeerState() for _ in range(self.num_outputs)] for i, feat in enumerate(attention_outputs): self.states[i].update(feat, current_step, self.max_order, self.fresh_threshold) - return attention_outputs + return attention_outputs[0] if len(attention_outputs) == 1 else attention_outputs should_predict = self.states[0].predict_counter > 0 if not should_predict: attention_outputs = self.fn_ref.original_forward(*args, **kwargs) + if isinstance(attention_outputs, torch.Tensor): + attention_outputs = [attention_outputs] for i, feat in enumerate(attention_outputs): self.states[i].update(feat, current_step, self.max_order, self.fresh_threshold) - return attention_outputs + return attention_outputs[0] if len(attention_outputs) == 1 else attention_outputs else: predicted_outputs = [state.predict(current_step) for state in self.states] + if len(predicted_outputs) == 1: + return predicted_outputs[0] return tuple(predicted_outputs) def reset_state(self, module: torch.nn.Module) -> None: @@ -101,7 +122,7 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi for name, submodule in module.named_modules(): if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)): continue - print(f"Applying TaylorSeer cache to {name}") + logger.debug(f"Applying TaylorSeer cache to {name}") _apply_taylorseer_cache_on_attention_class(name, submodule, config) @@ -111,5 +132,5 @@ def _apply_taylorseer_cache_on_attention_class(name: str, module: Attention, con def _apply_taylorseer_cache_hook(module: Attention, config: TaylorSeerCacheConfig): registry = HookRegistry.check_if_exists_or_initialize(module) - hook = TaylorSeerAttentionCacheHook(config.fresh_threshold, config.max_order, config.current_timestep_callback) + hook = TaylorSeerAttentionCacheHook(config.fresh_threshold, config.max_order, config.current_timestep_callback, config.warmup_steps) registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK) \ No newline at end of file From 0602044da71913832c2e81350d76f0327567efa2 Mon Sep 17 00:00:00 2001 From: toilaluan Date: Thu, 13 Nov 2025 17:03:35 +0000 Subject: [PATCH 4/7] still update in warmup steps --- src/diffusers/hooks/taylorseer_cache.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index 8c3c6a7c3614..6c99f095e26f 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -85,10 +85,6 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): if self.states is None: attention_outputs = self.fn_ref.original_forward(*args, **kwargs) - if self.warmup_steps_counter < self.warmup_steps: - logger.debug(f"warmup_steps_counter: {self.warmup_steps_counter}") - self.warmup_steps_counter += 1 - return attention_outputs if isinstance(attention_outputs, torch.Tensor): attention_outputs = [attention_outputs] self.num_outputs = len(attention_outputs) @@ -97,7 +93,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): self.states[i].update(feat, current_step, self.max_order, self.fresh_threshold) return attention_outputs[0] if len(attention_outputs) == 1 else attention_outputs - should_predict = self.states[0].predict_counter > 0 + should_predict = self.states[0].predict_counter > 0 and self.warmup_steps_counter > self.warmup_steps if not should_predict: attention_outputs = self.fn_ref.original_forward(*args, **kwargs) @@ -108,9 +104,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): return attention_outputs[0] if len(attention_outputs) == 1 else attention_outputs else: predicted_outputs = [state.predict(current_step) for state in self.states] - if len(predicted_outputs) == 1: - return predicted_outputs[0] - return tuple(predicted_outputs) + return predicted_outputs[0] if len(predicted_outputs) == 1 else predicted_outputs def reset_state(self, module: torch.nn.Module) -> None: if self.states is not None: From 1099e493e635526c8ecbc8ebca0f57e4bea2a0d8 Mon Sep 17 00:00:00 2001 From: toilaluan Date: Fri, 14 Nov 2025 07:00:12 +0000 Subject: [PATCH 5/7] refractor, add docs --- src/diffusers/__init__.py | 4 + src/diffusers/hooks/__init__.py | 1 + src/diffusers/hooks/taylorseer_cache.py | 246 +++++++++++++++++------- src/diffusers/models/cache_utils.py | 9 +- 4 files changed, 185 insertions(+), 75 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 02df34c07e8e..69d4aa4ba345 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -169,10 +169,12 @@ "LayerSkipConfig", "PyramidAttentionBroadcastConfig", "SmoothedEnergyGuidanceConfig", + "TaylorSeerCacheConfig", "apply_faster_cache", "apply_first_block_cache", "apply_layer_skip", "apply_pyramid_attention_broadcast", + "apply_taylorseer_cache", ] ) _import_structure["models"].extend( @@ -883,10 +885,12 @@ LayerSkipConfig, PyramidAttentionBroadcastConfig, SmoothedEnergyGuidanceConfig, + TaylorSeerCacheConfig, apply_faster_cache, apply_first_block_cache, apply_layer_skip, apply_pyramid_attention_broadcast, + apply_taylorseer_cache, ) from .models import ( AllegroTransformer3DModel, diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 524a92ea9966..1d9d43d96b2a 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -25,3 +25,4 @@ from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig + from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache \ No newline at end of file diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index 6c99f095e26f..509f6ba1179d 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -1,9 +1,6 @@ -# Experimental hook for TaylorSeer cache -# Supports Flux only for now - import torch from dataclasses import dataclass -from typing import Callable +from typing import Callable, Optional, List, Dict from .hooks import ModelHook import math from ..models.attention import Attention @@ -13,118 +10,219 @@ ) from ..hooks import HookRegistry from ..utils import logging + logger = logging.get_logger(__name__) # pylint: disable=invalid-name _TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache" @dataclass class TaylorSeerCacheConfig: - warmup_steps: int = 3 # full compute some first steps - fresh_threshold: int = 5 # interleave cache and compute: `fresh_threshold` steps are cached, then 1 full compute step is performed - max_order: int = 1 # order of Taylor series expansion - current_timestep_callback: Callable[[], int] = None - -class TaylorSeerState: - def __init__(self): - self.predict_counter: int = 0 - self.last_step: int = 1000 - self.taylor_factors: dict[int, torch.Tensor] = {} + """ + Configuration for TaylorSeer cache. + See: https://huggingface.co/papers/2503.06923 + + Attributes: + warmup_steps (int, defaults to 3): Number of warmup steps without caching. + predict_steps (int, defaults to 5): Number of prediction (cache) steps between non-cached steps. + max_order (int, defaults to 1): Maximum order of Taylor series expansion to approximate the features. + taylor_factors_dtype (torch.dtype, defaults to torch.float32): Data type for Taylor series expansion factors. + """ + warmup_steps: int = 3 + predict_steps: int = 5 + max_order: int = 1 + taylor_factors_dtype: torch.dtype = torch.float32 + + def __repr__(self) -> str: + return f"TaylorSeerCacheConfig(warmup_steps={self.warmup_steps}, predict_steps={self.predict_steps}, max_order={self.max_order}, taylor_factors_dtype={self.taylor_factors_dtype})" + +class TaylorSeerOutputState: + """ + Manages the state for Taylor series-based prediction of a single attention output. + Tracks Taylor expansion factors, last update step, and remaining prediction steps. + The Taylor expansion uses the timestep as the independent variable for approximation. + """ + + def __init__(self, module_name: str, taylor_factors_dtype: torch.dtype, module_dtype: torch.dtype): + self.module_name = module_name + self.remaining_predictions: int = 0 + self.last_update_step: Optional[int] = None + self.taylor_factors: Dict[int, torch.Tensor] = {} + self.taylor_factors_dtype = taylor_factors_dtype + self.module_dtype = module_dtype def reset(self): - self.predict_counter = 0 - self.last_step = 1000 + self.remaining_predictions = 0 + self.last_update_step = None self.taylor_factors = {} - def update(self, features: torch.Tensor, current_step: int, max_order: int, refresh_threshold: int): - logger.debug("="*10) - N = self.last_step - current_step - logger.debug(f"update: N: {N}, current_step: {current_step}, last_step: {self.last_step}") - # initialize the first order taylor factors - new_taylor_factors = {0: features} - for i in range(max_order): - if (self.taylor_factors.get(i) is not None) and current_step > 1: - new_taylor_factors[i+1] = (self.taylor_factors[i] - new_taylor_factors[i]) / N - else: - break - self.taylor_factors = new_taylor_factors - self.last_step = current_step - self.predict_counter = refresh_threshold - logger.debug(f"last_step: {self.last_step}") - logger.debug(f"predict_counter: {self.predict_counter}") - logger.debug("="*10) - - def predict(self, current_step: int): - k = current_step - self.last_step + def update(self, features: torch.Tensor, current_step: int, max_order: int, predict_steps: int, is_first_update: bool): + """ + Updates the Taylor factors based on the current features and timestep. + Computes finite difference approximations for derivatives using recursive divided differences. + + Args: + features (torch.Tensor): The attention output features to update with. + current_step (int): The current timestep or step number from the diffusion model. + max_order (int): Maximum order of the Taylor expansion. + predict_steps (int): Number of prediction steps to set after update. + is_first_update (bool): Whether this is the initial update (skips difference computation). + """ + features = features.to(self.taylor_factors_dtype) + new_factors = {0: features} + if not is_first_update: + if self.last_update_step is None: + raise ValueError("Cannot update without prior initialization.") + delta_step = current_step - self.last_update_step + if delta_step == 0: + raise ValueError("Delta step cannot be zero for updates.") + for i in range(max_order): + if i in self.taylor_factors: + # Finite difference: (current - previous) / delta for forward approximation + new_factors[i + 1] = (new_factors[i] - self.taylor_factors[i].to(self.taylor_factors_dtype)) / delta_step + + # taylor factors will be kept in the taylor_factors_dtype + self.taylor_factors = new_factors + self.last_update_step = current_step + self.remaining_predictions = predict_steps + + def predict(self, current_step: int) -> torch.Tensor: + """ + Predicts the features using the Taylor series expansion at the given timestep. + + Args: + current_step (int): The current timestep for prediction. + + Returns: + torch.Tensor: The predicted features in the module's dtype. + """ + if self.last_update_step is None: + raise ValueError("Cannot predict without prior update.") + step_offset = current_step - self.last_update_step device = self.taylor_factors[0].device - output = torch.zeros_like(self.taylor_factors[0], device=device) - for i in range(len(self.taylor_factors)): - output += self.taylor_factors[i] * (k ** i) / math.factorial(i) - self.predict_counter -= 1 - logger.debug(f"predict_counter: {self.predict_counter}") - logger.debug(f"k: {k}") - return output + output = torch.zeros_like(self.taylor_factors[0], device=device, dtype=self.taylor_factors_dtype) + for order in range(len(self.taylor_factors)): + output += self.taylor_factors[order] * (step_offset ** order) / math.factorial(order) + self.remaining_predictions -= 1 + # output will be converted to the module's dtype + return output.to(self.module_dtype) class TaylorSeerAttentionCacheHook(ModelHook): + """ + Hook for caching and predicting attention outputs using Taylor series approximations. + Applies to attention modules in diffusion models (e.g., Flux). + Performs full computations during warmup, then alternates between predictions and refreshes. + """ _is_stateful = True - def __init__(self, fresh_threshold: int, max_order: int, current_timestep_callback: Callable[[], int], warmup_steps: int): + def __init__( + self, + module_name: str, + predict_steps: int, + max_order: int, + warmup_steps: int, + taylor_factors_dtype: torch.dtype, + module_dtype: torch.dtype = None, + ): super().__init__() - self.fresh_threshold = fresh_threshold + self.module_name = module_name + self.predict_steps = predict_steps self.max_order = max_order - self.current_timestep_callback = current_timestep_callback self.warmup_steps = warmup_steps - - def initialize_hook(self, module): + self.step_counter = -1 + self.states: Optional[List[TaylorSeerOutputState]] = None + self.num_outputs: Optional[int] = None + self.taylor_factors_dtype = taylor_factors_dtype + self.module_dtype = module_dtype + + def initialize_hook(self, module: torch.nn.Module): + self.step_counter = -1 self.states = None self.num_outputs = None - self.warmup_steps_counter = 0 + self.module_dtype = None return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): - current_step = self.current_timestep_callback() - assert current_step is not None, "timestep is required for TaylorSeerAttentionCacheHook" + self.step_counter += 1 + is_warmup_phase = self.step_counter < self.warmup_steps if self.states is None: + # First step: always full compute and initialize attention_outputs = self.fn_ref.original_forward(*args, **kwargs) if isinstance(attention_outputs, torch.Tensor): attention_outputs = [attention_outputs] + else: + attention_outputs = list(attention_outputs) + module_dtype = attention_outputs[0].dtype self.num_outputs = len(attention_outputs) - self.states = [TaylorSeerState() for _ in range(self.num_outputs)] - for i, feat in enumerate(attention_outputs): - self.states[i].update(feat, current_step, self.max_order, self.fresh_threshold) - return attention_outputs[0] if len(attention_outputs) == 1 else attention_outputs - - should_predict = self.states[0].predict_counter > 0 and self.warmup_steps_counter > self.warmup_steps - - if not should_predict: + self.states = [ + TaylorSeerOutputState(self.module_name, self.taylor_factors_dtype, module_dtype) + for _ in range(self.num_outputs) + ] + for i, features in enumerate(attention_outputs): + self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update=True) + return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs) + + should_predict = self.states[0].remaining_predictions > 0 + if is_warmup_phase or not should_predict: + # Full compute during warmup or when refresh needed attention_outputs = self.fn_ref.original_forward(*args, **kwargs) if isinstance(attention_outputs, torch.Tensor): attention_outputs = [attention_outputs] - for i, feat in enumerate(attention_outputs): - self.states[i].update(feat, current_step, self.max_order, self.fresh_threshold) - return attention_outputs[0] if len(attention_outputs) == 1 else attention_outputs + else: + attention_outputs = list(attention_outputs) + is_first_update = self.step_counter == 0 # Only True for the very first step + for i, features in enumerate(attention_outputs): + self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update) + return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs) else: - predicted_outputs = [state.predict(current_step) for state in self.states] - return predicted_outputs[0] if len(predicted_outputs) == 1 else predicted_outputs + # Predict using Taylor series + predicted_outputs = [state.predict(self.step_counter) for state in self.states] + return predicted_outputs[0] if self.num_outputs == 1 else tuple(predicted_outputs) def reset_state(self, module: torch.nn.Module) -> None: if self.states is not None: for state in self.states: state.reset() - return module def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfig): - for name, submodule in module.named_modules(): - if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)): - continue - logger.debug(f"Applying TaylorSeer cache to {name}") - _apply_taylorseer_cache_on_attention_class(name, submodule, config) + """ + Applies the TaylorSeer cache to given pipeline. + Args: + module (torch.nn.Module): The model to apply the hook to. + config (TaylorSeerCacheConfig): Configuration for the cache. -def _apply_taylorseer_cache_on_attention_class(name: str, module: Attention, config: TaylorSeerCacheConfig): - _apply_taylorseer_cache_hook(module, config) + Example: + ```python + >>> import torch + >>> from diffusers import FluxPipeline, TaylorSeerCacheConfig, apply_taylorseer_cache + >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") -def _apply_taylorseer_cache_hook(module: Attention, config: TaylorSeerCacheConfig): + >>> config = TaylorSeerCacheConfig(predict_steps=5, max_order=1, warmup_steps=3, taylor_factors_dtype=torch.float32) + >>> apply_taylorseer_cache(pipe.transformer, config) + ``` + """ + for name, submodule in module.named_modules(): + if isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)): + logger.debug(f"Applying TaylorSeer cache to {name}") + _apply_taylorseer_cache_hook(name, submodule, config) + +def _apply_taylorseer_cache_hook(name: str, module: Attention, config: TaylorSeerCacheConfig): + """ + Registers the TaylorSeer hook on the specified attention module. + + Args: + name (str): Name of the module. + module (Attention): The attention module. + config (TaylorSeerCacheConfig): Configuration for the cache. + """ registry = HookRegistry.check_if_exists_or_initialize(module) - hook = TaylorSeerAttentionCacheHook(config.fresh_threshold, config.max_order, config.current_timestep_callback, config.warmup_steps) + hook = TaylorSeerAttentionCacheHook( + name, + config.predict_steps, + config.max_order, + config.warmup_steps, + config.taylor_factors_dtype, + ) registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK) \ No newline at end of file diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 605c0d588c8c..ffbf296ff617 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -67,9 +67,11 @@ def enable_cache(self, config) -> None: FasterCacheConfig, FirstBlockCacheConfig, PyramidAttentionBroadcastConfig, + TaylorSeerCacheConfig, apply_faster_cache, apply_first_block_cache, apply_pyramid_attention_broadcast, + apply_taylorseer_cache, ) if self.is_cache_enabled: @@ -83,16 +85,19 @@ def enable_cache(self, config) -> None: apply_first_block_cache(self, config) elif isinstance(config, PyramidAttentionBroadcastConfig): apply_pyramid_attention_broadcast(self, config) + elif isinstance(config, TaylorSeerCacheConfig): + apply_taylorseer_cache(self, config) else: raise ValueError(f"Cache config {type(config)} is not supported.") self._cache_config = config def disable_cache(self) -> None: - from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig + from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig, TaylorSeerCacheConfig from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK + from ..hooks.taylorseer_cache import _TAYLORSEER_ATTENTION_CACHE_HOOK if self._cache_config is None: logger.warning("Caching techniques have not been enabled, so there's nothing to disable.") @@ -107,6 +112,8 @@ def disable_cache(self) -> None: registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True) elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig): registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True) + elif isinstance(self._cache_config, TaylorSeerCacheConfig): + registry.remove_hook(_TAYLORSEER_ATTENTION_CACHE_HOOK, recurse=True) else: raise ValueError(f"Cache config {type(self._cache_config)} is not supported.") From 7b4ad2de63c314489b8129a496bea5c67e31cf7e Mon Sep 17 00:00:00 2001 From: toilaluan Date: Fri, 14 Nov 2025 09:09:46 +0000 Subject: [PATCH 6/7] add configurable cache, skip compute module --- src/diffusers/hooks/taylorseer_cache.py | 169 ++++++++++++++++++------ 1 file changed, 126 insertions(+), 43 deletions(-) diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index 509f6ba1179d..89d6da307488 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -10,10 +10,28 @@ ) from ..hooks import HookRegistry from ..utils import logging - +import re +from collections import defaultdict logger = logging.get_logger(__name__) # pylint: disable=invalid-name _TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache" +SPECIAL_CACHE_IDENTIFIERS = { + "flux": [ + r"transformer_blocks\.\d+\.attn", + r"transformer_blocks\.\d+\.ff", + r"transformer_blocks\.\d+\.ff_context", + r"single_transformer_blocks\.\d+\.proj_out", + ] +} +SKIP_COMPUTE_IDENTIFIERS = { + "flux": [ + r"single_transformer_blocks\.\d+\.attn", + r"single_transformer_blocks\.\d+\.proj_mlp", + r"single_transformer_blocks\.\d+\.act_mlp", + ] +} + + @dataclass class TaylorSeerCacheConfig: """ @@ -25,14 +43,22 @@ class TaylorSeerCacheConfig: predict_steps (int, defaults to 5): Number of prediction (cache) steps between non-cached steps. max_order (int, defaults to 1): Maximum order of Taylor series expansion to approximate the features. taylor_factors_dtype (torch.dtype, defaults to torch.float32): Data type for Taylor series expansion factors. + architecture (str, defaults to None): Architecture for which the cache is applied. If we know the architecture, we can use the special cache identifiers. + skip_compute_identifiers (List[str], defaults to []): Identifiers for modules to skip computation. + special_cache_identifiers (List[str], defaults to []): Identifiers for modules to use special cache. """ + warmup_steps: int = 3 predict_steps: int = 5 max_order: int = 1 taylor_factors_dtype: torch.dtype = torch.float32 + architecture: str | None = None + skip_compute_identifiers: List[str] = None + special_cache_identifiers: List[str] = None def __repr__(self) -> str: - return f"TaylorSeerCacheConfig(warmup_steps={self.warmup_steps}, predict_steps={self.predict_steps}, max_order={self.max_order}, taylor_factors_dtype={self.taylor_factors_dtype})" + return f"TaylorSeerCacheConfig(warmup_steps={self.warmup_steps}, predict_steps={self.predict_steps}, max_order={self.max_order}, taylor_factors_dtype={self.taylor_factors_dtype}, architecture={self.architecture}, skip_compute_identifiers={self.skip_compute_identifiers}, special_cache_identifiers={self.special_cache_identifiers})" + class TaylorSeerOutputState: """ @@ -41,20 +67,31 @@ class TaylorSeerOutputState: The Taylor expansion uses the timestep as the independent variable for approximation. """ - def __init__(self, module_name: str, taylor_factors_dtype: torch.dtype, module_dtype: torch.dtype): + def __init__( + self, module_name: str, taylor_factors_dtype: torch.dtype, module_dtype: torch.dtype, is_skip: bool = False + ): self.module_name = module_name self.remaining_predictions: int = 0 self.last_update_step: Optional[int] = None self.taylor_factors: Dict[int, torch.Tensor] = {} self.taylor_factors_dtype = taylor_factors_dtype self.module_dtype = module_dtype + self.is_skip = is_skip + self.dummy_shape: Optional[Tuple[int, ...]] = None + self.device: Optional[torch.device] = None + self.dummy_tensor: Optional[torch.Tensor] = None def reset(self): self.remaining_predictions = 0 self.last_update_step = None self.taylor_factors = {} + self.dummy_shape = None + self.device = None + self.dummy_tensor = None - def update(self, features: torch.Tensor, current_step: int, max_order: int, predict_steps: int, is_first_update: bool): + def update( + self, features: torch.Tensor, current_step: int, max_order: int, predict_steps: int, is_first_update: bool + ): """ Updates the Taylor factors based on the current features and timestep. Computes finite difference approximations for derivatives using recursive divided differences. @@ -66,23 +103,33 @@ def update(self, features: torch.Tensor, current_step: int, max_order: int, pred predict_steps (int): Number of prediction steps to set after update. is_first_update (bool): Whether this is the initial update (skips difference computation). """ - features = features.to(self.taylor_factors_dtype) - new_factors = {0: features} - if not is_first_update: - if self.last_update_step is None: - raise ValueError("Cannot update without prior initialization.") - delta_step = current_step - self.last_update_step - if delta_step == 0: - raise ValueError("Delta step cannot be zero for updates.") - for i in range(max_order): - if i in self.taylor_factors: - # Finite difference: (current - previous) / delta for forward approximation - new_factors[i + 1] = (new_factors[i] - self.taylor_factors[i].to(self.taylor_factors_dtype)) / delta_step - - # taylor factors will be kept in the taylor_factors_dtype - self.taylor_factors = new_factors - self.last_update_step = current_step - self.remaining_predictions = predict_steps + if self.is_skip: + self.dummy_shape = features.shape + self.device = features.device + self.taylor_factors = {} + self.last_update_step = current_step + self.remaining_predictions = predict_steps + else: + features = features.to(self.taylor_factors_dtype) + new_factors = {0: features} + if not is_first_update: + if self.last_update_step is None: + raise ValueError("Cannot update without prior initialization.") + delta_step = current_step - self.last_update_step + if delta_step == 0: + raise ValueError("Delta step cannot be zero for updates.") + for i in range(max_order): + if i in self.taylor_factors: + new_factors[i + 1] = ( + new_factors[i] - self.taylor_factors[i].to(self.taylor_factors_dtype) + ) / delta_step + else: + break + + # taylor factors will be kept in the taylor_factors_dtype + self.taylor_factors = new_factors + self.last_update_step = current_step + self.remaining_predictions = predict_steps def predict(self, current_step: int) -> torch.Tensor: """ @@ -94,16 +141,22 @@ def predict(self, current_step: int) -> torch.Tensor: Returns: torch.Tensor: The predicted features in the module's dtype. """ - if self.last_update_step is None: - raise ValueError("Cannot predict without prior update.") - step_offset = current_step - self.last_update_step - device = self.taylor_factors[0].device - output = torch.zeros_like(self.taylor_factors[0], device=device, dtype=self.taylor_factors_dtype) - for order in range(len(self.taylor_factors)): - output += self.taylor_factors[order] * (step_offset ** order) / math.factorial(order) - self.remaining_predictions -= 1 - # output will be converted to the module's dtype - return output.to(self.module_dtype) + if self.is_skip: + if self.dummy_shape is None or self.device is None: + raise ValueError("Cannot predict for skip module without prior update.") + self.remaining_predictions -= 1 + return torch.empty(self.dummy_shape, dtype=self.module_dtype, device=self.device) + else: + if self.last_update_step is None: + raise ValueError("Cannot predict without prior update.") + step_offset = current_step - self.last_update_step + output = 0 + for order in range(len(self.taylor_factors)): + output += self.taylor_factors[order] * (step_offset**order) * (1 / math.factorial(order)) + self.remaining_predictions -= 1 + # output will be converted to the module's dtype + return output.to(self.module_dtype) + class TaylorSeerAttentionCacheHook(ModelHook): """ @@ -111,6 +164,7 @@ class TaylorSeerAttentionCacheHook(ModelHook): Applies to attention modules in diffusion models (e.g., Flux). Performs full computations during warmup, then alternates between predictions and refreshes. """ + _is_stateful = True def __init__( @@ -120,7 +174,7 @@ def __init__( max_order: int, warmup_steps: int, taylor_factors_dtype: torch.dtype, - module_dtype: torch.dtype = None, + is_skip_compute: bool = False, ): super().__init__() self.module_name = module_name @@ -131,13 +185,12 @@ def __init__( self.states: Optional[List[TaylorSeerOutputState]] = None self.num_outputs: Optional[int] = None self.taylor_factors_dtype = taylor_factors_dtype - self.module_dtype = module_dtype + self.is_skip_compute = is_skip_compute def initialize_hook(self, module: torch.nn.Module): self.step_counter = -1 self.states = None self.num_outputs = None - self.module_dtype = None return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): @@ -154,11 +207,15 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): module_dtype = attention_outputs[0].dtype self.num_outputs = len(attention_outputs) self.states = [ - TaylorSeerOutputState(self.module_name, self.taylor_factors_dtype, module_dtype) + TaylorSeerOutputState( + self.module_name, self.taylor_factors_dtype, module_dtype, is_skip=self.is_skip_compute + ) for _ in range(self.num_outputs) ] for i, features in enumerate(attention_outputs): - self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update=True) + self.states[i].update( + features, self.step_counter, self.max_order, self.predict_steps, is_first_update=True + ) return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs) should_predict = self.states[0].remaining_predictions > 0 @@ -179,9 +236,8 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): return predicted_outputs[0] if self.num_outputs == 1 else tuple(predicted_outputs) def reset_state(self, module: torch.nn.Module) -> None: - if self.states is not None: - for state in self.states: - state.reset() + self.states = None + def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfig): """ @@ -199,30 +255,57 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") - >>> config = TaylorSeerCacheConfig(predict_steps=5, max_order=1, warmup_steps=3, taylor_factors_dtype=torch.float32) + >>> config = TaylorSeerCacheConfig(predict_steps=5, max_order=1, warmup_steps=3, taylor_factors_dtype=torch.float32, architecture="flux") >>> apply_taylorseer_cache(pipe.transformer, config) ``` """ + if config.skip_compute_identifiers: + skip_compute_identifiers = config.skip_compute_identifiers + else: + skip_compute_identifiers = SKIP_COMPUTE_IDENTIFIERS.get(config.architecture, []) + + if config.special_cache_identifiers: + special_cache_identifiers = config.special_cache_identifiers + else: + special_cache_identifiers = SPECIAL_CACHE_IDENTIFIERS.get(config.architecture, []) + + logger.debug(f"Skip compute identifiers: {skip_compute_identifiers}") + logger.debug(f"Special cache identifiers: {special_cache_identifiers}") + for name, submodule in module.named_modules(): - if isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)): + if skip_compute_identifiers and special_cache_identifiers: + if any(re.fullmatch(identifier, name) for identifier in skip_compute_identifiers) or any( + re.fullmatch(identifier, name) for identifier in special_cache_identifiers + ): + logger.debug(f"Applying TaylorSeer cache to {name}") + _apply_taylorseer_cache_hook(name, submodule, config) + elif isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)): logger.debug(f"Applying TaylorSeer cache to {name}") _apply_taylorseer_cache_hook(name, submodule, config) + def _apply_taylorseer_cache_hook(name: str, module: Attention, config: TaylorSeerCacheConfig): """ Registers the TaylorSeer hook on the specified attention module. - Args: name (str): Name of the module. module (Attention): The attention module. config (TaylorSeerCacheConfig): Configuration for the cache. """ + + is_skip_compute = any( + re.fullmatch(identifier, name) for identifier in SKIP_COMPUTE_IDENTIFIERS.get(config.architecture, []) + ) + registry = HookRegistry.check_if_exists_or_initialize(module) + hook = TaylorSeerAttentionCacheHook( name, config.predict_steps, config.max_order, config.warmup_steps, config.taylor_factors_dtype, + is_skip_compute=is_skip_compute, ) - registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK) \ No newline at end of file + + registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK) From 51b4318a3e5b2dc2b3df93f6e2fc2decc254a320 Mon Sep 17 00:00:00 2001 From: toilaluan Date: Sat, 15 Nov 2025 05:13:33 +0000 Subject: [PATCH 7/7] allow special cache ids only --- src/diffusers/hooks/taylorseer_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index 89d6da307488..3c5d0a2f3991 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -273,7 +273,7 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi logger.debug(f"Special cache identifiers: {special_cache_identifiers}") for name, submodule in module.named_modules(): - if skip_compute_identifiers and special_cache_identifiers: + if (skip_compute_identifiers and special_cache_identifiers) or (special_cache_identifiers): if any(re.fullmatch(identifier, name) for identifier in skip_compute_identifiers) or any( re.fullmatch(identifier, name) for identifier in special_cache_identifiers ):