Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,12 @@
"LayerSkipConfig",
"PyramidAttentionBroadcastConfig",
"SmoothedEnergyGuidanceConfig",
"TaylorSeerCacheConfig",
"apply_faster_cache",
"apply_first_block_cache",
"apply_layer_skip",
"apply_pyramid_attention_broadcast",
"apply_taylorseer_cache",
]
)
_import_structure["models"].extend(
Expand Down Expand Up @@ -883,10 +885,12 @@
LayerSkipConfig,
PyramidAttentionBroadcastConfig,
SmoothedEnergyGuidanceConfig,
TaylorSeerCacheConfig,
apply_faster_cache,
apply_first_block_cache,
apply_layer_skip,
apply_pyramid_attention_broadcast,
apply_taylorseer_cache,
)
from .models import (
AllegroTransformer3DModel,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache
311 changes: 311 additions & 0 deletions src/diffusers/hooks/taylorseer_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,311 @@
import torch
from dataclasses import dataclass
from typing import Callable, Optional, List, Dict
from .hooks import ModelHook
import math
from ..models.attention import Attention
from ..models.attention import AttentionModuleMixin
from ._common import (
_ATTENTION_CLASSES,
)
from ..hooks import HookRegistry
from ..utils import logging
import re
from collections import defaultdict
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
_TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache"

SPECIAL_CACHE_IDENTIFIERS = {
"flux": [
r"transformer_blocks\.\d+\.attn",
r"transformer_blocks\.\d+\.ff",
r"transformer_blocks\.\d+\.ff_context",
r"single_transformer_blocks\.\d+\.proj_out",
]
}
SKIP_COMPUTE_IDENTIFIERS = {
"flux": [
r"single_transformer_blocks\.\d+\.attn",
r"single_transformer_blocks\.\d+\.proj_mlp",
r"single_transformer_blocks\.\d+\.act_mlp",
]
}


@dataclass
class TaylorSeerCacheConfig:
"""
Configuration for TaylorSeer cache.
See: https://huggingface.co/papers/2503.06923

Attributes:
warmup_steps (int, defaults to 3): Number of warmup steps without caching.
predict_steps (int, defaults to 5): Number of prediction (cache) steps between non-cached steps.
max_order (int, defaults to 1): Maximum order of Taylor series expansion to approximate the features.
taylor_factors_dtype (torch.dtype, defaults to torch.float32): Data type for Taylor series expansion factors.
architecture (str, defaults to None): Architecture for which the cache is applied. If we know the architecture, we can use the special cache identifiers.
skip_compute_identifiers (List[str], defaults to []): Identifiers for modules to skip computation.
special_cache_identifiers (List[str], defaults to []): Identifiers for modules to use special cache.
"""

warmup_steps: int = 3
predict_steps: int = 5
max_order: int = 1
taylor_factors_dtype: torch.dtype = torch.float32
architecture: str | None = None
skip_compute_identifiers: List[str] = None
special_cache_identifiers: List[str] = None

def __repr__(self) -> str:
return f"TaylorSeerCacheConfig(warmup_steps={self.warmup_steps}, predict_steps={self.predict_steps}, max_order={self.max_order}, taylor_factors_dtype={self.taylor_factors_dtype}, architecture={self.architecture}, skip_compute_identifiers={self.skip_compute_identifiers}, special_cache_identifiers={self.special_cache_identifiers})"


class TaylorSeerOutputState:
"""
Manages the state for Taylor series-based prediction of a single attention output.
Tracks Taylor expansion factors, last update step, and remaining prediction steps.
The Taylor expansion uses the timestep as the independent variable for approximation.
"""

def __init__(
self, module_name: str, taylor_factors_dtype: torch.dtype, module_dtype: torch.dtype, is_skip: bool = False
):
self.module_name = module_name
self.remaining_predictions: int = 0
self.last_update_step: Optional[int] = None
self.taylor_factors: Dict[int, torch.Tensor] = {}
self.taylor_factors_dtype = taylor_factors_dtype
self.module_dtype = module_dtype
self.is_skip = is_skip
self.dummy_shape: Optional[Tuple[int, ...]] = None
self.device: Optional[torch.device] = None
self.dummy_tensor: Optional[torch.Tensor] = None

def reset(self):
self.remaining_predictions = 0
self.last_update_step = None
self.taylor_factors = {}
self.dummy_shape = None
self.device = None
self.dummy_tensor = None

def update(
self, features: torch.Tensor, current_step: int, max_order: int, predict_steps: int, is_first_update: bool
):
"""
Updates the Taylor factors based on the current features and timestep.
Computes finite difference approximations for derivatives using recursive divided differences.

Args:
features (torch.Tensor): The attention output features to update with.
current_step (int): The current timestep or step number from the diffusion model.
max_order (int): Maximum order of the Taylor expansion.
predict_steps (int): Number of prediction steps to set after update.
is_first_update (bool): Whether this is the initial update (skips difference computation).
"""
if self.is_skip:
self.dummy_shape = features.shape
self.device = features.device
self.taylor_factors = {}
self.last_update_step = current_step
self.remaining_predictions = predict_steps
else:
features = features.to(self.taylor_factors_dtype)
new_factors = {0: features}
if not is_first_update:
if self.last_update_step is None:
raise ValueError("Cannot update without prior initialization.")
delta_step = current_step - self.last_update_step
if delta_step == 0:
raise ValueError("Delta step cannot be zero for updates.")
for i in range(max_order):
if i in self.taylor_factors:
new_factors[i + 1] = (
new_factors[i] - self.taylor_factors[i].to(self.taylor_factors_dtype)
) / delta_step
else:
break

# taylor factors will be kept in the taylor_factors_dtype
self.taylor_factors = new_factors
self.last_update_step = current_step
self.remaining_predictions = predict_steps

def predict(self, current_step: int) -> torch.Tensor:
"""
Predicts the features using the Taylor series expansion at the given timestep.

Args:
current_step (int): The current timestep for prediction.

Returns:
torch.Tensor: The predicted features in the module's dtype.
"""
if self.is_skip:
if self.dummy_shape is None or self.device is None:
raise ValueError("Cannot predict for skip module without prior update.")
self.remaining_predictions -= 1
return torch.empty(self.dummy_shape, dtype=self.module_dtype, device=self.device)
else:
if self.last_update_step is None:
raise ValueError("Cannot predict without prior update.")
step_offset = current_step - self.last_update_step
output = 0
for order in range(len(self.taylor_factors)):
output += self.taylor_factors[order] * (step_offset**order) * (1 / math.factorial(order))
self.remaining_predictions -= 1
# output will be converted to the module's dtype
return output.to(self.module_dtype)


class TaylorSeerAttentionCacheHook(ModelHook):
"""
Hook for caching and predicting attention outputs using Taylor series approximations.
Applies to attention modules in diffusion models (e.g., Flux).
Performs full computations during warmup, then alternates between predictions and refreshes.
"""

_is_stateful = True

def __init__(
self,
module_name: str,
predict_steps: int,
max_order: int,
warmup_steps: int,
taylor_factors_dtype: torch.dtype,
is_skip_compute: bool = False,
):
super().__init__()
self.module_name = module_name
self.predict_steps = predict_steps
self.max_order = max_order
self.warmup_steps = warmup_steps
self.step_counter = -1
self.states: Optional[List[TaylorSeerOutputState]] = None
self.num_outputs: Optional[int] = None
self.taylor_factors_dtype = taylor_factors_dtype
self.is_skip_compute = is_skip_compute

def initialize_hook(self, module: torch.nn.Module):
self.step_counter = -1
self.states = None
self.num_outputs = None
return module

def new_forward(self, module: torch.nn.Module, *args, **kwargs):
self.step_counter += 1
is_warmup_phase = self.step_counter < self.warmup_steps

if self.states is None:
# First step: always full compute and initialize
attention_outputs = self.fn_ref.original_forward(*args, **kwargs)
if isinstance(attention_outputs, torch.Tensor):
attention_outputs = [attention_outputs]
else:
attention_outputs = list(attention_outputs)
module_dtype = attention_outputs[0].dtype
self.num_outputs = len(attention_outputs)
self.states = [
TaylorSeerOutputState(
self.module_name, self.taylor_factors_dtype, module_dtype, is_skip=self.is_skip_compute
)
for _ in range(self.num_outputs)
]
for i, features in enumerate(attention_outputs):
self.states[i].update(
features, self.step_counter, self.max_order, self.predict_steps, is_first_update=True
)
return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs)

should_predict = self.states[0].remaining_predictions > 0
if is_warmup_phase or not should_predict:
# Full compute during warmup or when refresh needed
attention_outputs = self.fn_ref.original_forward(*args, **kwargs)
if isinstance(attention_outputs, torch.Tensor):
attention_outputs = [attention_outputs]
else:
attention_outputs = list(attention_outputs)
is_first_update = self.step_counter == 0 # Only True for the very first step
for i, features in enumerate(attention_outputs):
self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update)
return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs)
else:
# Predict using Taylor series
predicted_outputs = [state.predict(self.step_counter) for state in self.states]
return predicted_outputs[0] if self.num_outputs == 1 else tuple(predicted_outputs)

def reset_state(self, module: torch.nn.Module) -> None:
self.states = None


def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfig):
"""
Applies the TaylorSeer cache to given pipeline.

Args:
module (torch.nn.Module): The model to apply the hook to.
config (TaylorSeerCacheConfig): Configuration for the cache.

Example:
```python
>>> import torch
>>> from diffusers import FluxPipeline, TaylorSeerCacheConfig, apply_taylorseer_cache

>>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")

>>> config = TaylorSeerCacheConfig(predict_steps=5, max_order=1, warmup_steps=3, taylor_factors_dtype=torch.float32, architecture="flux")
>>> apply_taylorseer_cache(pipe.transformer, config)
```
"""
if config.skip_compute_identifiers:
skip_compute_identifiers = config.skip_compute_identifiers
else:
skip_compute_identifiers = SKIP_COMPUTE_IDENTIFIERS.get(config.architecture, [])

if config.special_cache_identifiers:
special_cache_identifiers = config.special_cache_identifiers
else:
special_cache_identifiers = SPECIAL_CACHE_IDENTIFIERS.get(config.architecture, [])

logger.debug(f"Skip compute identifiers: {skip_compute_identifiers}")
logger.debug(f"Special cache identifiers: {special_cache_identifiers}")

for name, submodule in module.named_modules():
if (skip_compute_identifiers and special_cache_identifiers) or (special_cache_identifiers):
if any(re.fullmatch(identifier, name) for identifier in skip_compute_identifiers) or any(
re.fullmatch(identifier, name) for identifier in special_cache_identifiers
):
logger.debug(f"Applying TaylorSeer cache to {name}")
_apply_taylorseer_cache_hook(name, submodule, config)
elif isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
logger.debug(f"Applying TaylorSeer cache to {name}")
_apply_taylorseer_cache_hook(name, submodule, config)


def _apply_taylorseer_cache_hook(name: str, module: Attention, config: TaylorSeerCacheConfig):
"""
Registers the TaylorSeer hook on the specified attention module.
Args:
name (str): Name of the module.
module (Attention): The attention module.
config (TaylorSeerCacheConfig): Configuration for the cache.
"""

is_skip_compute = any(
re.fullmatch(identifier, name) for identifier in SKIP_COMPUTE_IDENTIFIERS.get(config.architecture, [])
)

registry = HookRegistry.check_if_exists_or_initialize(module)

hook = TaylorSeerAttentionCacheHook(
name,
config.predict_steps,
config.max_order,
config.warmup_steps,
config.taylor_factors_dtype,
is_skip_compute=is_skip_compute,
)

registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK)
9 changes: 8 additions & 1 deletion src/diffusers/models/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,11 @@ def enable_cache(self, config) -> None:
FasterCacheConfig,
FirstBlockCacheConfig,
PyramidAttentionBroadcastConfig,
TaylorSeerCacheConfig,
apply_faster_cache,
apply_first_block_cache,
apply_pyramid_attention_broadcast,
apply_taylorseer_cache,
)

if self.is_cache_enabled:
Expand All @@ -83,16 +85,19 @@ def enable_cache(self, config) -> None:
apply_first_block_cache(self, config)
elif isinstance(config, PyramidAttentionBroadcastConfig):
apply_pyramid_attention_broadcast(self, config)
elif isinstance(config, TaylorSeerCacheConfig):
apply_taylorseer_cache(self, config)
else:
raise ValueError(f"Cache config {type(config)} is not supported.")

self._cache_config = config

def disable_cache(self) -> None:
from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig, TaylorSeerCacheConfig
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
from ..hooks.taylorseer_cache import _TAYLORSEER_ATTENTION_CACHE_HOOK

if self._cache_config is None:
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
Expand All @@ -107,6 +112,8 @@ def disable_cache(self) -> None:
registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True)
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
elif isinstance(self._cache_config, TaylorSeerCacheConfig):
registry.remove_hook(_TAYLORSEER_ATTENTION_CACHE_HOOK, recurse=True)
else:
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")

Expand Down