Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch compile settings #2

Closed
hal-314 opened this issue Mar 21, 2024 · 12 comments
Closed

Torch compile settings #2

hal-314 opened this issue Mar 21, 2024 · 12 comments

Comments

@hal-314
Copy link

hal-314 commented Mar 21, 2024

Hi there,

First of all, thanks for the benchmark! It's very useful to see this nice comparison between keras back-ends and native Pytorch.

I have two doubts:

  • Why did you choose to use 'reduce-overhead' and not 'max-autotune' to compile native torch code? Pytorch docs recommends max-autotune for the best performance.
  • Torch keras backend is generally slower than native torch. Sometimes, it's more than 2 times slower. Is keras using torch compile? If so, which mode is using?

Finally, it could be useful to add how much time takes to compile / first model run for every model with different backends. So, you can know which backend you can use for quick prototyping and which one to use for long training jobs.

Thanks

@lezcano
Copy link

lezcano commented Apr 2, 2024

Also, I didn't follow up in the first issue, but if you are using cuda graphs, you need to run the model twice during warmp-up for the tracing to happen. This is not currently done in these benchmarks.

@haifeng-jin
Copy link
Owner

haifeng-jin commented Apr 3, 2024

Hi Everyone,

Thanks for all the comments!
I am trying to make the comparison as fair as possible.
So, if you find any unfair settings, please do let us know.

Here are my replies:

For the initial questions of this issue:

Why did you choose to use 'reduce-overhead' and not 'max-autotune' to compile native torch code? Pytorch docs recommends max-autotune for the best performance.

I mainly followed the suggestions in #1. The reduce-overhead seems a more transparent since max-autotune seems unclear what was done under the hood. You are welcome try with other modes and report the results.

Torch keras backend is generally slower than native torch. Sometimes, it's more than 2 times slower. Is keras using torch compile? If so, which mode is using?

Keras is not using torch.compile() by default. We have added the eager mode numbers for PyTorch native.
We still have a gap, but much smaller. I think it is mainly because of Keras is using the ops on a more fine grained level instead of calling the fused ops.

Some other concerns from people:

  1. "You have ignored an explicit warning about torch.set_float32_matmul_precision('high') for performance."

We use float32 for SAM on all frameworks.
Set it to "high" instead of "highest" is lowering the precision for one framework and leads to unfair comparison.

  1. "you are not using cudagraph [for SAM]"

It is because of the model implementation from Meta Research.
We don't change the source code from Meta Research.
Changing the code contradicts with our purpose of measuring the "out-of-the-box" performance.

The implementation is not optimal, but it is representative for a common way to use PyTorch.

So, we benchmark it, and be explicit about it in the post.
It is referred to as "less manually-optimized model".

@Chillee
Copy link

Chillee commented Apr 3, 2024

Set it to "high" instead of "highest" is lowering the precision for one framework and leads to unfair comparison.

In this case, the default in Jax/tensorflow is to turn on tensorfloat32 by default. So, in order to have a fair comparison, you need to enable tensorfloat32 in PyTorch.

@haifeng-jin
Copy link
Owner

Also, I didn't follow up in the first issue, but if you are using cuda graphs, you need to run the model twice during warmp-up for the tracing to happen. This is not currently done in these benchmarks.

Ah. Sorry for the oversight. The code is updated. Blog post update on the way.
We only used one batch for the warmup.

@haifeng-jin
Copy link
Owner

haifeng-jin commented Apr 3, 2024

Set it to "high" instead of "highest" is lowering the precision for one framework and leads to unfair comparison.

In this case, the default in Jax/tensorflow is to turn on tensorfloat32 by default. So, in order to have a fair comparison, you need to enable tensorfloat32 in PyTorch.

@Chillee
Thanks for pointing this out!
I didn't know this.

Is it just torch.backends.cuda.matmul.allow_tf32 = True?
If it is that simple, I can measure it again and update the results.

Just find another caveat. I think Keras is explicitly specifying the dtypes to the backend it is using.
So it passes float32 to TF/JAX. So I wonder if in that case it still uses tensorfloat32.
Is there a way to check if TF/JAX is actually using tensorfloat32?

Repository owner deleted a comment from bhack Apr 3, 2024
Repository owner deleted a comment from bhack Apr 3, 2024
Repository owner deleted a comment from bhack Apr 3, 2024
Repository owner deleted a comment from bhack Apr 3, 2024
Repository owner deleted a comment from bhack Apr 3, 2024
Repository owner deleted a comment from bhack Apr 3, 2024
Repository owner deleted a comment from bhack Apr 3, 2024
Repository owner deleted a comment from bhack Apr 3, 2024
@lezcano
Copy link

lezcano commented May 19, 2024

Was the tf32 point fixed in the end?

@haifeng-jin
Copy link
Owner

Hi @lezcano,

Thanks for following up!

Keras use "float32" explicitly when creating tensors. (code link) (code link) (code link)

So, to my understanding, Keras enforces TF and JAX to use "float32" instead of leaving it to the default value of the corresponding backend.

Therefore, the benchmarking code uses the same "float32" dtype for different backends. There is no fix required.

Everyone is welcome to check the actual dtype during runtime (if there is a way to do so) and post the method and results here!

@lezcano
Copy link

lezcano commented May 20, 2024

I don't think that is correct. TF32 is not a type per-se in keras or PyTorch. Both of them use it as a mode to perform fast float32 x float32 multiplication. As Horace mentioned, I'm pretty sure that JAX uses this by default. OTOH, in PyTorch is off by default, so you'd need to execute the PyTorch code with torch.backends.cuda.matmul.allow_tf32 = True to have a fair comparison.

@haifeng-jin
Copy link
Owner

haifeng-jin commented May 20, 2024

I see. And I trust your expertise in PyTorch.

However, I am unable to verify if TF32 is actually enabled for TF/JAX.
I would prefer to just turn it off completely for all frameworks with export NVIDIA_TF32_OVERRIDE=0 as suggested here.

What do you think?

@lezcano
Copy link

lezcano commented May 20, 2024

either that or enable it in both. I think a more reasonable comparison, given that you are benchmarking neural networks, would be to enable it in both, as that's what would be closer to reality (even better, networks would be executed on bfloat16 or amp, but yeah).

You can probably check that it exercises the tf32 cores by profiling the program with ncu and then looking at the cublas/cutlass kernels that it calls.

@lezcano
Copy link

lezcano commented May 20, 2024

To make sure they both use the same precision, you can manually enable the same flag in PyTorch (see above) and JAX

jax.config.update("jax_default_matmul_precision", "highest")

@haifeng-jin
Copy link
Owner

@lezcano Thank you so much for teaching me all these! I have just turned it off.

To everyone:

Unfortunately, I do not have bandwidth to keep working on these benchmarks.
It would be more responsible to just remove it rather than leaving them as they are.

Therefore,
I removed all the code for native PyTorch/HuggingFace part
(and only kept the Keras with different backends),
and had a new snapshot and release of the repo.

I saw a lot concerns from people about the code in this repo.
I thought these benchmarks would help people gain insights about different frameworks,
but it only caused more trouble instead.

It was mainly because of my limited understanding of the frameworks involved.
My sincere apologies if these benchmarks have caused any trouble for you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants