Skip to content

Conversation

Narsil
Copy link
Contributor

@Narsil Narsil commented Sep 14, 2022

While investigating performance of diffusers I figured out a relatively simple way to get 35% speedups on the default stable AI.

Simply by removing autocast.

Happy to add any test that might be worthwhile, or move some of the casting to different places.

Ran on Titan RTX:

Before

Took 0:00:18.994290

After

Took 0:00:14.049952

The reason why autocast adds so much overhead, si that a few tensors were still in fp32 and EVERY single operation afterwards would downcast to fp16 for the op, and upcast back to fp32 afterwards leading to insane amounts of copies of tensors.
Screenshot from 2022-09-14 17-17-48

Since it seems very easy to have inefficient code with it I took the liberty of removing it in the lib itself.
If that's ok with the maintainers I would like to remove it from the docs too.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 14, 2022

The documentation is not available anymore as the PR was closed or merged.

from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt).images[0]
image = pipe(prompt).images[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure that this is faster? Using autocast gives currently (before this PR) a 2x boost in terms of generation speed.

Will also test a bit locally on a GPU tomorrow

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is extremely surprising but I am also measuring a 2x speedup with autocast on f32.

I am looking into it, I see to copies but not nearly the same amount, there's probably a device-to-host /host-to-device somewhere that kills performance but I haven't found it yet.

Copy link
Contributor Author

@Narsil Narsil Sep 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I figure it out.

autocast will actually use fp16 for some ops by doing some heuristics. https://pytorch.org/docs/stable/amp.html#cuda-op-specific-behavior

So it's faster because it's running on fp16 even if the model was loaded in f32.
So without it it's slower just because it's actually running f32.

If we enable real fp16 with a big performance boost I feel like we shouldn't need it f32 (but that does make it slower but also "more" correct.). Some ops are kind of dangerous to actual run in fp16 but we should be able to tell them apart (and for now it seems the generations are actually still good enough even when running everything in f16)

But it's still a nice way to get f16 running "for free". The heuristics they use seem OK (but as they mention, they probably wouldn't work for gradients, I guess because fp16 overflows faster)

Copy link
Contributor

@patrickvonplaten patrickvonplaten Oct 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commented extensively here: #511

Should we maybe change the example to load native fp16 weights then?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So replace:

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)

by

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16, revision="fp16")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems OK. Is fp16 supported by enough GPU cards at this point ? That would be my only point of concern.
But given the speed difference, advocating for fp16 is definitely something we should do !

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think fp16 is widely supported now

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this looks very good to me, I'm not 100% sure whether we're getting a big speed-up though if everything is in fp32 - will test tomorrow

@keturn
Copy link
Contributor

keturn commented Sep 14, 2022

Do you think it would be practical to take something from test_pipelines and make some sort of pipeline benchmark test from it?

I've always been reluctant to try to write performance tests when CI is made out of The Cloud, but if you can trust the tests will run somewhere with a consistent execution environment, then I think it could be very informative.

@Narsil
Copy link
Contributor Author

Narsil commented Sep 15, 2022

I've always been reluctant to try to write performance tests when CI is made out of The Cloud, but if you can trust the tests will run somewhere with a consistent execution environment, then I think it could be very informative.

100% agree, I was not thinking about performance, merely correctness (the pipeline works in fp16 without autocast.
This is the biggest point with autocast, if ANY tensors forgets to be instantiated in the "correct" dtype, then probably there's going to be huge amount of copies will be added. So making sure we can run it without probably saves us from this issue from creeping up again. autocast then is a slight slowdown, but only because it will UPcast the operations it finds should be ran in fp32 (which might be more correct)

Edit: Added a slow test.

@keturn
Copy link
Contributor

keturn commented Sep 27, 2022

I've been successfully using some code based on this patch. The one change I've had to make is to the stochastic schedulers such as DDIM. Their use of randn needs to be adjusted to use an explicit dtype, in much the same way the randn call has been changed in pipeline_stable_diffusion here.

@keturn
Copy link
Contributor

keturn commented Sep 30, 2022

With #371 merged, this needs conflicts resolved.

I think #371 actually included most, if not all, of the changes from this PR, but the DDIM issue I mentioned seems to still exist when I try to call the pipeline without autocast.

@patrickvonplaten
Copy link
Contributor

Sorry to be so late here @Narsil ! Do you think there is a way to extract everything out of this PR that was not already merged in #371

@Narsil
Copy link
Contributor Author

Narsil commented Oct 4, 2022

I think it included most of the changes here sorry I'm also laggy on my GH notifications.

For DDIM I don't know but probably all schedulers/snippet need to handle precision.

I still advocate removing autocast from all snippets (and making sure they work).
If we don't then the latency regression is likely to come back very fast.

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Oct 4, 2022

@NouamaneTazi do you time by any change to rebase this PR on current main and see if there is anything we forgot to add in your PR: #371?

Also, I agree with @Narsil that we should updated the docs / readmes to not have "autocast" if pure fp16 is faster and less error-prone. If @NouamaneTazi you have time maybe we could go through the docs together and see if we can remove autocast in favor of pure fp16 if we get more or less same visual results

device=latents_device,
dtype=text_embeddings.dtype,
)
if self.device.type == "mps":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pcuenca could you take a look here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I commented in the other conservation, I think it's ok like this.

diff = np.abs(image_chunked.flatten() - image.flatten())
# They ARE different since ops are not run always at the same precision
# however, they should be extremely close.
assert diff.mean() < 2e-2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test tolerance is a bit high to me... => will play around with it a bit!

Copy link
Contributor Author

@Narsil Narsil Oct 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's roughly exactly the same difference as in #371 for pure fp16 run.

Running some ops in f32 directly instead of f16 (which autocast will do) does change some patches.
This is less than 2% total variance in images, so it really doesn't show visually. (Much less than some other changes in #371 where I think some visual differences were visible).

@patrickvonplaten
Copy link
Contributor

@Narsil I checked ran a quick benchmark which I think is in line with your findings:

The following script:

#!/usr/bin/env python3
from torch import autocast
from diffusers import StableDiffusionPipeline
import torch
from time import time
pipe_fp32 = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to("cuda")
pipe_fp16 = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16, revision="fp16").to("cuda")


prompt = 4 * ["Dark, eerie forest with beautiful sunshine."]


def benchmark(name, do_autocast, pipe, generator):
    print(name)
    start = time()
    if not do_autocast:
        image = pipe(prompt, generator=generator).images[0]
    else:
        with autocast("cuda"):
            image = pipe(prompt, generator=generator).images[0]
    image.save(name + ".png")
    end = time()
    print("time", end - start)
    print(50 * "-")


torch_device = "cuda"
name = "FP32 - Autocast"
generator = torch.Generator(device=torch_device).manual_seed(0)
benchmark(name, True, pipe_fp32, generator)

name = "FP32"
generator = torch.Generator(device=torch_device).manual_seed(0)
benchmark(name, False, pipe_fp32, generator)

name = "FP16 - Autocast"
generator = torch.Generator(device=torch_device).manual_seed(0)
benchmark(name, True, pipe_fp16, generator)

name = "FP16"
generator = torch.Generator(device=torch_device).manual_seed(0)
benchmark(name, False, pipe_fp16, generator)

gives these results on a TITAN RTX:

FP32 - Autocast
100%|██████████████████████████████████████████████████████████████████████████████████| 51/51 [00:19<00:00,  2.62it/s]
time 20.597806692123413
--------------------------------------------------
FP32
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 51/51 [00:37<00:00,  1.34it/s]
time 38.97595691680908
--------------------------------------------------
FP16 - Autocast
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 51/51 [00:19<00:00,  2.61it/s]
time 20.543357133865356
--------------------------------------------------
FP16
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 51/51 [00:14<00:00,  3.58it/s]
time 15.096464395523071
--------------------------------------------------

The saved images are all identical - in general it doesn't seem like fp16 is hurting performance really (see: https://huggingface.co/datasets/patrickvonplaten/images/tree/main/to_delete) => so should we maybe change all of the docs to advertise native FP16 indeed meaning we'll advertise the following in all our examples:

from diffusers import StableDiffusionPipeline
import torch
from time import time

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16, revision="fp16").to("cuda")

image = pipe(prompt).images[0]

It seems to make most sense given that we already advertise everywhere that it should be run on GPU no?

What do you think @Narsil @pcuenca @patil-suraj @anton-l ?

@anton-l
Copy link
Member

anton-l commented Oct 5, 2022

@patrickvonplaten I agree!

But just to be on the safe side I think we need to throw a custom informative error when users try to use the fp16 weights on the CPU, as RuntimeError: "LayerNormKernelImpl" not implemented for 'Half' isn't too useful there. Happy to do that in another PR :)

@Narsil
Copy link
Contributor Author

Narsil commented Oct 5, 2022

@Narsil I checked ran a quick benchmark which I think is in line with your findings:

Perfectly in line. The other kind of optimization torch.jit.script actually helped a lot too, and autocast when the attention was scripted led to 7% slowdown instead of 25%. Just FYI (I'm guessing because the fp32 softmax can be done more efficiently). Also the speedups/slowdowns did seem to depend from GPU to GPU (the overall direction + magnitude is always correlated, but not necessarily the same throughout).

@patrickvonplaten
Copy link
Contributor

Merging now and then updating the docs in a follow up PR

@patrickvonplaten patrickvonplaten merged commit 3dcc75c into main Oct 5, 2022
@patrickvonplaten patrickvonplaten deleted the optimizations branch October 5, 2022 13:33
shirayu added a commit to shirayu/purepale that referenced this pull request Oct 6, 2022
@WASasquatch
Copy link

WASasquatch commented Oct 14, 2022

For the Testla T4, on Google Colab, this results in a 2-second gain in time, from 28s to 30s (this timing the actual process itself, not the time the pipe gives you).

Using PNDM, 50 steps, guidance scale 13.5 (usual I use), resolution 512x768 (Portrait)

@Narsil
Copy link
Contributor Author

Narsil commented Oct 14, 2022

@WASasquatch Most of the speedups were already included in #371 as their was some duplicated efforts here.

If you're interested you can even try https://github.com/huggingface/diffusers/tree/optim_attempts where we went a bit further.
using https://pytorch.org/docs/stable/backends.html#torch.backends.cudnn.torch.backends.cudnn.benchmark (Set it to True in your script). That provides an additional few %.

@WASasquatch
Copy link

WASasquatch commented Oct 15, 2022

@WASasquatch Most of the speedups were already included in #371 as their was some duplicated efforts here.

If you're interested you can even try https://github.com/huggingface/diffusers/tree/optim_attempts where we went a bit further. using https://pytorch.org/docs/stable/backends.html#torch.backends.cudnn.torch.backends.cudnn.benchmark (Set it to True in your script). That provides an additional few %.

I'll check it out.

The speed I got after Attention Slicing PR (over a month ago) was 28 seconds, all through-out these changes (I use current github branch and make changes as they happen). With this PR it's increased to 30 seconds.

Could it be that enable_attention_slicing() is having undesired effects with this PR?

@Narsil
Copy link
Contributor Author

Narsil commented Oct 17, 2022

Could it be that enable_attention_slicing() is having undesired effects with this PR?

Most likely. Attention slicing is for memory reduction, not inference speed.
But it's hard to tell without looking at the whole script.

You could also try using pytorch profiler to see what's wrong. Maybe the slow part lies somewhere else.
Can't recommend enough using : https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html

@WASasquatch
Copy link

You could also try using pytorch profiler to see what's wrong. Maybe the slow part lies somewhere else.
Can't recommend enough using : https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html

@Narsil Thanks for that. I assume I'm looking at the profiler bit. Thanks again.

prathikr pushed a commit to prathikr/diffusers that referenced this pull request Oct 26, 2022
…ful). (huggingface#511)

* Removing `autocast` for `35-25% speedup`.

* iQuality

* Adding a slow test.

* Fixing mps noise generation.

* Raising error on wrong device, instead of just casting on behalf of user.

* Quality.

* fix merge

Co-authored-by: Nouamane Tazi <nouamane98@gmail.com>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
…ful). (huggingface#511)

* Removing `autocast` for `35-25% speedup`.

* iQuality

* Adding a slow test.

* Fixing mps noise generation.

* Raising error on wrong device, instead of just casting on behalf of user.

* Quality.

* fix merge

Co-authored-by: Nouamane Tazi <nouamane98@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants