From 9773432b831e8ca375491266a1e73a148f8043dd Mon Sep 17 00:00:00 2001 From: Benjamin Chetioui <3920784+SIben@users.noreply.github.com> Date: Tue, 21 Jul 2020 15:36:35 +0200 Subject: [PATCH] [jax2tf] First draft of testing the QR conversion. (#3775) * [jax2tf] First draft of testing the QR conversion. QR decomposition is off by over 1e-6 in some instances, requiring custom atol and rtol values in testing code. There is an odd problem in which experimental compilation fails for complex types, although they are in principle supported. --- jax/experimental/jax2tf/jax2tf.py | 7 ++++++- .../jax2tf/tests/primitive_harness.py | 12 ++++++++++++ .../jax2tf/tests/primitives_test.py | 19 +++++++++++++++++++ 3 files changed, 37 insertions(+), 1 deletion(-) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index fc81c809ed7c..4bcf108627ff 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -358,7 +358,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs): lax.linear_solve_p, lax_linalg.cholesky_p, lax_linalg.eig_p, lax_linalg.eigh_p, - lax_linalg.lu_p, lax_linalg.qr_p, lax_linalg.svd_p, + lax_linalg.lu_p, lax_linalg.svd_p, lax_linalg.triangular_solve_p, lax_fft.fft_p, lax.igamma_grad_a_p, @@ -1136,6 +1136,11 @@ def _sort(*operand: TfVal, dimension: int, is_stable: bool, num_keys: int) -> Tu tf_impl[lax.sort_p] = _sort +def _qr(operand, full_matrices): + return tf.linalg.qr(operand, full_matrices=full_matrices) + +tf_impl[lax_linalg.qr_p] = _qr + def _custom_jvp_call_jaxpr(*args: TfValOrUnit, fun_jaxpr: core.TypedJaxpr, jvp_jaxpr_thunk: Callable) -> Sequence[TfValOrUnit]: diff --git a/jax/experimental/jax2tf/tests/primitive_harness.py b/jax/experimental/jax2tf/tests/primitive_harness.py index d39f3b79fda1..d6a4b124bc16 100644 --- a/jax/experimental/jax2tf/tests/primitive_harness.py +++ b/jax/experimental/jax2tf/tests/primitive_harness.py @@ -24,6 +24,7 @@ from jax import config from jax import test_util as jtu from jax import lax +from jax import lax_linalg from jax import numpy as jnp import numpy as np @@ -358,6 +359,17 @@ def parameterized(harness_group: Iterable[Harness], for is_stable in [False, True] ) +lax_linalg_qr = tuple( + Harness(f"multi_array_shape={jtu.format_shape_dtype_string(shape, dtype)}_fullmatrices={full_matrices}", + lax_linalg.qr, + [RandArg(shape, dtype), StaticArg(full_matrices)], + shape=shape, + dtype=dtype, + full_matrices=full_matrices) + for dtype in jtu.dtypes.all + for shape in [(1, 1), (3, 3), (3, 4), (2, 10, 5), (2, 200, 100)] + for full_matrices in [False, True] +) lax_slice = tuple( Harness(f"_shape={shape}_start_indices={start_indices}_limit_indices={limit_indices}_strides={strides}", # type: ignore diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index f66f01f306ba..1fa5b84b330c 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -133,6 +133,25 @@ def test_sort(self, harness: primitive_harness.Harness): raise unittest.SkipTest("GPU tests are running TF on CPU") self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng())) + @primitive_harness.parameterized(primitive_harness.lax_linalg_qr) + def test_qr(self, harness: primitive_harness.Harness): + # See jax.lib.lapack.geqrf for the list of compatible types + if harness.params["dtype"] in [jnp.float32, jnp.float64]: + self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), + atol=1e-5, rtol=1e-5) + elif harness.params["dtype"] in [jnp.complex64, jnp.complex128]: + # TODO: see https://github.com/google/jax/pull/3775#issuecomment-659407824. + # - check_compiled=True breaks for complex types; + # - for now, the performance of the HLO QR implementation called when + # compiling with TF is expected to have worse performance than the + # custom calls made in JAX. + self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()), + expect_tf_exceptions=True, atol=1e-5, rtol=1e-5) + else: + expected_error = ValueError if jtu.device_under_test() == "gpu" else NotImplementedError + with self.assertRaisesRegex(expected_error, "Unsupported dtype"): + harness.dyn_fun(*harness.dyn_args_maker(self.rng())) + @primitive_harness.parameterized(primitive_harness.lax_unary_elementwise) def test_unary_elementwise(self, harness: primitive_harness.Harness): dtype = harness.params["dtype"]