-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[RISCV] Allow swapped operands in reduction formation #68634
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RISCV] Allow swapped operands in reduction formation #68634
Conversation
Very straight forward, but worth lnading on it's own in advance of a more complicated generalization.
@llvm/pr-subscribers-backend-risc-v ChangesVery straight forward, but worth lnading on it's own in advance of a more complicated generalization. Full diff: https://github.com/llvm/llvm-project/pull/68634.diff 2 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 6be3fa71479be5c..b0fc99f6eff860b 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -11363,16 +11363,20 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
const unsigned ReduceOpc = getVecReduceOpcode(Opc);
assert(Opc == ISD::getVecReduceBaseOpcode(ReduceOpc) &&
"Inconsistent mappings");
- const SDValue LHS = N->getOperand(0);
- const SDValue RHS = N->getOperand(1);
+ SDValue LHS = N->getOperand(0);
+ SDValue RHS = N->getOperand(1);
if (!LHS.hasOneUse() || !RHS.hasOneUse())
return SDValue();
+ if (RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
+ std::swap(LHS, RHS);
+
if (RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
!isa<ConstantSDNode>(RHS.getOperand(1)))
return SDValue();
+ uint64_t RHSIdx = cast<ConstantSDNode>(RHS.getOperand(1))->getLimitedValue();
SDValue SrcVec = RHS.getOperand(0);
EVT SrcVecVT = SrcVec.getValueType();
assert(SrcVecVT.getVectorElementType() == VT);
@@ -11385,14 +11389,17 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
// match binop (extract_vector_elt V, 0), (extract_vector_elt V, 1) to
// reduce_op (extract_subvector [2 x VT] from V). This will form the
// root of our reduction tree. TODO: We could extend this to any two
- // adjacent constant indices if desired.
+ // adjacent aligned constant indices if desired.
if (LHS.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
- LHS.getOperand(0) == SrcVec && isNullConstant(LHS.getOperand(1)) &&
- isOneConstant(RHS.getOperand(1))) {
- EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
- SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
- DAG.getVectorIdxConstant(0, DL));
- return DAG.getNode(ReduceOpc, DL, VT, Vec, N->getFlags());
+ LHS.getOperand(0) == SrcVec && isa<ConstantSDNode>(LHS.getOperand(1))) {
+ uint64_t LHSIdx =
+ cast<ConstantSDNode>(LHS.getOperand(1))->getLimitedValue();
+ if (0 == std::min(LHSIdx, RHSIdx) && 1 == std::max(LHSIdx, RHSIdx)) {
+ EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
+ SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
+ DAG.getVectorIdxConstant(0, DL));
+ return DAG.getNode(ReduceOpc, DL, VT, Vec, N->getFlags());
+ }
}
// Match (binop (reduce (extract_subvector V, 0),
@@ -11404,20 +11411,18 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
SDValue ReduceVec = LHS.getOperand(0);
if (ReduceVec.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
ReduceVec.hasOneUse() && ReduceVec.getOperand(0) == RHS.getOperand(0) &&
- isNullConstant(ReduceVec.getOperand(1))) {
- uint64_t Idx = cast<ConstantSDNode>(RHS.getOperand(1))->getLimitedValue();
- if (ReduceVec.getValueType().getVectorNumElements() == Idx) {
- // For illegal types (e.g. 3xi32), most will be combined again into a
- // wider (hopefully legal) type. If this is a terminal state, we are
- // relying on type legalization here to produce something reasonable
- // and this lowering quality could probably be improved. (TODO)
- EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, Idx + 1);
- SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
- DAG.getVectorIdxConstant(0, DL));
- auto Flags = ReduceVec->getFlags();
- Flags.intersectWith(N->getFlags());
- return DAG.getNode(ReduceOpc, DL, VT, Vec, Flags);
- }
+ isNullConstant(ReduceVec.getOperand(1)) &&
+ ReduceVec.getValueType().getVectorNumElements() == RHSIdx) {
+ // For illegal types (e.g. 3xi32), most will be combined again into a
+ // wider (hopefully legal) type. If this is a terminal state, we are
+ // relying on type legalization here to produce something reasonable
+ // and this lowering quality could probably be improved. (TODO)
+ EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, RHSIdx + 1);
+ SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
+ DAG.getVectorIdxConstant(0, DL));
+ auto Flags = ReduceVec->getFlags();
+ Flags.intersectWith(N->getFlags());
+ return DAG.getNode(ReduceOpc, DL, VT, Vec, Flags);
}
return SDValue();
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
index 76df097a7697162..fd4a54b468f15fd 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
@@ -34,7 +34,6 @@ define i32 @reduce_sum_4xi32(<4 x i32> %v) {
ret i32 %add2
}
-
define i32 @reduce_sum_8xi32(<8 x i32> %v) {
; CHECK-LABEL: reduce_sum_8xi32:
; CHECK: # %bb.0:
@@ -449,6 +448,68 @@ define i32 @reduce_sum_16xi32_prefix15(ptr %p) {
ret i32 %add13
}
+; Check that we can match with the operand ordered reversed, but the
+; reduction order unchanged.
+define i32 @reduce_sum_4xi32_op_order(<4 x i32> %v) {
+; CHECK-LABEL: reduce_sum_4xi32_op_order:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; CHECK-NEXT: vmv.s.x v9, zero
+; CHECK-NEXT: vredsum.vs v8, v8, v9
+; CHECK-NEXT: vmv.x.s a0, v8
+; CHECK-NEXT: ret
+ %e0 = extractelement <4 x i32> %v, i32 0
+ %e1 = extractelement <4 x i32> %v, i32 1
+ %e2 = extractelement <4 x i32> %v, i32 2
+ %e3 = extractelement <4 x i32> %v, i32 3
+ %add0 = add i32 %e1, %e0
+ %add1 = add i32 %e2, %add0
+ %add2 = add i32 %add1, %e3
+ ret i32 %add2
+}
+
+; Negative test - Reduction order isn't compatibile with current
+; incremental matching scheme.
+define i32 @reduce_sum_4xi32_reduce_order(<4 x i32> %v) {
+; RV32-LABEL: reduce_sum_4xi32_reduce_order:
+; RV32: # %bb.0:
+; RV32-NEXT: vsetivli zero, 1, e32, m1, ta, ma
+; RV32-NEXT: vmv.x.s a0, v8
+; RV32-NEXT: vslidedown.vi v9, v8, 1
+; RV32-NEXT: vmv.x.s a1, v9
+; RV32-NEXT: vslidedown.vi v9, v8, 2
+; RV32-NEXT: vmv.x.s a2, v9
+; RV32-NEXT: vslidedown.vi v8, v8, 3
+; RV32-NEXT: vmv.x.s a3, v8
+; RV32-NEXT: add a1, a1, a2
+; RV32-NEXT: add a0, a0, a3
+; RV32-NEXT: add a0, a0, a1
+; RV32-NEXT: ret
+;
+; RV64-LABEL: reduce_sum_4xi32_reduce_order:
+; RV64: # %bb.0:
+; RV64-NEXT: vsetivli zero, 1, e32, m1, ta, ma
+; RV64-NEXT: vmv.x.s a0, v8
+; RV64-NEXT: vslidedown.vi v9, v8, 1
+; RV64-NEXT: vmv.x.s a1, v9
+; RV64-NEXT: vslidedown.vi v9, v8, 2
+; RV64-NEXT: vmv.x.s a2, v9
+; RV64-NEXT: vslidedown.vi v8, v8, 3
+; RV64-NEXT: vmv.x.s a3, v8
+; RV64-NEXT: add a1, a1, a2
+; RV64-NEXT: add a0, a0, a3
+; RV64-NEXT: addw a0, a0, a1
+; RV64-NEXT: ret
+ %e0 = extractelement <4 x i32> %v, i32 0
+ %e1 = extractelement <4 x i32> %v, i32 1
+ %e2 = extractelement <4 x i32> %v, i32 2
+ %e3 = extractelement <4 x i32> %v, i32 3
+ %add0 = add i32 %e1, %e2
+ %add1 = add i32 %e0, %add0
+ %add2 = add i32 %add1, %e3
+ ret i32 %add2
+}
+
;; Most of the cornercases are exercised above, the following just
;; makes sure that other opcodes work as expected.
@@ -923,6 +984,3 @@ define float @reduce_fadd_4xi32_non_associative2(ptr %p) {
}
-;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
-; RV32: {{.*}}
-; RV64: {{.*}}
|
You can test this locally with the following command:git-clang-format --diff 07d2e90f28e36ac3c0a79d208ab74610f4b98546 4659721e895bd52eda40e7d500c4bfec018c57a2 -- llvm/lib/Target/RISCV/RISCVISelLowering.cpp View the diff from clang-format here.diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b0fc99f6eff8..93579c2c1f69 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -11393,7 +11393,7 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
if (LHS.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
LHS.getOperand(0) == SrcVec && isa<ConstantSDNode>(LHS.getOperand(1))) {
uint64_t LHSIdx =
- cast<ConstantSDNode>(LHS.getOperand(1))->getLimitedValue();
+ cast<ConstantSDNode>(LHS.getOperand(1))->getLimitedValue();
if (0 == std::min(LHSIdx, RHSIdx) && 1 == std::max(LHSIdx, RHSIdx)) {
EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
|
ping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Very straight forward, but worth lnading on it's own in advance of a more complicated generalization.