-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[MLIR] Add cpow support in ComplexToROCDLLibraryCalls #153183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for complex power operations (cpow
) in the ComplexToROCDLLibraryCalls conversion pass, specifically targeting AMDGPU architectures. The implementation optimizes complex exponentiation by using mathematical identities and special-case handling for small integer powers.
- Force lowering to
complex.pow
operations for theamdgcn-amd-amdhsa
target instead of using library calls - Convert
complex.pow(z, w)
tocomplex.exp(w * complex.log(z))
using mathematical identity - Optimize integer powers 2-8 by expanding to multiplication chains (e.g.,
x²
becomesx * x
)
Reviewed Changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 1 comment.
Show a summary per file
File | Description |
---|---|
mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp |
Implements PowOpToROCDLLibraryCalls pattern with integer power optimizations and exp/log conversion |
mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir |
Adds test cases for both general complex power and optimized integer power cases |
flang/lib/Optimizer/Builder/IntrinsicCall.cpp |
Updates complex power intrinsic handling to use complex.pow operations for AMDGPU targets |
flang/test/Lower/power-operator.f90 |
Updates test expectations to handle both precise and fast modes for complex power operations |
flang/test/Lower/amdgcn-complex.f90 |
Adds test for complex power operations on AMDGPU target |
flang/test/Lower/HLFIR/binary-ops.f90 |
Updates binary operation tests to handle different lowering strategies for complex power |
mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
Show resolved
Hide resolved
@llvm/pr-subscribers-mlir-complex @llvm/pr-subscribers-mlir Author: Akash Banerjee (TIFitis) ChangesThis PR contributes the following changes: Full diff: https://github.com/llvm/llvm-project/pull/153183.diff 6 Files Affected:
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 3e6fbafe8a6b3..2f8965adfb320 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -1276,6 +1276,28 @@ mlir::Value genComplexMathOp(fir::FirOpBuilder &builder, mlir::Location loc,
return result;
}
+mlir::Value genComplexPowI(fir::FirOpBuilder &builder, mlir::Location loc,
+ const MathOperation &mathOp,
+ mlir::FunctionType mathLibFuncType,
+ llvm::ArrayRef<mlir::Value> args) {
+ bool canUseApprox = mlir::arith::bitEnumContainsAny(
+ builder.getFastMathFlags(), mlir::arith::FastMathFlags::afn);
+ bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN();
+ if (!forceMlirComplex && !canUseApprox && !isAMDGPU)
+ return genLibCall(builder, loc, mathOp, mathLibFuncType, args);
+
+ auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
+ auto realTy = complexTy.getElementType();
+ mlir::Value realExp = builder.createConvert(loc, realTy, args[1]);
+ mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
+ mlir::Value complexExp =
+ builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
+ mlir::Value result =
+ builder.create<mlir::complex::PowOp>(loc, args[0], complexExp);
+ result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
+ return result;
+}
+
/// Mapping between mathematical intrinsic operations and MLIR operations
/// of some appropriate dialect (math, complex, etc.) or libm calls.
/// TODO: support remaining Fortran math intrinsics.
@@ -1625,15 +1647,19 @@ static constexpr MathOperation mathOperations[] = {
genFuncType<Ty::Real<16>, Ty::Real<16>, Ty::Integer<8>>,
genMathOp<mlir::math::FPowIOp>},
{"pow", RTNAME_STRING(cpowi),
- genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>, genLibCall},
+ genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>,
+ genComplexPowI},
{"pow", RTNAME_STRING(zpowi),
- genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>, genLibCall},
+ genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>,
+ genComplexPowI},
{"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4,
genLibF128Call},
{"pow", RTNAME_STRING(cpowk),
- genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>, genLibCall},
+ genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>,
+ genComplexPowI},
{"pow", RTNAME_STRING(zpowk),
- genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>, genLibCall},
+ genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>,
+ genComplexPowI},
{"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8,
genLibF128Call},
{"remainder", "remainderf",
diff --git a/flang/test/Lower/HLFIR/binary-ops.f90 b/flang/test/Lower/HLFIR/binary-ops.f90
index 5855d5ad00036..fbbb596e97e1b 100644
--- a/flang/test/Lower/HLFIR/binary-ops.f90
+++ b/flang/test/Lower/HLFIR/binary-ops.f90
@@ -1,5 +1,6 @@
! Test lowering of binary intrinsic operations to HLFIR
-! RUN: bbc -emit-hlfir -o - %s 2>&1 | FileCheck %s
+! RUN: bbc -emit-hlfir -o - %s 2>&1 | FileCheck %s --check-prefixes=CHECK,PRECISE
+! RUN: bbc --force-mlir-complex -emit-hlfir -o - %s 2>&1 | FileCheck %s --check-prefixes=CHECK,FAST
subroutine int_add(x, y, z)
integer :: x, y, z
@@ -193,7 +194,8 @@ subroutine complex_to_int_power(x, y, z)
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
-! CHECK: %[[VAL_8:.*]] = fir.call @_FortranAcpowi(%[[VAL_6]], %[[VAL_7]]) fastmath<contract> : (complex<f32>, i32) -> complex<f32>
+! PRECISE: %[[VAL_8:.*]] = fir.call @_FortranAcpowi(%[[VAL_6]], %[[VAL_7]]) fastmath<contract> : (complex<f32>, i32) -> complex<f32>
+! FAST: %[[VAL_8:.*]] = complex.pow %[[VAL_6]], %{{.*}} : complex<f32>
subroutine extremum(c, n, l)
integer(8), intent(in) :: l
diff --git a/flang/test/Lower/amdgcn-complex.f90 b/flang/test/Lower/amdgcn-complex.f90
index f15c7db2b7316..3d52355d3d50a 100644
--- a/flang/test/Lower/amdgcn-complex.f90
+++ b/flang/test/Lower/amdgcn-complex.f90
@@ -1,21 +1,27 @@
! REQUIRES: amdgpu-registered-target
-! RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-fir -flang-deprecated-no-hlfir %s -o - | FileCheck %s
+! RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-fir %s -o - | FileCheck %s
+! CHECK-LABEL: func @_QPcabsf_test(
+! CHECK: complex.abs
+! CHECK-NOT: fir.call @cabsf
subroutine cabsf_test(a, b)
complex :: a
real :: b
b = abs(a)
end subroutine
-! CHECK-LABEL: func @_QPcabsf_test(
-! CHECK: complex.abs
-! CHECK-NOT: fir.call @cabsf
-
+! CHECK-LABEL: func @_QPcexpf_test(
+! CHECK: complex.exp
+! CHECK-NOT: fir.call @cexpf
subroutine cexpf_test(a, b)
complex :: a, b
b = exp(a)
end subroutine
-! CHECK-LABEL: func @_QPcexpf_test(
-! CHECK: complex.exp
-! CHECK-NOT: fir.call @cexpf
+! CHECK-LABEL: func @_QPpow_test(
+! CHECK: complex.pow
+! CHECK-NOT: fir.call @_FortranAcpowi
+subroutine pow_test(a, b)
+ complex :: a, b
+ a = b**2
+end subroutine pow_test
diff --git a/flang/test/Lower/power-operator.f90 b/flang/test/Lower/power-operator.f90
index 7436e031d20cb..2a0a09e090dde 100644
--- a/flang/test/Lower/power-operator.f90
+++ b/flang/test/Lower/power-operator.f90
@@ -96,7 +96,8 @@ subroutine pow_c4_i4(x, y, z)
complex :: x, z
integer :: y
z = x ** y
- ! CHECK: call @_FortranAcpowi
+ ! PRECISE: call @_FortranAcpowi
+ ! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f32>
end subroutine
! CHECK-LABEL: pow_c4_i8
@@ -104,7 +105,8 @@ subroutine pow_c4_i8(x, y, z)
complex :: x, z
integer(8) :: y
z = x ** y
- ! CHECK: call @_FortranAcpowk
+ ! PRECISE: call @_FortranAcpowk
+ ! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f32>
end subroutine
! CHECK-LABEL: pow_c8_i4
@@ -112,7 +114,8 @@ subroutine pow_c8_i4(x, y, z)
complex(8) :: x, z
integer :: y
z = x ** y
- ! CHECK: call @_FortranAzpowi
+ ! PRECISE: call @_FortranAzpowi
+ ! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f64>
end subroutine
! CHECK-LABEL: pow_c8_i8
@@ -120,7 +123,8 @@ subroutine pow_c8_i8(x, y, z)
complex(8) :: x, z
integer(8) :: y
z = x ** y
- ! CHECK: call @_FortranAzpowk
+ ! PRECISE: call @_FortranAzpowk
+ ! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f64>
end subroutine
! CHECK-LABEL: pow_c4_c4
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index b3d6d59e25bd0..558fcdf782800 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -56,10 +56,43 @@ struct ComplexOpToROCDLLibraryCalls : public OpRewritePattern<Op> {
private:
std::string funcName;
};
+
+// Rewrite complex.pow(z, w) -> complex.exp(w * complex.log(z))
+struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> {
+ using OpRewritePattern<complex::PowOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(complex::PowOp op,
+ PatternRewriter &rewriter) const final {
+ auto loc = op.getLoc();
+ if (auto constOp = op.getRhs().getDefiningOp<complex::ConstantOp>()) {
+ ArrayAttr value = constOp.getValue();
+ if (value.size() == 2) {
+ auto real = dyn_cast<FloatAttr>(value[0]);
+ auto imag = dyn_cast<FloatAttr>(value[1]);
+ if (real && imag && imag.getValue().isZero())
+ for (int i = 2; i <= 8; ++i)
+ if (real.getValue().isExactlyValue(i)) {
+ Value base = op.getLhs();
+ Value result = base;
+ for (int j = 1; j < i; ++j)
+ result = rewriter.create<complex::MulOp>(loc, result, base);
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+ }
+ }
+ Value logBase = rewriter.create<complex::LogOp>(loc, op.getLhs());
+ Value mul = rewriter.create<complex::MulOp>(loc, op.getRhs(), logBase);
+ Value exp = rewriter.create<complex::ExpOp>(loc, mul);
+ rewriter.replaceOp(op, exp);
+ return success();
+ }
+};
} // namespace
void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
RewritePatternSet &patterns) {
+ patterns.add<PowOpToROCDLLibraryCalls>(patterns.getContext());
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float32Type>>(
patterns.getContext(), "__ocml_cabs_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>(
@@ -110,9 +143,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
ConversionTarget target(getContext());
target.addLegalDialect<func::FuncDialect>();
+ target.addLegalOp<complex::MulOp>();
target.addIllegalOp<complex::AbsOp, complex::CosOp, complex::ExpOp,
- complex::LogOp, complex::SinOp, complex::SqrtOp,
- complex::TanOp, complex::TanhOp>();
+ complex::LogOp, complex::PowOp, complex::SinOp,
+ complex::SqrtOp, complex::TanOp, complex::TanhOp>();
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
index b9eec8e24a0b9..e229462b70b98 100644
--- a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
+++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
@@ -63,6 +63,32 @@ func.func @log_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, comp
return %lf, %ld : complex<f32>, complex<f64>
}
+//CHECK-LABEL: @pow_caller
+func.func @pow_caller(%f: complex<f32>, %g: complex<f32>) -> complex<f32> {
+ // CHECK: %[[LOG:.*]] = call @__ocml_clog_f32(%{{.*}})
+ // CHECK: %[[MUL:.*]] = complex.mul %[[LOG]], %{{.*}} : complex<f32>
+ // CHECK: %[[EXP:.*]] = call @__ocml_cexp_f32(%[[MUL]])
+ // CHECK: return %[[EXP]]
+ %r = complex.pow %f, %g : complex<f32>
+ return %r : complex<f32>
+}
+
+// CHECK-LABEL: @pow_int_caller
+func.func @pow_int_caller(%f : complex<f32>, %d : complex<f64>)
+ ->(complex<f32>, complex<f64>) {
+ // CHECK-NOT: call @__ocml
+ // CHECK: %[[M2:.*]] = complex.mul %{{.*}}, %{{.*}} : complex<f32>
+ %c2 = complex.constant [2.0 : f32, 0.0 : f32] : complex<f32>
+ %p2 = complex.pow %f, %c2 : complex<f32>
+ // CHECK-NOT: call @__ocml
+ // CHECK: %[[M3A:.*]] = complex.mul %{{.*}}, %{{.*}} : complex<f64>
+ // CHECK: %[[M3B:.*]] = complex.mul %[[M3A]], %{{.*}} : complex<f64>
+ %c3 = complex.constant [3.0 : f64, 0.0 : f64] : complex<f64>
+ %p3 = complex.pow %d, %c3 : complex<f64>
+ // CHECK: return %[[M2]], %[[M3B]]
+ return %p2, %p3 : complex<f32>, complex<f64>
+}
+
//CHECK-LABEL: @sin_caller
func.func @sin_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
// CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}})
|
@llvm/pr-subscribers-flang-fir-hlfir Author: Akash Banerjee (TIFitis) ChangesThis PR contributes the following changes: Full diff: https://github.com/llvm/llvm-project/pull/153183.diff 6 Files Affected:
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 3e6fbafe8a6b3..2f8965adfb320 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -1276,6 +1276,28 @@ mlir::Value genComplexMathOp(fir::FirOpBuilder &builder, mlir::Location loc,
return result;
}
+mlir::Value genComplexPowI(fir::FirOpBuilder &builder, mlir::Location loc,
+ const MathOperation &mathOp,
+ mlir::FunctionType mathLibFuncType,
+ llvm::ArrayRef<mlir::Value> args) {
+ bool canUseApprox = mlir::arith::bitEnumContainsAny(
+ builder.getFastMathFlags(), mlir::arith::FastMathFlags::afn);
+ bool isAMDGPU = fir::getTargetTriple(builder.getModule()).isAMDGCN();
+ if (!forceMlirComplex && !canUseApprox && !isAMDGPU)
+ return genLibCall(builder, loc, mathOp, mathLibFuncType, args);
+
+ auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
+ auto realTy = complexTy.getElementType();
+ mlir::Value realExp = builder.createConvert(loc, realTy, args[1]);
+ mlir::Value zero = builder.createRealConstant(loc, realTy, 0);
+ mlir::Value complexExp =
+ builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero);
+ mlir::Value result =
+ builder.create<mlir::complex::PowOp>(loc, args[0], complexExp);
+ result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
+ return result;
+}
+
/// Mapping between mathematical intrinsic operations and MLIR operations
/// of some appropriate dialect (math, complex, etc.) or libm calls.
/// TODO: support remaining Fortran math intrinsics.
@@ -1625,15 +1647,19 @@ static constexpr MathOperation mathOperations[] = {
genFuncType<Ty::Real<16>, Ty::Real<16>, Ty::Integer<8>>,
genMathOp<mlir::math::FPowIOp>},
{"pow", RTNAME_STRING(cpowi),
- genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>, genLibCall},
+ genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>,
+ genComplexPowI},
{"pow", RTNAME_STRING(zpowi),
- genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>, genLibCall},
+ genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<4>>,
+ genComplexPowI},
{"pow", RTNAME_STRING(cqpowi), FuncTypeComplex16Complex16Integer4,
genLibF128Call},
{"pow", RTNAME_STRING(cpowk),
- genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>, genLibCall},
+ genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<8>>,
+ genComplexPowI},
{"pow", RTNAME_STRING(zpowk),
- genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>, genLibCall},
+ genFuncType<Ty::Complex<8>, Ty::Complex<8>, Ty::Integer<8>>,
+ genComplexPowI},
{"pow", RTNAME_STRING(cqpowk), FuncTypeComplex16Complex16Integer8,
genLibF128Call},
{"remainder", "remainderf",
diff --git a/flang/test/Lower/HLFIR/binary-ops.f90 b/flang/test/Lower/HLFIR/binary-ops.f90
index 5855d5ad00036..fbbb596e97e1b 100644
--- a/flang/test/Lower/HLFIR/binary-ops.f90
+++ b/flang/test/Lower/HLFIR/binary-ops.f90
@@ -1,5 +1,6 @@
! Test lowering of binary intrinsic operations to HLFIR
-! RUN: bbc -emit-hlfir -o - %s 2>&1 | FileCheck %s
+! RUN: bbc -emit-hlfir -o - %s 2>&1 | FileCheck %s --check-prefixes=CHECK,PRECISE
+! RUN: bbc --force-mlir-complex -emit-hlfir -o - %s 2>&1 | FileCheck %s --check-prefixes=CHECK,FAST
subroutine int_add(x, y, z)
integer :: x, y, z
@@ -193,7 +194,8 @@ subroutine complex_to_int_power(x, y, z)
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
-! CHECK: %[[VAL_8:.*]] = fir.call @_FortranAcpowi(%[[VAL_6]], %[[VAL_7]]) fastmath<contract> : (complex<f32>, i32) -> complex<f32>
+! PRECISE: %[[VAL_8:.*]] = fir.call @_FortranAcpowi(%[[VAL_6]], %[[VAL_7]]) fastmath<contract> : (complex<f32>, i32) -> complex<f32>
+! FAST: %[[VAL_8:.*]] = complex.pow %[[VAL_6]], %{{.*}} : complex<f32>
subroutine extremum(c, n, l)
integer(8), intent(in) :: l
diff --git a/flang/test/Lower/amdgcn-complex.f90 b/flang/test/Lower/amdgcn-complex.f90
index f15c7db2b7316..3d52355d3d50a 100644
--- a/flang/test/Lower/amdgcn-complex.f90
+++ b/flang/test/Lower/amdgcn-complex.f90
@@ -1,21 +1,27 @@
! REQUIRES: amdgpu-registered-target
-! RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-fir -flang-deprecated-no-hlfir %s -o - | FileCheck %s
+! RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-fir %s -o - | FileCheck %s
+! CHECK-LABEL: func @_QPcabsf_test(
+! CHECK: complex.abs
+! CHECK-NOT: fir.call @cabsf
subroutine cabsf_test(a, b)
complex :: a
real :: b
b = abs(a)
end subroutine
-! CHECK-LABEL: func @_QPcabsf_test(
-! CHECK: complex.abs
-! CHECK-NOT: fir.call @cabsf
-
+! CHECK-LABEL: func @_QPcexpf_test(
+! CHECK: complex.exp
+! CHECK-NOT: fir.call @cexpf
subroutine cexpf_test(a, b)
complex :: a, b
b = exp(a)
end subroutine
-! CHECK-LABEL: func @_QPcexpf_test(
-! CHECK: complex.exp
-! CHECK-NOT: fir.call @cexpf
+! CHECK-LABEL: func @_QPpow_test(
+! CHECK: complex.pow
+! CHECK-NOT: fir.call @_FortranAcpowi
+subroutine pow_test(a, b)
+ complex :: a, b
+ a = b**2
+end subroutine pow_test
diff --git a/flang/test/Lower/power-operator.f90 b/flang/test/Lower/power-operator.f90
index 7436e031d20cb..2a0a09e090dde 100644
--- a/flang/test/Lower/power-operator.f90
+++ b/flang/test/Lower/power-operator.f90
@@ -96,7 +96,8 @@ subroutine pow_c4_i4(x, y, z)
complex :: x, z
integer :: y
z = x ** y
- ! CHECK: call @_FortranAcpowi
+ ! PRECISE: call @_FortranAcpowi
+ ! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f32>
end subroutine
! CHECK-LABEL: pow_c4_i8
@@ -104,7 +105,8 @@ subroutine pow_c4_i8(x, y, z)
complex :: x, z
integer(8) :: y
z = x ** y
- ! CHECK: call @_FortranAcpowk
+ ! PRECISE: call @_FortranAcpowk
+ ! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f32>
end subroutine
! CHECK-LABEL: pow_c8_i4
@@ -112,7 +114,8 @@ subroutine pow_c8_i4(x, y, z)
complex(8) :: x, z
integer :: y
z = x ** y
- ! CHECK: call @_FortranAzpowi
+ ! PRECISE: call @_FortranAzpowi
+ ! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f64>
end subroutine
! CHECK-LABEL: pow_c8_i8
@@ -120,7 +123,8 @@ subroutine pow_c8_i8(x, y, z)
complex(8) :: x, z
integer(8) :: y
z = x ** y
- ! CHECK: call @_FortranAzpowk
+ ! PRECISE: call @_FortranAzpowk
+ ! FAST: complex.pow %{{.*}}, %{{.*}} : complex<f64>
end subroutine
! CHECK-LABEL: pow_c4_c4
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index b3d6d59e25bd0..558fcdf782800 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -56,10 +56,43 @@ struct ComplexOpToROCDLLibraryCalls : public OpRewritePattern<Op> {
private:
std::string funcName;
};
+
+// Rewrite complex.pow(z, w) -> complex.exp(w * complex.log(z))
+struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> {
+ using OpRewritePattern<complex::PowOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(complex::PowOp op,
+ PatternRewriter &rewriter) const final {
+ auto loc = op.getLoc();
+ if (auto constOp = op.getRhs().getDefiningOp<complex::ConstantOp>()) {
+ ArrayAttr value = constOp.getValue();
+ if (value.size() == 2) {
+ auto real = dyn_cast<FloatAttr>(value[0]);
+ auto imag = dyn_cast<FloatAttr>(value[1]);
+ if (real && imag && imag.getValue().isZero())
+ for (int i = 2; i <= 8; ++i)
+ if (real.getValue().isExactlyValue(i)) {
+ Value base = op.getLhs();
+ Value result = base;
+ for (int j = 1; j < i; ++j)
+ result = rewriter.create<complex::MulOp>(loc, result, base);
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+ }
+ }
+ Value logBase = rewriter.create<complex::LogOp>(loc, op.getLhs());
+ Value mul = rewriter.create<complex::MulOp>(loc, op.getRhs(), logBase);
+ Value exp = rewriter.create<complex::ExpOp>(loc, mul);
+ rewriter.replaceOp(op, exp);
+ return success();
+ }
+};
} // namespace
void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
RewritePatternSet &patterns) {
+ patterns.add<PowOpToROCDLLibraryCalls>(patterns.getContext());
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float32Type>>(
patterns.getContext(), "__ocml_cabs_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>(
@@ -110,9 +143,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
ConversionTarget target(getContext());
target.addLegalDialect<func::FuncDialect>();
+ target.addLegalOp<complex::MulOp>();
target.addIllegalOp<complex::AbsOp, complex::CosOp, complex::ExpOp,
- complex::LogOp, complex::SinOp, complex::SqrtOp,
- complex::TanOp, complex::TanhOp>();
+ complex::LogOp, complex::PowOp, complex::SinOp,
+ complex::SqrtOp, complex::TanOp, complex::TanhOp>();
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
index b9eec8e24a0b9..e229462b70b98 100644
--- a/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
+++ b/mlir/test/Conversion/ComplexToROCDLLibraryCalls/complex-to-rocdl-library-calls.mlir
@@ -63,6 +63,32 @@ func.func @log_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, comp
return %lf, %ld : complex<f32>, complex<f64>
}
+//CHECK-LABEL: @pow_caller
+func.func @pow_caller(%f: complex<f32>, %g: complex<f32>) -> complex<f32> {
+ // CHECK: %[[LOG:.*]] = call @__ocml_clog_f32(%{{.*}})
+ // CHECK: %[[MUL:.*]] = complex.mul %[[LOG]], %{{.*}} : complex<f32>
+ // CHECK: %[[EXP:.*]] = call @__ocml_cexp_f32(%[[MUL]])
+ // CHECK: return %[[EXP]]
+ %r = complex.pow %f, %g : complex<f32>
+ return %r : complex<f32>
+}
+
+// CHECK-LABEL: @pow_int_caller
+func.func @pow_int_caller(%f : complex<f32>, %d : complex<f64>)
+ ->(complex<f32>, complex<f64>) {
+ // CHECK-NOT: call @__ocml
+ // CHECK: %[[M2:.*]] = complex.mul %{{.*}}, %{{.*}} : complex<f32>
+ %c2 = complex.constant [2.0 : f32, 0.0 : f32] : complex<f32>
+ %p2 = complex.pow %f, %c2 : complex<f32>
+ // CHECK-NOT: call @__ocml
+ // CHECK: %[[M3A:.*]] = complex.mul %{{.*}}, %{{.*}} : complex<f64>
+ // CHECK: %[[M3B:.*]] = complex.mul %[[M3A]], %{{.*}} : complex<f64>
+ %c3 = complex.constant [3.0 : f64, 0.0 : f64] : complex<f64>
+ %p3 = complex.pow %d, %c3 : complex<f64>
+ // CHECK: return %[[M2]], %[[M3B]]
+ return %p2, %p3 : complex<f32>, complex<f64>
+}
+
//CHECK-LABEL: @sin_caller
func.func @sin_caller(%f: complex<f32>, %d: complex<f64>) -> (complex<f32>, complex<f64>) {
// CHECK: %[[SF:.*]] = call @__ocml_csin_f32(%{{.*}})
|
mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
Outdated
Show resolved
Hide resolved
@@ -1625,15 +1647,19 @@ static constexpr MathOperation mathOperations[] = { | |||
genFuncType<Ty::Real<16>, Ty::Real<16>, Ty::Integer<8>>, | |||
genMathOp<mlir::math::FPowIOp>}, | |||
{"pow", RTNAME_STRING(cpowi), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please clarify the benefits of expanding these functions in MLIR vs implementing the same logic in Fortran runtime compiled for AMD GPU device? I do not have any concerns, I am just curious.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it's more modular in the sense any frontend can just lower to complex dialect and have the conversion pass take care of the rest, rather than have every frontend lower specifically for amdgcn.
But Flang is the only concern at the moment so I'm happy to move it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, the "runtime" approach that I meant was that we generate the cpowi
etc. calls here, and then the runtime implementation for AMD GPU device uses the __ocml_*
intrinsics. Would that be a viable solution? I guess the benefit of having the complex operations is some special case handling, like the constant exponent optimizations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, but I didn't understand. This pass only runs on device pass for amdgpu and converts the complex ops to relevant ocml library calls. Are you suggesting we delay this conversion to something like mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant that we could just have AMD GPU specific versions of _FortranAcpowi
and other functions in flang-rt/lib/runtime
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, I'll look into it if that's a possibility. But I guess this is a workaround at the moment for lib functions that are not available on the GPU. mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp also does something similar by converting things to ROCDL calls.
mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MLIR-side LGTM, might want to wait on a flang reviewer before merging
mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
Show resolved
Hide resolved
mlir::FunctionType mathLibFuncType, | ||
llvm::ArrayRef<mlir::Value> args) { | ||
if (auto expInt = fir::getIntIfConstant(args[1])) | ||
if (*expInt >= 2 && *expInt <= 8) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My original thought that this could me a rewrite pattern over in mlir/lib/Dialect/Complex/Transforms
, but here works too if you don't think it'll be generally useful.
(Because I'm not sure that this is actually an AMDGPU- or flang-specific optimization once you've relaxed IEEE-bit-accuracy)
@tblah Hi, would you be please be able to review the Thanks. |
@@ -1625,15 +1647,19 @@ static constexpr MathOperation mathOperations[] = { | |||
genFuncType<Ty::Real<16>, Ty::Real<16>, Ty::Integer<8>>, | |||
genMathOp<mlir::math::FPowIOp>}, | |||
{"pow", RTNAME_STRING(cpowi), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, the "runtime" approach that I meant was that we generate the cpowi
etc. calls here, and then the runtime implementation for AMD GPU device uses the __ocml_*
intrinsics. Would that be a viable solution? I guess the benefit of having the complex operations is some special case handling, like the constant exponent optimizations.
@@ -1625,15 +1655,19 @@ static constexpr MathOperation mathOperations[] = { | |||
genFuncType<Ty::Real<16>, Ty::Real<16>, Ty::Integer<8>>, | |||
genMathOp<mlir::math::FPowIOp>}, | |||
{"pow", RTNAME_STRING(cpowi), | |||
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>, genLibCall}, | |||
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>, | |||
genComplexPow}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it will be better to handle these cases same way as the other intrinsics where we can generate either a lib call or an MLIR operation, e.g.:
{"pow", RTNAME_STRING(FPow4i),
genFuncType<Ty::Real<4>, Ty::Real<4>, Ty::Integer<4>>,
genMathOp<mlir::math::FPowIOp>},
Then, you may have a ROCDL specific pass that converts the complex operations into AMD GPU code, and a Flang pipeline pass that converts the complex operations into the runtime calls. You may also have a pass that does the canonicalization/folding for the constant exponent cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, but I don't understand the change. Currently @_FortranAcpowi calls are generated for all cases, and this PR adds lowering to complex.pow op for amdgpu device pass. The complex.pow gets later converted to ocml calls. Can you please clarify what you are suggesting instead?
Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please look at how this case works:
{"pow", RTNAME_STRING(FPow4i),
genFuncType<Ty::Real<4>, Ty::Real<4>, Ty::Integer<4>>,
genMathOp<mlir::math::FPowIOp>},
Depending on the mathRuntimeVersion
Flang either generates a call to _FortranAFPow4i
or an mlir::math::FPowIOp
operation. You can do the same for _FortranAcpowi
vs complex.pow
, and then handle complex.pow
any way you wish later in the pipeline. So for AMD GPU you may convert it to the ocml calls, and otherwise you may convert it to _FortranAcpowi
late in Flang pass pipeline. This way, we get all the benefits of not having a call with side effects at MLIR level, and we can apply folding/canonicalization to complex.pow
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've updated the code to reflect this change. Let me know if it's what you wanted or would like to see any further changes.
Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update! It seems to be the right direction to me, though there is a couple of missing things:
- I think we need to make sure that we still call the
_FortranAcpowi
and other Fortran runtime functions for Flang, so I think we need to have a pass that will convert the complex pow operations back to Fortran runtime calls (unless the ROCDL conversion converts them to AMD GPU specific code). - I would suggest introducing powi operation in the Complex dialect, so that we know that the exponent argument is an integer value 100%. If there is a way to guarantee that we always recognize
complex.pow
's integer exponent argument whenever Flang created such an operation, thenpowi
is redundant. So it depends on how reliable the recognition is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the optimisation has deviated the PR too far. The original motivation of this PR is only to add support for cpow lowering on AMDGPU.
I've reverted the PR to an older revision which was already accepted by @krzysz00 , and dropped the optimisation entirely.
I'll start a separate PR soon for the Flang lowering changes along with the optimisation.
Please let me know if you would like to see any changes to this PR before I merge it.
Thanks.
This PR contributes the following changes: 1. Force lowering to complex.pow ops for the amdgcn-amd-amdhsa target. 2. Convert complex.pow(z, w) -> complex.exp(w * complex.log(z)). 3. Convert x ** 2 -> x * x, x ** 3 -> x * x * x, ... x ** 8 -> x * x... .
… complex::CreateOp.
@@ -1625,15 +1655,19 @@ static constexpr MathOperation mathOperations[] = { | |||
genFuncType<Ty::Real<16>, Ty::Real<16>, Ty::Integer<8>>, | |||
genMathOp<mlir::math::FPowIOp>}, | |||
{"pow", RTNAME_STRING(cpowi), | |||
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>, genLibCall}, | |||
genFuncType<Ty::Complex<4>, Ty::Complex<4>, Ty::Integer<4>>, | |||
genComplexPow}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update! It seems to be the right direction to me, though there is a couple of missing things:
- I think we need to make sure that we still call the
_FortranAcpowi
and other Fortran runtime functions for Flang, so I think we need to have a pass that will convert the complex pow operations back to Fortran runtime calls (unless the ROCDL conversion converts them to AMD GPU specific code). - I would suggest introducing powi operation in the Complex dialect, so that we know that the exponent argument is an integer value 100%. If there is a way to guarantee that we always recognize
complex.pow
's integer exponent argument whenever Flang created such an operation, thenpowi
is redundant. So it depends on how reliable the recognition is.
return castIntToFloat(cast); | ||
if (auto cast = v.getDefiningOp<arith::UIToFPOp>()) | ||
return castIntToFloat(cast); | ||
if (v.getDefiningOp()->getName().getStringRef() == "fir.convert") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not be relying on fir.convert
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MLIR side approved
mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
Outdated
Show resolved
Hide resolved
@vzakhari Are you okay with me merging this? Thanks. |
mlir::Value complexExp = | ||
builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp, zero); | ||
mlir::Value result = | ||
builder.create<mlir::complex::PowOp>(loc, args[0], complexExp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please make sure that complex.pow
is only generated when isAMDGPU
is true, otherwise, I would expect performance regressions in afn
compilations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
This PR adds support for complex power operations (
cpow
) in theComplexToROCDLLibraryCalls
conversion pass, specifically targeting AMDGPU architectures. The implementation optimises complex exponentiation by using mathematical identities and special-case handling for small integer powers.complex.pow
operations for theamdgcn-amd-amdhsa
target instead of using library callscomplex.pow(z, w)
tocomplex.exp(w * complex.log(z))
using mathematical identityx²
becomesx * x
)