Skip to content

Commit

Permalink
Merge pull request #358 from levskaya/master
Browse files Browse the repository at this point in the history
added jvp rule for eigh, tests
  • Loading branch information
mattjj committed Feb 14, 2019
2 parents b0db87b + cd22050 commit fc4c8bd
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 5 deletions.
30 changes: 29 additions & 1 deletion jax/lax_linalg.py
Expand Up @@ -39,7 +39,10 @@ def cholesky(x, symmetrize_input=True):
x = symmetrize(x)
return cholesky_p.bind(x)

def eigh(x, lower=True): return eigh_p.bind(x, lower=lower)
def eigh(x, lower=True, symmetrize_input=True):
if symmetrize_input:
x = symmetrize(x)
return eigh_p.bind(x, lower=lower)

def lu(x): return lu_p.bind(x)

Expand Down Expand Up @@ -146,10 +149,35 @@ def eigh_cpu_translation_rule(c, operand, lower):
raise NotImplementedError(
"Only unbatched eigendecomposition is implemented on CPU")

def eigh_jvp_rule(primals, tangents, lower):
# Derivative for eigh in the simplest case of distinct eigenvalues.
# This is classic nondegenerate perurbation theory, but also see
# https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
# The general solution treating the case of degenerate eigenvalues is
# considerably more complicated. Ambitious readers may refer to the general
# methods below or refer to degenerate perturbation theory in physics.
# https://www.win.tue.nl/analysis/reports/rana06-33.pdf and
# https://people.orie.cornell.edu/aslewis/publications/99-clarke.pdf
a, = primals
a_dot, = tangents
v, w = eigh_p.bind(symmetrize(a), lower=lower)
# for complex numbers we need eigenvalues to be full dtype of v, a:
w = w.astype(a.dtype)
eye_n = np.eye(a.shape[-1], dtype=a.dtype)
# carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs.
Fmat = np.reciprocal(eye_n + w - w[..., np.newaxis]) - eye_n
# eigh impl doesn't support batch dims, but future-proof the grad.
dot = lax.dot if a.ndim == 2 else lax.batch_matmul
vdag_adot_v = dot(dot(_H(v), a_dot), v)
dv = dot(v, np.multiply(Fmat, vdag_adot_v))
dw = np.diagonal(vdag_adot_v)
return core.pack((v, w)), core.pack((dv, dw))

eigh_p = Primitive('eigh')
eigh_p.def_impl(eigh_impl)
eigh_p.def_abstract_eval(eigh_abstract_eval)
xla.translations[eigh_p] = eigh_translation_rule
ad.primitive_jvps[eigh_p] = eigh_jvp_rule
xla.backend_specific_translations['Host'][eigh_p] = eigh_cpu_translation_rule


Expand Down
4 changes: 2 additions & 2 deletions jax/numpy/linalg.py
Expand Up @@ -93,7 +93,7 @@ def det(a):


@_wraps(onp.linalg.eigh)
def eigh(a, UPLO=None):
def eigh(a, UPLO=None, symmetrize_input=True):
if UPLO is None or UPLO == "L":
lower = True
elif UPLO == "U":
Expand All @@ -103,7 +103,7 @@ def eigh(a, UPLO=None):
raise ValueError(msg)

a = _promote_arg_dtypes(np.asarray(a))
v, w = lax_linalg.eigh(a, lower=lower)
v, w = lax_linalg.eigh(a, lower=lower, symmetrize_input=symmetrize_input)
return w, v


Expand Down
76 changes: 74 additions & 2 deletions tests/linalg_test.py
Expand Up @@ -132,14 +132,86 @@ def norm(x):

a, = args_maker()
a = (a + onp.conj(a.T)) / 2
w, v = np.linalg.eigh(onp.tril(a) if lower else onp.triu(a), UPLO=uplo)

w, v = np.linalg.eigh(onp.tril(a) if lower else onp.triu(a),
UPLO=uplo, symmetrize_input=False)
self.assertTrue(norm(onp.eye(n) - onp.matmul(onp.conj(T(v)), v)) < 5)
self.assertTrue(norm(onp.matmul(a, v) - w * v) < 30)

self._CompileAndCheck(partial(np.linalg.eigh, UPLO=uplo), args_maker,
check_dtypes=True)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}_lower={}".format(jtu.format_shape_dtype_string(shape, dtype),
lower),
"shape": shape, "dtype": dtype, "rng": rng, "lower":lower}
for shape in [(1, 1), (4, 4), (5, 5), (50, 50)]
for dtype in float_types() | complex_types()
for rng in [jtu.rand_default()]
for lower in [True, False]))
# TODO(phawkins): enable when there is an eigendecomposition implementation
# for GPU/TPU.
@jtu.skip_on_devices("gpu", "tpu")
def testEighGrad(self, shape, dtype, rng, lower):
if not hasattr(lapack, "jax_syevd"):
self.skipTest("No symmetric eigendecomposition implementation available")
uplo = "L" if lower else "U"
a = rng(shape, dtype)
a = (a + onp.conj(a.T)) / 2
a = onp.tril(a) if lower else onp.triu(a)
# Gradient checks will fail without symmetrization as the eigh jvp rule
# is only correct for tangents in the symmetric subspace, whereas the
# checker checks against unconstrained (co)tangents.
if dtype not in complex_types():
f = partial(np.linalg.eigh, UPLO=uplo, symmetrize_input=True)
else: # only check eigenvalue grads for complex matrices
f = lambda a: partial(np.linalg.eigh, UPLO=uplo, symmetrize_input=True)(a)[0]
jtu.check_grads(f, (a,), 2, rtol=1e-1)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}_lower={}".format(jtu.format_shape_dtype_string(shape, dtype),
lower),
"shape": shape, "dtype": dtype, "rng": rng, "lower":lower, "eps":eps}
for shape in [(1, 1), (4, 4), (5, 5), (50, 50)]
for dtype in complex_types()
for rng in [jtu.rand_default()]
for lower in [True, False]
for eps in [1e-4]))
# TODO(phawkins): enable when there is an eigendecomposition implementation
# for GPU/TPU.
@jtu.skip_on_devices("gpu", "tpu")
def testEighGradVectorComplex(self, shape, dtype, rng, lower, eps):
# Special case to test for complex eigenvector grad correctness.
# Exact eigenvector coordinate gradients are hard to test numerically for complex
# eigensystem solvers given the extra degrees of per-eigenvector phase freedom.
# Instead, we numerically verify the eigensystem properties on the perturbed
# eigenvectors. You only ever want to optimize eigenvector directions, not coordinates!
if not hasattr(lapack, "jax_syevd"):
self.skipTest("No symmetric eigendecomposition implementation available")
uplo = "L" if lower else "U"
a = rng(shape, dtype)
a = (a + onp.conj(a.T)) / 2
a = onp.tril(a) if lower else onp.triu(a)
a_dot = eps * rng(shape, dtype)
a_dot = (a_dot + onp.conj(a_dot.T)) / 2
a_dot = onp.tril(a_dot) if lower else onp.triu(a_dot)
# evaluate eigenvector gradient and groundtruth eigensystem for perturbed input matrix
f = partial(np.linalg.eigh, UPLO=uplo)
(w, v), (dw, dv) = jvp(f, primals=(a,), tangents=(a_dot,))
new_a = a + a_dot
new_w, new_v = f(new_a)
new_a = (new_a + onp.conj(new_a.T)) / 2
# Assert rtol eigenvalue delta between perturbed eigenvectors vs new true eigenvalues.
RTOL=1e-2
assert onp.max(
onp.abs((onp.diag(onp.dot(onp.conj((v+dv).T), onp.dot(new_a,(v+dv)))) - new_w) / new_w)) < RTOL
# Redundant to above, but also assert rtol for eigenvector property with new true eigenvalues.
assert onp.max(
onp.linalg.norm(onp.abs(new_w*(v+dv) - onp.dot(new_a, (v+dv))), axis=0) /
onp.linalg.norm(onp.abs(new_w*(v+dv)), axis=0)
) < RTOL

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_ord={}_axis={}_keepdims={}".format(
jtu.format_shape_dtype_string(shape, dtype), ord, axis, keepdims),
Expand Down

0 comments on commit fc4c8bd

Please sign in to comment.