Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle degenerate eigenvalues in JVP rule for eigh #1665

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

shoyer
Copy link
Collaborator

@shoyer shoyer commented Nov 11, 2019

Fixes #669

Googlers: please ping me for me a copy of the referenced paper.

onp.testing.assert_allclose(abs(v), onp.ones((2, 2)) / onp.sqrt(2))
onp.testing.assert_allclose(w, onp.ones((2,)))
onp.testing.assert_allclose(abs(dv), onp.zeros((2, 2)))
onp.testing.assert_allclose(dw, onp.array([-1, 1]))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: more test cases. Currently nothing is checking eigenvector derivatives in the case of degeneracies if they are non-zero.

# TODO(shoyer): consider rewriting with an explicit loop over degenerate
# subspaces instead?
v2 = dot(v, A)
w2 = np.einsum('ij,jk,ki->i', _H(v2), a_sym, v2).real
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: handle batching here.

@shoyer
Copy link
Collaborator Author

shoyer commented Nov 11, 2019

@sethaxen's test case from GH-669 currently fails:

import numpy as onp
import jax.numpy as np
from jax import grad

def test(x):
    val, vec = np.linalg.eigh(x)
    return np.real(np.sum(vec))

grad_test = grad(test)

x = onp.eye(3, dtype=onp.double)

print(test(x))
print(grad_test(x))

The error is:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-1-2a12f7a482f8> in <module>
     12 
     13 print(test(x))
---> 14 print(grad_test(x))

~/open-source/jax/jax/api.py in grad_f(*args, **kwargs)
    341   @wraps(fun, docstr=docstr, argnums=argnums)
    342   def grad_f(*args, **kwargs):
--> 343     _, g = value_and_grad_f(*args, **kwargs)
    344     return g
    345 

~/open-source/jax/jax/api.py in value_and_grad_f(*args, **kwargs)
    389     f_partial, dyn_args = _argnums_partial(f, argnums, args)
    390     if not has_aux:
--> 391       ans, vjp_py = vjp(f_partial, *dyn_args)
    392     else:
    393       ans, vjp_py, aux = vjp(f_partial, *dyn_args, has_aux=True)

~/open-source/jax/jax/api.py in vjp(fun, *primals, **kwargs)
   1147   if not has_aux:
   1148     flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
-> 1149     out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
   1150     out_tree = out_tree()
   1151   else:

~/open-source/jax/jax/interpreters/ad.py in vjp(traceable, primals, has_aux)
    105 def vjp(traceable, primals, has_aux=False):
    106   if not has_aux:
--> 107     out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
    108   else:
    109     out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)

~/open-source/jax/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs)
     97   pval_primals, pval_tangents = tree_unflatten(out_tree(), out_pvals)
     98   aval_primals, const_primals = unzip2(pval_primals)
---> 99   assert all(aval_primal is None for aval_primal in aval_primals)
    100   if not has_aux:
    101     return const_primals, pval_tangents, jaxpr, consts

AssertionError: 

I'm guessing this has something to do with how the primal depends on the tangent value...

@sethaxen
Copy link

I haven't had a chance to look at these changes yet, but just in case it's useful, I recently added FluxML/Zygote.jl#355 to Zygote to handle the degeneracy cases in eigh-based functions of Hermitian matrices.

@shoyer
Copy link
Collaborator Author

shoyer commented Nov 12, 2019

Interesting, thanks for the tip! Could you perhaps point to a specific example of how you handle degenerate eigenvalues?

@sethaxen
Copy link

Sure. The PR was a bit complex due to supporting so many functions. The core function is _apply_series_fun here: https://github.com/FluxML/Zygote.jl/pull/355/files#diff-32dcbbbe9b541dd76dce481e5ff1e2f1R573-R588

Degeneracy is handled within the function _pairdiffquot (https://github.com/FluxML/Zygote.jl/pull/355/files#diff-32dcbbbe9b541dd76dce481e5ff1e2f1R460-R471). This computes elements P_{ij} = \frac{f(x_i) - f(x_j)}{x_i - x_j} of a matrix P. When the eigenvalues are degenerate, they approach the first derivative f'(x_i), which in fact is what the diagonal values are. When \Delta x for an element is less than machine epsilon, the off-diagonal is expanded around \Delta x \to 0, which introduces the first and second derivatives. In absence of 2nd derivatives, we expand them as well and reuse the first derivatives, which gets us a better approximation. I can provide rough derivations if useful.

The bounds here aren't rigorously worked out; I expect they'll break down for very large or very small f(x) or if the derivatives are not roughly on the same order of magnitude.

@shoyer
Copy link
Collaborator Author

shoyer commented Nov 12, 2019

@sethaxen thanks for the pointers, this is very helpful.

I'm still wrapping my head around the cases where backwards mode auto-diff of degenerate eigenvector problems is valid and how to fit that into JAX's auto-diff system (which incidentally is quite different from Zygote's, because JAX defines the backwards auto-diff indirectly via transposing the forward pass).

It's becoming clear to me that backwards mode auto-diff of degenerate eigenvector problems is not well defined in general for arbitrary loss functions, but you can do it if you have a uniquely defined function of the eigenvectors that doesn't depend on the particular choice of basis within degenerate subspaces. Your PR is a good example of this, because exp, cos, sin, pow, etc can all be defined as power series.

@sethaxen
Copy link

@sethaxen thanks for the pointers, this is very helpful.

No problem! Let me know if you'd like any more info or would like another set of eyes on some math. I haven't studied JAX's approach, but I imagine there might be some overlap in some of the tricks even if they're using different modes.

It's becoming clear to me that backwards mode auto-diff of degenerate eigenvector problems is not well defined in general for arbitrary loss functions, but you can do it if you have a uniquely defined function of the eigenvectors that doesn't depend on the particular choice of basis within degenerate subspaces. Your PR is a good example of this, because exp, cos, sin, pow, etc can all be defined as power series.

Yes that's right, that approach only supports exactly functions that can be expressed with a power series on the eigenvectors. And even if the backwards autodiff of the eigendecomposition is undefined for degenerate matrices, it can still be defined for these functions, which is pretty cool.

jax/lax_linalg.py Outdated Show resolved Hide resolved
@shoyer
Copy link
Collaborator Author

shoyer commented Nov 17, 2020

For future reference: The strategy used in this PR basically works, but only for forward-mode (JVPs). In the case of degeneracies, the resulting computational graph uses non-linear operations in the JVP calculation, and thus cannot be transposed to calculate a VJP. This makes it useless for backward-mode autodiff, which was my intended use-case.

I'm not sure it's possible to do backwards-mode autodiff in the presence of the degeneracies. Fundamentally the problem is that the choice of basis for the primal eigenvalue calculation needs to depend on the pullbacks.

@proteneer
Copy link
Contributor

proteneer commented Nov 18, 2020

Thank you so much @shoyer for writing this. I've been doing some testing on the effects of increasing the "epsilon" value that determines same_subspace. The default value epsilon = 10 * np.finfo(a.dtype).resolution is equivalent to 1e-14 for double precision and 1e-5 for single precision. I know its probably generally terrible to override machine epsilon, but it may be okay in this case. Although the calculation itself is in 64bit, in practice my input tensor only ever has around 7 significant digits worth of precision. So for all practical purposes on my end I consider machine precision to be that of the single precision equivalent, I'd also be happy to make this trade off in exchange for stabilizing my simulation.

In general though, I don't advise doing this.

from jax.config import config; config.update("jax_enable_x64", True)
import numpy as onp
import jax.numpy as np
import jax
import functools

def _T(x): return np.swapaxes(x, -1, -2)
def _H(x): return np.conj(_T(x))
def symmetrize(x): return (x + _H(x)) / 2

# # from shoyer
def eigh_jvp(
    a,
    a_tangent,  
    w,
    v,
    epsilon):

  a_dot = a_tangent
  a_sym = symmetrize(a)
  w = w.astype(a.dtype)
  dot = np.dot
  vdag_adot = dot(_H(v), a_dot)
  vdag_adot_v = dot(vdag_adot, v)

  deltas = w[..., np.newaxis, :] - w[..., np.newaxis]
  handle_degeneracies = True
  same_subspace = (abs(deltas) < epsilon
                   if handle_degeneracies
                   else np.eye(a.shape[-1], dtype=bool))

  if handle_degeneracies:
    w_dot, v_dot = np.linalg.eigh(vdag_adot_v * same_subspace)
    # Reorder these into sorted order of the original eigenvalues.
    # TODO(shoyer): consider rewriting with an explicit loop over degenerate
    # subspaces instead?
    v2 = dot(v, v_dot)
    w2 = np.einsum('...ij,...jk,...ki->...i', _H(v2), a_sym, v2).real
    order = np.argsort(w2, axis=-1)
    v = np.take_along_axis(v2, order[..., np.newaxis, :], axis=-1)
    dw = np.take_along_axis(w_dot, order, axis=-1)
    deltas = w[..., np.newaxis, :] - w[..., np.newaxis]
    same_subspace = abs(deltas) < epsilon
  else:
    dw = np.diagonal(vdag_adot_v, axis1=-2, axis2=-1)

  Fmat = np.where(same_subspace, 0.0, 1.0 / deltas)
  C = Fmat * vdag_adot_v
  dv = dot(v, C)
  return dw, dv

def test():

    tensor = onp.array([
        [ 0.06683436, -0.05386403,  0.0127573 ],
        [-0.05386403,  0.10799386,  0.00877524],
        [ 0.0127573 ,  0.00877524,  0.14300045]]
    )

    w, v = np.linalg.eigh(tensor)

    # print(w)
    # print(v)

    for i in range(3):
        for j in range(3):

            tensor_tangent = onp.array([
                [0,0,0],
                [0,0,0],
                [0,0,0]
            ])

            tensor_tangent[i][j] = 1

            dw, dv = eigh_jvp(tensor, tensor_tangent, w, v, epsilon=1e-4)
            print(dv)
            dw, dv = eigh_jvp(tensor, tensor_tangent, w, v, epsilon=1e-7)
            print(dv)

test()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support autodiff of Eigendecomposition with repeated eigenvalues
4 participants