Skip to content

Commit

Permalink
added jvp rule for eigh, tests
Browse files Browse the repository at this point in the history
  • Loading branch information
levskaya committed Feb 14, 2019
1 parent 9ba27be commit 8a84ae8
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 5 deletions.
29 changes: 28 additions & 1 deletion jax/lax_linalg.py
Expand Up @@ -39,7 +39,10 @@ def cholesky(x, symmetrize_input=True):
x = symmetrize(x)
return cholesky_p.bind(x)

def eigh(x, lower=True): return eigh_p.bind(x, lower=lower)
def eigh(x, lower=True, symmetrize=True):
if symmetrize:
x = (x + _H(x)) / 2 # orthogonal projection onto self-adjoint matrices
return eigh_p.bind(x, lower=lower)

def lu(x): return lu_p.bind(x)

Expand Down Expand Up @@ -146,10 +149,34 @@ def eigh_cpu_translation_rule(c, operand, lower):
raise NotImplementedError(
"Only unbatched eigendecomposition is implemented on CPU")

def eigh_jvp_rule(primals, tangents, lower):
# Derivative for eigh in the simplest case of distinct eigenvalues.
# Simple case from https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
# The general solution treating the case of degenerate eigenvalues is
# considerably more complicated. Ambitious readers may refer to the general
# methods at:
# https://www.win.tue.nl/analysis/reports/rana06-33.pdf and
# https://people.orie.cornell.edu/aslewis/publications/99-clarke.pdf
a, = primals
a_dot, = tangents
v, w = eigh_p.bind((a + _H(a)) / 2.0, lower=lower)
# for complex numbers we need eigenvalues to be full dtype of v, a:
w = w.astype(a.dtype)
eye_n = np.eye(a.shape[-1], dtype=a.dtype)
# carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs.
Fmat = np.reciprocal(eye_n + w - w[..., np.newaxis]) - eye_n
# eigh impl doesn't support batch dims, but future-proof the grad.
dot = lax.dot if a.ndim == 2 else lax.batch_matmul
vdag_adot_v = dot(dot(_H(v), a_dot), v)
dv = dot(v, np.multiply(Fmat, vdag_adot_v))
dw = np.diagonal(np.multiply(eye_n, vdag_adot_v))
return core.pack((v, w)), core.pack((dv, dw))

eigh_p = Primitive('eigh')
eigh_p.def_impl(eigh_impl)
eigh_p.def_abstract_eval(eigh_abstract_eval)
xla.translations[eigh_p] = eigh_translation_rule
ad.primitive_jvps[eigh_p] = eigh_jvp_rule
xla.backend_specific_translations['Host'][eigh_p] = eigh_cpu_translation_rule


Expand Down
4 changes: 2 additions & 2 deletions jax/numpy/linalg.py
Expand Up @@ -93,7 +93,7 @@ def det(a):


@_wraps(onp.linalg.eigh)
def eigh(a, UPLO=None):
def eigh(a, UPLO=None, symmetrize=True):
if UPLO is None or UPLO == "L":
lower = True
elif UPLO == "U":
Expand All @@ -103,7 +103,7 @@ def eigh(a, UPLO=None):
raise ValueError(msg)

a = _promote_arg_dtypes(np.asarray(a))
v, w = lax_linalg.eigh(a, lower=lower)
v, w = lax_linalg.eigh(a, lower=lower, symmetrize=symmetrize)
return w, v


Expand Down
29 changes: 27 additions & 2 deletions tests/linalg_test.py
Expand Up @@ -132,14 +132,39 @@ def norm(x):

a, = args_maker()
a = (a + onp.conj(a.T)) / 2
w, v = np.linalg.eigh(onp.tril(a) if lower else onp.triu(a), UPLO=uplo)

w, v = np.linalg.eigh(onp.tril(a) if lower else onp.triu(a),
UPLO=uplo, symmetrize=False)
self.assertTrue(norm(onp.eye(n) - onp.matmul(onp.conj(T(v)), v)) < 5)
self.assertTrue(norm(onp.matmul(a, v) - w * v) < 30)

self._CompileAndCheck(partial(np.linalg.eigh, UPLO=uplo), args_maker,
check_dtypes=True)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name":
"_shape={}_lower={}".format(jtu.format_shape_dtype_string(shape, dtype),
lower),
"shape": shape, "dtype": dtype, "rng": rng, "lower":lower}
for shape in [(1, 1), (4, 4), (5, 5), (50, 50)]
for dtype in float_types() | complex_types()
for rng in [jtu.rand_default()]
for lower in [True, False]))
# TODO(phawkins): enable when there is an eigendecomposition implementation
# for GPU/TPU.
@jtu.skip_on_devices("gpu", "tpu")
def testEighGrad(self, shape, dtype, rng, lower):
if not hasattr(lapack, "jax_syevd"):
self.skipTest("No symmetric eigendecomposition implementation available")
uplo = "L" if lower else "U"
a = rng(shape, dtype)
a = (a + onp.conj(a.T)) / 2
a = onp.tril(a) if lower else onp.triu(a)
# Gradient checks will fail without symmetrization as the eigh jvp rule
# is only correct for tangents in the symmetric subspace, whereas the
# checker checks against unconstrained (co)tangents.
f = partial(np.linalg.eigh, UPLO=uplo, symmetrize=True)
jtu.check_grads(f, (a,), 2, rtol=1e-1)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_ord={}_axis={}_keepdims={}".format(
jtu.format_shape_dtype_string(shape, dtype), ord, axis, keepdims),
Expand Down

0 comments on commit 8a84ae8

Please sign in to comment.