Skip to content

Commit

Permalink
[jax2tf] Fix bfloat16 bug in select_and_gather_add conversion. (#4058)
Browse files Browse the repository at this point in the history
* [jax2tf] Fix bfloat16 bug in select_and_gather_add conversion.

This fix makes it possible to run bfloat16 tests for the jax2tf
conversion of select_and_gather_add.
  • Loading branch information
bchetioui committed Aug 17, 2020
1 parent c7aff1d commit ec90c35
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 5 deletions.
4 changes: 2 additions & 2 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,8 +799,8 @@ def _select_and_gather_add(tangents: TfVal,
padding: Sequence[Tuple[int, int]]):
# Note: this function follows the pattern in
# jax.lax._select_and_gather_add_translation.
dtype = to_jax_dtype(operand.dtype)
nbits = dtypes.finfo(dtype).bits
dtype = operand.dtype
nbits = dtypes.finfo(dtype.as_numpy_dtype).bits

# Specializing the function for 64 bits. Only up to 32 bits are supported on TPU,
# we thus intend to let the code throw a different exception on this platform.
Expand Down
3 changes: 0 additions & 3 deletions jax/experimental/jax2tf/tests/primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,6 @@ def _reconstruct_operand(result, is_tf: bool):
def test_select_and_gather_add(self, harness: primitive_harness.Harness):
dtype = harness.params["dtype"]

if dtype is dtypes.bfloat16:
raise unittest.SkipTest("bfloat16 not implemented")

max_bits = 64
if jtu.device_under_test() == "tpu":
max_bits = 32
Expand Down

0 comments on commit ec90c35

Please sign in to comment.