Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Out of curiosity, what's the performance compared to torch.compile? #2

Closed
Chillee opened this issue Oct 17, 2023 · 8 comments
Closed

Comments

@Chillee
Copy link

Chillee commented Oct 17, 2023

I think torch.compile should work on HF diffusers: https://huggingface.co/docs/diffusers/optimization/torch2.0#a100-batch-size-1

@isidentical
Copy link

I have made an attempt on reproducing the results, but always take these with a grain of salt since there might be differences in cuda/library versions/hardware (especially memory bandwith between 40G and 80G being different) etc.

On one of our A100 40G, using model=SD1.5, batch_size=1, steps=100, I get:

-> SD1.5 out of the box with torch 2.1 SDPA is ~32it/s
-> SD1.5 + torch.compile is ~51it/s
-> SD1.5 + stable-fast is ~55it/s

@chengzeyi
Copy link
Owner

I have made an attempt on reproducing the results, but always take these with a grain of salt since there might be differences in cuda/library versions/hardware (especially memory bandwith between 40G and 80G being different) etc.

On one of our A100 40G, using model=SD1.5, batch_size=1, steps=100, I get:

-> SD1.5 out of the box with torch 2.1 SDPA is ~32it/s -> SD1.5 + torch.compile is ~51it/s -> SD1.5 + stable-fast is ~55it/s

Yes, in this doc, the performance of torch.compile on hf diffusers is discussed in detail.

https://huggingface.co/docs/diffusers/optimization/torch2.0

@Chillee
Copy link
Author

Chillee commented Oct 18, 2023

So it's about on par with TensorRT but a bit slower than OneFlow and this repo?

@isidentical
Copy link

So it's about on par with TensorRT but a bit slower than OneFlow and this repo?

Since I have not tested either TensorRT or OneFlow on the same hardware, it's hard to make comparisons directly on the reported it/s numbers here vs mine. But as a relative perf point, torch.compile on A100 40G is about %8 slower than stable-fast which puts it into the same venue as OneFlow which is 9% slower compared to stable-fast.

@chengzeyi
Copy link
Owner

So it's about on par with TensorRT but a bit slower than OneFlow and this repo?

Since I have not tested either TensorRT or OneFlow on the same hardware, it's hard to make comparisons directly on the reported it/s numbers here vs mine. But as a relative perf point, torch.compile on A100 40G is about %8 slower than stable-fast which puts it into the same venue as OneFlow which is 9% slower compared to stable-fast.

Commercial GPUs are expensive and hard to get in my country.
We also have strict international Internet connection limitations here.

So testing cutting-edge ML models is not an easy task for me, please wait.

@chengzeyi
Copy link
Owner

chengzeyi commented Oct 20, 2023

So it's about on par with TensorRT but a bit slower than OneFlow and this repo?

Since I have not tested either TensorRT or OneFlow on the same hardware, it's hard to make comparisons directly on the reported it/s numbers here vs mine. But as a relative perf point, torch.compile on A100 40G is about %8 slower than stable-fast which puts it into the same venue as OneFlow which is 9% slower compared to stable-fast.

Accurate benchmark has been conducted for 4090 and 3090 by myself today.

Performance varies very greatly across different hardware/software/platform/driver configurations.
It is very hard to benchmark accurately. And preparing the environment for benchmarking is also a hard job.
I have tested on some platforms before but the results may still be inaccurate.

currently A100 is hard and expensive to rent from cloud server providers in my region.

Benchmark results will be available when I have the access to A100 again.

RTX 4090 (512x512, batch size 1, fp16, tcmalloc enabled)

Framework SD 1.5 SD 2.1 SD 1.5 ControlNet
Vanilla PyTorch (2.1.0+cu118) 24.9 it/s 27.1 it/s 18.9 it/s
torch.compile (2.1.0+cu118, NHWC UNet) 33.5 it/s 38.2 it/s 22.7 it/s
AITemplate 65.7 it/s 71.6 it/s untested
OneFlow 60.1 it/s 12.9 it/s (??) untested
TensorRT untested untested untested
Stable Fast (with xformers & triton) 61.8 it/s 61.6 it/s 42.3 it/s

RTX 3090 (512x512, batch size 1, fp16, tcmalloc enabled)

Framework SD 1.5
Vanilla PyTorch (2.1.0+cu118) 22.5 it/s
torch.compile (2.1.0+cu118, NHWC UNet) 25.3 it/s
AITemplate 34.6 it/s
OneFlow 38.8 it/s
TensorRT untested
Stable Fast (with xformers & triton) 31.5 it/s

@Chillee
Copy link
Author

Chillee commented Oct 20, 2023

Thanks for the benchmarks!

I wonder if the tile sizes for torch.compile are just very untuned for consumer hardware 🤔

If you happen to have some more time, could you try:

  1. making sure you turn on mode="reduce-overhead" for torch.compile.
  2. Try running with TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1

Thanks a lot for the benchmarks! I don't happen to have any consumer cards immediately available, so it's good to see torch.compile performance on consumer hardware.

@chengzeyi
Copy link
Owner

chengzeyi commented Oct 29, 2023

Thanks for the benchmarks!

I wonder if the tile sizes for torch.compile are just very untuned for consumer hardware 🤔

If you happen to have some more time, could you try:

  1. making sure you turn on mode="reduce-overhead" for torch.compile.
  2. Try running with TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1

Thanks a lot for the benchmarks! I don't happen to have any consumer cards immediately available, so it's good to see torch.compile performance on consumer hardware.

In my own development environment, with 'reduce-overhead', the model just generates buggy outputs...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants