Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Add intrinsics to represent complex multiply and divide operations #68742

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

jcranmer-intel
Copy link
Contributor

This patch represents the first in a series of patches to bring a more standardized version of complex values into LLVM. Representation of the complex multiply and division instructions are added as intrinsics, and their precise behavior (with regards to potential range overflow) is controlled via attributes and fast-math flags.

With the three commits that are added here, the intrinsics are specified in LLVM IR, methods to construct them are added in IR builder, and CodeGen is implemented, both to expand them into libcalls (to __mulsc3/__divsc3 and friends) or branchy code, or to use existing complex multiply instructions. CodeGen is only verified correct for the x86 platform, though. Later commits are not included in the PR, but available for viewing at https://github.com/jcranmer-intel/llvm-project/tree/complex-patches, which adds support for pattern matching complex multiply intrinsics in InstCombine, and also adds uses of these intrinsics in the clang frontend.

These changes were previously present on Phabricator at https://reviews.llvm.org/D119284, https://reviews.llvm.org/D119286, and https://reviews.llvm.org/D119287.

…ions.

This patch represents the first in a series of patches to bring a more
standardized version of complex values into LLVM. Representation of the complex
multiply and division instructions are added as intrinsics, and their precise
behavior is controlled via attributes and fast-math flags.
For architectures without complex multiply or divide intrinsics (most of them),
a pass is needed to expand these intrinsics before codegen.

The tricky thing here is that where the intrinsics need to expand into a
compiler-rt helper function (e.g., __mulsc3), the ABI of complex floating point
types needs to be retrieved from the target. However, this target information
isn't fully validated for all targets, only x86.

This also adds support for lowering the complex multiply intrinsic directly to
instructions for the x86 backend.
@llvmbot
Copy link
Collaborator

llvmbot commented Oct 10, 2023

@llvm/pr-subscribers-llvm-selectiondag
@llvm/pr-subscribers-llvm-ir
@llvm/pr-subscribers-backend-loongarch
@llvm/pr-subscribers-backend-amdgpu
@llvm/pr-subscribers-backend-x86
@llvm/pr-subscribers-backend-aarch64

@llvm/pr-subscribers-backend-arm

Author: Joshua Cranmer (jcranmer-intel)

Changes

This patch represents the first in a series of patches to bring a more standardized version of complex values into LLVM. Representation of the complex multiply and division instructions are added as intrinsics, and their precise behavior (with regards to potential range overflow) is controlled via attributes and fast-math flags.

With the three commits that are added here, the intrinsics are specified in LLVM IR, methods to construct them are added in IR builder, and CodeGen is implemented, both to expand them into libcalls (to __mulsc3/__divsc3 and friends) or branchy code, or to use existing complex multiply instructions. CodeGen is only verified correct for the x86 platform, though. Later commits are not included in the PR, but available for viewing at https://github.com/jcranmer-intel/llvm-project/tree/complex-patches, which adds support for pattern matching complex multiply intrinsics in InstCombine, and also adds uses of these intrinsics in the clang frontend.

These changes were previously present on Phabricator at https://reviews.llvm.org/D119284, https://reviews.llvm.org/D119286, and https://reviews.llvm.org/D119287.


Patch is 105.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/68742.diff

39 Files Affected:

  • (modified) llvm/docs/LangRef.rst (+171)
  • (added) llvm/include/llvm/CodeGen/ExpandComplex.h (+22)
  • (modified) llvm/include/llvm/CodeGen/ISDOpcodes.h (+3)
  • (modified) llvm/include/llvm/CodeGen/Passes.h (+6)
  • (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+19)
  • (modified) llvm/include/llvm/IR/IRBuilder.h (+37)
  • (modified) llvm/include/llvm/IR/Intrinsics.td (+10)
  • (modified) llvm/include/llvm/InitializePasses.h (+1)
  • (modified) llvm/include/llvm/Target/TargetSelectionDAG.td (+1)
  • (modified) llvm/lib/CodeGen/CMakeLists.txt (+1)
  • (added) llvm/lib/CodeGen/ExpandComplex.cpp (+294)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (+2)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+6)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp (+1)
  • (modified) llvm/lib/CodeGen/TargetPassConfig.cpp (+4)
  • (modified) llvm/lib/IR/IRBuilder.cpp (+32)
  • (modified) llvm/lib/IR/Verifier.cpp (+12)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+135)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.h (+4)
  • (modified) llvm/test/CodeGen/AArch64/O0-pipeline.ll (+1)
  • (modified) llvm/test/CodeGen/AArch64/O3-pipeline.ll (+1)
  • (modified) llvm/test/CodeGen/AMDGPU/llc-pipeline.ll (+5)
  • (modified) llvm/test/CodeGen/ARM/O3-pipeline.ll (+1)
  • (modified) llvm/test/CodeGen/LoongArch/O0-pipeline.ll (+1)
  • (modified) llvm/test/CodeGen/LoongArch/opt-pipeline.ll (+1)
  • (modified) llvm/test/CodeGen/PowerPC/O0-pipeline.ll (+1)
  • (modified) llvm/test/CodeGen/PowerPC/O3-pipeline.ll (+1)
  • (modified) llvm/test/CodeGen/RISCV/O0-pipeline.ll (+1)
  • (modified) llvm/test/CodeGen/RISCV/O3-pipeline.ll (+1)
  • (modified) llvm/test/CodeGen/X86/O0-pipeline.ll (+1)
  • (added) llvm/test/CodeGen/X86/complex-32bit.ll (+173)
  • (added) llvm/test/CodeGen/X86/complex-64bit.ll (+103)
  • (added) llvm/test/CodeGen/X86/complex-divide.ll (+92)
  • (added) llvm/test/CodeGen/X86/complex-multiply.ll (+525)
  • (added) llvm/test/CodeGen/X86/complex-win32.ll (+59)
  • (added) llvm/test/CodeGen/X86/complex-win64.ll (+44)
  • (added) llvm/test/CodeGen/X86/fp16-complex-multiply.ll (+231)
  • (modified) llvm/test/CodeGen/X86/opt-pipeline.ll (+1)
  • (added) llvm/test/Verifier/complex-intrinsics.ll (+39)
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 1883e9f6290b151..3d6323cee63b193 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -18448,6 +18448,177 @@ will be on any later loop iteration.
 This intrinsic will only return 0 if the input count is also 0. A non-zero input
 count will produce a non-zero result.
 
+Complex Intrinsics
+------------------
+
+Complex numbers are currently represented, for intrinsic purposes, as vectors of
+floating-point numbers. A scalar complex type is represented using the type
+``<2 x floatty>``, with index ``0`` corresponding to the real part of the number
+and index ``1`` corresponding the imaginary part of the number. A vector complex
+type can be represented by an even-length vector of floating-point numbers,
+with even indices (``0``, ``2``, etc.) corresponding to real parts of numbers
+and the indices one larger (``1``, ``3``, etc.) the corresponding imaginary
+parts.
+
+The precise semantics of these intrinsics depends on the value of the
+``complex-range`` attribute provided as a call-site attribute. This attribute
+takes on three possible values:
+
+``"full"``
+  The semantics has the full expansion as given in Annex G of the C
+  specification. In general, this means it needs to be expanded using the call
+  to the appropriate routine in compiler-rt (e.g., __mulsc3).
+
+``"no-nan"``
+  This code is permitted to allow complex infinities to be represented as NaNs
+  instead, as if the code for the appropriate routine were compiled in a manner
+  that allowed ``isnan(x)`` or ``isinf(x)`` to be optimized as false.
+
+``"limited"``
+  The semantics are equivalent to the naive arithmetic expansion operations
+  (specific expansion is detailed for each arithmetic expression).
+
+When this attribute is not present, it is presumed to be ``"full"`` if no
+fast-math flags are set, and ``"no-nan"`` if ``nnan`` or ``ninf`` flags are
+present.
+
+Fast-math flags are additionally relevant for these intrinsics, particularly in
+the case of ``complex-range=limited`` variants, as those will be likely to be
+expanded in code generation and fast-math flags will propagate to the expanded
+IR in such circumstances.
+
+Intrinsics for complex addition and subtraction are not provided, as these are
+equivalent to ``fadd`` and ``fsub`` instructions, respectively.
+
+'``llvm.experimental.complex.fmul.*``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+This is an overloaded intrinsic.
+
+::
+
+      declare <2 x float> @llvm.experimental.complex.fmul.v2f32(<2 x float> <op1>, <2 x float> <op2>)
+      declare <2 x double> @llvm.experimental.complex.fmul.v2f64(<2 x double> <op1>, <2 x double> <op2>)
+      declare <4 x float> @llvm.experimental.complex.fmul.v4f32(<4 x float> <op1>, <4 x float> <op2>)
+
+Overview:
+"""""""""
+
+The '``llvm.experimental.complex.fmul``' intrinsic returns the product of its
+two operands.
+
+Arguments:
+""""""""""
+
+The arguments to the '``llvm.experimental.complex.fmul``' intrinsic must be a
+:ref:`vector <t_vector>` of :ref:`floating-point <t_floating>` types of length
+divisible by 2.
+
+Semantics:
+""""""""""
+
+The value produced is the complex product of the two inputs.
+
+If the value of ``complex-range`` attribute is ``no-nan`` or ``limited``, or if
+the ``noinf`` or ``nonan`` fast math flags are provided, the output may be
+equivalent to the following code:
+
+.. code-block:: llvm
+
+      declare <2 x float> limited_complex_mul(<2 x float> %op1, <2 x float> %op2) {
+        %x = extractelement <2 x float> %op1, i32 0 ; real of %op1
+        %y = extractelement <2 x float> %op1, i32 1 ; imag of %op1
+        %u = extractelement <2 x float> %op2, i32 0 ; real of %op2
+        %v = extractelement <2 x float> %op2, i32 1 ; imag of %op2
+        %xu = fmul float %x, %u
+        %yv = fmul float %y, %v
+        %yu = fmul float %y, %u
+        %xv = fmul float %x, %v
+        %out_real = fsub float %xu, %yv
+        %out_imag = fadd float %yu, %xv
+        %ret.0 = insertelement <2 x float> undef, i32 0, %out_real
+        %ret.1 = insertelement <2 x float> %ret.0, i32 1, %out_imag
+        return <2 x float> %ret.1
+      }
+
+When the ``complex-range`` attribute is set to ``full`` or is missing, the above
+code is insufficient to handle the result. Instead, code must be added to
+check for infinities if either the real or imaginary component of the result is
+a NaN value.
+
+
+'``llvm.experimental.complex.fdiv.*``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+This is an overloaded intrinsic.
+
+::
+
+      declare <2 x float> @llvm.experimental.complex.fdiv.v2f32(<2 x float> <op1>, <2 x float> <op2>)
+      declare <2 x double> @llvm.experimental.complex.fdiv.v2f64(<2 x double> <op1>, <2 x double> <op2>)
+      declare <4 x float> @llvm.experimental.complex.fdiv.v4f32(<4 x float> <op1>, <4 x float> <op2>)
+
+Overview:
+"""""""""
+
+The '``llvm.experimental.complex.fdiv``' intrinsic returns the quotient of its
+two operands.
+
+Arguments:
+""""""""""
+
+The arguments to the '``llvm.experimental.complex.fdiv``' intrinsic must be a
+:ref:`vector <t_vector>` of :ref:`floating-point <t_floating>` types of length
+divisible by 2.
+
+Semantics:
+""""""""""
+
+The value produced is the complex quotient of the two inputs.
+
+If the ``complex-range`` attribute is set to ``limited``, the output will be
+equivalent to the following code:
+
+.. code-block:: llvm
+
+      declare <2 x float> limited_complex_div(<2 x float> %op1, <2 x float> %op2) {
+        %x = extractelement <2 x float> %op1, i32 0 ; real of %op1
+        %y = extractelement <2 x float> %op1, i32 1 ; imag of %op1
+        %u = extractelement <2 x float> %op2, i32 0 ; real of %op2
+        %v = extractelement <2 x float> %op2, i32 1 ; imag of %op2
+        %xu = fmul float %x, %u
+        %yv = fmul float %y, %v
+        %yu = fmul float %y, %u
+        %xv = fmul float %x, %v
+        %uu = fmul float %u, %u
+        %vv = fmul float %v, %v
+        %unscaled_real = fadd float %xu, %yv
+        %unscaled_imag = fsub float %yu, %xv
+        %scale = fadd float %uu, %vv
+        %out_real = fdiv float %unscaled_real, %scale
+        %out_imag = fdiv float %unscaled_imag, %scale
+        %ret.0 = insertelement <2 x float> undef, i32 0, %out_real
+        %ret.1 = insertelement <2 x float> %ret.0, i32 1, %out_imag
+        return <2 x float> %ret.1
+      }
+
+If the ``complex-range`` attribute is set to ``no-nan`` (or the ``nnan`` or
+``ninf`` flags are specified), an additional range reduction step is necessary.
+
+If the ``complex-range`` attribute is set to ``full``, or is missing entirely,
+then an additional check is necessary after the computation that is necessary
+to recover infinites that are instead represented as NaN values.
+
+Note that when ``complex-range`` is set to ``limited``, and the code is being
+expanded to the IR provided above, the fast-math flags are duplicated onto the
+expanded code. In particular, the ``arcp`` fast math flag may also be useful, as
+it will permit the divisions to be replaced with multiplications with a
+reciprocal instead.
+
 Matrix Intrinsics
 -----------------
 
diff --git a/llvm/include/llvm/CodeGen/ExpandComplex.h b/llvm/include/llvm/CodeGen/ExpandComplex.h
new file mode 100644
index 000000000000000..0186fa75ee395ab
--- /dev/null
+++ b/llvm/include/llvm/CodeGen/ExpandComplex.h
@@ -0,0 +1,22 @@
+//===---- ExpandComplex.h - Expand experimental complex intrinsics --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CODEGEN_EXPANDCOMPLEX_H
+#define LLVM_CODEGEN_EXPANDCOMPLEX_H
+
+#include "llvm/IR/PassManager.h"
+
+namespace llvm {
+
+class ExpandComplexPass : public PassInfoMixin<ExpandComplexPass> {
+public:
+  PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
+};
+} // end namespace llvm
+
+#endif // LLVM_CODEGEN_EXPANDCOMPLEX_H
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 67779a23a191313..4f72e6bd979d77e 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1371,6 +1371,9 @@ enum NodeType {
   // Outputs: [rv], output chain, glue
   PATCHPOINT,
 
+  /// COMPLEX_MUL - Do a naive complex multiplication.
+  COMPLEX_MUL,
+
 // Vector Predication
 #define BEGIN_REGISTER_VP_SDNODE(VPSDID, ...) VPSDID,
 #include "llvm/IR/VPIntrinsics.def"
diff --git a/llvm/include/llvm/CodeGen/Passes.h b/llvm/include/llvm/CodeGen/Passes.h
index befa8a6eb9a27ce..353c053ee5d626b 100644
--- a/llvm/include/llvm/CodeGen/Passes.h
+++ b/llvm/include/llvm/CodeGen/Passes.h
@@ -506,6 +506,12 @@ namespace llvm {
   /// printing assembly.
   ModulePass *createMachineOutlinerPass(bool RunOnAllFunctions = true);
 
+  /// This pass expands the experimental complex intrinsics into regular
+  /// floating-point arithmetic or calls to __mulsc3 (or similar) functions.
+  FunctionPass *createExpandComplexPass();
+
+  /// This pass expands the experimental reduction intrinsics into sequences of
+  /// shuffles.
   /// This pass expands the reduction intrinsics into sequences of shuffles.
   FunctionPass *createExpandReductionsPass();
 
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 187e000d0272d2e..19e28999d18ec00 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -676,6 +676,24 @@ class TargetLoweringBase {
     return false;
   }
 
+  /// Enum that specifies how a C complex type is lowered (in LLVM type terms).
+  enum class ComplexABI {
+    Memory,  ///< Indicates that a pointer to the struct is passed.
+    Vector,  ///< Indicates that T _Complex can be passed as <2 x T>.
+    Struct,  ///< Indicates that T _Complex can be passed as {T, T}.
+    Integer, ///< Indicates that an integer of the same size is passed.
+  };
+
+  /// Returns how a C complex type is lowered when used as the return value.
+  virtual ComplexABI getComplexReturnABI(Type *ScalarFloatTy) const {
+    return ComplexABI::Struct;
+  }
+
+  /// Returns true if the target can match the @llvm.experimental.complex.fmul
+  /// intrinsic with the given type. Such an intrinsic is assumed will only be
+  /// matched when "complex-range" is "limited" or "no-nan".
+  virtual bool CustomLowerComplexMultiply(Type *FloatTy) const { return false; }
+
   /// Return if the target supports combining a
   /// chain like:
   /// \code
@@ -2783,6 +2801,7 @@ class TargetLoweringBase {
     case ISD::AVGCEILU:
     case ISD::ABDS:
     case ISD::ABDU:
+    case ISD::COMPLEX_MUL:
       return true;
     default: return false;
     }
diff --git a/llvm/include/llvm/IR/IRBuilder.h b/llvm/include/llvm/IR/IRBuilder.h
index c9f243fdb12e404..dacdfec0d5da756 100644
--- a/llvm/include/llvm/IR/IRBuilder.h
+++ b/llvm/include/llvm/IR/IRBuilder.h
@@ -1762,6 +1762,43 @@ class IRBuilderBase {
   Value *CreateNAryOp(unsigned Opc, ArrayRef<Value *> Ops,
                       const Twine &Name = "", MDNode *FPMathTag = nullptr);
 
+  /// Construct a complex value out of a pair of real and imaginary values.
+  /// The resulting value will be a vector, with lane 0 being the real value and
+  /// lane 1 being the complex value.
+  /// Either the \p Real or \p Imag parameter may be null, if the input is a
+  /// pure real or pure imaginary number.
+  Value *CreateComplexValue(Value *Real, Value *Imag, const Twine &Name = "") {
+    Type *ScalarTy = (Real ? Real : Imag)->getType();
+    assert(ScalarTy->isFloatingPointTy() &&
+           "Only floating-point types may be complex values.");
+    Type *ComplexTy = FixedVectorType::get(ScalarTy, 2);
+    Value *Base = PoisonValue::get(ComplexTy);
+    if (Real)
+      Base = CreateInsertElement(Base, Real, uint64_t(0), Name);
+    if (Imag)
+      Base = CreateInsertElement(Base, Imag, uint64_t(1), Name);
+    return Base;
+  }
+
+  /// Construct a complex multiply operation, setting fast-math flags and the
+  /// complex-range attribute as appropriate.
+  Value *CreateComplexMul(Value *L, Value *R, bool CxLimitedRange,
+                          const Twine &Name = "");
+
+  /// Construct a complex divide operation, setting fast-math flags and the
+  /// complex-range attribute as appropriate.
+  /// The complex-range attribute is set from the \p IgnoreNaNs and
+  /// \p DisableScaling as follows:
+  ///
+  /// \p IgnoreNans | \p DisableScaling | complex-range value
+  /// ------------- | ----------------- | -------------------
+  /// false         | false             | full
+  /// false         | true              | (illegal combination)
+  /// true          | false             | no-nan
+  /// true          | true              | limited
+  Value *CreateComplexDiv(Value *L, Value *R, bool IgnoreNaNs,
+                          bool DisableScaling = false, const Twine &Name = "");
+
   //===--------------------------------------------------------------------===//
   // Instruction creation methods: Memory Instructions
   //===--------------------------------------------------------------------===//
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index ab15b1f1e0ee888..35e3c281861dfd8 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -2350,6 +2350,16 @@ let IntrProperties = [IntrNoMem, IntrSpeculatable] in {
                                          [llvm_anyvector_ty]>;
 }
 
+//===----- Complex math intrinsics ----------------------------------------===//
+
+def int_experimental_complex_fmul: DefaultAttrsIntrinsic<[llvm_anyvector_ty],
+                                            [LLVMMatchType<0>,LLVMMatchType<0>],
+                                            [IntrNoMem]>;
+
+def int_experimental_complex_fdiv: DefaultAttrsIntrinsic<[llvm_anyvector_ty],
+                                            [LLVMMatchType<0>,LLVMMatchType<0>],
+                                            [IntrNoMem]>;
+
 //===----- Matrix intrinsics ---------------------------------------------===//
 
 def int_matrix_transpose
diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h
index db653fff71ba95a..f1855763937037a 100644
--- a/llvm/include/llvm/InitializePasses.h
+++ b/llvm/include/llvm/InitializePasses.h
@@ -111,6 +111,7 @@ void initializeEdgeBundlesPass(PassRegistry&);
 void initializeEHContGuardCatchretPass(PassRegistry &);
 void initializeExpandLargeFpConvertLegacyPassPass(PassRegistry&);
 void initializeExpandLargeDivRemLegacyPassPass(PassRegistry&);
+void initializeExpandComplexPass(PassRegistry &);
 void initializeExpandMemCmpPassPass(PassRegistry&);
 void initializeExpandPostRAPass(PassRegistry&);
 void initializeExpandReductionsPass(PassRegistry&);
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index fa5761c3a199a56..11515063cdbc4e7 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -770,6 +770,7 @@ def assertsext : SDNode<"ISD::AssertSext", SDT_assert>;
 def assertzext : SDNode<"ISD::AssertZext", SDT_assert>;
 def assertalign : SDNode<"ISD::AssertAlign", SDT_assert>;
 
+def COMPLEX_MUL : SDNode<"ISD::COMPLEX_MUL", SDTFPBinOp, [SDNPCommutative]>;
 //===----------------------------------------------------------------------===//
 // Selection DAG Condition Codes
 
diff --git a/llvm/lib/CodeGen/CMakeLists.txt b/llvm/lib/CodeGen/CMakeLists.txt
index 389c70d04f17ba3..df214361abe9588 100644
--- a/llvm/lib/CodeGen/CMakeLists.txt
+++ b/llvm/lib/CodeGen/CMakeLists.txt
@@ -68,6 +68,7 @@ add_llvm_component_library(LLVMCodeGen
   EdgeBundles.cpp
   EHContGuardCatchret.cpp
   ExecutionDomainFix.cpp
+  ExpandComplex.cpp
   ExpandLargeDivRem.cpp
   ExpandLargeFpConvert.cpp
   ExpandMemCmp.cpp
diff --git a/llvm/lib/CodeGen/ExpandComplex.cpp b/llvm/lib/CodeGen/ExpandComplex.cpp
new file mode 100644
index 000000000000000..253de368cc7d4cf
--- /dev/null
+++ b/llvm/lib/CodeGen/ExpandComplex.cpp
@@ -0,0 +1,294 @@
+//===-- ExpandComplex.cpp - Expand experimental complex intrinsics --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass implements IR expansion for complex intrinsics, allowing targets
+// to enable the intrinsics until just before codegen.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CodeGen/ExpandComplex.h"
+#include "llvm/CodeGen/Passes.h"
+#include "llvm/CodeGen/TargetLowering.h"
+#include "llvm/CodeGen/TargetPassConfig.h"
+#include "llvm/CodeGen/TargetSubtargetInfo.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Target/TargetMachine.h"
+
+using namespace llvm;
+using namespace llvm::PatternMatch;
+
+namespace {
+
+bool expandComplexInstruction(IntrinsicInst *CI, const TargetLowering *TLI,
+                              const DataLayout &DL) {
+  Intrinsic::ID Opcode = CI->getIntrinsicID();
+  assert((Opcode == Intrinsic::experimental_complex_fmul ||
+          Opcode == Intrinsic::experimental_complex_fdiv) &&
+         "Expected a complex instruction");
+
+  // Break the input values up into real and imaginary pieces.
+  Type *ComplexVectorTy = CI->getArgOperand(0)->getType();
+  Type *FloatTy = ComplexVectorTy->getScalarType();
+  IRBuilder<> Builder(CI);
+  Builder.setFastMathFlags(CI->getFastMathFlags());
+  Value *LhsR = Builder.CreateExtractElement(CI->getArgOperand(0), uint64_t(0));
+  Value *LhsI = Builder.CreateExtractElement(CI->getArgOperand(0), uint64_t(1));
+  Value *RhsR = nullptr, *RhsI = nullptr;
+  RhsR = Builder.CreateExtractElement(CI->getArgOperand(1), uint64_t(0));
+  RhsI = Builder.CreateExtractElement(CI->getArgOperand(1), uint64_t(1));
+
+  // The expansion has three pieces: the naive arithmetic, a possible prescaling
+  // (not relevant for multiplication), and a step to convert NaN output values
+  // to infinity values in certain situations (see Annex G of the C
+  // specification for more details). The "complex-range" attribute determines
+  // how many we need: "limited" has just the first one, "no-nan" the first two,
+  // and "full" for all three.
+
+  // Get the "complex-range" attribute, setting a default based on the presence
+  // of fast-math flags.
+  StringRef Range = CI->getFnAttr("complex-range").getValueAsString();
+  if (Range.empty()) {
+    Range = CI->getFastMathFlags().noNaNs() || CI->getFastMathFlags().noInfs()
+                ? "no-nan"
+                : "full";
+  }
+
+  // We can expand to naive arithmetic code if we only need the first piece. For
+  // multiplication, we can also accept "no-nan", since there is no semantic
+  // difference between "limited" and "no-nan" in that case.
+  bool CanExpand =
+      Range == "limited" ||
+      (Range == "no-nan" && Opcode == Intrinsic::experimental_complex_fmul);
+
+  Value *OutReal, *OutImag;
+  if (!CanExpand) {
+    // Do a call directly to the compiler-rt library here.
+    const char *Name = nullptr;
+    if (Opcode == Intrinsic::experimental_complex_fmul) {
+      if (FloatTy->isHalfTy())
+        Name = "__mulhc3";
+      else if (FloatTy->isFloatTy())
+        Name = "__mulsc3";
+      else if (FloatTy->isDoubleTy())
+        Name = "__muldc3";
+      else if (FloatTy->isX86_FP80Ty())
+        Name = "__mulxc3";
+      else if (FloatTy->isFP128Ty() || FloatTy->isPPC_FP128Ty())
+        Name = "__multc3";
+    } else if (Opcode == Intrinsic::experimental_complex_fdiv) {
+      if (FloatTy->isHalfTy())
+        Name = "__divhc3";
+      else if (FloatTy->isFloatTy())
+        Name = "__divsc3";
+      else ...
[truncated]

@github-actions
Copy link

github-actions bot commented Oct 10, 2023

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff d4ae7ee662d2f318c0e4105c674e0634733b48eb c43971843a566d4aeb5cddaa3e3a90c7be52390a -- llvm/include/llvm/CodeGen/ExpandComplex.h llvm/lib/CodeGen/ExpandComplex.cpp llvm/include/llvm/CodeGen/ISDOpcodes.h llvm/include/llvm/CodeGen/Passes.h llvm/include/llvm/CodeGen/TargetLowering.h llvm/include/llvm/IR/IRBuilder.h llvm/include/llvm/InitializePasses.h llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp llvm/lib/CodeGen/TargetPassConfig.cpp llvm/lib/IR/IRBuilder.cpp llvm/lib/IR/Verifier.cpp llvm/lib/Target/X86/X86ISelLowering.cpp llvm/lib/Target/X86/X86ISelLowering.h
View the diff from clang-format here.
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index cd2556d12483..65575bdb3dd3 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -513,7 +513,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
     return "stackmap";
   case ISD::PATCHPOINT:
     return "patchpoint";
-  case ISD::COMPLEX_FMUL:               return "complex_fmul";
+  case ISD::COMPLEX_FMUL:
+    return "complex_fmul";
 
     // Vector Predication
 #define BEGIN_REGISTER_VP_SDNODE(SDID, LEGALARG, NAME, ...)                    \
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 018d9497fcbf..371dd696b517 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -32019,7 +32019,8 @@ SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   case ISD::ADDRSPACECAST:      return LowerADDRSPACECAST(Op, DAG);
   case X86ISD::CVTPS2PH:        return LowerCVTPS2PH(Op, DAG);
   case ISD::PREFETCH:           return LowerPREFETCH(Op, Subtarget, DAG);
-  case ISD::COMPLEX_FMUL:       return LowerComplexMUL(Op, DAG, Subtarget);
+  case ISD::COMPLEX_FMUL:
+    return LowerComplexMUL(Op, DAG, Subtarget);
   }
 }
 

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some initial feedback

@@ -1371,6 +1371,9 @@ enum NodeType {
// Outputs: [rv], output chain, glue
PATCHPOINT,

/// COMPLEX_MUL - Do a naive complex multiplication.
COMPLEX_MUL,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be COMPLEX_FMUL?

@@ -31750,6 +31806,68 @@ bool X86TargetLowering::isInlineAsmTargetBranch(
return Inst.equals_insensitive("call") || Inst.equals_insensitive("jmp");
}

bool X86TargetLowering::CustomLowerComplexMultiply(Type *FloatTy) const {
auto VecTy = cast<FixedVectorType>(FloatTy);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto *

return ComplexABI::Integer;
} else {
return ComplexABI::Memory;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(style) remove the unnecessary braces from the method

declare <2 x x86_fp80> @llvm.experimental.complex.fmul.v2f80(<2 x x86_fp80>, <2 x x86_fp80>)
declare <2 x fp128> @llvm.experimental.complex.fmul.v2f128(<2 x fp128>, <2 x fp128>)

define <2 x half> @intrinsic_f16(<2 x half> %z, <2 x half> %w) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add nounwind to avoid the cfi noise

return Subtarget.hasAnyFMA() ||
(Subtarget.hasAVX512() && Subtarget.hasVLX());
}
return false;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we use the TLI isOperationLegal to do this? It seems to match the DAG setOperationAction setup earlier in the file.

@efriedma-quic
Copy link
Collaborator

Do we care whether the results of complex multiply/divide are correctly rounded? This isn't mentioned anywhere, but the naive expansion double-rounds the result. And if we use FMA, we still double-round, but produce a different result in some cases. (If we don't care, that's fine, I guess, but we should explicitly state the expected precision somewhere.)

How does this interact with strictfp? Do we need separate strictfp intrinsics?

@jcranmer-intel
Copy link
Contributor Author

Do we care whether the results of complex multiply/divide are correctly rounded? This isn't mentioned anywhere, but the naive expansion double-rounds the result. And if we use FMA, we still double-round, but produce a different result in some cases. (If we don't care, that's fine, I guess, but we should explicitly state the expected precision somewhere.)

All implementations of __mulsc3 I can find double-round the results. The story is quite different for __divsc3, of which there's at least 4 different implementations I've seen (naive version, naive version with scalbn, Smith's algorithm, and naive version in next-larger-float-size), all of which have different rounding implications. I view these largely like the libm intrinsics in terms of rounding guarantees.

For hardware implementations, x86-64's VFCMULCSH explicitly is implemented as fmul + fma, with intermediate rounding, whereas I can't quite tell from AArch64's manual whether or not FCMLA double-rounds or not.

How does this interact with strictfp? Do we need separate strictfp intrinsics?

This does need some form of strictfp support, but I've started to align with @arsenm in that I'm not sure that constrained intrinsics is necessarily the best way to expose strictfp support.

@efriedma-quic
Copy link
Collaborator

I can't quite tell from AArch64's manual whether or not FCMLA double-rounds or not.

Despite the name, fcmla is just an fma that "rotates" the input (bitwise manipulation of the input floats). The manual suggests you can implement a complex multiply using two fcmla instructions (so there's a rounding step between the first instruction and the second).

How does this interact with strictfp? Do we need separate strictfp intrinsics?

This does need some form of strictfp support, but I've started to align with @arsenm in that I'm not sure that constrained intrinsics is necessarily the best way to expose strictfp support.

Sure, we just need some way to indicate that it isn't actually readnone in strictfp mode...

@efriedma-quic
Copy link
Collaborator

I dug a bit into academic sources around the rounding... apparently it's a well-studied problem. If anyone else is interested in going down the rabbit hole, see https://inria.hal.science/inria-00120352v2/document ("Error Bounds on Complex Floating-Point Multiplication") and https://ens-lyon.hal.science/ensl-00649347v4/file/JeannerodLouvetMuller13.pdf ("Further analysis of Kahan’s algorithm for the accurate computation of 2 x 2 determinants").


Complex numbers are currently represented, for intrinsic purposes, as vectors of
floating-point numbers. A scalar complex type is represented using the type
``<2 x floatty>``, with index ``0`` corresponding to the real part of the number
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a struct instead?

@@ -770,6 +770,7 @@ def assertsext : SDNode<"ISD::AssertSext", SDT_assert>;
def assertzext : SDNode<"ISD::AssertZext", SDT_assert>;
def assertalign : SDNode<"ISD::AssertAlign", SDT_assert>;

def COMPLEX_FMUL : SDNode<"ISD::COMPLEX_FMUL", SDTFPBinOp, [SDNPCommutative]>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(style) lowercase for tablegen operator - complex_fmul

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants