Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental thick-restart Lancozs method for sparse eigenvalue problems #3114

Closed
wants to merge 2 commits into from

Conversation

shoyer
Copy link
Collaborator

@shoyer shoyer commented May 15, 2020

I'm not sure if actually makes sense to merge this code. It probably needs some additional work/testing to make it more robust, but maybe jax.experimental would be a suitable incubator? Either way, hopefully it's a useful point of reference.

xref #3112

I'm not sure if actually makes sense to merge this code. It probably needs
some additional work/testing to make it more robust, but maybe
`jax.experimental` would be a suitable incubator? Either way, hopefully it's
a useful point of reference.

xref GH3112
@simbilod
Copy link

Very nice. I am interested in automatic differentiation of sparse eigensolvers, but new to JAX -- did you already have something in mind about this? My only point of reference is this recent paper (https://arxiv.org/abs/2001.04121) where they reduce computing the adjoint to solving a low-rank linear problem. Maybe your new CG implementation could be used to that end?

@shoyer
Copy link
Collaborator Author

shoyer commented May 21, 2020

Indeed, if the eigenvalues are unique, you can calculate eigenvalue/eigenvector derivatives in a straightforward fashion. That paper appears to have rediscovered a very old method for efficiently calculated eigenvector derivatives when not all eigenvectors are known, e.g., see this 1976 paper: https://arc.aiaa.org/doi/abs/10.2514/3.7211

There's actually a rather large literature on this problem from the field of computational mechanics. There are some other formulations that are supposed to be more numerically stable, e.g., so you don't need to solve a non-full rank system of linear equations: https://onlinelibrary.wiley.com/doi/abs/10.1002/cnm.895. I suspect the resulting linear equations are not necessarily positive, so we would more sophisticated iterative solvers such as (L)GMRES.

The challenge is how to define derivatives in the case of degenerate eigenvalues. In that case the eigenvalue problem necessarily isn't a well defined function, so calculating derivatives is quite challenging (maybe even impossible in reverse mode). But we probably don't need to worry this for a first-pass solution.

@jakevdp jakevdp self-assigned this May 28, 2020
def _lanczos_restart(A, k, m, Q, alpha, beta):
"""The inner loop of the restarted Lanzcos method."""

def body_fun(i, carray):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: change to carry to match docs?

# With JAX, it makes even more sense to do full orthogonalization, since
# making use of matrix-multiplication for orthogonalization requires a
# statically sized set of Lanczos vectors.
Q_valid = (i >= jnp.arange(m + 1)) * Q
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe better/clearer to use Q_valid = Q.at[:, :i].set(0)? Also, perhaps add a TODO to use masking when available.

return r, q


def _lanczos_restart(A, k, m, Q, alpha, beta):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment describing the meaning of each of these variables would be helpful.

tolerance=None,
return_info=False,
):
"""Find the `num_desired` smallest eigenvalues of the linear map A.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My experience with lanczos solvers is that they perform poorly for finding the smallest eigenvalues except in shift-invert mode, in which the smallest eigenvalues are transformed to the largest.

Have you tried modifying this to find the largest eigenvalues?

[-1, -2, 3, -3],
[0, 0, -3, 4],
])
np.testing.assert_array_equal(actual, expected)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have self.assertArraysEqual now

# note: error tolerances here (and below) are set based on how well we do
# in float32 precision
np.testing.assert_allclose(w_actual, self.w_expected, atol=1e-5)
np.testing.assert_allclose(abs(v_actual), abs(self.v_expected), atol=1e-4)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.assertArraysAllclose, here and below


def body_fun(i, carray):
Q, alphas, betas = carray
q_prev = Q[:, i]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! I noticed that Krylov vectors are stored in column major format here. Would it make sense to switch to row-major? Wondering if for large-dimensional vectors, zero-padding on TPUs could lead to an unnecessarily large memory footprint.

@jakevdp
Copy link
Collaborator

jakevdp commented Nov 3, 2023

This has become stale - I'm going to close. Feel free to re-open if you'd like to land this in jax.scipy.sparse.linalg

@jakevdp jakevdp closed this Nov 3, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants