Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang] Add fastmath attributes to complex arithmetic #70690

Merged
merged 3 commits into from
Oct 31, 2023

Conversation

tblah
Copy link
Contributor

@tblah tblah commented Oct 30, 2023

Propagate fast math flags through complex number lowering (when lowering fir.*c directly to llvm floating point operations).

The lowering path through the MLIR complex dialect is unchanged.

This leads to a small improvement in spec2017 fotonik3d_r.

These attributes (when propagated to LLVM) allow multiple operations to
be merged into one e.g. fused-multiply-add.

I will add support for these attributes in CodeGen in my next patch.
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir openacc flang:codegen labels Oct 30, 2023
@llvmbot
Copy link

llvmbot commented Oct 30, 2023

@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-openacc

@llvm/pr-subscribers-flang-codegen

Author: Tom Eccles (tblah)

Changes

Propagate fast math flags through complex number lowering (when lowering fir.*c directly to llvm floating point operations).

The lowering path through the MLIR complex dialect is unchanged.

This leads to a small improvement in spec2017 fotonik3d_r.


Patch is 21.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/70690.diff

7 Files Affected:

  • (modified) flang/include/flang/Optimizer/Dialect/FIROps.td (+12-6)
  • (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+26)
  • (modified) flang/test/Fir/convert-to-llvm.fir (+25-25)
  • (modified) flang/test/Lower/HLFIR/binary-ops.f90 (+3-3)
  • (modified) flang/test/Lower/OpenACC/acc-reduction.f90 (+2-2)
  • (modified) flang/test/Lower/array-elemental-calls-2.f90 (+1-1)
  • (modified) flang/test/Lower/assignment.f90 (+3-3)
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index dd2e90c3b1a1fde..6e8064a63b7ae0a 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2538,12 +2538,18 @@ def fir_NegcOp : ComplexUnaryArithmeticOp<"negc">;
 
 class ComplexArithmeticOp<string mnemonic, list<Trait> traits = []> :
       fir_ArithmeticOp<mnemonic, traits>,
-      Arguments<(ins fir_ComplexType:$lhs, fir_ComplexType:$rhs)>;
-
-def fir_AddcOp : ComplexArithmeticOp<"addc", [Commutative]>;
-def fir_SubcOp : ComplexArithmeticOp<"subc">;
-def fir_MulcOp : ComplexArithmeticOp<"mulc", [Commutative]>;
-def fir_DivcOp : ComplexArithmeticOp<"divc">;
+      Arguments<(ins fir_ComplexType:$lhs, fir_ComplexType:$rhs,
+          DefaultValuedAttr<Arith_FastMathAttr,
+                            "::mlir::arith::FastMathFlags::none">:$fastmath)>;
+
+def fir_AddcOp : ComplexArithmeticOp<"addc",
+    [Commutative, DeclareOpInterfaceMethods<ArithFastMathInterface>]>;
+def fir_SubcOp : ComplexArithmeticOp<"subc",
+    [DeclareOpInterfaceMethods<ArithFastMathInterface>]>;
+def fir_MulcOp : ComplexArithmeticOp<"mulc",
+    [Commutative, DeclareOpInterfaceMethods<ArithFastMathInterface>]>;
+def fir_DivcOp : ComplexArithmeticOp<"divc",
+    [DeclareOpInterfaceMethods<ArithFastMathInterface>]>;
 // Pow is a builtin call and not a primitive
 
 def fir_CmpcOp : fir_Op<"cmpc",
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 0f85f89f1a48138..3f6f2b0474d44b4 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -33,6 +33,7 @@
 #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/Transforms/AddComdats.h"
 #include "mlir/Dialect/OpenACC/OpenACC.h"
@@ -3502,6 +3503,8 @@ static mlir::LLVM::InsertValueOp
 complexSum(OPTY sumop, mlir::ValueRange opnds,
            mlir::ConversionPatternRewriter &rewriter,
            const fir::LLVMTypeConverter &lowering) {
+  mlir::LLVM::FastmathFlags fastmathFlags =
+      mlir::arith::convertArithFastMathFlagsToLLVM(sumop.getFastmath());
   mlir::Value a = opnds[0];
   mlir::Value b = opnds[1];
   auto loc = sumop.getLoc();
@@ -3512,7 +3515,9 @@ complexSum(OPTY sumop, mlir::ValueRange opnds,
   auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 0);
   auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 1);
   auto rx = rewriter.create<LLVMOP>(loc, eleTy, x0, x1);
+  rx.setFastmathFlags(fastmathFlags);
   auto ry = rewriter.create<LLVMOP>(loc, eleTy, y0, y1);
+  ry.setFastmathFlags(fastmathFlags);
   auto r0 = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
   auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r0, rx, 0);
   return rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ry, 1);
@@ -3560,6 +3565,8 @@ struct MulcOpConversion : public FIROpConversion<fir::MulcOp> {
     // TODO: Can we use a call to __muldc3 ?
     // given: (x + iy) * (x' + iy')
     // result: (xx'-yy')+i(xy'+yx')
+    mlir::LLVM::FastmathFlags fastmathFlags =
+        mlir::arith::convertArithFastMathFlagsToLLVM(mulc.getFastmath());
     mlir::Value a = adaptor.getOperands()[0];
     mlir::Value b = adaptor.getOperands()[1];
     auto loc = mulc.getLoc();
@@ -3570,11 +3577,17 @@ struct MulcOpConversion : public FIROpConversion<fir::MulcOp> {
     auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 0);
     auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 1);
     auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1);
+    xx.setFastmathFlags(fastmathFlags);
     auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1);
+    yx.setFastmathFlags(fastmathFlags);
     auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1);
+    xy.setFastmathFlags(fastmathFlags);
     auto ri = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xy, yx);
+    ri.setFastmathFlags(fastmathFlags);
     auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1);
+    yy.setFastmathFlags(fastmathFlags);
     auto rr = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, xx, yy);
+    rr.setFastmathFlags(fastmathFlags);
     auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
     auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ra, rr, 0);
     auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ri, 1);
@@ -3594,6 +3607,8 @@ struct DivcOpConversion : public FIROpConversion<fir::DivcOp> {
     // Just generate inline code for now.
     // given: (x + iy) / (x' + iy')
     // result: ((xx'+yy')/d) + i((yx'-xy')/d) where d = x'x' + y'y'
+    mlir::LLVM::FastmathFlags fastmathFlags =
+        mlir::arith::convertArithFastMathFlagsToLLVM(divc.getFastmath());
     mlir::Value a = adaptor.getOperands()[0];
     mlir::Value b = adaptor.getOperands()[1];
     auto loc = divc.getLoc();
@@ -3604,16 +3619,27 @@ struct DivcOpConversion : public FIROpConversion<fir::DivcOp> {
     auto x1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 0);
     auto y1 = rewriter.create<mlir::LLVM::ExtractValueOp>(loc, b, 1);
     auto xx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, x1);
+    xx.setFastmathFlags(fastmathFlags);
     auto x1x1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x1, x1);
+    x1x1.setFastmathFlags(fastmathFlags);
     auto yx = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, x1);
+    yx.setFastmathFlags(fastmathFlags);
     auto xy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, x0, y1);
+    xy.setFastmathFlags(fastmathFlags);
     auto yy = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y0, y1);
+    yy.setFastmathFlags(fastmathFlags);
     auto y1y1 = rewriter.create<mlir::LLVM::FMulOp>(loc, eleTy, y1, y1);
+    y1y1.setFastmathFlags(fastmathFlags);
     auto d = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, x1x1, y1y1);
+    d.setFastmathFlags(fastmathFlags);
     auto rrn = rewriter.create<mlir::LLVM::FAddOp>(loc, eleTy, xx, yy);
+    rrn.setFastmathFlags(fastmathFlags);
     auto rin = rewriter.create<mlir::LLVM::FSubOp>(loc, eleTy, yx, xy);
+    rin.setFastmathFlags(fastmathFlags);
     auto rr = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rrn, d);
+    rr.setFastmathFlags(fastmathFlags);
     auto ri = rewriter.create<mlir::LLVM::FDivOp>(loc, eleTy, rin, d);
+    ri.setFastmathFlags(fastmathFlags);
     auto ra = rewriter.create<mlir::LLVM::UndefOp>(loc, ty);
     auto r1 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, ra, rr, 0);
     auto r0 = rewriter.create<mlir::LLVM::InsertValueOp>(loc, r1, ri, 1);
diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir
index cecfbff7eac228b..c9a44914b987053 100644
--- a/flang/test/Fir/convert-to-llvm.fir
+++ b/flang/test/Fir/convert-to-llvm.fir
@@ -507,7 +507,7 @@ func.func @test_call_return_val() -> i32 {
 // result: (x + x') + i(y + y')
 
 func.func @fir_complex_add(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.complex<16> {
-  %c = fir.addc %a, %b : !fir.complex<16>
+  %c = fir.addc %a, %b {fastmath = #arith.fastmath<fast>} : !fir.complex<16>
   return %c : !fir.complex<16>
 }
 
@@ -518,8 +518,8 @@ func.func @fir_complex_add(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
 // CHECK:         %[[Y0:.*]] = llvm.extractvalue %[[ARG0]][1] : !llvm.struct<(f128, f128)>
 // CHECK:         %[[X1:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.struct<(f128, f128)>
 // CHECK:         %[[Y1:.*]] = llvm.extractvalue %[[ARG1]][1] : !llvm.struct<(f128, f128)>
-// CHECK:         %[[ADD_X0_X1:.*]] = llvm.fadd %[[X0]], %[[X1]]  : f128
-// CHECK:         %[[ADD_Y0_Y1:.*]] = llvm.fadd %[[Y0]], %[[Y1]]  : f128
+// CHECK:         %[[ADD_X0_X1:.*]] = llvm.fadd %[[X0]], %[[X1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[ADD_Y0_Y1:.*]] = llvm.fadd %[[Y0]], %[[Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
 // CHECK:         %{{.*}} = llvm.mlir.undef : !llvm.struct<(f128, f128)>
 // CHECK:         %{{.*}} = llvm.insertvalue %[[ADD_X0_X1]], %{{.*}}[0] : !llvm.struct<(f128, f128)>
 // CHECK:         %{{.*}} = llvm.insertvalue %[[ADD_Y0_Y1]], %{{.*}}[1] : !llvm.struct<(f128, f128)>
@@ -532,7 +532,7 @@ func.func @fir_complex_add(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
 // result: (x - x') + i(y - y')
 
 func.func @fir_complex_sub(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.complex<16> {
-  %c = fir.subc %a, %b : !fir.complex<16>
+  %c = fir.subc %a, %b {fastmath = #arith.fastmath<fast>} : !fir.complex<16>
   return %c : !fir.complex<16>
 }
 
@@ -543,8 +543,8 @@ func.func @fir_complex_sub(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
 // CHECK:         %[[Y0:.*]] = llvm.extractvalue %[[ARG0]][1] : !llvm.struct<(f128, f128)>
 // CHECK:         %[[X1:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.struct<(f128, f128)>
 // CHECK:         %[[Y1:.*]] = llvm.extractvalue %[[ARG1]][1] : !llvm.struct<(f128, f128)>
-// CHECK:         %[[SUB_X0_X1:.*]] = llvm.fsub %[[X0]], %[[X1]]  : f128
-// CHECK:         %[[SUB_Y0_Y1:.*]] = llvm.fsub %[[Y0]], %[[Y1]]  : f128
+// CHECK:         %[[SUB_X0_X1:.*]] = llvm.fsub %[[X0]], %[[X1]] {fastmathFlags = #llvm.fastmath<fast>}  : f128
+// CHECK:         %[[SUB_Y0_Y1:.*]] = llvm.fsub %[[Y0]], %[[Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
 // CHECK:         %{{.*}} = llvm.mlir.undef : !llvm.struct<(f128, f128)>
 // CHECK:         %{{.*}} = llvm.insertvalue %[[SUB_X0_X1]], %{{.*}}[0] : !llvm.struct<(f128, f128)>
 // CHECK:         %{{.*}} = llvm.insertvalue %[[SUB_Y0_Y1]], %{{.*}}[1] : !llvm.struct<(f128, f128)>
@@ -557,7 +557,7 @@ func.func @fir_complex_sub(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
 // result: (xx'-yy')+i(xy'+yx')
 
 func.func @fir_complex_mul(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.complex<16> {
-  %c = fir.mulc %a, %b : !fir.complex<16>
+  %c = fir.mulc %a, %b {fastmath = #arith.fastmath<fast>} : !fir.complex<16>
   return %c : !fir.complex<16>
 }
 
@@ -568,12 +568,12 @@ func.func @fir_complex_mul(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
 // CHECK:         %[[Y0:.*]] = llvm.extractvalue %[[ARG0]][1] : !llvm.struct<(f128, f128)>
 // CHECK:         %[[X1:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.struct<(f128, f128)>
 // CHECK:         %[[Y1:.*]] = llvm.extractvalue %[[ARG1]][1] : !llvm.struct<(f128, f128)>
-// CHECK:         %[[MUL_X0_X1:.*]] = llvm.fmul %[[X0]], %[[X1]]  : f128
-// CHECK:         %[[MUL_Y0_X1:.*]] = llvm.fmul %[[Y0]], %[[X1]]  : f128
-// CHECK:         %[[MUL_X0_Y1:.*]] = llvm.fmul %[[X0]], %[[Y1]]  : f128
-// CHECK:         %[[ADD:.*]] = llvm.fadd %[[MUL_X0_Y1]], %[[MUL_Y0_X1]]  : f128
-// CHECK:         %[[MUL_Y0_Y1:.*]] = llvm.fmul %[[Y0]], %[[Y1]]  : f128
-// CHECK:         %[[SUB:.*]] = llvm.fsub %[[MUL_X0_X1]], %[[MUL_Y0_Y1]]  : f128
+// CHECK:         %[[MUL_X0_X1:.*]] = llvm.fmul %[[X0]], %[[X1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[MUL_Y0_X1:.*]] = llvm.fmul %[[Y0]], %[[X1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[MUL_X0_Y1:.*]] = llvm.fmul %[[X0]], %[[Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[ADD:.*]] = llvm.fadd %[[MUL_X0_Y1]], %[[MUL_Y0_X1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[MUL_Y0_Y1:.*]] = llvm.fmul %[[Y0]], %[[Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[SUB:.*]] = llvm.fsub %[[MUL_X0_X1]], %[[MUL_Y0_Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
 // CHECK:         %{{.*}} = llvm.mlir.undef : !llvm.struct<(f128, f128)>
 // CHECK:         %{{.*}} = llvm.insertvalue %[[SUB]], %{{.*}}[0] : !llvm.struct<(f128, f128)>
 // CHECK:         %{{.*}} = llvm.insertvalue %[[ADD]], %{{.*}}[1] : !llvm.struct<(f128, f128)>
@@ -586,7 +586,7 @@ func.func @fir_complex_mul(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
 // result: ((xx'+yy')/d) + i((yx'-xy')/d) where d = x'x' + y'y'
 
 func.func @fir_complex_div(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.complex<16> {
-  %c = fir.divc %a, %b : !fir.complex<16>
+  %c = fir.divc %a, %b {fastmath = #arith.fastmath<fast>} : !fir.complex<16>
   return %c : !fir.complex<16>
 }
 
@@ -597,17 +597,17 @@ func.func @fir_complex_div(%a: !fir.complex<16>, %b: !fir.complex<16>) -> !fir.c
 // CHECK:         %[[Y0:.*]] = llvm.extractvalue %[[ARG0]][1] : !llvm.struct<(f128, f128)>
 // CHECK:         %[[X1:.*]] = llvm.extractvalue %[[ARG1]][0] : !llvm.struct<(f128, f128)>
 // CHECK:         %[[Y1:.*]] = llvm.extractvalue %[[ARG1]][1] : !llvm.struct<(f128, f128)>
-// CHECK:         %[[MUL_X0_X1:.*]] = llvm.fmul %[[X0]], %[[X1]]  : f128
-// CHECK:         %[[MUL_X1_X1:.*]] = llvm.fmul %[[X1]], %[[X1]]  : f128
-// CHECK:         %[[MUL_Y0_X1:.*]] = llvm.fmul %[[Y0]], %[[X1]]  : f128
-// CHECK:         %[[MUL_X0_Y1:.*]] = llvm.fmul %[[X0]], %[[Y1]]  : f128
-// CHECK:         %[[MUL_Y0_Y1:.*]] = llvm.fmul %[[Y0]], %[[Y1]]  : f128
-// CHECK:         %[[MUL_Y1_Y1:.*]] = llvm.fmul %[[Y1]], %[[Y1]]  : f128
-// CHECK:         %[[ADD_X1X1_Y1Y1:.*]] = llvm.fadd %[[MUL_X1_X1]], %[[MUL_Y1_Y1]]  : f128
-// CHECK:         %[[ADD_X0X1_Y0Y1:.*]] = llvm.fadd %[[MUL_X0_X1]], %[[MUL_Y0_Y1]]  : f128
-// CHECK:         %[[SUB_Y0X1_X0Y1:.*]] = llvm.fsub %[[MUL_Y0_X1]], %[[MUL_X0_Y1]]  : f128
-// CHECK:         %[[DIV0:.*]] = llvm.fdiv %[[ADD_X0X1_Y0Y1]], %[[ADD_X1X1_Y1Y1]]  : f128
-// CHECK:         %[[DIV1:.*]] = llvm.fdiv %[[SUB_Y0X1_X0Y1]], %[[ADD_X1X1_Y1Y1]]  : f128
+// CHECK:         %[[MUL_X0_X1:.*]] = llvm.fmul %[[X0]], %[[X1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[MUL_X1_X1:.*]] = llvm.fmul %[[X1]], %[[X1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[MUL_Y0_X1:.*]] = llvm.fmul %[[Y0]], %[[X1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[MUL_X0_Y1:.*]] = llvm.fmul %[[X0]], %[[Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[MUL_Y0_Y1:.*]] = llvm.fmul %[[Y0]], %[[Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[MUL_Y1_Y1:.*]] = llvm.fmul %[[Y1]], %[[Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[ADD_X1X1_Y1Y1:.*]] = llvm.fadd %[[MUL_X1_X1]], %[[MUL_Y1_Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[ADD_X0X1_Y0Y1:.*]] = llvm.fadd %[[MUL_X0_X1]], %[[MUL_Y0_Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[SUB_Y0X1_X0Y1:.*]] = llvm.fsub %[[MUL_Y0_X1]], %[[MUL_X0_Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[DIV0:.*]] = llvm.fdiv %[[ADD_X0X1_Y0Y1]], %[[ADD_X1X1_Y1Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
+// CHECK:         %[[DIV1:.*]] = llvm.fdiv %[[SUB_Y0X1_X0Y1]], %[[ADD_X1X1_Y1Y1]] {fastmathFlags = #llvm.fastmath<fast>} : f128
 // CHECK:         %{{.*}} = llvm.mlir.undef : !llvm.struct<(f128, f128)>
 // CHECK:         %{{.*}} = llvm.insertvalue %[[DIV0]], %{{.*}}[0] : !llvm.struct<(f128, f128)>
 // CHECK:         %{{.*}} = llvm.insertvalue %[[DIV1]], %{{.*}}[1] : !llvm.struct<(f128, f128)>
diff --git a/flang/test/Lower/HLFIR/binary-ops.f90 b/flang/test/Lower/HLFIR/binary-ops.f90
index 8db6da3de81b291..6b89577cc54581b 100644
--- a/flang/test/Lower/HLFIR/binary-ops.f90
+++ b/flang/test/Lower/HLFIR/binary-ops.f90
@@ -32,7 +32,7 @@ subroutine complex_add(x, y, z)
 ! CHECK:  %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<!fir.complex<4>>) -> (!fir.ref<!fir.complex<4>>, !fir.ref<!fir.complex<4>>)
 ! CHECK:  %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<!fir.complex<4>>
 ! CHECK:  %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<!fir.complex<4>>
-! CHECK:  %[[VAL_8:.*]] = fir.addc %[[VAL_6]], %[[VAL_7]] : !fir.complex<4>
+! CHECK:  %[[VAL_8:.*]] = fir.addc %[[VAL_6]], %[[VAL_7]] {fastmath = #arith.fastmath<contract>} : !fir.complex<4>
 
 subroutine int_sub(x, y, z)
  integer :: x, y, z
@@ -65,7 +65,7 @@ subroutine complex_sub(x, y, z)
 ! CHECK:  %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<!fir.complex<4>>) -> (!fir.ref<!fir.complex<4>>, !fir.ref<!fir.complex<4>>)
 ! CHECK:  %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<!fir.complex<4>>
 ! CHECK:  %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<!fir.complex<4>>
-! CHECK:  %[[VAL_8:.*]] = fir.subc %[[VAL_6]], %[[VAL_7]] : !fir.complex<4>
+! CHECK:  %[[VAL_8:.*]] = fir.subc %[[VAL_6]], %[[VAL_7]] {fastmath = #arith.fastmath<contract>} : !fir.complex<4>
 
 subroutine int_mul(x, y, z)
  integer :: x, y, z
@@ -98,7 +98,7 @@ subroutine complex_mul(x, y, z)
 ! CHECK:  %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<!fir.complex<4>>) -> (!fir.ref<!fir.complex<4>>, !fir.ref<!fir.complex<4>>)
 ! CHECK:  %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<!fir.complex<4>>
 ! CHECK:  %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<!fir.complex<4>>
-! CHECK:  %[[VAL_8:.*]] = fir.mulc %[[VAL_6]], %[[VAL_7]] : !fir.complex<4>
+! CHECK:  %[[VAL_8:.*]] = fir.mulc %[[VAL_6]], %[[VAL_7]] {fastmath = #arith.fastmath<contract>} : !fir.complex<4>
 
 subroutine int_div(x, y, z)
  integer :: x, y, z
diff --git a/flang/test/Lower/OpenACC/acc-reduction.f90 b/flang/test/Lower/OpenACC/acc-reduction.f90
index b874d5219625df8..8671c280c2fb314 100644
--- a/flang/test/Lower/OpenACC/acc-reduction.f90
+++ b/flang/test/Lower/OpenACC/acc-reduction.f90
@@ -163,7 +163,7 @@
 ! CHECK: ^bb0(%[[ARG0:.*]]: !fir.ref<!fir.complex<4>>, %[[ARG1:.*]]: !fir.ref<!fir.complex<4>>):
 ! CHECK:   %[[LOAD0:.*]] = fir.load %[[ARG0]] : !fir.ref<!fir.complex<4>>
 ! CHECK:   %[[LOAD1:.*]] = fir.load %[[ARG1]] : !fir.ref<!fir.complex<4>>
-! CHECK:   %[[COMBINED:.*]] = fir.mulc %[[LOAD0]], %[[LOAD1]] : !fir.complex<4>
+! CHECK:   %[[COMBINED:.*]] = fir.mulc %[[LOAD0]], %[[LOAD1]] {fastmath = #arith.fastmath<contract>} : !fir.complex<4>
 ! CHECK:   fir.store %[[COMBINED]] to %[[ARG0]] : !fir.ref<!fir.complex<4>>
 ! CHECK:   acc.yield %[[ARG0]] : !fir.ref<!fir.complex<4>>
 ! CHECK: }
@@ -183,7 +183,7 @@
 ! CHECK: ^bb0(%[[ARG0:.*]]: !fir.ref<!fir.complex<4>>, %[[ARG1:.*]]: !fir.ref<!fir.complex<4>>):
 ! CHECK:   %[[LOAD0:.*]] = fir.load %[[ARG0]] : !fir.ref<!fir.complex<4>>
 ! CHECK:   %[[LOAD1:.*]] = fir.load %[[ARG1]] : !fir.ref<!fir.complex<4>>
-! CHECK:   %[[COMBINED:.*]] = fir.addc %[[LOAD0]], %[[LOAD1]] : !fir.complex<4>
+! CHECK:   %[[COMBINED:.*]] = fir.addc %[[LOAD0]], %[[LOAD1]] {fastmath = #arith.fastmath<contract>} : !fir.complex<4>
 ! CHECK:   fir.store %[[COMBINED]] to %[[ARG0]] : !fir.ref<!fir.complex<4>>
 ! CHECK:   acc.yield %[[ARG0]] : !fir.ref<!fir.complex<4>>
 ! CHECK: }
diff --git a/flang/test/Lower/array-elemental-calls-2.f90 b/flang/test/Lower/array-elemental-calls-2.f90
index 94e24a9910bc267..0d6e34c6391c3df 100644
--- a/flang/test/Lower/array-elemental-calls-2.f90
+++ b/flang/test/Lower/array-elemental-calls-2.f90
@@ -144,7 +144,7 @@ subroutine check_cmplx_part()
 ! CHECK:  %[[VAL_13:.*]] = fir.load %{{.*}} : !fir.ref<!fir.complex<8>>
 ! CHECK:  fir.do_loop
 ! CHECK:  %[[VAL_23:.*]] = fir.array_fetch %{{.*}}, %{{.*}} : (!fir.array<10x!fir.complex<8>>, index) -> !fir.complex<8>
-! CHECK:  %[[VAL_24:.*]] = fir.addc %[[VAL_23]], %[[VAL_13]] : !fir.complex<8>
+! CHECK:  %[[VAL_24:.*]] = fir.addc %[[VAL_23]], %[[VAL_13]] {fastmath = #arith.fastmath<contract>} : !fir.complex<8>
 ! CHECK:  %[[VAL_25:.*]] = fir.extract_value %[[VAL_24]], [1 : index] : (!fir.complex<8>) -> f64
 ! CHECK:  fir.call @_QPelem_func_real(%[[VAL_25]]) {{.*}}: (f64) -> i32
 end subroutine
diff --git a/flang/test/Lower/assignment.f90 b/flang/test/Lower/assignment.f90
index 9b5039e3ea88ebd..058842828d2687a 100644
--- a/flang/test/Lower/assignment.f90
+++ b/flang/test/Lower/assignment.f90
@@ -203,7 +203,7 @@ real function divf(a, b)
 ! CHECK:         %[[FCTRES:.*]] = fir.alloca !fir.complex<4>
 ! CHECK:         %[[A_VAL:.*]] = fir.load %[[A]] : !fir.ref<!fir.complex<4>>
 ! CHECK:         %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref<!fir.complex<4>>
-! CHECK:         %[[ADD:.*]] = fir.addc %[[A_VAL]], %[[B_VAL]] : !fir.complex<4>
+! CHECK:         %[[ADD:.*]] = fir.addc %[[A_VAL]], %[[B_VAL]] {fastmath = #arith.fastmath<contract>} : !fir.complex<4>
 ! CHECK:         fir.store %[[ADD]] to %[[FCTRES]] : !fir.ref...
[truncated]

@kiranchandramohan
Copy link
Contributor

The patch that added the same in MLIR complex is the following. That patch creates an LLVM::FastmathFlagsAttr and passes that during construction of the LLVM Ops rather than setting it separately.

https://reviews.llvm.org/D156310

    arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
    LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
        op.getContext(),
        convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
    Value real =
        rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
    Value imag =
        rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
    result.setReal(rewriter, loc, real);
    result.setImaginary(rewriter, loc, imag);

Copy link
Contributor

@kiranchandramohan kiranchandramohan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG. Thanks @tblah

Copy link
Contributor

@vzakhari vzakhari left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you, Tom!

@tblah tblah merged commit 6be0e97 into llvm:main Oct 31, 2023
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:codegen flang:fir-hlfir flang Flang issues not falling into any other category openacc
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants