Skip to content

Commit

Permalink
Reordered tests for clarity.
Browse files Browse the repository at this point in the history
  • Loading branch information
bchetioui committed Jul 21, 2020
1 parent 48ad8f0 commit 0e788f4
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions jax/experimental/jax2tf/tests/primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,21 +136,21 @@ def test_sort(self, harness: primitive_harness.Harness):
@primitive_harness.parameterized(primitive_harness.lax_linalg_qr)
def test_qr(self, harness: primitive_harness.Harness):
# See jax.lib.lapack.geqrf for the list of compatible types
if not harness.params["dtype"] in [jnp.float32, jnp.float64, jnp.complex64, jnp.complex128]:
expected_error = ValueError if jtu.device_under_test() == "gpu" else NotImplementedError
with self.assertRaisesRegex(expected_error, "Unsupported dtype"):
harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
elif harness.params["dtype"] in [jnp.float32, jnp.float64]:
if harness.params["dtype"] in [jnp.float32, jnp.float64]:
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
atol=1e-5, rtol=1e-5)
else:
elif harness.params["dtype"] in [jnp.complex64, jnp.complex128]:
# TODO: see https://github.com/google/jax/pull/3775#issuecomment-659407824.
# - check_compiled=True breaks for complex types;
# - for now, the performance of the HLO QR implementation called when
# compiling with TF is expected to have worse performance than the
# custom calls made in JAX.
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
expect_tf_exceptions=True, atol=1e-5, rtol=1e-5)
else:
expected_error = ValueError if jtu.device_under_test() == "gpu" else NotImplementedError
with self.assertRaisesRegex(expected_error, "Unsupported dtype"):
harness.dyn_fun(*harness.dyn_args_maker(self.rng()))

@primitive_harness.parameterized(primitive_harness.lax_unary_elementwise)
def test_unary_elementwise(self, harness: primitive_harness.Harness):
Expand Down

0 comments on commit 0e788f4

Please sign in to comment.