Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vmap behaving unexpected with function that inverts singular matrix #15429

Open
bjeurissen opened this issue Apr 6, 2023 · 4 comments
Open
Assignees
Labels
bug Something isn't working XLA

Comments

@bjeurissen
Copy link

Description

I ran into unexpected behavior when using vmap:

import jax.numpy as jnp
from jax import vmap, __version__
import platform
print(platform.platform())
print(platform.python_version())
print(__version__)

def myfun(x):
    return jnp.linalg.inv(x.T@x)

myfun_many = vmap(myfun, in_axes=2, out_axes=2)

print('test1')
x1 = jnp.array([[1.0,3.0],[-5.0,1.0]])
y1 = myfun(x1)
print(y1)

x2 = jnp.stack((x1,x1,x1),2)
assert (x2[:,:,0]==x1).all()
assert (x2[:,:,1]==x1).all()
assert (x2[:,:,2]==x1).all()
y2 = myfun_many(x2)
print(y2[:,:,0]) # this is the same as y1
print(y2[:,:,1]) # this is the same as y1
print(y2[:,:,2]) # this is the same as y1


print('test2')
x1 = jnp.array([[1.0,1.0],[1.0,1.0]])
y1 = myfun(x1)
print(y1)

x2 = jnp.stack((x1,x1,x1),2)
assert (x2[:,:,0]==x1).all()
assert (x2[:,:,1]==x1).all()
assert (x2[:,:,2]==x1).all()
y2 = myfun_many(x2)
print(y2[:,:,0]) # this is not the same as y1!
print(y2[:,:,1]) # this is not the same as y1!
print(y2[:,:,2]) # this is not the same as y1!

which outputs:

macOS-13.3-arm64-arm-64bit
3.11.1
0.4.8
test1
[[0.0390625  0.0078125 ]
 [0.0078125  0.10156249]]
[[0.0390625  0.0078125 ]
 [0.0078125  0.10156249]]
[[0.0390625  0.0078125 ]
 [0.0078125  0.10156249]]
[[0.0390625  0.0078125 ]
 [0.0078125  0.10156249]]
test2
[[ inf -inf]
 [-inf  inf]]
[[ 1.5 -1. ]
 [-1.   1. ]]
[[ 1.5 -1. ]
 [-1.   1. ]]
[[ 1.5 -1. ]
 [-1.   1. ]]

In the first test, one can appreciate that vmap correctly vectorized myfun when x.T@x can be inverted.

However, when x.T@x cannot be inverted, myfun "correctly" returns Infs, whereas vmap of myfun returns something else.

What jax/jaxlib version are you using?

jax 0.4.8

Which accelerator(s) are you using?

CPU

Additional system info

Python 3.11.1, macOS-13.3-arm64-arm-64bit

NVIDIA GPU info

No response

@bjeurissen bjeurissen added the bug Something isn't working label Apr 6, 2023
@alonfnt
Copy link
Contributor

alonfnt commented May 13, 2023

I can reproduce this on CPU. On GPU it works correctly.
CPU

Linux-5.19.0-41-generic-x86_64-with-glibc2.35
3.10.6
0.4.8
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[CpuDevice(id=0)]
test1
[[0.0390625  0.0078125 ]
 [0.0078125  0.10156249]]
[[0.0390625  0.0078125 ]
 [0.0078125  0.10156249]]
[[0.0390625  0.0078125 ]
 [0.0078125  0.10156249]]
[[0.0390625  0.0078125 ]
 [0.0078125  0.10156249]]
test2
[[ inf -inf]
 [-inf  inf]]
[[ 1.5 -1. ]
 [-1.   1. ]]
[[ 1.5 -1. ]
 [-1.   1. ]]
[[ 1.5 -1. ]
 [-1.   1. ]]

GPU:

Linux-5.19.0-41-generic-x86_64-with-glibc2.35
3.10.6
0.4.8
[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
test1
[[0.0390625  0.0078125 ]
 [0.0078125  0.10156249]]
[[0.0390625  0.0078125 ]
 [0.0078125  0.10156249]]
[[0.0390625  0.0078125 ]
 [0.0078125  0.10156249]]
[[0.0390625  0.0078125 ]
 [0.0078125  0.10156249]]
test2
[[ inf -inf]
 [-inf  inf]]
[[ inf -inf]
 [-inf  inf]]
[[ inf -inf]
 [-inf  inf]]
[[ inf -inf]
 [-inf  inf]]

@jakevdp jakevdp self-assigned this Jun 21, 2023
@jakevdp
Copy link
Collaborator

jakevdp commented Jun 21, 2023

Shorter repro:

import jax
import jax.numpy as jnp

x = jnp.ones((1, 2, 2))
print(jnp.linalg.inv(x))
# [[[ inf -inf]
#   [-inf  inf]]]

print(jax.vmap(jnp.linalg.inv)(x))
# [[[ 2. -1.]
#   [-1.  1.]]]

It probably has something to do with the custom_linear_solve batching rule. I'm going to dig into it.

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 28, 2023

I think I'm getting closer – this looks like it somehow comes from the batching rule of jax.lax.triangular_solve:

import jax

def solve(x, y):
  return jax.lax.linalg.triangular_solve(x, y, left_side=True)

x = jnp.array([[1., 1.], [1., 0.]])
y = jnp.array([[1.], [0.]])

print(solve(x, y))
# [[nan]
#  [nan]]

print(jax.vmap(solve)(x[None], y[None])[0])
# [[1.]
#  [0.]]

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 10, 2024

This is still an issue as of JAX v0.4.26.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working XLA
Projects
None yet
Development

No branches or pull requests

3 participants