-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[Feat] TaylorSeer Cache #12648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Feat] TaylorSeer Cache #12648
Conversation
|
Will you adapt this great PR for flux kontext controlnet or flux controlnet? It would be nice if it is implemented and I am very eager to try it out. |
|
@seed93 yes, i am prioritizing for flux series and qwen image |
|
Here is analysis about TaylorSeer for Flux
Flux visual resultsBaseline
Pyramid Attention Broadcast
TaylorSeer Cache (this implementation)
TaylorSeer Original (https://github.com/Shenyi-Z/TaylorSeer/blob/main/TaylorSeers-Diffusers/taylorseer_flux/diffusers_taylorseer_flux.py)
Benchmark code is based on #10163 import argparse
import gc
import pathlib
import traceback
import git
import pandas as pd
import torch
from diffusers import (
AllegroPipeline,
CogVideoXPipeline,
FluxPipeline,
HunyuanVideoPipeline,
LattePipeline,
MochiPipeline,
)
from diffusers.models import HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_info, set_verbosity_debug
from tabulate import tabulate
repo = git.Repo(path="/root/diffusers")
branch = repo.active_branch
from diffusers import (
apply_taylorseer_cache,
TaylorSeerCacheConfig,
apply_faster_cache,
FasterCacheConfig,
apply_pyramid_attention_broadcast,
PyramidAttentionBroadcastConfig,
)
def pretty_print_results(results, precision: int = 3):
def format_value(value):
if isinstance(value, float):
return f"{value:.{precision}f}"
return value
filtered_table = {k: format_value(v) for k, v in results.items()}
print(tabulate([filtered_table], headers="keys", tablefmt="pipe", stralign="center"))
def benchmark_fn(f, *args, **kwargs):
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
output = f(*args, **kwargs)
end.record()
torch.cuda.synchronize()
elapsed_time = round(start.elapsed_time(end) / 1000, 3)
return elapsed_time, output
def prepare_flux(dtype: torch.dtype) -> None:
model_id = "black-forest-labs/FLUX.1-dev"
print(f"Loading {model_id} with {dtype} dtype")
pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=dtype, use_safetensors=True)
pipe.to("cuda")
generation_kwargs = {
"prompt": "A cat holding a sign that says hello world",
"height": 1024,
"width": 1024,
"num_inference_steps": 50,
"guidance_scale": 5.0,
}
return pipe, generation_kwargs
def prepare_flux_config(cache_method: str, pipe: FluxPipeline):
if cache_method == "pyramid_attention_broadcast":
return PyramidAttentionBroadcastConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(100, 950),
spatial_attention_block_identifiers=["transformer_blocks", "single_transformer_blocks"],
current_timestep_callback=lambda: pipe.current_timestep,
)
elif cache_method == "taylorseer_cache":
return TaylorSeerCacheConfig(predict_steps=5, max_order=1, warmup_steps=3, taylor_factors_dtype=torch.float16, architecture="flux")
elif cache_method == "fastercache":
return FasterCacheConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(-1, 681),
low_frequency_weight_update_timestep_range=(99, 641),
high_frequency_weight_update_timestep_range=(-1, 301),
spatial_attention_block_identifiers=["transformer_blocks"],
attention_weight_callback=lambda _: 0.3,
tensor_format="BFCHW",
)
elif cache_method == "none":
return None
def decode_flux(pipe: FluxPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
height = kwargs["height"]
width = kwargs["width"]
filename = f"{filename.as_posix()}.png"
latents = pipe._unpack_latents(latents, height, width, pipe.vae_scale_factor)
latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
image = pipe.vae.decode(latents, return_dict=False)[0]
image = pipe.image_processor.postprocess(image, output_type="pil")[0]
image.save(filename)
return filename
MODEL_MAPPING = {
"flux": {
"prepare": prepare_flux,
"config": prepare_flux_config,
"decode": decode_flux,
},
}
STR_TO_COMPUTE_DTYPE = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}
def run_inference(pipe, generation_kwargs):
generator = torch.Generator(device="cuda").manual_seed(181201)
print(f"Generator: {generator}")
print(f"Generation kwargs: {generation_kwargs}")
output = pipe(generator=generator, output_type="latent", **generation_kwargs)[0]
torch.cuda.synchronize()
return output
@torch.no_grad()
def main(model_id: str, cache_method: str, output_dir: str, dtype: str):
if model_id not in MODEL_MAPPING.keys():
raise ValueError("Unsupported `model_id` specified.")
output_dir = pathlib.Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
csv_filename = output_dir / f"{model_id}.csv"
compute_dtype = STR_TO_COMPUTE_DTYPE[dtype]
model = MODEL_MAPPING[model_id]
try:
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_accumulated_memory_stats()
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
torch.cuda.synchronize()
# 1. Prepare inputs and generation kwargs
pipe, generation_kwargs = model["prepare"](dtype=compute_dtype)
model_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
model_max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)
# 2. Apply attention approximation technique
config = model["config"](cache_method, pipe)
if cache_method == "pyramid_attention_broadcast":
apply_pyramid_attention_broadcast(pipe.transformer, config)
elif cache_method == "fastercache":
apply_faster_cache(pipe.transformer, config)
elif cache_method == "taylorseer_cache":
apply_taylorseer_cache(pipe.transformer, config)
elif cache_method == "none":
pass
else:
raise ValueError(f"Invalid {cache_method=} provided.")
# 4. Benchmark
time, latents = benchmark_fn(run_inference, pipe, generation_kwargs)
inference_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
inference_max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)
# 5. Decode latents
filename = output_dir / f"{model_id}---dtype-{dtype}---cache_method-{cache_method}---compile-{compile}"
filename = model["decode"](
pipe,
latents,
filename,
height=generation_kwargs["height"],
width=generation_kwargs["width"],
video_length=generation_kwargs.get("video_length", None),
)
# 6. Save artifacts
info = {
"model_id": model_id,
"cache_method": cache_method,
"compute_dtype": dtype,
"time": time,
"model_memory": model_memory,
"model_max_memory_reserved": model_max_memory_reserved,
"inference_memory": inference_memory,
"inference_max_memory_reserved": inference_max_memory_reserved,
"branch": branch,
"filename": filename,
"exception": None,
}
except Exception as e:
print(f"An error occurred: {e}")
traceback.print_exc()
# 6. Save artifacts
info = {
"model_id": model_id,
"cache_method": cache_method,
"compute_dtype": dtype,
"time": None,
"model_memory": None,
"model_max_memory_reserved": None,
"inference_memory": None,
"inference_max_memory_reserved": None,
"branch": branch,
"filename": None,
"exception": str(e),
}
pretty_print_results(info, precision=3)
df = pd.DataFrame([info])
df.to_csv(csv_filename.as_posix(), mode="a", index=False, header=not csv_filename.is_file())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_id",
type=str,
default="flux",
choices=["flux"],
help="Model to run benchmark for.",
)
parser.add_argument(
"--cache_method",
type=str,
default="pyramid_attention_broadcast",
choices=["pyramid_attention_broadcast", "fastercache", "taylorseer_cache", "none"],
help="Cache method to use.",
)
parser.add_argument(
"--output_dir", type=str, help="Path where the benchmark artifacts and outputs are the be saved."
)
parser.add_argument("--dtype", type=str, help="torch.dtype to use for inference")
parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose logging.")
args = parser.parse_args()
if args.verbose:
set_verbosity_debug()
else:
set_verbosity_info()
main(args.model_id, args.cache_method, args.output_dir, args.dtype)
|
|
I think current implementation is unified for every models that have attention modules, but to achieve full optimization, we have to config regex for which layer to cache or skip compute |





What does this PR do?
Adding TaylorSeer Caching method to accelerate inference speed mentioned in #12569
Author's codebase: https://github.com/Shenyi-Z/TaylorSeer
This PR structure will heavily mimic FasterCache (https://github.com/huggingface/diffusers/pull/10163/files) behaviour
I prioritze to make it work on image model pipelines (Flux, Qwen Image) for ease of evaluation
Expected Output
4->5x speeding up by these settings while keep output images are qualified
State Design
Core of this algorithm is about predict features of step
tby using real computed features from previous step using Taylor Expansion Approximation.We design a State class, include
predict&updatemethod andtaylor_factors: Tensorto maintain iteration information. Each feature tensor will be bounded to a state instance (in double stream attention class in Flux & QwenImage, output of this module is image_features & txt_features, we will create 2 state instances for them)updatemethod will be called from real compute timestep and updatetaylor_factorsusing math formular referenced to original implementationpredictmethod will be called to predict feature from currenttaylor_factorsusing math formular referenced to original implementation