Skip to content

Commit

Permalink
[jax2tf] Update the limitations
Browse files Browse the repository at this point in the history
To account for progress on XLA and TF
  • Loading branch information
gnecula committed Jul 31, 2021
1 parent 655a3e7 commit 022d2d6
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 47 deletions.
19 changes: 11 additions & 8 deletions jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Primitives with limited JAX support

*Last generated on: 2021-07-05* (YYYY-MM-DD)
*Last generated on: 2021-07-31* (YYYY-MM-DD)

## Supported data types for primitives

We use a set of 2667 test harnesses to test
the implementation of 121 numeric JAX primitives.
We use a set of 2809 test harnesses to test
the implementation of 126 numeric JAX primitives.
We consider a JAX primitive supported for a particular data
type if it is supported on at least one device type.
The following table shows the dtypes at which primitives
Expand Down Expand Up @@ -46,8 +46,8 @@ be updated.
| add | 16 | inexact, integer | bool |
| add_any | 14 | inexact, integer | bool |
| and | 11 | bool, integer | inexact |
| argmax | 22 | bool, floating, integer | complex |
| argmin | 22 | bool, floating, integer | complex |
| argmax | 64 | bool, floating, integer | complex |
| argmin | 64 | bool, floating, integer | complex |
| asin | 6 | inexact | bool, integer |
| asinh | 6 | inexact | bool, integer |
| atan | 6 | inexact | bool, integer |
Expand All @@ -56,8 +56,8 @@ be updated.
| bessel_i0e | 4 | floating | bool, complex, integer |
| bessel_i1e | 4 | floating | bool, complex, integer |
| bitcast_convert_type | 41 | all | |
| broadcast | 17 | all | |
| broadcast_in_dim | 19 | all | |
| cbrt | 4 | floating | bool, complex, integer |
| ceil | 4 | floating | bool, complex, integer |
| cholesky | 30 | inexact | bool, integer |
| clamp | 20 | all | |
Expand Down Expand Up @@ -115,9 +115,13 @@ be updated.
| population_count | 8 | integer | bool, inexact |
| pow | 10 | inexact | bool, integer |
| qr | 60 | inexact | bool, integer |
| random_categorical | 12 | floating | bool, complex, integer |
| random_gamma | 4 | float32, float64 | bfloat16, bool, complex, float16, integer |
| random_randint | 12 | signed | bool, inexact, unsigned |
| random_split | 5 | uint32 | all |
| random_uniform | 12 | floating | bool, complex, integer |
| real | 2 | complex | bool, floating, integer |
| reduce | 33 | all | |
| reduce_and | 1 | bool | inexact, integer |
| reduce_max | 15 | all | |
| reduce_min | 15 | all | |
Expand Down Expand Up @@ -159,6 +163,7 @@ be updated.
| top_k | 15 | bool, floating, integer | complex |
| transpose | 17 | all | |
| triangular_solve | 26 | inexact | bool, integer |
| tridiagonal_solve | 2 | float32, float64 | bfloat16, bool, complex, float16, integer |
| xor | 11 | bool, integer | inexact |
| zeros_like | 15 | all | |

Expand Down Expand Up @@ -197,8 +202,6 @@ and search for "limitation".
|lu|unimplemented|bfloat16, float16|cpu, gpu, tpu|
|qr|unimplemented|bfloat16, float16|cpu, gpu|
|scatter_add|unimplemented|bool|cpu, gpu, tpu|
|scatter_max|unimplemented|complex64|tpu|
|scatter_min|unimplemented|complex64|tpu|
|scatter_mul|unimplemented|bool|cpu, gpu, tpu|
|select_and_scatter_add|works only for 2 or more inactive dimensions|all|tpu|
|svd|complex not implemented. Works in JAX for CPU and GPU with custom kernels|complex|tpu|
Expand Down
13 changes: 2 additions & 11 deletions jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Primitives with limited support for jax2tf

*Last generated on (YYYY-MM-DD): 2021-07-05*
*Last generated on (YYYY-MM-DD): 2021-07-31*

This document summarizes known limitations of the jax2tf conversion.
There are several kinds of limitations.
Expand Down Expand Up @@ -64,8 +64,6 @@ More detailed information can be found in the
| bessel_i0e | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| bessel_i1e | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| cholesky | TF test skipped: Not implemented in JAX: unimplemented | float16 | cpu, gpu | compiled, eager, graph |
| cholesky | TF error: function not compilable | complex | cpu, gpu | compiled |
| cholesky | TF error: op not defined for dtype | complex | tpu | compiled, graph |
| clamp | TF test skipped: Not implemented in JAX: unimplemented | bool, complex | cpu, gpu, tpu | compiled, eager, graph |
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type not implemented for integers | int16, int32, int8 | gpu | compiled, eager, graph |
| conv_general_dilated | TF test skipped: Not implemented in JAX: preferred_element_type=c128 not implemented | complex64 | tpu | compiled, eager, graph |
Expand All @@ -74,7 +72,7 @@ More detailed information can be found in the
| conv_general_dilated | TF error: jax2tf BUG: batch_group_count > 1 not yet converted | all | cpu, gpu, tpu | compiled, eager, graph |
| digamma | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| div | TF error: TF integer division fails if divisor contains 0; JAX returns NaN | integer | cpu, gpu, tpu | compiled, eager, graph |
| dot_general | TF error: Numeric comparision disabled: Non-deterministic NaN for dot_general with preferred_element_type on GPU (b/189287598) | bfloat16, complex64, float16, float32 | gpu | compiled, eager, graph |
| dot_general | TF error: Numeric comparison disabled: Non-deterministic NaN for dot_general with preferred_element_type on GPU (b/189287598) | bfloat16, complex64, float16, float32 | gpu | compiled, eager, graph |
| dot_general | TF test skipped: Not implemented in JAX: preferred_element_type=c128 not implemented | complex64 | tpu | compiled, eager, graph |
| dot_general | TF test skipped: Not implemented in JAX: preferred_element_type=i64 not implemented | int16, int32, int8 | tpu | compiled, eager, graph |
| dot_general | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
Expand All @@ -83,7 +81,6 @@ More detailed information can be found in the
| eig | TF error: TF Conversion of eig is not implemented when both compute_left_eigenvectors and compute_right_eigenvectors are set to True | all | cpu, gpu, tpu | compiled, eager, graph |
| eig | TF error: function not compilable | all | cpu | compiled |
| eigh | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu | compiled, eager, graph |
| eigh | TF test skipped: TF error: XLA lowering bug | complex | gpu | compiled |
| eigh | TF error: op not defined for dtype | bfloat16 | tpu | compiled, eager, graph |
| erf | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| erf_inv | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu | eager, graph |
Expand All @@ -94,7 +91,6 @@ More detailed information can be found in the
| integer_pow | TF error: op not defined for dtype | int16, int8, unsigned | cpu, gpu | graph |
| lgamma | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| lu | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph |
| lu | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
| nextafter | TF error: op not defined for dtype | bfloat16, float16 | cpu, gpu, tpu | compiled, eager, graph |
| qr | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu | compiled, eager, graph |
| qr | TF error: op not defined for dtype | bfloat16 | tpu | compiled, eager, graph |
Expand All @@ -104,18 +100,13 @@ More detailed information can be found in the
| rem | TF error: TF integer division fails if divisor contains 0; JAX returns NaN | integer | cpu, gpu, tpu | compiled, eager, graph |
| round | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| scatter_add | TF test skipped: Not implemented in JAX: unimplemented | bool | cpu, gpu, tpu | compiled, eager, graph |
| scatter_add | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
| scatter_max | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph |
| scatter_min | TF test skipped: Not implemented in JAX: unimplemented | complex64 | tpu | compiled, eager, graph |
| scatter_mul | TF test skipped: Not implemented in JAX: unimplemented | bool | cpu, gpu, tpu | compiled, eager, graph |
| scatter_mul | TF error: op not defined for dtype | complex64 | tpu | compiled, eager, graph |
| select_and_gather_add | TF error: jax2tf unimplemented for 64-bit inputs because the current implementation relies on packing two values into a single value. This can be fixed by using a variadic XlaReduceWindow, when available | float64 | cpu, gpu | compiled, eager, graph |
| select_and_scatter_add | TF test skipped: Not implemented in JAX: works only for 2 or more inactive dimensions | all | tpu | compiled, eager, graph |
| svd | TF test skipped: Not implemented in JAX: complex not implemented. Works in JAX for CPU and GPU with custom kernels | complex | tpu | compiled, eager, graph |
| svd | TF test skipped: Not implemented in JAX: unimplemented | bfloat16, float16 | cpu, gpu | compiled, eager, graph |
| svd | TF error: function not compilable. Implemented using `tf.linalg.svd` and `tf.linalg.adjoint` | complex | cpu, gpu | compiled |
| svd | TF error: op not defined for dtype | bfloat16 | tpu | compiled, eager, graph |
| top_k | TF error: op not defined for dtype | int64, uint64 | cpu, gpu | compiled |
| triangular_solve | TF test skipped: Not implemented in JAX: unimplemented | float16 | gpu | compiled, eager, graph |
| triangular_solve | TF error: op not defined for dtype | bfloat16 | cpu, gpu, tpu | compiled, eager, graph |
| triangular_solve | TF error: op not defined for dtype | float16 | cpu, gpu | eager, graph |
Expand Down
25 changes: 2 additions & 23 deletions jax/experimental/jax2tf/tests/jax2tf_limitations.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,6 @@ def custom_assert(tst, result_jax, result_tf, *, tol, err_msg, **_):
jnp.tril(result_jax), result_tf, atol=tol, err_msg=err_msg)

return [
# See https://github.com/google/jax/pull/3775#issuecomment-659407824;
Jax2TfLimitation(
"function not compilable",
dtypes=[np.complex64, np.complex128],
devices=("cpu", "gpu"),
modes="compiled"),
# TODO: very high tolerance
custom_numeric(
dtypes=[np.float32, np.complex64],
Expand Down Expand Up @@ -812,7 +806,6 @@ def _make_permutation_matrix(perm):
err_msg=err_msg)

return [
missing_tf_kernel(dtypes=[np.complex64], devices="tpu"),
custom_numeric(
dtypes=[np.float32, np.complex64], devices="tpu", tol=0.1),
custom_numeric(
Expand Down Expand Up @@ -937,21 +930,11 @@ def round(cls, harness: primitive_harness.Harness):

@classmethod
def scatter_add(cls, harness):
return [
missing_tf_kernel(
dtypes=[np.complex64],
devices="tpu",
)
]
return []

@classmethod
def scatter_mul(cls, harness):
return [
missing_tf_kernel(
dtypes=[np.complex64],
devices="tpu",
),
]
return []

@classmethod
def select_and_gather_add(cls, harness):
Expand Down Expand Up @@ -1067,10 +1050,6 @@ def custom_assert(tst, result_jax, result_tf, *, err_msg, **_):
first_arr_jax[~mask_jax], first_arr_tf[~mask_tf], err_msg=err_msg)

return [
missing_tf_kernel(
dtypes=[np.uint64, np.int64],
devices=("cpu", "gpu"),
modes="compiled"),
custom_numeric(
dtypes=[np.float16, dtypes.bfloat16, np.float32, np.float64],
custom_assert=custom_assert,
Expand Down
5 changes: 0 additions & 5 deletions jax/experimental/jax2tf/tests/primitive_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,11 +1211,6 @@ def _make_scatter_harness(name,
StaticArg(dimension_numbers)
],
jax_unimplemented=[
Limitation(
"unimplemented",
devices="tpu",
dtypes=np.complex64,
enabled=(f_lax in [lax.scatter_max, lax.scatter_min])),
Limitation(
"unimplemented",
dtypes=np.bool_,
Expand Down

0 comments on commit 022d2d6

Please sign in to comment.