diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 524a92ea9966..ecec7322fac9 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -19,6 +19,7 @@ from .context_parallel import apply_context_parallel from .faster_cache import FasterCacheConfig, apply_faster_cache from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache + from .flux_teacache import FluxTeaCacheConfig, apply_flux_teacache from .group_offloading import apply_group_offloading from .hooks import HookRegistry, ModelHook from .layer_skip import LayerSkipConfig, apply_layer_skip diff --git a/src/diffusers/hooks/flux_teacache.py b/src/diffusers/hooks/flux_teacache.py new file mode 100644 index 000000000000..c31723d99803 --- /dev/null +++ b/src/diffusers/hooks/flux_teacache.py @@ -0,0 +1,474 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Callable, List, Optional + +import numpy as np +import torch + +from ..utils import logging +from .hooks import BaseState, HookRegistry, ModelHook, StateManager + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +_FLUX_TEACACHE_HOOK = "flux_teacache" + + +@dataclass +class FluxTeaCacheConfig: + r""" + Configuration for [TeaCache](https://liewfeng.github.io/TeaCache/) applied to FLUX models. + + TeaCache (Timestep Embedding Aware Cache) is an adaptive caching technique that speeds up diffusion model + inference by reusing transformer block computations when consecutive timestep embeddings are similar. It uses + polynomial rescaling of L1 distances between modulated inputs to intelligently decide when to cache. + + Args: + rel_l1_thresh (`float`, defaults to `0.2`): + Threshold for accumulated relative L1 distance. When the accumulated distance is below this threshold, + the cached residual from the previous timestep is reused instead of computing the full transformer. + Based on the original TeaCache paper, values in the range [0.1, 0.3] work best for balancing speed + and quality: + - 0.25 for ~1.5x speedup with minimal quality loss + - 0.4 for ~1.8x speedup with slight quality loss + - 0.6 for ~2.0x speedup with noticeable quality loss + - 0.8 for ~2.25x speedup with significant quality loss + Higher thresholds lead to more aggressive caching and faster inference, but may reduce output quality. + coefficients (`List[float]`, *optional*, defaults to FLUX-specific polynomial coefficients): + FLUX-specific polynomial coefficients used for rescaling the raw L1 distance. These coefficients + transform the relative L1 distance into a model-specific caching signal. If not provided, defaults + to the coefficients determined for FLUX models in the TeaCache paper: + [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01]. + The polynomial is evaluated as: `c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4]` where x is the + relative L1 distance. + current_timestep_callback (`Callable[[], int]`, *optional*, defaults to `None`): + Callback function that returns the current timestep during inference. This is used internally for + debugging and statistics tracking. If not provided, TeaCache will still function correctly. + num_inference_steps_callback (`Callable[[], int]`, *optional*, defaults to `None`): + Callback function that returns the total number of inference steps. This is used to ensure the first + and last timesteps are always computed (never cached) for maximum quality. If not provided, TeaCache + will attempt to detect the number of steps automatically from the pipeline. + + Examples: + ```python + from diffusers import FluxPipeline + from diffusers.hooks import FluxTeaCacheConfig + + # Load FLUX pipeline + pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + pipe.to("cuda") + + # Enable TeaCache with default settings (1.5x speedup) + config = FluxTeaCacheConfig(rel_l1_thresh=0.2) + pipe.transformer.enable_cache(config) + + # Generate image with caching + image = pipe("A cat sitting on a windowsill", num_inference_steps=4).images[0] + + # Disable caching + pipe.transformer.disable_cache() + + # For more aggressive caching (2x speedup, slight quality loss) + config = FluxTeaCacheConfig(rel_l1_thresh=0.6) + pipe.transformer.enable_cache(config) + image = pipe("A cat sitting on a windowsill", num_inference_steps=4).images[0] + ``` + """ + rel_l1_thresh: float = 0.2 + coefficients: Optional[List[float]] = None + current_timestep_callback: Optional[Callable[[], int]] = None + num_inference_steps_callback: Optional[Callable[[], int]] = None + + def __post_init__(self): + # Validate rel_l1_thresh + if not isinstance(self.rel_l1_thresh, (int, float)): + raise TypeError( + f"rel_l1_thresh must be a number, got {type(self.rel_l1_thresh).__name__}. " + f"Please provide a float value between 0.1 and 1.0." + ) + if self.rel_l1_thresh <= 0: + raise ValueError( + f"rel_l1_thresh must be positive, got {self.rel_l1_thresh}. " + f"Based on the TeaCache paper, values between 0.1 and 0.3 work best. " + f"Try 0.25 for 1.5x speedup or 0.6 for 2x speedup." + ) + if self.rel_l1_thresh < 0.05: + import warnings + warnings.warn( + f"rel_l1_thresh={self.rel_l1_thresh} is very low and may result in minimal caching. " + f"Consider using values between 0.1 and 0.3 for optimal performance.", + UserWarning + ) + if self.rel_l1_thresh > 1.0: + import warnings + warnings.warn( + f"rel_l1_thresh={self.rel_l1_thresh} is very high and may cause quality degradation. " + f"Consider using values between 0.1 and 0.6 for better quality-speed tradeoff.", + UserWarning + ) + + # Set default coefficients if not provided + if self.coefficients is None: + # Original FLUX coefficients from TeaCache paper + self.coefficients = [4.98651651e+02, -2.83781631e+02, + 5.58554382e+01, -3.82021401e+00, 2.64230861e-01] + + # Validate coefficients + if not isinstance(self.coefficients, (list, tuple)): + raise TypeError( + f"coefficients must be a list or tuple, got {type(self.coefficients).__name__}. " + f"Please provide a list of 5 polynomial coefficients." + ) + if len(self.coefficients) != 5: + raise ValueError( + f"coefficients must contain exactly 5 elements for 4th-degree polynomial, " + f"got {len(self.coefficients)}. The polynomial is evaluated as: " + f"c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4]" + ) + if not all(isinstance(c, (int, float)) for c in self.coefficients): + raise TypeError( + f"All coefficients must be numbers. " + f"Got types: {[type(c).__name__ for c in self.coefficients]}" + ) + + def __repr__(self) -> str: + return ( + f"FluxTeaCacheConfig(\n" + f" rel_l1_thresh={self.rel_l1_thresh},\n" + f" coefficients={self.coefficients},\n" + f" current_timestep_callback={self.current_timestep_callback},\n" + f" num_inference_steps_callback={self.num_inference_steps_callback}\n" + f")" + ) + + +class FluxTeaCacheState(BaseState): + """ + State management for FLUX TeaCache hook. + + This class tracks the caching state across diffusion timesteps, managing counters, accumulated distances, + and cached values needed for the TeaCache algorithm. The state persists across multiple forward passes + during a single inference run and is automatically reset when a new inference begins. + + Attributes: + cnt (int): + Current timestep counter, incremented with each forward pass. Used to identify first/last timesteps + which are always computed (never cached) for maximum quality. + num_steps (int): + Total number of inference steps for the current run. Used to identify the last timestep. Automatically + detected from callbacks or pipeline attributes if not explicitly set. + accumulated_rel_l1_distance (float): + Running accumulator for rescaled L1 distances between consecutive modulated inputs. Compared against + the threshold to make caching decisions. Reset to 0 when the decision is made to recompute. + previous_modulated_input (torch.Tensor): + Modulated input from the previous timestep, extracted from the first transformer block's norm1 layer. + Used for computing L1 distance to determine similarity between consecutive timesteps. + previous_residual (torch.Tensor): + Cached residual (output - input) from the previous timestep's full transformer computation. Applied + directly when caching is triggered instead of computing all transformer blocks. + """ + def __init__(self): + self.cnt = 0 + self.num_steps = 0 + self.accumulated_rel_l1_distance = 0.0 + self.previous_modulated_input = None + self.previous_residual = None + + def reset(self): + """Reset all state variables to initial values for a new inference run.""" + self.cnt = 0 + self.num_steps = 0 + self.accumulated_rel_l1_distance = 0.0 + self.previous_modulated_input = None + self.previous_residual = None + + def __repr__(self) -> str: + return ( + f"FluxTeaCacheState(\n" + f" cnt={self.cnt},\n" + f" num_steps={self.num_steps},\n" + f" accumulated_rel_l1_distance={self.accumulated_rel_l1_distance:.6f},\n" + f" previous_modulated_input={'cached' if self.previous_modulated_input is not None else 'None'},\n" + f" previous_residual={'cached' if self.previous_residual is not None else 'None'}\n" + f")" + ) + + +class FluxTeaCacheHook(ModelHook): + """ + ModelHook implementing TeaCache for FLUX transformer models. + + This hook intercepts the FLUX transformer forward pass and implements adaptive caching based on timestep + embedding similarity. It extracts modulated inputs from the first transformer block, computes L1 distances, + applies polynomial rescaling, and decides whether to reuse cached residuals or compute full transformer blocks. + + The hook follows the original TeaCache algorithm from the paper: + 1. Extract modulated input from first transformer block's norm1 layer with timestep embedding + 2. Compute relative L1 distance between current and previous modulated inputs + 3. Apply polynomial rescaling with FLUX-specific coefficients to the distance + 4. Accumulate rescaled distances and compare to threshold + 5. If below threshold: reuse cached residual (fast path, skip transformer computation) + 6. If above threshold: compute full transformer blocks and cache new residual (slow path) + + The first and last timesteps are always computed fully (never cached) to ensure maximum quality. + + Attributes: + config (FluxTeaCacheConfig): + Configuration containing threshold, polynomial coefficients, and optional callbacks. + rescale_func (np.poly1d): + Polynomial function for rescaling L1 distances using FLUX-specific coefficients. + state_manager (StateManager): + Manages FluxTeaCacheState across forward passes, maintaining counters and cached values. + """ + + _is_stateful = True + + def __init__(self, config: FluxTeaCacheConfig): + super().__init__() + self.config = config + self.rescale_func = np.poly1d(config.coefficients) + self.state_manager = StateManager(FluxTeaCacheState, (), {}) + + def initialize_hook(self, module): + self.state_manager.set_context("flux_teacache") + return module + + def new_forward(self, module, hidden_states, timestep, pooled_projections, + encoder_hidden_states, txt_ids, img_ids, **kwargs): + """ + Replace FLUX transformer forward pass with TeaCache-enabled version. + + This method implements the full TeaCache algorithm inline, processing transformer blocks directly instead + of calling the original forward method. It extracts modulated inputs, makes caching decisions, and either + applies cached residuals (fast path) or computes full transformer blocks (slow path). + + Args: + module: The FluxTransformer2DModel instance. + hidden_states (`torch.Tensor`): Input latent tensor of shape (batch, channels, height, width). + timestep (`torch.Tensor`): Current diffusion timestep. + pooled_projections (`torch.Tensor`): Pooled text embeddings for timestep conditioning. + encoder_hidden_states (`torch.Tensor`): Text encoder outputs for cross-attention. + txt_ids (`torch.Tensor`): Position IDs for text tokens. + img_ids (`torch.Tensor`): Position IDs for image tokens. + **kwargs: Additional arguments including 'guidance' and 'joint_attention_kwargs'. + + Returns: + `torch.Tensor`: Denoised output tensor. + """ + state = self.state_manager.get_state() + + # Reset counter if we've completed all steps (new inference run) + if state.cnt == state.num_steps and state.num_steps > 0: + logger.info("TeaCache inference completed") + state.cnt = 0 + state.accumulated_rel_l1_distance = 0.0 + state.previous_modulated_input = None + state.previous_residual = None + + # Set num_steps on first timestep if not already set + if state.cnt == 0 and state.num_steps == 0: + if self.config.num_inference_steps_callback is not None: + state.num_steps = self.config.num_inference_steps_callback() + # If still not set, try to get from module attribute (set by pipeline) + if state.num_steps == 0 and hasattr(module, 'num_steps'): + state.num_steps = module.num_steps + + # Process inputs like original TeaCache + # Must process hidden_states through x_embedder first + hidden_states = module.x_embedder(hidden_states) + + # Extract timestep embedding + timestep_scaled = timestep.to(hidden_states.dtype) * 1000 + if kwargs.get('guidance') is not None: + guidance = kwargs['guidance'].to(hidden_states.dtype) * 1000 + temb = module.time_text_embed(timestep_scaled, guidance, pooled_projections) + else: + temb = module.time_text_embed(timestep_scaled, pooled_projections) + + # Extract modulated input from first transformer block like original + inp = hidden_states.clone() + temb_clone = temb.clone() + modulated_inp, _, _, _, _ = module.transformer_blocks[0].norm1(inp, emb=temb_clone) + + # Make caching decision + should_calc = self._should_compute_full_transformer(state, modulated_inp) + + if not should_calc: + # Fast path: apply cached residual + logger.debug( + f"TeaCache: reusing cached residual at step {state.cnt}/{state.num_steps} " + f"(accumulated distance: {state.accumulated_rel_l1_distance:.6f})" + ) + output = hidden_states + state.previous_residual + else: + # Slow path: full computation inline (like original TeaCache) + logger.debug( + f"TeaCache: computing full transformer at step {state.cnt}/{state.num_steps} " + f"(accumulated distance: {state.accumulated_rel_l1_distance:.6f})" + ) + ori_hidden_states = hidden_states.clone() + + # Process encoder_hidden_states + encoder_hidden_states = module.context_embedder(encoder_hidden_states) + + # Process txt_ids and img_ids + if txt_ids.ndim == 3: + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + img_ids = img_ids[0] + + ids = torch.cat((txt_ids, img_ids), dim=0) + image_rotary_emb = module.pos_embed(ids) + + # Process through transformer blocks + for block in module.transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=kwargs.get('joint_attention_kwargs'), + ) + + # Process through single transformer blocks + # Note: single blocks concatenate internally, so pass separately + for block in module.single_transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=kwargs.get('joint_attention_kwargs'), + ) + + # Cache the residual + state.previous_residual = hidden_states - ori_hidden_states + + state.previous_modulated_input = modulated_inp + state.cnt += 1 + + # Apply final norm and projection (always needed) + hidden_states = module.norm_out(hidden_states, temb) + output = module.proj_out(hidden_states) + + return output + + def _should_compute_full_transformer(self, state, modulated_inp): + """ + Determine whether to compute full transformer blocks or reuse cached residual. + + This method implements the core caching decision logic from the TeaCache paper: + - Always compute first and last timesteps (for maximum quality) + - For intermediate timesteps, compute relative L1 distance between current and previous modulated inputs + - Apply polynomial rescaling to convert distance to model-specific caching signal + - Accumulate rescaled distances and compare to threshold + - Return True (compute) if accumulated distance exceeds threshold, False (cache) otherwise + + Args: + state (`FluxTeaCacheState`): Current state containing counters and cached values. + modulated_inp (`torch.Tensor`): Modulated input from first transformer block's norm1 layer. + + Returns: + `bool`: True to compute full transformer, False to reuse cached residual. + """ + # Compute first timestep + if state.cnt == 0: + state.accumulated_rel_l1_distance = 0 + return True + + # compute last timestep (if num_steps is set) + if state.num_steps > 0 and state.cnt == state.num_steps - 1: + state.accumulated_rel_l1_distance = 0 + return True + + # Need previous modulated input for comparison + if state.previous_modulated_input is None: + return True + + # Compute relative L1 distance + rel_distance = ((modulated_inp - state.previous_modulated_input).abs().mean() + / state.previous_modulated_input.abs().mean()).cpu().item() + + # Apply polynomial rescaling + rescaled_distance = self.rescale_func(rel_distance) + state.accumulated_rel_l1_distance += rescaled_distance + + # Make decision based on accumulated threshold + if state.accumulated_rel_l1_distance < self.config.rel_l1_thresh: + return False + else: + state.accumulated_rel_l1_distance = 0 # Reset accumulator + return True + + def reset_state(self, module): + self.state_manager.reset() + return module + + +def apply_flux_teacache(module, config: FluxTeaCacheConfig): + """ + Apply TeaCache optimization to a FLUX transformer model. + + This function registers a FluxTeaCacheHook on the provided FLUX transformer, enabling adaptive caching of + transformer block computations based on timestep embedding similarity. The hook intercepts the forward pass + and implements the TeaCache algorithm to achieve 1.5x-2x speedup with minimal quality loss. + + Args: + module: The FLUX transformer model (FluxTransformer2DModel) to optimize. + config (`FluxTeaCacheConfig`): Configuration specifying caching threshold and optional callbacks. + + Raises: + ValueError: If the module is not a FluxTransformer2DModel. + + Examples: + ```python + from diffusers import FluxPipeline + from diffusers.hooks import FluxTeaCacheConfig, apply_flux_teacache + + # Load FLUX pipeline + pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + pipe.to("cuda") + + # Apply TeaCache directly to transformer + config = FluxTeaCacheConfig(rel_l1_thresh=0.2) + apply_flux_teacache(pipe.transformer, config) + + # Generate with caching enabled + image = pipe("A cat on a windowsill", num_inference_steps=4).images[0] + + # Or use the convenience method via CacheMixin + pipe.transformer.enable_cache(config) + ``` + + Note: + For most use cases, it's recommended to use the CacheMixin interface: + `pipe.transformer.enable_cache(FluxTeaCacheConfig(...))` which provides additional convenience methods + like `disable_cache()` for easy toggling. + """ + from ..models.transformers.transformer_flux import FluxTransformer2DModel + + # Validate FLUX model + if not isinstance(module, FluxTransformer2DModel): + raise ValueError( + f"TeaCache currently supports only FLUX transformer models. " + f"Got {type(module).__name__}. Please ensure you're applying TeaCache to a " + f"FluxTransformer2DModel instance (e.g., pipe.transformer)." + ) + + # Register hook on main transformer + registry = HookRegistry.check_if_exists_or_initialize(module) + hook = FluxTeaCacheHook(config) + registry.register_hook(hook, _FLUX_TEACACHE_HOOK) diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 605c0d588c8c..48055716307f 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -28,6 +28,7 @@ class CacheMixin: - [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) - [FasterCache](https://huggingface.co/papers/2410.19355) - [FirstBlockCache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching) + - [TeaCache](https://huggingface.co/papers/2411.19108) (FLUX-specific) """ _cache_config = None @@ -66,9 +67,11 @@ def enable_cache(self, config) -> None: from ..hooks import ( FasterCacheConfig, FirstBlockCacheConfig, + FluxTeaCacheConfig, PyramidAttentionBroadcastConfig, apply_faster_cache, apply_first_block_cache, + apply_flux_teacache, apply_pyramid_attention_broadcast, ) @@ -83,15 +86,18 @@ def enable_cache(self, config) -> None: apply_first_block_cache(self, config) elif isinstance(config, PyramidAttentionBroadcastConfig): apply_pyramid_attention_broadcast(self, config) + elif isinstance(config, FluxTeaCacheConfig): + apply_flux_teacache(self, config) else: raise ValueError(f"Cache config {type(config)} is not supported.") self._cache_config = config def disable_cache(self) -> None: - from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig + from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, FluxTeaCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK + from ..hooks.flux_teacache import _FLUX_TEACACHE_HOOK from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK if self._cache_config is None: @@ -107,6 +113,8 @@ def disable_cache(self) -> None: registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True) elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig): registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True) + elif isinstance(self._cache_config, FluxTeaCacheConfig): + registry.remove_hook(_FLUX_TEACACHE_HOOK, recurse=True) else: raise ValueError(f"Cache config {type(self._cache_config)} is not supported.") @@ -128,3 +136,12 @@ def cache_context(self, name: str): yield registry._set_context(None) + + def enable_flux_teacache(self, rel_l1_thresh: float = 0.2, **kwargs): + r""" + Enable FLUX TeaCache on the model. + """ + from ..hooks import FluxTeaCacheConfig + + config = FluxTeaCacheConfig(rel_l1_thresh=rel_l1_thresh, **kwargs) + self.enable_cache(config) diff --git a/tests/hooks/test_flux_teacache.py b/tests/hooks/test_flux_teacache.py new file mode 100644 index 000000000000..1a2bcd613520 --- /dev/null +++ b/tests/hooks/test_flux_teacache.py @@ -0,0 +1,167 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import warnings + +import torch + +from diffusers.hooks import FluxTeaCacheConfig, HookRegistry + + +class FluxTeaCacheConfigTests(unittest.TestCase): + """Tests for FluxTeaCacheConfig parameter validation.""" + + def test_valid_config(self): + """Test valid configuration is accepted.""" + config = FluxTeaCacheConfig(rel_l1_thresh=0.2) + self.assertEqual(config.rel_l1_thresh, 0.2) + self.assertIsNotNone(config.coefficients) + self.assertEqual(len(config.coefficients), 5) + + def test_invalid_type(self): + """Test invalid type for rel_l1_thresh raises TypeError.""" + with self.assertRaises(TypeError) as context: + FluxTeaCacheConfig(rel_l1_thresh="invalid") + self.assertIn("must be a number", str(context.exception)) + + def test_negative_value(self): + """Test negative threshold raises ValueError.""" + with self.assertRaises(ValueError) as context: + FluxTeaCacheConfig(rel_l1_thresh=-0.5) + self.assertIn("must be positive", str(context.exception)) + + def test_invalid_coefficients_length(self): + """Test wrong coefficient count raises ValueError.""" + with self.assertRaises(ValueError) as context: + FluxTeaCacheConfig(rel_l1_thresh=0.2, coefficients=[1.0, 2.0, 3.0]) + self.assertIn("exactly 5 elements", str(context.exception)) + + def test_invalid_coefficients_type(self): + """Test invalid coefficient types raise TypeError.""" + with self.assertRaises(TypeError) as context: + FluxTeaCacheConfig(rel_l1_thresh=0.2, coefficients=[1.0, 2.0, "invalid", 4.0, 5.0]) + self.assertIn("must be numbers", str(context.exception)) + + def test_warning_very_low_threshold(self): + """Test warning is issued for very low threshold.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + FluxTeaCacheConfig(rel_l1_thresh=0.01) + self.assertEqual(len(w), 1) + self.assertIn("very low", str(w[0].message)) + + def test_warning_very_high_threshold(self): + """Test warning is issued for very high threshold.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + FluxTeaCacheConfig(rel_l1_thresh=1.5) + self.assertEqual(len(w), 1) + self.assertIn("very high", str(w[0].message)) + + def test_config_repr(self): + """Test __repr__ method works correctly.""" + config = FluxTeaCacheConfig(rel_l1_thresh=0.25) + repr_str = repr(config) + self.assertIn("FluxTeaCacheConfig", repr_str) + self.assertIn("0.25", repr_str) + + def test_custom_coefficients(self): + """Test custom coefficients are accepted.""" + custom_coeffs = [1.0, 2.0, 3.0, 4.0, 5.0] + config = FluxTeaCacheConfig(rel_l1_thresh=0.2, coefficients=custom_coeffs) + self.assertEqual(config.coefficients, custom_coeffs) + + +class FluxTeaCacheStateTests(unittest.TestCase): + """Tests for FluxTeaCacheState.""" + + def test_state_initialization(self): + """Test state initializes with correct default values.""" + from diffusers.hooks.flux_teacache import FluxTeaCacheState + + state = FluxTeaCacheState() + self.assertEqual(state.cnt, 0) + self.assertEqual(state.num_steps, 0) + self.assertEqual(state.accumulated_rel_l1_distance, 0.0) + self.assertIsNone(state.previous_modulated_input) + self.assertIsNone(state.previous_residual) + + def test_state_reset(self): + """Test state reset clears all values.""" + from diffusers.hooks.flux_teacache import FluxTeaCacheState + + state = FluxTeaCacheState() + # Modify state + state.cnt = 5 + state.num_steps = 10 + state.accumulated_rel_l1_distance = 0.5 + state.previous_modulated_input = torch.randn(1, 10) + state.previous_residual = torch.randn(1, 10) + + # Reset + state.reset() + + # Verify reset + self.assertEqual(state.cnt, 0) + self.assertEqual(state.num_steps, 0) + self.assertEqual(state.accumulated_rel_l1_distance, 0.0) + self.assertIsNone(state.previous_modulated_input) + self.assertIsNone(state.previous_residual) + + def test_state_repr(self): + """Test __repr__ method works correctly.""" + from diffusers.hooks.flux_teacache import FluxTeaCacheState + + state = FluxTeaCacheState() + state.cnt = 3 + state.num_steps = 10 + repr_str = repr(state) + self.assertIn("FluxTeaCacheState", repr_str) + self.assertIn("cnt=3", repr_str) + self.assertIn("num_steps=10", repr_str) + + +class FluxTeaCacheHookTests(unittest.TestCase): + """Tests for FluxTeaCacheHook functionality.""" + + def test_hook_initialization(self): + """Test hook initializes correctly with config.""" + from diffusers.hooks.flux_teacache import FluxTeaCacheHook + + config = FluxTeaCacheConfig(rel_l1_thresh=0.2) + hook = FluxTeaCacheHook(config) + + self.assertEqual(hook.config.rel_l1_thresh, 0.2) + self.assertIsNotNone(hook.rescale_func) + self.assertIsNotNone(hook.state_manager) + + def test_apply_flux_teacache_validation(self): + """Test apply_flux_teacache validates input module type.""" + from diffusers.hooks import apply_flux_teacache + + # Create a dummy module that's not a FluxTransformer2DModel + class DummyModule(torch.nn.Module): + pass + + module = DummyModule() + config = FluxTeaCacheConfig(rel_l1_thresh=0.2) + + with self.assertRaises(ValueError) as context: + apply_flux_teacache(module, config) + self.assertIn("FLUX transformer models", str(context.exception)) + + +if __name__ == "__main__": + unittest.main()