Skip to content

AnimateDiffPipeline performance regression on CPU. #12975

@jiqing-feng

Description

@jiqing-feng

Describe the bug

regression PR: #11098 .

The attn_processor will turn the hidden_states into channel last (NHWC) layout, which will have great benefit on CPU when running matmul.

But the op: .contiguous() will turn the tensor into NCHW layout as the tensor orginal layout, it will make the matmul slower than the NHWC layout (converted by attn processor). Besides, the op .contiguous() also costs too much time on CPU if the tensor layout is bad (just like in this case).

Reproduction

numactl -C 0-31 --membind 0 python test.py

from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from transformers import set_seed
import torch
import time

SEED = 42
device = "cpu"
model_dtype = torch.float16
WARM_UP = 4
RUN = 4

set_seed(SEED)

print("\nLoading AnimateDiff-Lightning model...")
model_id = "ByteDance/AnimateDiff-Lightning"
step = 4
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
base = "emilianJR/epiCRealism"

adapter = MotionAdapter().to(device, model_dtype)
adapter.load_state_dict(load_file(hf_hub_download(model_id, ckpt), device=device))
pipe = AnimateDiffPipeline.from_pretrained(base, motion_adapter=adapter, torch_dtype=model_dtype).to(device)
pipe.scheduler = EulerDiscreteScheduler.from_config(
    pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear"
)

def run_inference():
    set_seed(SEED)
    with torch.no_grad():
        output = pipe(
            prompt="An astronaut riding a green horse",
            guidance_scale=1.0,
            num_inference_steps=4,
        ).frames[0]
    return output

# Warm up
print(f"\nWarming up ({WARM_UP} iterations)...")
for i in range(WARM_UP):
    run_inference()
    print(f"  Warm-up {i+1}/{WARM_UP} done")

# Benchmark
print(f"\nRunning benchmark ({RUN} iterations)...")
elapsed_times = []
for i in range(RUN):
    start = time.perf_counter()
    output = run_inference()
    end = time.perf_counter()
    elapsed = (end - start) * 1000  # ms
    elapsed_times.append(elapsed)
    print(f"  Run {i+1}/{RUN}: {elapsed:.2f} ms")

# Statistics
avg_time = sum(elapsed_times) / len(elapsed_times)
min_time = min(elapsed_times)
max_time = max(elapsed_times)
print(f"\n{'='*50}")
print(f"Results ({RUN} runs):")
print(f"  Average: {avg_time:.2f} ms")
print(f"  Min:     {min_time:.2f} ms")
print(f"  Max:     {max_time:.2f} ms")
print(f"{'='*50}")

The pipeline latency has 50% performance regression after the regression PR.

Since the PR is targeted to fix the DDP issue, I think we can check if DDP before using .contiguous(). WDYT? @sayakpaul

Hi @jinc7461 . Could you please provide the script to reproduce the error, and give me some advice to check before using .contiguous() ? Thanks!

cc @sywangyi

Logs

System Info

torch 2.11.0.dev20260113+cpu
platform: Intel Xeon 6

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingperformanceAnything related to performance improvements, profiling and benchmarking

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions