-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Handle degenerate eigenvalues in JVP rule for eigh #1665
base: main
Are you sure you want to change the base?
Conversation
onp.testing.assert_allclose(abs(v), onp.ones((2, 2)) / onp.sqrt(2)) | ||
onp.testing.assert_allclose(w, onp.ones((2,))) | ||
onp.testing.assert_allclose(abs(dv), onp.zeros((2, 2))) | ||
onp.testing.assert_allclose(dw, onp.array([-1, 1])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: more test cases. Currently nothing is checking eigenvector derivatives in the case of degeneracies if they are non-zero.
jax/lax_linalg.py
Outdated
# TODO(shoyer): consider rewriting with an explicit loop over degenerate | ||
# subspaces instead? | ||
v2 = dot(v, A) | ||
w2 = np.einsum('ij,jk,ki->i', _H(v2), a_sym, v2).real |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: handle batching here.
@sethaxen's test case from GH-669 currently fails: import numpy as onp
import jax.numpy as np
from jax import grad
def test(x):
val, vec = np.linalg.eigh(x)
return np.real(np.sum(vec))
grad_test = grad(test)
x = onp.eye(3, dtype=onp.double)
print(test(x))
print(grad_test(x)) The error is:
I'm guessing this has something to do with how the primal depends on the tangent value... |
I haven't had a chance to look at these changes yet, but just in case it's useful, I recently added FluxML/Zygote.jl#355 to Zygote to handle the degeneracy cases in eigh-based functions of Hermitian matrices. |
Interesting, thanks for the tip! Could you perhaps point to a specific example of how you handle degenerate eigenvalues? |
Sure. The PR was a bit complex due to supporting so many functions. The core function is Degeneracy is handled within the function The bounds here aren't rigorously worked out; I expect they'll break down for very large or very small |
@sethaxen thanks for the pointers, this is very helpful. I'm still wrapping my head around the cases where backwards mode auto-diff of degenerate eigenvector problems is valid and how to fit that into JAX's auto-diff system (which incidentally is quite different from Zygote's, because JAX defines the backwards auto-diff indirectly via transposing the forward pass). It's becoming clear to me that backwards mode auto-diff of degenerate eigenvector problems is not well defined in general for arbitrary loss functions, but you can do it if you have a uniquely defined function of the eigenvectors that doesn't depend on the particular choice of basis within degenerate subspaces. Your PR is a good example of this, because |
No problem! Let me know if you'd like any more info or would like another set of eyes on some math. I haven't studied JAX's approach, but I imagine there might be some overlap in some of the tricks even if they're using different modes.
Yes that's right, that approach only supports exactly functions that can be expressed with a power series on the eigenvectors. And even if the backwards autodiff of the eigendecomposition is undefined for degenerate matrices, it can still be defined for these functions, which is pretty cool. |
For future reference: The strategy used in this PR basically works, but only for forward-mode (JVPs). In the case of degeneracies, the resulting computational graph uses non-linear operations in the JVP calculation, and thus cannot be transposed to calculate a VJP. This makes it useless for backward-mode autodiff, which was my intended use-case. I'm not sure it's possible to do backwards-mode autodiff in the presence of the degeneracies. Fundamentally the problem is that the choice of basis for the primal eigenvalue calculation needs to depend on the pullbacks. |
Thank you so much @shoyer for writing this. I've been doing some testing on the effects of increasing the "epsilon" value that determines same_subspace. The default value epsilon = 10 * np.finfo(a.dtype).resolution is equivalent to 1e-14 for double precision and 1e-5 for single precision. I know its probably generally terrible to override machine epsilon, but it may be okay in this case. Although the calculation itself is in 64bit, in practice my input tensor only ever has around 7 significant digits worth of precision. So for all practical purposes on my end I consider machine precision to be that of the single precision equivalent, I'd also be happy to make this trade off in exchange for stabilizing my simulation. In general though, I don't advise doing this. from jax.config import config; config.update("jax_enable_x64", True)
import numpy as onp
import jax.numpy as np
import jax
import functools
def _T(x): return np.swapaxes(x, -1, -2)
def _H(x): return np.conj(_T(x))
def symmetrize(x): return (x + _H(x)) / 2
# # from shoyer
def eigh_jvp(
a,
a_tangent,
w,
v,
epsilon):
a_dot = a_tangent
a_sym = symmetrize(a)
w = w.astype(a.dtype)
dot = np.dot
vdag_adot = dot(_H(v), a_dot)
vdag_adot_v = dot(vdag_adot, v)
deltas = w[..., np.newaxis, :] - w[..., np.newaxis]
handle_degeneracies = True
same_subspace = (abs(deltas) < epsilon
if handle_degeneracies
else np.eye(a.shape[-1], dtype=bool))
if handle_degeneracies:
w_dot, v_dot = np.linalg.eigh(vdag_adot_v * same_subspace)
# Reorder these into sorted order of the original eigenvalues.
# TODO(shoyer): consider rewriting with an explicit loop over degenerate
# subspaces instead?
v2 = dot(v, v_dot)
w2 = np.einsum('...ij,...jk,...ki->...i', _H(v2), a_sym, v2).real
order = np.argsort(w2, axis=-1)
v = np.take_along_axis(v2, order[..., np.newaxis, :], axis=-1)
dw = np.take_along_axis(w_dot, order, axis=-1)
deltas = w[..., np.newaxis, :] - w[..., np.newaxis]
same_subspace = abs(deltas) < epsilon
else:
dw = np.diagonal(vdag_adot_v, axis1=-2, axis2=-1)
Fmat = np.where(same_subspace, 0.0, 1.0 / deltas)
C = Fmat * vdag_adot_v
dv = dot(v, C)
return dw, dv
def test():
tensor = onp.array([
[ 0.06683436, -0.05386403, 0.0127573 ],
[-0.05386403, 0.10799386, 0.00877524],
[ 0.0127573 , 0.00877524, 0.14300045]]
)
w, v = np.linalg.eigh(tensor)
# print(w)
# print(v)
for i in range(3):
for j in range(3):
tensor_tangent = onp.array([
[0,0,0],
[0,0,0],
[0,0,0]
])
tensor_tangent[i][j] = 1
dw, dv = eigh_jvp(tensor, tensor_tangent, w, v, epsilon=1e-4)
print(dv)
dw, dv = eigh_jvp(tensor, tensor_tangent, w, v, epsilon=1e-7)
print(dv)
test() |
Fixes #669
Googlers: please ping me for me a copy of the referenced paper.