diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index a9db53251d9f..62fe2b7d1f6d 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -142,7 +142,12 @@ def test_qr(self, harness: primitive_harness.Harness): elif harness.params["dtype"] in [jnp.float32, jnp.float64]: self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), atol=1e-5, rtol=1e-5) - else: # TODO: figure out why check_compiled=True breaks for complex types. + else: + # TODO: see https://github.com/google/jax/pull/3775#issuecomment-659407824. + # - check_compiled=True breaks for complex types; + # - for now, the performance of the HLO QR implementation called when + # compiling with TF is expected to have worse performance than the + # custom calls made in JAX. self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), expect_tf_exceptions=True, atol=1e-5, rtol=1e-5)