-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support autodiff of Eigendecomposition with repeated eigenvalues #669
Comments
Might not be related, but even without jit compilation and complex inputs, gradient computation of a function of eigenvectors fails. import numpy as onp
import jax.numpy as np
from jax import grad
def test(x):
val, vec = np.linalg.eigh(x)
return np.real(np.sum(vec))
grad_test = grad(test)
x = onp.eye(3, dtype=onp.double)
print(test(x))
print(grad_test(x))
|
PR #670 fixes the first bug; we were incorrectly declaring that the eigenvalues of a complex matrix were complex, leading to a type error when under a Not sure about the second bug yet. |
I'm wondering if the second case is happening because one of the assumptions of the JVP rule for eigh is that the eigenvalues are distinct. In this case, all the eigenvalues are 1. This is at the limits of my linear algebra knowledge, @mattjj do you have insights here? |
See the discussion on the implementation here: https://github.com/google/jax/blob/master/jax/lax_linalg.py#L155 |
Wow, that's a nice comment! Thanks @levskaya. I don't have any insights. That comment taught me things. I think the case of repeated eigenvalues might come down to "contributions welcome". |
Thanks @hawkinsp for the quick response and fix! Looks like repeated eigenvalues is probably the cause of the second bug import numpy as onp
import jax.numpy as np
from jax import grad
def test(x):
val, vec = np.linalg.eigh(x)
return np.real(np.sum(vec))
grad_test = grad(test)
onp.random.seed(42)
x = onp.diag(
onp.ones(3, dtype=onp.double) +
onp.random.normal(0, 1e-6, size=3)
)
x2 = onp.diag(
onp.ones(3, dtype=onp.double) +
onp.random.normal(0, 1e-8, size=3)
)
print(grad_test(x))
print(grad_test(x2))
|
@sdaxen do you need autodiff of eigendecomposition with repeated eigenvalues? If not, we should probably close this issue until someone actually asks for it. (That way we can keep all the "enhancement" issues tracking things that users have specifically asked for.) |
@mattjj no, it's not a priority. I do need repeated eigenvalues, but I'm only test driving jax for the moment while doing my main work with a different system. Feel free to close. |
Thanks for the info! We're very interested to hear about JAX's shortcomings so that we can work to fix them, so if there's something about JAX that makes it unsuitable for your work, please let us know! @hawkinsp is it a better policy to close this issue, or leave it open and just keep in mind that we don't have any users specifically asking for it yet? |
I think we should leave these kinds of issues open. It makes them more easily searchable should someone else have the same problem; I'd rather have one issue than two. |
Cool, makes sense to me! At least we clarified how to prioritize it, then. |
I also observe that TF has the same limitation: https://github.com/tensorflow/tensorflow/blob/f33aa592f92e233aeb00198d0caab80eaa89afe9/tensorflow/python/ops/linalg_grad.py#L314 |
Forgive my general lack of knowledge, I'm just beginning to look into jax/remind myself of some linear algebra. But on this issue: import numpy as onp
import jax.numpy as np
from jax import grad
def test(x):
val, vec = np.linalg.eigh(x)
return np.real(np.sum(vec))
grad_test = grad(test)
x = onp.eye(3, dtype=onp.double)
print(test(x))
print(grad_test(x)) Is the derivative here even well defined? It seems to me like in the degenerate case, there is no unique eigenvector so the np.real(np.sum(vec)) could be a whole range of possible values depending on your choice of basis, no? This is interesting stuff, I'd be curious how you go about learning more. |
As you note the general problem is quite tricky. To use physics parlance, there are two cases: you can have a matrix w. degenerate eigenvalues where the perturbation (gradient direction) "breaks the symmetry" and causes the degenerate eigenvalues to split, and then you have the case where the perturbation preserves the degeneracy... which generally makes talking about eigenvector derivatives very tricky / ill-defined with simple approaches. Especially if you're dealing with the general complex case where the eigenvector phase has additional freedom. There are a few papers that seem to offer general algorithmic approaches, but they're complicated enough that no one has sat down to try to implement them to see how they'd work: |
The hard case is differentiating eigenvectors in the presence of degeneracies. Eigenvalue derivatives are still fine, either way. I believe my pull request #1665 actually has a working JVP (forward mode) gradient implementation for eigh with degeneracies, but it can't be transposed, which means it doesn't work for backward mode differentiation. In general, I don't think it's possible to define backwards mode gradients of eigenvectors for arbitrary functions of degenerate eigenvectors -- the gradients simply don't always exist. I'll see if I can work-up a good counter-example. From a practical perspective, it seems like the better idea is to differentiate a higher level function like a power-series that does always have a well defined derivatives. Typically your program that uses eigenvectors corresponding to degenerate eigenvalues is ultimately using eigenvectors to calculate something like this anyways, because otherwise it's output would depend on arbitrary choices from the linear algebra library. EDIT: to clarify, by "power series" I really mean "matrix valued function" here |
@shoyer Just so I'm clear, you're saying that if we have a matrix A = P D P^-1, typically the reason that you would want to do the eigendecomposition is so that you can evaluation a function f(A) by doing f(A) = P f(D) P^-1 which is independent of the arbitrary choice of eigenvector (and other uses would be out of scope)? If so, I'm still confused as to how the Or are you talking about an entirely different application of power-series that I'm unfamiliar with. |
That's exactly right. I hypothesize that every real-world use case for calculating eigenvectors is using them in order to evaluate a matrix-valued function of some form.
The example in the first post was with eigenvalue directives. As noted in #669 (comment), it's been fixed. |
It's probably worth noting that the example failure case for eigenvector derivatives from #669 (comment) is not well-defined matrix-valued function:
E.g., suppose |
FWIW, I originally encountered this while playing around with the matrix exponential of hermitian matrices, which is a power series function. For such functions that internally use the eigendecomposition, we can nevertheless write forward- and reverse-mode rules that almost completely account for the degeneracy of the eigendecomposition. |
I recently stumbled across this paper, which seems to provide exactly the algorithm we need here: It looks like we could use the forward derivative for a JVP rule in JAX, which would suffice for auto-diff as long as we know how to implement and transpose a Sylvester solve (i.e., |
Shouldn't be too hard. The JVP for a Sylvester solve is just another Sylvester solve. If But Sylvester solves usually compute a Schur decomposition of |
I spent some time going through the paper today, and as far as I can tell, this approach only handles exactly degenerate matrices, not almost-degenerate matrices, and in the case of exactly degenerate matrices, a Sylvester solver is not actually necessary and does not help things. Using their notation, for a standard eigenvalue problem In the paper's notation, the usual JVP in Giles' paper and elsewhere is
After simplification, their contribution wrt degeneracy amounts to a simple modification to the matrix
When
But we can take this further by multiplying both sides of Sylvester's equation by
This Sylvester appears in the standard derivation, the same as in Giles and elsewhere. We can solve for So as far as I can tell, this does not resolve the issue of degeneracy, especially since even if a matrix is constructed to have exactly equal eigenvalues, due to floating point error, the eigenvalues will usually be nonequal. All it does is ensure that for certain programs, when a matrix with exact eigenvalues is factorized to produce exact eigenvalues, and when the tangent has a certain structure, that 0/0 does not happen. Which is useful but I think will only happen in very extreme and rare cases. One last point is that this modification applies to the standard eigendecomposition as well, not just for symmetric or real matrices. |
I noticed the author of the paper, @mfkasim1 is on GitHub, and might be interested in this discussion. |
Thanks for the tag. The same issue is also discussed in pytorch: pytorch/pytorch#47599 with PR pytorch/pytorch#50942. The algorithm in the paper basically only works if the loss function does not depend directly on the degenerate eigenvectors, but it can depends on the space spanned by the degenerate eigenvectors. I have tried the algorithm in my differentiable density functional theory (DFT) simulation (there are a lot of degenerate eigenvalues) and it works nicely (i.e. it passes pytorch's gradcheck and gradgradcheck). |
@mfkasim1 thanks for joining us! I share @sethaxen's concern about almost degenerate eigenvalues. In such cases, the standard auto-diff rules for |
@shoyer The case in my DFT application is that it is supposed to have exactly the same eigenvalues theoretically, but numerically, the retrieved eigenvalues are only close to each other (near-degenerate). In my DFT case, this is sufficient, because the loss function does not depend directly on the degenerate eigenvectors (it depends on the space spanned by the eigenvectors), so the nominator (eq 2.13) is supposed to be 0. |
I think it is fine to restrict the autodiff rules for |
I came across this conversation and wanted to leave this note as a reference for near-degenerate eigenvectors, specifically the transformation of Eq. 10 into Eq. 11 to account for near-degeneracy: https://github.com/mitmath/18335/blob/master/notes/adjoint/eigenvalue-adjoint.pdf . Hopefully you find it useful, though some translation may be in order. |
(posting a comment from last week, that I thought I had already already submitted!)
Yes, indeed, this does look pretty straightforward! (Note for Googlers: here's a chat thread that contains links to a prototype for a differentiable sylvester solve in JAX, written in terms of SciPy operations)
I think we can use the same trick we use for linear solves: compute the matrix factorization/decompositions first (without gradients), and then pass it into the auto-diff primitive: Technically, the auto-diff primitive becomes "Sylvester solve from a Schur decomposition" rather than "Sylvester solve from scratch". |
Hi! I was wondering if there are any plans to implement the Sylvester solver in Jax, perhaps based on the draft implementation you mentioned or Jax.scipy.linalg.schur and jax.scipy.linalg.solve_triangular as in the Bartels-Stewart algorithm? I would be interested in an implementation. Any tips or status update would be appreciated! |
Any progress on this? Encountering problems with almost degenerate eigenvalues... |
Hi, there is regularization technique to circumvent this degeneracy. You can refer those in docstring. code from my repo (https://github.com/kc-ml2/meent/blob/main/meent/on_jax/emsolver/primitives.py) import jax
import jax.numpy as jnp
from functools import partial
def conj(arr):
return arr.real + arr.imag * -1j
# return arr.conj()
@partial(jax.custom_vjp, nondiff_argnums=(1, 2, 3))
def eig(x, type_complex=jnp.complex128, perturbation=1E-10, device='cpu'):
_eig = jax.jit(jnp.linalg.eig, device=jax.devices('cpu')[0])
eigenvalues_shape = jax.ShapeDtypeStruct(x.shape[:-1], type_complex)
eigenvectors_shape = jax.ShapeDtypeStruct(x.shape, type_complex)
result_shape_dtype = (eigenvalues_shape, eigenvectors_shape)
if device == 'cpu':
res = _eig(x)
else:
res = jax.pure_callback(_eig, result_shape_dtype, x)
return res
def eig_fwd(x, type_complex, perturbation, device):
return eig(x, type_complex, perturbation), (eig(x, type_complex, perturbation), x)
def eig_bwd(type_complex, perturbation, device, res, g):
"""
Gradient of a general square (complex valued) matrix
Eq 2~5 in https://www.nature.com/articles/s42005-021-00568-6
Eq 4.77 in https://arxiv.org/pdf/1701.00392.pdf
Eq. 30~32 in https://www.sciencedirect.com/science/article/abs/pii/S0010465522002715
https://github.com/kch3782/torcwa
https://github.com/weiliangjinca/grcwa
https://github.com/pytorch/pytorch/issues/41857
https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Complex-numbers-and-differentiation
https://discuss.pytorch.org/t/autograd-on-complex-numbers/144687/3
"""
(eig_val, eig_vector), x = res
grad_eigval, grad_eigvec = g
grad_eigval = jnp.diag(grad_eigval)
W_H = eig_vector.T.conj()
Fij = eig_val.reshape((1, -1)) - eig_val.reshape((-1, 1))
Fij = Fij / (jnp.abs(Fij) ** 2 + perturbation)
Fij = Fij.at[jnp.diag_indices_from(Fij)].set(0)
# diag_indices = jnp.arange(len(eig_val))
# Eij = eig_val.reshape((1, -1)) - eig_val.reshape((-1, 1))
# Eij = Eij.at[diag_indices, diag_indices].set(1)
# Fij = 1 / Eij
# Fij = Fij.at[diag_indices, diag_indices].set(0)
grad = jnp.linalg.inv(W_H) @ (grad_eigval.conj() + Fij * (W_H @ grad_eigvec.conj())) @ W_H
grad = grad.conj()
if not jnp.iscomplexobj(x):
grad = grad.real
return grad,
eig.defvjp(eig_fwd, eig_bwd) |
This code works fine! BTW, I'm pondering is there any "jvp"-version. (Cause I'm working on a frame only supporting forward-auto-diff) |
I am also running into issues with needing degenerate eigenvalues for the computation of molecular energies using wavefunction methods. Degenerate eigenvalues are quite common in quantum chemistry as evidenced by Kasim's work on differentiable density functional theory and if you just think of the symmetry present in many molecules. Is the fact that backward-mode AD can't handle degenerate eigenvalues a hindrance to implementing the forward-mode JVP? |
On v0.1.25 on OSX, I get the following error when computing gradients from the following jit-compiled function.
So far so good. But computing the gradient of the jit-compiled function with complex inputs errors
Jax built from source produced the same error.
The text was updated successfully, but these errors were encountered: