-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use exact GELU #4428
Comments
Interesting! #1556 switched to the approximate version. @trevorcai @jekbradbury |
Exact GeLU is significantly slower on TPUs (easily noticeable even in end-to-end step time). We’d be happy to take a PR adding the exact implementation as an option, but keeping the approximate one as the default? |
Interesting. For PyTorch this was not the case: pytorch/pytorch#39853 (comment) The exact version was slightly faster.
Sadly I don't know what optimizations would make them similarly fast like in PyTorch. |
In JAX under a JIT, both versions are quite well optimized (fused, etc.). But TPUs are not very fast at certain kinds of vector math, and the exact GeLU happens to hit some of those cases (I think). |
#1556 (comment) says
So do you think the issue is with both TPUs and GPUs yet PyTorch on GPUs doesn't have a problem? Or do you think it's more of a TPU-specific issue? |
I just tried some new JAX timings on TPUv2 (two generations old) and V100 (one generation old), mostly because I had easy access to them via Colab. I found that on V100 when compiled with Since the performance differences are so large, at the moment it does seem like we might do best to let users choose which they want. |
PR #4438 adds an Separately, I have a PR coming that optimizes the implementation of However, I need to try some end-to-end benchmarking on TPU before switching the default; we are guessing that in the context of a larger model like BERT the approximation would still be faster. |
jax.nn.gelu uses the approximate form of the GELU, but tensorflow and pytorch use the exact version.
I believe the exact form is more numerically stable and similarly fast. Figure 16 of the Performer paper (@xingyousong) showed the GELU running into NaN issues, and I suspect this is because jax uses the approximate version.
The text was updated successfully, but these errors were encountered: