diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td index e3001454cdf19..093d5de028048 100644 --- a/flang/include/flang/Optimizer/Transforms/Passes.td +++ b/flang/include/flang/Optimizer/Transforms/Passes.td @@ -551,6 +551,17 @@ def SimplifyFIROperations : Pass<"simplify-fir-operations", "mlir::ModuleOp"> { "Prefer expanding without using Fortran runtime calls.">]; } +def ConvertComplexPow : Pass<"convert-complex-pow", "mlir::ModuleOp"> { + let summary = "Convert complex.pow operations to library calls"; + let description = [{ + Replace `complex.pow` operations with calls to the appropriate + Fortran runtime or libm functions. + }]; + let dependentDialects = ["fir::FIROpsDialect", "mlir::func::FuncDialect", + "mlir::complex::ComplexDialect", + "mlir::arith::ArithDialect"]; +} + def OptimizeArrayRepacking : Pass<"optimize-array-repacking", "mlir::func::FuncOp"> { let summary = "Optimizes redundant array repacking operations"; diff --git a/flang/include/flang/Tools/CrossToolHelpers.h b/flang/include/flang/Tools/CrossToolHelpers.h index 335f0a45531c8..c2a4e082b129d 100644 --- a/flang/include/flang/Tools/CrossToolHelpers.h +++ b/flang/include/flang/Tools/CrossToolHelpers.h @@ -134,6 +134,7 @@ struct MLIRToLLVMPassPipelineConfig : public FlangEPCallBacks { bool NSWOnLoopVarInc = true; ///< Add nsw flag to loop variable increments. bool EnableOpenMP = false; ///< Enable OpenMP lowering. bool EnableOpenMPSimd = false; ///< Enable OpenMP simd-only mode. + bool SkipConvertComplexPow = false; ///< Do not run complex pow conversion. std::string InstrumentFunctionEntry = ""; ///< Name of the instrument-function that is called on each ///< function-entry diff --git a/flang/lib/Frontend/FrontendActions.cpp b/flang/lib/Frontend/FrontendActions.cpp index 23cc1e63e773d..6ebea5f8501b4 100644 --- a/flang/lib/Frontend/FrontendActions.cpp +++ b/flang/lib/Frontend/FrontendActions.cpp @@ -738,6 +738,8 @@ void CodeGenAction::generateLLVMIR() { pm.enableVerifier(/*verifyPasses=*/true); MLIRToLLVMPassPipelineConfig config(level, opts, mathOpts); + llvm::Triple pipelineTriple(invoc.getTargetOpts().triple); + config.SkipConvertComplexPow = pipelineTriple.isAMDGCN(); fir::registerDefaultInlinerPass(config); if (auto vsr = getVScaleRange(ci)) { diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index ce1376fd209cc..466458c05dba7 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -1327,18 +1327,18 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc, const MathOperation &mathOp, mlir::FunctionType mathLibFuncType, llvm::ArrayRef args) { - bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN(); - if (!isAMDGPU) + if (mathRuntimeVersion == preciseVersion) return genLibCall(builder, loc, mathOp, mathLibFuncType, args); - auto complexTy = mlir::cast(mathLibFuncType.getInput(0)); - auto realTy = complexTy.getElementType(); - mlir::Value realExp = builder.createConvert(loc, realTy, args[1]); - mlir::Value zero = builder.createRealConstant(loc, realTy, 0); - mlir::Value complexExp = - builder.create(loc, complexTy, realExp, zero); - mlir::Value result = - builder.create(loc, args[0], complexExp); + mlir::Value exp = args[1]; + if (!mlir::isa(exp.getType())) { + auto realTy = complexTy.getElementType(); + mlir::Value realExp = builder.createConvert(loc, realTy, exp); + mlir::Value zero = builder.createRealConstant(loc, realTy, 0); + exp = + builder.create(loc, complexTy, realExp, zero); + } + mlir::Value result = builder.create(loc, args[0], exp); result = builder.createConvert(loc, mathLibFuncType.getResult(0), result); return result; } @@ -1668,11 +1668,11 @@ static constexpr MathOperation mathOperations[] = { {"pow", RTNAME_STRING(PowF128), FuncTypeReal16Real16Real16, genLibF128Call}, {"pow", "cpowf", genFuncType, Ty::Complex<4>, Ty::Complex<4>>, - genComplexMathOp}, + genComplexPow}, {"pow", "cpow", genFuncType, Ty::Complex<8>, Ty::Complex<8>>, - genComplexMathOp}, + genComplexPow}, {"pow", RTNAME_STRING(CPowF128), FuncTypeComplex16Complex16Complex16, - genLibF128Call}, + genComplexPow}, {"pow", RTNAME_STRING(FPow4i), genFuncType, Ty::Real<4>, Ty::Integer<4>>, genMathOp}, @@ -1698,7 +1698,7 @@ static constexpr MathOperation mathOperations[] = { genFuncType, Ty::Complex<8>, Ty::Integer<4>>, genComplexPow}, {"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4, - genLibF128Call}, + genComplexPow}, {"pow", RTNAME_STRING(cpowk), genFuncType, Ty::Complex<4>, Ty::Integer<8>>, genComplexPow}, @@ -1706,7 +1706,7 @@ static constexpr MathOperation mathOperations[] = { genFuncType, Ty::Complex<8>, Ty::Integer<8>>, genComplexPow}, {"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8, - genLibF128Call}, + genComplexPow}, {"pow-unsigned", RTNAME_STRING(UPow1), genFuncType, Ty::Integer<1>, Ty::Integer<1>>, genLibCall}, {"pow-unsigned", RTNAME_STRING(UPow2), diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp index 7c2777baebef1..805f84e888798 100644 --- a/flang/lib/Optimizer/Passes/Pipelines.cpp +++ b/flang/lib/Optimizer/Passes/Pipelines.cpp @@ -225,6 +225,8 @@ void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm, pm.addPass(mlir::createCanonicalizerPass(config)); pm.addPass(fir::createSimplifyRegionLite()); + if (!pc.SkipConvertComplexPow) + pm.addPass(fir::createConvertComplexPow()); pm.addPass(mlir::createCSEPass()); if (pc.OptLevel.isOptimizingForSpeed()) diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt index a8812e08c1ccd..4ec16274830fe 100644 --- a/flang/lib/Optimizer/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt @@ -35,6 +35,7 @@ add_flang_library(FIRTransforms GenRuntimeCallsForTest.cpp SimplifyFIROperations.cpp OptimizeArrayRepacking.cpp + ConvertComplexPow.cpp DEPENDS CUFAttrs diff --git a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp new file mode 100644 index 0000000000000..78f9d9e4f639a --- /dev/null +++ b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp @@ -0,0 +1,123 @@ +//===- ConvertComplexPow.cpp - Convert complex.pow to library calls -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "flang/Common/static-multimap-view.h" +#include "flang/Optimizer/Builder/FIRBuilder.h" +#include "flang/Optimizer/Dialect/FIRDialect.h" +#include "flang/Optimizer/Transforms/Passes.h" +#include "flang/Runtime/entry-names.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" + +namespace fir { +#define GEN_PASS_DEF_CONVERTCOMPLEXPOW +#include "flang/Optimizer/Transforms/Passes.h.inc" +} // namespace fir + +using namespace mlir; + +namespace { +class ConvertComplexPowPass + : public fir::impl::ConvertComplexPowBase { +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() override; +}; +} // namespace + +// Helper to declare or get a math library function. +static func::FuncOp getOrDeclare(fir::FirOpBuilder &builder, Location loc, + StringRef name, FunctionType type) { + if (auto func = builder.getNamedFunction(name)) + return func; + auto func = builder.createFunction(loc, name, type); + func->setAttr(fir::getSymbolAttrName(), builder.getStringAttr(name)); + func->setAttr(fir::FIROpsDialect::getFirRuntimeAttrName(), + builder.getUnitAttr()); + return func; +} + +static bool isZero(Value v) { + if (auto cst = v.getDefiningOp()) + if (auto attr = dyn_cast(cst.getValue())) + return attr.getValue().isZero(); + return false; +} + +void ConvertComplexPowPass::runOnOperation() { + ModuleOp mod = getOperation(); + fir::FirOpBuilder builder(mod, fir::getKindMapping(mod)); + + mod.walk([&](complex::PowOp op) { + builder.setInsertionPoint(op); + Location loc = op.getLoc(); + auto complexTy = cast(op.getType()); + auto elemTy = complexTy.getElementType(); + + Value base = op.getLhs(); + Value rhs = op.getRhs(); + + Value intExp; + if (auto create = rhs.getDefiningOp()) { + if (isZero(create.getImaginary())) { + if (auto conv = create.getReal().getDefiningOp()) { + if (auto intTy = dyn_cast(conv.getValue().getType())) + intExp = conv.getValue(); + } + } + } + + func::FuncOp callee; + SmallVector args; + if (intExp) { + unsigned realBits = cast(elemTy).getWidth(); + unsigned intBits = cast(intExp.getType()).getWidth(); + auto funcTy = builder.getFunctionType( + {complexTy, builder.getIntegerType(intBits)}, {complexTy}); + if (realBits == 32 && intBits == 32) + callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy); + else if (realBits == 32 && intBits == 64) + callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy); + else if (realBits == 64 && intBits == 32) + callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy); + else if (realBits == 64 && intBits == 64) + callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy); + else if (realBits == 128 && intBits == 32) + callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy); + else if (realBits == 128 && intBits == 64) + callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy); + else + return; + args = {base, intExp}; + } else { + unsigned realBits = cast(elemTy).getWidth(); + auto funcTy = + builder.getFunctionType({complexTy, complexTy}, {complexTy}); + if (realBits == 32) + callee = getOrDeclare(builder, loc, "cpowf", funcTy); + else if (realBits == 64) + callee = getOrDeclare(builder, loc, "cpow", funcTy); + else if (realBits == 128) + callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy); + else + return; + args = {base, rhs}; + } + + auto call = fir::CallOp::create(builder, loc, callee, args); + if (auto fmf = op.getFastmathAttr()) + call.setFastmathAttr(fmf); + op.replaceAllUsesWith(call.getResult(0)); + op.erase(); + }); +} diff --git a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 b/flang/test/Driver/bbc-mlir-pass-pipeline.f90 index f3791fe9f8dc3..bf2712d547a82 100644 --- a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 +++ b/flang/test/Driver/bbc-mlir-pass-pipeline.f90 @@ -69,6 +69,7 @@ ! CHECK-NEXT: SCFToControlFlow ! CHECK-NEXT: Canonicalizer ! CHECK-NEXT: SimplifyRegionLite +! CHECK-NEXT: ConvertComplexPow ! CHECK-NEXT: CSE ! CHECK-NEXT: (S) 0 num-cse'd - Number of operations CSE'd ! CHECK-NEXT: (S) 0 num-dce'd - Number of operations DCE'd diff --git a/flang/test/Driver/mlir-debug-pass-pipeline.f90 b/flang/test/Driver/mlir-debug-pass-pipeline.f90 index 42a71b2d6adc3..5943a3c61c342 100644 --- a/flang/test/Driver/mlir-debug-pass-pipeline.f90 +++ b/flang/test/Driver/mlir-debug-pass-pipeline.f90 @@ -96,6 +96,7 @@ ! ALL-NEXT: SCFToControlFlow ! ALL-NEXT: Canonicalizer ! ALL-NEXT: SimplifyRegionLite +! ALL-NEXT: ConvertComplexPow ! ALL-NEXT: CSE ! ALL-NEXT: (S) 0 num-cse'd - Number of operations CSE'd ! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90 index e85a7728fc9af..4fd89d6f15d46 100644 --- a/flang/test/Driver/mlir-pass-pipeline.f90 +++ b/flang/test/Driver/mlir-pass-pipeline.f90 @@ -127,6 +127,7 @@ ! ALL-NEXT: SCFToControlFlow ! ALL-NEXT: Canonicalizer ! ALL-NEXT: SimplifyRegionLite +! ALL-NEXT: ConvertComplexPow ! ALL-NEXT: CSE ! ALL-NEXT: (S) 0 num-cse'd - Number of operations CSE'd ! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir index 0a31397efb332..195e5ad7f9dc8 100644 --- a/flang/test/Fir/basic-program.fir +++ b/flang/test/Fir/basic-program.fir @@ -125,6 +125,7 @@ func.func @_QQmain() { // PASSES-NEXT: SCFToControlFlow // PASSES-NEXT: Canonicalizer // PASSES-NEXT: SimplifyRegionLite +// PASSES-NEXT: ConvertComplexPow // PASSES-NEXT: CSE // PASSES-NEXT: (S) 0 num-cse'd - Number of operations CSE'd // PASSES-NEXT: (S) 0 num-dce'd - Number of operations DCE'd diff --git a/flang/test/Lower/HLFIR/binary-ops.f90 b/flang/test/Lower/HLFIR/binary-ops.f90 index 72cd048ea3615..1fbd333db37c3 100644 --- a/flang/test/Lower/HLFIR/binary-ops.f90 +++ b/flang/test/Lower/HLFIR/binary-ops.f90 @@ -168,7 +168,7 @@ subroutine complex_power(x, y, z) ! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref>, !fir.dscope) -> (!fir.ref>, !fir.ref>) ! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref> ! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref> -! CHECK: %[[VAL_8:.*]] = fir.call @cpowf(%[[VAL_6]], %[[VAL_7]]) fastmath : (complex, complex) -> complex +! CHECK: %[[VAL_8:.*]] = complex.pow %[[VAL_6]], %[[VAL_7]] fastmath : complex subroutine real_to_int_power(x, y, z) @@ -193,7 +193,7 @@ subroutine complex_to_int_power(x, y, z) ! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref, !fir.dscope) -> (!fir.ref, !fir.ref) ! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref> ! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref -! CHECK: %[[VAL_8:.*]] = fir.call @_FortranAcpowi(%[[VAL_6]], %[[VAL_7]]) fastmath : (complex, i32) -> complex +! CHECK: %[[VAL_8:.*]] = complex.pow subroutine extremum(c, n, l) integer(8), intent(in) :: l diff --git a/flang/test/Lower/Intrinsics/pow_complex16.f90 b/flang/test/Lower/Intrinsics/pow_complex16.f90 index 7467986832479..c026dd242e964 100644 --- a/flang/test/Lower/Intrinsics/pow_complex16.f90 +++ b/flang/test/Lower/Intrinsics/pow_complex16.f90 @@ -1,9 +1,10 @@ ! REQUIRES: flang-supports-f128-math ! RUN: bbc -emit-fir %s -o - | FileCheck %s -! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s +! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE" ! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s -! CHECK: fir.call @_FortranACPowF128({{.*}}){{.*}}: (complex, complex) -> complex +! PRECISE: fir.call @_FortranACPowF128({{.*}}){{.*}}: (complex, complex) -> complex +! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath : complex complex(16) :: a, b b = a ** b end diff --git a/flang/test/Lower/Intrinsics/pow_complex16i.f90 b/flang/test/Lower/Intrinsics/pow_complex16i.f90 index 6f8684d9a663a..1827863a57f43 100644 --- a/flang/test/Lower/Intrinsics/pow_complex16i.f90 +++ b/flang/test/Lower/Intrinsics/pow_complex16i.f90 @@ -1,9 +1,10 @@ ! REQUIRES: flang-supports-f128-math ! RUN: bbc -emit-fir %s -o - | FileCheck %s -! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s +! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE" ! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s -! CHECK: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex, i32) -> complex +! PRECISE: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex, i32) -> complex +! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath : complex complex(16) :: a integer(4) :: b b = a ** b diff --git a/flang/test/Lower/Intrinsics/pow_complex16k.f90 b/flang/test/Lower/Intrinsics/pow_complex16k.f90 index d3765050640ae..039dfd5152a06 100644 --- a/flang/test/Lower/Intrinsics/pow_complex16k.f90 +++ b/flang/test/Lower/Intrinsics/pow_complex16k.f90 @@ -1,9 +1,10 @@ ! REQUIRES: flang-supports-f128-math ! RUN: bbc -emit-fir %s -o - | FileCheck %s -! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s +! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE" ! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s -! CHECK: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex, i64) -> complex +! PRECISE: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex, i64) -> complex +! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath : complex complex(16) :: a integer(8) :: b b = a ** b diff --git a/flang/test/Lower/power-operator.f90 b/flang/test/Lower/power-operator.f90 index 7436e031d20cb..3058927144248 100644 --- a/flang/test/Lower/power-operator.f90 +++ b/flang/test/Lower/power-operator.f90 @@ -1,10 +1,10 @@ -! RUN: bbc -emit-fir %s -o - | FileCheck %s --check-prefixes="CHECK,PRECISE" -! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE" -! RUN: bbc --force-mlir-complex -emit-fir %s -o - | FileCheck %s --check-prefixes="FAST" -! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s --check-prefixes="CHECK,PRECISE" -! RUN: %flang_fc1 -fapprox-func -emit-fir %s -o - | FileCheck %s --check-prefixes="CHECK,FAST" -! RUN: %flang_fc1 -emit-fir -mllvm --math-runtime=precise %s -o - | FileCheck %s --check-prefixes="PRECISE" -! RUN: %flang_fc1 -emit-fir -mllvm --force-mlir-complex %s -o - | FileCheck %s --check-prefixes="FAST" +! RUN: bbc -emit-fir %s -o - | FileCheck %s +! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefix=PRECISE +! RUN: bbc --force-mlir-complex -emit-fir %s -o - | FileCheck %s +! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s +! RUN: %flang_fc1 -fapprox-func -emit-fir %s -o - | FileCheck %s +! RUN: %flang_fc1 -emit-fir -mllvm --math-runtime=precise %s -o - | FileCheck %s --check-prefix=PRECISE +! RUN: %flang_fc1 -emit-fir -mllvm --force-mlir-complex %s -o - | FileCheck %s ! Test power operation lowering @@ -96,7 +96,8 @@ subroutine pow_c4_i4(x, y, z) complex :: x, z integer :: y z = x ** y - ! CHECK: call @_FortranAcpowi + ! CHECK: complex.pow + ! PRECISE: fir.call @_FortranAcpowi end subroutine ! CHECK-LABEL: pow_c4_i8 @@ -104,7 +105,8 @@ subroutine pow_c4_i8(x, y, z) complex :: x, z integer(8) :: y z = x ** y - ! CHECK: call @_FortranAcpowk + ! CHECK: complex.pow + ! PRECISE: fir.call @_FortranAcpowk end subroutine ! CHECK-LABEL: pow_c8_i4 @@ -112,7 +114,8 @@ subroutine pow_c8_i4(x, y, z) complex(8) :: x, z integer :: y z = x ** y - ! CHECK: call @_FortranAzpowi + ! CHECK: complex.pow + ! PRECISE: fir.call @_FortranAzpowi end subroutine ! CHECK-LABEL: pow_c8_i8 @@ -120,22 +123,23 @@ subroutine pow_c8_i8(x, y, z) complex(8) :: x, z integer(8) :: y z = x ** y - ! CHECK: call @_FortranAzpowk + ! CHECK: complex.pow + ! PRECISE: fir.call @_FortranAzpowk end subroutine ! CHECK-LABEL: pow_c4_c4 subroutine pow_c4_c4(x, y, z) complex :: x, y, z z = x ** y - ! FAST: complex.pow %{{.*}}, %{{.*}} : complex - ! PRECISE: call @cpowf + ! CHECK: complex.pow %{{.*}}, %{{.*}} : complex + ! PRECISE: fir.call @cpowf end subroutine ! CHECK-LABEL: pow_c8_c8 subroutine pow_c8_c8(x, y, z) complex(8) :: x, y, z z = x ** y - ! FAST: complex.pow %{{.*}}, %{{.*}} : complex - ! PRECISE: call @cpow + ! CHECK: complex.pow %{{.*}}, %{{.*}} : complex + ! PRECISE: fir.call @cpow end subroutine diff --git a/flang/test/Transforms/convert-complex-pow.fir b/flang/test/Transforms/convert-complex-pow.fir new file mode 100644 index 0000000000000..e09fa7316c4b0 --- /dev/null +++ b/flang/test/Transforms/convert-complex-pow.fir @@ -0,0 +1,111 @@ +// RUN: fir-opt --convert-complex-pow %s | FileCheck %s + +module { + func.func @pow_c4_i4(%arg0: complex, %arg1: i32) -> complex { + %c0 = arith.constant 0.0 : f32 + %0 = fir.convert %arg1 : (i32) -> f32 + %1 = complex.create %0, %c0 : complex + %2 = complex.pow %arg0, %1 : complex + return %2 : complex + } + + func.func @pow_c4_i8(%arg0: complex, %arg1: i64) -> complex { + %c0 = arith.constant 0.0 : f32 + %0 = fir.convert %arg1 : (i64) -> f32 + %1 = complex.create %0, %c0 : complex + %2 = complex.pow %arg0, %1 : complex + return %2 : complex + } + + func.func @pow_c8_i4(%arg0: complex, %arg1: i32) -> complex { + %c0 = arith.constant 0.0 : f64 + %0 = fir.convert %arg1 : (i32) -> f64 + %1 = complex.create %0, %c0 : complex + %2 = complex.pow %arg0, %1 : complex + return %2 : complex + } + + func.func @pow_c8_i8(%arg0: complex, %arg1: i64) -> complex { + %c0 = arith.constant 0.0 : f64 + %0 = fir.convert %arg1 : (i64) -> f64 + %1 = complex.create %0, %c0 : complex + %2 = complex.pow %arg0, %1 : complex + return %2 : complex + } + + func.func @pow_c16_i4(%arg0: complex, %arg1: i32) -> complex { + %c0 = arith.constant 0.0 : f128 + %0 = fir.convert %arg1 : (i32) -> f128 + %1 = complex.create %0, %c0 : complex + %2 = complex.pow %arg0, %1 : complex + return %2 : complex + } + + func.func @pow_c16_i8(%arg0: complex, %arg1: i64) -> complex { + %c0 = arith.constant 0.0 : f128 + %0 = fir.convert %arg1 : (i64) -> f128 + %1 = complex.create %0, %c0 : complex + %2 = complex.pow %arg0, %1 : complex + return %2 : complex + } + + func.func @pow_c4_fast(%arg0: complex, %arg1: f32) -> complex { + %c1 = arith.constant 1.0 : f32 + %0 = complex.create %arg1, %c1 : complex + %1 = complex.pow %arg0, %0 fastmath : complex + return %1 : complex + } + + func.func @pow_c8_complex(%arg0: complex, %arg1: f64) -> complex { + %c2 = arith.constant 2.0 : f64 + %0 = complex.create %arg1, %c2 : complex + %1 = complex.pow %arg0, %0 : complex + return %1 : complex + } + + func.func @pow_c16_complex(%arg0: complex, %arg1: f128) -> complex { + %c3 = arith.constant 3.0 : f128 + %0 = complex.create %arg1, %c3 : complex + %1 = complex.pow %arg0, %0 : complex + return %1 : complex + } +} + +// CHECK-LABEL: func.func @pow_c4_i4( +// CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) : (complex, i32) -> complex +// CHECK-NOT: complex.pow + +// CHECK-LABEL: func.func @pow_c4_i8( +// CHECK: fir.call @_FortranAcpowk(%{{.*}}, %{{.*}}) : (complex, i64) -> complex +// CHECK-NOT: complex.pow + +// CHECK-LABEL: func.func @pow_c8_i4( +// CHECK: fir.call @_FortranAzpowi(%{{.*}}, %{{.*}}) : (complex, i32) -> complex +// CHECK-NOT: complex.pow + +// CHECK-LABEL: func.func @pow_c8_i8( +// CHECK: fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}) : (complex, i64) -> complex +// CHECK-NOT: complex.pow + +// CHECK-LABEL: func.func @pow_c16_i4( +// CHECK: fir.call @_FortranAcqpowi(%{{.*}}, %{{.*}}) : (complex, i32) -> complex +// CHECK-NOT: complex.pow + +// CHECK-LABEL: func.func @pow_c16_i8( +// CHECK: fir.call @_FortranAcqpowk(%{{.*}}, %{{.*}}) : (complex, i64) -> complex +// CHECK-NOT: complex.pow + +// CHECK-LABEL: func.func @pow_c4_fast( +// CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex +// CHECK: fir.call @cpowf(%{{.*}}, %[[EXP]]) fastmath : (complex, complex) -> complex +// CHECK-NOT: complex.pow + +// CHECK-LABEL: func.func @pow_c8_complex( +// CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex +// CHECK: fir.call @cpow(%{{.*}}, %[[EXP]]) : (complex, complex) -> complex +// CHECK-NOT: complex.pow + +// CHECK-LABEL: func.func @pow_c16_complex( +// CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex +// CHECK: fir.call @_FortranACPowF128(%{{.*}}, %[[EXP]]) : (complex, complex) -> complex +// CHECK-NOT: complex.pow \ No newline at end of file diff --git a/flang/tools/bbc/bbc.cpp b/flang/tools/bbc/bbc.cpp index 82dff2653ad09..69a45c66a079a 100644 --- a/flang/tools/bbc/bbc.cpp +++ b/flang/tools/bbc/bbc.cpp @@ -538,6 +538,7 @@ static llvm::LogicalResult convertFortranSourceToMLIR( // Add O2 optimizer pass pipeline. MLIRToLLVMPassPipelineConfig config(llvm::OptimizationLevel::O2); + config.SkipConvertComplexPow = targetMachine.getTargetTriple().isAMDGCN(); if (enableOpenMP) config.EnableOpenMP = true; config.NSWOnLoopVarInc = !integerWrapAround; diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp index 0372f32d6b6df..72b1fa6e833f9 100644 --- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp +++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp @@ -64,9 +64,12 @@ struct PowOpToROCDLLibraryCalls : public OpRewritePattern { LogicalResult matchAndRewrite(complex::PowOp op, PatternRewriter &rewriter) const final { Location loc = op.getLoc(); - Value logBase = complex::LogOp::create(rewriter, loc, op.getLhs()); - Value mul = complex::MulOp::create(rewriter, loc, op.getRhs(), logBase); - Value exp = complex::ExpOp::create(rewriter, loc, mul); + auto fastmath = op.getFastmathAttr(); + Value logBase = + complex::LogOp::create(rewriter, loc, op.getLhs(), fastmath); + Value mul = + complex::MulOp::create(rewriter, loc, op.getRhs(), logBase, fastmath); + Value exp = complex::ExpOp::create(rewriter, loc, mul, fastmath); rewriter.replaceOp(op, exp); return success(); }