diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst index 8e863939781a2..2066167345945 100644 --- a/llvm/docs/LangRef.rst +++ b/llvm/docs/LangRef.rst @@ -20315,6 +20315,76 @@ Arguments: """""""""" The argument to this intrinsic must be a vector of floating-point values. +Vector Partial Reduction Intrinsics +----------------------------------- + +Partial reductions of vectors can be expressed using the following intrinsics. +Each one reduces the concatenation of the two vector arguments down to the +number of elements of the result vector type. + +Other than the reduction operator (e.g. add, fadd) the way in which the +concatenated arguments is reduced is entirely unspecified. By their nature these +intrinsics are not expected to be useful in isolation but instead implement the +first phase of an overall reduction operation. + +The typical use case is loop vectorization where reductions are split into an +in-loop phase, where maintaining an unordered vector result is important for +performance, and an out-of-loop phase to calculate the final scalar result. + +By avoiding the introduction of new ordering constraints, these intrinsics +enhance the ability to leverage a target's accumulation instructions. + +'``llvm.vector.partial.reduce.add.*``' Intrinsic +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v4i32.v8i32(<4 x i32> %a, <8 x i32> %b) + declare <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v4i32.v16i32(<4 x i32> %a, <16 x i32> %b) + declare @llvm.vector.partial.reduce.add.nxv4i32.nxv4i32.nxv8i32( %a, %b) + declare @llvm.vector.partial.reduce.add.nxv4i32.nxv4i32.nxv16i32( %a, %b) + +Arguments: +"""""""""" + +The first argument is an integer vector with the same type as the result. + +The second argument is a vector with a length that is a known integer multiple +of the result's type, while maintaining the same element type. + +'``llvm.vector.partial.reduce.fadd.*``' Intrinsic +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Syntax: +""""""" +This is an overloaded intrinsic. + +:: + + declare <4 x f32> @llvm.vector.partial.reduce.fadd.v4f32.v8f32(<4 x f32> %a, <8 x f32> %b) + declare @llvm.vector.partial.reduce.fadd.nxv4f32.nxv8f32( %a, %b) + +Arguments: +"""""""""" + +The first argument is a floating-point vector with the same type as the result. + +The second argument is a vector with a length that is a known integer multiple +of the result's type, while maintaining the same element type. + +Semantics: +"""""""""" + +As the way in which the arguments to this floating-point intrinsic are reduced +is unspecified, this intrinsic will assume floating-point reassociation and +contraction, which may result in variations to the results due to reordering or +by lowering to different instructions (including combining multiple instructions +into a single one). + '``llvm.vector.insert``' Intrinsic ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -20688,50 +20758,6 @@ Note that it has the following implications: - If ``%cnt`` is non-zero, the return value is non-zero as well. - If ``%cnt`` is less than or equal to ``%max_lanes``, the return value is equal to ``%cnt``. -'``llvm.vector.partial.reduce.add.*``' Intrinsic -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Syntax: -""""""" -This is an overloaded intrinsic. - -:: - - declare <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v4i32.v8i32(<4 x i32> %a, <8 x i32> %b) - declare <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v4i32.v16i32(<4 x i32> %a, <16 x i32> %b) - declare @llvm.vector.partial.reduce.add.nxv4i32.nxv4i32.nxv8i32( %a, %b) - declare @llvm.vector.partial.reduce.add.nxv4i32.nxv4i32.nxv16i32( %a, %b) - -Overview: -""""""""" - -The '``llvm.vector.partial.reduce.add.*``' intrinsics reduce the -concatenation of the two vector arguments down to the number of elements of the -result vector type. - -Arguments: -"""""""""" - -The first argument is an integer vector with the same type as the result. - -The second argument is a vector with a length that is a known integer multiple -of the result's type, while maintaining the same element type. - -Semantics: -"""""""""" - -Other than the reduction operator (e.g. add) the way in which the concatenated -arguments is reduced is entirely unspecified. By their nature these intrinsics -are not expected to be useful in isolation but instead implement the first phase -of an overall reduction operation. - -The typical use case is loop vectorization where reductions are split into an -in-loop phase, where maintaining an unordered vector result is important for -performance, and an out-of-loop phase to calculate the final scalar result. - -By avoiding the introduction of new ordering constraints, these intrinsics -enhance the ability to leverage a target's accumulation instructions. - '``llvm.experimental.vector.histogram.*``' Intrinsic ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h index c76c83d84b3c7..e1f6aab0040db 100644 --- a/llvm/include/llvm/CodeGen/ISDOpcodes.h +++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h @@ -1510,6 +1510,7 @@ enum NodeType { PARTIAL_REDUCE_SMLA, // sext, sext PARTIAL_REDUCE_UMLA, // zext, zext PARTIAL_REDUCE_SUMLA, // sext, zext + PARTIAL_REDUCE_FMLA, // fpext, fpext // The `llvm.experimental.stackmap` intrinsic. // Operands: input chain, glue, , , [live0[, live1...]] diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h index c45e03a7bdad8..8d75b0581618c 100644 --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -1672,7 +1672,7 @@ class LLVM_ABI TargetLoweringBase { LegalizeAction getPartialReduceMLAAction(unsigned Opc, EVT AccVT, EVT InputVT) const { assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA || - Opc == ISD::PARTIAL_REDUCE_SUMLA); + Opc == ISD::PARTIAL_REDUCE_SUMLA || Opc == ISD::PARTIAL_REDUCE_FMLA); PartialReduceActionTypes Key = {Opc, AccVT.getSimpleVT().SimpleTy, InputVT.getSimpleVT().SimpleTy}; auto It = PartialReduceMLAActions.find(Key); @@ -2774,7 +2774,7 @@ class LLVM_ABI TargetLoweringBase { void setPartialReduceMLAAction(unsigned Opc, MVT AccVT, MVT InputVT, LegalizeAction Action) { assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA || - Opc == ISD::PARTIAL_REDUCE_SUMLA); + Opc == ISD::PARTIAL_REDUCE_SUMLA || Opc == ISD::PARTIAL_REDUCE_FMLA); assert(AccVT.isValid() && InputVT.isValid() && "setPartialReduceMLAAction types aren't valid"); PartialReduceActionTypes Key = {Opc, AccVT.SimpleTy, InputVT.SimpleTy}; diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td index dba44e7c3c506..ce79a58053754 100644 --- a/llvm/include/llvm/IR/Intrinsics.td +++ b/llvm/include/llvm/IR/Intrinsics.td @@ -2802,6 +2802,10 @@ def int_vector_partial_reduce_add : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyvector_ty, llvm_anyvector_ty], [IntrNoMem]>; +def int_vector_partial_reduce_fadd : DefaultAttrsIntrinsic<[LLVMMatchType<0>], + [llvm_anyfloat_ty, llvm_anyfloat_ty], + [IntrNoMem]>; + //===----------------- Pointer Authentication Intrinsics ------------------===// // diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td index 5e57dcaa303f3..1f0800885937e 100644 --- a/llvm/include/llvm/Target/TargetSelectionDAG.td +++ b/llvm/include/llvm/Target/TargetSelectionDAG.td @@ -522,6 +522,8 @@ def partial_reduce_smla : SDNode<"ISD::PARTIAL_REDUCE_SMLA", SDTPartialReduceMLA>; def partial_reduce_sumla : SDNode<"ISD::PARTIAL_REDUCE_SUMLA", SDTPartialReduceMLA>; +def partial_reduce_fmla : SDNode<"ISD::PARTIAL_REDUCE_FMLA", + SDTPartialReduceMLA>; def fadd : SDNode<"ISD::FADD" , SDTFPBinOp, [SDNPCommutative]>; def fsub : SDNode<"ISD::FSUB" , SDTFPBinOp>; diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index c81568672de3c..b3228d904ee1e 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -2040,6 +2040,7 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::PARTIAL_REDUCE_SMLA: case ISD::PARTIAL_REDUCE_UMLA: case ISD::PARTIAL_REDUCE_SUMLA: + case ISD::PARTIAL_REDUCE_FMLA: return visitPARTIAL_REDUCE_MLA(N); case ISD::VECTOR_COMPRESS: return visitVECTOR_COMPRESS(N); case ISD::LIFETIME_END: return visitLIFETIME_END(N); @@ -12989,6 +12990,12 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) { // // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1)) // -> partial_reduce_*mla(acc, x, C) +// +// partial_reduce_fmla(acc, fmul(fpext(a), fpext(b)), splat(1.0)) +// -> partial_reduce_fmla(acc, a, b) +// +// partial_reduce_fmla(acc, fmul(fpext(x), splat(C)), splat(1.0)) +// -> partial_reduce_fmla(acc, x, C) SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { SDLoc DL(N); auto *Context = DAG.getContext(); @@ -12997,14 +13004,22 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { SDValue Op2 = N->getOperand(2); APInt C; - if (Op1->getOpcode() != ISD::MUL || - !ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne()) + ConstantFPSDNode *CFP; + if (!(Op1->getOpcode() == ISD::MUL && + ISD::isConstantSplatVector(Op2.getNode(), C) && C.isOne()) && + !(Op1->getOpcode() == ISD::FMUL && + (CFP = llvm::isConstOrConstSplatFP(Op2, false)) && + CFP->isExactlyValue(1.0))) return SDValue(); + auto IsIntOrFPExtOpcode = [](unsigned int Opcode) { + return (ISD::isExtOpcode(Opcode) || Opcode == ISD::FP_EXTEND); + }; + SDValue LHS = Op1->getOperand(0); SDValue RHS = Op1->getOperand(1); unsigned LHSOpcode = LHS->getOpcode(); - if (!ISD::isExtOpcode(LHSOpcode)) + if (!IsIntOrFPExtOpcode(LHSOpcode)) return SDValue(); SDValue LHSExtOp = LHS->getOperand(0); @@ -13036,7 +13051,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { } unsigned RHSOpcode = RHS->getOpcode(); - if (!ISD::isExtOpcode(RHSOpcode)) + if (!IsIntOrFPExtOpcode(RHSOpcode)) return SDValue(); SDValue RHSExtOp = RHS->getOperand(0); @@ -13053,6 +13068,8 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::SIGN_EXTEND) { NewOpc = ISD::PARTIAL_REDUCE_SUMLA; std::swap(LHSExtOp, RHSExtOp); + } else if (LHSOpcode == ISD::FP_EXTEND && RHSOpcode == ISD::FP_EXTEND) { + NewOpc = ISD::PARTIAL_REDUCE_FMLA; } else return SDValue(); // For a 2-stage extend the signedness of both of the extends must match @@ -13080,6 +13097,8 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { // -> partial.reduce.smla(acc, op, splat(trunc(1))) // partial.reduce.sumla(acc, sext(op), splat(1)) // -> partial.reduce.smla(acc, op, splat(trunc(1))) +// partial.reduce.fmla(acc, fpext(op), splat(1.0)) +// -> partial.reduce.fmla(acc, op, splat(1.0)) SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) { SDLoc DL(N); SDValue Acc = N->getOperand(0); @@ -13087,23 +13106,30 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) { SDValue Op2 = N->getOperand(2); APInt ConstantOne; - if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) || - !ConstantOne.isOne()) + ConstantFPSDNode *C; + if (!(N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA && + (C = llvm::isConstOrConstSplatFP(Op2, false)) && + C->isExactlyValue(1.0)) && + !(ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) && + ConstantOne.isOne())) return SDValue(); unsigned Op1Opcode = Op1.getOpcode(); - if (!ISD::isExtOpcode(Op1Opcode)) + if (!ISD::isExtOpcode(Op1Opcode) && Op1Opcode != ISD::FP_EXTEND) return SDValue(); bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND; bool NodeIsSigned = N->getOpcode() != ISD::PARTIAL_REDUCE_UMLA; EVT AccElemVT = Acc.getValueType().getVectorElementType(); - if (Op1IsSigned != NodeIsSigned && + if (N->getOpcode() != ISD::PARTIAL_REDUCE_FMLA && + Op1IsSigned != NodeIsSigned && Op1.getValueType().getVectorElementType() != AccElemVT) return SDValue(); - unsigned NewOpcode = - Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA; + unsigned NewOpcode = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA + ? ISD::PARTIAL_REDUCE_FMLA + : Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA + : ISD::PARTIAL_REDUCE_UMLA; SDValue UnextOp1 = Op1.getOperand(0); EVT UnextOp1VT = UnextOp1.getValueType(); @@ -13113,8 +13139,12 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) { TLI.getTypeToTransformTo(*Context, UnextOp1VT))) return SDValue(); + auto Constant = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA + ? DAG.getConstantFP(1, DL, UnextOp1VT) + : DAG.getConstant(1, DL, UnextOp1VT); + return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1, - DAG.getConstant(1, DL, UnextOp1VT)); + Constant); } SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) { diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp index 8e423c4f83b38..94751be5b7986 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp @@ -534,6 +534,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) { case ISD::PARTIAL_REDUCE_UMLA: case ISD::PARTIAL_REDUCE_SMLA: case ISD::PARTIAL_REDUCE_SUMLA: + case ISD::PARTIAL_REDUCE_FMLA: Action = TLI.getPartialReduceMLAAction(Op.getOpcode(), Node->getValueType(0), Node->getOperand(1).getValueType()); @@ -1243,6 +1244,7 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl &Results) { case ISD::PARTIAL_REDUCE_UMLA: case ISD::PARTIAL_REDUCE_SMLA: case ISD::PARTIAL_REDUCE_SUMLA: + case ISD::PARTIAL_REDUCE_FMLA: Results.push_back(TLI.expandPartialReduceMLA(Node, DAG)); return; case ISD::VECREDUCE_SEQ_FADD: diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp index ff7cd665446cc..e6f19499b8f41 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -1459,6 +1459,7 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) { case ISD::PARTIAL_REDUCE_UMLA: case ISD::PARTIAL_REDUCE_SMLA: case ISD::PARTIAL_REDUCE_SUMLA: + case ISD::PARTIAL_REDUCE_FMLA: SplitVecRes_PARTIAL_REDUCE_MLA(N, Lo, Hi); break; case ISD::GET_ACTIVE_LANE_MASK: @@ -3674,6 +3675,7 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) { case ISD::PARTIAL_REDUCE_UMLA: case ISD::PARTIAL_REDUCE_SMLA: case ISD::PARTIAL_REDUCE_SUMLA: + case ISD::PARTIAL_REDUCE_FMLA: Res = SplitVecOp_PARTIAL_REDUCE_MLA(N); break; } diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 7aa293af963e6..0ccaf9c7d887d 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -8394,7 +8394,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, } case ISD::PARTIAL_REDUCE_UMLA: case ISD::PARTIAL_REDUCE_SMLA: - case ISD::PARTIAL_REDUCE_SUMLA: { + case ISD::PARTIAL_REDUCE_SUMLA: + case ISD::PARTIAL_REDUCE_FMLA: { [[maybe_unused]] EVT AccVT = N1.getValueType(); [[maybe_unused]] EVT Input1VT = N2.getValueType(); [[maybe_unused]] EVT Input2VT = N3.getValueType(); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index b5201a311c591..544b4abd6b45a 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -8114,6 +8114,19 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I, Input, DAG.getConstant(1, sdl, Input.getValueType()))); return; } + case Intrinsic::vector_partial_reduce_fadd: { + if (!TLI.shouldExpandPartialReductionIntrinsic(cast(&I))) { + visitTargetIntrinsic(I, Intrinsic); + return; + } + SDValue Acc = getValue(I.getOperand(0)); + SDValue Input = getValue(I.getOperand(1)); + setValue(&I, + DAG.getNode(ISD::PARTIAL_REDUCE_FMLA, sdl, Acc.getValueType(), Acc, + Input, + DAG.getConstantFP(1.0f, sdl, Input.getValueType()))); + return; + } case Intrinsic::experimental_cttz_elts: { auto DL = getCurSDLoc(); SDValue Op = getValue(I.getOperand(0)); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp index 4b2a00c2e2cfa..cf5c269c20761 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp @@ -587,6 +587,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const { return "partial_reduce_smla"; case ISD::PARTIAL_REDUCE_SUMLA: return "partial_reduce_sumla"; + case ISD::PARTIAL_REDUCE_FMLA: + return "partial_reduce_fmla"; case ISD::LOOP_DEPENDENCE_WAR_MASK: return "loop_dep_war"; case ISD::LOOP_DEPENDENCE_RAW_MASK: diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index dba5a8c0a7315..f1b5641db4b5d 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -12062,12 +12062,14 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N, EVT::getVectorVT(*DAG.getContext(), AccVT.getVectorElementType(), MulOpVT.getVectorElementCount()); - unsigned ExtOpcLHS = N->getOpcode() == ISD::PARTIAL_REDUCE_UMLA - ? ISD::ZERO_EXTEND - : ISD::SIGN_EXTEND; - unsigned ExtOpcRHS = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA - ? ISD::SIGN_EXTEND - : ISD::ZERO_EXTEND; + unsigned ExtOpcLHS = + N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA ? ISD::FP_EXTEND + : N->getOpcode() == ISD::PARTIAL_REDUCE_UMLA ? ISD::ZERO_EXTEND + : ISD::SIGN_EXTEND; + unsigned ExtOpcRHS = + N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA ? ISD::FP_EXTEND + : N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA ? ISD::SIGN_EXTEND + : ISD::ZERO_EXTEND; if (ExtMulOpVT != MulOpVT) { MulLHS = DAG.getNode(ExtOpcLHS, DL, ExtMulOpVT, MulLHS); @@ -12075,8 +12077,12 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N, } SDValue Input = MulLHS; APInt ConstantOne; - if (!ISD::isConstantSplatVector(MulRHS.getNode(), ConstantOne) || - !ConstantOne.isOne()) + ConstantFPSDNode *C; + if (!(N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA && + (C = llvm::isConstOrConstSplatFP(MulRHS, false)) && + C->isExactlyValue(1.0)) && + !(ISD::isConstantSplatVector(MulRHS.getNode(), ConstantOne) && + ConstantOne.isOne())) Input = DAG.getNode(ISD::MUL, DL, ExtMulOpVT, MulLHS, MulRHS); unsigned Stride = AccVT.getVectorMinNumElements(); @@ -12087,10 +12093,13 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N, for (unsigned I = 0; I < ScaleFactor; I++) Subvectors.push_back(DAG.getExtractSubvector(DL, AccVT, Input, I * Stride)); + unsigned FlatNode = + N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA ? ISD::FADD : ISD::ADD; + // Flatten the subvector tree while (Subvectors.size() > 1) { Subvectors.push_back( - DAG.getNode(ISD::ADD, DL, AccVT, {Subvectors[0], Subvectors[1]})); + DAG.getNode(FlatNode, DL, AccVT, {Subvectors[0], Subvectors[1]})); Subvectors.pop_front(); Subvectors.pop_front(); } diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp index b2e76cc7a8a90..7c3ca6e4574f3 100644 --- a/llvm/lib/IR/Verifier.cpp +++ b/llvm/lib/IR/Verifier.cpp @@ -6561,6 +6561,7 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) { } break; } + case Intrinsic::vector_partial_reduce_fadd: case Intrinsic::vector_partial_reduce_add: { VectorType *AccTy = cast(Call.getArgOperand(0)->getType()); VectorType *VecTy = cast(Call.getArgOperand(1)->getType()); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index a4c1e265f0e63..2733e46cf94bb 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1919,6 +1919,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, } } + // Handle floating-point partial reduction + if (Subtarget->hasSVE2p1() || Subtarget->hasSME2()) { + static const unsigned FMLAOps[] = {ISD::PARTIAL_REDUCE_FMLA}; + setPartialReduceMLAAction(FMLAOps, MVT::nxv4f32, MVT::nxv8f16, Legal); + } + // Handle non-aliasing elements mask if (Subtarget->hasSVE2() || (Subtarget->hasSME() && Subtarget->isStreaming())) { diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td index 9a23c35766cac..2323ab49dfba3 100644 --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -9457,6 +9457,7 @@ multiclass sve_float_dot { def NAME : sve_float_dot; def : SVE_3_Op_Pat(NAME)>; + def : SVE_3_Op_Pat(NAME)>; } multiclass sve_fp8_dot @fdot_wide_vl128( %acc, %a, %b) { +; CHECK-LABEL: fdot_wide_vl128: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: fdot z0.s, z1.h, z2.h +; CHECK-NEXT: ret +entry: + %a.wide = fpext %a to + %b.wide = fpext %b to + %mult = fmul %a.wide, %b.wide + %partial.reduce = call @llvm.vector.partial.reduce.fadd( %acc, %mult) + ret %partial.reduce +} + +define @fdot_splat_vl128( %acc, %a) { +; CHECK-LABEL: fdot_splat_vl128: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: fmov z2.h, #1.00000000 +; CHECK-NEXT: fdot z0.s, z1.h, z2.h +; CHECK-NEXT: ret +entry: + %a.wide = fpext %a to + %partial.reduce = call @llvm.vector.partial.reduce.fadd( %acc, %a.wide) + ret %partial.reduce +} + +define void @fdot_wide_vl256(ptr %accptr, ptr %aptr, ptr %bptr) vscale_range(2,2) { +; CHECK-LABEL: fdot_wide_vl256: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: ld1h { z0.s }, p0/z, [x1] +; CHECK-NEXT: ld1h { z1.s }, p0/z, [x2] +; CHECK-NEXT: ld1h { z2.s }, p0/z, [x1, #1, mul vl] +; CHECK-NEXT: ld1h { z3.s }, p0/z, [x2, #1, mul vl] +; CHECK-NEXT: fcvt z0.s, p0/m, z0.h +; CHECK-NEXT: fcvt z1.s, p0/m, z1.h +; CHECK-NEXT: fcvt z2.s, p0/m, z2.h +; CHECK-NEXT: fcvt z3.s, p0/m, z3.h +; CHECK-NEXT: fmul z0.s, z0.s, z1.s +; CHECK-NEXT: ldr z1, [x0] +; CHECK-NEXT: fmul z2.s, z2.s, z3.s +; CHECK-NEXT: fadd z0.s, z1.s, z0.s +; CHECK-NEXT: fadd z0.s, z0.s, z2.s +; CHECK-NEXT: str z0, [x0] +; CHECK-NEXT: ret +entry: + %acc = load <8 x float>, ptr %accptr + %a = load <16 x half>, ptr %aptr + %b = load <16 x half>, ptr %bptr + %a.wide = fpext <16 x half> %a to <16 x float> + %b.wide = fpext <16 x half> %b to <16 x float> + %mult = fmul <16 x float> %a.wide, %b.wide + %partial.reduce = call <8 x float> @llvm.vector.partial.reduce.fadd(<8 x float> %acc, <16 x float> %mult) + store <8 x float> %partial.reduce, ptr %accptr + ret void +} + +define <4 x float> @fixed_fdot_wide(<4 x float> %acc, <8 x half> %a, <8 x half> %b) { +; CHECK-LABEL: fixed_fdot_wide: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: fcvtl v3.4s, v1.4h +; CHECK-NEXT: fcvtl v4.4s, v2.4h +; CHECK-NEXT: fcvtl2 v1.4s, v1.8h +; CHECK-NEXT: fcvtl2 v2.4s, v2.8h +; CHECK-NEXT: fmul v3.4s, v3.4s, v4.4s +; CHECK-NEXT: fmul v1.4s, v1.4s, v2.4s +; CHECK-NEXT: fadd v0.4s, v0.4s, v3.4s +; CHECK-NEXT: fadd v0.4s, v0.4s, v1.4s +; CHECK-NEXT: ret +entry: + %a.wide = fpext <8 x half> %a to <8 x float> + %b.wide = fpext <8 x half> %b to <8 x float> + %mult = fmul <8 x float> %a.wide, %b.wide + %partial.reduce = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %mult) + ret <4 x float> %partial.reduce +} + +define <8 x half> @partial_reduce_half(<8 x half> %acc, <16 x half> %a) { +; CHECK-LABEL: partial_reduce_half: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: fadd v0.8h, v0.8h, v1.8h +; CHECK-NEXT: fadd v0.8h, v0.8h, v2.8h +; CHECK-NEXT: ret +entry: + %partial.reduce = call <8 x half> @llvm.vector.partial.reduce.fadd(<8 x half> %acc, <16 x half> %a) + ret <8 x half> %partial.reduce +} + +define <4 x float> @partial_reduce_float(<4 x float> %acc, <8 x float> %a) { +; CHECK-LABEL: partial_reduce_float: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: fadd v0.4s, v0.4s, v1.4s +; CHECK-NEXT: fadd v0.4s, v0.4s, v2.4s +; CHECK-NEXT: ret +entry: + %partial.reduce = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %a) + ret <4 x float> %partial.reduce +} + +define <2 x double> @partial_reduce_double(<2 x double> %acc, <4 x double> %a) { +; CHECK-LABEL: partial_reduce_double: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: fadd v0.2d, v0.2d, v1.2d +; CHECK-NEXT: fadd v0.2d, v0.2d, v2.2d +; CHECK-NEXT: ret +entry: + %partial.reduce = call <2 x double> @llvm.vector.partial.reduce.fadd(<2 x double> %acc, <4 x double> %a) + ret <2 x double> %partial.reduce +} + +define @partial_reduce_half_vl128( %acc, %a) { +; CHECK-LABEL: partial_reduce_half_vl128: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: fadd z0.h, z0.h, z1.h +; CHECK-NEXT: fadd z0.h, z0.h, z2.h +; CHECK-NEXT: ret +entry: + %partial.reduce = call @llvm.vector.partial.reduce.fadd( %acc, %a) + ret %partial.reduce +} + +define @partial_reduce_float_vl128( %acc, %a) { +; CHECK-LABEL: partial_reduce_float_vl128: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: fadd z0.s, z0.s, z1.s +; CHECK-NEXT: fadd z0.s, z0.s, z2.s +; CHECK-NEXT: ret +entry: + %partial.reduce = call @llvm.vector.partial.reduce.fadd( %acc, %a) + ret %partial.reduce +} + +define @partial_reduce_double_vl128( %acc, %a) { +; CHECK-LABEL: partial_reduce_double_vl128: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: fadd z0.d, z0.d, z1.d +; CHECK-NEXT: fadd z0.d, z0.d, z2.d +; CHECK-NEXT: ret +entry: + %partial.reduce = call @llvm.vector.partial.reduce.fadd( %acc, %a) + ret %partial.reduce +}