-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Torch compile settings #2
Comments
Also, I didn't follow up in the first issue, but if you are using cuda graphs, you need to run the model twice during warmp-up for the tracing to happen. This is not currently done in these benchmarks. |
Hi Everyone, Thanks for all the comments! Here are my replies: For the initial questions of this issue:
I mainly followed the suggestions in #1. The
Keras is not using Some other concerns from people:
We use
It is because of the model implementation from Meta Research. The implementation is not optimal, but it is representative for a common way to use PyTorch. So, we benchmark it, and be explicit about it in the post. |
In this case, the default in Jax/tensorflow is to turn on tensorfloat32 by default. So, in order to have a fair comparison, you need to enable tensorfloat32 in PyTorch. |
Ah. Sorry for the oversight. The code is updated. Blog post update on the way. |
@Chillee Is it just Just find another caveat. I think Keras is explicitly specifying the dtypes to the backend it is using. |
Was the tf32 point fixed in the end? |
Hi @lezcano, Thanks for following up! Keras use So, to my understanding, Keras enforces TF and JAX to use Therefore, the benchmarking code uses the same Everyone is welcome to check the actual dtype during runtime (if there is a way to do so) and post the method and results here! |
I don't think that is correct. TF32 is not a type per-se in keras or PyTorch. Both of them use it as a mode to perform fast float32 x float32 multiplication. As Horace mentioned, I'm pretty sure that JAX uses this by default. OTOH, in PyTorch is off by default, so you'd need to execute the PyTorch code with |
I see. And I trust your expertise in PyTorch. However, I am unable to verify if TF32 is actually enabled for TF/JAX. What do you think? |
either that or enable it in both. I think a more reasonable comparison, given that you are benchmarking neural networks, would be to enable it in both, as that's what would be closer to reality (even better, networks would be executed on bfloat16 or amp, but yeah). You can probably check that it exercises the tf32 cores by profiling the program with ncu and then looking at the cublas/cutlass kernels that it calls. |
To make sure they both use the same precision, you can manually enable the same flag in PyTorch (see above) and JAX jax.config.update("jax_default_matmul_precision", "highest") |
@lezcano Thank you so much for teaching me all these! I have just turned it off. To everyone: Unfortunately, I do not have bandwidth to keep working on these benchmarks. Therefore, I saw a lot concerns from people about the code in this repo. |
Hi there,
First of all, thanks for the benchmark! It's very useful to see this nice comparison between keras back-ends and native Pytorch.
I have two doubts:
Finally, it could be useful to add how much time takes to compile / first model run for every model with different backends. So, you can know which backend you can use for quick prototyping and which one to use for long training jobs.
Thanks
The text was updated successfully, but these errors were encountered: