-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Experimental thick-restart Lancozs method for sparse eigenvalue problems #3114
Conversation
I'm not sure if actually makes sense to merge this code. It probably needs some additional work/testing to make it more robust, but maybe `jax.experimental` would be a suitable incubator? Either way, hopefully it's a useful point of reference. xref GH3112
Very nice. I am interested in automatic differentiation of sparse eigensolvers, but new to JAX -- did you already have something in mind about this? My only point of reference is this recent paper (https://arxiv.org/abs/2001.04121) where they reduce computing the adjoint to solving a low-rank linear problem. Maybe your new CG implementation could be used to that end? |
Indeed, if the eigenvalues are unique, you can calculate eigenvalue/eigenvector derivatives in a straightforward fashion. That paper appears to have rediscovered a very old method for efficiently calculated eigenvector derivatives when not all eigenvectors are known, e.g., see this 1976 paper: https://arc.aiaa.org/doi/abs/10.2514/3.7211 There's actually a rather large literature on this problem from the field of computational mechanics. There are some other formulations that are supposed to be more numerically stable, e.g., so you don't need to solve a non-full rank system of linear equations: https://onlinelibrary.wiley.com/doi/abs/10.1002/cnm.895. I suspect the resulting linear equations are not necessarily positive, so we would more sophisticated iterative solvers such as (L)GMRES. The challenge is how to define derivatives in the case of degenerate eigenvalues. In that case the eigenvalue problem necessarily isn't a well defined function, so calculating derivatives is quite challenging (maybe even impossible in reverse mode). But we probably don't need to worry this for a first-pass solution. |
def _lanczos_restart(A, k, m, Q, alpha, beta): | ||
"""The inner loop of the restarted Lanzcos method.""" | ||
|
||
def body_fun(i, carray): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: change to carry
to match docs?
# With JAX, it makes even more sense to do full orthogonalization, since | ||
# making use of matrix-multiplication for orthogonalization requires a | ||
# statically sized set of Lanczos vectors. | ||
Q_valid = (i >= jnp.arange(m + 1)) * Q |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe better/clearer to use Q_valid = Q.at[:, :i].set(0)
? Also, perhaps add a TODO to use masking when available.
return r, q | ||
|
||
|
||
def _lanczos_restart(A, k, m, Q, alpha, beta): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A comment describing the meaning of each of these variables would be helpful.
tolerance=None, | ||
return_info=False, | ||
): | ||
"""Find the `num_desired` smallest eigenvalues of the linear map A. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My experience with lanczos solvers is that they perform poorly for finding the smallest eigenvalues except in shift-invert mode, in which the smallest eigenvalues are transformed to the largest.
Have you tried modifying this to find the largest eigenvalues?
[-1, -2, 3, -3], | ||
[0, 0, -3, 4], | ||
]) | ||
np.testing.assert_array_equal(actual, expected) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we have self.assertArraysEqual
now
# note: error tolerances here (and below) are set based on how well we do | ||
# in float32 precision | ||
np.testing.assert_allclose(w_actual, self.w_expected, atol=1e-5) | ||
np.testing.assert_allclose(abs(v_actual), abs(self.v_expected), atol=1e-4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.assertArraysAllclose
, here and below
|
||
def body_fun(i, carray): | ||
Q, alphas, betas = carray | ||
q_prev = Q[:, i] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi! I noticed that Krylov vectors are stored in column major format here. Would it make sense to switch to row-major? Wondering if for large-dimensional vectors, zero-padding on TPUs could lead to an unnecessarily large memory footprint.
This has become stale - I'm going to close. Feel free to re-open if you'd like to land this in |
I'm not sure if actually makes sense to merge this code. It probably needs some additional work/testing to make it more robust, but maybe
jax.experimental
would be a suitable incubator? Either way, hopefully it's a useful point of reference.xref #3112