diff --git a/flang/test/Lower/Intrinsics/abs.f90 b/flang/test/Lower/Intrinsics/abs.f90 index d2288a140ad43..7986eeee00030 100644 --- a/flang/test/Lower/Intrinsics/abs.f90 +++ b/flang/test/Lower/Intrinsics/abs.f90 @@ -4,7 +4,7 @@ ! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s --check-prefixes="CHECK,CMPLX,CMPLX-PRECISE" ! RUN: %flang_fc1 -emit-fir -mllvm --math-runtime=precise %s -o - | FileCheck %s --check-prefixes="CMPLX,CMPLX-PRECISE" ! RUN: %flang_fc1 -emit-fir -mllvm --force-mlir-complex %s -o - | FileCheck %s --check-prefixes="CMPLX,CMPLX-FAST" -! RUN: %flang_fc1 -fapprox-func -emit-fir %s -o - | FileCheck %s --check-prefixes="CMPLX,CMPLX-FAST" +! RUN: %flang_fc1 -fapprox-func -emit-fir %s -o - | FileCheck %s --check-prefixes="CMPLX,CMPLX-APPROX" ! Test abs intrinsic for various types (int, float, complex) @@ -100,7 +100,9 @@ subroutine abs_testr16(a, b) subroutine abs_testzr(a, b) ! CMPLX: %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref> ! CMPLX-FAST: %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (!fir.complex<4>) -> complex -! CMPLX-FAST: %[[VAL_4:.*]] = complex.abs %[[VAL_3]] : complex +! CMPLX-FAST: %[[VAL_4:.*]] = complex.abs %[[VAL_3]] fastmath : complex +! CMPLX-APPROX: %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (!fir.complex<4>) -> complex +! CMPLX-APPROX: %[[VAL_4:.*]] = complex.abs %[[VAL_3]] fastmath : complex ! CMPLX-PRECISE: %[[VAL_4:.*]] = fir.call @cabsf(%[[VAL_2]]) {{.*}}: (!fir.complex<4>) -> f32 ! CMPLX: fir.store %[[VAL_4]] to %[[VAL_1]] : !fir.ref ! CMPLX: return @@ -114,7 +116,9 @@ end subroutine abs_testzr subroutine abs_testzd(a, b) ! CMPLX: %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref> ! CMPLX-FAST: %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (!fir.complex<8>) -> complex -! CMPLX-FAST: %[[VAL_4:.*]] = complex.abs %[[VAL_3]] : complex +! CMPLX-FAST: %[[VAL_4:.*]] = complex.abs %[[VAL_3]] fastmath : complex +! CMPLX-APPROX: %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (!fir.complex<8>) -> complex +! CMPLX-APPROX: %[[VAL_4:.*]] = complex.abs %[[VAL_3]] fastmath : complex ! CMPLX-PRECISE: %[[VAL_4:.*]] = fir.call @cabs(%[[VAL_2]]) {{.*}}: (!fir.complex<8>) -> f64 ! CMPLX: fir.store %[[VAL_4]] to %[[VAL_1]] : !fir.ref ! CMPLX: return diff --git a/flang/test/Lower/Intrinsics/exp.f90 b/flang/test/Lower/Intrinsics/exp.f90 index f128e548edeb0..49ed25a5b8e6c 100644 --- a/flang/test/Lower/Intrinsics/exp.f90 +++ b/flang/test/Lower/Intrinsics/exp.f90 @@ -2,7 +2,7 @@ ! RUN: bbc -emit-fir --math-runtime=precise -outline-intrinsics %s -o - | FileCheck %s --check-prefixes="CMPLX,CMPLX-PRECISE" ! RUN: bbc -emit-fir --force-mlir-complex -outline-intrinsics %s -o - | FileCheck %s --check-prefixes="CMPLX,CMPLX-FAST,CMPLX-MLIR" ! RUN: %flang_fc1 -emit-fir -mllvm -outline-intrinsics %s -o - | FileCheck %s --check-prefixes="CHECK,CMPLX,CMPLX-PRECISE" -! RUN: %flang_fc1 -fapprox-func -emit-fir -mllvm -outline-intrinsics %s -o - | FileCheck %s --check-prefixes="CMPLX,CMPLX-FAST,CMPLX-APPROX" +! RUN: %flang_fc1 -fapprox-func -emit-fir -mllvm -outline-intrinsics %s -o - | FileCheck %s --check-prefixes="CMPLX,CMPLX-APPROX" ! RUN: %flang_fc1 -emit-fir -mllvm -outline-intrinsics -mllvm --math-runtime=precise %s -o - | FileCheck %s --check-prefixes="CMPLX,CMPLX-PRECISE" ! RUN: %flang_fc1 -emit-fir -mllvm -outline-intrinsics -mllvm --force-mlir-complex %s -o - | FileCheck %s --check-prefixes="CMPLX,CMPLX-FAST,CMPLX-MLIR" @@ -61,8 +61,11 @@ subroutine exp_testcd(a, b) ! CMPLX-MLIR-LABEL: private @fir.exp.contract.z4.z4 ! CMPLX-SAME: (%[[ARG32_OUTLINE:.*]]: !fir.complex<4>) -> !fir.complex<4> ! CMPLX-FAST: %[[C:.*]] = fir.convert %[[ARG32_OUTLINE]] : (!fir.complex<4>) -> complex -! CMPLX-FAST: %[[E:.*]] = complex.exp %[[C]] : complex +! CMPLX-FAST: %[[E:.*]] = complex.exp %[[C]] fastmath : complex ! CMPLX-FAST: %[[RESULT32_OUTLINE:.*]] = fir.convert %[[E]] : (complex) -> !fir.complex<4> +! CMPLX-APPROX: %[[C:.*]] = fir.convert %[[ARG32_OUTLINE]] : (!fir.complex<4>) -> complex +! CMPLX-APPROX: %[[E:.*]] = complex.exp %[[C]] fastmath : complex +! CMPLX-APPROX: %[[RESULT32_OUTLINE:.*]] = fir.convert %[[E]] : (complex) -> !fir.complex<4> ! CMPLX-PRECISE: %[[RESULT32_OUTLINE:.*]] = fir.call @cexpf(%[[ARG32_OUTLINE]]) fastmath : (!fir.complex<4>) -> !fir.complex<4> ! CMPLX: return %[[RESULT32_OUTLINE]] : !fir.complex<4> @@ -71,7 +74,10 @@ subroutine exp_testcd(a, b) ! CMPLX-MLIR-LABEL: private @fir.exp.contract.z8.z8 ! CMPLX-SAME: (%[[ARG64_OUTLINE:.*]]: !fir.complex<8>) -> !fir.complex<8> ! CMPLX-FAST: %[[C:.*]] = fir.convert %[[ARG64_OUTLINE]] : (!fir.complex<8>) -> complex -! CMPLX-FAST: %[[E:.*]] = complex.exp %[[C]] : complex +! CMPLX-FAST: %[[E:.*]] = complex.exp %[[C]] fastmath : complex ! CMPLX-FAST: %[[RESULT64_OUTLINE:.*]] = fir.convert %[[E]] : (complex) -> !fir.complex<8> +! CMPLX-APPROX: %[[C:.*]] = fir.convert %[[ARG64_OUTLINE]] : (!fir.complex<8>) -> complex +! CMPLX-APPROX: %[[E:.*]] = complex.exp %[[C]] fastmath : complex +! CMPLX-APPROX: %[[RESULT64_OUTLINE:.*]] = fir.convert %[[E]] : (complex) -> !fir.complex<8> ! CMPLX-PRECISE: %[[RESULT64_OUTLINE:.*]] = fir.call @cexp(%[[ARG64_OUTLINE]]) fastmath : (!fir.complex<8>) -> !fir.complex<8> ! CMPLX: return %[[RESULT64_OUTLINE]] : !fir.complex<8> diff --git a/flang/test/Lower/Intrinsics/log.f90 b/flang/test/Lower/Intrinsics/log.f90 index 49be4d968c890..08dbd4218d64f 100644 --- a/flang/test/Lower/Intrinsics/log.f90 +++ b/flang/test/Lower/Intrinsics/log.f90 @@ -4,7 +4,7 @@ ! RUN: %flang_fc1 -emit-fir -mllvm -outline-intrinsics %s -o - | FileCheck %s --check-prefixes="CHECK,CMPLX,CMPLX-PRECISE" ! RUN: %flang_fc1 -emit-fir -mllvm -outline-intrinsics -mllvm --math-runtime=precise %s -o - | FileCheck %s --check-prefixes="CMPLX,CMPLX-PRECISE" ! RUN: %flang_fc1 -emit-fir -mllvm -outline-intrinsics -mllvm --force-mlir-complex %s -o - | FileCheck %s --check-prefixes="CMPLX,CMPLX-FAST,CMPLX-MLIR" -! RUN: %flang_fc1 -fapprox-func -emit-fir -mllvm -outline-intrinsics %s -o - | FileCheck %s --check-prefixes="CMPLX,CMPLX-FAST,CMPLX-APPROX" +! RUN: %flang_fc1 -fapprox-func -emit-fir -mllvm -outline-intrinsics %s -o - | FileCheck %s --check-prefixes="CMPLX,CMPLX-APPROX" ! CHECK-LABEL: log_testr ! CHECK-SAME: (%[[AREF:.*]]: !fir.ref {{.*}}, %[[BREF:.*]]: !fir.ref {{.*}}) @@ -81,8 +81,11 @@ subroutine log10_testd(a, b) ! CMPLX-MLIR-LABEL: private @fir.log.contract.z4.z4 ! CMPLX-SAME: (%[[ARG32_OUTLINE:.*]]: !fir.complex<4>) -> !fir.complex<4> ! CMPLX-FAST: %[[C:.*]] = fir.convert %[[ARG32_OUTLINE]] : (!fir.complex<4>) -> complex -! CMPLX-FAST: %[[E:.*]] = complex.log %[[C]] : complex +! CMPLX-FAST: %[[E:.*]] = complex.log %[[C]] fastmath : complex ! CMPLX-FAST: %[[RESULT32_OUTLINE:.*]] = fir.convert %[[E]] : (complex) -> !fir.complex<4> +! CMPLX-APPROX: %[[C:.*]] = fir.convert %[[ARG32_OUTLINE]] : (!fir.complex<4>) -> complex +! CMPLX-APPROX: %[[E:.*]] = complex.log %[[C]] fastmath : complex +! CMPLX-APPROX: %[[RESULT32_OUTLINE:.*]] = fir.convert %[[E]] : (complex) -> !fir.complex<4> ! CMPLX-PRECISE: %[[RESULT32_OUTLINE:.*]] = fir.call @clogf(%[[ARG32_OUTLINE]]) fastmath : (!fir.complex<4>) -> !fir.complex<4> ! CMPLX: return %[[RESULT32_OUTLINE]] : !fir.complex<4> @@ -91,8 +94,11 @@ subroutine log10_testd(a, b) ! CMPLX-MLIR-LABEL: private @fir.log.contract.z8.z8 ! CMPLX-SAME: (%[[ARG64_OUTLINE:.*]]: !fir.complex<8>) -> !fir.complex<8> ! CMPLX-FAST: %[[C:.*]] = fir.convert %[[ARG64_OUTLINE]] : (!fir.complex<8>) -> complex -! CMPLX-FAST: %[[E:.*]] = complex.log %[[C]] : complex +! CMPLX-FAST: %[[E:.*]] = complex.log %[[C]] fastmath : complex ! CMPLX-FAST: %[[RESULT64_OUTLINE:.*]] = fir.convert %[[E]] : (complex) -> !fir.complex<8> +! CMPLX-APPROX: %[[C:.*]] = fir.convert %[[ARG64_OUTLINE]] : (!fir.complex<8>) -> complex +! CMPLX-APPROX: %[[E:.*]] = complex.log %[[C]] fastmath : complex +! CMPLX-APPROX: %[[RESULT64_OUTLINE:.*]] = fir.convert %[[E]] : (complex) -> !fir.complex<8> ! CMPLX-PRECISE: %[[RESULT64_OUTLINE:.*]] = fir.call @clog(%[[ARG64_OUTLINE]]) fastmath : (!fir.complex<8>) -> !fir.complex<8> ! CMPLX: return %[[RESULT64_OUTLINE]] : !fir.complex<8> diff --git a/flang/test/Lower/complex-operations.f90 b/flang/test/Lower/complex-operations.f90 index c686671c7a112..42cdac0dc2a21 100644 --- a/flang/test/Lower/complex-operations.f90 +++ b/flang/test/Lower/complex-operations.f90 @@ -33,7 +33,7 @@ end subroutine mul_test ! CHECK: %[[CVAL:.*]] = fir.load %[[CREF]] : !fir.ref> ! CHECK: %[[BVAL_CVT:.*]] = fir.convert %[[BVAL]] : (!fir.complex<2>) -> complex ! CHECK: %[[CVAL_CVT:.*]] = fir.convert %[[CVAL]] : (!fir.complex<2>) -> complex -! CHECK: %[[AVAL_CVT:.*]] = complex.div %[[BVAL_CVT]], %[[CVAL_CVT]] : complex +! CHECK: %[[AVAL_CVT:.*]] = complex.div %[[BVAL_CVT]], %[[CVAL_CVT]] fastmath : complex ! CHECK: %[[AVAL:.*]] = fir.convert %[[AVAL_CVT]] : (complex) -> !fir.complex<2> ! CHECK: fir.store %[[AVAL]] to %[[AREF]] : !fir.ref> subroutine div_test_half(a,b,c) @@ -47,7 +47,7 @@ end subroutine div_test_half ! CHECK: %[[CVAL:.*]] = fir.load %[[CREF]] : !fir.ref> ! CHECK: %[[BVAL_CVT:.*]] = fir.convert %[[BVAL]] : (!fir.complex<3>) -> complex ! CHECK: %[[CVAL_CVT:.*]] = fir.convert %[[CVAL]] : (!fir.complex<3>) -> complex -! CHECK: %[[AVAL_CVT:.*]] = complex.div %[[BVAL_CVT]], %[[CVAL_CVT]] : complex +! CHECK: %[[AVAL_CVT:.*]] = complex.div %[[BVAL_CVT]], %[[CVAL_CVT]] fastmath : complex ! CHECK: %[[AVAL:.*]] = fir.convert %[[AVAL_CVT]] : (complex) -> !fir.complex<3> ! CHECK: fir.store %[[AVAL]] to %[[AREF]] : !fir.ref> subroutine div_test_bfloat(a,b,c) diff --git a/mlir/include/mlir/Dialect/Complex/IR/Complex.h b/mlir/include/mlir/Dialect/Complex/IR/Complex.h index 663e81c71d860..fb024fa2e951e 100644 --- a/mlir/include/mlir/Dialect/Complex/IR/Complex.h +++ b/mlir/include/mlir/Dialect/Complex/IR/Complex.h @@ -10,6 +10,7 @@ #define MLIR_DIALECT_COMPLEX_IR_COMPLEX_H_ #include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/InferTypeOpInterface.h" diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td index b80d77996a20f..a829fa88efa89 100644 --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -9,6 +9,8 @@ #ifndef COMPLEX_OPS #define COMPLEX_OPS +include "mlir/Dialect/Arith/IR/ArithBase.td" +include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td" include "mlir/Dialect/Complex/IR/ComplexBase.td" include "mlir/IR/OpAsmInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" @@ -22,19 +24,21 @@ class Complex_Op traits = []> // one result, all of which must be complex numbers of the same type. class ComplexArithmeticOp traits = []> : Complex_Op { - let arguments = (ins Complex:$lhs, Complex:$rhs); + Elementwise, DeclareOpInterfaceMethods]> { + let arguments = (ins Complex:$lhs, Complex:$rhs, DefaultValuedAttr< + Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath); let results = (outs Complex:$result); - let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)"; + let assemblyFormat = "$lhs `,` $rhs (`fastmath` `` $fastmath^)? attr-dict `:` type($result)"; } // Base class for standard unary operations on complex numbers with a // floating-point element type. These operations take one operand and return // one result; the operand must be a complex number. class ComplexUnaryOp traits = []> : - Complex_Op { - let arguments = (ins Complex:$complex); - let assemblyFormat = "$complex attr-dict `:` type($complex)"; + Complex_Op]> { + let arguments = (ins Complex:$complex, DefaultValuedAttr< + Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath); + let assemblyFormat = "$complex (`fastmath` `` $fastmath^)? attr-dict `:` type($complex)"; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp index b219de7fef8f8..68406125ba526 100644 --- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp +++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp @@ -9,6 +9,7 @@ #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" +#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -23,6 +24,7 @@ namespace mlir { using namespace mlir; using namespace mlir::LLVM; +using namespace mlir::arith; //===----------------------------------------------------------------------===// // ComplexStructBuilder implementation. @@ -73,7 +75,10 @@ struct AbsOpConversion : public ConvertOpToLLVMPattern { Value real = complexStruct.real(rewriter, op.getLoc()); Value imag = complexStruct.imaginary(rewriter, op.getLoc()); - auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {}); + arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr(); + LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( + op.getContext(), + convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); Value sqNorm = rewriter.create( loc, rewriter.create(loc, real, real, fmf), rewriter.create(loc, imag, imag, fmf), fmf); @@ -181,7 +186,10 @@ struct AddOpConversion : public ConvertOpToLLVMPattern { auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to add complex numbers. - auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {}); + arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr(); + LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( + op.getContext(), + convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); Value real = rewriter.create(loc, arg.lhs.real(), arg.rhs.real(), fmf); Value imag = @@ -209,7 +217,10 @@ struct DivOpConversion : public ConvertOpToLLVMPattern { auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to add complex numbers. - auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {}); + arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr(); + LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( + op.getContext(), + convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); Value rhsRe = arg.rhs.real(); Value rhsIm = arg.rhs.imag(); Value lhsRe = arg.lhs.real(); @@ -254,7 +265,10 @@ struct MulOpConversion : public ConvertOpToLLVMPattern { auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to add complex numbers. - auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {}); + arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr(); + LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( + op.getContext(), + convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); Value rhsRe = arg.rhs.real(); Value rhsIm = arg.rhs.imag(); Value lhsRe = arg.lhs.real(); @@ -291,7 +305,10 @@ struct SubOpConversion : public ConvertOpToLLVMPattern { auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to substract complex numbers. - auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {}); + arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr(); + LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( + op.getContext(), + convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue())); Value real = rewriter.create(loc, arg.lhs.real(), arg.rhs.real(), fmf); Value imag = diff --git a/mlir/lib/Dialect/Complex/IR/CMakeLists.txt b/mlir/lib/Dialect/Complex/IR/CMakeLists.txt index a90f34ec1684d..3ee0d26f3225f 100644 --- a/mlir/lib/Dialect/Complex/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Complex/IR/CMakeLists.txt @@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRComplexDialect MLIRComplexAttributesIncGen LINK_LIBS PUBLIC + MLIRArithAttrToLLVMConversion MLIRArithDialect MLIRDialect MLIRInferTypeOpInterface diff --git a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir index 3b8ed25d6073c..a60b974e374d3 100644 --- a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir @@ -154,3 +154,125 @@ func.func @complex_abs(%arg: complex) -> f32 { // CHECK: %[[NORM:.*]] = llvm.intr.sqrt(%[[SQ_NORM]]) : (f32) -> f32 // CHECK: return %[[NORM]] : f32 +// CHECK-LABEL: func @complex_addition_with_fmf +// CHECK-DAG: %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm.struct<(f64, f64)> +// CHECK-DAG: %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm.struct<(f64, f64)> +// CHECK-DAG: %[[A_IMAG:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(f64, f64)> +// CHECK-DAG: %[[B_IMAG:.*]] = llvm.extractvalue %[[B]][1] : !llvm.struct<(f64, f64)> +// CHECK: %[[C0:.*]] = llvm.mlir.undef : !llvm.struct<(f64, f64)> +// CHECK-DAG: %[[C_REAL:.*]] = llvm.fadd %[[A_REAL]], %[[B_REAL]] {fastmathFlags = #llvm.fastmath} : f64 +// CHECK-DAG: %[[C_IMAG:.*]] = llvm.fadd %[[A_IMAG]], %[[B_IMAG]] {fastmathFlags = #llvm.fastmath} : f64 +// CHECK: %[[C1:.*]] = llvm.insertvalue %[[C_REAL]], %[[C0]][0] : !llvm.struct<(f64, f64)> +// CHECK: %[[C2:.*]] = llvm.insertvalue %[[C_IMAG]], %[[C1]][1] : !llvm.struct<(f64, f64)> +func.func @complex_addition_with_fmf() { + %a_re = arith.constant 1.2 : f64 + %a_im = arith.constant 3.4 : f64 + %a = complex.create %a_re, %a_im : complex + %b_re = arith.constant 5.6 : f64 + %b_im = arith.constant 7.8 : f64 + %b = complex.create %b_re, %b_im : complex + %c = complex.add %a, %b fastmath : complex + return +} + +// CHECK-LABEL: func @complex_substraction_with_fmf +// CHECK-DAG: %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm.struct<(f64, f64)> +// CHECK-DAG: %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm.struct<(f64, f64)> +// CHECK-DAG: %[[A_IMAG:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(f64, f64)> +// CHECK-DAG: %[[B_IMAG:.*]] = llvm.extractvalue %[[B]][1] : !llvm.struct<(f64, f64)> +// CHECK: %[[C0:.*]] = llvm.mlir.undef : !llvm.struct<(f64, f64)> +// CHECK-DAG: %[[C_REAL:.*]] = llvm.fsub %[[A_REAL]], %[[B_REAL]] {fastmathFlags = #llvm.fastmath} : f64 +// CHECK-DAG: %[[C_IMAG:.*]] = llvm.fsub %[[A_IMAG]], %[[B_IMAG]] {fastmathFlags = #llvm.fastmath} : f64 +// CHECK: %[[C1:.*]] = llvm.insertvalue %[[C_REAL]], %[[C0]][0] : !llvm.struct<(f64, f64)> +// CHECK: %[[C2:.*]] = llvm.insertvalue %[[C_IMAG]], %[[C1]][1] : !llvm.struct<(f64, f64)> +func.func @complex_substraction_with_fmf() { + %a_re = arith.constant 1.2 : f64 + %a_im = arith.constant 3.4 : f64 + %a = complex.create %a_re, %a_im : complex + %b_re = arith.constant 5.6 : f64 + %b_im = arith.constant 7.8 : f64 + %b = complex.create %b_re, %b_im : complex + %c = complex.sub %a, %b fastmath : complex + return +} + +// CHECK-LABEL: func @complex_div_with_fmf +// CHECK-SAME: %[[LHS:.*]]: complex, %[[RHS:.*]]: complex +// CHECK-DAG: %[[CASTED_LHS:.*]] = builtin.unrealized_conversion_cast %[[LHS]] : complex to ![[C_TY:.*>]] +// CHECK-DAG: %[[CASTED_RHS:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : complex to ![[C_TY]] + +// CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[CASTED_LHS]][0] : ![[C_TY]] +// CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[CASTED_LHS]][1] : ![[C_TY]] +// CHECK: %[[RHS_RE:.*]] = llvm.extractvalue %[[CASTED_RHS]][0] : ![[C_TY]] +// CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[CASTED_RHS]][1] : ![[C_TY]] + +// CHECK: %[[RESULT_0:.*]] = llvm.mlir.undef : ![[C_TY]] + +// CHECK-DAG: %[[RHS_RE_SQ:.*]] = llvm.fmul %[[RHS_RE]], %[[RHS_RE]] {fastmathFlags = #llvm.fastmath} : f32 +// CHECK-DAG: %[[RHS_IM_SQ:.*]] = llvm.fmul %[[RHS_IM]], %[[RHS_IM]] {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[RHS_RE_SQ]], %[[RHS_IM_SQ]] {fastmathFlags = #llvm.fastmath} : f32 + +// CHECK-DAG: %[[REAL_TMP_0:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_RE]] {fastmathFlags = #llvm.fastmath} : f32 +// CHECK-DAG: %[[REAL_TMP_1:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_IM]] {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: %[[REAL_TMP_2:.*]] = llvm.fadd %[[REAL_TMP_0]], %[[REAL_TMP_1]] {fastmathFlags = #llvm.fastmath} : f32 + +// CHECK-DAG: %[[IMAG_TMP_0:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_RE]] {fastmathFlags = #llvm.fastmath} : f32 +// CHECK-DAG: %[[IMAG_TMP_1:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_IM]] {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: %[[IMAG_TMP_2:.*]] = llvm.fsub %[[IMAG_TMP_0]], %[[IMAG_TMP_1]] {fastmathFlags = #llvm.fastmath} : f32 + +// CHECK: %[[REAL:.*]] = llvm.fdiv %[[REAL_TMP_2]], %[[SQ_NORM]] {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0] : ![[C_TY]] +// CHECK: %[[IMAG:.*]] = llvm.fdiv %[[IMAG_TMP_2]], %[[SQ_NORM]] {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1] : ![[C_TY]] +// +// CHECK: %[[CASTED_RESULT:.*]] = builtin.unrealized_conversion_cast %[[RESULT_2]] : ![[C_TY]] to complex +// CHECK: return %[[CASTED_RESULT]] : complex +func.func @complex_div_with_fmf(%lhs: complex, %rhs: complex) -> complex { + %div = complex.div %lhs, %rhs fastmath : complex + return %div : complex +} + + +// CHECK-LABEL: func @complex_mul_with_fmf +// CHECK-SAME: %[[LHS:.*]]: complex, %[[RHS:.*]]: complex +// CHECK-DAG: %[[CASTED_LHS:.*]] = builtin.unrealized_conversion_cast %[[LHS]] : complex to ![[C_TY:.*>]] +// CHECK-DAG: %[[CASTED_RHS:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : complex to ![[C_TY]] + +// CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[CASTED_LHS]][0] : ![[C_TY]] +// CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[CASTED_LHS]][1] : ![[C_TY]] +// CHECK: %[[RHS_RE:.*]] = llvm.extractvalue %[[CASTED_RHS]][0] : ![[C_TY]] +// CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[CASTED_RHS]][1] : ![[C_TY]] +// CHECK: %[[RESULT_0:.*]] = llvm.mlir.undef : ![[C_TY]] + +// CHECK-DAG: %[[REAL_TMP_0:.*]] = llvm.fmul %[[RHS_RE]], %[[LHS_RE]] {fastmathFlags = #llvm.fastmath} : f32 +// CHECK-DAG: %[[REAL_TMP_1:.*]] = llvm.fmul %[[RHS_IM]], %[[LHS_IM]] {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: %[[REAL:.*]] = llvm.fsub %[[REAL_TMP_0]], %[[REAL_TMP_1]] {fastmathFlags = #llvm.fastmath} : f32 + +// CHECK-DAG: %[[IMAG_TMP_0:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_RE]] {fastmathFlags = #llvm.fastmath} : f32 +// CHECK-DAG: %[[IMAG_TMP_1:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_IM]] {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: %[[IMAG:.*]] = llvm.fadd %[[IMAG_TMP_0]], %[[IMAG_TMP_1]] {fastmathFlags = #llvm.fastmath} : f32 + +// CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0] +// CHECK: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1] + +// CHECK: %[[CASTED_RESULT:.*]] = builtin.unrealized_conversion_cast %[[RESULT_2]] : ![[C_TY]] to complex +// CHECK: return %[[CASTED_RESULT]] : complex +func.func @complex_mul_with_fmf(%lhs: complex, %rhs: complex) -> complex { + %mul = complex.mul %lhs, %rhs fastmath : complex + return %mul : complex +} + +// CHECK-LABEL: func @complex_abs_with_fmf +// CHECK-SAME: %[[ARG:.*]]: complex +// CHECK: %[[CASTED_ARG:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : complex to ![[C_TY:.*>]] +// CHECK: %[[REAL:.*]] = llvm.extractvalue %[[CASTED_ARG]][0] : ![[C_TY]] +// CHECK: %[[IMAG:.*]] = llvm.extractvalue %[[CASTED_ARG]][1] : ![[C_TY]] +// CHECK-DAG: %[[REAL_SQ:.*]] = llvm.fmul %[[REAL]], %[[REAL]] {fastmathFlags = #llvm.fastmath} : f32 +// CHECK-DAG: %[[IMAG_SQ:.*]] = llvm.fmul %[[IMAG]], %[[IMAG]] {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[REAL_SQ]], %[[IMAG_SQ]] {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: %[[NORM:.*]] = llvm.intr.sqrt(%[[SQ_NORM]]) : (f32) -> f32 +// CHECK: return %[[NORM]] : f32 +func.func @complex_abs_with_fmf(%arg: complex) -> f32 { + %abs = complex.abs %arg fastmath : complex + return %abs : f32 +}