Skip to content

Commit

Permalink
[jax2tf] Fix the conversion of reduce_sum and reduce_prod for booleans
Browse files Browse the repository at this point in the history
Also update the documentation
  • Loading branch information
gnecula committed Jun 28, 2021
1 parent a50c273 commit 44b9542
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
8 changes: 4 additions & 4 deletions jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Primitives with limited JAX support

*Last generated on: 2021-06-14* (YYYY-MM-DD)
*Last generated on: 2021-06-28* (YYYY-MM-DD)

## Supported data types for primitives

We use a set of 2604 test harnesses to test
We use a set of 2668 test harnesses to test
the implementation of 121 numeric JAX primitives.
We consider a JAX primitive supported for a particular data
type if it is supported on at least one device type.
Expand Down Expand Up @@ -78,7 +78,7 @@ be updated.
| div | 20 | inexact, integer | bool |
| dot_general | 245 | all | |
| dynamic_slice | 64 | all | |
| dynamic_update_slice | 21 | all | |
| dynamic_update_slice | 42 | all | |
| eig | 72 | inexact | bool, integer |
| eigh | 36 | inexact | bool, integer |
| eq | 17 | all | |
Expand All @@ -89,7 +89,7 @@ be updated.
| expm1 | 6 | inexact | bool, integer |
| fft | 20 | complex, float32, float64 | bfloat16, bool, float16, integer |
| floor | 4 | floating | bool, complex, integer |
| gather | 37 | all | |
| gather | 80 | all | |
| ge | 17 | all | |
| gt | 17 | all | |
| igamma | 6 | floating | bool, complex, integer |
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Primitives with limited support for jax2tf

*Last generated on (YYYY-MM-DD): 2021-06-15*
*Last generated on (YYYY-MM-DD): 2021-06-28*

This document summarizes known limitations of the jax2tf conversion.
There are several kinds of limitations.
Expand Down Expand Up @@ -71,6 +71,7 @@ More detailed information can be found in the
| conv_general_dilated | TF error: jax2tf BUG: batch_group_count > 1 not yet converted | all | cpu, gpu, tpu | compiled, eager, graph |
| digamma | TF error: op not defined for dtype | bfloat16 | cpu, gpu | eager, graph |
| div | TF error: TF integer division fails if divisor contains 0; JAX returns NaN | integer | cpu, gpu, tpu | compiled, eager, graph |
| dot_general | TF error: Numeric comparision disabled: Non-deterministic NaN for dot_general with preferred_element_type on GPU (b/189287598) | bfloat16, complex64, float16, float32 | gpu | compiled, eager, graph |
| dot_general | TF test skipped: Not implemented in JAX: preferred_element_type=c128 not implemented | complex64 | tpu | compiled, eager, graph |
| dot_general | TF test skipped: Not implemented in JAX: preferred_element_type=i64 not implemented | int16, int32, int8 | tpu | compiled, eager, graph |
| dot_general | TF error: op not defined for dtype | bool | cpu, gpu, tpu | compiled, eager, graph |
Expand Down
7 changes: 3 additions & 4 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1805,10 +1805,9 @@ def _transpose(operand, *, permutation):

axes_to_axis = lambda func: lambda operand, axes: func(operand, axis=axes)

tf_impl[lax.reduce_sum_p] = (
bool_to_int8(axes_to_axis(tf.reduce_sum), argnums=[0]))
tf_impl[lax.reduce_prod_p] = (
bool_to_int8(axes_to_axis(tf.reduce_prod), argnums=[0]))
# reduce_sum and reduce_prod are not supported for bool
tf_impl[lax.reduce_sum_p] = axes_to_axis(tf.reduce_sum)
tf_impl[lax.reduce_prod_p] = axes_to_axis(tf.reduce_prod)
tf_impl[lax.reduce_max_p] = (
bool_to_int8(axes_to_axis(tf.reduce_max), argnums=[0]))
tf_impl[lax.reduce_min_p] = (
Expand Down

0 comments on commit 44b9542

Please sign in to comment.