diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 1af860d524ee..cce22823a2ed 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1579,13 +1579,23 @@ def _sign_translation_rule(ctx, avals_in, avals_out, x): def _sign_lower_mhlo(ctx, x): x_aval, = ctx.avals_in if dtypes.issubdtype(x_aval.dtype, np.unsignedinteger): - return mhlo.SelectOp( - mhlo.CompareOp( - mlir.aval_to_ir_type(x_aval.update(dtype=np.dtype(np.bool_))), - x, mlir.full_like_aval(0, x_aval), ir.StringAttr.get("EQ"), - ir.StringAttr.get("UNSIGNED")).result, - mlir.full_like_aval(0, x_aval), - mlir.full_like_aval(1, x_aval)).results + if jax._src.lib.mlir_api_version >= 3: + return mhlo.SelectOp( + mhlo.CompareOp( + mlir.aval_to_ir_type(x_aval.update(dtype=np.dtype(np.bool_))), x, + mlir.full_like_aval(0, x_aval), + mhlo.ComparisonDirectionAttr.get('EQ'), + mhlo.ComparisonTypeAttr.get('UNSIGNED')).result, + mlir.full_like_aval(0, x_aval), mlir.full_like_aval(1, + x_aval)).results + else: + return mhlo.SelectOp( + mhlo.CompareOp( + mlir.aval_to_ir_type(x_aval.update(dtype=np.dtype(np.bool_))), x, + mlir.full_like_aval(0, x_aval), ir.StringAttr.get('EQ'), + ir.StringAttr.get('UNSIGNED')).result, + mlir.full_like_aval(0, x_aval), mlir.full_like_aval(1, + x_aval)).results return mhlo.SignOp(x).results mlir.register_lowering(sign_p, _sign_lower_mhlo) @@ -2216,9 +2226,15 @@ def _compare_lower_mhlo(direction: str, ctx, x, y): compare_type = "SIGNED" else: compare_type = "UNSIGNED" - return mhlo.CompareOp(mlir.aval_to_ir_type(aval_out), x, y, - ir.StringAttr.get(direction), - ir.StringAttr.get(compare_type)).results + if jax._src.lib.mlir_api_version >= 3: + return mhlo.CompareOp( + mlir.aval_to_ir_type(aval_out), x, y, + mhlo.ComparisonDirectionAttr.get(direction), + mhlo.ComparisonTypeAttr.get(compare_type)).results + else: + return mhlo.CompareOp( + mlir.aval_to_ir_type(aval_out), x, y, ir.StringAttr.get(direction), + ir.StringAttr.get(compare_type)).results eq_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq') ad.defjvp_zero(eq_p) @@ -2630,7 +2646,13 @@ def precision_attr(precision: PrecisionType) -> ir.ArrayAttr: full_precision = (precision, precision) else: full_precision = precision - return ir.ArrayAttr.get([ir.StringAttr.get(str(p)) for p in full_precision]) + if jax._src.lib.mlir_api_version >= 3: + return ir.ArrayAttr.get( + [mhlo.PrecisionAttr.get(str(p)) for p in full_precision]) + else: + return ir.ArrayAttr.get([ir.StringAttr.get(str(p)) for p in full_precision]) + + def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, precision, preferred_element_type: Optional[np.dtype]): @@ -3335,19 +3357,26 @@ def _select_mhlo_lowering(ctx, which, *cases): bool_shape = ir.RankedTensorType.get(which_aval.shape, ir.IntegerType.get_signless(1)) if dtypes.issubdtype(which_aval.dtype, np.signedinteger): - compare_type = ir.StringAttr.get("SIGNED") + compare_type = 'SIGNED' else: - compare_type = ir.StringAttr.get("UNSIGNED") - lt = ir.StringAttr.get("LT") + compare_type = 'UNSIGNED' + lt = 'LT' def _select(offset, cases): assert len(cases) > 0 if len(cases) == 1: return cases[0] mid = len(cases) // 2 - pred = mhlo.CompareOp( - bool_shape, which, mlir.full_like_aval(offset + mid, which_aval), - lt, compare_type) + if jax._src.lib.mlir_api_version >= 3: + pred = mhlo.CompareOp(bool_shape, which, + mlir.full_like_aval(offset + mid, which_aval), + mhlo.ComparisonDirectionAttr.get(lt), + mhlo.ComparisonTypeAttr.get(compare_type)) + else: + pred = mhlo.CompareOp(bool_shape, which, + mlir.full_like_aval(offset + mid, which_aval), + ir.StringAttr.get(lt), + ir.StringAttr.get(compare_type)) return mhlo.SelectOp(pred, _select(offset, cases[:mid]), _select(offset + mid, cases[mid:])).result diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index 9364b0ebc4dd..7a365ca14a80 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -28,6 +28,7 @@ from typing_extensions import Protocol import warnings +import jax from jax import core from jax import linear_util as lu from jax._src import ad_util @@ -820,15 +821,27 @@ def _minmax_mhlo(op, cmp, x, y): ry = mhlo.RealOp(y).result dims = [tensor_type.get_dim_size(i) for i in range(tensor_type.rank)] bool_shape = ir.RankedTensorType.get(dims, ir.IntegerType.get_signless(1)) - real_eq = mhlo.CompareOp(bool_shape, rx, ry, ir.StringAttr.get("EQ"), - ir.StringAttr.get("FLOAT")) - real_cmp = mhlo.CompareOp(bool_shape, rx, ry, - ir.StringAttr.get(cmp), - ir.StringAttr.get("FLOAT")) - imag_cmp = mhlo.CompareOp(bool_shape, mhlo.ImagOp(x).result, - mhlo.ImagOp(y).result, - ir.StringAttr.get(cmp), - ir.StringAttr.get("FLOAT")) + if jax._src.lib.mlir_api_version >= 3: + real_eq = mhlo.CompareOp(bool_shape, rx, ry, + mhlo.ComparisonDirectionAttr.get("EQ"), + mhlo.ComparisonTypeAttr.get("FLOAT")) + real_cmp = mhlo.CompareOp(bool_shape, rx, ry, + mhlo.ComparisonDirectionAttr.get(cmp), + mhlo.ComparisonTypeAttr.get("FLOAT")) + imag_cmp = mhlo.CompareOp(bool_shape, + mhlo.ImagOp(x).result, + mhlo.ImagOp(y).result, + mhlo.ComparisonDirectionAttr.get(cmp), + mhlo.ComparisonTypeAttr.get("FLOAT")) + else: + real_eq = mhlo.CompareOp(bool_shape, rx, ry, ir.StringAttr.get("EQ"), + ir.StringAttr.get("FLOAT")) + real_cmp = mhlo.CompareOp(bool_shape, rx, ry, ir.StringAttr.get(cmp), + ir.StringAttr.get("FLOAT")) + imag_cmp = mhlo.CompareOp(bool_shape, + mhlo.ImagOp(x).result, + mhlo.ImagOp(y).result, ir.StringAttr.get(cmp), + ir.StringAttr.get("FLOAT")) which = mhlo.SelectOp(real_eq, imag_cmp, real_cmp).result return mhlo.SelectOp(which, x, y) else: @@ -850,9 +863,15 @@ def convert_mhlo(x, aval_in, aval_out): compare_type = "SIGNED" else: compare_type = "UNSIGNED" - return mhlo.CompareOp( - aval_to_ir_type(aval_out), x, full_like_aval(0, aval_in), - ir.StringAttr.get("NE"), ir.StringAttr.get(compare_type)).result + if jax._src.lib.mlir_api_version >= 3: + return mhlo.CompareOp( + aval_to_ir_type(aval_out), x, full_like_aval(0, aval_in), + mhlo.ComparisonDirectionAttr.get("NE"), + mhlo.ComparisonTypeAttr.get(compare_type)).result + else: + return mhlo.CompareOp( + aval_to_ir_type(aval_out), x, full_like_aval(0, aval_in), + ir.StringAttr.get("NE"), ir.StringAttr.get(compare_type)).result return mhlo.ConvertOp(aval_to_ir_type(aval_out), x).result def _wrap_with_spmd_op(name: str, diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 83cf78c1f148..a06d07931ab0 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -44,6 +44,7 @@ from absl import logging import numpy as np +import jax from jax._src.config import config from jax import core from jax import linear_util as lu @@ -1753,10 +1754,16 @@ def _mhlo_unshard(aval, axis_env, out_axis, xs, platform): # TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU if convert_bool: float_zero = mlir.full_like_aval(0, padded_aval) - out = mhlo.CompareOp( - mlir.aval_to_ir_type(padded_aval.update(dtype=np.dtype(np.bool_))), - out, float_zero, ir.StringAttr.get("NE"), - ir.StringAttr.get("FLOAT")).result + if jax._src.lib.mlir_api_version >= 3: + out = mhlo.CompareOp( + mlir.aval_to_ir_type(padded_aval.update(dtype=np.dtype(np.bool_))), + out, float_zero, mhlo.ComparisonDirectionAttr.get("NE"), + mhlo.ComparisonTypeAttr.get("FLOAT")).result + else: + out = mhlo.CompareOp( + mlir.aval_to_ir_type(padded_aval.update(dtype=np.dtype(np.bool_))), + out, float_zero, ir.StringAttr.get("NE"), + ir.StringAttr.get("FLOAT")).result return out else: raise TypeError(aval) diff --git a/tests/filecheck/math.filecheck.py b/tests/filecheck/math.filecheck.py index d4384fb54ed1..f7f4957c7ff5 100644 --- a/tests/filecheck/math.filecheck.py +++ b/tests/filecheck/math.filecheck.py @@ -199,29 +199,29 @@ def main(_): # CHECK-LABEL: TEST: eq float32[] float32[] # CHECK: mhlo.compare - # CHECK-SAME: compare_type = "FLOAT" - # CHECK-SAME: comparison_direction = "EQ" + # CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT"> + # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction EQ"> # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.eq) # CHECK-LABEL: TEST: eq complex128[] complex128[] # CHECK: mhlo.compare - # CHECK-SAME: compare_type = "FLOAT" - # CHECK-SAME: comparison_direction = "EQ" + # CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT"> + # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction EQ"> # CHECK-SAME: tensor> print_ir(np.complex128(1), np.complex128(2))(lax.eq) # CHECK-LABEL: TEST: eq int64[] int64[] # CHECK: mhlo.compare - # CHECK-SAME: compare_type = "SIGNED" - # CHECK-SAME: comparison_direction = "EQ" + # CHECK-SAME: compare_type = #mhlo<"comparison_type SIGNED"> + # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction EQ"> # CHECK-SAME: tensor print_ir(np.int64(1), np.int64(2))(lax.eq) # CHECK-LABEL: TEST: eq uint16[] uint16[] # CHECK: mhlo.compare - # CHECK-SAME: compare_type = "UNSIGNED" - # CHECK-SAME: comparison_direction = "EQ" + # CHECK-SAME: compare_type = #mhlo<"comparison_type UNSIGNED"> + # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction EQ"> # CHECK-SAME: tensor print_ir(np.uint16(1), np.uint16(2))(lax.eq) @@ -257,15 +257,15 @@ def main(_): # CHECK-LABEL: TEST: ge float32[] float32[] # CHECK: mhlo.compare - # CHECK-SAME: compare_type = "FLOAT" - # CHECK-SAME: comparison_direction = "GE" + # CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT"> + # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction GE"> # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.ge) # CHECK-LABEL: TEST: gt float32[] float32[] # CHECK: mhlo.compare - # CHECK-SAME: compare_type = "FLOAT" - # CHECK-SAME: comparison_direction = "GT" + # CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT"> + # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction GT"> # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.gt) @@ -302,8 +302,8 @@ def integer_pow(x): return lax.integer_pow(x, 3) # CHECK-LABEL: TEST: le float32[] float32[] # CHECK: mhlo.compare - # CHECK-SAME: compare_type = "FLOAT" - # CHECK-SAME: comparison_direction = "LE" + # CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT"> + # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction LE"> # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.le) @@ -324,8 +324,8 @@ def integer_pow(x): return lax.integer_pow(x, 3) # CHECK-LABEL: TEST: lt float32[] float32[] # CHECK: mhlo.compare - # CHECK-SAME: compare_type = "FLOAT" - # CHECK-SAME: comparison_direction = "LT" + # CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT"> + # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction LT"> # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.lt) @@ -346,8 +346,8 @@ def integer_pow(x): return lax.integer_pow(x, 3) # CHECK-LABEL: TEST: ne float32[] float32[] # CHECK: mhlo.compare - # CHECK-SAME: compare_type = "FLOAT" - # CHECK-SAME: comparison_direction = "NE" + # CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT"> + # CHECK-SAME: comparison_direction = #mhlo<"comparison_direction NE"> # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.ne) @@ -418,7 +418,7 @@ def integer_pow(x): return lax.integer_pow(x, 3) # CHECK-SAME: tensor print_ir(np.uint32(0), np.uint32(0))(lax.shift_left) - # CHECK-LABEL: TEST: shift_right_arithmetic uint8[] uint8[] + # CHECK-LABEL: TEST: shift_right_arithmetic uint8[] uint8[] # CHECK: mhlo.shift_right_arithmetic # CHECK-SAME: tensor print_ir(np.uint8(0), np.uint8(0))(lax.shift_right_arithmetic)