-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
sparse eigenvalue solvers #3112
Comments
By "solvers" are you referring specifically to eigenvalue solvers? Yes, I think we would be interested in adding some of these to JAX, but they would need to be reasonably robust and scalable, ideally comparable to the SciPy algorithms. The "basic" Lanczos method would probably not qualify, but restarted Arnoldi (used by SciPy in ARPACK) probably would. It would also be nice to include auto-diff rules for these operations, but that could come later. |
I played around with writing a thick-restart Lanczos solver. As a point of reference, you can find it in #3114 but I'm not sure we actually want to merge it. |
We have an implementation at https://github.com/google/TensorNetwork/blob/d4ec0a381dbf1a7d453d97525ccc857fc416b575/tensornetwork/backends/jax/jax_backend.py#L230 Not sure if it would useful to have in JAX proper. |
This version does not support restarts right now, and is in general geared towards tensornetwork applications. It could serve as a starting point for an implicitly restarted lanczos. |
I had a look at #3114, and this would indeed be interesting for us. The three most prevalent problems that pop up frequently in our applications are sparse symmetric (mostly SA or LA eigenvalues, usually only one to a few), sparse non-symmetric (with LR eigenvalues), and linear system solvers (symmetric and non-symmetric, i.e. gmres or lgmres), all of which we are currently working on. |
GMRES and L-GMRES for linear solves would also be quite welcome in JAX. These would fit in quite naturally alongside our new CG solver. |
Just a note on this: here I have defined an |
Wowww cool, that is really helpful @momchilmm ! |
Indeed, very cool to see! I also made a note with a few references about methods for calculating eigenvector derivatives from a partial decomposition over in #3114. |
@shoyer oh this is very interesting! I might consider adding something like this in my |
@shoyer so I wrote a method to compute the eigenvector derivatives from a partial decomposition, along the lines of the works you had pointed out, and it works! Here's the vjp, and a test. The works mentioned in #3114 are overly complex because they consider degenerate eigenvalues, so I ended up following Steven Johnson's notes here, which I extended to the case of a Hermitian matrix with a small modification. So yeah this seems to work for non-degenerate eigenvalues. I do think (similarly to you I believe) that the gradient is not a well-defined quantity in the case of degenerate eigenvalues, specifically when we have more than one input parameters. Basically, for a single parameter, you can define the derivative of each eigenvalue by choosing eigenvectors in your subspace that are also eigenvectors of the system with the corresponding small perturbation. However, this choice of eigenvectors will be different for every different input parameter (corresponding to a different matrix perturbation), and so a full gradient cannot be defined. |
Update: this was due to an error with my eigenvector calculation (signs were inconsistent in evaluations for forward difference calculation I was using to check the gradient implementation). |
@jackd I think you are right! I have not encountered it myself since it seems to just work in some cases, but I can see how that might be a problem in others. In the Steven Johsnon notes there's a footnote:
However I don't think that would guarantee convergence, unless I'm missing something? I'm also not sure what's the best method to use for this matrix - we only know that it's Hermitian. But you could try just changing the solver to bicg or gmers and see if your troubles are gone. If you see it becoming stable I encourage you to submit a PR! |
@momchilmm glad I'm not going crazy :). I'm not having trouble with the smallest eigenvalue - and if I were I suspect a small diagonal shift might be enough to resolve it - but I am finding my tests often break as I increase the number of eigenvectors solved. I'm looking at MINRES-QLP as a substitute, as it explicitly caters for singular symmetric matrices with pre-conditioning - alas there's no jax implementation of that though, so it's not just a one-line change :S. |
Yeah there doesn't seem to be anything in scipy.sparse.linalg that's expected to always work. I think the footnote that I quoted above means that in most cases you don't have to worry that the matrix is singular, because if you start with an initial guess that's orthogonal to the eigenvector (which spans the kernel), you'll always stay outside of the kernel. So the fact that the matrix is not positive-definite is the issue. The fact that you don't have trouble with the smallest eigenvalue is because in the iteration the vector is restricted to the subspace orthogonal to the corresponding eigenvector, and all the eigenvalues mapping to that space are larger than 0 (they are shifted by the smallest eigenvalue). Minres-qld (or minres-qlp) seems to be the only thing I find for Hermitian matrices too. There's a freely available python implementation if you want to try your luck... |
Sigh, false alarm. Turns out my errors were due to my |
Ah. Yeah. If your objective function depends on the sign (or complex phase more generally) of the eigenvectors, and you're not deterministically setting it in your solver, then you're in trouble yeah. That's one of the main reasons this current issue exists. |
If anyone's looking for a LOBPCG implementation I've hacked up a basic implementation here. The sparse implementation works without There's also some |
See also https://github.com/eserie/jju |
Hey @jackd , are you planning on putting together a PR to land a version of lobpcg into jax core? |
@vlad17 no plan to right now. I'm willing to bet it wouldn't be a high priority from the jax core team even if it fit in the package (from memory mine is based on a pytorch implementation as opposed to the scipy one). Anyone else can feel free to use it however they want (e.g. for their own PR), though I haven't looked at it for a while so @lobpcg 's more recent adaption is probably in a better state. |
@lobpcg I've been toying with a simplified but fairly stable version (as in, seems to better than scipy) of LOBPCG which solves only the standard eigenvalue problem. I think JAX users would benefit from a PR which puts this into an experimental jax directory. Would you be able to help review it for correctness? |
This initial version is f32-only for accelerators, since it relies on an eigh call (which itself is f32 at most) in its inner loop. For details, see jax.experimental.linalg.standard_lobpcg documentation.
This initial version is f32-only for accelerators, since it relies on an eigh call (which itself is f32 at most) in its inner loop. For details, see jax.experimental.linalg.standard_lobpcg documentation.
This initial version is f32-only for accelerators, since it relies on an eigh call (which itself is f32 at most) in its inner loop. For details, see jax.experimental.linalg.standard_lobpcg documentation.
@mganahl |
This initial version is f32-only for accelerators, since it relies on an eigh call (which itself is f32 at most) in its inner loop. For details, see jax.experimental.linalg.standard_lobpcg documentation. This is a partial implementation of the similar [scipy lobpcg function](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.lobpcg.html).
I had erroneously assumed that GPU would be as-high accuracy for f64 (both in numerics and eigh) when submitting google#3112, so I did not disable f64 tests on that platform. This is of course not the case, so those tests should be disabled.
Thanks a ton for your work on this. On something else - I know preconditioning has been discussed in the main matrix-free thread (#1531), but I also want to express my interest in something like an xla accelerated multigrid preconditioner. |
bottom-k and preconditioner are reasonable next steps; I could add them (or review) if there are enough use cases that they'd be high priority. @rmlarsen mentioned some interest in bottom-k, though not sure if he knows about possible users for it. |
@rmlarsen @choltz95 The bottom-k of a matrix A is trivial to run in the existing code just run the code on negative matrix -A and change the sign of the eigenvalues to the opposite in the existing code at a level of the user. @vlad17 Or modify the code to select the smallest eigenvalues instead of the largest in the RR; cf. option "largest" in https://github.com/scipy/scipy/blob/main/scipy/sparse/linalg/_eigen/lobpcg/lobpcg.py |
When I said "smallest" I meant smallest absolute value. Of course flipping the sign of |
How hard would it be to add a shift-invert mode to the current lobpcg solver? |
The way I read the paper, the beautiful thing about LOBPCG is its ability to solve for 1/lambda by a change of variables in the generalized eigenvalue problem. It may not be super efficient, but it's easy to use and presumably much more efficient than a naive shift-invert implementation(?) |
For my specific use case w/ Laplacians, there isn't too much ambiguity about the eigenvalues (large, small, etc. are all nonnegative). So flipping the order of RR like prev. mentioned is great. At least for me, the important thing is that preconditioning (which is already there, just commented out) can be valuable for large ill-conditioned problems. |
@rmlarsen Solving Ax = lambda Bx in LOBPCG, the matrix B must be positive definite (PD). If A is not PD, you cannot plug Bx = 1/lambda Ax into LOBPCG |
@lobpcg So if B is the identity, I can use LOBPCG for this, right? |
@jakevdp LOBPCG is matrix-free, so solving Ax = lambda Bx in LOBPCG one needs only the multiplications by A and B, e.g., via external functions, but B must be PD. One can solve Ax = lambda x in the shift-invert mode with a fixed shift alpha in LOBPCG, e.g., by plugging inv(A - alpha I) x = gamma x, by providing a function that solves (A - alpha I) x = b and going for largest or smallest gamma, which would give lambdas around the shift alpha, separately below and above alpha. Since in this case only the action of inv(A - alpha I) is provided to the code, alpha should be fixed and the solve (A - alpha I) x = b should be done accurately enough. I have not tested the algorithm for this scenario so unsure how accurate exactly. @rmlarsen Alternatively, one can plug - (A - alpha I)^2 x = gamma x and go for the largest (negative in this case) gamma, e.g., see https://www.researchgate.net/publication/343531770_Eigenvalue_solvers_for_computer_simulations_of_efficient_solar_cell_materials |
Let me clarify: LOBPCG solves for smallest or largest eigenvalues of Ax = lambda Bx, where A is symmetric, and B is symmetric positive definite, e.g., B=I. Matrices A and B also need to be more-or-less fixed, e.g., if you plug into LOBPCG the problem inv(A - alpha I) x = gamma x and change alpha during the iterations that you likely break the code. Whatever eigenvalue problem you want to solve originally, you need to transform it mathematically into this form and satisfy these assumptions. If these assumptions are violated, the code would most likely break. |
@lobpcg and by "smallest lambda" you mean the "largest of 1/lambda"? |
Remember that A does not need to be positive, only B must be positive, so lambda can be positive, negative, or zero. One could only plug Bx = 1/lambda Ax into LOBPCG if A is positive. |
@lobpcg ah, thanks for clarifying. So in the case where A is positive, is there any advantage of rewriting to Bx = 1/lambda Ax, rather than solving for the largest eigenvalues of -A (when B=I). I suppose the gap structure and thus speed of convergence comes into play then? |
Good question. Without doing any math, I would guess that that speed should probably be the same, since in both cases one only multiplies by A anyways. The convergence speed is technically determined by the relative gap, so one should be able to write it down and compare. Too lazy to do it myself at the moment :=) |
@lobpcg you are right. The relative gap size should be the same to first order. |
On a much more mundane level, I just wanted to point out that Script and outputSCRIPT import jax.numpy as jnp
from jax.experimental.sparse.linalg import lobpcg_standard
for dtype in ("float32", "complex64"):
matrix = jnp.zeros((6, 6), dtype=dtype).at[0, 0].set(1)
init_vec = jnp.ones((6, 1), dtype=dtype)
output = lobpcg_standard(matrix, init_vec)
print(f"LOBPCG works with {dtype}") OUTPUT
|
@jemisjoky good catch! Are you interested in doing a pull request with the fix? If not I can look into it. |
@jakevdp, I'd love to! I honestly haven't done a PR with a major project like JAX before, but the contributor guide is pretty clear here. Will let you know if I have any questions in the process! |
Sounds good, thanks! |
Hi, apologies for the bump. Is there any news on the implementation of sparse eigensolvers for smallest eigenvalues? Currently, |
Just run LOBPCG for -A where A is the matrix of the eigenvalue problem and multiply the largest eigenvalues of -A by -1. |
I agree that running LOBPCG is probably your best choice. If by smallest you mean of smallest magnitude, you can try running LOBPCG on a functor that applies the inverse of A to a vector, if that is feasible. Notice that we recently added a mechanism The |
The inverse is typically infeasible although may give a much faster convergence in theory. A function multiplying by -A is always feasible. That's the approach used, e.g., in LOBPCG's scipy implementation. In contrast to a popular misunderstanding, LOBPCG does not require the matrix to be positive definite. |
Hi!
Our team is currently developing sparse symmetric and non-symmetric solvers in JAX (implicitly restarted arnoldi and lanzcos). I have it from hearsay that there are efforts to support this natively in JAX. I wanted to ask what the status there is. It would be actually great if JAX would provide those.
The text was updated successfully, but these errors were encountered: