Skip to content

Commit fdb1f48

Browse files
authored
[MLIR] Add new complex.powi op (#158722)
This PR adds a new complex.powi operation to MLIR's complex dialect for computing complex numbers raised to integer powers. Key changes include: - Addition of the new `PowiOp` operation definition in the Complex dialect - Integration with algebraic simplification passes for optimization - Support for conversion to ROCDL library calls - Updates to Flang frontend to generate the new operation This depends on #158642.
1 parent 1ad5d63 commit fdb1f48

File tree

16 files changed

+270
-120
lines changed

16 files changed

+270
-120
lines changed

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,26 +1323,6 @@ mlir::Value genComplexMathOp(fir::FirOpBuilder &builder, mlir::Location loc,
13231323
return result;
13241324
}
13251325

1326-
mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
1327-
const MathOperation &mathOp,
1328-
mlir::FunctionType mathLibFuncType,
1329-
llvm::ArrayRef<mlir::Value> args) {
1330-
if (mathRuntimeVersion == preciseVersion)
1331-
return genLibCall(builder, loc, mathOp, mathLibFuncType, args);
1332-
auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
1333-
mlir::Value exp = args[1];
1334-
if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
1335-
auto realTy = complexTy.getElementType();
1336-
mlir::Value realExp = builder.createConvert(loc, realTy, exp);
1337-
mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
1338-
exp =
1339-
builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
1340-
}
1341-
mlir::Value result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
1342-
result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
1343-
return result;
1344-
}
1345-
13461326
/// Mapping between mathematical intrinsic operations and MLIR operations
13471327
/// of some appropriate dialect (math, complex, etc.) or libm calls.
13481328
/// TODO: support remaining Fortran math intrinsics.
@@ -1668,11 +1648,11 @@ static constexpr MathOperation mathOperations[] = {
16681648
{"pow", RTNAME_STRING(PowF128), FuncTypeReal16Real16Real16, genLibF128Call},
16691649
{"pow", "cpowf",
16701650
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Complex<4>>,
1671-
genComplexPow},
1651+
genMathOp<mlir::complex::PowOp>},
16721652
{"pow", "cpow", genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Complex<8>>,
1673-
genComplexPow},
1653+
genMathOp<mlir::complex::PowOp>},
16741654
{"pow", RTNAME_STRING(CPowF128), FuncTypeComplex16Complex16Complex16,
1675-
genComplexPow},
1655+
genMathOp<mlir::complex::PowOp>},
16761656
{"pow", RTNAME_STRING(FPow4i),
16771657
genFuncType<Ty::Real<4>, Ty::Real<4>, Ty::Integer<4>>,
16781658
genMathOp<mlir::math::FPowIOp>},
@@ -1693,20 +1673,20 @@ static constexpr MathOperation mathOperations[] = {
16931673
genMathOp<mlir::math::FPowIOp>},
16941674
{"pow", RTNAME_STRING(cpowi),
16951675
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>,
1696-
genComplexPow},
1676+
genMathOp<mlir::complex::PowiOp>},
16971677
{"pow", RTNAME_STRING(zpowi),
16981678
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>,
1699-
genComplexPow},
1679+
genMathOp<mlir::complex::PowiOp>},
17001680
{"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4,
1701-
genComplexPow},
1681+
genMathOp<mlir::complex::PowiOp>},
17021682
{"pow", RTNAME_STRING(cpowk),
17031683
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>,
1704-
genComplexPow},
1684+
genMathOp<mlir::complex::PowiOp>},
17051685
{"pow", RTNAME_STRING(zpowk),
17061686
genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>,
1707-
genComplexPow},
1687+
genMathOp<mlir::complex::PowiOp>},
17081688
{"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8,
1709-
genComplexPow},
1689+
genMathOp<mlir::complex::PowiOp>},
17101690
{"pow-unsigned", RTNAME_STRING(UPow1),
17111691
genFuncType<Ty::Integer<1>, Ty::Integer<1>, Ty::Integer<1>>, genLibCall},
17121692
{"pow-unsigned", RTNAME_STRING(UPow2),

flang/lib/Optimizer/Transforms/ConvertComplexPow.cpp

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -47,39 +47,19 @@ static func::FuncOp getOrDeclare(fir::FirOpBuilder &builder, Location loc,
4747
return func;
4848
}
4949

50-
static bool isZero(Value v) {
51-
if (auto cst = v.getDefiningOp<arith::ConstantOp>())
52-
if (auto attr = dyn_cast<FloatAttr>(cst.getValue()))
53-
return attr.getValue().isZero();
54-
return false;
55-
}
56-
5750
void ConvertComplexPowPass::runOnOperation() {
5851
ModuleOp mod = getOperation();
5952
fir::FirOpBuilder builder(mod, fir::getKindMapping(mod));
6053

61-
mod.walk([&](complex::PowOp op) {
62-
builder.setInsertionPoint(op);
63-
Location loc = op.getLoc();
64-
auto complexTy = cast<ComplexType>(op.getType());
65-
auto elemTy = complexTy.getElementType();
66-
67-
Value base = op.getLhs();
68-
Value rhs = op.getRhs();
69-
70-
Value intExp;
71-
if (auto create = rhs.getDefiningOp<complex::CreateOp>()) {
72-
if (isZero(create.getImaginary())) {
73-
if (auto conv = create.getReal().getDefiningOp<fir::ConvertOp>()) {
74-
if (auto intTy = dyn_cast<IntegerType>(conv.getValue().getType()))
75-
intExp = conv.getValue();
76-
}
77-
}
78-
}
79-
80-
func::FuncOp callee;
81-
SmallVector<Value> args;
82-
if (intExp) {
54+
mod.walk([&](Operation *op) {
55+
if (auto powIop = dyn_cast<complex::PowiOp>(op)) {
56+
builder.setInsertionPoint(powIop);
57+
Location loc = powIop.getLoc();
58+
auto complexTy = cast<ComplexType>(powIop.getType());
59+
auto elemTy = complexTy.getElementType();
60+
Value base = powIop.getLhs();
61+
Value intExp = powIop.getRhs();
62+
func::FuncOp callee;
8363
unsigned realBits = cast<FloatType>(elemTy).getWidth();
8464
unsigned intBits = cast<IntegerType>(intExp.getType()).getWidth();
8565
auto funcTy = builder.getFunctionType(
@@ -98,9 +78,20 @@ void ConvertComplexPowPass::runOnOperation() {
9878
callee = getOrDeclare(builder, loc, RTNAME_STRING(cqpowk), funcTy);
9979
else
10080
return;
101-
args = {base, intExp};
102-
} else {
81+
auto call = fir::CallOp::create(builder, loc, callee, {base, intExp});
82+
if (auto fmf = powIop.getFastmathAttr())
83+
call.setFastmathAttr(fmf);
84+
powIop.replaceAllUsesWith(call.getResult(0));
85+
powIop.erase();
86+
}
87+
88+
if (auto powOp = dyn_cast<complex::PowOp>(op)) {
89+
builder.setInsertionPoint(powOp);
90+
Location loc = powOp.getLoc();
91+
auto complexTy = cast<ComplexType>(powOp.getType());
92+
auto elemTy = complexTy.getElementType();
10393
unsigned realBits = cast<FloatType>(elemTy).getWidth();
94+
func::FuncOp callee;
10495
auto funcTy =
10596
builder.getFunctionType({complexTy, complexTy}, {complexTy});
10697
if (realBits == 32)
@@ -111,13 +102,12 @@ void ConvertComplexPowPass::runOnOperation() {
111102
callee = getOrDeclare(builder, loc, RTNAME_STRING(CPowF128), funcTy);
112103
else
113104
return;
114-
args = {base, rhs};
105+
auto call = fir::CallOp::create(builder, loc, callee,
106+
{powOp.getLhs(), powOp.getRhs()});
107+
if (auto fmf = powOp.getFastmathAttr())
108+
call.setFastmathAttr(fmf);
109+
powOp.replaceAllUsesWith(call.getResult(0));
110+
powOp.erase();
115111
}
116-
117-
auto call = fir::CallOp::create(builder, loc, callee, args);
118-
if (auto fmf = op.getFastmathAttr())
119-
call.setFastmathAttr(fmf);
120-
op.replaceAllUsesWith(call.getResult(0));
121-
op.erase();
122112
});
123113
}

flang/test/Lower/HLFIR/binary-ops.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ subroutine complex_to_int_power(x, y, z)
193193
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
194194
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
195195
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
196-
! CHECK: %[[VAL_8:.*]] = complex.pow
196+
! CHECK: %[[VAL_8:.*]] = complex.powi %[[VAL_6]], %[[VAL_7]] fastmath<contract> : complex<f32>, i32
197197

198198
subroutine extremum(c, n, l)
199199
integer(8), intent(in) :: l

flang/test/Lower/Intrinsics/pow_complex16i.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
55

66
! PRECISE: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> complex<f128>
7-
! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
7+
! CHECK: complex.powi %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
88
complex(16) :: a
99
integer(4) :: b
1010
b = a ** b

flang/test/Lower/Intrinsics/pow_complex16k.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
55

66
! PRECISE: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> complex<f128>
7-
! CHECK: complex.pow %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
7+
! CHECK: complex.powi %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
88
complex(16) :: a
99
integer(8) :: b
1010
b = a ** b

flang/test/Lower/amdgcn-complex.f90

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,12 @@ subroutine pow_test(a, b, c)
2525
complex :: a, b, c
2626
a = b**c
2727
end subroutine pow_test
28+
29+
! CHECK-LABEL: func @_QPpowi_test(
30+
! CHECK: complex.powi
31+
! CHECK-NOT: fir.call @_FortranAcpowi
32+
subroutine powi_test(a, b, c)
33+
complex :: a, b
34+
integer :: i
35+
b = a ** i
36+
end subroutine powi_test

flang/test/Lower/power-operator.f90

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ subroutine pow_c4_i4(x, y, z)
9696
complex :: x, z
9797
integer :: y
9898
z = x ** y
99-
! CHECK: complex.pow
99+
! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i32
100100
! PRECISE: fir.call @_FortranAcpowi
101101
end subroutine
102102

@@ -105,7 +105,7 @@ subroutine pow_c4_i8(x, y, z)
105105
complex :: x, z
106106
integer(8) :: y
107107
z = x ** y
108-
! CHECK: complex.pow
108+
! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f32>, i64
109109
! PRECISE: fir.call @_FortranAcpowk
110110
end subroutine
111111

@@ -114,7 +114,7 @@ subroutine pow_c8_i4(x, y, z)
114114
complex(8) :: x, z
115115
integer :: y
116116
z = x ** y
117-
! CHECK: complex.pow
117+
! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f64>, i32
118118
! PRECISE: fir.call @_FortranAzpowi
119119
end subroutine
120120

@@ -123,7 +123,7 @@ subroutine pow_c8_i8(x, y, z)
123123
complex(8) :: x, z
124124
integer(8) :: y
125125
z = x ** y
126-
! CHECK: complex.pow
126+
! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f64>, i64
127127
! PRECISE: fir.call @_FortranAzpowk
128128
end subroutine
129129

@@ -142,4 +142,3 @@ subroutine pow_c8_c8(x, y, z)
142142
! CHECK: complex.pow %{{.*}}, %{{.*}} : complex<f64>
143143
! PRECISE: fir.call @cpow
144144
end subroutine
145-

flang/test/Transforms/convert-complex-pow.fir

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,51 +2,38 @@
22

33
module {
44
func.func @pow_c4_i4(%arg0: complex<f32>, %arg1: i32) -> complex<f32> {
5-
%c0 = arith.constant 0.0 : f32
6-
%0 = fir.convert %arg1 : (i32) -> f32
7-
%1 = complex.create %0, %c0 : complex<f32>
8-
%2 = complex.pow %arg0, %1 : complex<f32>
9-
return %2 : complex<f32>
5+
%0 = complex.powi %arg0, %arg1 : complex<f32>, i32
6+
return %0 : complex<f32>
7+
}
8+
9+
func.func @pow_c4_i4_fast(%arg0: complex<f32>, %arg1: i32) -> complex<f32> {
10+
%0 = complex.powi %arg0, %arg1 fastmath<fast> : complex<f32>, i32
11+
return %0 : complex<f32>
1012
}
1113

1214
func.func @pow_c4_i8(%arg0: complex<f32>, %arg1: i64) -> complex<f32> {
13-
%c0 = arith.constant 0.0 : f32
14-
%0 = fir.convert %arg1 : (i64) -> f32
15-
%1 = complex.create %0, %c0 : complex<f32>
16-
%2 = complex.pow %arg0, %1 : complex<f32>
17-
return %2 : complex<f32>
15+
%0 = complex.powi %arg0, %arg1 : complex<f32>, i64
16+
return %0 : complex<f32>
1817
}
1918

2019
func.func @pow_c8_i4(%arg0: complex<f64>, %arg1: i32) -> complex<f64> {
21-
%c0 = arith.constant 0.0 : f64
22-
%0 = fir.convert %arg1 : (i32) -> f64
23-
%1 = complex.create %0, %c0 : complex<f64>
24-
%2 = complex.pow %arg0, %1 : complex<f64>
25-
return %2 : complex<f64>
20+
%0 = complex.powi %arg0, %arg1 : complex<f64>, i32
21+
return %0 : complex<f64>
2622
}
2723

2824
func.func @pow_c8_i8(%arg0: complex<f64>, %arg1: i64) -> complex<f64> {
29-
%c0 = arith.constant 0.0 : f64
30-
%0 = fir.convert %arg1 : (i64) -> f64
31-
%1 = complex.create %0, %c0 : complex<f64>
32-
%2 = complex.pow %arg0, %1 : complex<f64>
33-
return %2 : complex<f64>
25+
%0 = complex.powi %arg0, %arg1 : complex<f64>, i64
26+
return %0 : complex<f64>
3427
}
3528

3629
func.func @pow_c16_i4(%arg0: complex<f128>, %arg1: i32) -> complex<f128> {
37-
%c0 = arith.constant 0.0 : f128
38-
%0 = fir.convert %arg1 : (i32) -> f128
39-
%1 = complex.create %0, %c0 : complex<f128>
40-
%2 = complex.pow %arg0, %1 : complex<f128>
41-
return %2 : complex<f128>
30+
%0 = complex.powi %arg0, %arg1 : complex<f128>, i32
31+
return %0 : complex<f128>
4232
}
4333

4434
func.func @pow_c16_i8(%arg0: complex<f128>, %arg1: i64) -> complex<f128> {
45-
%c0 = arith.constant 0.0 : f128
46-
%0 = fir.convert %arg1 : (i64) -> f128
47-
%1 = complex.create %0, %c0 : complex<f128>
48-
%2 = complex.pow %arg0, %1 : complex<f128>
49-
return %2 : complex<f128>
35+
%0 = complex.powi %arg0, %arg1 : complex<f128>, i64
36+
return %0 : complex<f128>
5037
}
5138

5239
func.func @pow_c4_fast(%arg0: complex<f32>, %arg1: f32) -> complex<f32> {
@@ -74,26 +61,37 @@ module {
7461
// CHECK-LABEL: func.func @pow_c4_i4(
7562
// CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) : (complex<f32>, i32) -> complex<f32>
7663
// CHECK-NOT: complex.pow
64+
// CHECK-NOT: complex.powi
65+
66+
// CHECK-LABEL: func.func @pow_c4_i4_fast(
67+
// CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) fastmath<fast> : (complex<f32>, i32) -> complex<f32>
68+
// CHECK-NOT: complex.pow
69+
// CHECK-NOT: complex.powi
7770

7871
// CHECK-LABEL: func.func @pow_c4_i8(
7972
// CHECK: fir.call @_FortranAcpowk(%{{.*}}, %{{.*}}) : (complex<f32>, i64) -> complex<f32>
8073
// CHECK-NOT: complex.pow
74+
// CHECK-NOT: complex.powi
8175

8276
// CHECK-LABEL: func.func @pow_c8_i4(
8377
// CHECK: fir.call @_FortranAzpowi(%{{.*}}, %{{.*}}) : (complex<f64>, i32) -> complex<f64>
8478
// CHECK-NOT: complex.pow
79+
// CHECK-NOT: complex.powi
8580

8681
// CHECK-LABEL: func.func @pow_c8_i8(
8782
// CHECK: fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}) : (complex<f64>, i64) -> complex<f64>
8883
// CHECK-NOT: complex.pow
84+
// CHECK-NOT: complex.powi
8985

9086
// CHECK-LABEL: func.func @pow_c16_i4(
9187
// CHECK: fir.call @_FortranAcqpowi(%{{.*}}, %{{.*}}) : (complex<f128>, i32) -> complex<f128>
9288
// CHECK-NOT: complex.pow
89+
// CHECK-NOT: complex.powi
9390

9491
// CHECK-LABEL: func.func @pow_c16_i8(
9592
// CHECK: fir.call @_FortranAcqpowk(%{{.*}}, %{{.*}}) : (complex<f128>, i64) -> complex<f128>
9693
// CHECK-NOT: complex.pow
94+
// CHECK-NOT: complex.powi
9795

9896
// CHECK-LABEL: func.func @pow_c4_fast(
9997
// CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex<f32>
@@ -108,4 +106,4 @@ module {
108106
// CHECK-LABEL: func.func @pow_c16_complex(
109107
// CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex<f128>
110108
// CHECK: fir.call @_FortranACPowF128(%{{.*}}, %[[EXP]]) : (complex<f128>, complex<f128>) -> complex<f128>
111-
// CHECK-NOT: complex.pow
109+
// CHECK-NOT: complex.pow

mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,36 @@ def PowOp : ComplexArithmeticOp<"pow"> {
443443
}];
444444
}
445445

446+
//===----------------------------------------------------------------------===//
447+
// PowiOp
448+
//===----------------------------------------------------------------------===//
449+
450+
def PowiOp : Complex_Op<"powi",
451+
[Pure, Elementwise, SameOperandsAndResultShape,
452+
AllTypesMatch<["lhs", "result"]>,
453+
DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
454+
let summary = "complex number raised to signed integer power";
455+
let description = [{
456+
The `powi` operation takes a `base` operand of complex type and a `power`
457+
operand of signed integer type and returns one result of the same type
458+
as `base`. The result is `base` raised to the power of `power`.
459+
460+
Example:
461+
462+
```mlir
463+
%a = complex.powi %b, %c : complex<f32>, i32
464+
```
465+
}];
466+
467+
let arguments = (ins Complex<AnyFloat>:$lhs,
468+
AnySignlessInteger:$rhs,
469+
OptionalAttr<Arith_FastMathAttr>:$fastmath);
470+
let results = (outs Complex<AnyFloat>:$result);
471+
472+
let assemblyFormat =
473+
"$lhs `,` $rhs (`fastmath` `` $fastmath^)? attr-dict `:` type($result) `,` type($rhs)";
474+
}
475+
446476
//===----------------------------------------------------------------------===//
447477
// ReOp
448478
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)