diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp index b2147d37031ab..1f0b39cd85e1e 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp @@ -191,6 +191,8 @@ class VectorLegalizer { /// rounding of the result does not affect its value. void PromoteFloatVECREDUCE(SDNode *Node, SmallVectorImpl &Results, bool NonArithmetic); + void PromoteFloatVECREDUCE_SEQ(SDNode *Node, + SmallVectorImpl &Results); void PromoteVECTOR_COMPRESS(SDNode *Node, SmallVectorImpl &Results); @@ -722,6 +724,23 @@ void VectorLegalizer::PromoteFloatVECREDUCE(SDNode *Node, Results.push_back(Res); } +void VectorLegalizer::PromoteFloatVECREDUCE_SEQ( + SDNode *Node, SmallVectorImpl &Results) { + MVT OrigVecVT = Node->getOperand(1).getSimpleValueType(); + assert(OrigVecVT.isFloatingPoint() && "Expected floating point reduction!"); + MVT VecVT = TLI.getTypeToPromoteTo(Node->getOpcode(), OrigVecVT); + MVT EltVT = VecVT.getVectorElementType(); + + SDLoc DL(Node); + SDValue EltOp = DAG.getNode(ISD::FP_EXTEND, DL, EltVT, Node->getOperand(0)); + SDValue VecOp = DAG.getNode(ISD::FP_EXTEND, DL, VecVT, Node->getOperand(1)); + SDValue Rdx = + DAG.getNode(Node->getOpcode(), DL, EltVT, EltOp, VecOp, Node->getFlags()); + SDValue Res = DAG.getNode(ISD::FP_ROUND, DL, Node->getValueType(0), Rdx, + DAG.getIntPtrConstant(0, DL, /*isTarget=*/true)); + Results.push_back(Res); +} + void VectorLegalizer::PromoteVECTOR_COMPRESS( SDNode *Node, SmallVectorImpl &Results) { SDLoc DL(Node); @@ -790,6 +809,9 @@ void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl &Results) { case ISD::VECREDUCE_FMINIMUM: PromoteFloatVECREDUCE(Node, Results, /*NonArithmetic=*/true); return; + case ISD::VECREDUCE_SEQ_FADD: + PromoteFloatVECREDUCE_SEQ(Node, Results); + return; case ISD::VECTOR_COMPRESS: PromoteVECTOR_COMPRESS(Node, Results); return; diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index 48fe1eed9093f..a365c346d7cfc 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -12659,7 +12659,7 @@ SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const { } SDValue TargetLowering::expandVecReduceSeq(SDNode *Node, SelectionDAG &DAG) const { - SDLoc dl(Node); + SDLoc DL(Node); SDValue AccOp = Node->getOperand(0); SDValue VecOp = Node->getOperand(1); SDNodeFlags Flags = Node->getFlags(); @@ -12667,6 +12667,18 @@ SDValue TargetLowering::expandVecReduceSeq(SDNode *Node, SelectionDAG &DAG) cons EVT VT = VecOp.getValueType(); EVT EltVT = VT.getVectorElementType(); + if (VT.getVectorElementCount().isKnownMultipleOf(2)) { + auto [LoVecVT, HiVecVT] = DAG.GetSplitDestVTs(VT); + if (isOperationLegalOrCustomOrPromote(Node->getOpcode(), LoVecVT) && + isOperationLegalOrCustomOrPromote(Node->getOpcode(), HiVecVT)) { + auto [LoVecOp, HiVecOp] = DAG.SplitVector(VecOp, DL, LoVecVT, HiVecVT); + + unsigned Opcode = Node->getOpcode(); + SDValue ReduceLo = DAG.getNode(Opcode, DL, EltVT, AccOp, LoVecOp, Flags); + return DAG.getNode(Opcode, DL, EltVT, ReduceLo, HiVecOp, Flags); + } + } + if (VT.isScalableVector()) report_fatal_error( "Expanding reductions for scalable vectors is undefined."); @@ -12680,7 +12692,7 @@ SDValue TargetLowering::expandVecReduceSeq(SDNode *Node, SelectionDAG &DAG) cons SDValue Res = AccOp; for (unsigned i = 0; i < NumElts; i++) - Res = DAG.getNode(BaseOpcode, dl, EltVT, Res, Ops[i], Flags); + Res = DAG.getNode(BaseOpcode, DL, EltVT, Res, Ops[i], Flags); return Res; } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 08f3d4e0d30ac..22cd8d0db512d 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -2134,6 +2134,13 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, MVT::v2f32, MVT::v4f32, MVT::v2f64}) setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom); + setOperationPromotedToType(ISD::VECREDUCE_SEQ_FADD, MVT::nxv2bf16, + MVT::nxv2f32); + setOperationPromotedToType(ISD::VECREDUCE_SEQ_FADD, MVT::nxv4bf16, + MVT::nxv4f32); + setOperationPromotedToType(ISD::VECREDUCE_SEQ_FADD, MVT::nxv8bf16, + MVT::nxv8f32); + // We can lower types that have elements to compact. for (auto VT : {MVT::nxv4i32, MVT::nxv2i64, MVT::nxv2f32, MVT::nxv4f32, MVT::nxv2f64}) diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-reductions-ordered.ll b/llvm/test/CodeGen/AArch64/sve-bf16-reductions-ordered.ll new file mode 100644 index 0000000000000..e6bca48283760 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-bf16-reductions-ordered.ll @@ -0,0 +1,49 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc -mattr=+sve,+bf16 < %s | FileCheck %s + +target triple = "aarch64-unknown-linux-gnu" + +define bfloat @fadd_ordered_nxv2bf16(bfloat %a, %b) { +; CHECK-LABEL: fadd_ordered_nxv2bf16: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $h0 killed $h0 def $d0 +; CHECK-NEXT: lsl z1.s, z1.s, #16 +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: shll v0.4s, v0.4h, #16 +; CHECK-NEXT: fadda s0, p0, s0, z1.s +; CHECK-NEXT: bfcvt h0, s0 +; CHECK-NEXT: ret + %res = call bfloat @llvm.vector.reduce.fadd(bfloat %a, %b) + ret bfloat %res +} + +define bfloat @fadd_ordered_nxv4bf16(bfloat %a, %b) { +; CHECK-LABEL: fadd_ordered_nxv4bf16: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $h0 killed $h0 def $d0 +; CHECK-NEXT: lsl z1.s, z1.s, #16 +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: shll v0.4s, v0.4h, #16 +; CHECK-NEXT: fadda s0, p0, s0, z1.s +; CHECK-NEXT: bfcvt h0, s0 +; CHECK-NEXT: ret + %res = call bfloat @llvm.vector.reduce.fadd(bfloat %a, %b) + ret bfloat %res +} + +define bfloat @fadd_ordered_nxv8bf16(bfloat %a, %b) { +; CHECK-LABEL: fadd_ordered_nxv8bf16: +; CHECK: // %bb.0: +; CHECK-NEXT: movi v2.2d, #0000000000000000 +; CHECK-NEXT: // kill: def $h0 killed $h0 def $d0 +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: shll v0.4s, v0.4h, #16 +; CHECK-NEXT: zip1 z3.h, z2.h, z1.h +; CHECK-NEXT: zip2 z1.h, z2.h, z1.h +; CHECK-NEXT: fadda s0, p0, s0, z3.s +; CHECK-NEXT: fadda s0, p0, s0, z1.s +; CHECK-NEXT: bfcvt h0, s0 +; CHECK-NEXT: ret + %res = call bfloat @llvm.vector.reduce.fadd(bfloat %a, %b) + ret bfloat %res +}