diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index 466458c05dba7..71d35e37bbe94 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -1323,26 +1323,6 @@ mlir::Value genComplexMathOp(fir::FirOpBuilder &builder, mlir::Location loc, return result; } -mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc, - const MathOperation &mathOp, - mlir::FunctionType mathLibFuncType, - llvm::ArrayRef args) { - if (mathRuntimeVersion == preciseVersion) - return genLibCall(builder, loc, mathOp, mathLibFuncType, args); - auto complexTy = mlir::cast(mathLibFuncType.getInput(0)); - mlir::Value exp = args[1]; - if (!mlir::isa(exp.getType())) { - auto realTy = complexTy.getElementType(); - mlir::Value realExp = builder.createConvert(loc, realTy, exp); - mlir::Value zero = builder.createRealConstant(loc, realTy, 0); - exp = - builder.create(loc, complexTy, realExp, zero); - } - mlir::Value result = builder.create(loc, args[0], exp); - result = builder.createConvert(loc, mathLibFuncType.getResult(0), result); - return result; -} - /// Mapping between mathematical intrinsic operations and MLIR operations /// of some appropriate dialect (math, complex, etc.) or libm calls. /// TODO: support remaining Fortran math intrinsics. @@ -1668,11 +1648,11 @@ static constexpr MathOperation mathOperations[] = { {"pow", RTNAME_STRING(PowF128), FuncTypeReal16Real16Real16, genLibF128Call}, {"pow", "cpowf", genFuncType, Ty::Complex<4>, Ty::Complex<4>>, - genComplexPow}, + genMathOp}, {"pow", "cpow", genFuncType, Ty::Complex<8>, Ty::Complex<8>>, - genComplexPow}, + genMathOp}, {"pow", RTNAME_STRING(CPowF128), FuncTypeComplex16Complex16Complex16, - genComplexPow}, + genMathOp}, {"pow", RTNAME_STRING(FPow4i), genFuncType, Ty::Real<4>, Ty::Integer<4>>, genMathOp}, @@ -1693,20 +1673,20 @@ static constexpr MathOperation mathOperations[] = { genMathOp}, {"pow", RTNAME_STRING(cpowi), genFuncType, Ty::Complex<4>, Ty::Integer<4>>, - genComplexPow}, + genMathOp}, {"pow", RTNAME_STRING(zpowi), genFuncType, Ty::Complex<8>, Ty::Integer<4>>, - genComplexPow}, + genMathOp}, {"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4, - genComplexPow}, + genMathOp}, {"pow", RTNAME_STRING(cpowk), genFuncType, Ty::Complex<4>, Ty::Integer<8>>, - genComplexPow}, + genMathOp}, {"pow", RTNAME_STRING(zpowk), genFuncType, Ty::Complex<8>, Ty::Integer<8>>, - genComplexPow}, + genMathOp}, {"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8, - genComplexPow}, + genMathOp}, {"pow-unsigned", RTNAME_STRING(UPow1), genFuncType, Ty::Integer<1>, Ty::Integer<1>>, genLibCall}, {"pow-unsigned", RTNAME_STRING(UPow2), diff --git a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp index 78f9d9e4f639a..127f8720ae524 100644 --- a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp +++ b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp @@ -47,39 +47,19 @@ static func::FuncOp getOrDeclare(fir::FirOpBuilder &builder, Location loc, return func; } -static bool isZero(Value v) { - if (auto cst = v.getDefiningOp()) - if (auto attr = dyn_cast(cst.getValue())) - return attr.getValue().isZero(); - return false; -} - void ConvertComplexPowPass::runOnOperation() { ModuleOp mod = getOperation(); fir::FirOpBuilder builder(mod, fir::getKindMapping(mod)); - mod.walk([&](complex::PowOp op) { - builder.setInsertionPoint(op); - Location loc = op.getLoc(); - auto complexTy = cast(op.getType()); - auto elemTy = complexTy.getElementType(); - - Value base = op.getLhs(); - Value rhs = op.getRhs(); - - Value intExp; - if (auto create = rhs.getDefiningOp()) { - if (isZero(create.getImaginary())) { - if (auto conv = create.getReal().getDefiningOp()) { - if (auto intTy = dyn_cast(conv.getValue().getType())) - intExp = conv.getValue(); - } - } - } - - func::FuncOp callee; - SmallVector args; - if (intExp) { + mod.walk([&](Operation *op) { + if (auto powIop = dyn_cast(op)) { + builder.setInsertionPoint(powIop); + Location loc = powIop.getLoc(); + auto complexTy = cast(powIop.getType()); + auto elemTy = complexTy.getElementType(); + Value base = powIop.getLhs(); + Value intExp = powIop.getRhs(); + func::FuncOp callee; unsigned realBits = cast(elemTy).getWidth(); unsigned intBits = cast(intExp.getType()).getWidth(); auto funcTy = builder.getFunctionType( @@ -98,9 +78,20 @@ void ConvertComplexPowPass::runOnOperation() { callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy); else return; - args = {base, intExp}; - } else { + auto call = fir::CallOp::create(builder, loc, callee, {base, intExp}); + if (auto fmf = powIop.getFastmathAttr()) + call.setFastmathAttr(fmf); + powIop.replaceAllUsesWith(call.getResult(0)); + powIop.erase(); + } + + if (auto powOp = dyn_cast(op)) { + builder.setInsertionPoint(powOp); + Location loc = powOp.getLoc(); + auto complexTy = cast(powOp.getType()); + auto elemTy = complexTy.getElementType(); unsigned realBits = cast(elemTy).getWidth(); + func::FuncOp callee; auto funcTy = builder.getFunctionType({complexTy, complexTy}, {complexTy}); if (realBits == 32) @@ -111,13 +102,12 @@ void ConvertComplexPowPass::runOnOperation() { callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy); else return; - args = {base, rhs}; + auto call = fir::CallOp::create(builder, loc, callee, + {powOp.getLhs(), powOp.getRhs()}); + if (auto fmf = powOp.getFastmathAttr()) + call.setFastmathAttr(fmf); + powOp.replaceAllUsesWith(call.getResult(0)); + powOp.erase(); } - - auto call = fir::CallOp::create(builder, loc, callee, args); - if (auto fmf = op.getFastmathAttr()) - call.setFastmathAttr(fmf); - op.replaceAllUsesWith(call.getResult(0)); - op.erase(); }); } diff --git a/flang/test/Lower/HLFIR/binary-ops.f90 b/flang/test/Lower/HLFIR/binary-ops.f90 index 1fbd333db37c3..b7695a761a0b8 100644 --- a/flang/test/Lower/HLFIR/binary-ops.f90 +++ b/flang/test/Lower/HLFIR/binary-ops.f90 @@ -193,7 +193,7 @@ subroutine complex_to_int_power(x, y, z) ! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref, !fir.dscope) -> (!fir.ref, !fir.ref) ! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref> ! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref -! CHECK: %[[VAL_8:.*]] = complex.pow +! CHECK: %[[VAL_8:.*]] = complex.powi %[[VAL_6]], %[[VAL_7]] fastmath : complex, i32 subroutine extremum(c, n, l) integer(8), intent(in) :: l diff --git a/flang/test/Lower/Intrinsics/pow_complex16i.f90 b/flang/test/Lower/Intrinsics/pow_complex16i.f90 index 1827863a57f43..ea18d67b75460 100644 --- a/flang/test/Lower/Intrinsics/pow_complex16i.f90 +++ b/flang/test/Lower/Intrinsics/pow_complex16i.f90 @@ -4,7 +4,7 @@ ! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s ! PRECISE: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex, i32) -> complex -! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath : complex +! CHECK: complex.powi %{{.*}}, %{{.*}} fastmath : complex complex(16) :: a integer(4) :: b b = a ** b diff --git a/flang/test/Lower/Intrinsics/pow_complex16k.f90 b/flang/test/Lower/Intrinsics/pow_complex16k.f90 index 039dfd5152a06..d2b70185bda9f 100644 --- a/flang/test/Lower/Intrinsics/pow_complex16k.f90 +++ b/flang/test/Lower/Intrinsics/pow_complex16k.f90 @@ -4,7 +4,7 @@ ! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s ! PRECISE: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex, i64) -> complex -! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath : complex +! CHECK: complex.powi %{{.*}}, %{{.*}} fastmath : complex complex(16) :: a integer(8) :: b b = a ** b diff --git a/flang/test/Lower/amdgcn-complex.f90 b/flang/test/Lower/amdgcn-complex.f90 index 4ee5de4d2842e..a28eaea82379b 100644 --- a/flang/test/Lower/amdgcn-complex.f90 +++ b/flang/test/Lower/amdgcn-complex.f90 @@ -25,3 +25,12 @@ subroutine pow_test(a, b, c) complex :: a, b, c a = b**c end subroutine pow_test + +! CHECK-LABEL: func @_QPpowi_test( +! CHECK: complex.powi +! CHECK-NOT: fir.call @_FortranAcpowi +subroutine powi_test(a, b, c) + complex :: a, b + integer :: i + b = a ** i +end subroutine powi_test diff --git a/flang/test/Lower/power-operator.f90 b/flang/test/Lower/power-operator.f90 index 3058927144248..9f74d172a6bb2 100644 --- a/flang/test/Lower/power-operator.f90 +++ b/flang/test/Lower/power-operator.f90 @@ -96,7 +96,7 @@ subroutine pow_c4_i4(x, y, z) complex :: x, z integer :: y z = x ** y - ! CHECK: complex.pow + ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex, i32 ! PRECISE: fir.call @_FortranAcpowi end subroutine @@ -105,7 +105,7 @@ subroutine pow_c4_i8(x, y, z) complex :: x, z integer(8) :: y z = x ** y - ! CHECK: complex.pow + ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex, i64 ! PRECISE: fir.call @_FortranAcpowk end subroutine @@ -114,7 +114,7 @@ subroutine pow_c8_i4(x, y, z) complex(8) :: x, z integer :: y z = x ** y - ! CHECK: complex.pow + ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex, i32 ! PRECISE: fir.call @_FortranAzpowi end subroutine @@ -123,7 +123,7 @@ subroutine pow_c8_i8(x, y, z) complex(8) :: x, z integer(8) :: y z = x ** y - ! CHECK: complex.pow + ! CHECK: complex.powi %{{.*}}, %{{.*}} : complex, i64 ! PRECISE: fir.call @_FortranAzpowk end subroutine @@ -142,4 +142,3 @@ subroutine pow_c8_c8(x, y, z) ! CHECK: complex.pow %{{.*}}, %{{.*}} : complex ! PRECISE: fir.call @cpow end subroutine - diff --git a/flang/test/Transforms/convert-complex-pow.fir b/flang/test/Transforms/convert-complex-pow.fir index e09fa7316c4b0..23316ed46d40f 100644 --- a/flang/test/Transforms/convert-complex-pow.fir +++ b/flang/test/Transforms/convert-complex-pow.fir @@ -2,51 +2,38 @@ module { func.func @pow_c4_i4(%arg0: complex, %arg1: i32) -> complex { - %c0 = arith.constant 0.0 : f32 - %0 = fir.convert %arg1 : (i32) -> f32 - %1 = complex.create %0, %c0 : complex - %2 = complex.pow %arg0, %1 : complex - return %2 : complex + %0 = complex.powi %arg0, %arg1 : complex, i32 + return %0 : complex + } + + func.func @pow_c4_i4_fast(%arg0: complex, %arg1: i32) -> complex { + %0 = complex.powi %arg0, %arg1 fastmath : complex, i32 + return %0 : complex } func.func @pow_c4_i8(%arg0: complex, %arg1: i64) -> complex { - %c0 = arith.constant 0.0 : f32 - %0 = fir.convert %arg1 : (i64) -> f32 - %1 = complex.create %0, %c0 : complex - %2 = complex.pow %arg0, %1 : complex - return %2 : complex + %0 = complex.powi %arg0, %arg1 : complex, i64 + return %0 : complex } func.func @pow_c8_i4(%arg0: complex, %arg1: i32) -> complex { - %c0 = arith.constant 0.0 : f64 - %0 = fir.convert %arg1 : (i32) -> f64 - %1 = complex.create %0, %c0 : complex - %2 = complex.pow %arg0, %1 : complex - return %2 : complex + %0 = complex.powi %arg0, %arg1 : complex, i32 + return %0 : complex } func.func @pow_c8_i8(%arg0: complex, %arg1: i64) -> complex { - %c0 = arith.constant 0.0 : f64 - %0 = fir.convert %arg1 : (i64) -> f64 - %1 = complex.create %0, %c0 : complex - %2 = complex.pow %arg0, %1 : complex - return %2 : complex + %0 = complex.powi %arg0, %arg1 : complex, i64 + return %0 : complex } func.func @pow_c16_i4(%arg0: complex, %arg1: i32) -> complex { - %c0 = arith.constant 0.0 : f128 - %0 = fir.convert %arg1 : (i32) -> f128 - %1 = complex.create %0, %c0 : complex - %2 = complex.pow %arg0, %1 : complex - return %2 : complex + %0 = complex.powi %arg0, %arg1 : complex, i32 + return %0 : complex } func.func @pow_c16_i8(%arg0: complex, %arg1: i64) -> complex { - %c0 = arith.constant 0.0 : f128 - %0 = fir.convert %arg1 : (i64) -> f128 - %1 = complex.create %0, %c0 : complex - %2 = complex.pow %arg0, %1 : complex - return %2 : complex + %0 = complex.powi %arg0, %arg1 : complex, i64 + return %0 : complex } func.func @pow_c4_fast(%arg0: complex, %arg1: f32) -> complex { @@ -74,26 +61,37 @@ module { // CHECK-LABEL: func.func @pow_c4_i4( // CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) : (complex, i32) -> complex // CHECK-NOT: complex.pow +// CHECK-NOT: complex.powi + +// CHECK-LABEL: func.func @pow_c4_i4_fast( +// CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) fastmath : (complex, i32) -> complex +// CHECK-NOT: complex.pow +// CHECK-NOT: complex.powi // CHECK-LABEL: func.func @pow_c4_i8( // CHECK: fir.call @_FortranAcpowk(%{{.*}}, %{{.*}}) : (complex, i64) -> complex // CHECK-NOT: complex.pow +// CHECK-NOT: complex.powi // CHECK-LABEL: func.func @pow_c8_i4( // CHECK: fir.call @_FortranAzpowi(%{{.*}}, %{{.*}}) : (complex, i32) -> complex // CHECK-NOT: complex.pow +// CHECK-NOT: complex.powi // CHECK-LABEL: func.func @pow_c8_i8( // CHECK: fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}) : (complex, i64) -> complex // CHECK-NOT: complex.pow +// CHECK-NOT: complex.powi // CHECK-LABEL: func.func @pow_c16_i4( // CHECK: fir.call @_FortranAcqpowi(%{{.*}}, %{{.*}}) : (complex, i32) -> complex // CHECK-NOT: complex.pow +// CHECK-NOT: complex.powi // CHECK-LABEL: func.func @pow_c16_i8( // CHECK: fir.call @_FortranAcqpowk(%{{.*}}, %{{.*}}) : (complex, i64) -> complex // CHECK-NOT: complex.pow +// CHECK-NOT: complex.powi // CHECK-LABEL: func.func @pow_c4_fast( // CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex @@ -108,4 +106,4 @@ module { // CHECK-LABEL: func.func @pow_c16_complex( // CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex // CHECK: fir.call @_FortranACPowF128(%{{.*}}, %[[EXP]]) : (complex, complex) -> complex -// CHECK-NOT: complex.pow \ No newline at end of file +// CHECK-NOT: complex.pow diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td index 44590406301eb..828379ded14b3 100644 --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -443,6 +443,36 @@ def PowOp : ComplexArithmeticOp<"pow"> { }]; } +//===----------------------------------------------------------------------===// +// PowiOp +//===----------------------------------------------------------------------===// + +def PowiOp : Complex_Op<"powi", + [Pure, Elementwise, SameOperandsAndResultShape, + AllTypesMatch<["lhs", "result"]>, + DeclareOpInterfaceMethods]> { + let summary = "complex number raised to signed integer power"; + let description = [{ + The `powi` operation takes a `base` operand of complex type and a `power` + operand of signed integer type and returns one result of the same type + as `base`. The result is `base` raised to the power of `power`. + + Example: + + ```mlir + %a = complex.powi %b, %c : complex, i32 + ``` + }]; + + let arguments = (ins Complex:$lhs, + AnySignlessInteger:$rhs, + OptionalAttr:$fastmath); + let results = (outs Complex:$result); + + let assemblyFormat = + "$lhs `,` $rhs (`fastmath` `` $fastmath^)? attr-dict `:` type($result) `,` type($rhs)"; +} + //===----------------------------------------------------------------------===// // ReOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp index 72b1fa6e833f9..42099aaa6b574 100644 --- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp +++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp @@ -7,9 +7,11 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { @@ -74,10 +76,39 @@ struct PowOpToROCDLLibraryCalls : public OpRewritePattern { return success(); } }; + +// Rewrite complex.powi(z, n) -> complex.pow(z, complex(float(n), 0)) +struct PowiOpToROCDLLibraryCalls : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(complex::PowiOp op, + PatternRewriter &rewriter) const final { + auto complexType = cast(getElementTypeOrSelf(op.getType())); + Type elementType = complexType.getElementType(); + + Type exponentType = op.getRhs().getType(); + Type exponentFloatType = elementType; + if (auto shapedType = dyn_cast(exponentType)) + exponentFloatType = shapedType.cloneWith(std::nullopt, elementType); + + Location loc = op.getLoc(); + Value exponentReal = + rewriter.create(loc, exponentFloatType, op.getRhs()); + Value zeroImag = rewriter.create( + loc, rewriter.getZeroAttr(exponentFloatType)); + Value exponent = rewriter.create( + loc, op.getLhs().getType(), exponentReal, zeroImag); + + rewriter.replaceOpWithNewOp(op, op.getType(), op.getLhs(), + exponent, op.getFastmathAttr()); + return success(); + } +}; } // namespace void mlir::populateComplexToROCDLLibraryCallsConversionPatterns( RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); patterns.add>( patterns.getContext(), "__ocml_cabs_f32"); @@ -128,11 +159,12 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() { populateComplexToROCDLLibraryCallsConversionPatterns(patterns); ConversionTarget target(getContext()); - target.addLegalDialect(); - target.addLegalOp(); + target.addLegalDialect(); + target.addLegalOp(); target.addIllegalOp(); + complex::LogOp, complex::PowOp, complex::PowiOp, + complex::SinOp, complex::SqrtOp, complex::TanOp, + complex::TanhOp>(); if (failed(applyPartialConversion(op, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index 5ad514d0f48e7..5613e021cd709 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -926,6 +926,30 @@ static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder, return cutoff4; } +struct PowiOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(complex::PowiOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ImplicitLocOpBuilder builder(op.getLoc(), rewriter); + auto type = cast(op.getType()); + auto elementType = cast(type.getElementType()); + + Value floatExponent = + builder.create(elementType, adaptor.getRhs()); + Value zero = arith::ConstantOp::create( + builder, elementType, builder.getFloatAttr(elementType, 0.0)); + Value complexExponent = + complex::CreateOp::create(builder, type, floatExponent, zero); + + auto pow = builder.create( + type, adaptor.getLhs(), complexExponent, op.getFastmathAttr()); + rewriter.replaceOp(op, pow.getResult()); + return success(); + } +}; + struct PowOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -1070,6 +1094,7 @@ void mlir::populateComplexToStandardConversionPatterns( SqrtOpConversion, TanTanhOpConversion, TanTanhOpConversion, + PowiOpConversion, PowOpConversion, RsqrtOpConversion >(patterns.getContext()); diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp index 31785eb20a642..77b10cec48d8e 100644 --- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp +++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -175,12 +176,20 @@ PowIStrengthReduction::matchAndRewrite( Value one; Type opType = getElementTypeOrSelf(op.getType()); - if constexpr (std::is_same_v) + if constexpr (std::is_same_v) { one = arith::ConstantOp::create(rewriter, loc, rewriter.getFloatAttr(opType, 1.0)); - else + } else if constexpr (std::is_same_v) { + auto complexTy = cast(opType); + Type elementType = complexTy.getElementType(); + auto realPart = rewriter.getFloatAttr(elementType, 1.0); + auto imagPart = rewriter.getFloatAttr(elementType, 0.0); + one = complex::ConstantOp::create( + rewriter, loc, complexTy, rewriter.getArrayAttr({realPart, imagPart})); + } else { one = arith::ConstantOp::create(rewriter, loc, rewriter.getIntegerAttr(opType, 1)); + } // Replace `[fi]powi(x, 0)` with `1`. if (exponentValue == 0) { @@ -208,13 +217,25 @@ PowIStrengthReduction::matchAndRewrite( // `[fi]powi(x, negative_exponent)` // with: // (1 / x) * (1 / x) * (1 / x) * ... + auto buildMul = [&](Value lhs, Value rhs) { + if constexpr (std::is_same_v) + return MulOpTy::create(rewriter, loc, op.getType(), lhs, rhs, + op.getFastmathAttr()); + else + return MulOpTy::create(rewriter, loc, lhs, rhs); + }; for (unsigned i = 1; i < exponentValue; ++i) - result = MulOpTy::create(rewriter, loc, result, base); + result = buildMul(result, base); // Inverse the base for negative exponent, i.e. for // `[fi]powi(x, negative_exponent)` set `x` to `1 / x`. - if (exponentIsNegative) - result = DivOpTy::create(rewriter, loc, bcast(one), result); + if (exponentIsNegative) { + if constexpr (std::is_same_v) + result = DivOpTy::create(rewriter, loc, op.getType(), bcast(one), result, + op.getFastmathAttr()); + else + result = DivOpTy::create(rewriter, loc, bcast(one), result); + } rewriter.replaceOp(op, result); return success(); @@ -224,9 +245,10 @@ PowIStrengthReduction::matchAndRewrite( void mlir::populateMathAlgebraicSimplificationPatterns( RewritePatternSet &patterns) { - patterns - .add, - PowIStrengthReduction>( - patterns.getContext()); + patterns.add< + PowFStrengthReduction, + PowIStrengthReduction, + PowIStrengthReduction, + PowIStrengthReduction>( + patterns.getContext(), /*exponentThreshold=*/8); } diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt index d37a056e8e158..ff62b515533c3 100644 --- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt @@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRMathTransforms LINK_LIBS PUBLIC MLIRArithDialect + MLIRComplexDialect MLIRDialectUtils MLIRIR MLIRMathDialect diff --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir index 080ba4f0ff67b..cf177528e532c 100644 --- a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir +++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir @@ -68,6 +68,20 @@ func.func @pow_caller(%z: complex, %w: complex) -> complex { return %r : complex } +//CHECK-LABEL: @powi_caller +//CHECK: (%[[Z:.*]]: complex, %[[N:.*]]: i32) +func.func @powi_caller(%z: complex, %n: i32) -> complex { + // CHECK: %[[N_FP:.*]] = arith.sitofp %[[N]] : i32 to f32 + // CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[N_COMPLEX:.*]] = complex.create %[[N_FP]], %[[ZERO]] : complex + // CHECK: %[[LOG:.*]] = call @__ocml_clog_f32(%[[Z]]) : (complex) -> complex + // CHECK: %[[MUL:.*]] = complex.mul %[[N_COMPLEX]], %[[LOG]] : complex + // CHECK: %[[EXP:.*]] = call @__ocml_cexp_f32(%[[MUL]]) : (complex) -> complex + // CHECK: return %[[EXP]] : complex + %r = complex.powi %z, %n : complex, i32 + return %r : complex +} + //CHECK-LABEL: @sin_caller func.func @sin_caller(%f: complex, %d: complex) -> (complex, complex) { // CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}}) diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir index a4ddabbd0821a..dec62f92c7b2e 100644 --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -700,6 +700,36 @@ func.func @complex_pow_with_fmf(%lhs: complex, // ----- +// CHECK-LABEL: func.func @complex_powi +// CHECK-SAME: %[[LHS:.*]]: complex, %[[EXP:.*]]: i32 +func.func @complex_powi(%lhs: complex, %rhs: i32) -> complex { + %pow = complex.powi %lhs, %rhs : complex, i32 + return %pow : complex +} + +// CHECK: %[[FLOAT_EXP:.*]] = arith.sitofp %[[EXP]] : i32 to f32 +// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[CPLX_EXP:.*]] = complex.create %[[FLOAT_EXP]], %[[ZERO]] : complex +// CHECK: math.atan2 +// CHECK-NOT: complex.powi + +// ----- + +// CHECK-LABEL: func.func @complex_powi_with_fmf +// CHECK-SAME: %[[LHS:.*]]: complex, %[[EXP:.*]]: i32 +func.func @complex_powi_with_fmf(%lhs: complex, %rhs: i32) -> complex { + %pow = complex.powi %lhs, %rhs fastmath : complex, i32 + return %pow : complex +} + +// CHECK: %[[FLOAT_EXP:.*]] = arith.sitofp %[[EXP]] : i32 to f32 +// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[CPLX_EXP:.*]] = complex.create %[[FLOAT_EXP]], %[[ZERO]] : complex +// CHECK: math.atan2 {{.*}} fastmath : f32 +// CHECK-NOT: complex.powi + +// ----- + // CHECK-LABEL: func.func @complex_rsqrt func.func @complex_rsqrt(%arg: complex) -> complex { %rsqrt = complex.rsqrt %arg : complex diff --git a/mlir/test/Dialect/Complex/powi-simplify.mlir b/mlir/test/Dialect/Complex/powi-simplify.mlir new file mode 100644 index 0000000000000..c7bb6a9d81479 --- /dev/null +++ b/mlir/test/Dialect/Complex/powi-simplify.mlir @@ -0,0 +1,20 @@ +// RUN: mlir-opt %s -test-math-algebraic-simplification | FileCheck %s + +func.func @pow3(%arg0: complex) -> complex { + %c3 = arith.constant 3 : i32 + %0 = complex.powi %arg0, %c3 : complex, i32 + return %0 : complex +} +// CHECK-LABEL: func.func @pow3( +// CHECK-NOT: complex.powi +// CHECK: %[[M0:.+]] = complex.mul %{{.*}}, %{{.*}} : complex +// CHECK: %[[M1:.+]] = complex.mul %[[M0]], %{{.*}} : complex +// CHECK: return %[[M1]] : complex + +func.func @pow9(%arg0: complex) -> complex { + %c9 = arith.constant 9 : i32 + %0 = complex.powi %arg0, %c9 : complex, i32 + return %0 : complex +} +// CHECK-LABEL: func.func @pow9( +// CHECK: complex.powi %{{.*}}, %{{.*}} : complex, i32