Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sparse eigenvalue solvers #3112

Open
mganahl opened this issue May 15, 2020 · 47 comments
Open

sparse eigenvalue solvers #3112

mganahl opened this issue May 15, 2020 · 47 comments
Labels
contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. enhancement New feature or request

Comments

@mganahl
Copy link

mganahl commented May 15, 2020

Hi!

Our team is currently developing sparse symmetric and non-symmetric solvers in JAX (implicitly restarted arnoldi and lanzcos). I have it from hearsay that there are efforts to support this natively in JAX. I wanted to ask what the status there is. It would be actually great if JAX would provide those.

@shoyer
Copy link
Member

shoyer commented May 15, 2020

By "solvers" are you referring specifically to eigenvalue solvers?

Yes, I think we would be interested in adding some of these to JAX, but they would need to be reasonably robust and scalable, ideally comparable to the SciPy algorithms. The "basic" Lanczos method would probably not qualify, but restarted Arnoldi (used by SciPy in ARPACK) probably would.

It would also be nice to include auto-diff rules for these operations, but that could come later.

@shoyer
Copy link
Member

shoyer commented May 15, 2020

I played around with writing a thick-restart Lanczos solver. As a point of reference, you can find it in #3114 but I'm not sure we actually want to merge it.

@shoyer shoyer changed the title sparse solvers sparse eigenvalue solvers May 15, 2020
@chaserileyroberts
Copy link
Contributor

We have an implementation at https://github.com/google/TensorNetwork/blob/d4ec0a381dbf1a7d453d97525ccc857fc416b575/tensornetwork/backends/jax/jax_backend.py#L230

Not sure if it would useful to have in JAX proper.

@mganahl
Copy link
Author

mganahl commented May 16, 2020

This version does not support restarts right now, and is in general geared towards tensornetwork applications. It could serve as a starting point for an implicitly restarted lanczos.

@mganahl
Copy link
Author

mganahl commented May 16, 2020

I had a look at #3114, and this would indeed be interesting for us. The three most prevalent problems that pop up frequently in our applications are sparse symmetric (mostly SA or LA eigenvalues, usually only one to a few), sparse non-symmetric (with LR eigenvalues), and linear system solvers (symmetric and non-symmetric, i.e. gmres or lgmres), all of which we are currently working on.
The requirement in terms of robustness and accuracy are often less rigorous for tensor networks than for other applications, hence our implementations are more mundane. But if the JAX team was interested as well in supporting those, it could make sense to join efforts.

@shoyer
Copy link
Member

shoyer commented May 16, 2020

GMRES and L-GMRES for linear solves would also be quite welcome in JAX. These would fit in quite naturally alongside our new CG solver.

@shoyer shoyer added the contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. label May 17, 2020
@momchilmm
Copy link

momchilmm commented May 28, 2020

It would also be nice to include auto-diff rules for these operations, but that could come later.

Just a note on this: here I have defined an autograd primitive wrapping scipy.sparse.linalg.eigsh (applied to regular Numpy matrices, not sparse ones). The vjp is almost identical to the one for numpy.linalg.eigh, but the summation in the backprop of eigenvector gradient is restricted only the numeig number of computed eigenvectors. This makes the eigenvector gradient approximate (becoming exact in the limit of numeig going to the linear size of the matrix). The gradient w.r.t. the eigenvalues however is exact. So you'll have to decide if you only support auto-diff w.r.t. the eigenvalues (I assume you don't want appoximate results...)

@mattjj
Copy link
Member

mattjj commented May 28, 2020

Wowww cool, that is really helpful @momchilmm !

@shoyer
Copy link
Member

shoyer commented May 28, 2020

Indeed, very cool to see! I also made a note with a few references about methods for calculating eigenvector derivatives from a partial decomposition over in #3114.

@momchilmm
Copy link

@shoyer oh this is very interesting! I might consider adding something like this in my autograd implementation because storing all (or many of) the eigenvectors to get the exact gradient is sometimes a significant memory overhead.

@momchilmm
Copy link

@shoyer so I wrote a method to compute the eigenvector derivatives from a partial decomposition, along the lines of the works you had pointed out, and it works! Here's the vjp, and a test. The works mentioned in #3114 are overly complex because they consider degenerate eigenvalues, so I ended up following Steven Johnson's notes here, which I extended to the case of a Hermitian matrix with a small modification.

So yeah this seems to work for non-degenerate eigenvalues. I do think (similarly to you I believe) that the gradient is not a well-defined quantity in the case of degenerate eigenvalues, specifically when we have more than one input parameters. Basically, for a single parameter, you can define the derivative of each eigenvalue by choosing eigenvectors in your subspace that are also eigenvectors of the system with the corresponding small perturbation. However, this choice of eigenvectors will be different for every different input parameter (corresponding to a different matrix perturbation), and so a full gradient cannot be defined.

@jackd
Copy link
Contributor

jackd commented Feb 2, 2021

@momchilmm I've been playing around with your vjp, and I occasionally get wildly inaccurate gradients. I'm guessing this is related to the cg-solve operating on a necessarily non-positive definite matrix. Have you enountered this? Any thoughts?

Update: this was due to an error with my eigenvector calculation (signs were inconsistent in evaluations for forward difference calculation I was using to check the gradient implementation).

@momchilmm
Copy link

@jackd I think you are right! I have not encountered it myself since it seems to just work in some cases, but I can see how that might be a problem in others. In the Steven Johsnon notes there's a footnote:

Since P commutes with A−α, we can solve for λ 0 easily by
an iterative method such as conjugate gradient: if we start with an
initial guess orthogonal to x, all subsequent iterates will also be
orthogonal to x and will thus converge to λ 0 (except for roundoff,
which can be corrected by multiplying the final result by P).

However I don't think that would guarantee convergence, unless I'm missing something? I'm also not sure what's the best method to use for this matrix - we only know that it's Hermitian. But you could try just changing the solver to bicg or gmers and see if your troubles are gone. If you see it becoming stable I encourage you to submit a PR!

@jackd
Copy link
Contributor

jackd commented Feb 2, 2021

@momchilmm glad I'm not going crazy :). I'm not having trouble with the smallest eigenvalue - and if I were I suspect a small diagonal shift might be enough to resolve it - but I am finding my tests often break as I increase the number of eigenvectors solved.

I'm looking at MINRES-QLP as a substitute, as it explicitly caters for singular symmetric matrices with pre-conditioning - alas there's no jax implementation of that though, so it's not just a one-line change :S.

@momchilmm
Copy link

Yeah there doesn't seem to be anything in scipy.sparse.linalg that's expected to always work. I think the footnote that I quoted above means that in most cases you don't have to worry that the matrix is singular, because if you start with an initial guess that's orthogonal to the eigenvector (which spans the kernel), you'll always stay outside of the kernel. So the fact that the matrix is not positive-definite is the issue. The fact that you don't have trouble with the smallest eigenvalue is because in the iteration the vector is restricted to the subspace orthogonal to the corresponding eigenvector, and all the eigenvalues mapping to that space are larger than 0 (they are shifted by the smallest eigenvalue).

Minres-qld (or minres-qlp) seems to be the only thing I find for Hermitian matrices too. There's a freely available python implementation if you want to try your luck...

@jackd
Copy link
Contributor

jackd commented Feb 2, 2021

Sigh, false alarm. Turns out my errors were due to my lobpcg implementation returning eigenvectors with different signs in the forward difference gradient check. I'll keep this in mind if I get any more NaN errors / gradient mis-matches in the future, but I have no evidence to suggest this is a problem at this stage. Thanks anyway :)

@momchilmm
Copy link

Ah. Yeah. If your objective function depends on the sign (or complex phase more generally) of the eigenvectors, and you're not deterministically setting it in your solver, then you're in trouble yeah. That's one of the main reasons this current issue exists.

@jackd
Copy link
Contributor

jackd commented Feb 3, 2021

If anyone's looking for a LOBPCG implementation I've hacked up a basic implementation here. The sparse implementation works without jitting - but I need to work out how to resolve this issue before the sparse jitted version will work.

There's also some vjps here based on @momchilmm's work above. Needs some more documentation, but hopefully the tests show integration with lobpcg.

@lobpcg
Copy link

lobpcg commented Mar 28, 2022

If anyone's looking for a LOBPCG implementation I've hacked up a basic implementation here. The sparse implementation works without jitting - but I need to work out how to resolve this issue before the sparse jitted version will work.

There's also some vjps here based on @momchilmm's work above. Needs some more documentation, but hopefully the tests show integration with lobpcg.

See also https://github.com/eserie/jju

@vlad17
Copy link
Contributor

vlad17 commented Apr 1, 2022

Hey @jackd , are you planning on putting together a PR to land a version of lobpcg into jax core?

@jackd
Copy link
Contributor

jackd commented Apr 1, 2022

@vlad17 no plan to right now. I'm willing to bet it wouldn't be a high priority from the jax core team even if it fit in the package (from memory mine is based on a pytorch implementation as opposed to the scipy one). Anyone else can feel free to use it however they want (e.g. for their own PR), though I haven't looked at it for a while so @lobpcg 's more recent adaption is probably in a better state.

@vlad17
Copy link
Contributor

vlad17 commented Apr 12, 2022

@lobpcg I've been toying with a simplified but fairly stable version (as in, seems to better than scipy) of LOBPCG which solves only the standard eigenvalue problem. I think JAX users would benefit from a PR which puts this into an experimental jax directory. Would you be able to help review it for correctness?

vlad17 pushed a commit to vlad17/jax that referenced this issue Jun 3, 2022
This initial version is f32-only for accelerators, since it relies on an eigh call (which itself is f32 at most) in its inner loop.

For details, see jax.experimental.linalg.standard_lobpcg documentation.
vlad17 pushed a commit to vlad17/jax that referenced this issue Jun 3, 2022
This initial version is f32-only for accelerators, since it relies on an eigh call (which itself is f32 at most) in its inner loop.

For details, see jax.experimental.linalg.standard_lobpcg documentation.
vlad17 added a commit to vlad17/jax that referenced this issue Jun 3, 2022
This initial version is f32-only for accelerators, since it relies on an eigh call (which itself is f32 at most) in its inner loop.

For details, see jax.experimental.linalg.standard_lobpcg documentation.
@vlad17
Copy link
Contributor

vlad17 commented Jul 6, 2022

@mganahl jax.experimental.sparse.linalg.lobpcg_standard is in master. Right now it's top-k only, no jvp, so pretty barebones, but feel free to give it a spin and let me know what you think.

LenaMartens pushed a commit to LenaMartens/jax that referenced this issue Jul 8, 2022
This initial version is f32-only for accelerators, since it relies on an eigh call (which itself is f32 at most) in its inner loop.

For details, see jax.experimental.linalg.standard_lobpcg documentation.

This is a partial implementation of the similar [scipy lobpcg
function](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.lobpcg.html).
LenaMartens pushed a commit to LenaMartens/jax that referenced this issue Jul 8, 2022
I had erroneously assumed that GPU would be as-high accuracy for f64 (both in numerics and eigh) when submitting google#3112, so I did not disable f64 tests on that platform. This is of course not the case, so those tests should be disabled.
@choltz95
Copy link

choltz95 commented Jan 2, 2023

Thanks a ton for your work on this. jax.experimental.sparse has been generally really helpful for my projects & I appreciate that it's actively developed. I think something that would be nice is having an argument for a preconditioner + bottom-k eigenvalues. I saw that there are already some comments in master about it. I have been doing this with some success on big Laplacians & happy to submit a PR, but unsure if there are design considerations.

On something else - I know preconditioning has been discussed in the main matrix-free thread (#1531), but I also want to express my interest in something like an xla accelerated multigrid preconditioner.

@vlad17
Copy link
Contributor

vlad17 commented Jan 3, 2023

bottom-k and preconditioner are reasonable next steps; I could add them (or review) if there are enough use cases that they'd be high priority.

@rmlarsen mentioned some interest in bottom-k, though not sure if he knows about possible users for it.

@lobpcg
Copy link

lobpcg commented Jan 3, 2023

@rmlarsen @choltz95 The bottom-k of a matrix A is trivial to run in the existing code just run the code on negative matrix -A and change the sign of the eigenvalues to the opposite in the existing code at a level of the user.

@vlad17 Or modify the code to select the smallest eigenvalues instead of the largest in the RR; cf. option "largest" in https://github.com/scipy/scipy/blob/main/scipy/sparse/linalg/_eigen/lobpcg/lobpcg.py

@rmlarsen
Copy link
Contributor

rmlarsen commented Jan 3, 2023

When I said "smallest" I meant smallest absolute value. Of course flipping the sign of A is trivial. It is extremely useful when dealing with ill-conditioned matrices in a variety of applications. When the matrix is indeed Hermitian, computing these directly instead of "naively" using an SVD solver, can be much faster or more accurate.

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 3, 2023

How hard would it be to add a shift-invert mode to the current lobpcg solver?

@rmlarsen
Copy link
Contributor

rmlarsen commented Jan 3, 2023

The way I read the paper, the beautiful thing about LOBPCG is its ability to solve for 1/lambda by a change of variables in the generalized eigenvalue problem. It may not be super efficient, but it's easy to use and presumably much more efficient than a naive shift-invert implementation(?)

@choltz95
Copy link

choltz95 commented Jan 3, 2023

For my specific use case w/ Laplacians, there isn't too much ambiguity about the eigenvalues (large, small, etc. are all nonnegative). So flipping the order of RR like prev. mentioned is great. At least for me, the important thing is that preconditioning (which is already there, just commented out) can be valuable for large ill-conditioned problems.

@lobpcg
Copy link

lobpcg commented Jan 3, 2023

@rmlarsen Solving Ax = lambda Bx in LOBPCG, the matrix B must be positive definite (PD). If A is not PD, you cannot plug Bx = 1/lambda Ax into LOBPCG

@rmlarsen
Copy link
Contributor

rmlarsen commented Jan 3, 2023

@lobpcg So if B is the identity, I can use LOBPCG for this, right?

@lobpcg
Copy link

lobpcg commented Jan 3, 2023

@jakevdp LOBPCG is matrix-free, so solving Ax = lambda Bx in LOBPCG one needs only the multiplications by A and B, e.g., via external functions, but B must be PD.

One can solve Ax = lambda x in the shift-invert mode with a fixed shift alpha in LOBPCG, e.g., by plugging inv(A - alpha I) x = gamma x, by providing a function that solves (A - alpha I) x = b and going for largest or smallest gamma, which would give lambdas around the shift alpha, separately below and above alpha. Since in this case only the action of inv(A - alpha I) is provided to the code, alpha should be fixed and the solve (A - alpha I) x = b should be done accurately enough. I have not tested the algorithm for this scenario so unsure how accurate exactly.

@rmlarsen Alternatively, one can plug - (A - alpha I)^2 x = gamma x and go for the largest (negative in this case) gamma, e.g., see https://www.researchgate.net/publication/343531770_Eigenvalue_solvers_for_computer_simulations_of_efficient_solar_cell_materials

@lobpcg
Copy link

lobpcg commented Jan 3, 2023

@lobpcg So if B is the identity, I can use LOBPCG for this, right?

Let me clarify: LOBPCG solves for smallest or largest eigenvalues of Ax = lambda Bx, where A is symmetric, and B is symmetric positive definite, e.g., B=I. Matrices A and B also need to be more-or-less fixed, e.g., if you plug into LOBPCG the problem inv(A - alpha I) x = gamma x and change alpha during the iterations that you likely break the code.

Whatever eigenvalue problem you want to solve originally, you need to transform it mathematically into this form and satisfy these assumptions. If these assumptions are violated, the code would most likely break.

@rmlarsen
Copy link
Contributor

rmlarsen commented Jan 3, 2023

@lobpcg and by "smallest lambda" you mean the "largest of 1/lambda"?

@lobpcg
Copy link

lobpcg commented Jan 3, 2023

@lobpcg and by "smallest lambda" you mean the "largest of 1/lambda"?

Remember that A does not need to be positive, only B must be positive, so lambda can be positive, negative, or zero. One could only plug Bx = 1/lambda Ax into LOBPCG if A is positive.

@rmlarsen
Copy link
Contributor

rmlarsen commented Jan 3, 2023

@lobpcg ah, thanks for clarifying. So in the case where A is positive, is there any advantage of rewriting to Bx = 1/lambda Ax, rather than solving for the largest eigenvalues of -A (when B=I). I suppose the gap structure and thus speed of convergence comes into play then?

@lobpcg
Copy link

lobpcg commented Jan 3, 2023

@lobpcg ah, thanks for clarifying. So in the case where A is positive, is there any advantage of rewriting to Bx = 1/lambda Ax, rather than solving for the largest eigenvalues of -A (when B=I). I suppose the gap structure and thus speed of convergence comes into play then?

Good question. Without doing any math, I would guess that that speed should probably be the same, since in both cases one only multiplies by A anyways. The convergence speed is technically determined by the relative gap, so one should be able to write it down and compare. Too lazy to do it myself at the moment :=)

@rmlarsen
Copy link
Contributor

rmlarsen commented Jan 3, 2023

@lobpcg you are right. The relative gap size should be the same to first order.

@jemisjoky
Copy link

jemisjoky commented Jan 11, 2023

On a much more mundane level, I just wanted to point out that jax.experimental.sparse.linalg.lobpcg_standard currently crashes when a complex Hermitian matrix is input. I'm including a script below that reproduces this issue, but the fix is simply to change this line to theta = jnp.sum(X * AX, axis=0, keepdims=True).real.

Script and output

SCRIPT

import jax.numpy as jnp
from jax.experimental.sparse.linalg import lobpcg_standard

for dtype in ("float32", "complex64"):
    matrix = jnp.zeros((6, 6), dtype=dtype).at[0, 0].set(1)
    init_vec = jnp.ones((6, 1), dtype=dtype)

    output = lobpcg_standard(matrix, init_vec)
    print(f"LOBPCG works with {dtype}")

OUTPUT

LOBPCG works with float32
Traceback (most recent call last):
  File "jax_bug_test.py", line 9, in <module>
    eig_info = lobpcg_standard(matrix, init_vec)
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/experimental/sparse/linalg.py", line 97, in lobpcg_standard
    return _lobpcg_standard_matrix(A, X, m, tol, debug=False)
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/_src/api.py", line 622, in cache_miss
    execute = dispatch._xla_call_impl_lazy(fun_, *tracers, **params)
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/_src/dispatch.py", line 241, in _xla_call_impl_lazy
    return xla_callable(fun, device, backend, name, donated_invars, keep_unused,
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/linear_util.py", line 303, in memoized_fun
    ans = call(fun, *args)
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/_src/dispatch.py", line 357, in _xla_callable_uncached
    computation = sharded_lowering(fun, device, backend, name, donated_invars,
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/_src/dispatch.py", line 348, in sharded_lowering
    return pxla.lower_sharding_computation(
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 2792, in lower_sharding_computation
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2065, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1998, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/experimental/sparse/linalg.py", line 108, in _lobpcg_standard_matrix
    return _lobpcg_standard_callable(
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/_src/api.py", line 626, in cache_miss
    top_trace.process_call(primitive, fun_, tracers, params))
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1739, in process_call
    jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(f, self.main, debug_info=dbg)
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2027, in trace_to_subjaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers_)
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/experimental/sparse/linalg.py", line 231, in _lobpcg_standard_callable
    state = jax.lax.while_loop(cond, body, state)
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/_src/lax/control_flow/loops.py", line 1119, in while_loop
    _check_tree_and_avals("body_fun output and input",
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/_src/lax/control_flow/common.py", line 108, in _check_tree_and_avals
    raise TypeError(f"{what} must have identical types, got\n{diff}.")
jax._src.traceback_util.UnfilteredStackTrace: TypeError: body_fun output and input must have identical types, got
('ShapedArray(int32[], weak_type=True)', 'ShapedArray(complex64[6,1])', 'ShapedArray(complex64[6,1])', 'ShapedArray(complex64[6,1])', 'ShapedArray(int32[])', 'DIFFERENT ShapedArray(float32[1,1]) vs. ShapedArray(complex64[1,1])').

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "jax_bug_test.py", line 9, in <module>
    eig_info = lobpcg_standard(matrix, init_vec)
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/experimental/sparse/linalg.py", line 97, in lobpcg_standard
    return _lobpcg_standard_matrix(A, X, m, tol, debug=False)
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/experimental/sparse/linalg.py", line 108, in _lobpcg_standard_matrix
    return _lobpcg_standard_callable(
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/experimental/sparse/linalg.py", line 231, in _lobpcg_standard_callable
    state = jax.lax.while_loop(cond, body, state)
TypeError: body_fun output and input must have identical types, got
('ShapedArray(int32[], weak_type=True)', 'ShapedArray(complex64[6,1])', 'ShapedArray(complex64[6,1])', 'ShapedArray(complex64[6,1])', 'ShapedArray(int32[])', 'DIFFERENT ShapedArray(float32[1,1]) vs. ShapedArray(complex64[1,1])').
(umps38) ~/Orquestra/z-qml-nlp (new_core_types ✗) python jax_bug_test.py
LOBPCG works with float32
Traceback (most recent call last):
  File "jax_bug_test.py", line 8, in <module>
    eig_info = lobpcg_standard(matrix, init_vec)
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/experimental/sparse/linalg.py", line 97, in lobpcg_standard
    return _lobpcg_standard_matrix(A, X, m, tol, debug=False)
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/_src/api.py", line 622, in cache_miss
    execute = dispatch._xla_call_impl_lazy(fun_, *tracers, **params)
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/_src/dispatch.py", line 241, in _xla_call_impl_lazy
    return xla_callable(fun, device, backend, name, donated_invars, keep_unused,
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/linear_util.py", line 303, in memoized_fun
    ans = call(fun, *args)
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/_src/dispatch.py", line 357, in _xla_callable_uncached
    computation = sharded_lowering(fun, device, backend, name, donated_invars,
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/_src/dispatch.py", line 348, in sharded_lowering
    return pxla.lower_sharding_computation(
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 2792, in lower_sharding_computation
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2065, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1998, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/experimental/sparse/linalg.py", line 108, in _lobpcg_standard_matrix
    return _lobpcg_standard_callable(
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/_src/api.py", line 626, in cache_miss
    top_trace.process_call(primitive, fun_, tracers, params))
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1739, in process_call
    jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(f, self.main, debug_info=dbg)
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2027, in trace_to_subjaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers_)
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/experimental/sparse/linalg.py", line 231, in _lobpcg_standard_callable
    state = jax.lax.while_loop(cond, body, state)
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/_src/lax/control_flow/loops.py", line 1119, in while_loop
    _check_tree_and_avals("body_fun output and input",
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/_src/lax/control_flow/common.py", line 108, in _check_tree_and_avals
    raise TypeError(f"{what} must have identical types, got\n{diff}.")
jax._src.traceback_util.UnfilteredStackTrace: TypeError: body_fun output and input must have identical types, got
('ShapedArray(int32[], weak_type=True)', 'ShapedArray(complex64[6,1])', 'ShapedArray(complex64[6,1])', 'ShapedArray(complex64[6,1])', 'ShapedArray(int32[])', 'DIFFERENT ShapedArray(float32[1,1]) vs. ShapedArray(complex64[1,1])').

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "jax_bug_test.py", line 8, in <module>
    eig_info = lobpcg_standard(matrix, init_vec)
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/experimental/sparse/linalg.py", line 97, in lobpcg_standard
    return _lobpcg_standard_matrix(A, X, m, tol, debug=False)
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/experimental/sparse/linalg.py", line 108, in _lobpcg_standard_matrix
    return _lobpcg_standard_callable(
  File "/Users/jemis/opt/miniconda3/envs/umps38/lib/python3.8/site-packages/jax/experimental/sparse/linalg.py", line 231, in _lobpcg_standard_callable
    state = jax.lax.while_loop(cond, body, state)
TypeError: body_fun output and input must have identical types, got
('ShapedArray(int32[], weak_type=True)', 'ShapedArray(complex64[6,1])', 'ShapedArray(complex64[6,1])', 'ShapedArray(complex64[6,1])', 'ShapedArray(int32[])', 'DIFFERENT ShapedArray(float32[1,1]) vs. ShapedArray(complex64[1,1])').

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 11, 2023

@jemisjoky good catch! Are you interested in doing a pull request with the fix? If not I can look into it.

@jemisjoky
Copy link

@jakevdp, I'd love to! I honestly haven't done a PR with a major project like JAX before, but the contributor guide is pretty clear here. Will let you know if I have any questions in the process!

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 12, 2023

Sounds good, thanks!

@gautierronan
Copy link

Hi, apologies for the bump. Is there any news on the implementation of sparse eigensolvers for smallest eigenvalues? Currently, jax.experimental.sparse.linalg.lobpcg_standard seems to only support largest eigenvalues. Thanks!

@lobpcg
Copy link

lobpcg commented May 23, 2024

Hi, apologies for the bump. Is there any news on the implementation of sparse eigensolvers for smallest eigenvalues? Currently, jax.experimental.sparse.linalg.lobpcg_standard seems to only support largest eigenvalues. Thanks!

Just run LOBPCG for -A where A is the matrix of the eigenvalue problem and multiply the largest eigenvalues of -A by -1.

@rmlarsen
Copy link
Contributor

rmlarsen commented May 23, 2024

I agree that running LOBPCG is probably your best choice. If by smallest you mean of smallest magnitude, you can try running LOBPCG on a functor that applies the inverse of A to a vector, if that is feasible.

Notice that we recently added a mechanism subset_by_index for selecting a subset of eigenvalues in the spectral bisection eigensolver here: https://github.com/google/jax/blob/main/jax/_src/lax/eigh.py#L506
It is also available for the corresponding SVD. There is no such thing for general non-Hermitian matrices.

The subset_by_index mechanism has not been plumbed through to the default eigensolvers used on CPU and GPU, so I think you have to call the LAX version directly.

@lobpcg
Copy link

lobpcg commented May 23, 2024

The inverse is typically infeasible although may give a much faster convergence in theory.

A function multiplying by -A is always feasible. That's the approach used, e.g., in LOBPCG's scipy implementation.

In contrast to a popular misunderstanding, LOBPCG does not require the matrix to be positive definite.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.