Skip to content

Commit

Permalink
[jax2tf] Ensure that the conversion of Round uses the same algorithm …
Browse files Browse the repository at this point in the history
…as JAX
  • Loading branch information
gnecula committed Jul 5, 2021
1 parent fbb9882 commit 2fe0226
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 17 deletions.
6 changes: 3 additions & 3 deletions jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md
@@ -1,10 +1,10 @@
# Primitives with limited JAX support

*Last generated on: 2021-06-28* (YYYY-MM-DD)
*Last generated on: 2021-07-05* (YYYY-MM-DD)

## Supported data types for primitives

We use a set of 2668 test harnesses to test
We use a set of 2667 test harnesses to test
the implementation of 121 numeric JAX primitives.
We consider a JAX primitive supported for a particular data
type if it is supported on at least one device type.
Expand Down Expand Up @@ -132,7 +132,7 @@ be updated.
| rem | 18 | floating, integer | bool, complex |
| reshape | 19 | all | |
| rev | 19 | all | |
| round | 7 | floating | bool, complex, integer |
| round | 6 | floating | bool, complex, integer |
| rsqrt | 6 | inexact | bool, integer |
| scatter_add | 15 | all | |
| scatter_max | 15 | all | |
Expand Down
@@ -1,6 +1,6 @@
# Primitives with limited support for jax2tf

*Last generated on (YYYY-MM-DD): 2021-06-28*
*Last generated on (YYYY-MM-DD): 2021-07-05*

This document summarizes known limitations of the jax2tf conversion.
There are several kinds of limitations.
Expand Down
14 changes: 9 additions & 5 deletions jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -1033,8 +1033,11 @@ def _sign(x: TfVal) -> TfVal:
tf_impl[lax.ceil_p] = tf.math.ceil


def _round(operand, *, rounding_method):
def _round(operand, *, rounding_method,
_in_avals: Sequence[core.AbstractValue],
_out_aval: core.AbstractValue):
if rounding_method is lax.RoundingMethod.AWAY_FROM_ZERO:
# JAX uses a single HLO op Round here
sign = _sign(operand)
operand *= sign
floor = tf.math.floor(operand)
Expand All @@ -1043,11 +1046,12 @@ def _round(operand, *, rounding_method):
return sign * (
tf.where(cond, tf.constant(np.array(1), operand.dtype),
tf.math.round(operand)) + floor)
else:
return tf.math.round(operand)

else: # rounding_method is RoundingMethod.TO_NEAREST_EVEN
rounding_fun = _convert_jax_impl(
lax._round_to_nearest_even, multiple_results=False)
return rounding_fun(operand, _in_avals=_in_avals, _out_aval=_out_aval)

tf_impl[lax.round_p] = _round
tf_impl_with_avals[lax.round_p] = _round
tf_impl[lax.nextafter_p] = tf.math.nextafter


Expand Down
8 changes: 0 additions & 8 deletions jax/experimental/jax2tf/tests/primitive_harness.py
Expand Up @@ -481,14 +481,6 @@ def _make_round_harness(name,
_make_round_harness(
"rounding_methods", operand=operand, rounding_method=rounding_method)

# Validate edge cases
for name, operand in [
# Checks that https://github.com/google/jax/issues/4952 is resolved
("round_away_from_0",
np.array([[0.5, 1.5, 2.5], [-0.5, -1.5, -2.5]], dtype=np.float32)),
]:
_make_round_harness(f"edge_case_{name}", operand=operand)


def _make_convert_element_type_harness(name,
*,
Expand Down

0 comments on commit 2fe0226

Please sign in to comment.