Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix JVP rule for lax.pow() #12041

Merged
merged 1 commit into from
Aug 23, 2022
Merged

Conversation

jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Aug 22, 2022

Fixes #12033

@jakevdp
Copy link
Collaborator Author

jakevdp commented Aug 22, 2022

I double-checked against the analog of the code in HIPS/autograd#116, which inspired the introduction of this select statement. I get the same result before and after the change in this PR:

import jax
import jax.numpy as jnp

def fun(x):
    return jnp.exp(-(x*x))
d1 = jax.grad(fun)
d2 = jax.grad(d1)
d3 = jax.grad(d2)
d4 = jax.grad(d3)
print(d1(0.), d2(0.), d3(0.), d4(0.))
# -0.0 -2.0 0.0 12.0

def fun(x):
    return jnp.exp(-(x**2))
d1 = jax.grad(fun)
d2 = jax.grad(d1)
d3 = jax.grad(d2)
d4 = jax.grad(d3)
print(d1(0.), d2(0.), d3(0.), d4(0.))
# -0.0 -2.0 0.0 12.0

@jakevdp
Copy link
Collaborator Author

jakevdp commented Aug 22, 2022

So, it turns out there is one (untested) case that this PR changes:

jax.grad(jax.lax.pow, 0)(0.0, 0.0)

On main, this returns 0.0; with this PR, it returns NaN.

Chatting offline with @mattjj, we realized this is similar in spirit to the sinc zero issue in (#5054); the resolution there was to do a custom JVP rule using the maclaurin series (#5077), and I think the most rigorously-correct solution here will probably be similar.

But as the current approach returns incorrect results, we'll merge this change before working on the 0**0 corner case, which will take a bit more thought.

@jakevdp jakevdp force-pushed the fix-pow-jvp branch 2 times, most recently from 088fcd9 to 73dcd74 Compare August 22, 2022 18:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

The second order gradient of lax.pow is wrong
3 participants