Skip to content

Commit

Permalink
Fix test failure for jax.numpy.signbit(bfloat16) on TPU. (#1735)
Browse files Browse the repository at this point in the history
  • Loading branch information
hawkinsp committed Nov 21, 2019
1 parent c60f3fd commit 2b0cde3
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions jax/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,9 +571,7 @@ def exp2(x):
@_wraps(onp.signbit)
def signbit(x):
x, = _promote_shapes("signbit", x)

dtype = _dtype(x)

if issubdtype(dtype, integer):
return lax.lt(x, _constant_like(x, 0))
elif issubdtype(dtype, bool_):
Expand All @@ -582,8 +580,13 @@ def signbit(x):
raise ValueError(
"jax.numpy.signbit is not well defined for %s" % dtype)

info = finfo(_dtype(x))
# TPU supports BF16 but not S16 types, so as a workaround, convert BF16 to
# F32.
if dtype == bfloat16:
dtype = float32
x = lax.convert_element_type(x, float32)

info = finfo(dtype)
if info.bits == 16:
int_type = onp.int16
elif info.bits == 32:
Expand Down

0 comments on commit 2b0cde3

Please sign in to comment.