diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp index 46e54b5366d66..398b5fee990b5 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp @@ -180,6 +180,13 @@ class SelectionDAGLegalize { SmallVectorImpl &Results); SDValue PromoteLegalFP_TO_INT_SAT(SDNode *Node, const SDLoc &dl); + /// Implements vector reduce operation promotion. + /// + /// All vector operands are promoted to a vector type with larger element + /// type, and the start value is promoted to a larger scalar type. Then the + /// result is truncated back to the original scalar type. + void PromoteReduction(SDNode *Node, SmallVectorImpl &Results); + SDValue ExpandPARITY(SDValue Op, const SDLoc &dl); SDValue ExpandExtractFromVectorThroughStack(SDValue Op); @@ -2979,6 +2986,47 @@ SDValue SelectionDAGLegalize::ExpandPARITY(SDValue Op, const SDLoc &dl) { return DAG.getNode(ISD::AND, dl, VT, Result, DAG.getConstant(1, dl, VT)); } +void SelectionDAGLegalize::PromoteReduction(SDNode *Node, + SmallVectorImpl &Results) { + MVT VecVT = Node->getOperand(1).getSimpleValueType(); + MVT NewVecVT = TLI.getTypeToPromoteTo(Node->getOpcode(), VecVT); + MVT ScalarVT = Node->getSimpleValueType(0); + MVT NewScalarVT = NewVecVT.getVectorElementType(); + + SDLoc DL(Node); + SmallVector Operands(Node->getNumOperands()); + + // promote the initial value. + // FIXME: Support integer. + assert(Node->getOperand(0).getValueType().isFloatingPoint() && + "Only FP promotion is supported"); + Operands[0] = + DAG.getNode(ISD::FP_EXTEND, DL, NewScalarVT, Node->getOperand(0)); + + for (unsigned j = 1; j != Node->getNumOperands(); ++j) + if (Node->getOperand(j).getValueType().isVector() && + !(ISD::isVPOpcode(Node->getOpcode()) && + ISD::getVPMaskIdx(Node->getOpcode()) == j)) { // Skip mask operand. + // promote the vector operand. + // FIXME: Support integer. + assert(Node->getOperand(j).getValueType().isFloatingPoint() && + "Only FP promotion is supported"); + Operands[j] = + DAG.getNode(ISD::FP_EXTEND, DL, NewVecVT, Node->getOperand(j)); + } else { + Operands[j] = Node->getOperand(j); // Skip VL operand. + } + + SDValue Res = DAG.getNode(Node->getOpcode(), DL, NewScalarVT, Operands, + Node->getFlags()); + + assert(ScalarVT.isFloatingPoint() && "Only FP promotion is supported"); + Res = DAG.getNode(ISD::FP_ROUND, DL, ScalarVT, Res, + DAG.getIntPtrConstant(0, DL, /*isTarget=*/true)); + + Results.push_back(Res); +} + bool SelectionDAGLegalize::ExpandNode(SDNode *Node) { LLVM_DEBUG(dbgs() << "Trying to expand node\n"); SmallVector Results; @@ -4955,7 +5003,12 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) { if (Node->getOpcode() == ISD::STRICT_UINT_TO_FP || Node->getOpcode() == ISD::STRICT_SINT_TO_FP || Node->getOpcode() == ISD::STRICT_FSETCC || - Node->getOpcode() == ISD::STRICT_FSETCCS) + Node->getOpcode() == ISD::STRICT_FSETCCS || + Node->getOpcode() == ISD::VP_REDUCE_FADD || + Node->getOpcode() == ISD::VP_REDUCE_FMUL || + Node->getOpcode() == ISD::VP_REDUCE_FMAX || + Node->getOpcode() == ISD::VP_REDUCE_FMIN || + Node->getOpcode() == ISD::VP_REDUCE_SEQ_FADD) OVT = Node->getOperand(1).getSimpleValueType(); if (Node->getOpcode() == ISD::BR_CC || Node->getOpcode() == ISD::SELECT_CC) @@ -5613,6 +5666,13 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) { DAG.getIntPtrConstant(0, dl, /*isTarget=*/true))); break; } + case ISD::VP_REDUCE_FADD: + case ISD::VP_REDUCE_FMUL: + case ISD::VP_REDUCE_FMAX: + case ISD::VP_REDUCE_FMIN: + case ISD::VP_REDUCE_SEQ_FADD: + PromoteReduction(Node, Results); + break; } // Replace the original node with the legalized result. diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp index 8f87ee8e09393..423df9ae6b2a5 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp @@ -176,13 +176,6 @@ class VectorLegalizer { /// truncated back to the original type. void PromoteFP_TO_INT(SDNode *Node, SmallVectorImpl &Results); - /// Implements vector reduce operation promotion. - /// - /// All vector operands are promoted to a vector type with larger element - /// type, and the start value is promoted to a larger scalar type. Then the - /// result is truncated back to the original scalar type. - void PromoteReduction(SDNode *Node, SmallVectorImpl &Results); - /// Implements vector setcc operation promotion. /// /// All vector operands are promoted to a vector type with larger element @@ -510,6 +503,11 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) { if (Action != TargetLowering::Legal) \ break; \ } \ + /* Defer non-vector results to LegalizeDAG. */ \ + if (!Node->getValueType(0).isVector()) { \ + Action = TargetLowering::Legal; \ + break; \ + } \ Action = TLI.getOperationAction(Node->getOpcode(), LegalizeVT); \ } break; #include "llvm/IR/VPIntrinsics.def" @@ -580,50 +578,6 @@ bool VectorLegalizer::LowerOperationWrapper(SDNode *Node, return true; } -void VectorLegalizer::PromoteReduction(SDNode *Node, - SmallVectorImpl &Results) { - MVT VecVT = Node->getOperand(1).getSimpleValueType(); - MVT NewVecVT = TLI.getTypeToPromoteTo(Node->getOpcode(), VecVT); - MVT ScalarVT = Node->getSimpleValueType(0); - MVT NewScalarVT = NewVecVT.getVectorElementType(); - - SDLoc DL(Node); - SmallVector Operands(Node->getNumOperands()); - - // promote the initial value. - if (Node->getOperand(0).getValueType().isFloatingPoint()) - Operands[0] = - DAG.getNode(ISD::FP_EXTEND, DL, NewScalarVT, Node->getOperand(0)); - else - Operands[0] = - DAG.getNode(ISD::ANY_EXTEND, DL, NewScalarVT, Node->getOperand(0)); - - for (unsigned j = 1; j != Node->getNumOperands(); ++j) - if (Node->getOperand(j).getValueType().isVector() && - !(ISD::isVPOpcode(Node->getOpcode()) && - ISD::getVPMaskIdx(Node->getOpcode()) == j)) // Skip mask operand. - // promote the vector operand. - if (Node->getOperand(j).getValueType().isFloatingPoint()) - Operands[j] = - DAG.getNode(ISD::FP_EXTEND, DL, NewVecVT, Node->getOperand(j)); - else - Operands[j] = - DAG.getNode(ISD::ANY_EXTEND, DL, NewVecVT, Node->getOperand(j)); - else - Operands[j] = Node->getOperand(j); // Skip VL operand. - - SDValue Res = DAG.getNode(Node->getOpcode(), DL, NewScalarVT, Operands, - Node->getFlags()); - - if (ScalarVT.isFloatingPoint()) - Res = DAG.getNode(ISD::FP_ROUND, DL, ScalarVT, Res, - DAG.getIntPtrConstant(0, DL, /*isTarget=*/true)); - else - Res = DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, Res); - - Results.push_back(Res); -} - void VectorLegalizer::PromoteSETCC(SDNode *Node, SmallVectorImpl &Results) { MVT VecVT = Node->getOperand(0).getSimpleValueType(); @@ -708,23 +662,6 @@ void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl &Results) { // Promote the operation by extending the operand. PromoteFP_TO_INT(Node, Results); return; - case ISD::VP_REDUCE_ADD: - case ISD::VP_REDUCE_MUL: - case ISD::VP_REDUCE_AND: - case ISD::VP_REDUCE_OR: - case ISD::VP_REDUCE_XOR: - case ISD::VP_REDUCE_SMAX: - case ISD::VP_REDUCE_SMIN: - case ISD::VP_REDUCE_UMAX: - case ISD::VP_REDUCE_UMIN: - case ISD::VP_REDUCE_FADD: - case ISD::VP_REDUCE_FMUL: - case ISD::VP_REDUCE_FMAX: - case ISD::VP_REDUCE_FMIN: - case ISD::VP_REDUCE_SEQ_FADD: - // Promote the operation by extending the operand. - PromoteReduction(Node, Results); - return; case ISD::VP_SETCC: case ISD::SETCC: // Promote the operation by extending the operand. diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll index 02a989a969960..b874a4477f5d1 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll @@ -802,25 +802,27 @@ define signext i32 @vpreduce_xor_v64i32(i32 signext %s, <64 x i32> %v, <64 x i1> ; CHECK-LABEL: vpreduce_xor_v64i32: ; CHECK: # %bb.0: ; CHECK-NEXT: vsetivli zero, 4, e8, mf2, ta, ma -; CHECK-NEXT: li a3, 32 ; CHECK-NEXT: vslidedown.vi v24, v0, 4 -; CHECK-NEXT: mv a2, a1 -; CHECK-NEXT: bltu a1, a3, .LBB49_2 +; CHECK-NEXT: addi a2, a1, -32 +; CHECK-NEXT: sltu a3, a1, a2 +; CHECK-NEXT: addi a3, a3, -1 +; CHECK-NEXT: li a4, 32 +; CHECK-NEXT: and a2, a3, a2 +; CHECK-NEXT: bltu a1, a4, .LBB49_2 ; CHECK-NEXT: # %bb.1: -; CHECK-NEXT: li a2, 32 +; CHECK-NEXT: li a1, 32 ; CHECK-NEXT: .LBB49_2: ; CHECK-NEXT: vsetvli zero, zero, e32, m2, ta, ma ; CHECK-NEXT: vmv.s.x v25, a0 -; CHECK-NEXT: vsetvli zero, a2, e32, m8, ta, ma +; CHECK-NEXT: vsetvli zero, a1, e32, m8, ta, ma ; CHECK-NEXT: vredxor.vs v25, v8, v25, v0.t -; CHECK-NEXT: addi a0, a1, -32 -; CHECK-NEXT: sltu a1, a1, a0 -; CHECK-NEXT: addi a1, a1, -1 -; CHECK-NEXT: and a0, a1, a0 -; CHECK-NEXT: vsetvli zero, a0, e32, m8, ta, ma -; CHECK-NEXT: vmv1r.v v0, v24 -; CHECK-NEXT: vredxor.vs v25, v16, v25, v0.t ; CHECK-NEXT: vmv.x.s a0, v25 +; CHECK-NEXT: vsetivli zero, 1, e32, m8, ta, ma +; CHECK-NEXT: vmv.s.x v8, a0 +; CHECK-NEXT: vsetvli zero, a2, e32, m8, ta, ma +; CHECK-NEXT: vmv1r.v v0, v24 +; CHECK-NEXT: vredxor.vs v8, v16, v8, v0.t +; CHECK-NEXT: vmv.x.s a0, v8 ; CHECK-NEXT: ret %r = call i32 @llvm.vp.reduce.xor.v64i32(i32 %s, <64 x i32> %v, <64 x i1> %m, i32 %evl) ret i32 %r diff --git a/llvm/test/CodeGen/RISCV/rvv/vreductions-int-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vreductions-int-vp.ll index 7bcf37b1af3c8..95b64cb662a61 100644 --- a/llvm/test/CodeGen/RISCV/rvv/vreductions-int-vp.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vreductions-int-vp.ll @@ -1115,10 +1115,13 @@ define signext i32 @vpreduce_umax_nxv32i32(i32 signext %s, % ; CHECK-NEXT: vmv.s.x v25, a0 ; CHECK-NEXT: vsetvli zero, a1, e32, m8, ta, ma ; CHECK-NEXT: vredmaxu.vs v25, v8, v25, v0.t +; CHECK-NEXT: vmv.x.s a0, v25 +; CHECK-NEXT: vsetivli zero, 1, e32, m8, ta, ma +; CHECK-NEXT: vmv.s.x v8, a0 ; CHECK-NEXT: vsetvli zero, a2, e32, m8, ta, ma ; CHECK-NEXT: vmv1r.v v0, v24 -; CHECK-NEXT: vredmaxu.vs v25, v16, v25, v0.t -; CHECK-NEXT: vmv.x.s a0, v25 +; CHECK-NEXT: vredmaxu.vs v8, v16, v8, v0.t +; CHECK-NEXT: vmv.x.s a0, v8 ; CHECK-NEXT: ret %r = call i32 @llvm.vp.reduce.umax.nxv32i32(i32 %s, %v, %m, i32 %evl) ret i32 %r