Skip to content

Commit

Permalink
Replace (deprecated) StrEnumAttr with EnumAttr.
Browse files Browse the repository at this point in the history
ref: https://reviews.llvm.org/D120834
PiperOrigin-RevId: 435550738
  • Loading branch information
sdasgup3 authored and jax authors committed Mar 18, 2022
1 parent 1f95273 commit 6cd9804
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 52 deletions.
63 changes: 46 additions & 17 deletions jax/_src/lax/lax.py
Expand Up @@ -1579,13 +1579,23 @@ def _sign_translation_rule(ctx, avals_in, avals_out, x):
def _sign_lower_mhlo(ctx, x):
x_aval, = ctx.avals_in
if dtypes.issubdtype(x_aval.dtype, np.unsignedinteger):
return mhlo.SelectOp(
mhlo.CompareOp(
mlir.aval_to_ir_type(x_aval.update(dtype=np.dtype(np.bool_))),
x, mlir.full_like_aval(0, x_aval), ir.StringAttr.get("EQ"),
ir.StringAttr.get("UNSIGNED")).result,
mlir.full_like_aval(0, x_aval),
mlir.full_like_aval(1, x_aval)).results
if jax._src.lib.mlir_api_version >= 3:
return mhlo.SelectOp(
mhlo.CompareOp(
mlir.aval_to_ir_type(x_aval.update(dtype=np.dtype(np.bool_))), x,
mlir.full_like_aval(0, x_aval),
mhlo.ComparisonDirectionAttr.get('EQ'),
mhlo.ComparisonTypeAttr.get('UNSIGNED')).result,
mlir.full_like_aval(0, x_aval), mlir.full_like_aval(1,
x_aval)).results
else:
return mhlo.SelectOp(
mhlo.CompareOp(
mlir.aval_to_ir_type(x_aval.update(dtype=np.dtype(np.bool_))), x,
mlir.full_like_aval(0, x_aval), ir.StringAttr.get('EQ'),
ir.StringAttr.get('UNSIGNED')).result,
mlir.full_like_aval(0, x_aval), mlir.full_like_aval(1,
x_aval)).results
return mhlo.SignOp(x).results

mlir.register_lowering(sign_p, _sign_lower_mhlo)
Expand Down Expand Up @@ -2216,9 +2226,15 @@ def _compare_lower_mhlo(direction: str, ctx, x, y):
compare_type = "SIGNED"
else:
compare_type = "UNSIGNED"
return mhlo.CompareOp(mlir.aval_to_ir_type(aval_out), x, y,
ir.StringAttr.get(direction),
ir.StringAttr.get(compare_type)).results
if jax._src.lib.mlir_api_version >= 3:
return mhlo.CompareOp(
mlir.aval_to_ir_type(aval_out), x, y,
mhlo.ComparisonDirectionAttr.get(direction),
mhlo.ComparisonTypeAttr.get(compare_type)).results
else:
return mhlo.CompareOp(
mlir.aval_to_ir_type(aval_out), x, y, ir.StringAttr.get(direction),
ir.StringAttr.get(compare_type)).results

eq_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq')
ad.defjvp_zero(eq_p)
Expand Down Expand Up @@ -2630,7 +2646,13 @@ def precision_attr(precision: PrecisionType) -> ir.ArrayAttr:
full_precision = (precision, precision)
else:
full_precision = precision
return ir.ArrayAttr.get([ir.StringAttr.get(str(p)) for p in full_precision])
if jax._src.lib.mlir_api_version >= 3:
return ir.ArrayAttr.get(
[mhlo.PrecisionAttr.get(str(p)) for p in full_precision])
else:
return ir.ArrayAttr.get([ir.StringAttr.get(str(p)) for p in full_precision])



def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
precision, preferred_element_type: Optional[np.dtype]):
Expand Down Expand Up @@ -3335,19 +3357,26 @@ def _select_mhlo_lowering(ctx, which, *cases):
bool_shape = ir.RankedTensorType.get(which_aval.shape,
ir.IntegerType.get_signless(1))
if dtypes.issubdtype(which_aval.dtype, np.signedinteger):
compare_type = ir.StringAttr.get("SIGNED")
compare_type = 'SIGNED'
else:
compare_type = ir.StringAttr.get("UNSIGNED")
lt = ir.StringAttr.get("LT")
compare_type = 'UNSIGNED'
lt = 'LT'

def _select(offset, cases):
assert len(cases) > 0
if len(cases) == 1:
return cases[0]
mid = len(cases) // 2
pred = mhlo.CompareOp(
bool_shape, which, mlir.full_like_aval(offset + mid, which_aval),
lt, compare_type)
if jax._src.lib.mlir_api_version >= 3:
pred = mhlo.CompareOp(bool_shape, which,
mlir.full_like_aval(offset + mid, which_aval),
mhlo.ComparisonDirectionAttr.get(lt),
mhlo.ComparisonTypeAttr.get(compare_type))
else:
pred = mhlo.CompareOp(bool_shape, which,
mlir.full_like_aval(offset + mid, which_aval),
ir.StringAttr.get(lt),
ir.StringAttr.get(compare_type))
return mhlo.SelectOp(pred, _select(offset, cases[:mid]),
_select(offset + mid, cases[mid:])).result

Expand Down
43 changes: 31 additions & 12 deletions jax/interpreters/mlir.py
Expand Up @@ -28,6 +28,7 @@
from typing_extensions import Protocol
import warnings

import jax
from jax import core
from jax import linear_util as lu
from jax._src import ad_util
Expand Down Expand Up @@ -820,15 +821,27 @@ def _minmax_mhlo(op, cmp, x, y):
ry = mhlo.RealOp(y).result
dims = [tensor_type.get_dim_size(i) for i in range(tensor_type.rank)]
bool_shape = ir.RankedTensorType.get(dims, ir.IntegerType.get_signless(1))
real_eq = mhlo.CompareOp(bool_shape, rx, ry, ir.StringAttr.get("EQ"),
ir.StringAttr.get("FLOAT"))
real_cmp = mhlo.CompareOp(bool_shape, rx, ry,
ir.StringAttr.get(cmp),
ir.StringAttr.get("FLOAT"))
imag_cmp = mhlo.CompareOp(bool_shape, mhlo.ImagOp(x).result,
mhlo.ImagOp(y).result,
ir.StringAttr.get(cmp),
ir.StringAttr.get("FLOAT"))
if jax._src.lib.mlir_api_version >= 3:
real_eq = mhlo.CompareOp(bool_shape, rx, ry,
mhlo.ComparisonDirectionAttr.get("EQ"),
mhlo.ComparisonTypeAttr.get("FLOAT"))
real_cmp = mhlo.CompareOp(bool_shape, rx, ry,
mhlo.ComparisonDirectionAttr.get(cmp),
mhlo.ComparisonTypeAttr.get("FLOAT"))
imag_cmp = mhlo.CompareOp(bool_shape,
mhlo.ImagOp(x).result,
mhlo.ImagOp(y).result,
mhlo.ComparisonDirectionAttr.get(cmp),
mhlo.ComparisonTypeAttr.get("FLOAT"))
else:
real_eq = mhlo.CompareOp(bool_shape, rx, ry, ir.StringAttr.get("EQ"),
ir.StringAttr.get("FLOAT"))
real_cmp = mhlo.CompareOp(bool_shape, rx, ry, ir.StringAttr.get(cmp),
ir.StringAttr.get("FLOAT"))
imag_cmp = mhlo.CompareOp(bool_shape,
mhlo.ImagOp(x).result,
mhlo.ImagOp(y).result, ir.StringAttr.get(cmp),
ir.StringAttr.get("FLOAT"))
which = mhlo.SelectOp(real_eq, imag_cmp, real_cmp).result
return mhlo.SelectOp(which, x, y)
else:
Expand All @@ -850,9 +863,15 @@ def convert_mhlo(x, aval_in, aval_out):
compare_type = "SIGNED"
else:
compare_type = "UNSIGNED"
return mhlo.CompareOp(
aval_to_ir_type(aval_out), x, full_like_aval(0, aval_in),
ir.StringAttr.get("NE"), ir.StringAttr.get(compare_type)).result
if jax._src.lib.mlir_api_version >= 3:
return mhlo.CompareOp(
aval_to_ir_type(aval_out), x, full_like_aval(0, aval_in),
mhlo.ComparisonDirectionAttr.get("NE"),
mhlo.ComparisonTypeAttr.get(compare_type)).result
else:
return mhlo.CompareOp(
aval_to_ir_type(aval_out), x, full_like_aval(0, aval_in),
ir.StringAttr.get("NE"), ir.StringAttr.get(compare_type)).result
return mhlo.ConvertOp(aval_to_ir_type(aval_out), x).result

def _wrap_with_spmd_op(name: str,
Expand Down
15 changes: 11 additions & 4 deletions jax/interpreters/pxla.py
Expand Up @@ -44,6 +44,7 @@
from absl import logging
import numpy as np

import jax
from jax._src.config import config
from jax import core
from jax import linear_util as lu
Expand Down Expand Up @@ -1753,10 +1754,16 @@ def _mhlo_unshard(aval, axis_env, out_axis, xs, platform):
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
if convert_bool:
float_zero = mlir.full_like_aval(0, padded_aval)
out = mhlo.CompareOp(
mlir.aval_to_ir_type(padded_aval.update(dtype=np.dtype(np.bool_))),
out, float_zero, ir.StringAttr.get("NE"),
ir.StringAttr.get("FLOAT")).result
if jax._src.lib.mlir_api_version >= 3:
out = mhlo.CompareOp(
mlir.aval_to_ir_type(padded_aval.update(dtype=np.dtype(np.bool_))),
out, float_zero, mhlo.ComparisonDirectionAttr.get("NE"),
mhlo.ComparisonTypeAttr.get("FLOAT")).result
else:
out = mhlo.CompareOp(
mlir.aval_to_ir_type(padded_aval.update(dtype=np.dtype(np.bool_))),
out, float_zero, ir.StringAttr.get("NE"),
ir.StringAttr.get("FLOAT")).result
return out
else:
raise TypeError(aval)
Expand Down
38 changes: 19 additions & 19 deletions tests/filecheck/math.filecheck.py
Expand Up @@ -199,29 +199,29 @@ def main(_):

# CHECK-LABEL: TEST: eq float32[] float32[]
# CHECK: mhlo.compare
# CHECK-SAME: compare_type = "FLOAT"
# CHECK-SAME: comparison_direction = "EQ"
# CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
# CHECK-SAME: comparison_direction = #mhlo<"comparison_direction EQ">
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.eq)

# CHECK-LABEL: TEST: eq complex128[] complex128[]
# CHECK: mhlo.compare
# CHECK-SAME: compare_type = "FLOAT"
# CHECK-SAME: comparison_direction = "EQ"
# CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
# CHECK-SAME: comparison_direction = #mhlo<"comparison_direction EQ">
# CHECK-SAME: tensor<complex<f64>>
print_ir(np.complex128(1), np.complex128(2))(lax.eq)

# CHECK-LABEL: TEST: eq int64[] int64[]
# CHECK: mhlo.compare
# CHECK-SAME: compare_type = "SIGNED"
# CHECK-SAME: comparison_direction = "EQ"
# CHECK-SAME: compare_type = #mhlo<"comparison_type SIGNED">
# CHECK-SAME: comparison_direction = #mhlo<"comparison_direction EQ">
# CHECK-SAME: tensor<i64>
print_ir(np.int64(1), np.int64(2))(lax.eq)

# CHECK-LABEL: TEST: eq uint16[] uint16[]
# CHECK: mhlo.compare
# CHECK-SAME: compare_type = "UNSIGNED"
# CHECK-SAME: comparison_direction = "EQ"
# CHECK-SAME: compare_type = #mhlo<"comparison_type UNSIGNED">
# CHECK-SAME: comparison_direction = #mhlo<"comparison_direction EQ">
# CHECK-SAME: tensor<ui16>
print_ir(np.uint16(1), np.uint16(2))(lax.eq)

Expand Down Expand Up @@ -257,15 +257,15 @@ def main(_):

# CHECK-LABEL: TEST: ge float32[] float32[]
# CHECK: mhlo.compare
# CHECK-SAME: compare_type = "FLOAT"
# CHECK-SAME: comparison_direction = "GE"
# CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
# CHECK-SAME: comparison_direction = #mhlo<"comparison_direction GE">
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.ge)

# CHECK-LABEL: TEST: gt float32[] float32[]
# CHECK: mhlo.compare
# CHECK-SAME: compare_type = "FLOAT"
# CHECK-SAME: comparison_direction = "GT"
# CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
# CHECK-SAME: comparison_direction = #mhlo<"comparison_direction GT">
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.gt)

Expand Down Expand Up @@ -302,8 +302,8 @@ def integer_pow(x): return lax.integer_pow(x, 3)

# CHECK-LABEL: TEST: le float32[] float32[]
# CHECK: mhlo.compare
# CHECK-SAME: compare_type = "FLOAT"
# CHECK-SAME: comparison_direction = "LE"
# CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
# CHECK-SAME: comparison_direction = #mhlo<"comparison_direction LE">
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.le)

Expand All @@ -324,8 +324,8 @@ def integer_pow(x): return lax.integer_pow(x, 3)

# CHECK-LABEL: TEST: lt float32[] float32[]
# CHECK: mhlo.compare
# CHECK-SAME: compare_type = "FLOAT"
# CHECK-SAME: comparison_direction = "LT"
# CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
# CHECK-SAME: comparison_direction = #mhlo<"comparison_direction LT">
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.lt)

Expand All @@ -346,8 +346,8 @@ def integer_pow(x): return lax.integer_pow(x, 3)

# CHECK-LABEL: TEST: ne float32[] float32[]
# CHECK: mhlo.compare
# CHECK-SAME: compare_type = "FLOAT"
# CHECK-SAME: comparison_direction = "NE"
# CHECK-SAME: compare_type = #mhlo<"comparison_type FLOAT">
# CHECK-SAME: comparison_direction = #mhlo<"comparison_direction NE">
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.ne)

Expand Down Expand Up @@ -418,7 +418,7 @@ def integer_pow(x): return lax.integer_pow(x, 3)
# CHECK-SAME: tensor<ui32>
print_ir(np.uint32(0), np.uint32(0))(lax.shift_left)

# CHECK-LABEL: TEST: shift_right_arithmetic uint8[] uint8[]
# CHECK-LABEL: TEST: shift_right_arithmetic uint8[] uint8[]
# CHECK: mhlo.shift_right_arithmetic
# CHECK-SAME: tensor<ui8>
print_ir(np.uint8(0), np.uint8(0))(lax.shift_right_arithmetic)
Expand Down

0 comments on commit 6cd9804

Please sign in to comment.