-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[Flang] Add new ConvertComplexPow pass for Flang #158642
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-flang-driver Author: Akash Banerjee (TIFitis) ChangesThis PR forces lowering to complex.pow ops for flang when The primary motivation for this is to benefit from math optimisations such as x**3 -> xxx. I'll be adding the optimisation shortly in a subsequent PR. Patch is 23.20 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/158642.diff 15 Files Affected:
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index e3001454cdf19..0ed4bb66aff0d 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -551,6 +551,17 @@ def SimplifyFIROperations : Pass<"simplify-fir-operations", "mlir::ModuleOp"> {
"Prefer expanding without using Fortran runtime calls.">];
}
+def ConvertComplexPow : Pass<"convert-complex-pow", "mlir::func::FuncOp"> {
+ let summary = "Convert complex.pow operations to library calls";
+ let description = [{
+ Replace `complex.pow` operations with calls to the appropriate
+ Fortran runtime or libm functions.
+ }];
+ let dependentDialects = ["fir::FIROpsDialect", "mlir::func::FuncDialect",
+ "mlir::complex::ComplexDialect",
+ "mlir::arith::ArithDialect"];
+}
+
def OptimizeArrayRepacking
: Pass<"optimize-array-repacking", "mlir::func::FuncOp"> {
let summary = "Optimizes redundant array repacking operations";
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index ce1376fd209cc..466458c05dba7 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -1327,18 +1327,18 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
const MathOperation &mathOp,
mlir::FunctionType mathLibFuncType,
llvm::ArrayRef<mlir::Value> args) {
- bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN();
- if (!isAMDGPU)
+ if (mathRuntimeVersion == preciseVersion)
return genLibCall(builder, loc, mathOp, mathLibFuncType, args);
-
auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
- auto realTy = complexTy.getElementType();
- mlir::Value realExp = builder.createConvert(loc, realTy, args[1]);
- mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
- mlir::Value complexExp =
- builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
- mlir::Value result =
- builder.create<mlir::complex::PowOp>(loc, args[0], complexExp);
+ mlir::Value exp = args[1];
+ if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
+ auto realTy = complexTy.getElementType();
+ mlir::Value realExp = builder.createConvert(loc, realTy, exp);
+ mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
+ exp =
+ builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
+ }
+ mlir::Value result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
return result;
}
@@ -1668,11 +1668,11 @@ static constexpr MathOperation mathOperations[] = {
{"pow", RTNAME_STRING(PowF128), FuncTypeReal16Real16Real16, genLibF128Call},
{"pow", "cpowf",
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Complex<4>>,
- genComplexMathOp<mlir::complex::PowOp>},
+ genComplexPow},
{"pow", "cpow", genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Complex<8>>,
- genComplexMathOp<mlir::complex::PowOp>},
+ genComplexPow},
{"pow", RTNAME_STRING(CPowF128), FuncTypeComplex16Complex16Complex16,
- genLibF128Call},
+ genComplexPow},
{"pow", RTNAME_STRING(FPow4i),
genFuncType<Ty::Real<4>, Ty::Real<4>, Ty::Integer<4>>,
genMathOp<mlir::math::FPowIOp>},
@@ -1698,7 +1698,7 @@ static constexpr MathOperation mathOperations[] = {
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>,
genComplexPow},
{"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4,
- genLibF128Call},
+ genComplexPow},
{"pow", RTNAME_STRING(cpowk),
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>,
genComplexPow},
@@ -1706,7 +1706,7 @@ static constexpr MathOperation mathOperations[] = {
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>,
genComplexPow},
{"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8,
- genLibF128Call},
+ genComplexPow},
{"pow-unsigned", RTNAME_STRING(UPow1),
genFuncType<Ty::Integer<1>, Ty::Integer<1>, Ty::Integer<1>>, genLibCall},
{"pow-unsigned", RTNAME_STRING(UPow2),
diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp
index 7c2777baebef1..ddcfffc9f158f 100644
--- a/flang/lib/Optimizer/Passes/Pipelines.cpp
+++ b/flang/lib/Optimizer/Passes/Pipelines.cpp
@@ -225,6 +225,7 @@ void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm,
pm.addPass(mlir::createCanonicalizerPass(config));
pm.addPass(fir::createSimplifyRegionLite());
+ pm.addPass(fir::createConvertComplexPow());
pm.addPass(mlir::createCSEPass());
if (pc.OptLevel.isOptimizingForSpeed())
diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index a8812e08c1ccd..4ec16274830fe 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -35,6 +35,7 @@ add_flang_library(FIRTransforms
GenRuntimeCallsForTest.cpp
SimplifyFIROperations.cpp
OptimizeArrayRepacking.cpp
+ ConvertComplexPow.cpp
DEPENDS
CUFAttrs
diff --git a/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
new file mode 100644
index 0000000000000..8b62237cf539d
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp
@@ -0,0 +1,125 @@
+//===- ConvertComplexPow.cpp - Convert complex.pow to library calls -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Common/static-multimap-view.h"
+#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "flang/Runtime/entry-names.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Pass/Pass.h"
+
+namespace fir {
+#define GEN_PASS_DEF_CONVERTCOMPLEXPOW
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+using namespace mlir;
+
+namespace {
+class ConvertComplexPowPass
+ : public fir::impl::ConvertComplexPowBase<ConvertComplexPowPass> {
+public:
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<fir::FIROpsDialect, complex::ComplexDialect,
+ arith::ArithDialect, func::FuncDialect>();
+ }
+ void runOnOperation() override;
+};
+} // namespace
+
+// Helper to declare or get a math library function.
+static func::FuncOp getOrDeclare(fir::FirOpBuilder &builder, Location loc,
+ StringRef name, FunctionType type) {
+ if (auto func = builder.getNamedFunction(name))
+ return func;
+ auto func = builder.createFunction(loc, name, type);
+ func->setAttr(fir::getSymbolAttrName(), builder.getStringAttr(name));
+ func->setAttr(fir::FIROpsDialect::getFirRuntimeAttrName(),
+ builder.getUnitAttr());
+ return func;
+}
+
+static bool isZero(Value v) {
+ if (auto cst = v.getDefiningOp<arith::ConstantOp>())
+ if (auto attr = dyn_cast<FloatAttr>(cst.getValue()))
+ return attr.getValue().isZero();
+ return false;
+}
+
+void ConvertComplexPowPass::runOnOperation() {
+ auto func = getOperation();
+ auto mod = func->getParentOfType<ModuleOp>();
+ if (fir::getTargetTriple(mod).isAMDGCN())
+ return;
+
+ fir::FirOpBuilder builder(func, fir::getKindMapping(mod));
+
+ func.walk([&](complex::PowOp op) {
+ builder.setInsertionPoint(op);
+ Location loc = op.getLoc();
+ auto complexTy = cast<ComplexType>(op.getType());
+ auto elemTy = complexTy.getElementType();
+
+ Value base = op.getLhs();
+ Value rhs = op.getRhs();
+
+ Value intExp;
+ if (auto create = rhs.getDefiningOp<complex::CreateOp>()) {
+ if (isZero(create.getImaginary())) {
+ if (auto conv = create.getReal().getDefiningOp<fir::ConvertOp>()) {
+ if (auto intTy = dyn_cast<IntegerType>(conv.getValue().getType()))
+ intExp = conv.getValue();
+ }
+ }
+ }
+
+ func::FuncOp callee;
+ SmallVector<Value> args;
+ if (intExp) {
+ unsigned realBits = cast<FloatType>(elemTy).getWidth();
+ unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
+ auto funcTy = builder.getFunctionType(
+ {complexTy, builder.getIntegerType(intBits)}, {complexTy});
+ if (realBits == 32 && intBits == 32)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowi), funcTy);
+ else if (realBits == 32 && intBits == 64)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(cpowk), funcTy);
+ else if (realBits == 64 && intBits == 32)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowi), funcTy);
+ else if (realBits == 64 && intBits == 64)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(zpowk), funcTy);
+ else if (realBits == 128 && intBits == 32)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowi), funcTy);
+ else if (realBits == 128 && intBits == 64)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
+ else
+ return;
+ args = {base, intExp};
+ } else {
+ unsigned realBits = cast<FloatType>(elemTy).getWidth();
+ auto funcTy =
+ builder.getFunctionType({complexTy, complexTy}, {complexTy});
+ if (realBits == 32)
+ callee = getOrDeclare(builder, loc, "cpowf", funcTy);
+ else if (realBits == 64)
+ callee = getOrDeclare(builder, loc, "cpow", funcTy);
+ else if (realBits == 128)
+ callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
+ else
+ return;
+ args = {base, rhs};
+ }
+
+ auto call = fir::CallOp::create(builder, loc, callee, args);
+ op.replaceAllUsesWith(call.getResult(0));
+ op.erase();
+ });
+}
diff --git a/flang/test/Driver/bbc-mlir-pass-pipeline.f90 b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
index f3791fe9f8dc3..30cb97e4455ee 100644
--- a/flang/test/Driver/bbc-mlir-pass-pipeline.f90
+++ b/flang/test/Driver/bbc-mlir-pass-pipeline.f90
@@ -69,6 +69,8 @@
! CHECK-NEXT: SCFToControlFlow
! CHECK-NEXT: Canonicalizer
! CHECK-NEXT: SimplifyRegionLite
+! CHECK-NEXT: 'func.func' Pipeline
+! CHECK-NEXT: ConvertComplexPow
! CHECK-NEXT: CSE
! CHECK-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
! CHECK-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
diff --git a/flang/test/Driver/mlir-debug-pass-pipeline.f90 b/flang/test/Driver/mlir-debug-pass-pipeline.f90
index 42a71b2d6adc3..bb6d5509c3269 100644
--- a/flang/test/Driver/mlir-debug-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-debug-pass-pipeline.f90
@@ -96,6 +96,8 @@
! ALL-NEXT: SCFToControlFlow
! ALL-NEXT: Canonicalizer
! ALL-NEXT: SimplifyRegionLite
+! ALL-NEXT: 'func.func' Pipeline
+! ALL-NEXT: ConvertComplexPow
! ALL-NEXT: CSE
! ALL-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
diff --git a/flang/test/Driver/mlir-pass-pipeline.f90 b/flang/test/Driver/mlir-pass-pipeline.f90
index e85a7728fc9af..6006f6672ee72 100644
--- a/flang/test/Driver/mlir-pass-pipeline.f90
+++ b/flang/test/Driver/mlir-pass-pipeline.f90
@@ -127,6 +127,8 @@
! ALL-NEXT: SCFToControlFlow
! ALL-NEXT: Canonicalizer
! ALL-NEXT: SimplifyRegionLite
+! ALL-NEXT: 'func.func' Pipeline
+! ALL-NEXT: ConvertComplexPow
! ALL-NEXT: CSE
! ALL-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index 0a31397efb332..a2e3cda8f2325 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -125,6 +125,8 @@ func.func @_QQmain() {
// PASSES-NEXT: SCFToControlFlow
// PASSES-NEXT: Canonicalizer
// PASSES-NEXT: SimplifyRegionLite
+// PASSES-NEXT: 'func.func' Pipeline
+// PASSES-NEXT: ConvertComplexPow
// PASSES-NEXT: CSE
// PASSES-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
// PASSES-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
diff --git a/flang/test/Lower/HLFIR/binary-ops.f90 b/flang/test/Lower/HLFIR/binary-ops.f90
index 72cd048ea3615..1fbd333db37c3 100644
--- a/flang/test/Lower/HLFIR/binary-ops.f90
+++ b/flang/test/Lower/HLFIR/binary-ops.f90
@@ -168,7 +168,7 @@ subroutine complex_power(x, y, z)
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<complex<f32>>, !fir.dscope) -> (!fir.ref<complex<f32>>, !fir.ref<complex<f32>>)
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<complex<f32>>
-! CHECK: %[[VAL_8:.*]] = fir.call @cpowf(%[[VAL_6]], %[[VAL_7]]) fastmath<contract> : (complex<f32>, complex<f32>) -> complex<f32>
+! CHECK: %[[VAL_8:.*]] = complex.pow %[[VAL_6]], %[[VAL_7]] fastmath<contract> : complex<f32>
subroutine real_to_int_power(x, y, z)
@@ -193,7 +193,7 @@ subroutine complex_to_int_power(x, y, z)
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
-! CHECK: %[[VAL_8:.*]] = fir.call @_FortranAcpowi(%[[VAL_6]], %[[VAL_7]]) fastmath<contract> : (complex<f32>, i32) -> complex<f32>
+! CHECK: %[[VAL_8:.*]] = complex.pow
subroutine extremum(c, n, l)
integer(8), intent(in) :: l
diff --git a/flang/test/Lower/Intrinsics/pow_complex16.f90 b/flang/test/Lower/Intrinsics/pow_complex16.f90
index 7467986832479..c026dd242e964 100644
--- a/flang/test/Lower/Intrinsics/pow_complex16.f90
+++ b/flang/test/Lower/Intrinsics/pow_complex16.f90
@@ -1,9 +1,10 @@
! REQUIRES: flang-supports-f128-math
! RUN: bbc -emit-fir %s -o - | FileCheck %s
-! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s
+! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE"
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
-! CHECK: fir.call @_FortranACPowF128({{.*}}){{.*}}: (complex<f128>, complex<f128>) -> complex<f128>
+! PRECISE: fir.call @_FortranACPowF128({{.*}}){{.*}}: (complex<f128>, complex<f128>) -> complex<f128>
+! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
complex(16) :: a, b
b = a ** b
end
diff --git a/flang/test/Lower/Intrinsics/pow_complex16i.f90 b/flang/test/Lower/Intrinsics/pow_complex16i.f90
index 6f8684d9a663a..1827863a57f43 100644
--- a/flang/test/Lower/Intrinsics/pow_complex16i.f90
+++ b/flang/test/Lower/Intrinsics/pow_complex16i.f90
@@ -1,9 +1,10 @@
! REQUIRES: flang-supports-f128-math
! RUN: bbc -emit-fir %s -o - | FileCheck %s
-! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s
+! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE"
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
-! CHECK: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> complex<f128>
+! PRECISE: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> complex<f128>
+! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
complex(16) :: a
integer(4) :: b
b = a ** b
diff --git a/flang/test/Lower/Intrinsics/pow_complex16k.f90 b/flang/test/Lower/Intrinsics/pow_complex16k.f90
index d3765050640ae..039dfd5152a06 100644
--- a/flang/test/Lower/Intrinsics/pow_complex16k.f90
+++ b/flang/test/Lower/Intrinsics/pow_complex16k.f90
@@ -1,9 +1,10 @@
! REQUIRES: flang-supports-f128-math
! RUN: bbc -emit-fir %s -o - | FileCheck %s
-! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s
+! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE"
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
-! CHECK: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> complex<f128>
+! PRECISE: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> complex<f128>
+! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
complex(16) :: a
integer(8) :: b
b = a ** b
diff --git a/flang/test/Lower/power-operator.f90 b/flang/test/Lower/power-operator.f90
index 7436e031d20cb..3058927144248 100644
--- a/flang/test/Lower/power-operator.f90
+++ b/flang/test/Lower/power-operator.f90
@@ -1,10 +1,10 @@
-! RUN: bbc -emit-fir %s -o - | FileCheck %s --check-prefixes="CHECK,PRECISE"
-! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefixes="PRECISE"
-! RUN: bbc --force-mlir-complex -emit-fir %s -o - | FileCheck %s --check-prefixes="FAST"
-! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s --check-prefixes="CHECK,PRECISE"
-! RUN: %flang_fc1 -fapprox-func -emit-fir %s -o - | FileCheck %s --check-prefixes="CHECK,FAST"
-! RUN: %flang_fc1 -emit-fir -mllvm --math-runtime=precise %s -o - | FileCheck %s --check-prefixes="PRECISE"
-! RUN: %flang_fc1 -emit-fir -mllvm --force-mlir-complex %s -o - | FileCheck %s --check-prefixes="FAST"
+! RUN: bbc -emit-fir %s -o - | FileCheck %s
+! RUN: bbc --math-runtime=precise -emit-fir %s -o - | FileCheck %s --check-prefix=PRECISE
+! RUN: bbc --force-mlir-complex -emit-fir %s -o - | FileCheck %s
+! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
+! RUN: %flang_fc1 -fapprox-func -emit-fir %s -o - | FileCheck %s
+! RUN: %flang_fc1 -emit-fir -mllvm --math-runtime=precise %s -o - | FileCheck %s --check-prefix=PRECISE
+! RUN: %flang_fc1 -emit-fir -mllvm --force-mlir-complex %s -o - | FileCheck %s
! Test power operation lowering
@@ -96,7 +96,8 @@ subroutine pow_c4_i4(x, y, z)
complex :: x, z
integer :: y
z = x ** y
- ! CHECK: call @_FortranAcpowi
+ ! CHECK: complex.pow
+ ! PRECISE: fir.call @_FortranAcpowi
end subroutine
! CHECK-LABEL: pow_c4_i8
@@ -104,7 +105,8 @@ subroutine pow_c4_i8(x, y, z)
complex :: x, z
integer(8) :: y
z = x ** y
- ! CHECK: call @_FortranAcpowk
+ ! CHECK: complex.pow
+ ! PRECISE: fir.call @_FortranAcpowk
end subroutine
! CHECK-LABEL: pow_c8_i4
@@ -112,7 +114,8 @@ subroutine pow_c8_i4(x, y, z)
complex(8) :: x, z
integer :: y
z = x ** y
- ! CHECK: call @_FortranAzpowi
+ ! CHECK: complex.pow
+ ! PRECISE: fir.call @_FortranAzpowi
end subroutine
! CHECK-LABEL: pow_c8_i8
@@ -120,22 +123,23 @@ subroutine pow_c8_i8(x, y, z)
complex(8) :: x, z
integer(8) :: y
z = x ** y
- ! CHECK: call @_FortranAzpowk
+ ! CHECK: complex.pow
+ ! PRECISE: fir.call @_FortranAzpowk
end subroutine
! CHECK-LABEL: pow_c4_c4
subroutine pow_c4_c4(x, y, z)
complex :: x, y, z
z = x ** y
- ! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f32>
- ! PRECISE: call @cpowf
+ ! CHECK: complex.pow %{{.*}}, %{{.*}} : complex<f32>
+ ! PRECISE: fir.call @cpowf
end subroutine
! CHECK-LABEL: pow_c8_c8
subroutine pow_c8_c8(x, y, z)
complex(8) :: x, y, z
z = x ** y
- ! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f64>
- ! PRECISE: call @cpow
+ ! CHECK: complex.pow %{{.*}}, %{{.*}} : complex<f64>
+ ! PRECISE: fir.call @cpow
end subroutine
diff --git a/flang/test/Transforms/convert-complex-pow.fir b/flang/test/Transforms/convert-complex-pow.fir
new file mode 100644
index 0000000000000..d980817aba9b9
--- /dev/null
+++ b/flang/test/Transforms/convert-complex-pow.fir
@@ -0,0 +1,102 @@
+// RUN: fir-opt --convert-complex-pow %s | FileCheck %s
+
+module {
+ func.func @pow_c4_i4(%arg0: complex<f32>, %arg1: i32) -> complex<f32> {
+ %c0 = arith.constant 0.000000e+00 : f32
+ %c1 = fir.convert %arg1 : (i32) -> f32
+ %c2 = complex.create %c1, %c0 : complex<f32>
+ %0 = complex.pow %arg0, %c2 : complex<f32>
+ return %0 : complex<f32>
+ }
+
+ func.func @pow_c4_i8(%arg0: complex<f32>, %arg1: i64) -> complex<f32> {
+ %c0 = arith.constant 0.000000e+00 : f32
+ %c1 = fir.convert %arg1 : (i64) -> f32
+ %c2 = complex.create %c1, %c0 : complex<f32>
+ %0...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a new ConvertComplexPow
pass for Flang that handles complex power operations. The change forces lowering to complex.pow
operations when --math-runtime=precise
is not used, then provides a new pass to convert these operations back to library calls.
- Adds a new
ConvertComplexPow
pass that convertscomplex.pow
ops to appropriate runtime library calls - Updates complex power lowering to use
complex.pow
operations by default instead of direct library calls - Updates test expectations to reflect the new lowering behavior
Reviewed Changes
Copilot reviewed 15 out of 15 changed files in this pull request and generated 2 comments.
Show a summary per file
File | Description |
---|---|
flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp | New pass implementation that converts complex.pow ops to library calls |
flang/lib/Optimizer/Builder/IntrinsicCall.cpp | Updated genComplexPow to generate complex.pow ops instead of direct calls |
flang/lib/Optimizer/Passes/Pipelines.cpp | Integrated the new pass into the compilation pipeline |
flang/include/flang/Optimizer/Transforms/Passes.td | Added pass definition and documentation |
flang/lib/Optimizer/Transforms/CMakeLists.txt | Added new source file to build |
flang/test/Transforms/convert-complex-pow.fir | New test file for the ConvertComplexPow pass |
Multiple test files | Updated to reflect new lowering behavior and test expectations |
else | ||
return; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Early returns without error handling or logging make debugging difficult. Consider adding a diagnostic message or comment explaining why these combinations are unsupported.
Copilot uses AI. Check for mistakes.
else | ||
return; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Early return without error handling or logging makes debugging difficult. Consider adding a diagnostic message or comment explaining why this bit width is unsupported.
else | |
return; | |
else { | |
emitWarning(loc, "Unsupported complex.pow bit width: realBits=" + | |
std::to_string(realBits)); | |
return; | |
} |
Copilot uses AI. Check for mistakes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the changes! Please see my comments inlined.
// CHECK-NOT: complex.pow | ||
|
||
// CHECK-LABEL: func.func @pow_c4_c4( | ||
// CHECK: fir.call @cpowf(%{{.*}}, %{{.*}}) : (complex<f32>, complex<f32>) -> complex<f32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please verify that the fast-math flags are properly propagated from complex.pow
to the call?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
void ConvertComplexPowPass::runOnOperation() { | ||
ModuleOp mod = getOperation(); | ||
if (fir::getTargetTriple(mod).isAMDGCN()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please make the pass generic, and instead not schedule it for AMDGCN in the Pipelines.cpp
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
return; | ||
args = {base, intExp}; | ||
} else { | ||
unsigned realBits = cast<FloatType>(elemTy).getWidth(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am really worried about dropping the imaginary part for these cases. Imagine, somewhere in Flang we start generating complex.pow
with a true complex exponent. This pass will just silently drop it and produce incorrect code. Please add a LIT test for this case.
I think we need to keep complex.pow
if we cannot prove that the imaginary part is zero.
Ideally, we should have powi
and powf
operations in the complex
dialect, so that we do not have to rely on the particular fir.convert
/complex.create
pattern generated by the lowering. Moreover, SSA values may become block arguments making it harder to recognize the specific pattern even more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, we should have
powi
andpowf
operations in thecomplex
dialect, so that we do not have to rely on the particularfir.convert
/complex.create
pattern generated by the lowering. Moreover, SSA values may become block arguments making it harder to recognize the specific pattern even more.
I have added powi
in #158722. I'll address the rest of this comment along with other comments tomorrow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for adding powi
! I think adding powf
is not required. I just misread the code.
genComplexPow}, | ||
{"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4, | ||
genLibF128Call}, | ||
genComplexPow}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When you have complex.powi
, I think we can just use genComplexMathOp<mlir::complex::powi>
or genMathOp<mlir::complex::powi>
here.
We can probably get rid of genComplexPow
and use genMathOp
instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
genComplexMathOp
would lower to libCall
if (!forceMlirComplex && !canUseApprox && !isAMDGPU)
which means we would restrict lowering to complex.pow
for some cases where we are currently forcing it. Is that okay?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if (!forceMlirComplex && !canUseApprox && !isAMDGPU)
check is yet another workaround that has to be removed eventually (not in this PR).
I think genMathOp
should work here just fine or am I missing something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I glossed over genMathOp
, I've added this change in #158722 to removed genComplexPow
.
Propagate fast-math attributes, update tests.
Fix Targetmachine build error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM assuming that the final code will look like in #158722.
This PR adds a new complex.powi operation to MLIR's complex dialect for computing complex numbers raised to integer powers. Key changes include: - Addition of the new `PowiOp` operation definition in the Complex dialect - Integration with algebraic simplification passes for optimization - Support for conversion to ROCDL library calls - Updates to Flang frontend to generate the new operation This depends on #158642.
|
I noticed this when updating flang on my laptop. I'll try to get a stacktrace of flang with debug info enabled.
Unfortunately it doesn't fix it. EDIT: I confused this crash with another issue related to lld. Removed the non-related lld comments. |
The
|
This change fixes the issue for me:
|
@luporl Thanks for the stack trace. I'll report back on this tomorrow. |
Thanks! |
This PR introduces a new
ConvertComplexPow
pass for Flang that handles complex power operations. The change forces lowering to complex.pow operations when--math-runtime=precise
is not used, then uses theConvertComplexPow
pass to convert these operations back to library calls.ConvertComplexPow
pass that converts complex.pow ops to appropriate runtime library callscomplex.pow
operations by default instead of direct library calls#158722 Adds a new
complex.powi
op enabling algebraic optimisations.