Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jax2tf] First draft of testing the QR conversion. #3775

Merged
merged 6 commits into from
Jul 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):

lax.linear_solve_p,
lax_linalg.cholesky_p, lax_linalg.eig_p, lax_linalg.eigh_p,
lax_linalg.lu_p, lax_linalg.qr_p, lax_linalg.svd_p,
lax_linalg.lu_p, lax_linalg.svd_p,
lax_linalg.triangular_solve_p,

lax_fft.fft_p, lax.igamma_grad_a_p,
Expand Down Expand Up @@ -1130,6 +1130,11 @@ def _sort(*operand: TfVal, dimension: int, is_stable: bool, num_keys: int) -> Tu

tf_impl[lax.sort_p] = _sort

def _qr(operand, full_matrices):
return tf.linalg.qr(operand, full_matrices=full_matrices)

tf_impl[lax_linalg.qr_p] = _qr

def _custom_jvp_call_jaxpr(*args: TfValOrUnit,
fun_jaxpr: core.TypedJaxpr,
jvp_jaxpr_thunk: Callable) -> Sequence[TfValOrUnit]:
Expand Down
12 changes: 12 additions & 0 deletions jax/experimental/jax2tf/tests/primitive_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from jax import config
from jax import test_util as jtu
from jax import lax
from jax import lax_linalg
from jax import numpy as jnp

import numpy as np
Expand Down Expand Up @@ -358,6 +359,17 @@ def parameterized(harness_group: Iterable[Harness],
for is_stable in [False, True]
)

lax_linalg_qr = tuple(
Harness(f"multi_array_shape={jtu.format_shape_dtype_string(shape, dtype)}_fullmatrices={full_matrices}",
lax_linalg.qr,
[RandArg(shape, dtype), StaticArg(full_matrices)],
shape=shape,
dtype=dtype,
full_matrices=full_matrices)
for dtype in jtu.dtypes.all
for shape in [(1, 1), (3, 3), (3, 4), (2, 10, 5), (2, 200, 100)]
for full_matrices in [False, True]
)

lax_slice = tuple(
Harness(f"_shape={shape}_start_indices={start_indices}_limit_indices={limit_indices}_strides={strides}", # type: ignore
Expand Down
19 changes: 19 additions & 0 deletions jax/experimental/jax2tf/tests/primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,25 @@ def test_sort(self, harness: primitive_harness.Harness):
raise unittest.SkipTest("GPU tests are running TF on CPU")
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()))

@primitive_harness.parameterized(primitive_harness.lax_linalg_qr)
def test_qr(self, harness: primitive_harness.Harness):
# See jax.lib.lapack.geqrf for the list of compatible types
if harness.params["dtype"] in [jnp.float32, jnp.float64]:
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
atol=1e-5, rtol=1e-5)
elif harness.params["dtype"] in [jnp.complex64, jnp.complex128]:
# TODO: see https://github.com/google/jax/pull/3775#issuecomment-659407824.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good strategy to keep track of TODOs that are discussed in comments.

# - check_compiled=True breaks for complex types;
# - for now, the performance of the HLO QR implementation called when
# compiling with TF is expected to have worse performance than the
# custom calls made in JAX.
self.ConvertAndCompare(harness.dyn_fun, *harness.dyn_args_maker(self.rng()),
expect_tf_exceptions=True, atol=1e-5, rtol=1e-5)
else:
expected_error = ValueError if jtu.device_under_test() == "gpu" else NotImplementedError
with self.assertRaisesRegex(expected_error, "Unsupported dtype"):
harness.dyn_fun(*harness.dyn_args_maker(self.rng()))

@primitive_harness.parameterized(primitive_harness.lax_unary_elementwise)
def test_unary_elementwise(self, harness: primitive_harness.Harness):
dtype = harness.params["dtype"]
Expand Down