Skip to content

Commit

Permalink
Remove bitwise conversions from _canonicalize_float_for_sort
Browse files Browse the repository at this point in the history
This lets us compute entirely in the float domain.

PiperOrigin-RevId: 576613806
  • Loading branch information
majnemer authored and jax authors committed Oct 25, 2023
1 parent edbe49f commit ba9fd77
Showing 1 changed file with 3 additions and 17 deletions.
20 changes: 3 additions & 17 deletions jax/_src/lax/lax.py
Expand Up @@ -4051,25 +4051,11 @@ def _canonicalize_float_for_sort(x):
# the beginning and end of the ordering. This causes issues for stable
# sorts, so we avoid this by standardizing the representation of zeros
# and NaNs in the output.
nbits = dtypes.finfo(x.dtype).bits
unsigned_dtype = _UINT_DTYPES[nbits]

# We cannot standardize zeros in x because XLA elides this is some cases.
# We cannot standardize NaNs in x because it triggers jax.debug_nans
# So instead we do these replacements in the unsigned integer representation.
result = select(eq(x, _zero(x)), _zeros(x), x)
result = select(_isnan(x), full_like(result, np.nan), result)

unsigned = bitcast_convert_type(x, unsigned_dtype)
unsigned_zero = x.dtype.type(0.0).view(unsigned_dtype)
unsigned_nan = x.dtype.type(np.nan).view(unsigned_dtype)

# Standardize zeros:
unsigned = select(eq(x, _zero(x)), full_like(unsigned, unsigned_zero), unsigned)

# Standardize nans:
unsigned = select(_isnan(x), full_like(unsigned, unsigned_nan), unsigned)

# Convert back to a floating-point representation.
return bitcast_convert_type(unsigned, x.dtype)
return result


# Default comparator that sorts the operands lexicographically on the
Expand Down

0 comments on commit ba9fd77

Please sign in to comment.