-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Description
Describe the bug
regression PR: #11098 .
The attn_processor will turn the hidden_states into channel last (NHWC) layout, which will have great benefit on CPU when running matmul.
But the op: .contiguous() will turn the tensor into NCHW layout as the tensor orginal layout, it will make the matmul slower than the NHWC layout (converted by attn processor). Besides, the op .contiguous() also costs too much time on CPU if the tensor layout is bad (just like in this case).
Reproduction
numactl -C 0-31 --membind 0 python test.py
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from transformers import set_seed
import torch
import time
SEED = 42
device = "cpu"
model_dtype = torch.float16
WARM_UP = 4
RUN = 4
set_seed(SEED)
print("\nLoading AnimateDiff-Lightning model...")
model_id = "ByteDance/AnimateDiff-Lightning"
step = 4
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
base = "emilianJR/epiCRealism"
adapter = MotionAdapter().to(device, model_dtype)
adapter.load_state_dict(load_file(hf_hub_download(model_id, ckpt), device=device))
pipe = AnimateDiffPipeline.from_pretrained(base, motion_adapter=adapter, torch_dtype=model_dtype).to(device)
pipe.scheduler = EulerDiscreteScheduler.from_config(
pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear"
)
def run_inference():
set_seed(SEED)
with torch.no_grad():
output = pipe(
prompt="An astronaut riding a green horse",
guidance_scale=1.0,
num_inference_steps=4,
).frames[0]
return output
# Warm up
print(f"\nWarming up ({WARM_UP} iterations)...")
for i in range(WARM_UP):
run_inference()
print(f" Warm-up {i+1}/{WARM_UP} done")
# Benchmark
print(f"\nRunning benchmark ({RUN} iterations)...")
elapsed_times = []
for i in range(RUN):
start = time.perf_counter()
output = run_inference()
end = time.perf_counter()
elapsed = (end - start) * 1000 # ms
elapsed_times.append(elapsed)
print(f" Run {i+1}/{RUN}: {elapsed:.2f} ms")
# Statistics
avg_time = sum(elapsed_times) / len(elapsed_times)
min_time = min(elapsed_times)
max_time = max(elapsed_times)
print(f"\n{'='*50}")
print(f"Results ({RUN} runs):")
print(f" Average: {avg_time:.2f} ms")
print(f" Min: {min_time:.2f} ms")
print(f" Max: {max_time:.2f} ms")
print(f"{'='*50}")The pipeline latency has 50% performance regression after the regression PR.
Since the PR is targeted to fix the DDP issue, I think we can check if DDP before using .contiguous(). WDYT? @sayakpaul
Hi @jinc7461 . Could you please provide the script to reproduce the error, and give me some advice to check before using .contiguous() ? Thanks!
cc @sywangyi
Logs
System Info
torch 2.11.0.dev20260113+cpu
platform: Intel Xeon 6
Who can help?
No response