Skip to content

Commit

Permalink
Fix conditional in eig and expand eig test suite. (jax-ml#4320)
Browse files Browse the repository at this point in the history
* Fix conditional in eig and expand eig test suite.
  • Loading branch information
bchetioui committed Sep 18, 2020
1 parent b3a0987 commit d478e34
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 7 deletions.
3 changes: 2 additions & 1 deletion jax/lax_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,9 @@ def eig_abstract_eval(operand, *, compute_left_eigenvectors,
output = [w]
if compute_left_eigenvectors:
output.append(vl)
elif compute_right_eigenvectors:
if compute_right_eigenvectors:
output.append(vr)

return tuple(output)

_cpu_geev = lapack.geev
Expand Down
38 changes: 32 additions & 6 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,16 +222,26 @@ def testIssue1213(self):
tol=1e-3)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(
jtu.format_shape_dtype_string(shape, dtype)),
"shape": shape, "dtype": dtype, "rng_factory": rng_factory}
{"testcase_name": "_shape={}_leftvectors={}_rightvectors={}".format(
jtu.format_shape_dtype_string(shape, dtype),
compute_left_eigenvectors, compute_right_eigenvectors),
"shape": shape, "dtype": dtype, "rng_factory": rng_factory,
"compute_left_eigenvectors": compute_left_eigenvectors,
"compute_right_eigenvectors": compute_right_eigenvectors}
for shape in [(0, 0), (4, 4), (5, 5), (50, 50), (2, 6, 6)]
for dtype in float_types + complex_types
for compute_left_eigenvectors, compute_right_eigenvectors in [
(False, False),
(True, False),
(False, True),
(True, True)
]
for rng_factory in [jtu.rand_default]))
# TODO(phawkins): enable when there is an eigendecomposition implementation
# for GPU/TPU.
@jtu.skip_on_devices("gpu", "tpu")
def testEig(self, shape, dtype, rng_factory):
def testEig(self, shape, dtype, compute_left_eigenvectors,
compute_right_eigenvectors, rng_factory):
rng = rng_factory(self.rng())
jtu.skip_if_unsupported_type(dtype)
n = shape[-1]
Expand All @@ -242,9 +252,25 @@ def norm(x):
norm = np.linalg.norm(x, axis=(-2, -1))
return norm / ((n + 1) * jnp.finfo(dtype).eps)

def check_right_eigenvectors(a, w, vr):
self.assertTrue(
np.all(norm(np.matmul(a, vr) - w[..., None, :] * vr) < 100))

def check_left_eigenvectors(a, w, vl):
rank = len(a.shape)
aH = jnp.conj(a.transpose(list(range(rank - 2)) + [rank - 1, rank - 2]))
wC = jnp.conj(w)
check_right_eigenvectors(aH, wC, vl)

a, = args_maker()
w, v = jnp.linalg.eig(a)
self.assertTrue(np.all(norm(np.matmul(a, v) - w[..., None, :] * v) < 100))
results = lax_linalg.eig(a, compute_left_eigenvectors,
compute_right_eigenvectors)
w = results[0]

if compute_left_eigenvectors:
check_left_eigenvectors(a, w, results[1])
if compute_right_eigenvectors:
check_right_eigenvectors(a, w, results[1 + compute_left_eigenvectors])

self._CompileAndCheck(partial(jnp.linalg.eig), args_maker,
rtol=1e-3)
Expand Down

0 comments on commit d478e34

Please sign in to comment.