Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use exact GELU #4428

Closed
hendrycks opened this issue Oct 1, 2020 · 7 comments · Fixed by #4438
Closed

Use exact GELU #4428

hendrycks opened this issue Oct 1, 2020 · 7 comments · Fixed by #4438
Labels
enhancement New feature or request performance make things lean and fast

Comments

@hendrycks
Copy link

hendrycks commented Oct 1, 2020

jax.nn.gelu uses the approximate form of the GELU, but tensorflow and pytorch use the exact version.

I believe the exact form is more numerically stable and similarly fast. Figure 16 of the Performer paper (@xingyousong) showed the GELU running into NaN issues, and I suspect this is because jax uses the approximate version.

@hawkinsp
Copy link
Collaborator

hawkinsp commented Oct 1, 2020

Interesting! #1556 switched to the approximate version. @trevorcai @jekbradbury

@jekbradbury
Copy link
Contributor

Exact GeLU is significantly slower on TPUs (easily noticeable even in end-to-end step time). We’d be happy to take a PR adding the exact implementation as an option, but keeping the approximate one as the default?

@hendrycks
Copy link
Author

hendrycks commented Oct 1, 2020

Exact GeLU is significantly slower on TPUs

Interesting. For PyTorch this was not the case: pytorch/pytorch#39853 (comment) The exact version was slightly faster.

We’d be happy to take a PR

Sadly I don't know what optimizations would make them similarly fast like in PyTorch.

@jekbradbury
Copy link
Contributor

In JAX under a JIT, both versions are quite well optimized (fused, etc.). But TPUs are not very fast at certain kinds of vector math, and the exact GeLU happens to hit some of those cases (I think).

@hendrycks
Copy link
Author

hendrycks commented Oct 1, 2020

#1556 (comment) says

Confirming that my benchmarks are showing jax.grad(jax.jarrett(gelu)) as slower than jax.grad(gelu) on GPU as well.

So do you think the issue is with both TPUs and GPUs yet PyTorch on GPUs doesn't have a problem? Or do you think it's more of a TPU-specific issue?

@hawkinsp
Copy link
Collaborator

hawkinsp commented Oct 1, 2020

I just tried some new JAX timings on TPUv2 (two generations old) and V100 (one generation old), mostly because I had easy access to them via Colab.

I found that on V100 when compiled with jax.jit, the approximate formulation is 1.12x faster on the forward pass (which seems relatively insignificant), but on TPUv2 the approximate formulation is 1.75x faster. The difference on the backward pass was much smaller. I would guess this is related to erf being much more expensive to compute than tanh on TPU. It's possible we could optimize our erf implementation, which would presumably improve the relative performance of the exact formulation on both platforms.

Since the performance differences are so large, at the moment it does seem like we might do best to let users choose which they want.

@hawkinsp
Copy link
Collaborator

hawkinsp commented Oct 2, 2020

PR #4438 adds an jax.nn.gelu(..., approximate=True) keyword argument to select between the exact and approximate versions. It does not switch the default to the exact version yet.

Separately, I have a PR coming that optimizes the implementation of erf in XLA such that the exact formulation matches the performance of the approximation on V100, and comes very close on TPU.

However, I need to try some end-to-end benchmarking on TPU before switching the default; we are guessing that in the context of a larger model like BERT the approximation would still be faster.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request performance make things lean and fast
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants