Skip to content

Conversation

dheaton-arm
Copy link
Contributor

With this intrinsic, and supporting SelectionDAG nodes, we can better make use of instructions such as AArch64's FDOT.

@llvmbot
Copy link
Member

llvmbot commented Sep 19, 2025

@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-llvm-selectiondag

Author: Damian Heaton (dheaton-arm)

Changes

With this intrinsic, and supporting SelectionDAG nodes, we can better make use of instructions such as AArch64's FDOT.


Patch is 21.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/159776.diff

15 Files Affected:

  • (modified) llvm/docs/LangRef.rst (+42)
  • (modified) llvm/include/llvm/CodeGen/ISDOpcodes.h (+2-1)
  • (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+2-2)
  • (modified) llvm/include/llvm/IR/Intrinsics.td (+4)
  • (modified) llvm/include/llvm/Target/TargetSelectionDAG.td (+2)
  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+16-6)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp (+2)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (+2)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+2-1)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+13)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp (+2)
  • (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+13-8)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+19-2)
  • (modified) llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td (+3)
  • (added) llvm/test/CodeGen/AArch64/sve2p1-fdot.ll (+66)
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 5fd0f6573bb97..069fcd29d808b 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -20613,6 +20613,48 @@ performance, and an out-of-loop phase to calculate the final scalar result.
 By avoiding the introduction of new ordering constraints, these intrinsics
 enhance the ability to leverage a target's accumulation instructions.
 
+'``llvm.vector.partial.reduce.fadd.*``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+This is an overloaded intrinsic.
+
+::
+
+      declare <4 x f32> @llvm.vector.partial.reduce.fadd.v4f32.v4f32.v8f32(<4 x f32> %a, <8 x f32> %b)
+      declare <vscale x 4 x f32> @llvm.vector.partial.reduce.add.nxv4f32.nxv4f32.nxv8f32(<vscale x 4 x f32> %a, <vscale x 8 x f32> %b)
+
+Overview:
+"""""""""
+
+The '``llvm.vector.partial.reduce.fadd.*``' intrinsics reduce the
+concatenation of the two vector arguments down to the number of elements of the
+result vector type.
+
+Arguments:
+""""""""""
+
+The first argument is a floating-point vector with the same type as the result.
+
+The second argument is a vector with a length that is a known integer multiple
+of the result's type, while maintaining the same element type.
+
+Semantics:
+""""""""""
+
+Other than the reduction operator (e.g. add) the way in which the concatenated
+arguments is reduced is entirely unspecified. By their nature these intrinsics
+are not expected to be useful in isolation but instead implement the first phase
+of an overall reduction operation.
+
+The typical use case is loop vectorization where reductions are split into an
+in-loop phase, where maintaining an unordered vector result is important for
+performance, and an out-of-loop phase to calculate the final scalar result.
+
+By avoiding the introduction of new ordering constraints, these intrinsics
+enhance the ability to leverage a target's accumulation instructions.
+
 '``llvm.experimental.vector.histogram.*``' Intrinsic
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index c76c83d84b3c7..83ee6ff677e3d 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1510,6 +1510,7 @@ enum NodeType {
   PARTIAL_REDUCE_SMLA,  // sext, sext
   PARTIAL_REDUCE_UMLA,  // zext, zext
   PARTIAL_REDUCE_SUMLA, // sext, zext
+  PARTIAL_REDUCE_FMLA,  // fpext, fpext
 
   // The `llvm.experimental.stackmap` intrinsic.
   // Operands: input chain, glue, <id>, <numShadowBytes>, [live0[, live1...]]
@@ -1761,7 +1762,7 @@ LLVM_ABI CondCode getSetCCInverse(CondCode Operation, EVT Type);
 
 inline bool isExtOpcode(unsigned Opcode) {
   return Opcode == ISD::ANY_EXTEND || Opcode == ISD::ZERO_EXTEND ||
-         Opcode == ISD::SIGN_EXTEND;
+         Opcode == ISD::SIGN_EXTEND || Opcode == ISD::FP_EXTEND;
 }
 
 inline bool isExtVecInRegOpcode(unsigned Opcode) {
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 46be271320fdd..26fb5f087d6ba 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1672,7 +1672,7 @@ class LLVM_ABI TargetLoweringBase {
   LegalizeAction getPartialReduceMLAAction(unsigned Opc, EVT AccVT,
                                            EVT InputVT) const {
     assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA ||
-           Opc == ISD::PARTIAL_REDUCE_SUMLA);
+           Opc == ISD::PARTIAL_REDUCE_SUMLA || Opc == ISD::PARTIAL_REDUCE_FMLA);
     PartialReduceActionTypes Key = {Opc, AccVT.getSimpleVT().SimpleTy,
                                     InputVT.getSimpleVT().SimpleTy};
     auto It = PartialReduceMLAActions.find(Key);
@@ -2774,7 +2774,7 @@ class LLVM_ABI TargetLoweringBase {
   void setPartialReduceMLAAction(unsigned Opc, MVT AccVT, MVT InputVT,
                                  LegalizeAction Action) {
     assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA ||
-           Opc == ISD::PARTIAL_REDUCE_SUMLA);
+           Opc == ISD::PARTIAL_REDUCE_SUMLA || Opc == ISD::PARTIAL_REDUCE_FMLA);
     assert(AccVT.isValid() && InputVT.isValid() &&
            "setPartialReduceMLAAction types aren't valid");
     PartialReduceActionTypes Key = {Opc, AccVT.SimpleTy, InputVT.SimpleTy};
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 585371a6a4423..1ecfe284e05fa 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -2801,6 +2801,10 @@ def int_vector_partial_reduce_add : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
                                                           [llvm_anyvector_ty, llvm_anyvector_ty],
                                                           [IntrNoMem]>;
 
+def int_vector_partial_reduce_fadd : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
+                                                                        [llvm_anyfloat_ty, llvm_anyfloat_ty],
+                                                                        [IntrNoMem]>;
+
 //===----------------- Pointer Authentication Intrinsics ------------------===//
 //
 
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index ef88c9507c86d..1e36cc0a00505 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -523,6 +523,8 @@ def partial_reduce_smla : SDNode<"ISD::PARTIAL_REDUCE_SMLA",
                                  SDTPartialReduceMLA>;
 def partial_reduce_sumla : SDNode<"ISD::PARTIAL_REDUCE_SUMLA",
                                  SDTPartialReduceMLA>;
+def partial_reduce_fmla : SDNode<"ISD::PARTIAL_REDUCE_FMLA",
+                                 SDTPartialReduceMLA>;
 
 def fadd       : SDNode<"ISD::FADD"       , SDTFPBinOp, [SDNPCommutative]>;
 def fsub       : SDNode<"ISD::FSUB"       , SDTFPBinOp>;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 4b20b756f8a15..7347d77172054 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -2040,6 +2040,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SUMLA:
+  case ISD::PARTIAL_REDUCE_FMLA:
                                 return visitPARTIAL_REDUCE_MLA(N);
   case ISD::VECTOR_COMPRESS:    return visitVECTOR_COMPRESS(N);
   case ISD::LIFETIME_END:       return visitLIFETIME_END(N);
@@ -12942,8 +12943,11 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
   SDValue Op2 = N->getOperand(2);
 
   APInt C;
-  if (Op1->getOpcode() != ISD::MUL ||
-      !ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne())
+  if (!(Op1->getOpcode() == ISD::MUL &&
+        ISD::isConstantSplatVector(Op2.getNode(), C) && C.isOne()) &&
+      !(Op1->getOpcode() == ISD::FMUL &&
+        ISD::isConstantSplatVector(Op2.getNode(), C) &&
+        C == APFloat(1.0f).bitcastToAPInt().trunc(C.getBitWidth())))
     return SDValue();
 
   SDValue LHS = Op1->getOperand(0);
@@ -12998,6 +13002,8 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
   else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::SIGN_EXTEND) {
     NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
     std::swap(LHSExtOp, RHSExtOp);
+  } else if (LHSOpcode == ISD::FP_EXTEND && RHSOpcode == ISD::FP_EXTEND) {
+    NewOpc = ISD::PARTIAL_REDUCE_FMLA;
   } else
     return SDValue();
   // For a 2-stage extend the signedness of both of the extends must match
@@ -13033,22 +13039,26 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
 
   APInt ConstantOne;
   if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
-      !ConstantOne.isOne())
+      !(ConstantOne.isOne() ||
+        ConstantOne ==
+            APFloat(1.0f).bitcastToAPInt().trunc(ConstantOne.getBitWidth())))
     return SDValue();
 
   unsigned Op1Opcode = Op1.getOpcode();
   if (!ISD::isExtOpcode(Op1Opcode))
     return SDValue();
 
-  bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
+  bool Op1IsSigned = Op1Opcode != ISD::ZERO_EXTEND;
   bool NodeIsSigned = N->getOpcode() != ISD::PARTIAL_REDUCE_UMLA;
   EVT AccElemVT = Acc.getValueType().getVectorElementType();
   if (Op1IsSigned != NodeIsSigned &&
       Op1.getValueType().getVectorElementType() != AccElemVT)
     return SDValue();
 
-  unsigned NewOpcode =
-      Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+  unsigned NewOpcode = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
+                           ? ISD::PARTIAL_REDUCE_FMLA
+                       : Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA
+                                     : ISD::PARTIAL_REDUCE_UMLA;
 
   SDValue UnextOp1 = Op1.getOperand(0);
   EVT UnextOp1VT = UnextOp1.getValueType();
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 8e423c4f83b38..94751be5b7986 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -534,6 +534,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_SUMLA:
+  case ISD::PARTIAL_REDUCE_FMLA:
     Action =
         TLI.getPartialReduceMLAAction(Op.getOpcode(), Node->getValueType(0),
                                       Node->getOperand(1).getValueType());
@@ -1243,6 +1244,7 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_SUMLA:
+  case ISD::PARTIAL_REDUCE_FMLA:
     Results.push_back(TLI.expandPartialReduceMLA(Node, DAG));
     return;
   case ISD::VECREDUCE_SEQ_FADD:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index ff7cd665446cc..e6f19499b8f41 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -1459,6 +1459,7 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_SUMLA:
+  case ISD::PARTIAL_REDUCE_FMLA:
     SplitVecRes_PARTIAL_REDUCE_MLA(N, Lo, Hi);
     break;
   case ISD::GET_ACTIVE_LANE_MASK:
@@ -3674,6 +3675,7 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_SUMLA:
+  case ISD::PARTIAL_REDUCE_FMLA:
     Res = SplitVecOp_PARTIAL_REDUCE_MLA(N);
     break;
   }
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 029eb025ff1de..bc082513786ef 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -8329,7 +8329,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
   }
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
-  case ISD::PARTIAL_REDUCE_SUMLA: {
+  case ISD::PARTIAL_REDUCE_SUMLA:
+  case ISD::PARTIAL_REDUCE_FMLA: {
     [[maybe_unused]] EVT AccVT = N1.getValueType();
     [[maybe_unused]] EVT Input1VT = N2.getValueType();
     [[maybe_unused]] EVT Input2VT = N3.getValueType();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 070d7978ce48f..448e3bbd02038 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8114,6 +8114,19 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
                          Input, DAG.getConstant(1, sdl, Input.getValueType())));
     return;
   }
+  case Intrinsic::vector_partial_reduce_fadd: {
+    if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
+      visitTargetIntrinsic(I, Intrinsic);
+      return;
+    }
+    SDValue Acc = getValue(I.getOperand(0));
+    SDValue Input = getValue(I.getOperand(1));
+    setValue(&I,
+             DAG.getNode(ISD::PARTIAL_REDUCE_FMLA, sdl, Acc.getValueType(), Acc,
+                         Input,
+                         DAG.getConstantFP(1.0f, sdl, Input.getValueType())));
+    return;
+  }
   case Intrinsic::experimental_cttz_elts: {
     auto DL = getCurSDLoc();
     SDValue Op = getValue(I.getOperand(0));
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 4b2a00c2e2cfa..cf5c269c20761 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -587,6 +587,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
     return "partial_reduce_smla";
   case ISD::PARTIAL_REDUCE_SUMLA:
     return "partial_reduce_sumla";
+  case ISD::PARTIAL_REDUCE_FMLA:
+    return "partial_reduce_fmla";
   case ISD::LOOP_DEPENDENCE_WAR_MASK:
     return "loop_dep_war";
   case ISD::LOOP_DEPENDENCE_RAW_MASK:
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index fd6d20e146bb2..50aa41b77691d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -12046,12 +12046,14 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
       EVT::getVectorVT(*DAG.getContext(), AccVT.getVectorElementType(),
                        MulOpVT.getVectorElementCount());
 
-  unsigned ExtOpcLHS = N->getOpcode() == ISD::PARTIAL_REDUCE_UMLA
-                      ? ISD::ZERO_EXTEND
-                      : ISD::SIGN_EXTEND;
-  unsigned ExtOpcRHS = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA
-                      ? ISD::SIGN_EXTEND
-                      : ISD::ZERO_EXTEND;
+  unsigned ExtOpcLHS =
+      N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA   ? ISD::FP_EXTEND
+      : N->getOpcode() == ISD::PARTIAL_REDUCE_UMLA ? ISD::ZERO_EXTEND
+                                                   : ISD::SIGN_EXTEND;
+  unsigned ExtOpcRHS =
+      N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA   ? ISD::FP_EXTEND
+      : N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA ? ISD::SIGN_EXTEND
+                                                   : ISD::ZERO_EXTEND;
 
   if (ExtMulOpVT != MulOpVT) {
     MulLHS = DAG.getNode(ExtOpcLHS, DL, ExtMulOpVT, MulLHS);
@@ -12060,7 +12062,7 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
   SDValue Input = MulLHS;
   APInt ConstantOne;
   if (!ISD::isConstantSplatVector(MulRHS.getNode(), ConstantOne) ||
-      !ConstantOne.isOne())
+      !(ConstantOne.isOne() || ConstantOne == APFloat(1.0f).bitcastToAPInt()))
     Input = DAG.getNode(ISD::MUL, DL, ExtMulOpVT, MulLHS, MulRHS);
 
   unsigned Stride = AccVT.getVectorMinNumElements();
@@ -12071,10 +12073,13 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
   for (unsigned I = 0; I < ScaleFactor; I++)
     Subvectors.push_back(DAG.getExtractSubvector(DL, AccVT, Input, I * Stride));
 
+  unsigned FlatNode =
+      N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA ? ISD::FADD : ISD::ADD;
+
   // Flatten the subvector tree
   while (Subvectors.size() > 1) {
     Subvectors.push_back(
-        DAG.getNode(ISD::ADD, DL, AccVT, {Subvectors[0], Subvectors[1]}));
+        DAG.getNode(FlatNode, DL, AccVT, {Subvectors[0], Subvectors[1]}));
     Subvectors.pop_front();
     Subvectors.pop_front();
   }
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index fc3efb072d57b..0ca596b634d11 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1919,6 +1919,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     }
   }
 
+  // Handle floating-point partial reduction
+  if (Subtarget->hasSVE2p1() || Subtarget->hasSME2()) {
+    static const unsigned FMLAOps[] = {ISD::PARTIAL_REDUCE_FMLA};
+    setPartialReduceMLAAction(FMLAOps, MVT::nxv4f32, MVT::nxv8f16, Legal);
+  }
+
   // Handle non-aliasing elements mask
   if (Subtarget->hasSVE2() ||
       (Subtarget->hasSME() && Subtarget->isStreaming())) {
@@ -2184,7 +2190,8 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
 
 bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
     const IntrinsicInst *I) const {
-  assert(I->getIntrinsicID() == Intrinsic::vector_partial_reduce_add &&
+  assert((I->getIntrinsicID() == Intrinsic::vector_partial_reduce_add ||
+          I->getIntrinsicID() == Intrinsic::vector_partial_reduce_fadd) &&
          "Unexpected intrinsic!");
   return true;
 }
@@ -22519,7 +22526,8 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
                                       SelectionDAG &DAG) {
 
   assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
-         getIntrinsicID(N) == Intrinsic::vector_partial_reduce_add &&
+         (getIntrinsicID(N) == Intrinsic::vector_partial_reduce_add ||
+          getIntrinsicID(N) == Intrinsic::vector_partial_reduce_fadd) &&
          "Expected a partial reduction node");
 
   bool Scalable = N->getValueType(0).isScalableVector();
@@ -22689,6 +22697,15 @@ static SDValue performIntrinsicCombine(SDNode *N,
                        N->getOperand(1), Input,
                        DAG.getConstant(1, DL, Input.getValueType()));
   }
+  case Intrinsic::vector_partial_reduce_fadd: {
+    if (SDValue Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
+      return Dot;
+    SDLoc DL(N);
+    SDValue Input = N->getOperand(2);
+    return DAG.getNode(ISD::PARTIAL_REDUCE_FMLA, DL, N->getValueType(0),
+                       N->getOperand(1), Input,
+                       DAG.getConstantFP(1.0f, DL, Input.getValueType()));
+  }
   case Intrinsic::aarch64_neon_vcvtfxs2fp:
   case Intrinsic::aarch64_neon_vcvtfxu2fp:
     return tryCombineFixedPointConvert(N, DCI, DAG);
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 7fe4f7acdbd49..8ef69ad13abc5 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -4228,6 +4228,9 @@ defm FCLAMP_ZZZ : sve_fp_clamp<"fclamp", AArch64fclamp>;
 defm FDOT_ZZZ_S  : sve_float_dot<0b0, 0b0, ZPR32, ZPR16, "fdot", nxv8f16, int_aarch64_sve_fdot_x2>;
 defm FDOT_ZZZI_S : sve_float_dot_indexed<0b0, 0b00, ZPR16, ZPR3b16, "fdot", nxv8f16, int_aarch64_sve_fdot_lane_x2>;
 
+def : Pat<(nxv4f32 (partial_reduce_fmla nxv4f32:$Acc, nxv8f16:$LHS, nxv8f16:$RHS)),
+          (FDOT_ZZZ_S $Acc, $LHS, $RHS)>;
+
 defm BFMLSLB_ZZZ_S : sve2_fp_mla_long<0b110, "bfmlslb", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlslb>;
 defm BFMLSLT_ZZZ_S : sve2_fp_mla_long<0b111, "bfmlslt", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlslt>;
 defm BFMLSLB_ZZZI_S : sve2_fp_mla_long_by_indexed_elem<0b110, "bfmlslb", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlslb_lane>;
diff --git a/llvm/test/CodeGen/AArch64/sve2p1-fdot.ll b/llvm/test/CodeGen/AArch64/sve2p1-fdot.ll
new file mode 100644
index 0000000000000..5bb1fae43392f
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve2p1-fdot.ll
@@ -0,0 +1,66 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve2p1 < %s | FileCheck %s
+
+define <vscale x 4 x float> @fdot_wide_vl128(<vscale x 4 x float> %acc, <vscale x 8 x half> %a, <vscale x 8 x half> %b) {
+; CHECK-LABEL: fdot_wide_vl128:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fdot z0.s, z1.h, z2.h
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
+  %b.wide = fpext <vscale x 8 x half> %b to <vscale x 8 x float>
+  %mult = fmul <vscale x 8 x float> %a.wide, %b.wide
+  %partial.reduce = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %mult)
+  ret <vscale x 4 x float> %partial.reduce
+}
+
+define void @fdot_wide_vl256(ptr %accptr, ptr %aptr, ptr %bptr) vscale_range(2,2) {
+; CHECK-LABEL: fdot_wide_vl256:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.s
+; C...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Sep 19, 2025

@llvm/pr-subscribers-backend-aarch64

Author: Damian Heaton (dheaton-arm)

Changes

With this intrinsic, and supporting SelectionDAG nodes, we can better make use of instructions such as AArch64's FDOT.


Patch is 21.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/159776.diff

15 Files Affected:

  • (modified) llvm/docs/LangRef.rst (+42)
  • (modified) llvm/include/llvm/CodeGen/ISDOpcodes.h (+2-1)
  • (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+2-2)
  • (modified) llvm/include/llvm/IR/Intrinsics.td (+4)
  • (modified) llvm/include/llvm/Target/TargetSelectionDAG.td (+2)
  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+16-6)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp (+2)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (+2)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+2-1)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+13)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp (+2)
  • (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+13-8)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+19-2)
  • (modified) llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td (+3)
  • (added) llvm/test/CodeGen/AArch64/sve2p1-fdot.ll (+66)
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 5fd0f6573bb97..069fcd29d808b 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -20613,6 +20613,48 @@ performance, and an out-of-loop phase to calculate the final scalar result.
 By avoiding the introduction of new ordering constraints, these intrinsics
 enhance the ability to leverage a target's accumulation instructions.
 
+'``llvm.vector.partial.reduce.fadd.*``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+This is an overloaded intrinsic.
+
+::
+
+      declare <4 x f32> @llvm.vector.partial.reduce.fadd.v4f32.v4f32.v8f32(<4 x f32> %a, <8 x f32> %b)
+      declare <vscale x 4 x f32> @llvm.vector.partial.reduce.add.nxv4f32.nxv4f32.nxv8f32(<vscale x 4 x f32> %a, <vscale x 8 x f32> %b)
+
+Overview:
+"""""""""
+
+The '``llvm.vector.partial.reduce.fadd.*``' intrinsics reduce the
+concatenation of the two vector arguments down to the number of elements of the
+result vector type.
+
+Arguments:
+""""""""""
+
+The first argument is a floating-point vector with the same type as the result.
+
+The second argument is a vector with a length that is a known integer multiple
+of the result's type, while maintaining the same element type.
+
+Semantics:
+""""""""""
+
+Other than the reduction operator (e.g. add) the way in which the concatenated
+arguments is reduced is entirely unspecified. By their nature these intrinsics
+are not expected to be useful in isolation but instead implement the first phase
+of an overall reduction operation.
+
+The typical use case is loop vectorization where reductions are split into an
+in-loop phase, where maintaining an unordered vector result is important for
+performance, and an out-of-loop phase to calculate the final scalar result.
+
+By avoiding the introduction of new ordering constraints, these intrinsics
+enhance the ability to leverage a target's accumulation instructions.
+
 '``llvm.experimental.vector.histogram.*``' Intrinsic
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index c76c83d84b3c7..83ee6ff677e3d 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1510,6 +1510,7 @@ enum NodeType {
   PARTIAL_REDUCE_SMLA,  // sext, sext
   PARTIAL_REDUCE_UMLA,  // zext, zext
   PARTIAL_REDUCE_SUMLA, // sext, zext
+  PARTIAL_REDUCE_FMLA,  // fpext, fpext
 
   // The `llvm.experimental.stackmap` intrinsic.
   // Operands: input chain, glue, <id>, <numShadowBytes>, [live0[, live1...]]
@@ -1761,7 +1762,7 @@ LLVM_ABI CondCode getSetCCInverse(CondCode Operation, EVT Type);
 
 inline bool isExtOpcode(unsigned Opcode) {
   return Opcode == ISD::ANY_EXTEND || Opcode == ISD::ZERO_EXTEND ||
-         Opcode == ISD::SIGN_EXTEND;
+         Opcode == ISD::SIGN_EXTEND || Opcode == ISD::FP_EXTEND;
 }
 
 inline bool isExtVecInRegOpcode(unsigned Opcode) {
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 46be271320fdd..26fb5f087d6ba 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1672,7 +1672,7 @@ class LLVM_ABI TargetLoweringBase {
   LegalizeAction getPartialReduceMLAAction(unsigned Opc, EVT AccVT,
                                            EVT InputVT) const {
     assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA ||
-           Opc == ISD::PARTIAL_REDUCE_SUMLA);
+           Opc == ISD::PARTIAL_REDUCE_SUMLA || Opc == ISD::PARTIAL_REDUCE_FMLA);
     PartialReduceActionTypes Key = {Opc, AccVT.getSimpleVT().SimpleTy,
                                     InputVT.getSimpleVT().SimpleTy};
     auto It = PartialReduceMLAActions.find(Key);
@@ -2774,7 +2774,7 @@ class LLVM_ABI TargetLoweringBase {
   void setPartialReduceMLAAction(unsigned Opc, MVT AccVT, MVT InputVT,
                                  LegalizeAction Action) {
     assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA ||
-           Opc == ISD::PARTIAL_REDUCE_SUMLA);
+           Opc == ISD::PARTIAL_REDUCE_SUMLA || Opc == ISD::PARTIAL_REDUCE_FMLA);
     assert(AccVT.isValid() && InputVT.isValid() &&
            "setPartialReduceMLAAction types aren't valid");
     PartialReduceActionTypes Key = {Opc, AccVT.SimpleTy, InputVT.SimpleTy};
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 585371a6a4423..1ecfe284e05fa 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -2801,6 +2801,10 @@ def int_vector_partial_reduce_add : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
                                                           [llvm_anyvector_ty, llvm_anyvector_ty],
                                                           [IntrNoMem]>;
 
+def int_vector_partial_reduce_fadd : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
+                                                                        [llvm_anyfloat_ty, llvm_anyfloat_ty],
+                                                                        [IntrNoMem]>;
+
 //===----------------- Pointer Authentication Intrinsics ------------------===//
 //
 
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index ef88c9507c86d..1e36cc0a00505 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -523,6 +523,8 @@ def partial_reduce_smla : SDNode<"ISD::PARTIAL_REDUCE_SMLA",
                                  SDTPartialReduceMLA>;
 def partial_reduce_sumla : SDNode<"ISD::PARTIAL_REDUCE_SUMLA",
                                  SDTPartialReduceMLA>;
+def partial_reduce_fmla : SDNode<"ISD::PARTIAL_REDUCE_FMLA",
+                                 SDTPartialReduceMLA>;
 
 def fadd       : SDNode<"ISD::FADD"       , SDTFPBinOp, [SDNPCommutative]>;
 def fsub       : SDNode<"ISD::FSUB"       , SDTFPBinOp>;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 4b20b756f8a15..7347d77172054 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -2040,6 +2040,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SUMLA:
+  case ISD::PARTIAL_REDUCE_FMLA:
                                 return visitPARTIAL_REDUCE_MLA(N);
   case ISD::VECTOR_COMPRESS:    return visitVECTOR_COMPRESS(N);
   case ISD::LIFETIME_END:       return visitLIFETIME_END(N);
@@ -12942,8 +12943,11 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
   SDValue Op2 = N->getOperand(2);
 
   APInt C;
-  if (Op1->getOpcode() != ISD::MUL ||
-      !ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne())
+  if (!(Op1->getOpcode() == ISD::MUL &&
+        ISD::isConstantSplatVector(Op2.getNode(), C) && C.isOne()) &&
+      !(Op1->getOpcode() == ISD::FMUL &&
+        ISD::isConstantSplatVector(Op2.getNode(), C) &&
+        C == APFloat(1.0f).bitcastToAPInt().trunc(C.getBitWidth())))
     return SDValue();
 
   SDValue LHS = Op1->getOperand(0);
@@ -12998,6 +13002,8 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
   else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::SIGN_EXTEND) {
     NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
     std::swap(LHSExtOp, RHSExtOp);
+  } else if (LHSOpcode == ISD::FP_EXTEND && RHSOpcode == ISD::FP_EXTEND) {
+    NewOpc = ISD::PARTIAL_REDUCE_FMLA;
   } else
     return SDValue();
   // For a 2-stage extend the signedness of both of the extends must match
@@ -13033,22 +13039,26 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
 
   APInt ConstantOne;
   if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
-      !ConstantOne.isOne())
+      !(ConstantOne.isOne() ||
+        ConstantOne ==
+            APFloat(1.0f).bitcastToAPInt().trunc(ConstantOne.getBitWidth())))
     return SDValue();
 
   unsigned Op1Opcode = Op1.getOpcode();
   if (!ISD::isExtOpcode(Op1Opcode))
     return SDValue();
 
-  bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
+  bool Op1IsSigned = Op1Opcode != ISD::ZERO_EXTEND;
   bool NodeIsSigned = N->getOpcode() != ISD::PARTIAL_REDUCE_UMLA;
   EVT AccElemVT = Acc.getValueType().getVectorElementType();
   if (Op1IsSigned != NodeIsSigned &&
       Op1.getValueType().getVectorElementType() != AccElemVT)
     return SDValue();
 
-  unsigned NewOpcode =
-      Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+  unsigned NewOpcode = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
+                           ? ISD::PARTIAL_REDUCE_FMLA
+                       : Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA
+                                     : ISD::PARTIAL_REDUCE_UMLA;
 
   SDValue UnextOp1 = Op1.getOperand(0);
   EVT UnextOp1VT = UnextOp1.getValueType();
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 8e423c4f83b38..94751be5b7986 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -534,6 +534,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_SUMLA:
+  case ISD::PARTIAL_REDUCE_FMLA:
     Action =
         TLI.getPartialReduceMLAAction(Op.getOpcode(), Node->getValueType(0),
                                       Node->getOperand(1).getValueType());
@@ -1243,6 +1244,7 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_SUMLA:
+  case ISD::PARTIAL_REDUCE_FMLA:
     Results.push_back(TLI.expandPartialReduceMLA(Node, DAG));
     return;
   case ISD::VECREDUCE_SEQ_FADD:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index ff7cd665446cc..e6f19499b8f41 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -1459,6 +1459,7 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_SUMLA:
+  case ISD::PARTIAL_REDUCE_FMLA:
     SplitVecRes_PARTIAL_REDUCE_MLA(N, Lo, Hi);
     break;
   case ISD::GET_ACTIVE_LANE_MASK:
@@ -3674,6 +3675,7 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_SUMLA:
+  case ISD::PARTIAL_REDUCE_FMLA:
     Res = SplitVecOp_PARTIAL_REDUCE_MLA(N);
     break;
   }
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 029eb025ff1de..bc082513786ef 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -8329,7 +8329,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
   }
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
-  case ISD::PARTIAL_REDUCE_SUMLA: {
+  case ISD::PARTIAL_REDUCE_SUMLA:
+  case ISD::PARTIAL_REDUCE_FMLA: {
     [[maybe_unused]] EVT AccVT = N1.getValueType();
     [[maybe_unused]] EVT Input1VT = N2.getValueType();
     [[maybe_unused]] EVT Input2VT = N3.getValueType();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 070d7978ce48f..448e3bbd02038 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8114,6 +8114,19 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
                          Input, DAG.getConstant(1, sdl, Input.getValueType())));
     return;
   }
+  case Intrinsic::vector_partial_reduce_fadd: {
+    if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
+      visitTargetIntrinsic(I, Intrinsic);
+      return;
+    }
+    SDValue Acc = getValue(I.getOperand(0));
+    SDValue Input = getValue(I.getOperand(1));
+    setValue(&I,
+             DAG.getNode(ISD::PARTIAL_REDUCE_FMLA, sdl, Acc.getValueType(), Acc,
+                         Input,
+                         DAG.getConstantFP(1.0f, sdl, Input.getValueType())));
+    return;
+  }
   case Intrinsic::experimental_cttz_elts: {
     auto DL = getCurSDLoc();
     SDValue Op = getValue(I.getOperand(0));
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 4b2a00c2e2cfa..cf5c269c20761 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -587,6 +587,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
     return "partial_reduce_smla";
   case ISD::PARTIAL_REDUCE_SUMLA:
     return "partial_reduce_sumla";
+  case ISD::PARTIAL_REDUCE_FMLA:
+    return "partial_reduce_fmla";
   case ISD::LOOP_DEPENDENCE_WAR_MASK:
     return "loop_dep_war";
   case ISD::LOOP_DEPENDENCE_RAW_MASK:
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index fd6d20e146bb2..50aa41b77691d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -12046,12 +12046,14 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
       EVT::getVectorVT(*DAG.getContext(), AccVT.getVectorElementType(),
                        MulOpVT.getVectorElementCount());
 
-  unsigned ExtOpcLHS = N->getOpcode() == ISD::PARTIAL_REDUCE_UMLA
-                      ? ISD::ZERO_EXTEND
-                      : ISD::SIGN_EXTEND;
-  unsigned ExtOpcRHS = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA
-                      ? ISD::SIGN_EXTEND
-                      : ISD::ZERO_EXTEND;
+  unsigned ExtOpcLHS =
+      N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA   ? ISD::FP_EXTEND
+      : N->getOpcode() == ISD::PARTIAL_REDUCE_UMLA ? ISD::ZERO_EXTEND
+                                                   : ISD::SIGN_EXTEND;
+  unsigned ExtOpcRHS =
+      N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA   ? ISD::FP_EXTEND
+      : N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA ? ISD::SIGN_EXTEND
+                                                   : ISD::ZERO_EXTEND;
 
   if (ExtMulOpVT != MulOpVT) {
     MulLHS = DAG.getNode(ExtOpcLHS, DL, ExtMulOpVT, MulLHS);
@@ -12060,7 +12062,7 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
   SDValue Input = MulLHS;
   APInt ConstantOne;
   if (!ISD::isConstantSplatVector(MulRHS.getNode(), ConstantOne) ||
-      !ConstantOne.isOne())
+      !(ConstantOne.isOne() || ConstantOne == APFloat(1.0f).bitcastToAPInt()))
     Input = DAG.getNode(ISD::MUL, DL, ExtMulOpVT, MulLHS, MulRHS);
 
   unsigned Stride = AccVT.getVectorMinNumElements();
@@ -12071,10 +12073,13 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
   for (unsigned I = 0; I < ScaleFactor; I++)
     Subvectors.push_back(DAG.getExtractSubvector(DL, AccVT, Input, I * Stride));
 
+  unsigned FlatNode =
+      N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA ? ISD::FADD : ISD::ADD;
+
   // Flatten the subvector tree
   while (Subvectors.size() > 1) {
     Subvectors.push_back(
-        DAG.getNode(ISD::ADD, DL, AccVT, {Subvectors[0], Subvectors[1]}));
+        DAG.getNode(FlatNode, DL, AccVT, {Subvectors[0], Subvectors[1]}));
     Subvectors.pop_front();
     Subvectors.pop_front();
   }
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index fc3efb072d57b..0ca596b634d11 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1919,6 +1919,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     }
   }
 
+  // Handle floating-point partial reduction
+  if (Subtarget->hasSVE2p1() || Subtarget->hasSME2()) {
+    static const unsigned FMLAOps[] = {ISD::PARTIAL_REDUCE_FMLA};
+    setPartialReduceMLAAction(FMLAOps, MVT::nxv4f32, MVT::nxv8f16, Legal);
+  }
+
   // Handle non-aliasing elements mask
   if (Subtarget->hasSVE2() ||
       (Subtarget->hasSME() && Subtarget->isStreaming())) {
@@ -2184,7 +2190,8 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
 
 bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
     const IntrinsicInst *I) const {
-  assert(I->getIntrinsicID() == Intrinsic::vector_partial_reduce_add &&
+  assert((I->getIntrinsicID() == Intrinsic::vector_partial_reduce_add ||
+          I->getIntrinsicID() == Intrinsic::vector_partial_reduce_fadd) &&
          "Unexpected intrinsic!");
   return true;
 }
@@ -22519,7 +22526,8 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
                                       SelectionDAG &DAG) {
 
   assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
-         getIntrinsicID(N) == Intrinsic::vector_partial_reduce_add &&
+         (getIntrinsicID(N) == Intrinsic::vector_partial_reduce_add ||
+          getIntrinsicID(N) == Intrinsic::vector_partial_reduce_fadd) &&
          "Expected a partial reduction node");
 
   bool Scalable = N->getValueType(0).isScalableVector();
@@ -22689,6 +22697,15 @@ static SDValue performIntrinsicCombine(SDNode *N,
                        N->getOperand(1), Input,
                        DAG.getConstant(1, DL, Input.getValueType()));
   }
+  case Intrinsic::vector_partial_reduce_fadd: {
+    if (SDValue Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
+      return Dot;
+    SDLoc DL(N);
+    SDValue Input = N->getOperand(2);
+    return DAG.getNode(ISD::PARTIAL_REDUCE_FMLA, DL, N->getValueType(0),
+                       N->getOperand(1), Input,
+                       DAG.getConstantFP(1.0f, DL, Input.getValueType()));
+  }
   case Intrinsic::aarch64_neon_vcvtfxs2fp:
   case Intrinsic::aarch64_neon_vcvtfxu2fp:
     return tryCombineFixedPointConvert(N, DCI, DAG);
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 7fe4f7acdbd49..8ef69ad13abc5 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -4228,6 +4228,9 @@ defm FCLAMP_ZZZ : sve_fp_clamp<"fclamp", AArch64fclamp>;
 defm FDOT_ZZZ_S  : sve_float_dot<0b0, 0b0, ZPR32, ZPR16, "fdot", nxv8f16, int_aarch64_sve_fdot_x2>;
 defm FDOT_ZZZI_S : sve_float_dot_indexed<0b0, 0b00, ZPR16, ZPR3b16, "fdot", nxv8f16, int_aarch64_sve_fdot_lane_x2>;
 
+def : Pat<(nxv4f32 (partial_reduce_fmla nxv4f32:$Acc, nxv8f16:$LHS, nxv8f16:$RHS)),
+          (FDOT_ZZZ_S $Acc, $LHS, $RHS)>;
+
 defm BFMLSLB_ZZZ_S : sve2_fp_mla_long<0b110, "bfmlslb", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlslb>;
 defm BFMLSLT_ZZZ_S : sve2_fp_mla_long<0b111, "bfmlslt", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlslt>;
 defm BFMLSLB_ZZZI_S : sve2_fp_mla_long_by_indexed_elem<0b110, "bfmlslb", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlslb_lane>;
diff --git a/llvm/test/CodeGen/AArch64/sve2p1-fdot.ll b/llvm/test/CodeGen/AArch64/sve2p1-fdot.ll
new file mode 100644
index 0000000000000..5bb1fae43392f
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve2p1-fdot.ll
@@ -0,0 +1,66 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve2p1 < %s | FileCheck %s
+
+define <vscale x 4 x float> @fdot_wide_vl128(<vscale x 4 x float> %acc, <vscale x 8 x half> %a, <vscale x 8 x half> %b) {
+; CHECK-LABEL: fdot_wide_vl128:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fdot z0.s, z1.h, z2.h
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
+  %b.wide = fpext <vscale x 8 x half> %b to <vscale x 8 x float>
+  %mult = fmul <vscale x 8 x float> %a.wide, %b.wide
+  %partial.reduce = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %mult)
+  ret <vscale x 4 x float> %partial.reduce
+}
+
+define void @fdot_wide_vl256(ptr %accptr, ptr %aptr, ptr %bptr) vscale_range(2,2) {
+; CHECK-LABEL: fdot_wide_vl256:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.s
+; C...
[truncated]

@dheaton-arm
Copy link
Contributor Author

@dheaton-arm dheaton-arm changed the title Add llvm.vector.partial.reduction.fadd intrinsic Add llvm.vector.partial.reduce.fadd intrinsic Sep 19, 2025
@dheaton-arm dheaton-arm marked this pull request as draft September 19, 2025 19:07
@dheaton-arm dheaton-arm marked this pull request as ready for review September 23, 2025 09:52
@paulwalker-arm
Copy link
Collaborator

paulwalker-arm commented Sep 23, 2025

You'll need to roll out the reassoc change to the LangRef as well, so the unordered nature is activated by this option but then add a separate note stating the flag is required.
Based on feedback I'm retracting my reassoc suggestion.

With this intrinsic, and supporting SelectionDAG nodes, we can better make use of instructions such as AArch64's `FDOT`.
Corrected LangRef typos, improved const
comparisons for fadd, and add direct tests.
This reverts commit 319852132602f685aea6228f10418370fd530aa7.
C->isExactlyValue(1.0)) &&
!(ISD::isConstantSplatVector(MulRHS.getNode(), ConstantOne) &&
ConstantOne.isOne()))
Input = DAG.getNode(ISD::MUL, DL, ExtMulOpVT, MulLHS, MulRHS);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can only work for integer types, so this is also missing test coverage?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants