Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: prefer batched matmuls for attention #1203

Merged
merged 1 commit into from Nov 21, 2022

Conversation

Birch-san
Copy link
Contributor

@Birch-san Birch-san commented Nov 9, 2022

Restores the baddbmm() which was reverted in #689.
the original attempt was reverted due to its regressing the performance of the tests/test_pipelines.py::PipelineTesterMixin::test_stable_diffusion_memory_chunking test.
testing is more mature now, so let's have another try and see whether the problem reproduces.

Implements this optimization, motivated by the following benchmarks of matmul strategies (on MPS backend):

batched matmul with scale factor
https://gist.github.com/Birch-san/8f3eb99deffdc3541595e46a01605dea

batched matmul
https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0

adds fast-path to attention in Decoder when num_heads=1.
most of the heavy lifting here is done by the Unet optimizations. but perhaps the simpler Decoder will compile better to CoreML.

we should check CoreML performance — it was using matmul() until now for compatibility, which on MPS is the slowest way to run a batched matmul, according to my measurements.

measurements (seconds) from PyTorch 1.14.0.dev20221105, M1 Pro Max (MPS backend)
15 Heun steps (via k-diffusion), batch-of-1 512x512 image:

some measurements — especially first-in-batch — are outliers by about a second; they may be benefitting from lower GPU temperature.

Sliced

with Unet+Decoder opt:
17.946660667017568
18.44281687500188
18.492898124968633

unoptimized:
21.032894457981456
21.52378533303272

(I'll ignore the outlier, and average the rest)
Conclusion:
21.25÷18.45 = 15% faster

Unsliced

with Unet+Decoder opt:
17.943294666998554
17.721827082976233
17.2017434159643
17.328909542004112

with Unet opt only:
16.862462541961577
17.651387749996502
17.75758166599553

unoptimized:
19.89687283296371
20.449800083995797
20.862784416996874

(I'll ignore the 19.8 outlier, and average the rest)
Conclusion:
20.66÷17.53 = 18% faster

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 9, 2022

The documentation is not available anymore as the PR was closed or merged.

@pcuenca
Copy link
Member

pcuenca commented Nov 9, 2022

Thanks a lot @Birch-san, I'll take a look today (and test on CUDA too) :)

@pcuenca
Copy link
Member

pcuenca commented Nov 10, 2022

I've tested using the default PNDM scheduler for 50 steps on CUDA and MPS, and I'm observing about 11.5% speedup in both devices:

Hardware bs main bmm Δ (%)
3090 1 3.52 3.12 11.4
2 6.15 5.46 11.2
4 11.83 10.48 11.4
6 17.3 15.23 12
8 22.83 20.1 12
mps, 64GB 1 24.16 21.42 11.3

@pcuenca
Copy link
Member

pcuenca commented Nov 10, 2022

The results are very similar, but not identical. They are probably being caused by precision errors?

Update: this refers to float16, I verified they are visually identical for float32.

@Birch-san
Copy link
Contributor Author

Birch-san commented Nov 10, 2022

I think I was getting visually identical results on float32 on MPS.
I did simplify the arithmetic in the decoder (eliminated an unnecessary sqrt and a multiply), which might have reduced some numerical instability with the original approach. you could undo the decoder changes to verify that theory.

@keturn
Copy link
Contributor

keturn commented Nov 10, 2022

Tangential question about how to handle this sort of thing in the future:

torch.einsum("b i d, b j d -> b i j", query, key) * self.scale

is so much more concise than the nine-parameter baddbmm call that replaces it.

Do you think the optimization of this to baddbmm() is something that could be handled by a future optimization to torch.einsum upstream?

Or would an optimizer for einsum not recognize this as optimizable-to-baddbmm because the scale (alpha) multiplication happens outside the einsum? (and unfortunately there is no einsum syntax to include scalar multiplication.)

Alternately:

Lots of the cognitive noise in the baddbmm call is its input argument:

torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device)

which is unfortunate, because that empty tensor is ignored with beta=0. Terrible signal-to-noise ratio for that line.

I might be tempted to make some kind of scaled_bmm(batch1, batch2, alpha) function to encapsulate that.

@Birch-san
Copy link
Contributor Author

@keturn it's a good question.

perhaps the shape of the torch.empty doesn't matter. when I originally tried this on CompVis: I used torch.empty(1,1,1). I assigned it as an nn.Parameter property during __init__() so I could re-use it (I think that has the downside of meaning it gets serialized with the model weights, but the upside that it gets the device and dtype for free).

maybe an even more concise torch.empty() would work.

Do you think the optimization of this to baddbmm() is something that could be handled by a future optimization to torch.einsum upstream?

it feels like it must be possible to optimize torch.einsum to recognise b i d, b j d -> b i j as a bmm(). I dunno whether it's so easy to make it eat the * scale that's outside the einsum. this kind of op-fusing feels like it requires a JIT, so maybe it's possible to get that with other tools, but not sure that torch does that for free.

@patrickvonplaten
Copy link
Contributor

@Birch-san, did you also check how this affects performance for PyTorch + CUDA which is the main use case here? I'm not super opposed to this, but we've seen changes like this possible leading to big differences in generation.

Would love to have comparisons also for PyTorch + CUDA on GPU or else we could also think about making a seperate attention function for device="mps" - what do you think @pcuenca ?

@Birch-san
Copy link
Contributor Author

did you also check how this affects performance for PyTorch + CUDA

@patrickvonplaten I haven't and cannot. what did you think of @pcuenca's measurements, which showed an 11% speedup on CUDA?
#1203 (comment)

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Nov 18, 2022

Ah yeah true, sorry only realizing now! That's actually quite a bit - @patil-suraj could you also run this on a A100 and I'll take the V100 so that we also have numbers for those. If we also get high speed-ups there and the precision stays within limits, I'm actually happy to go forward with this PR.

@patil-suraj @pcuenca , could we maybe all run this script:

#!/usr/bin/env python3
import torch
import time
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    revision="fp16",  # comment out for fp32
    torch_dtype=torch.float16,  # comment out for fp32
)
pipe = pipe.to("cuda")


def time_stable_diffusion(model, _id, bs=1, num_inference_steps=50):
    print(f"time_bs={bs}_id={_id}_num_steps={num_inference_steps}")
    print(20 * "-")
    start_time = time.time()
    prompt = bs * ["a photo of an astronaut riding a horse on mars"]
    model(prompt, num_inference_steps=num_inference_steps).images[0]
    print(time.time() - start_time)
    print(20 * "-")


time_stable_diffusion(pipe, "v1-5-pndm-fp16", bs=1)
time_stable_diffusion(pipe, "v1-5-pndm-fp16", bs=2)
time_stable_diffusion(pipe, "v1-5-pndm-fp16", bs=4)
time_stable_diffusion(pipe, "v1-5-pndm-fp16", bs=8)

on CUDA PyTorch in both fp32 and fp16 for both "main" and this PR.

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Nov 20, 2022

Here some results for the V100 machine:

FP32

Hardware bs main bmm Δ (%)
V100 1 11.01 10.27 7.25
2 17.63 16.28 8.29
4 33.54 30.84 8.75
8 - - -

FP16

Hardware bs main bmm Δ (%)
V100 1 5.20 4.76 9.24
2 6.88 6.13 12.23
4 12.66 11.21 12.93
8 24.37 21.48 13.45

=> So we're also seeing a clear improvement on a V100.

@patil-suraj @pcuenca would be great if you could also run a quick benchmark.
Looks like this PR would be a nice improvement to the current codebase!


# get scores
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm
scale = 1 / math.sqrt(self.channels / self.num_heads)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could me move this line before the first if-else statement and then have only one big if-else statement? IMO this would make the code more readable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, rebased, force-pushed.

there's still another if-else below. merging that is possible, but only by duplicating the softmax.

@patrickvonplaten
Copy link
Contributor

@Birch-san, I rebased your PR to current main to make sure we can test it well, hope this is ok.

@pcuenca @patil-suraj @anton-l just ran the whole slow test suite and everything passes => I'm in favor of merging this PR. Could you give it a look?

@pcuenca
Copy link
Member

pcuenca commented Nov 21, 2022

Updated tests. (Note: I ran a warmup pass with bs=1 at the beginning of each test).

Hardware bs main bmm Δ (%)
3090, fp32 1 8.00 7.33 8.38
2 13.34 12.00 10.04
4 26.12 23.39 10.45
8 52.05 48.60 10.47
Hardware bs main bmm Δ (%)
3090, fp16 1 3.44 3.08 10.47
2 6.07 5.33 12.19
4 11.79 10.31 12.55
8 22.60 19.72 12.74
Hardware bs main bmm Δ (%)
A100, fp32 1 5.79 5.37 7.25
2 10.45 9.60 8.13
4 19.38 17.67 8.82
8 37.55 34.08 9.24
Hardware bs main bmm Δ (%)
A100, fp16 1 3.30 3.10 6.06
2 4.09 3.74 8.56
4 6.58 5.73 12.92
8 12.40 10.69 13.79

Updated A100 numbers after running on PyTorch version 1.13.0+cu116

@pcuenca
Copy link
Member

pcuenca commented Nov 21, 2022

In terms of visual differences, these are my conclusions:

  • CUDA @ fp32: No visual differences.
  • CUDA @ fp16: Minor visual differences.
Examples (best seen downloading and switching back and forth) Note the astronaut's boot and thigh.

v1-5-pndm-fp16_0_bmm-rebased
v1-5-pndm-fp16_0_main

Shadows, background structure.

v1-5-pndm-fp16_5_main
v1-5-pndm-fp16_5_bmm-rebased

  • mps: Same as cuda: no visual difference in float32 and some minor differences, sporadically, in float16.
Examples `mps @ float16` Note the shadow behind the dog's ear.

fp16-bs-1-0-main
fp16-bs-1-0-bmm-rebased

TL;DR: happy to merge if the slow tests pass.

@patil-suraj
Copy link
Contributor

@pcuenca is the warup time also counted ? Because I'm getting slightly faster perf on A100 for fp16 with bmm, for ex 10.56 sec for BS=8

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding this! I'm seeing nice speed-ups on A100.
Left a comment about how to always make this a 3D problem.

if the slow tests pass and the generated images are not worse than main, should be good to merge after the comments are addressed :)

Comment on lines +291 to +300
query_states = self.transpose_for_scores(query_proj)
key_states = self.transpose_for_scores(key_proj)
value_states = self.transpose_for_scores(value_proj)

# TODO: is there a way to perform batched matmul (e.g. baddbmm) on 4D tensors?
# or reformulate this into a 3D problem?
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * scale
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible to formulate as 3D problem, the same way we are doing in CrossAttention, by merging the batch and heads , before bmm and then splitting it again. For example

def reshape_heads_to_batch_dim(tensor, heads=2):
    batch_size, seq_len, dim = tensor.shape
    head_size = heads
    tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
    tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
    return tensor

def reshape_batch_dim_to_heads(tensor, heads=2):
    batch_size, seq_len, dim = tensor.shape
    head_size = heads
    tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
    # tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
    return tensor

query_states = reshape_heads_to_batch_dim(query_proj)
key_states = reshape_heads_to_batch_dim(key_proj)
value_states = reshape_heads_to_batch_dim(v_proj)

attention_scores = torch.baddbmm(
    torch.empty(
        query_states.shape[0],
        query_states.shape[1],
        key_states.shape[1],
        dtype=query_states.dtype,
        device=query_states.device,
    ),
    query_states,
    key_states.transpose(-1, -2),
    beta=0,
    alpha=1,
)

attention_scores = reshape_batch_dim_to_heads(attention_scores)

Comment on lines +322 to +330
# TODO: is there a way to perform batched matmul (e.g. bmm) on 4D tensors?
# or reformulate this into a 3D problem?
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
hidden_states = torch.matmul(attention_probs, value_states)
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
hidden_states = hidden_states.view(new_hidden_states_shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above, should be possible to make it 3D using the logic in above comment.

Comment on lines +565 to +571
attn_slice = torch.baddbmm(
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query[start_idx:end_idx],
key[start_idx:end_idx].transpose(-1, -2),
beta=0,
alpha=self.scale,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

@patil-suraj
Copy link
Contributor

@Birch-san Feel free to open a follow-up PR to address the comments, merging this now :)

yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
perf: prefer batched matmuls for attention. added fast-path to Decoder when num_heads=1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants