Skip to content

Commit 6290fdf

Browse files
toilaluantoilaluangithub-actions[bot]DN6sayakpaul
authored
[Feat] TaylorSeer Cache (#12648)
* init taylor_seer cache * make compatible with any tuple size returned * use logger for printing, add warmup feature * still update in warmup steps * refractor, add docs * add configurable cache, skip compute module * allow special cache ids only * add stop_predicts (cooldown) * update docs * apply ruff * update to handle multple calls per timestep * refractor to use state manager * fix format & doc * chores: naming, remove redundancy * add docs * quality & style * fix taylor precision * Apply style fixes * add tests * Apply style fixes * Remove TaylorSeerCacheTesterMixin from flux2 tests * rename identifiers, use more expressive taylor predict loop * torch compile compatible * Apply style fixes * Update src/diffusers/hooks/taylorseer_cache.py Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> * update docs * make fix-copies * fix example usage. * remove tests on flux kontext --------- Co-authored-by: toilaluan <toilaluan@github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 256e010 commit 6290fdf

File tree

10 files changed

+477
-1
lines changed

10 files changed

+477
-1
lines changed

docs/source/en/api/cache.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,9 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate
3434
[[autodoc]] FirstBlockCacheConfig
3535

3636
[[autodoc]] apply_first_block_cache
37+
38+
### TaylorSeerCacheConfig
39+
40+
[[autodoc]] TaylorSeerCacheConfig
41+
42+
[[autodoc]] apply_taylorseer_cache

docs/source/en/optimization/cache.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,4 +66,35 @@ config = FasterCacheConfig(
6666
tensor_format="BFCHW",
6767
)
6868
pipeline.transformer.enable_cache(config)
69+
```
70+
71+
## TaylorSeer Cache
72+
73+
[TaylorSeer Cache](https://huggingface.co/papers/2403.06923) accelerates diffusion inference by using Taylor series expansions to approximate and cache intermediate activations across denoising steps. The method predicts future outputs based on past computations, reusing them at specified intervals to reduce redundant calculations.
74+
75+
This caching mechanism delivers strong results with minimal additional memory overhead. For detailed performance analysis, see [our findings here](https://github.com/huggingface/diffusers/pull/12648#issuecomment-3610615080).
76+
77+
To enable TaylorSeer Cache, create a [`TaylorSeerCacheConfig`] and pass it to your pipeline's transformer:
78+
79+
- `cache_interval`: Number of steps to reuse cached outputs before performing a full forward pass
80+
- `disable_cache_before_step`: Initial steps that use full computations to gather data for approximations
81+
- `max_order`: Approximation accuracy (in theory, higher values improve quality but increase memory usage but we recommend it should be set to `1`)
82+
83+
```python
84+
import torch
85+
from diffusers import FluxPipeline, TaylorSeerCacheConfig
86+
87+
pipe = FluxPipeline.from_pretrained(
88+
"black-forest-labs/FLUX.1-dev",
89+
torch_dtype=torch.bfloat16,
90+
)
91+
pipe.to("cuda")
92+
93+
config = TaylorSeerCacheConfig(
94+
cache_interval=5,
95+
max_order=1,
96+
disable_cache_before_step=10,
97+
taylor_factors_dtype=torch.bfloat16,
98+
)
99+
pipe.transformer.enable_cache(config)
69100
```

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,12 @@
169169
"LayerSkipConfig",
170170
"PyramidAttentionBroadcastConfig",
171171
"SmoothedEnergyGuidanceConfig",
172+
"TaylorSeerCacheConfig",
172173
"apply_faster_cache",
173174
"apply_first_block_cache",
174175
"apply_layer_skip",
175176
"apply_pyramid_attention_broadcast",
177+
"apply_taylorseer_cache",
176178
]
177179
)
178180
_import_structure["models"].extend(
@@ -899,10 +901,12 @@
899901
LayerSkipConfig,
900902
PyramidAttentionBroadcastConfig,
901903
SmoothedEnergyGuidanceConfig,
904+
TaylorSeerCacheConfig,
902905
apply_faster_cache,
903906
apply_first_block_cache,
904907
apply_layer_skip,
905908
apply_pyramid_attention_broadcast,
909+
apply_taylorseer_cache,
906910
)
907911
from .models import (
908912
AllegroTransformer3DModel,

src/diffusers/hooks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@
2525
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
2626
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
2727
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
28+
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache

0 commit comments

Comments
 (0)