-
Notifications
You must be signed in to change notification settings - Fork 12k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[flang] Add fastmath attributes to complex arithmetic #70690
Conversation
These attributes (when propagated to LLVM) allow multiple operations to be merged into one e.g. fused-multiply-add. I will add support for these attributes in CodeGen in my next patch.
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-flang-codegen Author: Tom Eccles (tblah) ChangesPropagate fast math flags through complex number lowering (when lowering fir.*c directly to llvm floating point operations). The lowering path through the MLIR complex dialect is unchanged. This leads to a small improvement in spec2017 fotonik3d_r. Patch is 21.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/70690.diff 7 Files Affected:
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index dd2e90c3b1a1fde..6e8064a63b7ae0a 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2538,12 +2538,18 @@ def fir_NegcOp : ComplexUnaryArithmeticOp<"negc">;
class ComplexArithmeticOp<string mnemonic, list<Trait> traits = []> :
fir_ArithmeticOp<mnemonic, traits>,
- Arguments<(ins fir_ComplexType:$lhs, fir_ComplexType:$rhs)>;
-
-def fir_AddcOp : ComplexArithmeticOp<"addc", [Commutative]>;
-def fir_SubcOp : ComplexArithmeticOp<"subc">;
-def fir_MulcOp : ComplexArithmeticOp<"mulc", [Commutative]>;
-def fir_DivcOp : ComplexArithmeticOp<"divc">;
+ Arguments<(ins fir_ComplexType:$lhs, fir_ComplexType:$rhs,
+ DefaultValuedAttr<Arith_FastMathAttr,
+ "::mlir::arith::FastMathFlags::none">:$fastmath)>;
+
+def fir_AddcOp : ComplexArithmeticOp<"addc",
+ [Commutative, DeclareOpInterfaceMethods<ArithFastMathInterface>]>;
+def fir_SubcOp : ComplexArithmeticOp<"subc",
+ [DeclareOpInterfaceMethods<ArithFastMathInterface>]>;
+def fir_MulcOp : ComplexArithmeticOp<"mulc",
+ [Commutative, DeclareOpInterfaceMethods<ArithFastMathInterface>]>;
+def fir_DivcOp : ComplexArithmeticOp<"divc",
+ [DeclareOpInterfaceMethods<ArithFastMathInterface>]>;
// Pow is a builtin call and not a primitive
def fir_CmpcOp : fir_Op<"cmpc",
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 0f85f89f1a48138..3f6f2b0474d44b4 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -33,6 +33,7 @@
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/Transforms/AddComdats.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
@@ -3502,6 +3503,8 @@ static mlir::LLVM::InsertValueOp
complexSum(OPTY sumop, mlir::ValueRange opnds,
mlir::ConversionPatternRewriter &rewriter,
const fir::LLVMTypeConverter &lowering) {
+ mlir::LLVM::FastmathFlags fastmathFlags =
+ mlir::arith::convertArithFastMathFlagsToLLVM(sumop.getFastmath());
mlir::Value a = opnds[0];
mlir::Value b = opnds[1];
auto loc = sumop.getLoc();
@@ -3512,7 +3515,9 @@ complexSum(OPTY sumop, mlir::ValueRange opnds,
auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 0);
auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 1);
auto rx = rewriter.create<LLVMOP>(loc, eleTy, x0, x1);
+ rx.setFastmathFlags(fastmathFlags);
auto ry = rewriter.create<LLVMOP>(loc, eleTy, y0, y1);
+ ry.setFastmathFlags(fastmathFlags);
auto r0 = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r0, rx, 0);
return rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ry, 1);
@@ -3560,6 +3565,8 @@ struct MulcOpConversion : public FIROpConversion<fir::MulcOp> {
// TODO: Can we use a call to __muldc3 ?
// given: (x + iy) * (x' + iy')
// result: (xx'-yy')+i(xy'+yx')
+ mlir::LLVM::FastmathFlags fastmathFlags =
+ mlir::arith::convertArithFastMathFlagsToLLVM(mulc.getFastmath());
mlir::Value a = adaptor.getOperands()[0];
mlir::Value b = adaptor.getOperands()[1];
auto loc = mulc.getLoc();
@@ -3570,11 +3577,17 @@ struct MulcOpConversion : public FIROpConversion<fir::MulcOp> {
auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 0);
auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 1);
auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1);
+ xx.setFastmathFlags(fastmathFlags);
auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1);
+ yx.setFastmathFlags(fastmathFlags);
auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1);
+ xy.setFastmathFlags(fastmathFlags);
auto ri = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xy, yx);
+ ri.setFastmathFlags(fastmathFlags);
auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1);
+ yy.setFastmathFlags(fastmathFlags);
auto rr = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, xx, yy);
+ rr.setFastmathFlags(fastmathFlags);
auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ra, rr, 0);
auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ri, 1);
@@ -3594,6 +3607,8 @@ struct DivcOpConversion : public FIROpConversion<fir::DivcOp> {
// Just generate inline code for now.
// given: (x + iy) / (x' + iy')
// result: ((xx'+yy')/d) + i((yx'-xy')/d) where d = x'x' + y'y'
+ mlir::LLVM::FastmathFlags fastmathFlags =
+ mlir::arith::convertArithFastMathFlagsToLLVM(divc.getFastmath());
mlir::Value a = adaptor.getOperands()[0];
mlir::Value b = adaptor.getOperands()[1];
auto loc = divc.getLoc();
@@ -3604,16 +3619,27 @@ struct DivcOpConversion : public FIROpConversion<fir::DivcOp> {
auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 0);
auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 1);
auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1);
+ xx.setFastmathFlags(fastmathFlags);
auto x1x1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x1, x1);
+ x1x1.setFastmathFlags(fastmathFlags);
auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1);
+ yx.setFastmathFlags(fastmathFlags);
auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1);
+ xy.setFastmathFlags(fastmathFlags);
auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1);
+ yy.setFastmathFlags(fastmathFlags);
auto y1y1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y1, y1);
+ y1y1.setFastmathFlags(fastmathFlags);
auto d = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, x1x1, y1y1);
+ d.setFastmathFlags(fastmathFlags);
auto rrn = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xx, yy);
+ rrn.setFastmathFlags(fastmathFlags);
auto rin = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, yx, xy);
+ rin.setFastmathFlags(fastmathFlags);
auto rr = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rrn, d);
+ rr.setFastmathFlags(fastmathFlags);
auto ri = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rin, d);
+ ri.setFastmathFlags(fastmathFlags);
auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ra, rr, 0);
auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ri, 1);
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index cecfbff7eac228b..c9a44914b987053 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -507,7 +507,7 @@ func.func @test_call_return_val() -> i32 {
// result: (x + x') + i(y + y')
func.func @fir_complex_add(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.complex<16> {
- %c = fir.addc %a, %b : !fir.complex<16>
+ %c = fir.addc %a, %b {fastmath = #arith.fastmath<fast>} : !fir.complex<16>
return %c : !fir.complex<16>
}
@@ -518,8 +518,8 @@ func.func @fir_complex_add(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
// CHECK: %[[Y0:.*]] = llvm.extractvalue %[[ARG0]][1] : !llvm.struct<(f128, f128)>
// CHECK: %[[X1:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.struct<(f128, f128)>
// CHECK: %[[Y1:.*]] = llvm.extractvalue %[[ARG1]][1] : !llvm.struct<(f128, f128)>
-// CHECK: %[[ADD_X0_X1:.*]] = llvm.fadd %[[X0]], %[[X1]] : f128
-// CHECK: %[[ADD_Y0_Y1:.*]] = llvm.fadd %[[Y0]], %[[Y1]] : f128
+// CHECK: %[[ADD_X0_X1:.*]] = llvm.fadd %[[X0]], %[[X1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[ADD_Y0_Y1:.*]] = llvm.fadd %[[Y0]], %[[Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
// CHECK: %{{.*}} = llvm.mlir.undef : !llvm.struct<(f128, f128)>
// CHECK: %{{.*}} = llvm.insertvalue %[[ADD_X0_X1]], %{{.*}}[0] : !llvm.struct<(f128, f128)>
// CHECK: %{{.*}} = llvm.insertvalue %[[ADD_Y0_Y1]], %{{.*}}[1] : !llvm.struct<(f128, f128)>
@@ -532,7 +532,7 @@ func.func @fir_complex_add(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
// result: (x - x') + i(y - y')
func.func @fir_complex_sub(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.complex<16> {
- %c = fir.subc %a, %b : !fir.complex<16>
+ %c = fir.subc %a, %b {fastmath = #arith.fastmath<fast>} : !fir.complex<16>
return %c : !fir.complex<16>
}
@@ -543,8 +543,8 @@ func.func @fir_complex_sub(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
// CHECK: %[[Y0:.*]] = llvm.extractvalue %[[ARG0]][1] : !llvm.struct<(f128, f128)>
// CHECK: %[[X1:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.struct<(f128, f128)>
// CHECK: %[[Y1:.*]] = llvm.extractvalue %[[ARG1]][1] : !llvm.struct<(f128, f128)>
-// CHECK: %[[SUB_X0_X1:.*]] = llvm.fsub %[[X0]], %[[X1]] : f128
-// CHECK: %[[SUB_Y0_Y1:.*]] = llvm.fsub %[[Y0]], %[[Y1]] : f128
+// CHECK: %[[SUB_X0_X1:.*]] = llvm.fsub %[[X0]], %[[X1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[SUB_Y0_Y1:.*]] = llvm.fsub %[[Y0]], %[[Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
// CHECK: %{{.*}} = llvm.mlir.undef : !llvm.struct<(f128, f128)>
// CHECK: %{{.*}} = llvm.insertvalue %[[SUB_X0_X1]], %{{.*}}[0] : !llvm.struct<(f128, f128)>
// CHECK: %{{.*}} = llvm.insertvalue %[[SUB_Y0_Y1]], %{{.*}}[1] : !llvm.struct<(f128, f128)>
@@ -557,7 +557,7 @@ func.func @fir_complex_sub(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
// result: (xx'-yy')+i(xy'+yx')
func.func @fir_complex_mul(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.complex<16> {
- %c = fir.mulc %a, %b : !fir.complex<16>
+ %c = fir.mulc %a, %b {fastmath = #arith.fastmath<fast>} : !fir.complex<16>
return %c : !fir.complex<16>
}
@@ -568,12 +568,12 @@ func.func @fir_complex_mul(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
// CHECK: %[[Y0:.*]] = llvm.extractvalue %[[ARG0]][1] : !llvm.struct<(f128, f128)>
// CHECK: %[[X1:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.struct<(f128, f128)>
// CHECK: %[[Y1:.*]] = llvm.extractvalue %[[ARG1]][1] : !llvm.struct<(f128, f128)>
-// CHECK: %[[MUL_X0_X1:.*]] = llvm.fmul %[[X0]], %[[X1]] : f128
-// CHECK: %[[MUL_Y0_X1:.*]] = llvm.fmul %[[Y0]], %[[X1]] : f128
-// CHECK: %[[MUL_X0_Y1:.*]] = llvm.fmul %[[X0]], %[[Y1]] : f128
-// CHECK: %[[ADD:.*]] = llvm.fadd %[[MUL_X0_Y1]], %[[MUL_Y0_X1]] : f128
-// CHECK: %[[MUL_Y0_Y1:.*]] = llvm.fmul %[[Y0]], %[[Y1]] : f128
-// CHECK: %[[SUB:.*]] = llvm.fsub %[[MUL_X0_X1]], %[[MUL_Y0_Y1]] : f128
+// CHECK: %[[MUL_X0_X1:.*]] = llvm.fmul %[[X0]], %[[X1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[MUL_Y0_X1:.*]] = llvm.fmul %[[Y0]], %[[X1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[MUL_X0_Y1:.*]] = llvm.fmul %[[X0]], %[[Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[ADD:.*]] = llvm.fadd %[[MUL_X0_Y1]], %[[MUL_Y0_X1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[MUL_Y0_Y1:.*]] = llvm.fmul %[[Y0]], %[[Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[SUB:.*]] = llvm.fsub %[[MUL_X0_X1]], %[[MUL_Y0_Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
// CHECK: %{{.*}} = llvm.mlir.undef : !llvm.struct<(f128, f128)>
// CHECK: %{{.*}} = llvm.insertvalue %[[SUB]], %{{.*}}[0] : !llvm.struct<(f128, f128)>
// CHECK: %{{.*}} = llvm.insertvalue %[[ADD]], %{{.*}}[1] : !llvm.struct<(f128, f128)>
@@ -586,7 +586,7 @@ func.func @fir_complex_mul(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
// result: ((xx'+yy')/d) + i((yx'-xy')/d) where d = x'x' + y'y'
func.func @fir_complex_div(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.complex<16> {
- %c = fir.divc %a, %b : !fir.complex<16>
+ %c = fir.divc %a, %b {fastmath = #arith.fastmath<fast>} : !fir.complex<16>
return %c : !fir.complex<16>
}
@@ -597,17 +597,17 @@ func.func @fir_complex_div(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
// CHECK: %[[Y0:.*]] = llvm.extractvalue %[[ARG0]][1] : !llvm.struct<(f128, f128)>
// CHECK: %[[X1:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.struct<(f128, f128)>
// CHECK: %[[Y1:.*]] = llvm.extractvalue %[[ARG1]][1] : !llvm.struct<(f128, f128)>
-// CHECK: %[[MUL_X0_X1:.*]] = llvm.fmul %[[X0]], %[[X1]] : f128
-// CHECK: %[[MUL_X1_X1:.*]] = llvm.fmul %[[X1]], %[[X1]] : f128
-// CHECK: %[[MUL_Y0_X1:.*]] = llvm.fmul %[[Y0]], %[[X1]] : f128
-// CHECK: %[[MUL_X0_Y1:.*]] = llvm.fmul %[[X0]], %[[Y1]] : f128
-// CHECK: %[[MUL_Y0_Y1:.*]] = llvm.fmul %[[Y0]], %[[Y1]] : f128
-// CHECK: %[[MUL_Y1_Y1:.*]] = llvm.fmul %[[Y1]], %[[Y1]] : f128
-// CHECK: %[[ADD_X1X1_Y1Y1:.*]] = llvm.fadd %[[MUL_X1_X1]], %[[MUL_Y1_Y1]] : f128
-// CHECK: %[[ADD_X0X1_Y0Y1:.*]] = llvm.fadd %[[MUL_X0_X1]], %[[MUL_Y0_Y1]] : f128
-// CHECK: %[[SUB_Y0X1_X0Y1:.*]] = llvm.fsub %[[MUL_Y0_X1]], %[[MUL_X0_Y1]] : f128
-// CHECK: %[[DIV0:.*]] = llvm.fdiv %[[ADD_X0X1_Y0Y1]], %[[ADD_X1X1_Y1Y1]] : f128
-// CHECK: %[[DIV1:.*]] = llvm.fdiv %[[SUB_Y0X1_X0Y1]], %[[ADD_X1X1_Y1Y1]] : f128
+// CHECK: %[[MUL_X0_X1:.*]] = llvm.fmul %[[X0]], %[[X1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[MUL_X1_X1:.*]] = llvm.fmul %[[X1]], %[[X1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[MUL_Y0_X1:.*]] = llvm.fmul %[[Y0]], %[[X1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[MUL_X0_Y1:.*]] = llvm.fmul %[[X0]], %[[Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[MUL_Y0_Y1:.*]] = llvm.fmul %[[Y0]], %[[Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[MUL_Y1_Y1:.*]] = llvm.fmul %[[Y1]], %[[Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[ADD_X1X1_Y1Y1:.*]] = llvm.fadd %[[MUL_X1_X1]], %[[MUL_Y1_Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[ADD_X0X1_Y0Y1:.*]] = llvm.fadd %[[MUL_X0_X1]], %[[MUL_Y0_Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[SUB_Y0X1_X0Y1:.*]] = llvm.fsub %[[MUL_Y0_X1]], %[[MUL_X0_Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[DIV0:.*]] = llvm.fdiv %[[ADD_X0X1_Y0Y1]], %[[ADD_X1X1_Y1Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK: %[[DIV1:.*]] = llvm.fdiv %[[SUB_Y0X1_X0Y1]], %[[ADD_X1X1_Y1Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
// CHECK: %{{.*}} = llvm.mlir.undef : !llvm.struct<(f128, f128)>
// CHECK: %{{.*}} = llvm.insertvalue %[[DIV0]], %{{.*}}[0] : !llvm.struct<(f128, f128)>
// CHECK: %{{.*}} = llvm.insertvalue %[[DIV1]], %{{.*}}[1] : !llvm.struct<(f128, f128)>
diff --git a/flang/test/Lower/HLFIR/binary-ops.f90 b/flang/test/Lower/HLFIR/binary-ops.f90
index 8db6da3de81b291..6b89577cc54581b 100644
--- a/flang/test/Lower/HLFIR/binary-ops.f90
+++ b/flang/test/Lower/HLFIR/binary-ops.f90
@@ -32,7 +32,7 @@ subroutine complex_add(x, y, z)
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<!fir.complex<4>>) -> (!fir.ref<!fir.complex<4>>, !fir.ref<!fir.complex<4>>)
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<!fir.complex<4>>
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<!fir.complex<4>>
-! CHECK: %[[VAL_8:.*]] = fir.addc %[[VAL_6]], %[[VAL_7]] : !fir.complex<4>
+! CHECK: %[[VAL_8:.*]] = fir.addc %[[VAL_6]], %[[VAL_7]] {fastmath = #arith.fastmath<contract>} : !fir.complex<4>
subroutine int_sub(x, y, z)
integer :: x, y, z
@@ -65,7 +65,7 @@ subroutine complex_sub(x, y, z)
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<!fir.complex<4>>) -> (!fir.ref<!fir.complex<4>>, !fir.ref<!fir.complex<4>>)
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<!fir.complex<4>>
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<!fir.complex<4>>
-! CHECK: %[[VAL_8:.*]] = fir.subc %[[VAL_6]], %[[VAL_7]] : !fir.complex<4>
+! CHECK: %[[VAL_8:.*]] = fir.subc %[[VAL_6]], %[[VAL_7]] {fastmath = #arith.fastmath<contract>} : !fir.complex<4>
subroutine int_mul(x, y, z)
integer :: x, y, z
@@ -98,7 +98,7 @@ subroutine complex_mul(x, y, z)
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<!fir.complex<4>>) -> (!fir.ref<!fir.complex<4>>, !fir.ref<!fir.complex<4>>)
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<!fir.complex<4>>
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<!fir.complex<4>>
-! CHECK: %[[VAL_8:.*]] = fir.mulc %[[VAL_6]], %[[VAL_7]] : !fir.complex<4>
+! CHECK: %[[VAL_8:.*]] = fir.mulc %[[VAL_6]], %[[VAL_7]] {fastmath = #arith.fastmath<contract>} : !fir.complex<4>
subroutine int_div(x, y, z)
integer :: x, y, z
diff --git a/flang/test/Lower/OpenACC/acc-reduction.f90 b/flang/test/Lower/OpenACC/acc-reduction.f90
index b874d5219625df8..8671c280c2fb314 100644
--- a/flang/test/Lower/OpenACC/acc-reduction.f90
+++ b/flang/test/Lower/OpenACC/acc-reduction.f90
@@ -163,7 +163,7 @@
! CHECK: ^bb0(%[[ARG0:.*]]: !fir.ref<!fir.complex<4>>, %[[ARG1:.*]]: !fir.ref<!fir.complex<4>>):
! CHECK: %[[LOAD0:.*]] = fir.load %[[ARG0]] : !fir.ref<!fir.complex<4>>
! CHECK: %[[LOAD1:.*]] = fir.load %[[ARG1]] : !fir.ref<!fir.complex<4>>
-! CHECK: %[[COMBINED:.*]] = fir.mulc %[[LOAD0]], %[[LOAD1]] : !fir.complex<4>
+! CHECK: %[[COMBINED:.*]] = fir.mulc %[[LOAD0]], %[[LOAD1]] {fastmath = #arith.fastmath<contract>} : !fir.complex<4>
! CHECK: fir.store %[[COMBINED]] to %[[ARG0]] : !fir.ref<!fir.complex<4>>
! CHECK: acc.yield %[[ARG0]] : !fir.ref<!fir.complex<4>>
! CHECK: }
@@ -183,7 +183,7 @@
! CHECK: ^bb0(%[[ARG0:.*]]: !fir.ref<!fir.complex<4>>, %[[ARG1:.*]]: !fir.ref<!fir.complex<4>>):
! CHECK: %[[LOAD0:.*]] = fir.load %[[ARG0]] : !fir.ref<!fir.complex<4>>
! CHECK: %[[LOAD1:.*]] = fir.load %[[ARG1]] : !fir.ref<!fir.complex<4>>
-! CHECK: %[[COMBINED:.*]] = fir.addc %[[LOAD0]], %[[LOAD1]] : !fir.complex<4>
+! CHECK: %[[COMBINED:.*]] = fir.addc %[[LOAD0]], %[[LOAD1]] {fastmath = #arith.fastmath<contract>} : !fir.complex<4>
! CHECK: fir.store %[[COMBINED]] to %[[ARG0]] : !fir.ref<!fir.complex<4>>
! CHECK: acc.yield %[[ARG0]] : !fir.ref<!fir.complex<4>>
! CHECK: }
diff --git a/flang/test/Lower/array-elemental-calls-2.f90 b/flang/test/Lower/array-elemental-calls-2.f90
index 94e24a9910bc267..0d6e34c6391c3df 100644
--- a/flang/test/Lower/array-elemental-calls-2.f90
+++ b/flang/test/Lower/array-elemental-calls-2.f90
@@ -144,7 +144,7 @@ subroutine check_cmplx_part()
! CHECK: %[[VAL_13:.*]] = fir.load %{{.*}} : !fir.ref<!fir.complex<8>>
! CHECK: fir.do_loop
! CHECK: %[[VAL_23:.*]] = fir.array_fetch %{{.*}}, %{{.*}} : (!fir.array<10x!fir.complex<8>>, index) -> !fir.complex<8>
-! CHECK: %[[VAL_24:.*]] = fir.addc %[[VAL_23]], %[[VAL_13]] : !fir.complex<8>
+! CHECK: %[[VAL_24:.*]] = fir.addc %[[VAL_23]], %[[VAL_13]] {fastmath = #arith.fastmath<contract>} : !fir.complex<8>
! CHECK: %[[VAL_25:.*]] = fir.extract_value %[[VAL_24]], [1 : index] : (!fir.complex<8>) -> f64
! CHECK: fir.call @_QPelem_func_real(%[[VAL_25]]) {{.*}}: (f64) -> i32
end subroutine
diff --git a/flang/test/Lower/assignment.f90 b/flang/test/Lower/assignment.f90
index 9b5039e3ea88ebd..058842828d2687a 100644
--- a/flang/test/Lower/assignment.f90
+++ b/flang/test/Lower/assignment.f90
@@ -203,7 +203,7 @@ real function divf(a, b)
! CHECK: %[[FCTRES:.*]] = fir.alloca !fir.complex<4>
! CHECK: %[[A_VAL:.*]] = fir.load %[[A]] : !fir.ref<!fir.complex<4>>
! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref<!fir.complex<4>>
-! CHECK: %[[ADD:.*]] = fir.addc %[[A_VAL]], %[[B_VAL]] : !fir.complex<4>
+! CHECK: %[[ADD:.*]] = fir.addc %[[A_VAL]], %[[B_VAL]] {fastmath = #arith.fastmath<contract>} : !fir.complex<4>
! CHECK: fir.store %[[ADD]] to %[[FCTRES]] : !fir.ref...
[truncated]
|
The patch that added the same in MLIR complex is the following. That patch creates an LLVM::FastmathFlagsAttr and passes that during construction of the LLVM Ops rather than setting it separately. https://reviews.llvm.org/D156310
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG. Thanks @tblah
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thank you, Tom!
Propagate fast math flags through complex number lowering (when lowering fir.*c directly to llvm floating point operations).
The lowering path through the MLIR complex dialect is unchanged.
This leads to a small improvement in spec2017 fotonik3d_r.