Skip to content

Conversation

@toilaluan
Copy link

What does this PR do?

Adding TaylorSeer Caching method to accelerate inference speed mentioned in #12569

Author's codebase: https://github.com/Shenyi-Z/TaylorSeer

This PR structure will heavily mimic FasterCache (https://github.com/huggingface/diffusers/pull/10163/files) behaviour
I prioritze to make it work on image model pipelines (Flux, Qwen Image) for ease of evaluation

Expected Output

4->5x speeding up by these settings while keep output images are qualified

image

State Design

Core of this algorithm is about predict features of step t by using real computed features from previous step using Taylor Expansion Approximation.
We design a State class, include predict & update method and taylor_factors: Tensor to maintain iteration information. Each feature tensor will be bounded to a state instance (in double stream attention class in Flux & QwenImage, output of this module is image_features & txt_features, we will create 2 state instances for them)

  • update method will be called from real compute timestep and update taylor_factors using math formular referenced to original implementation
  • predict method will be called to predict feature from current taylor_factors using math formular referenced to original implementation

@seed93
Copy link

seed93 commented Nov 14, 2025

Will you adapt this great PR for flux kontext controlnet or flux controlnet? It would be nice if it is implemented and I am very eager to try it out.

@toilaluan
Copy link
Author

@seed93 yes, i am prioritizing for flux series and qwen image

@toilaluan
Copy link
Author

Here is analysis about TaylorSeer for Flux
Comparing with baseline, the output image is different, although PAB method give pretty close result
This result is match with author's implementation

model_id cache_method compute_dtype compile time model_memory model_max_memory_reserved inference_memory inference_max_memory_reserved
flux none fp16 False 22.318 33.313 33.322 33.322 34.305
flux pyramid_attention_broadcast fp16 False 18.394 33.313 33.322 33.322 35.789
flux taylorseer_cache fp16 False 6.457 33.313 33.322 33.322 38.18

Flux visual results

Baseline

image

Pyramid Attention Broadcast

image

TaylorSeer Cache (this implementation)

image

TaylorSeer Original (https://github.com/Shenyi-Z/TaylorSeer/blob/main/TaylorSeers-Diffusers/taylorseer_flux/diffusers_taylorseer_flux.py)

image

Benchmark code is based on #10163

import argparse
import gc
import pathlib
import traceback

import git
import pandas as pd
import torch
from diffusers import (
    AllegroPipeline,
    CogVideoXPipeline,
    FluxPipeline,
    HunyuanVideoPipeline,
    LattePipeline,
    MochiPipeline,
)
from diffusers.models import HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_info, set_verbosity_debug
from tabulate import tabulate


repo = git.Repo(path="/root/diffusers")
branch = repo.active_branch

from diffusers import (
    apply_taylorseer_cache, 
    TaylorSeerCacheConfig, 
    apply_faster_cache, 
    FasterCacheConfig, 
    apply_pyramid_attention_broadcast, 
    PyramidAttentionBroadcastConfig,
)

def pretty_print_results(results, precision: int = 3):
    def format_value(value):
        if isinstance(value, float):
            return f"{value:.{precision}f}"
        return value

    filtered_table = {k: format_value(v) for k, v in results.items()}
    print(tabulate([filtered_table], headers="keys", tablefmt="pipe", stralign="center"))


def benchmark_fn(f, *args, **kwargs):
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    output = f(*args, **kwargs)
    end.record()
    torch.cuda.synchronize()
    elapsed_time = round(start.elapsed_time(end) / 1000, 3)

    return elapsed_time, output

def prepare_flux(dtype: torch.dtype) -> None:
    model_id = "black-forest-labs/FLUX.1-dev"
    print(f"Loading {model_id} with {dtype} dtype")
    pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=dtype, use_safetensors=True)
    pipe.to("cuda")
    generation_kwargs = {
        "prompt": "A cat holding a sign that says hello world",
        "height": 1024,
        "width": 1024,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
    }

    return pipe, generation_kwargs

def prepare_flux_config(cache_method: str, pipe: FluxPipeline):
    if cache_method == "pyramid_attention_broadcast":
        return PyramidAttentionBroadcastConfig(
            spatial_attention_block_skip_range=2,
            spatial_attention_timestep_skip_range=(100, 950),
            spatial_attention_block_identifiers=["transformer_blocks", "single_transformer_blocks"],
            current_timestep_callback=lambda: pipe.current_timestep,
        )
    elif cache_method == "taylorseer_cache":
        return TaylorSeerCacheConfig(predict_steps=5, max_order=1, warmup_steps=3, taylor_factors_dtype=torch.float16, architecture="flux")
    elif cache_method == "fastercache":
        return FasterCacheConfig(
        spatial_attention_block_skip_range=2,
        spatial_attention_timestep_skip_range=(-1, 681),
        low_frequency_weight_update_timestep_range=(99, 641),
        high_frequency_weight_update_timestep_range=(-1, 301),
        spatial_attention_block_identifiers=["transformer_blocks"],
        attention_weight_callback=lambda _: 0.3,
        tensor_format="BFCHW",
    )
    elif cache_method == "none":
        return None


def decode_flux(pipe: FluxPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    height = kwargs["height"]
    width = kwargs["width"]
    filename = f"{filename.as_posix()}.png"
    latents = pipe._unpack_latents(latents, height, width, pipe.vae_scale_factor)
    latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
    image = pipe.vae.decode(latents, return_dict=False)[0]
    image = pipe.image_processor.postprocess(image, output_type="pil")[0]
    image.save(filename)
    return filename


MODEL_MAPPING = {
    "flux": {
        "prepare": prepare_flux,
        "config": prepare_flux_config,
        "decode": decode_flux,
    },
}

STR_TO_COMPUTE_DTYPE = {
    "bf16": torch.bfloat16,
    "fp16": torch.float16,
    "fp32": torch.float32,
}


def run_inference(pipe, generation_kwargs):
    generator = torch.Generator(device="cuda").manual_seed(181201)
    print(f"Generator: {generator}")
    print(f"Generation kwargs: {generation_kwargs}")
    output = pipe(generator=generator, output_type="latent", **generation_kwargs)[0]
    torch.cuda.synchronize()
    return output


@torch.no_grad()
def main(model_id: str, cache_method: str, output_dir: str, dtype: str):
    if model_id not in MODEL_MAPPING.keys():
        raise ValueError("Unsupported `model_id` specified.")

    output_dir = pathlib.Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    csv_filename = output_dir / f"{model_id}.csv"

    compute_dtype = STR_TO_COMPUTE_DTYPE[dtype]
    model = MODEL_MAPPING[model_id]

    try:
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.reset_accumulated_memory_stats()
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        torch.cuda.synchronize()

        # 1. Prepare inputs and generation kwargs
        pipe, generation_kwargs = model["prepare"](dtype=compute_dtype)

        model_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
        model_max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)

        # 2. Apply attention approximation technique
        config = model["config"](cache_method, pipe)
        if cache_method == "pyramid_attention_broadcast":
            apply_pyramid_attention_broadcast(pipe.transformer, config)
        elif cache_method == "fastercache":
            apply_faster_cache(pipe.transformer, config)
        elif cache_method == "taylorseer_cache":
            apply_taylorseer_cache(pipe.transformer, config)
        elif cache_method == "none":
            pass
        else:
            raise ValueError(f"Invalid {cache_method=} provided.")

        # 4. Benchmark
        time, latents = benchmark_fn(run_inference, pipe, generation_kwargs)
        inference_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
        inference_max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)

        # 5. Decode latents
        filename = output_dir / f"{model_id}---dtype-{dtype}---cache_method-{cache_method}---compile-{compile}"
        filename = model["decode"](
            pipe,
            latents,
            filename,
            height=generation_kwargs["height"],
            width=generation_kwargs["width"],
            video_length=generation_kwargs.get("video_length", None),
        )

        # 6. Save artifacts
        info = {
            "model_id": model_id,
            "cache_method": cache_method,
            "compute_dtype": dtype,
            "time": time,
            "model_memory": model_memory,
            "model_max_memory_reserved": model_max_memory_reserved,
            "inference_memory": inference_memory,
            "inference_max_memory_reserved": inference_max_memory_reserved,
            "branch": branch,
            "filename": filename,
            "exception": None,
        }

    except Exception as e:
        print(f"An error occurred: {e}")
        traceback.print_exc()

        # 6. Save artifacts
        info = {
            "model_id": model_id,
            "cache_method": cache_method,
            "compute_dtype": dtype,
            "time": None,
            "model_memory": None,
            "model_max_memory_reserved": None,
            "inference_memory": None,
            "inference_max_memory_reserved": None,
            "branch": branch,
            "filename": None,
            "exception": str(e),
        }

    pretty_print_results(info, precision=3)

    df = pd.DataFrame([info])
    df.to_csv(csv_filename.as_posix(), mode="a", index=False, header=not csv_filename.is_file())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_id",
        type=str,
        default="flux",
        choices=["flux"],
        help="Model to run benchmark for.",
    )
    parser.add_argument(
        "--cache_method",
        type=str,
        default="pyramid_attention_broadcast",
        choices=["pyramid_attention_broadcast", "fastercache", "taylorseer_cache", "none"],
        help="Cache method to use.",
    )
    parser.add_argument(
        "--output_dir", type=str, help="Path where the benchmark artifacts and outputs are the be saved."
    )
    parser.add_argument("--dtype", type=str, help="torch.dtype to use for inference")
    parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose logging.")
    args = parser.parse_args()

    if args.verbose:
        set_verbosity_debug()
    else:
        set_verbosity_info()

    main(args.model_id, args.cache_method, args.output_dir, args.dtype)
    

@toilaluan
Copy link
Author

More comparison between this impl, baseline, author's impl

image

@toilaluan
Copy link
Author

I think current implementation is unified for every models that have attention modules, but to achieve full optimization, we have to config regex for which layer to cache or skip compute
Example in a sequence of Linear1, Act1, Linear2, Act2: we need to add hook for Linear1,act1,linear2 to do nothing (return an empty tensor) but cache output of act2
I already fix template for flux, but for other models, user have to write their own and pass it to the config init
@sayakpaul how do you think about this mechanism? I need some advises here

@sayakpaul sayakpaul requested a review from DN6 November 14, 2025 17:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants