Skip to content

Commit

Permalink
[mhlo] Clean up ops that can use InferTensorTypeWithReify
Browse files Browse the repository at this point in the history
This means we can get rid of custom builders and return type inference. This
all goes through inferReturnTypeComponents now, so fix obvious bugs in those
implementations.

There should be no behaviorial change. However, python bindings no longer
generate a result type builder for mhlo.CompareOp, which is unfortunate.

PiperOrigin-RevId: 438341237
  • Loading branch information
d0k authored and jax authors committed Mar 30, 2022
1 parent b1a50fd commit a04b777
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 80 deletions.
60 changes: 22 additions & 38 deletions jax/_src/lax/lax.py
Expand Up @@ -1576,26 +1576,28 @@ def _sign_translation_rule(ctx, avals_in, avals_out, x):
sign_p = standard_unop(_num, 'sign', translation_rule=_sign_translation_rule)
ad.defjvp_zero(sign_p)

def _compare_mhlo(x, y, direction, type):
"""Creates mhlo.CompareOp."""
if jax._src.lib.mlir_api_version >= 5:
return mhlo.CompareOp(x, y, mhlo.ComparisonDirectionAttr.get(direction),
mhlo.ComparisonTypeAttr.get(type))
tensor_type = ir.RankedTensorType(x.type)
dims = [tensor_type.get_dim_size(i) for i in range(tensor_type.rank)]
bool_shape = ir.RankedTensorType.get(dims, ir.IntegerType.get_signless(1))
if jax._src.lib.mlir_api_version >= 3:
return mhlo.CompareOp(bool_shape, x, y,
mhlo.ComparisonDirectionAttr.get(direction),
mhlo.ComparisonTypeAttr.get(type))
return mhlo.CompareOp(bool_shape, x, y, ir.StringAttr.get(direction),
ir.StringAttr.get(type))

def _sign_lower_mhlo(ctx, x):
x_aval, = ctx.avals_in
if dtypes.issubdtype(x_aval.dtype, np.unsignedinteger):
if jax._src.lib.mlir_api_version >= 3:
return mhlo.SelectOp(
mhlo.CompareOp(
mlir.aval_to_ir_type(x_aval.update(dtype=np.dtype(np.bool_))), x,
mlir.full_like_aval(0, x_aval),
mhlo.ComparisonDirectionAttr.get('EQ'),
mhlo.ComparisonTypeAttr.get('UNSIGNED')).result,
mlir.full_like_aval(0, x_aval), mlir.full_like_aval(1,
x_aval)).results
else:
return mhlo.SelectOp(
mhlo.CompareOp(
mlir.aval_to_ir_type(x_aval.update(dtype=np.dtype(np.bool_))), x,
mlir.full_like_aval(0, x_aval), ir.StringAttr.get('EQ'),
ir.StringAttr.get('UNSIGNED')).result,
mlir.full_like_aval(0, x_aval), mlir.full_like_aval(1,
x_aval)).results
return mhlo.SelectOp(
_compare_mhlo(x, mlir.full_like_aval(0, x_aval), 'EQ',
'UNSIGNED').result, mlir.full_like_aval(0, x_aval),
mlir.full_like_aval(1, x_aval)).results
return mhlo.SignOp(x).results

mlir.register_lowering(sign_p, _sign_lower_mhlo)
Expand Down Expand Up @@ -2226,15 +2228,7 @@ def _compare_lower_mhlo(direction: str, ctx, x, y):
compare_type = "SIGNED"
else:
compare_type = "UNSIGNED"
if jax._src.lib.mlir_api_version >= 3:
return mhlo.CompareOp(
mlir.aval_to_ir_type(aval_out), x, y,
mhlo.ComparisonDirectionAttr.get(direction),
mhlo.ComparisonTypeAttr.get(compare_type)).results
else:
return mhlo.CompareOp(
mlir.aval_to_ir_type(aval_out), x, y, ir.StringAttr.get(direction),
ir.StringAttr.get(compare_type)).results
return _compare_mhlo(x, y, direction, compare_type).results

eq_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq')
ad.defjvp_zero(eq_p)
Expand Down Expand Up @@ -3354,8 +3348,6 @@ def _select_mhlo_lowering(ctx, which, *cases):
if len(cases) == 1: return cases
return mhlo.SelectOp(which, cases[1], cases[0]).results

bool_shape = ir.RankedTensorType.get(which_aval.shape,
ir.IntegerType.get_signless(1))
if dtypes.issubdtype(which_aval.dtype, np.signedinteger):
compare_type = 'SIGNED'
else:
Expand All @@ -3367,16 +3359,8 @@ def _select(offset, cases):
if len(cases) == 1:
return cases[0]
mid = len(cases) // 2
if jax._src.lib.mlir_api_version >= 3:
pred = mhlo.CompareOp(bool_shape, which,
mlir.full_like_aval(offset + mid, which_aval),
mhlo.ComparisonDirectionAttr.get(lt),
mhlo.ComparisonTypeAttr.get(compare_type))
else:
pred = mhlo.CompareOp(bool_shape, which,
mlir.full_like_aval(offset + mid, which_aval),
ir.StringAttr.get(lt),
ir.StringAttr.get(compare_type))
pred = _compare_mhlo(which, mlir.full_like_aval(offset + mid, which_aval),
lt, compare_type)
return mhlo.SelectOp(pred, _select(offset, cases[:mid]),
_select(offset + mid, cases[mid:])).result

Expand Down
54 changes: 22 additions & 32 deletions jax/interpreters/mlir.py
Expand Up @@ -828,35 +828,32 @@ def add_jaxvals_lowering(ctx, x, y):
register_lowering(ad_util.stop_gradient_p, lambda ctx, x: [x])


def _compare_mhlo(x, y, direction, type):
"""Creates mhlo.CompareOp."""
if jax._src.lib.mlir_api_version >= 5:
return mhlo.CompareOp(x, y, mhlo.ComparisonDirectionAttr.get(direction),
mhlo.ComparisonTypeAttr.get(type))
tensor_type = ir.RankedTensorType(x.type)
dims = [tensor_type.get_dim_size(i) for i in range(tensor_type.rank)]
bool_shape = ir.RankedTensorType.get(dims, ir.IntegerType.get_signless(1))
if jax._src.lib.mlir_api_version >= 3:
return mhlo.CompareOp(bool_shape, x, y,
mhlo.ComparisonDirectionAttr.get(direction),
mhlo.ComparisonTypeAttr.get(type))
return mhlo.CompareOp(bool_shape, x, y, ir.StringAttr.get(direction),
ir.StringAttr.get(type))

def _minmax_mhlo(op, cmp, x, y):
"""Min/max that compares complex values lexicographically as pairs."""
tensor_type = ir.RankedTensorType(x.type)
if ir.ComplexType.isinstance(tensor_type.element_type):
rx = mhlo.RealOp(x).result
ry = mhlo.RealOp(y).result
dims = [tensor_type.get_dim_size(i) for i in range(tensor_type.rank)]
bool_shape = ir.RankedTensorType.get(dims, ir.IntegerType.get_signless(1))
if jax._src.lib.mlir_api_version >= 3:
real_eq = mhlo.CompareOp(bool_shape, rx, ry,
mhlo.ComparisonDirectionAttr.get("EQ"),
mhlo.ComparisonTypeAttr.get("FLOAT"))
real_cmp = mhlo.CompareOp(bool_shape, rx, ry,
mhlo.ComparisonDirectionAttr.get(cmp),
mhlo.ComparisonTypeAttr.get("FLOAT"))
imag_cmp = mhlo.CompareOp(bool_shape,
mhlo.ImagOp(x).result,
mhlo.ImagOp(y).result,
mhlo.ComparisonDirectionAttr.get(cmp),
mhlo.ComparisonTypeAttr.get("FLOAT"))
else:
real_eq = mhlo.CompareOp(bool_shape, rx, ry, ir.StringAttr.get("EQ"),
ir.StringAttr.get("FLOAT"))
real_cmp = mhlo.CompareOp(bool_shape, rx, ry, ir.StringAttr.get(cmp),
ir.StringAttr.get("FLOAT"))
imag_cmp = mhlo.CompareOp(bool_shape,
mhlo.ImagOp(x).result,
mhlo.ImagOp(y).result, ir.StringAttr.get(cmp),
ir.StringAttr.get("FLOAT"))
real_eq = _compare_mhlo(rx, ry, "EQ", "FLOAT")
real_cmp = _compare_mhlo(rx, ry, cmp, "FLOAT")
imag_cmp = _compare_mhlo(
mhlo.ImagOp(x).result,
mhlo.ImagOp(y).result, cmp, "FLOAT")
which = mhlo.SelectOp(real_eq, imag_cmp, real_cmp).result
return mhlo.SelectOp(which, x, y)
else:
Expand All @@ -878,15 +875,8 @@ def convert_mhlo(x, aval_in, aval_out):
compare_type = "SIGNED"
else:
compare_type = "UNSIGNED"
if jax._src.lib.mlir_api_version >= 3:
return mhlo.CompareOp(
aval_to_ir_type(aval_out), x, full_like_aval(0, aval_in),
mhlo.ComparisonDirectionAttr.get("NE"),
mhlo.ComparisonTypeAttr.get(compare_type)).result
else:
return mhlo.CompareOp(
aval_to_ir_type(aval_out), x, full_like_aval(0, aval_in),
ir.StringAttr.get("NE"), ir.StringAttr.get(compare_type)).result
return _compare_mhlo(x, full_like_aval(0, aval_in), "NE",
compare_type).result
return mhlo.ConvertOp(aval_to_ir_type(aval_out), x).result

def _wrap_with_spmd_op(name: str,
Expand Down
27 changes: 17 additions & 10 deletions jax/interpreters/pxla.py
Expand Up @@ -1729,6 +1729,22 @@ def _mhlo_shard(aval, axis_env, xs, in_axis):
else:
raise TypeError(aval)


def _compare_mhlo(x, y, direction, type):
"""Creates mhlo.CompareOp."""
if jax._src.lib.mlir_api_version >= 5:
return mhlo.CompareOp(x, y, mhlo.ComparisonDirectionAttr.get(direction),
mhlo.ComparisonTypeAttr.get(type))
tensor_type = ir.RankedTensorType(x.type)
dims = [tensor_type.get_dim_size(i) for i in range(tensor_type.rank)]
bool_shape = ir.RankedTensorType.get(dims, ir.IntegerType.get_signless(1))
if jax._src.lib.mlir_api_version >= 3:
return mhlo.CompareOp(bool_shape, x, y,
mhlo.ComparisonDirectionAttr.get(direction),
mhlo.ComparisonTypeAttr.get(type))
return mhlo.CompareOp(bool_shape, x, y, ir.StringAttr.get(direction),
ir.StringAttr.get(type))

# TODO(b/110096942): more efficient gather
def _mhlo_unshard(aval, axis_env, out_axis, xs, platform):
if aval is core.abstract_unit:
Expand Down Expand Up @@ -1770,16 +1786,7 @@ def _mhlo_unshard(aval, axis_env, out_axis, xs, platform):
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
if convert_bool:
float_zero = mlir.full_like_aval(0, padded_aval)
if jax._src.lib.mlir_api_version >= 3:
out = mhlo.CompareOp(
mlir.aval_to_ir_type(padded_aval.update(dtype=np.dtype(np.bool_))),
out, float_zero, mhlo.ComparisonDirectionAttr.get("NE"),
mhlo.ComparisonTypeAttr.get("FLOAT")).result
else:
out = mhlo.CompareOp(
mlir.aval_to_ir_type(padded_aval.update(dtype=np.dtype(np.bool_))),
out, float_zero, ir.StringAttr.get("NE"),
ir.StringAttr.get("FLOAT")).result
out = _compare_mhlo(out, float_zero, "NE", "FLOAT").result
return out
else:
raise TypeError(aval)
Expand Down

0 comments on commit a04b777

Please sign in to comment.