New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
perf: prefer batched matmuls for attention #1203
Conversation
The documentation is not available anymore as the PR was closed or merged. |
220c0f0
to
4dcce80
Compare
Thanks a lot @Birch-san, I'll take a look today (and test on CUDA too) :) |
I've tested using the default PNDM scheduler for 50 steps on CUDA and MPS, and I'm observing about 11.5% speedup in both devices:
|
The results are very similar, but not identical. They are probably being caused by precision errors? Update: this refers to |
I think I was getting visually identical results on float32 on MPS. |
Tangential question about how to handle this sort of thing in the future: torch.einsum("b i d, b j d -> b i j", query, key) * self.scale is so much more concise than the nine-parameter Do you think the optimization of this to Or would an optimizer for einsum not recognize this as optimizable-to-baddbmm because the scale (alpha) multiplication happens outside the einsum? (and unfortunately there is no einsum syntax to include scalar multiplication.) Alternately: Lots of the cognitive noise in the torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device) which is unfortunate, because that empty tensor is ignored with I might be tempted to make some kind of |
@keturn it's a good question. perhaps the shape of the maybe an even more concise
it feels like it must be possible to optimize |
@Birch-san, did you also check how this affects performance for PyTorch + CUDA which is the main use case here? I'm not super opposed to this, but we've seen changes like this possible leading to big differences in generation. Would love to have comparisons also for PyTorch + CUDA on GPU or else we could also think about making a seperate attention function for |
@patrickvonplaten I haven't and cannot. what did you think of @pcuenca's measurements, which showed an 11% speedup on CUDA? |
Ah yeah true, sorry only realizing now! That's actually quite a bit - @patil-suraj could you also run this on a A100 and I'll take the V100 so that we also have numbers for those. If we also get high speed-ups there and the precision stays within limits, I'm actually happy to go forward with this PR. @patil-suraj @pcuenca , could we maybe all run this script: #!/usr/bin/env python3
import torch
import time
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
revision="fp16", # comment out for fp32
torch_dtype=torch.float16, # comment out for fp32
)
pipe = pipe.to("cuda")
def time_stable_diffusion(model, _id, bs=1, num_inference_steps=50):
print(f"time_bs={bs}_id={_id}_num_steps={num_inference_steps}")
print(20 * "-")
start_time = time.time()
prompt = bs * ["a photo of an astronaut riding a horse on mars"]
model(prompt, num_inference_steps=num_inference_steps).images[0]
print(time.time() - start_time)
print(20 * "-")
time_stable_diffusion(pipe, "v1-5-pndm-fp16", bs=1)
time_stable_diffusion(pipe, "v1-5-pndm-fp16", bs=2)
time_stable_diffusion(pipe, "v1-5-pndm-fp16", bs=4)
time_stable_diffusion(pipe, "v1-5-pndm-fp16", bs=8) on CUDA PyTorch in both fp32 and fp16 for both "main" and this PR.
|
Here some results for the V100 machine: FP32
FP16
=> So we're also seeing a clear improvement on a V100. @patil-suraj @pcuenca would be great if you could also run a quick benchmark. |
src/diffusers/models/attention.py
Outdated
|
||
# get scores | ||
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) | ||
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm | ||
scale = 1 / math.sqrt(self.channels / self.num_heads) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could me move this line before the first if-else
statement and then have only one big if-else statement? IMO this would make the code more readable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, rebased, force-pushed.
there's still another if-else below. merging that is possible, but only by duplicating the softmax.
@Birch-san, I rebased your PR to current main to make sure we can test it well, hope this is ok. @pcuenca @patil-suraj @anton-l just ran the whole slow test suite and everything passes => I'm in favor of merging this PR. Could you give it a look? |
…r when num_heads=1
d8ac7ee
to
f9bf148
Compare
Updated tests. (Note: I ran a warmup pass with bs=1 at the beginning of each test).
Updated A100 numbers after running on PyTorch version |
In terms of visual differences, these are my conclusions:
Examples (best seen downloading and switching back and forth)Note the astronaut's boot and thigh.Shadows, background structure.
TL;DR: happy to merge if the slow tests pass. |
@pcuenca is the warup time also counted ? Because I'm getting slightly faster perf on A100 for fp16 with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for adding this! I'm seeing nice speed-ups on A100.
Left a comment about how to always make this a 3D problem.
if the slow tests pass and the generated images are not worse than main
, should be good to merge after the comments are addressed :)
query_states = self.transpose_for_scores(query_proj) | ||
key_states = self.transpose_for_scores(key_proj) | ||
value_states = self.transpose_for_scores(value_proj) | ||
|
||
# TODO: is there a way to perform batched matmul (e.g. baddbmm) on 4D tensors? | ||
# or reformulate this into a 3D problem? | ||
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum | ||
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS | ||
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0 | ||
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's possible to formulate as 3D problem, the same way we are doing in CrossAttention
, by merging the batch
and heads
, before bmm and then splitting it again. For example
def reshape_heads_to_batch_dim(tensor, heads=2):
batch_size, seq_len, dim = tensor.shape
head_size = heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(tensor, heads=2):
batch_size, seq_len, dim = tensor.shape
head_size = heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
# tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
query_states = reshape_heads_to_batch_dim(query_proj)
key_states = reshape_heads_to_batch_dim(key_proj)
value_states = reshape_heads_to_batch_dim(v_proj)
attention_scores = torch.baddbmm(
torch.empty(
query_states.shape[0],
query_states.shape[1],
key_states.shape[1],
dtype=query_states.dtype,
device=query_states.device,
),
query_states,
key_states.transpose(-1, -2),
beta=0,
alpha=1,
)
attention_scores = reshape_batch_dim_to_heads(attention_scores)
# TODO: is there a way to perform batched matmul (e.g. bmm) on 4D tensors? | ||
# or reformulate this into a 3D problem? | ||
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum | ||
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS | ||
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0 | ||
hidden_states = torch.matmul(attention_probs, value_states) | ||
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() | ||
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) | ||
hidden_states = hidden_states.view(new_hidden_states_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment as above, should be possible to make it 3D using the logic in above comment.
attn_slice = torch.baddbmm( | ||
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), | ||
query[start_idx:end_idx], | ||
key[start_idx:end_idx].transpose(-1, -2), | ||
beta=0, | ||
alpha=self.scale, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
@Birch-san Feel free to open a follow-up PR to address the comments, merging this now :) |
perf: prefer batched matmuls for attention. added fast-path to Decoder when num_heads=1
Restores the
baddbmm()
which was reverted in #689.the original attempt was reverted due to its regressing the performance of the
tests/test_pipelines.py::PipelineTesterMixin::test_stable_diffusion_memory_chunking
test.testing is more mature now, so let's have another try and see whether the problem reproduces.
Implements this optimization, motivated by the following benchmarks of matmul strategies (on MPS backend):
batched matmul with scale factor
https://gist.github.com/Birch-san/8f3eb99deffdc3541595e46a01605dea
batched matmul
https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
adds fast-path to attention in Decoder when
num_heads=1
.most of the heavy lifting here is done by the Unet optimizations. but perhaps the simpler Decoder will compile better to CoreML.
we should check CoreML performance — it was using
matmul()
until now for compatibility, which on MPS is the slowest way to run a batched matmul, according to my measurements.measurements (seconds) from PyTorch
1.14.0.dev20221105
, M1 Pro Max (MPS backend)15 Heun steps (via k-diffusion), batch-of-1 512x512 image:
some measurements — especially first-in-batch — are outliers by about a second; they may be benefitting from lower GPU temperature.
Sliced
with Unet+Decoder opt:
17.946660667017568
18.44281687500188
18.492898124968633
unoptimized:
21.032894457981456
21.52378533303272
(I'll ignore the outlier, and average the rest)
Conclusion:
21.25÷18.45 = 15% faster
Unsliced
with Unet+Decoder opt:
17.943294666998554
17.721827082976233
17.2017434159643
17.328909542004112
with Unet opt only:
16.862462541961577
17.651387749996502
17.75758166599553
unoptimized:
19.89687283296371
20.449800083995797
20.862784416996874
(I'll ignore the 19.8 outlier, and average the rest)
Conclusion:
20.66÷17.53 = 18% faster