Skip to content

Commit

Permalink
Merge pull request #10095 from hawkinsp:mlir
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 438400696
  • Loading branch information
jax authors committed Mar 30, 2022
2 parents ee1ca3f + ade9f1a commit 0694dbd
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 27 deletions.
27 changes: 7 additions & 20 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1576,27 +1576,13 @@ def _sign_translation_rule(ctx, avals_in, avals_out, x):
sign_p = standard_unop(_num, 'sign', translation_rule=_sign_translation_rule)
ad.defjvp_zero(sign_p)

def _compare_mhlo(x, y, direction, type):
"""Creates mhlo.CompareOp."""
if jax._src.lib.mlir_api_version >= 5:
return mhlo.CompareOp(x, y, mhlo.ComparisonDirectionAttr.get(direction),
mhlo.ComparisonTypeAttr.get(type))
tensor_type = ir.RankedTensorType(x.type)
dims = [tensor_type.get_dim_size(i) for i in range(tensor_type.rank)]
bool_shape = ir.RankedTensorType.get(dims, ir.IntegerType.get_signless(1))
if jax._src.lib.mlir_api_version >= 3:
return mhlo.CompareOp(bool_shape, x, y,
mhlo.ComparisonDirectionAttr.get(direction),
mhlo.ComparisonTypeAttr.get(type))
return mhlo.CompareOp(bool_shape, x, y, ir.StringAttr.get(direction),
ir.StringAttr.get(type))

def _sign_lower_mhlo(ctx, x):
x_aval, = ctx.avals_in
if dtypes.issubdtype(x_aval.dtype, np.unsignedinteger):
return mhlo.SelectOp(
_compare_mhlo(x, mlir.full_like_aval(0, x_aval), 'EQ',
'UNSIGNED').result, mlir.full_like_aval(0, x_aval),
mlir.compare_mhlo(x, mlir.full_like_aval(0, x_aval), 'EQ',
'UNSIGNED').result,
mlir.full_like_aval(0, x_aval),
mlir.full_like_aval(1, x_aval)).results
return mhlo.SignOp(x).results

Expand Down Expand Up @@ -2228,7 +2214,7 @@ def _compare_lower_mhlo(direction: str, ctx, x, y):
compare_type = "SIGNED"
else:
compare_type = "UNSIGNED"
return _compare_mhlo(x, y, direction, compare_type).results
return mlir.compare_mhlo(x, y, direction, compare_type).results

eq_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq')
ad.defjvp_zero(eq_p)
Expand Down Expand Up @@ -3359,8 +3345,9 @@ def _select(offset, cases):
if len(cases) == 1:
return cases[0]
mid = len(cases) // 2
pred = _compare_mhlo(which, mlir.full_like_aval(offset + mid, which_aval),
lt, compare_type)
pred = mlir.compare_mhlo(which,
mlir.full_like_aval(offset + mid, which_aval),
lt, compare_type)
return mhlo.SelectOp(pred, _select(offset, cases[:mid]),
_select(offset + mid, cases[mid:])).result

Expand Down
13 changes: 6 additions & 7 deletions jax/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,13 +828,12 @@ def add_jaxvals_lowering(ctx, x, y):
register_lowering(ad_util.stop_gradient_p, lambda ctx, x: [x])


def _compare_mhlo(x, y, direction, type):
def compare_mhlo(x, y, direction, type):
"""Creates mhlo.CompareOp."""
if jax._src.lib.mlir_api_version >= 5:
return mhlo.CompareOp(x, y, mhlo.ComparisonDirectionAttr.get(direction),
mhlo.ComparisonTypeAttr.get(type))
tensor_type = ir.RankedTensorType(x.type)
dims = [tensor_type.get_dim_size(i) for i in range(tensor_type.rank)]
dims = ir.RankedTensorType(x.type).shape
bool_shape = ir.RankedTensorType.get(dims, ir.IntegerType.get_signless(1))
if jax._src.lib.mlir_api_version >= 3:
return mhlo.CompareOp(bool_shape, x, y,
Expand All @@ -849,9 +848,9 @@ def _minmax_mhlo(op, cmp, x, y):
if ir.ComplexType.isinstance(tensor_type.element_type):
rx = mhlo.RealOp(x).result
ry = mhlo.RealOp(y).result
real_eq = _compare_mhlo(rx, ry, "EQ", "FLOAT")
real_cmp = _compare_mhlo(rx, ry, cmp, "FLOAT")
imag_cmp = _compare_mhlo(
real_eq = compare_mhlo(rx, ry, "EQ", "FLOAT")
real_cmp = compare_mhlo(rx, ry, cmp, "FLOAT")
imag_cmp = compare_mhlo(
mhlo.ImagOp(x).result,
mhlo.ImagOp(y).result, cmp, "FLOAT")
which = mhlo.SelectOp(real_eq, imag_cmp, real_cmp).result
Expand All @@ -875,7 +874,7 @@ def convert_mhlo(x, aval_in, aval_out):
compare_type = "SIGNED"
else:
compare_type = "UNSIGNED"
return _compare_mhlo(x, full_like_aval(0, aval_in), "NE",
return compare_mhlo(x, full_like_aval(0, aval_in), "NE",
compare_type).result
return mhlo.ConvertOp(aval_to_ir_type(aval_out), x).result

Expand Down

0 comments on commit 0694dbd

Please sign in to comment.