-
Notifications
You must be signed in to change notification settings - Fork 11k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SelectionDAG][RISCV] Move VP_REDUCE* legalization to LegalizeDAG.cpp. #90522
Conversation
LegalizeVectorType is responsible for legalizing nodes that perform an operation on each element may need to scalarize. This is not true for nodes like VP_REDUCE.*, BUILD_VECTOR, SHUFFLE_VECTOR, EXTRACT_SUBVECTOR, etc. This patch drops any nodes with a scalar result from LegalizeVectorOps and handles them in LegalizeDAG instead. This required moving the reduction promotion to LegalizeDAG. I have removed the support integer promotion as it was incorrect for integer min/max reductions. Since it was untested, it was best to assert on it until it was really needed. There are a couple regressions that can be fixed with a small DAG combine which I will do as a follow up.
@llvm/pr-subscribers-llvm-selectiondag Author: Craig Topper (topperc) ChangesLegalizeVectorType is responsible for legalizing nodes that perform an operation on each element may need to scalarize. This is not true for nodes like VP_REDUCE.*, BUILD_VECTOR, SHUFFLE_VECTOR, EXTRACT_SUBVECTOR, etc. This patch drops any nodes with a scalar result from LegalizeVectorOps and handles them in LegalizeDAG instead. This required moving the reduction promotion to LegalizeDAG. I have removed the support integer promotion as it was incorrect for integer min/max reductions. Since it was untested, it was best to assert on it until it was really needed. There are a couple regressions that can be fixed with a small DAG combine which I will do as a follow up. Full diff: https://github.com/llvm/llvm-project/pull/90522.diff 4 Files Affected:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 46e54b5366d66a..9dd40531abb005 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -180,6 +180,13 @@ class SelectionDAGLegalize {
SmallVectorImpl<SDValue> &Results);
SDValue PromoteLegalFP_TO_INT_SAT(SDNode *Node, const SDLoc &dl);
+ /// Implements vector reduce operation promotion.
+ ///
+ /// All vector operands are promoted to a vector type with larger element
+ /// type, and the start value is promoted to a larger scalar type. Then the
+ /// result is truncated back to the original scalar type.
+ void PromoteReduction(SDNode *Node, SmallVectorImpl<SDValue> &Results);
+
SDValue ExpandPARITY(SDValue Op, const SDLoc &dl);
SDValue ExpandExtractFromVectorThroughStack(SDValue Op);
@@ -2979,6 +2986,49 @@ SDValue SelectionDAGLegalize::ExpandPARITY(SDValue Op, const SDLoc &dl) {
return DAG.getNode(ISD::AND, dl, VT, Result, DAG.getConstant(1, dl, VT));
}
+void SelectionDAGLegalize::PromoteReduction(SDNode *Node,
+ SmallVectorImpl<SDValue> &Results) {
+ MVT VecVT = Node->getOperand(1).getSimpleValueType();
+ MVT NewVecVT = TLI.getTypeToPromoteTo(Node->getOpcode(), VecVT);
+ MVT ScalarVT = Node->getSimpleValueType(0);
+ MVT NewScalarVT = NewVecVT.getVectorElementType();
+
+ SDLoc DL(Node);
+ SmallVector<SDValue, 4> Operands(Node->getNumOperands());
+
+ // promote the initial value.
+ // FIXME: Support integer.
+ assert(Node->getOperand(0).getValueType().isFloatingPoint() &&
+ "Only FP promotion is supported");
+ Operands[0] =
+ DAG.getNode(ISD::FP_EXTEND, DL, NewScalarVT, Node->getOperand(0));
+
+ for (unsigned j = 1; j != Node->getNumOperands(); ++j)
+ if (Node->getOperand(j).getValueType().isVector() &&
+ !(ISD::isVPOpcode(Node->getOpcode()) &&
+ ISD::getVPMaskIdx(Node->getOpcode()) == j)) { // Skip mask operand.
+ // promote the vector operand.
+ // FIXME: Support integer.
+ assert(Node->getOperand(j).getValueType().isFloatingPoint() &&
+ "Only FP promotion is supported");
+ Operands[j] =
+ DAG.getNode(ISD::FP_EXTEND, DL, NewVecVT, Node->getOperand(j));
+ } else {
+ Operands[j] = Node->getOperand(j); // Skip VL operand.
+ }
+
+ SDValue Res = DAG.getNode(Node->getOpcode(), DL, NewScalarVT, Operands,
+ Node->getFlags());
+
+ if (ScalarVT.isFloatingPoint())
+ Res = DAG.getNode(ISD::FP_ROUND, DL, ScalarVT, Res,
+ DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
+ else
+ Res = DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, Res);
+
+ Results.push_back(Res);
+}
+
bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
LLVM_DEBUG(dbgs() << "Trying to expand node\n");
SmallVector<SDValue, 8> Results;
@@ -4955,7 +5005,12 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
if (Node->getOpcode() == ISD::STRICT_UINT_TO_FP ||
Node->getOpcode() == ISD::STRICT_SINT_TO_FP ||
Node->getOpcode() == ISD::STRICT_FSETCC ||
- Node->getOpcode() == ISD::STRICT_FSETCCS)
+ Node->getOpcode() == ISD::STRICT_FSETCCS ||
+ Node->getOpcode() == ISD::VP_REDUCE_FADD ||
+ Node->getOpcode() == ISD::VP_REDUCE_FMUL ||
+ Node->getOpcode() == ISD::VP_REDUCE_FMAX ||
+ Node->getOpcode() == ISD::VP_REDUCE_FMIN ||
+ Node->getOpcode() == ISD::VP_REDUCE_SEQ_FADD)
OVT = Node->getOperand(1).getSimpleValueType();
if (Node->getOpcode() == ISD::BR_CC ||
Node->getOpcode() == ISD::SELECT_CC)
@@ -5613,6 +5668,13 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)));
break;
}
+ case ISD::VP_REDUCE_FADD:
+ case ISD::VP_REDUCE_FMUL:
+ case ISD::VP_REDUCE_FMAX:
+ case ISD::VP_REDUCE_FMIN:
+ case ISD::VP_REDUCE_SEQ_FADD:
+ PromoteReduction(Node, Results);
+ break;
}
// Replace the original node with the legalized result.
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 8f87ee8e09393a..423df9ae6b2a55 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -176,13 +176,6 @@ class VectorLegalizer {
/// truncated back to the original type.
void PromoteFP_TO_INT(SDNode *Node, SmallVectorImpl<SDValue> &Results);
- /// Implements vector reduce operation promotion.
- ///
- /// All vector operands are promoted to a vector type with larger element
- /// type, and the start value is promoted to a larger scalar type. Then the
- /// result is truncated back to the original scalar type.
- void PromoteReduction(SDNode *Node, SmallVectorImpl<SDValue> &Results);
-
/// Implements vector setcc operation promotion.
///
/// All vector operands are promoted to a vector type with larger element
@@ -510,6 +503,11 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
if (Action != TargetLowering::Legal) \
break; \
} \
+ /* Defer non-vector results to LegalizeDAG. */ \
+ if (!Node->getValueType(0).isVector()) { \
+ Action = TargetLowering::Legal; \
+ break; \
+ } \
Action = TLI.getOperationAction(Node->getOpcode(), LegalizeVT); \
} break;
#include "llvm/IR/VPIntrinsics.def"
@@ -580,50 +578,6 @@ bool VectorLegalizer::LowerOperationWrapper(SDNode *Node,
return true;
}
-void VectorLegalizer::PromoteReduction(SDNode *Node,
- SmallVectorImpl<SDValue> &Results) {
- MVT VecVT = Node->getOperand(1).getSimpleValueType();
- MVT NewVecVT = TLI.getTypeToPromoteTo(Node->getOpcode(), VecVT);
- MVT ScalarVT = Node->getSimpleValueType(0);
- MVT NewScalarVT = NewVecVT.getVectorElementType();
-
- SDLoc DL(Node);
- SmallVector<SDValue, 4> Operands(Node->getNumOperands());
-
- // promote the initial value.
- if (Node->getOperand(0).getValueType().isFloatingPoint())
- Operands[0] =
- DAG.getNode(ISD::FP_EXTEND, DL, NewScalarVT, Node->getOperand(0));
- else
- Operands[0] =
- DAG.getNode(ISD::ANY_EXTEND, DL, NewScalarVT, Node->getOperand(0));
-
- for (unsigned j = 1; j != Node->getNumOperands(); ++j)
- if (Node->getOperand(j).getValueType().isVector() &&
- !(ISD::isVPOpcode(Node->getOpcode()) &&
- ISD::getVPMaskIdx(Node->getOpcode()) == j)) // Skip mask operand.
- // promote the vector operand.
- if (Node->getOperand(j).getValueType().isFloatingPoint())
- Operands[j] =
- DAG.getNode(ISD::FP_EXTEND, DL, NewVecVT, Node->getOperand(j));
- else
- Operands[j] =
- DAG.getNode(ISD::ANY_EXTEND, DL, NewVecVT, Node->getOperand(j));
- else
- Operands[j] = Node->getOperand(j); // Skip VL operand.
-
- SDValue Res = DAG.getNode(Node->getOpcode(), DL, NewScalarVT, Operands,
- Node->getFlags());
-
- if (ScalarVT.isFloatingPoint())
- Res = DAG.getNode(ISD::FP_ROUND, DL, ScalarVT, Res,
- DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
- else
- Res = DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, Res);
-
- Results.push_back(Res);
-}
-
void VectorLegalizer::PromoteSETCC(SDNode *Node,
SmallVectorImpl<SDValue> &Results) {
MVT VecVT = Node->getOperand(0).getSimpleValueType();
@@ -708,23 +662,6 @@ void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
// Promote the operation by extending the operand.
PromoteFP_TO_INT(Node, Results);
return;
- case ISD::VP_REDUCE_ADD:
- case ISD::VP_REDUCE_MUL:
- case ISD::VP_REDUCE_AND:
- case ISD::VP_REDUCE_OR:
- case ISD::VP_REDUCE_XOR:
- case ISD::VP_REDUCE_SMAX:
- case ISD::VP_REDUCE_SMIN:
- case ISD::VP_REDUCE_UMAX:
- case ISD::VP_REDUCE_UMIN:
- case ISD::VP_REDUCE_FADD:
- case ISD::VP_REDUCE_FMUL:
- case ISD::VP_REDUCE_FMAX:
- case ISD::VP_REDUCE_FMIN:
- case ISD::VP_REDUCE_SEQ_FADD:
- // Promote the operation by extending the operand.
- PromoteReduction(Node, Results);
- return;
case ISD::VP_SETCC:
case ISD::SETCC:
// Promote the operation by extending the operand.
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll
index 02a989a9699606..b874a4477f5d17 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll
@@ -802,25 +802,27 @@ define signext i32 @vpreduce_xor_v64i32(i32 signext %s, <64 x i32> %v, <64 x i1>
; CHECK-LABEL: vpreduce_xor_v64i32:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 4, e8, mf2, ta, ma
-; CHECK-NEXT: li a3, 32
; CHECK-NEXT: vslidedown.vi v24, v0, 4
-; CHECK-NEXT: mv a2, a1
-; CHECK-NEXT: bltu a1, a3, .LBB49_2
+; CHECK-NEXT: addi a2, a1, -32
+; CHECK-NEXT: sltu a3, a1, a2
+; CHECK-NEXT: addi a3, a3, -1
+; CHECK-NEXT: li a4, 32
+; CHECK-NEXT: and a2, a3, a2
+; CHECK-NEXT: bltu a1, a4, .LBB49_2
; CHECK-NEXT: # %bb.1:
-; CHECK-NEXT: li a2, 32
+; CHECK-NEXT: li a1, 32
; CHECK-NEXT: .LBB49_2:
; CHECK-NEXT: vsetvli zero, zero, e32, m2, ta, ma
; CHECK-NEXT: vmv.s.x v25, a0
-; CHECK-NEXT: vsetvli zero, a2, e32, m8, ta, ma
+; CHECK-NEXT: vsetvli zero, a1, e32, m8, ta, ma
; CHECK-NEXT: vredxor.vs v25, v8, v25, v0.t
-; CHECK-NEXT: addi a0, a1, -32
-; CHECK-NEXT: sltu a1, a1, a0
-; CHECK-NEXT: addi a1, a1, -1
-; CHECK-NEXT: and a0, a1, a0
-; CHECK-NEXT: vsetvli zero, a0, e32, m8, ta, ma
-; CHECK-NEXT: vmv1r.v v0, v24
-; CHECK-NEXT: vredxor.vs v25, v16, v25, v0.t
; CHECK-NEXT: vmv.x.s a0, v25
+; CHECK-NEXT: vsetivli zero, 1, e32, m8, ta, ma
+; CHECK-NEXT: vmv.s.x v8, a0
+; CHECK-NEXT: vsetvli zero, a2, e32, m8, ta, ma
+; CHECK-NEXT: vmv1r.v v0, v24
+; CHECK-NEXT: vredxor.vs v8, v16, v8, v0.t
+; CHECK-NEXT: vmv.x.s a0, v8
; CHECK-NEXT: ret
%r = call i32 @llvm.vp.reduce.xor.v64i32(i32 %s, <64 x i32> %v, <64 x i1> %m, i32 %evl)
ret i32 %r
diff --git a/llvm/test/CodeGen/RISCV/rvv/vreductions-int-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vreductions-int-vp.ll
index 7bcf37b1af3c8f..95b64cb662a614 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vreductions-int-vp.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vreductions-int-vp.ll
@@ -1115,10 +1115,13 @@ define signext i32 @vpreduce_umax_nxv32i32(i32 signext %s, <vscale x 32 x i32> %
; CHECK-NEXT: vmv.s.x v25, a0
; CHECK-NEXT: vsetvli zero, a1, e32, m8, ta, ma
; CHECK-NEXT: vredmaxu.vs v25, v8, v25, v0.t
+; CHECK-NEXT: vmv.x.s a0, v25
+; CHECK-NEXT: vsetivli zero, 1, e32, m8, ta, ma
+; CHECK-NEXT: vmv.s.x v8, a0
; CHECK-NEXT: vsetvli zero, a2, e32, m8, ta, ma
; CHECK-NEXT: vmv1r.v v0, v24
-; CHECK-NEXT: vredmaxu.vs v25, v16, v25, v0.t
-; CHECK-NEXT: vmv.x.s a0, v25
+; CHECK-NEXT: vredmaxu.vs v8, v16, v8, v0.t
+; CHECK-NEXT: vmv.x.s a0, v8
; CHECK-NEXT: ret
%r = call i32 @llvm.vp.reduce.umax.nxv32i32(i32 %s, <vscale x 32 x i32> %v, <vscale x 32 x i1> %m, i32 %evl)
ret i32 %r
|
The regressions will be fixed by #90524 |
Res = DAG.getNode(ISD::FP_ROUND, DL, ScalarVT, Res, | ||
DAG.getIntPtrConstant(0, DL, /*isTarget=*/true)); | ||
else | ||
Res = DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, Res); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a dead code now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
I notice we might be missing support for non-VP reductions for f16 vectors with Zfhmin. |
LegalizeVectorType is responsible for legalizing nodes that perform an operation on each element may need to scalarize.
This is not true for nodes like VP_REDUCE.*, BUILD_VECTOR, SHUFFLE_VECTOR, EXTRACT_SUBVECTOR, etc.
This patch drops any nodes with a scalar result from LegalizeVectorOps and handles them in LegalizeDAG instead.
This required moving the reduction promotion to LegalizeDAG. I have removed the support integer promotion as it was incorrect for integer min/max reductions. Since it was untested, it was best to assert on it until it was really needed.
There are a couple regressions that can be fixed with a small DAG combine which I will do as a follow up.