diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index e1eca6181dff9..194e1669a86c9 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -44,6 +44,49 @@ struct AbsOpConversion : public OpConversionPattern { } }; +// atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2)) +struct Atan2OpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + auto type = op.getType().cast(); + Type elementType = type.getElementType(); + + Value lhs = adaptor.getLhs(); + Value rhs = adaptor.getRhs(); + + Value rhsSquared = b.create(type, rhs, rhs); + Value lhsSquared = b.create(type, lhs, lhs); + Value rhsSquaredPlusLhsSquared = + b.create(type, rhsSquared, lhsSquared); + Value sqrtOfRhsSquaredPlusLhsSquared = + b.create(type, rhsSquaredPlusLhsSquared); + + Value zero = + b.create(elementType, b.getZeroAttr(elementType)); + Value one = b.create(elementType, + b.getFloatAttr(elementType, 1)); + Value i = b.create(type, zero, one); + Value iTimesLhs = b.create(i, lhs); + Value rhsPlusILhs = b.create(rhs, iTimesLhs); + + Value divResult = + b.create(rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared); + Value logResult = b.create(divResult); + + Value negativeOne = b.create( + elementType, b.getFloatAttr(elementType, -1)); + Value negativeI = b.create(type, zero, negativeOne); + + rewriter.replaceOpWithNewOp(op, negativeI, logResult); + return success(); + } +}; + template struct ComparisonOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -700,6 +743,72 @@ struct SinOpConversion : public TrigonometricOpConversion { } }; +// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780. +struct SqrtOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + auto type = op.getType().cast(); + Type elementType = type.getElementType(); + Value arg = adaptor.getComplex(); + + Value zero = + b.create(elementType, b.getZeroAttr(elementType)); + + Value real = b.create(elementType, adaptor.getComplex()); + Value imag = b.create(elementType, adaptor.getComplex()); + + Value absLhs = b.create(real); + Value absArg = b.create(elementType, arg); + Value addAbs = b.create(absLhs, absArg); + Value sqrtAddAbs = b.create(addAbs); + Value sqrtAddAbsDivTwo = b.create( + sqrtAddAbs, b.create( + elementType, b.getFloatAttr(elementType, 2))); + + Value realIsNegative = + b.create(arith::CmpFPredicate::OLT, real, zero); + Value imagIsNegative = + b.create(arith::CmpFPredicate::OLT, imag, zero); + + Value resultReal = sqrtAddAbsDivTwo; + + Value imagDivTwoResultReal = b.create( + imag, b.create(resultReal, resultReal)); + + Value negativeResultReal = b.create(resultReal); + + Value resultImag = b.create( + realIsNegative, + b.create(imagIsNegative, negativeResultReal, + resultReal), + imagDivTwoResultReal); + + resultReal = b.create( + realIsNegative, + b.create( + imag, b.create(resultImag, resultImag)), + resultReal); + + Value realIsZero = + b.create(arith::CmpFPredicate::OEQ, real, zero); + Value imagIsZero = + b.create(arith::CmpFPredicate::OEQ, imag, zero); + Value argIsZero = b.create(realIsZero, imagIsZero); + + resultReal = b.create(argIsZero, zero, resultReal); + resultImag = b.create(argIsZero, zero, resultImag); + + rewriter.replaceOpWithNewOp(op, type, resultReal, + resultImag); + return success(); + } +}; + struct SignOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -735,6 +844,7 @@ void mlir::populateComplexToStandardConversionPatterns( // clang-format off patterns.add< AbsOpConversion, + Atan2OpConversion, ComparisonOpConversion, ComparisonOpConversion, BinaryComplexOpConversion, @@ -748,7 +858,8 @@ void mlir::populateComplexToStandardConversionPatterns( MulOpConversion, NegOpConversion, SignOpConversion, - SinOpConversion>(patterns.getContext()); + SinOpConversion, + SqrtOpConversion>(patterns.getContext()); // clang-format on } diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir index 6f57e722b520e..bf41028718517 100644 --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -pass-pipeline="func.func(convert-complex-to-standard)" | FileCheck %s +// RUN: mlir-opt %s --convert-complex-to-standard --split-input-file | FileCheck %s // CHECK-LABEL: func @complex_abs // CHECK-SAME: %[[ARG:.*]]: complex @@ -14,6 +14,17 @@ func.func @complex_abs(%arg: complex) -> f32 { // CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 // CHECK: return %[[NORM]] : f32 +// ----- + +// CHECK-LABEL: func @complex_atan2 +func.func @complex_atan2(%lhs: complex, + %rhs: complex) -> complex { + %atan2 = complex.atan2 %lhs, %rhs : complex + return %atan2 : complex +} + +// ----- + // CHECK-LABEL: func @complex_add // CHECK-SAME: (%[[LHS:.*]]: complex, %[[RHS:.*]]: complex) func.func @complex_add(%lhs: complex, %rhs: complex) -> complex { @@ -29,6 +40,8 @@ func.func @complex_add(%lhs: complex, %rhs: complex) -> complex { // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex // CHECK: return %[[RESULT]] : complex +// ----- + // CHECK-LABEL: func @complex_cos // CHECK-SAME: %[[ARG:.*]]: complex func.func @complex_cos(%arg: complex) -> complex { @@ -50,6 +63,8 @@ func.func @complex_cos(%arg: complex) -> complex { // CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex // CHECK: return %[[RESULT]] +// ----- + // CHECK-LABEL: func @complex_div // CHECK-SAME: (%[[LHS:.*]]: complex, %[[RHS:.*]]: complex) func.func @complex_div(%lhs: complex, %rhs: complex) -> complex { @@ -159,6 +174,8 @@ func.func @complex_div(%lhs: complex, %rhs: complex) -> complex { // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_IMAG_WITH_SPECIAL_CASES]] : complex // CHECK: return %[[RESULT]] : complex +// ----- + // CHECK-LABEL: func @complex_eq // CHECK-SAME: %[[LHS:.*]]: complex, %[[RHS:.*]]: complex func.func @complex_eq(%lhs: complex, %rhs: complex) -> i1 { @@ -174,6 +191,8 @@ func.func @complex_eq(%lhs: complex, %rhs: complex) -> i1 { // CHECK: %[[EQUAL:.*]] = arith.andi %[[REAL_EQUAL]], %[[IMAG_EQUAL]] : i1 // CHECK: return %[[EQUAL]] : i1 +// ----- + // CHECK-LABEL: func @complex_exp // CHECK-SAME: %[[ARG:.*]]: complex func.func @complex_exp(%arg: complex) -> complex { @@ -190,6 +209,8 @@ func.func @complex_exp(%arg: complex) -> complex { // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex // CHECK: return %[[RESULT]] : complex +// ----- + // CHECK-LABEL: func.func @complex_expm1( // CHECK-SAME: %[[ARG:.*]]: complex) -> complex { func.func @complex_expm1(%arg: complex) -> complex { @@ -211,6 +232,8 @@ func.func @complex_expm1(%arg: complex) -> complex { // CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex // CHECK: return %[[RES]] : complex +// ----- + // CHECK-LABEL: func @complex_log // CHECK-SAME: %[[ARG:.*]]: complex func.func @complex_log(%arg: complex) -> complex { @@ -230,6 +253,8 @@ func.func @complex_log(%arg: complex) -> complex { // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex // CHECK: return %[[RESULT]] : complex +// ----- + // CHECK-LABEL: func @complex_log1p // CHECK-SAME: %[[ARG:.*]]: complex func.func @complex_log1p(%arg: complex) -> complex { @@ -254,6 +279,8 @@ func.func @complex_log1p(%arg: complex) -> complex { // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex // CHECK: return %[[RESULT]] : complex +// ----- + // CHECK-LABEL: func @complex_mul // CHECK-SAME: (%[[LHS:.*]]: complex, %[[RHS:.*]]: complex) func.func @complex_mul(%lhs: complex, %rhs: complex) -> complex { @@ -372,6 +399,8 @@ func.func @complex_mul(%lhs: complex, %rhs: complex) -> complex { // CHECK: %[[RESULT:.*]] = complex.create %[[FINAL_REAL]], %[[FINAL_IMAG]] : complex // CHECK: return %[[RESULT]] : complex +// ----- + // CHECK-LABEL: func @complex_neg // CHECK-SAME: %[[ARG:.*]]: complex func.func @complex_neg(%arg: complex) -> complex { @@ -385,6 +414,8 @@ func.func @complex_neg(%arg: complex) -> complex { // CHECK: %[[RESULT:.*]] = complex.create %[[NEG_REAL]], %[[NEG_IMAG]] : complex // CHECK: return %[[RESULT]] : complex +// ----- + // CHECK-LABEL: func @complex_neq // CHECK-SAME: %[[LHS:.*]]: complex, %[[RHS:.*]]: complex func.func @complex_neq(%lhs: complex, %rhs: complex) -> i1 { @@ -400,6 +431,8 @@ func.func @complex_neq(%lhs: complex, %rhs: complex) -> i1 { // CHECK: %[[NOT_EQUAL:.*]] = arith.ori %[[REAL_NOT_EQUAL]], %[[IMAG_NOT_EQUAL]] : i1 // CHECK: return %[[NOT_EQUAL]] : i1 +// ----- + // CHECK-LABEL: func @complex_sin // CHECK-SAME: %[[ARG:.*]]: complex func.func @complex_sin(%arg: complex) -> complex { @@ -421,6 +454,8 @@ func.func @complex_sin(%arg: complex) -> complex { // CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex // CHECK: return %[[RESULT]] +// ----- + // CHECK-LABEL: func @complex_sign // CHECK-SAME: %[[ARG:.*]]: complex func.func @complex_sign(%arg: complex) -> complex { @@ -445,6 +480,8 @@ func.func @complex_sign(%arg: complex) -> complex { // CHECK: %[[RESULT:.*]] = arith.select %[[IS_ZERO]], %[[ARG]], %[[SIGN]] : complex // CHECK: return %[[RESULT]] : complex +// ----- + // CHECK-LABEL: func @complex_sub // CHECK-SAME: (%[[LHS:.*]]: complex, %[[RHS:.*]]: complex) func.func @complex_sub(%lhs: complex, %rhs: complex) -> complex { @@ -459,3 +496,11 @@ func.func @complex_sub(%lhs: complex, %rhs: complex) -> complex { // CHECK: %[[RESULT_IMAG:.*]] = arith.subf %[[IMAG_LHS]], %[[IMAG_RHS]] : f32 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex // CHECK: return %[[RESULT]] : complex + +// ----- + +// CHECK-LABEL: func @complex_sqrt +func.func @complex_sqrt(%arg: complex) -> complex { + %sqrt = complex.sqrt %arg : complex + return %sqrt : complex +}