Skip to content

Commit

Permalink
Fixed thrown error for GPU.
Browse files Browse the repository at this point in the history
  • Loading branch information
bchetioui committed Jul 21, 2020
1 parent 4c3fac6 commit 48ad8f0
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion jax/experimental/jax2tf/tests/primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ def test_sort(self, harness: primitive_harness.Harness):
def test_qr(self, harness: primitive_harness.Harness):
# See jax.lib.lapack.geqrf for the list of compatible types
if not harness.params["dtype"] in [jnp.float32, jnp.float64, jnp.complex64, jnp.complex128]:
with self.assertRaisesRegex(NotImplementedError, "Unsupported dtype"):
expected_error = ValueError if jtu.device_under_test() == "gpu" else NotImplementedError
with self.assertRaisesRegex(expected_error, "Unsupported dtype"):
harness.dyn_fun(*harness.dyn_args_maker(self.rng()))
elif harness.params["dtype"] in [jnp.float32, jnp.float64]:
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
Expand Down

0 comments on commit 48ad8f0

Please sign in to comment.