-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
vmap behaving unexpected with function that inverts singular matrix #15429
Comments
I can reproduce this on CPU. On GPU it works correctly.
GPU:
|
Shorter repro: import jax
import jax.numpy as jnp
x = jnp.ones((1, 2, 2))
print(jnp.linalg.inv(x))
# [[[ inf -inf]
# [-inf inf]]]
print(jax.vmap(jnp.linalg.inv)(x))
# [[[ 2. -1.]
# [-1. 1.]]] It probably has something to do with the |
I think I'm getting closer – this looks like it somehow comes from the batching rule of import jax
def solve(x, y):
return jax.lax.linalg.triangular_solve(x, y, left_side=True)
x = jnp.array([[1., 1.], [1., 0.]])
y = jnp.array([[1.], [0.]])
print(solve(x, y))
# [[nan]
# [nan]]
print(jax.vmap(solve)(x[None], y[None])[0])
# [[1.]
# [0.]] |
This is still an issue as of JAX v0.4.26. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Description
I ran into unexpected behavior when using vmap:
which outputs:
In the first test, one can appreciate that vmap correctly vectorized myfun when x.T@x can be inverted.
However, when x.T@x cannot be inverted, myfun "correctly" returns Infs, whereas vmap of myfun returns something else.
What jax/jaxlib version are you using?
jax 0.4.8
Which accelerator(s) are you using?
CPU
Additional system info
Python 3.11.1, macOS-13.3-arm64-arm-64bit
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered: