diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index 7b2c5c6f7876..1fa5b84b330c 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -136,14 +136,10 @@ def test_sort(self, harness: primitive_harness.Harness): @primitive_harness.parameterized(primitive_harness.lax_linalg_qr) def test_qr(self, harness: primitive_harness.Harness): # See jax.lib.lapack.geqrf for the list of compatible types - if not harness.params["dtype"] in [jnp.float32, jnp.float64, jnp.complex64, jnp.complex128]: - expected_error = ValueError if jtu.device_under_test() == "gpu" else NotImplementedError - with self.assertRaisesRegex(expected_error, "Unsupported dtype"): - harness.dyn_fun(*harness.dyn_args_maker(self.rng())) - elif harness.params["dtype"] in [jnp.float32, jnp.float64]: + if harness.params["dtype"] in [jnp.float32, jnp.float64]: self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=1e-5, rtol=1e-5) - else: + elif harness.params["dtype"] in [jnp.complex64, jnp.complex128]: # TODO: see https://github.com/google/jax/pull/3775#issuecomment-659407824. # - check_compiled=True breaks for complex types; # - for now, the performance of the HLO QR implementation called when @@ -151,6 +147,10 @@ def test_qr(self, harness: primitive_harness.Harness): # custom calls made in JAX. self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), expect_tf_exceptions=True, atol=1e-5, rtol=1e-5) + else: + expected_error = ValueError if jtu.device_under_test() == "gpu" else NotImplementedError + with self.assertRaisesRegex(expected_error, "Unsupported dtype"): + harness.dyn_fun(*harness.dyn_args_maker(self.rng())) @primitive_harness.parameterized(primitive_harness.lax_unary_elementwise) def test_unary_elementwise(self, harness: primitive_harness.Harness):