Skip to content

Commit

Permalink
[jax2tf] First draft of testing the QR conversion. (#3775)
Browse files Browse the repository at this point in the history
* [jax2tf] First draft of testing the QR conversion.

QR decomposition is off by over 1e-6 in some instances, requiring custom
atol and rtol values in testing code. There is an odd problem in which
experimental compilation fails for complex types, although they are in
principle supported.
  • Loading branch information
bchetioui committed Jul 21, 2020
1 parent 71f80a5 commit 9773432
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 1 deletion.
7 changes: 6 additions & 1 deletion jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):

lax.linear_solve_p,
lax_linalg.cholesky_p, lax_linalg.eig_p, lax_linalg.eigh_p,
lax_linalg.lu_p, lax_linalg.qr_p, lax_linalg.svd_p,
lax_linalg.lu_p, lax_linalg.svd_p,
lax_linalg.triangular_solve_p,

lax_fft.fft_p, lax.igamma_grad_a_p,
Expand Down Expand Up @@ -1136,6 +1136,11 @@ def _sort(*operand: TfVal, dimension: int, is_stable: bool, num_keys: int) -> Tu

tf_impl[lax.sort_p] = _sort

def _qr(operand, full_matrices):
return tf.linalg.qr(operand, full_matrices=full_matrices)

tf_impl[lax_linalg.qr_p] = _qr

def _custom_jvp_call_jaxpr(*args: TfValOrUnit,
fun_jaxpr: core.TypedJaxpr,
jvp_jaxpr_thunk: Callable) -> Sequence[TfValOrUnit]:
Expand Down
12 changes: 12 additions & 0 deletions jax/experimental/jax2tf/tests/primitive_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from jax import config
from jax import test_util as jtu
from jax import lax
from jax import lax_linalg
from jax import numpy as jnp

import numpy as np
Expand Down Expand Up @@ -358,6 +359,17 @@ def parameterized(harness_group: Iterable[Harness],
for is_stable in [False, True]
)

lax_linalg_qr = tuple(
Harness(f"multi_array_shape={jtu.format_shape_dtype_string(shape, dtype)}_fullmatrices={full_matrices}",
lax_linalg.qr,
[RandArg(shape, dtype), StaticArg(full_matrices)],
shape=shape,
dtype=dtype,
full_matrices=full_matrices)
for dtype in jtu.dtypes.all
for shape in [(1, 1), (3, 3), (3, 4), (2, 10, 5), (2, 200, 100)]
for full_matrices in [False, True]
)

lax_slice = tuple(
Harness(f"_shape={shape}_start_indices={start_indices}_limit_indices={limit_indices}_strides={strides}", # type: ignore
Expand Down
19 changes: 19 additions & 0 deletions jax/experimental/jax2tf/tests/primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,25 @@ def test_sort(self, harness: primitive_harness.Harness):
raise unittest.SkipTest("GPU tests are running TF on CPU")
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))

@primitive_harness.parameterized(primitive_harness.lax_linalg_qr)
def test_qr(self, harness: primitive_harness.Harness):
# See jax.lib.lapack.geqrf for the list of compatible types
if harness.params["dtype"] in [jnp.float32, jnp.float64]:
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
atol=1e-5, rtol=1e-5)
elif harness.params["dtype"] in [jnp.complex64, jnp.complex128]:
# TODO: see https://github.com/google/jax/pull/3775#issuecomment-659407824.
# - check_compiled=True breaks for complex types;
# - for now, the performance of the HLO QR implementation called when
# compiling with TF is expected to have worse performance than the
# custom calls made in JAX.
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
expect_tf_exceptions=True, atol=1e-5, rtol=1e-5)
else:
expected_error = ValueError if jtu.device_under_test() == "gpu" else NotImplementedError
with self.assertRaisesRegex(expected_error, "Unsupported dtype"):
harness.dyn_fun(*harness.dyn_args_maker(self.rng()))

@primitive_harness.parameterized(primitive_harness.lax_unary_elementwise)
def test_unary_elementwise(self, harness: primitive_harness.Harness):
dtype = harness.params["dtype"]
Expand Down

0 comments on commit 9773432

Please sign in to comment.