Skip to content

Commit

Permalink
[jax2tf] Update JAX limitations.
Browse files Browse the repository at this point in the history
JAX has made progress in coverage of primitives on TPU. This PR
updates those limitations.
  • Loading branch information
gnecula committed May 9, 2022
1 parent ccccd8a commit 137fdb4
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 63 deletions.
33 changes: 14 additions & 19 deletions jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Primitives with limited JAX support

*Last generated on: 2021-07-31* (YYYY-MM-DD)
*Last generated on: 2022-05-09* (YYYY-MM-DD)

## Supported data types for primitives

We use a set of 2809 test harnesses to test
the implementation of 126 numeric JAX primitives.
We use a set of 3075 test harnesses to test
the implementation of 127 numeric JAX primitives.
We consider a JAX primitive supported for a particular data
type if it is supported on at least one device type.
The following table shows the dtypes at which primitives
Expand Down Expand Up @@ -64,14 +64,14 @@ be updated.
| complex | 4 | float32, float64 | bfloat16, bool, complex, float16, integer |
| concatenate | 17 | all | |
| conj | 5 | complex, float32, float64 | bfloat16, bool, float16, integer |
| conv_general_dilated | 73 | inexact, int16, int32, int8 | bool, int64, unsigned |
| conv_general_dilated | 96 | inexact, int16, int32, int8 | bool, int64, unsigned |
| convert_element_type | 201 | all | |
| cos | 6 | inexact | bool, integer |
| cosh | 6 | inexact | bool, integer |
| cummax | 17 | inexact, integer | bool |
| cummin | 17 | inexact, integer | bool |
| cumprod | 17 | inexact, integer | bool |
| cumsum | 17 | inexact, integer | bool |
| cummax | 34 | inexact, integer | bool |
| cummin | 34 | inexact, integer | bool |
| cumprod | 34 | inexact, integer | bool |
| cumsum | 34 | inexact, integer | bool |
| custom_linear_solve | 4 | float32, float64 | bfloat16, bool, complex, float16, integer |
| device_put | 16 | all | |
| digamma | 4 | floating | bool, complex, integer |
Expand All @@ -89,7 +89,7 @@ be updated.
| expm1 | 6 | inexact | bool, integer |
| fft | 20 | complex, float32, float64 | bfloat16, bool, float16, integer |
| floor | 4 | floating | bool, complex, integer |
| gather | 80 | all | |
| gather | 136 | all | |
| ge | 17 | all | |
| gt | 17 | all | |
| igamma | 6 | floating | bool, complex, integer |
Expand All @@ -111,7 +111,7 @@ be updated.
| neg | 14 | inexact, integer | bool |
| nextafter | 6 | floating | bool, complex, integer |
| or | 11 | bool, integer | inexact |
| pad | 120 | all | |
| pad | 180 | all | |
| population_count | 8 | integer | bool, inexact |
| pow | 10 | inexact | bool, integer |
| qr | 60 | inexact | bool, integer |
Expand All @@ -129,22 +129,23 @@ be updated.
| reduce_prod | 14 | inexact, integer | bool |
| reduce_sum | 14 | inexact, integer | bool |
| reduce_window_add | 33 | inexact, integer | bool |
| reduce_window_max | 37 | all | |
| reduce_window_max | 39 | all | |
| reduce_window_min | 15 | all | |
| reduce_window_mul | 42 | inexact, integer | bool |
| regularized_incomplete_beta | 4 | floating | bool, complex, integer |
| rem | 18 | floating, integer | bool, complex |
| reshape | 19 | all | |
| rev | 19 | all | |
| rng_bit_generator | 36 | uint32, uint64 | bool, inexact, signed, uint16, uint8 |
| round | 6 | floating | bool, complex, integer |
| rsqrt | 6 | inexact | bool, integer |
| scatter_add | 15 | all | |
| scatter_max | 15 | all | |
| scatter_min | 19 | all | |
| scatter_min | 24 | all | |
| scatter_mul | 15 | all | |
| select | 16 | all | |
| select_and_gather_add | 15 | floating | bool, complex, integer |
| select_and_scatter_add | 27 | bool, floating, integer | complex |
| select_n | 32 | all | |
| shift_left | 10 | integer | bool, inexact |
| shift_right_arithmetic | 10 | integer | bool, inexact |
| shift_right_logical | 10 | integer | bool, inexact |
Expand Down Expand Up @@ -191,11 +192,6 @@ and search for "limitation".
|cholesky|unimplemented|float16|cpu, gpu|
|clamp|unimplemented|bool, complex|cpu, gpu, tpu|
|conv_general_dilated|preferred_element_type not implemented for integers|int16, int32, int8|gpu|
|conv_general_dilated|preferred_element_type=c128 not implemented|complex64|tpu|
|conv_general_dilated|preferred_element_type=f64 not implemented|bfloat16, float16, float32|tpu|
|conv_general_dilated|preferred_element_type=i64 not implemented|int16, int32, int8|tpu|
|dot_general|preferred_element_type=c128 not implemented|complex64|tpu|
|dot_general|preferred_element_type=i64 not implemented|int16, int32, int8|tpu|
|eig|only supported on CPU in JAX|all|tpu, gpu|
|eig|unimplemented|bfloat16, float16|cpu|
|eigh|unimplemented|bfloat16, float16|cpu, gpu|
Expand All @@ -204,7 +200,6 @@ and search for "limitation".
|scatter_add|unimplemented|bool|cpu, gpu, tpu|
|scatter_mul|unimplemented|bool|cpu, gpu, tpu|
|select_and_scatter_add|works only for 2 or more inactive dimensions|all|tpu|
|svd|complex not implemented. Works in JAX for CPU and GPU with custom kernels|complex|tpu|
|svd|unimplemented|bfloat16, float16|cpu, gpu|
|triangular_solve|unimplemented|float16|gpu|

Expand Down
13 changes: 5 additions & 8 deletions jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Primitives with limited support for jax2tf

*Last generated on (YYYY-MM-DD): 2021-12-06*
*Last generated on (YYYY-MM-DD): 2022-05-09*

This document summarizes known limitations of the jax2tf conversion.
There are several kinds of limitations.
Expand Down Expand Up @@ -66,14 +66,9 @@ More detailed information can be found in the
| cholesky | TF test skipped: Not implemented in JAX: unimplemented | float16 | cpu, gpu | compiled, eager, graph |
| clamp | TF test skipped: Not implemented in JAX: unimplemented | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type not implemented for integers | int16, int32, int8 | gpu | compiled, eager, graph |
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type=c128 not implemented | complex64 | tpu | compiled, eager, graph |
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type=f64 not implemented | bfloat16, float16, float32 | tpu | compiled, eager, graph |
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type=i64 not implemented | int16, int32, int8 | tpu | compiled, eager, graph |
| digamma | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| div | TF error: TF integer division fails if divisor contains 0; JAX returns NaN | integer | cpu, gpu, tpu | compiled, eager, graph |
| dot_general | TF error: Numeric comparison disabled: Non-deterministic NaN for dot_general with preferred_element_type on GPU (b/189287598) | bfloat16, complex64, float16, float32 | gpu | compiled, eager, graph |
| dot_general | TF test skipped: Not implemented in JAX: preferred_element_type=c128 not implemented | complex64 | tpu | compiled, eager, graph |
| dot_general | TF test skipped: Not implemented in JAX: preferred_element_type=i64 not implemented | int16, int32, int8 | tpu | compiled, eager, graph |
| dot_general | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
| eig | TF test skipped: Not implemented in JAX: only supported on CPU in JAX | all | gpu, tpu | compiled, eager, graph |
| eig | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu | compiled, eager, graph |
Expand Down Expand Up @@ -102,10 +97,11 @@ More detailed information can be found in the
| scatter_mul | TF test skipped: Not implemented in JAX: unimplemented | bool | cpu, gpu, tpu | compiled, eager, graph |
| select_and_gather_add | TF error: jax2tf unimplemented for 64-bit inputs because the current implementation relies on packing two values into a single value. This can be fixed by using a variadic XlaReduceWindow, when available | float64 | cpu, gpu | compiled, eager, graph |
| select_and_scatter_add | TF test skipped: Not implemented in JAX: works only for 2 or more inactive dimensions | all | tpu | compiled, eager, graph |
| svd | TF test skipped: Not implemented in JAX: complex not implemented. Works in JAX for CPU and GPU with custom kernels | complex | tpu | compiled, eager, graph |
| svd | TF error: Numeric comparison disabled: Large numerical discrepancy | float16 | tpu | compiled, eager, graph |
| svd | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu | compiled, eager, graph |
| svd | TF error: function not compilable. Implemented using `tf.linalg.svd` and `tf.linalg.adjoint` | complex | cpu, gpu | compiled |
| svd | TF error: op not defined for dtype | bfloat16 | tpu | compiled, eager, graph |
| svd | TF error: op not defined for dtype | complex | tpu | compiled, graph |
| triangular_solve | TF test skipped: Not implemented in JAX: unimplemented | float16 | gpu | compiled, eager, graph |
| triangular_solve | TF error: op not defined for dtype | bfloat16 | cpu, gpu, tpu | compiled, eager, graph |
| triangular_solve | TF error: op not defined for dtype | float16 | cpu, gpu | eager, graph |
Expand Down Expand Up @@ -142,7 +138,8 @@ with jax2tf. The following table lists that cases when this does not quite hold:
| min | May return different values when one of the values is NaN. JAX always returns NaN, while TF returns the value NaN is compared with. | all | cpu, gpu, tpu | compiled, eager, graph |
| pow | custom numeric comparison | complex | cpu, gpu, tpu | eager, graph |
| sort | Numeric comparison disabled: TODO: TF non-stable multiple-array sort | all | gpu | compiled, eager, graph |
| svd | custom numeric comparison when compute_uv | all | cpu, gpu | compiled, eager, graph |
| svd | custom numeric comparison when compute_uv on CPU/GPU | all | cpu, gpu | compiled, eager, graph |
| svd | custom numeric comparison when compute_uv on TPU | complex, float32, float64 | tpu | compiled, eager, graph |
| top_k | Produces different results when the array contains `inf` and `NaN` (they are sorted differently in TF vs. XLA). | floating | cpu, gpu, tpu | eager, graph |

## Updating the documentation
Expand Down
3 changes: 3 additions & 0 deletions jax/experimental/jax2tf/tests/jax2tf_limitations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,9 @@ def dot_column_wise(a, b):
modes=("eager", "graph", "compiled"),
skip_comparison=True),
missing_tf_kernel(dtypes=[dtypes.bfloat16], devices="tpu"),
missing_tf_kernel(dtypes=[np.complex64, np.complex128],
modes=("compiled", "graph"),
devices="tpu"),
custom_numeric(
tol=1e-4,
dtypes=[np.float32, np.complex64],
Expand Down
10 changes: 5 additions & 5 deletions jax/experimental/jax2tf/tests/jax_primitives_coverage_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ def test_jax_implemented(self, harness: primitive_harness.Harness):
logging.warning("Found no JAX error but expected JAX limitations: %s in "
"harness: %s",
[u.description for u in jax_unimpl], harness.fullname)
# We assert that we don't have too strict limitations. This assert can
# fail if somebody fixes a JAX or XLA limitation. In that case, you should
# find and remove the Limitation in primitive_harness. Alternatively,
# uncomment this assert and ping an OWNER of primitive_harness.
# self.assertEmpty(msg)
# We do not fail the test if we have too many limitations. If you want
# to find extraneous limitations, uncomment this assert and run the test
# on all platforms.
# self.assertEmpty(("Found no JAX error but expected JAX limitations: "
# f"{[u.description for u in jax_unimpl]} in harness: {harness.fullname}"))

def test_generate_primitives_coverage_doc(self):
harnesses = primitive_harness.all_harnesses
Expand Down
31 changes: 0 additions & 31 deletions jax/experimental/jax2tf/tests/primitive_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -1688,10 +1688,6 @@ def _fft_rng_factory(dtype):
"unimplemented",
devices=("cpu", "gpu"),
dtypes=[np.float16, dtypes.bfloat16]),
Limitation(
"complex not implemented. Works in JAX for CPU and GPU with custom kernels",
devices="tpu",
dtypes=[np.complex64, np.complex128])
],
shape=shape,
dtype=dtype,
Expand Down Expand Up @@ -2621,18 +2617,6 @@ def _make_dot_general_harness(name,
dimension_numbers=dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type,
jax_unimplemented=[
Limitation(
"preferred_element_type=c128 not implemented",
devices="tpu",
dtypes=np.complex64,
enabled=(preferred_element_type in [np.complex128])),
Limitation(
"preferred_element_type=i64 not implemented",
devices="tpu",
dtypes=(np.int8, np.int16, np.int32),
enabled=(preferred_element_type in [np.int64])),
],
)


Expand Down Expand Up @@ -2794,28 +2778,13 @@ def _make_conv_harness(name,
preferred_element_type=preferred_element_type,
enable_xla=enable_xla,
jax_unimplemented=[
Limitation(
"preferred_element_type=i64 not implemented",
devices="tpu",
dtypes=(np.int8, np.int16, np.int32),
enabled=(preferred_element_type in [np.int64])),
# b/183565702 - no integer convolutions for GPU
Limitation(
"preferred_element_type not implemented for integers",
devices="gpu",
dtypes=(np.int8, np.int16, np.int32),
enabled=(preferred_element_type in [np.int16, np.int32,
np.int64])),
Limitation(
"preferred_element_type=f64 not implemented",
devices="tpu",
dtypes=(np.float16, jnp.bfloat16, np.float32),
enabled=(preferred_element_type in [np.float64])),
Limitation(
"preferred_element_type=c128 not implemented",
devices="tpu",
dtypes=np.complex64,
enabled=(preferred_element_type in [np.complex128])),
],
)

Expand Down

0 comments on commit 137fdb4

Please sign in to comment.