Skip to content

Commit

Permalink
Use totalorder comparisons for sort
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 573289718
  • Loading branch information
majnemer authored and jax authors committed Oct 13, 2023
1 parent c568110 commit 8fe4fcc
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 43 deletions.
90 changes: 47 additions & 43 deletions jax/_src/lax/lax.py
Expand Up @@ -2247,7 +2247,7 @@ def _opaque_comparison_hlo(direction, reduction_op, identity, ctx,
base_aval_out = core.ShapedArray(base_aval_x.shape, aval_out.dtype)
reduce_axes = tuple(range(aval_out.ndim, base_aval_out.ndim))
res, = mlir.delegate_lowering(
ctx, partial(_compare_lower_hlo, direction),
ctx, partial(_compare_lower_hlo, direction, False),
x, y, avals_in=[base_aval_x, base_aval_y], avals_out=[base_aval_out])
return mlir.delegate_lowering(
ctx, partial(_unary_reduce_lower, reduction_op, identity,
Expand All @@ -2270,14 +2270,16 @@ def _compare_lower_hlo_opaque(direction: str, ctx, avals_in, aval_out, x, y):
raise NotImplementedError(
f"HLO comparison {direction} for extended dtype {avals_in[0].dtype}")

def _compare_lower_hlo(direction: str, ctx, x, y):

def _compare_lower_hlo(direction: str, total_order: bool, ctx, x, y):
avals_in, (aval_out,) = ctx.avals_in, ctx.avals_out
x_dtype = avals_in[0].dtype
x, y = mlir.multi_broadcast_in_dim(ctx, (x, y), avals_in, aval_out.shape)
if dtypes.issubdtype(x_dtype, dtypes.extended):
assert not total_order
return _compare_lower_hlo_opaque(direction, ctx, avals_in, aval_out, x, y)
if dtypes.issubdtype(x_dtype, np.inexact):
compare_type = "FLOAT"
compare_type = "TOTALORDER" if total_order else "FLOAT"
elif dtypes.issubdtype(x_dtype, np.signedinteger):
compare_type = "SIGNED"
else:
Expand All @@ -2286,27 +2288,39 @@ def _compare_lower_hlo(direction: str, ctx, x, y):

eq_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq', allow_extended_dtype=True)
ad.defjvp_zero(eq_p)
mlir.register_lowering(eq_p, partial(_compare_lower_hlo, "EQ"))
mlir.register_lowering(eq_p, partial(_compare_lower_hlo, "EQ", False))

ne_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'ne', allow_extended_dtype=True)
ad.defjvp_zero(ne_p)
mlir.register_lowering(ne_p, partial(_compare_lower_hlo, "NE"))
mlir.register_lowering(ne_p, partial(_compare_lower_hlo, "NE", False))

ge_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'ge')
ad.defjvp_zero(ge_p)
mlir.register_lowering(ge_p, partial(_compare_lower_hlo, "GE"))
mlir.register_lowering(ge_p, partial(_compare_lower_hlo, "GE", False))

gt_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'gt')
ad.defjvp_zero(gt_p)
mlir.register_lowering(gt_p, partial(_compare_lower_hlo, "GT"))
mlir.register_lowering(gt_p, partial(_compare_lower_hlo, "GT", False))

le_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'le')
ad.defjvp_zero(le_p)
mlir.register_lowering(le_p, partial(_compare_lower_hlo, "LE"))
mlir.register_lowering(le_p, partial(_compare_lower_hlo, "LE", False))

lt_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'lt')
ad.defjvp_zero(lt_p)
mlir.register_lowering(lt_p, partial(_compare_lower_hlo, "LT"))
mlir.register_lowering(lt_p, partial(_compare_lower_hlo, "LT", False))

eq_to_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq_to')
ad.defjvp_zero(eq_to_p)
mlir.register_lowering(eq_to_p, partial(_compare_lower_hlo, "EQ", True))

le_to_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'le_to')
ad.defjvp_zero(le_to_p)
mlir.register_lowering(le_to_p, partial(_compare_lower_hlo, "LE", True))

lt_to_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'lt_to')
ad.defjvp_zero(lt_to_p)
mlir.register_lowering(lt_to_p, partial(_compare_lower_hlo, "LT", True))


def _convert_element_type_shape_rule(operand, *, new_dtype, weak_type):
Expand Down Expand Up @@ -4030,43 +4044,33 @@ def _sort_abstract_eval(*args, **kwargs):
return args


def _float_to_int_for_sort(x):
# Switch from a floating point value to a integer value in such a way that
# when using the integer value to compare, we get the same result for normal
# values, and -nan is treated as the smallest value, and nan is treated as
# the largest value.
# If f is a float, and
# x = bit_cast<int32>(f);
# y = x < 0 ? int32_max - x : x;
# then y is ordered as an int32 such that finite values have the obvious
# order. In this scheme, -0 would be before 0, and -NaN and NaN appear at
def _canonicalize_float_for_sort(x):
# In the sort comparator, we are going to use a comparision operator where -0
# would be before 0, and -NaN and NaN appear at the beginning and end of the
# ordering. In this scheme, -0 would be before 0, and -NaN and NaN appear at
# the beginning and end of the ordering. This causes issues for stable
# sorts, so we avoid this by standardizing the representation of zeros
# and NaNs in the output.
# Note that in order to avoid -x to overflow, we calculate
# int32_max - x as unsigned, and then convert back to signed.
if x.dtype == dtypes.bfloat16:
x = convert_element_type(x, np.float32)
nbits = np.finfo(x).bits
signed_dtype = _INT_DTYPES[nbits]
nbits = dtypes.finfo(x.dtype).bits
unsigned_dtype = _UINT_DTYPES[nbits]

signed = bitcast_convert_type(x, signed_dtype)
unsigned = bitcast_convert_type(x, unsigned_dtype)

# We cannot standardize zeros in x because XLA elides this is some cases.
# We cannot standardize NaNs in x because it triggers jax.debug_nans
# So instead we do these replacements in the signed integer representation.
# So instead we do these replacements in the unsigned integer representation.

unsigned = bitcast_convert_type(x, unsigned_dtype)
unsigned_zero = x.dtype.type(0.0).view(unsigned_dtype)
unsigned_nan = x.dtype.type(np.nan).view(unsigned_dtype)

# Standardize zeros:
signed = select(eq(x, _zero(x)), _zeros(signed), signed)
unsigned = select(eq(x, _zero(x)), full_like(unsigned, unsigned_zero), unsigned)

# Standardize nans:
signed_nan = x.dtype.type(np.nan).view(signed_dtype)
signed = select(_isnan(x), full_like(signed, signed_nan), signed)
unsigned = select(_isnan(x), full_like(unsigned, unsigned_nan), unsigned)

# Convert back to a floating-point representation.
return bitcast_convert_type(unsigned, x.dtype)

flipped = bitcast_convert_type(
sub(unsigned_dtype.type(np.iinfo(signed_dtype).max), unsigned), signed_dtype)
return select(lt(signed, _zero(signed)), flipped, signed)

# Default comparator that sorts the operands lexicographically on the
# first `num_keys` arguments.
Expand All @@ -4081,8 +4085,8 @@ def _sort_lt_comparator(*operands, num_keys=1):
x_keys, y_keys = _operands_to_keys(*operands, num_keys=num_keys)
p = None
for xk, yk in zip(x_keys[::-1], y_keys[::-1]):
p = (bitwise_or(lt(xk, yk), bitwise_and(eq(xk, yk), p)) if p is not None
else lt(xk, yk))
p = (bitwise_or(lt_to_p.bind(xk, yk), bitwise_and(eq_to_p.bind(xk, yk), p)) if p is not None
else lt_to_p.bind(xk, yk))
return p

# Similar to sort_lt_comparator, but implements less than or equal. Used by
Expand All @@ -4091,8 +4095,8 @@ def _sort_le_comparator(*operands, num_keys=1):
x_keys, y_keys = _operands_to_keys(*operands, num_keys=num_keys)
p = None
for xk, yk in zip(x_keys[::-1], y_keys[::-1]):
p = (bitwise_or(lt(xk, yk), bitwise_and(eq(xk, yk), p)) if p is not None
else le(xk, yk))
p = (bitwise_or(lt_to_p.bind(xk, yk), bitwise_and(eq_to_p.bind(xk, yk), p)) if p is not None
else le_to_p.bind(xk, yk))
return p

def _operands_to_keys(*operands, num_keys=1):
Expand All @@ -4102,11 +4106,11 @@ def _operands_to_keys(*operands, num_keys=1):
for x, y in zip(operands[:2*num_keys:2], operands[1:2*num_keys:2]):
assert x.dtype == y.dtype, (x.dtype, y.dtype)
if dtypes.issubdtype(x.dtype, np.complexfloating):
x_keys.extend([_float_to_int_for_sort(real(x)), _float_to_int_for_sort(imag(x))])
y_keys.extend([_float_to_int_for_sort(real(y)), _float_to_int_for_sort(imag(y))])
x_keys.extend([_canonicalize_float_for_sort(real(x)), _canonicalize_float_for_sort(imag(x))])
y_keys.extend([_canonicalize_float_for_sort(real(y)), _canonicalize_float_for_sort(imag(y))])
elif dtypes.issubdtype(x.dtype, np.floating):
x_keys.append(_float_to_int_for_sort(x))
y_keys.append(_float_to_int_for_sort(y))
x_keys.append(_canonicalize_float_for_sort(x))
y_keys.append(_canonicalize_float_for_sort(y))
else:
x_keys.append(x)
y_keys.append(y)
Expand Down
37 changes: 37 additions & 0 deletions jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -1901,6 +1901,37 @@ def cast_aval(aval):
tf_impl[lax.eq_p] = tf.math.equal
tf_impl[lax.ne_p] = tf.math.not_equal


def _total_order_adjustment(x):
if not dtypes.issubdtype(x.dtype.as_numpy_dtype, np.inexact):
return x
assert dtypes.issubdtype(x.dtype.as_numpy_dtype, np.floating)
# Switch from a floating point value to a integer value in such a way that
# when using the integer value to compare, we get the same result for normal
# values, and -nan is treated as the smallest value, and nan is treated as
# the largest value.
# If f is a float, and
# x = bit_cast<int32>(f);
# y = x < 0 ? int32_max - x : x;
# then y is ordered as an int32 such that finite values have the obvious
# order. In this scheme, -0 would be before 0, and -NaN and NaN appear at
# the beginning and end of the ordering.
nbits = dtypes.finfo(x.dtype.as_numpy_dtype).bits
signed_dtype = lax_internal._INT_DTYPES[nbits]
unsigned_dtype = lax_internal._UINT_DTYPES[nbits]

signed = tf.bitcast(x, signed_dtype)
sign_mask = tf.bitcast(tf.bitwise.right_shift(signed, nbits - 1), unsigned_dtype)
sign_magnitude_mask = tf.bitcast(tf.bitwise.right_shift(sign_mask, 1), signed_dtype)
return tf.bitwise.bitwise_xor(signed, sign_magnitude_mask)

def _total_order_equal(x, y):
if dtypes.issubdtype(x.dtype.as_numpy_dtype, np.complexfloating):
return _total_order_equal(tf.math.real(x), tf.math.real(y)) and _total_order_equal(tf.math.imag(x), tf.math.imag(y))
return tf.math.equal(_total_order_adjustment(x), _total_order_adjustment(y))

tf_impl[lax.eq_to_p] = _total_order_equal

boolean_greater = lambda x,y: tf.logical_and(x, tf.logical_not(y)) # Only one combo: T,F -> T
boolean_less = lambda x,y: tf.logical_and(tf.logical_not(x), y) # Only one combo: F,T -> T
boolean_greater_or_equal = lambda x, y: tf.logical_not(boolean_less(x,y)) # All cases except F,T
Expand All @@ -1911,6 +1942,12 @@ def cast_aval(aval):
tf_impl[lax.ge_p] = handle_boolean_args(tf.math.greater_equal, argnums=(0, 1), boolean_f=boolean_greater_or_equal)
tf_impl[lax.le_p] = handle_boolean_args(tf.math.less_equal, argnums=(0, 1), boolean_f=boolean_less_or_equal)

def _total_order_cond(cond, x, y):
return cond(_total_order_adjustment(x), _total_order_adjustment(y))

tf_impl[lax.lt_to_p] = handle_boolean_args(partial(_total_order_cond, tf.math.less), argnums=(0, 1), boolean_f=boolean_less)
tf_impl[lax.le_to_p] = handle_boolean_args(partial(_total_order_cond, tf.math.less_equal), argnums=(0, 1), boolean_f=boolean_less_or_equal)

tf_impl[lax.linalg.cholesky_p] = tf.linalg.cholesky


Expand Down
3 changes: 3 additions & 0 deletions jax/lax/__init__.py
Expand Up @@ -91,6 +91,7 @@
dtypes as _deprecated_dtypes,
eq as eq,
eq_p as eq_p,
eq_to_p as eq_to_p,
exp as exp,
exp_p as exp_p,
exp2 as exp2,
Expand Down Expand Up @@ -119,6 +120,7 @@
itertools as _deprecated_itertools,
le as le,
le_p as le_p,
le_to_p as le_to_p,
log as log,
log1p as log1p,
log1p_p as log1p_p,
Expand All @@ -127,6 +129,7 @@
logistic_p as logistic_p,
lt as lt,
lt_p as lt_p,
lt_to_p as lt_to_p,
max as max,
max_p as max_p,
min as min,
Expand Down

0 comments on commit 8fe4fcc

Please sign in to comment.