New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CPU] np.linalg.eigh
outputs NaNs for matrix side size >=32767
#10420
Comments
I'm reasonably sure what's happening here is that JAX is passing parameters to LAPACK correctly, but we're using a LAPACK built with 32-bit integers. I think we probably need to switch to an ILP64-built LAPACK. However that means that we'll have to get LAPACK from somewhere other than SciPy's cython exports. Another option that might work for this specific function is to use a different LAPACK driver function. I believe |
As a first step to resolve this issue, I suggest catching 32-bit integer overflow before executing LAPACK functions as these will most likely produce garbage or lead to crashes in the case of overflow anyway. The overflow problem applies to all LAPACK functions, not just for the ones used in linalg.eigh. Switching to ILP64-built LAPACK would be the next step. There exists many options that require some effort. For instance, using ILP64-enabled scipy (it likely requires some work at the scipy side that has ILP64 support but not in releases) or use some other LAPACK library as an dependency such as Intel MKL, etc that provide the ILP64 support. |
Does scipy have plans to make an ILP64 release? |
I'm not excited to add our own build of LAPACK, Fortran toolchain and all, but another possibility is to scavenge a different LAPACK from the environment, e.g., perhaps if the user installs an ILP64 LAPACK through |
There is https://pypi.org/project/scipy-openblas64/ that provides LAPACK in ILP64 mode. |
jax-0.3.7 jaxlib-0.3.7 numpy-1.21.6
Outputs, for
n
in1000
,10_000
,32_766
, and32_767
:The correct answer that should be printed for each
n
isInterestingly, for
n = 50_000
, there happens a different issue described in #10411:The outputs are printed out correctly, but the program also outputs new error messages, and
terminated by signal SIGABRT (Abort)
. I think this causes running the same computation in Colab to crash (#10411).Related to: #4358.
The text was updated successfully, but these errors were encountered: