diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp index 25af2f3be3067..b5323597b7ca4 100644 --- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp @@ -371,35 +371,38 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult DivOp::fold(FoldAdaptor adaptor) { - auto rhs = adaptor.getRhs(); - auto lhs = adaptor.getLhs(); - - // We can't fold without knowing that LHS isn't NaN - if (!rhs || !lhs) - return {}; + Attribute rhs = adaptor.getRhs(); + Attribute lhs = adaptor.getLhs(); + + // complex.div(complex.constant, a) -> complex.constant + // complex.div(complex.constant, b) -> complex.constant + // complex.div(complex.constant, b) -> complex.constant + bool isLhsComplexHasNan = false; + ArrayAttr lhsArrayAttr = dyn_cast_if_present(lhs); + if (lhsArrayAttr && lhsArrayAttr.size() == 2) { + APFloat lhsReal = cast(lhsArrayAttr[0]).getValue(); + APFloat lhsImag = cast(lhsArrayAttr[1]).getValue(); + isLhsComplexHasNan = lhsReal.isNaN() || lhsImag.isNaN(); + if (isLhsComplexHasNan) { + Attribute nanValue = lhsReal.isNaN() ? lhsArrayAttr[0] : lhsArrayAttr[1]; + return ArrayAttr::get(getContext(), {nanValue, nanValue}); + } + } - ArrayAttr rhsArrayAttr = dyn_cast(rhs); + ArrayAttr rhsArrayAttr = dyn_cast_if_present(rhs); if (!rhsArrayAttr || rhsArrayAttr.size() != 2) return {}; - ArrayAttr lhsArrayAttr = dyn_cast(lhs); - if (!lhsArrayAttr || lhsArrayAttr.size() != 2) - return {}; - + // Fold only if RHS is complex.constant<1.0, 0.0> APFloat rhsImag = cast(rhsArrayAttr[1]).getValue(); - if (!rhsImag.isZero()) + APFloat rhsReal = cast(rhsArrayAttr[0]).getValue(); + if (!rhsImag.isZero() || rhsReal != APFloat(rhsReal.getSemantics(), 1)) return {}; - APFloat lhsReal = cast(lhsArrayAttr[0]).getValue(); - APFloat lhsImag = cast(lhsArrayAttr[1]).getValue(); - if (lhsReal.isNaN() || lhsImag.isNaN()) { - Attribute nanValue = lhsReal.isNaN() ? lhsArrayAttr[0] : lhsArrayAttr[1]; - return ArrayAttr::get(getContext(), {nanValue, nanValue}); - } - - // complex.div(a, complex.constant<1.0, 0.0>) -> a - APFloat rhsReal = cast(rhsArrayAttr[0]).getValue(); - if (rhsReal == APFloat(rhsReal.getSemantics(), 1)) + // Fold to LHS if it doesn't contains NaNs or fast math flag nan is set + // complex.div(a, complex.constant<1.0, 0.0>) fastmath -> a + if ((lhsArrayAttr && !isLhsComplexHasNan) || + arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan)) return getLhs(); return {}; diff --git a/mlir/test/Dialect/Complex/canonicalize.mlir b/mlir/test/Dialect/Complex/canonicalize.mlir index b3f49eb3f44c1..1c5216c82e5c3 100644 --- a/mlir/test/Dialect/Complex/canonicalize.mlir +++ b/mlir/test/Dialect/Complex/canonicalize.mlir @@ -327,8 +327,8 @@ func.func @div_one_f128() -> complex { return %div : complex } -// CHECK-LABEL: div_op_with_rhs_has_nan -func.func @div_op_with_rhs_has_nan() -> complex { +// CHECK-LABEL: div_op_with_rhs_has_nan_real +func.func @div_op_with_rhs_has_nan_real() -> complex { %a = complex.constant [0x7fffffff : f32, 1.0 : f32]: complex %b = complex.constant [1.0: f32, 0.0 : f32]: complex %div = complex.div %a, %b : complex @@ -336,3 +336,45 @@ func.func @div_op_with_rhs_has_nan() -> complex { // CHECK: return %[[DIV]] : complex return %div : complex } + +// CHECK-LABEL: div_op_with_rhs_has_nan_imag +func.func @div_op_with_rhs_has_nan_imag() -> complex { + %a = complex.constant [1.0 : f32, 0x7fffffff : f32]: complex + %b = complex.constant [1.0: f32, 0.0 : f32]: complex + %div = complex.div %a, %b : complex + // CHECK: %[[DIV:.*]] = complex.constant [0x7FFFFFFF : f32, 0x7FFFFFFF : f32] : complex + // CHECK: return %[[DIV]] : complex + return %div : complex +} + +// CHECK-LABEL: div_op_with_rhs_has_nan_real_imag +func.func @div_op_with_rhs_has_nan_real_imag() -> complex { + %a = complex.constant [0x7fffffff : f32, 0x7fffffff : f32]: complex + %b = complex.constant [1.0: f32, 0.0 : f32]: complex + %div = complex.div %a, %b : complex + // CHECK: %[[DIV:.*]] = complex.constant [0x7FFFFFFF : f32, 0x7FFFFFFF : f32] : complex + // CHECK: return %[[DIV]] : complex + return %div : complex +} + +// CHECK-LABEL: div_op_non_constant_lhs_with_fast_math +func.func @div_op_non_constant_lhs_with_fast_math(%arg0: f32, %arg1: f32) -> complex { + %a = complex.create %arg0, %arg1 : complex + %b = complex.constant [1.0 : f32, 0.0 : f32] : complex + %div = complex.div %a, %b fastmath : complex + // CHECK: %[[COMPLEX:.*]] = complex.create %arg0, %arg1 : complex + // CHECK: return %[[COMPLEX]] : complex + return %div: complex +} + +// CHECK-LABEL: div_op_non_constant_lhs_without_fast_math +func.func @div_op_non_constant_lhs_without_fast_math(%arg0: f32, %arg1: f32) -> complex { + %a = complex.create %arg0, %arg1 : complex + %b = complex.constant [1.0 : f32, 0.0 : f32] : complex + %div = complex.div %a, %b : complex + // CHECK: %[[B:.*]] = complex.constant [1.000000e+00 : f32, 0.000000e+00 : f32] : complex + // CHECK: %[[A:.*]] = complex.create %arg0, %arg1 : complex + // CHECK: %[[DIV:.*]] = complex.div %[[A]], %[[B]] : complex + // CHECK: return %[[DIV]] : complex + return %div: complex +}