Skip to content

Conversation

preames
Copy link
Collaborator

@preames preames commented Oct 9, 2023

Very straight forward, but worth lnading on it's own in advance of a more complicated generalization.

Very straight forward, but worth lnading on it's own in advance of a more complicated generalization.
@llvmbot
Copy link
Member

llvmbot commented Oct 9, 2023

@llvm/pr-subscribers-backend-risc-v

Changes

Very straight forward, but worth lnading on it's own in advance of a more complicated generalization.


Full diff: https://github.com/llvm/llvm-project/pull/68634.diff

2 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+28-23)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll (+62-4)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 6be3fa71479be5c..b0fc99f6eff860b 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -11363,16 +11363,20 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
   const unsigned ReduceOpc = getVecReduceOpcode(Opc);
   assert(Opc == ISD::getVecReduceBaseOpcode(ReduceOpc) &&
          "Inconsistent mappings");
-  const SDValue LHS = N->getOperand(0);
-  const SDValue RHS = N->getOperand(1);
+  SDValue LHS = N->getOperand(0);
+  SDValue RHS = N->getOperand(1);
 
   if (!LHS.hasOneUse() || !RHS.hasOneUse())
     return SDValue();
 
+  if (RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
+    std::swap(LHS, RHS);
+
   if (RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
       !isa<ConstantSDNode>(RHS.getOperand(1)))
     return SDValue();
 
+  uint64_t RHSIdx = cast<ConstantSDNode>(RHS.getOperand(1))->getLimitedValue();
   SDValue SrcVec = RHS.getOperand(0);
   EVT SrcVecVT = SrcVec.getValueType();
   assert(SrcVecVT.getVectorElementType() == VT);
@@ -11385,14 +11389,17 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
   // match binop (extract_vector_elt V, 0), (extract_vector_elt V, 1) to
   // reduce_op (extract_subvector [2 x VT] from V).  This will form the
   // root of our reduction tree. TODO: We could extend this to any two
-  // adjacent constant indices if desired.
+  // adjacent aligned constant indices if desired.
   if (LHS.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
-      LHS.getOperand(0) == SrcVec && isNullConstant(LHS.getOperand(1)) &&
-      isOneConstant(RHS.getOperand(1))) {
-    EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
-    SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
-                              DAG.getVectorIdxConstant(0, DL));
-    return DAG.getNode(ReduceOpc, DL, VT, Vec, N->getFlags());
+      LHS.getOperand(0) == SrcVec && isa<ConstantSDNode>(LHS.getOperand(1))) {
+    uint64_t LHSIdx =
+      cast<ConstantSDNode>(LHS.getOperand(1))->getLimitedValue();
+    if (0 == std::min(LHSIdx, RHSIdx) && 1 == std::max(LHSIdx, RHSIdx)) {
+      EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
+      SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
+                                DAG.getVectorIdxConstant(0, DL));
+      return DAG.getNode(ReduceOpc, DL, VT, Vec, N->getFlags());
+    }
   }
 
   // Match (binop (reduce (extract_subvector V, 0),
@@ -11404,20 +11411,18 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
   SDValue ReduceVec = LHS.getOperand(0);
   if (ReduceVec.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
       ReduceVec.hasOneUse() && ReduceVec.getOperand(0) == RHS.getOperand(0) &&
-      isNullConstant(ReduceVec.getOperand(1))) {
-    uint64_t Idx = cast<ConstantSDNode>(RHS.getOperand(1))->getLimitedValue();
-    if (ReduceVec.getValueType().getVectorNumElements() == Idx) {
-      // For illegal types (e.g. 3xi32), most will be combined again into a
-      // wider (hopefully legal) type.  If this is a terminal state, we are
-      // relying on type legalization here to produce something reasonable
-      // and this lowering quality could probably be improved. (TODO)
-      EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, Idx + 1);
-      SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
-                                DAG.getVectorIdxConstant(0, DL));
-      auto Flags = ReduceVec->getFlags();
-      Flags.intersectWith(N->getFlags());
-      return DAG.getNode(ReduceOpc, DL, VT, Vec, Flags);
-    }
+      isNullConstant(ReduceVec.getOperand(1)) &&
+      ReduceVec.getValueType().getVectorNumElements() == RHSIdx) {
+    // For illegal types (e.g. 3xi32), most will be combined again into a
+    // wider (hopefully legal) type.  If this is a terminal state, we are
+    // relying on type legalization here to produce something reasonable
+    // and this lowering quality could probably be improved. (TODO)
+    EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, RHSIdx + 1);
+    SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
+                              DAG.getVectorIdxConstant(0, DL));
+    auto Flags = ReduceVec->getFlags();
+    Flags.intersectWith(N->getFlags());
+    return DAG.getNode(ReduceOpc, DL, VT, Vec, Flags);
   }
 
   return SDValue();
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
index 76df097a7697162..fd4a54b468f15fd 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
@@ -34,7 +34,6 @@ define i32 @reduce_sum_4xi32(<4 x i32> %v) {
   ret i32 %add2
 }
 
-
 define i32 @reduce_sum_8xi32(<8 x i32> %v) {
 ; CHECK-LABEL: reduce_sum_8xi32:
 ; CHECK:       # %bb.0:
@@ -449,6 +448,68 @@ define i32 @reduce_sum_16xi32_prefix15(ptr %p) {
   ret i32 %add13
 }
 
+; Check that we can match with the operand ordered reversed, but the
+; reduction order unchanged.
+define i32 @reduce_sum_4xi32_op_order(<4 x i32> %v) {
+; CHECK-LABEL: reduce_sum_4xi32_op_order:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; CHECK-NEXT:    vmv.s.x v9, zero
+; CHECK-NEXT:    vredsum.vs v8, v8, v9
+; CHECK-NEXT:    vmv.x.s a0, v8
+; CHECK-NEXT:    ret
+  %e0 = extractelement <4 x i32> %v, i32 0
+  %e1 = extractelement <4 x i32> %v, i32 1
+  %e2 = extractelement <4 x i32> %v, i32 2
+  %e3 = extractelement <4 x i32> %v, i32 3
+  %add0 = add i32 %e1, %e0
+  %add1 = add i32 %e2, %add0
+  %add2 = add i32 %add1, %e3
+  ret i32 %add2
+}
+
+; Negative test - Reduction order isn't compatibile with current
+; incremental matching scheme.
+define i32 @reduce_sum_4xi32_reduce_order(<4 x i32> %v) {
+; RV32-LABEL: reduce_sum_4xi32_reduce_order:
+; RV32:       # %bb.0:
+; RV32-NEXT:    vsetivli zero, 1, e32, m1, ta, ma
+; RV32-NEXT:    vmv.x.s a0, v8
+; RV32-NEXT:    vslidedown.vi v9, v8, 1
+; RV32-NEXT:    vmv.x.s a1, v9
+; RV32-NEXT:    vslidedown.vi v9, v8, 2
+; RV32-NEXT:    vmv.x.s a2, v9
+; RV32-NEXT:    vslidedown.vi v8, v8, 3
+; RV32-NEXT:    vmv.x.s a3, v8
+; RV32-NEXT:    add a1, a1, a2
+; RV32-NEXT:    add a0, a0, a3
+; RV32-NEXT:    add a0, a0, a1
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: reduce_sum_4xi32_reduce_order:
+; RV64:       # %bb.0:
+; RV64-NEXT:    vsetivli zero, 1, e32, m1, ta, ma
+; RV64-NEXT:    vmv.x.s a0, v8
+; RV64-NEXT:    vslidedown.vi v9, v8, 1
+; RV64-NEXT:    vmv.x.s a1, v9
+; RV64-NEXT:    vslidedown.vi v9, v8, 2
+; RV64-NEXT:    vmv.x.s a2, v9
+; RV64-NEXT:    vslidedown.vi v8, v8, 3
+; RV64-NEXT:    vmv.x.s a3, v8
+; RV64-NEXT:    add a1, a1, a2
+; RV64-NEXT:    add a0, a0, a3
+; RV64-NEXT:    addw a0, a0, a1
+; RV64-NEXT:    ret
+  %e0 = extractelement <4 x i32> %v, i32 0
+  %e1 = extractelement <4 x i32> %v, i32 1
+  %e2 = extractelement <4 x i32> %v, i32 2
+  %e3 = extractelement <4 x i32> %v, i32 3
+  %add0 = add i32 %e1, %e2
+  %add1 = add i32 %e0, %add0
+  %add2 = add i32 %add1, %e3
+  ret i32 %add2
+}
+
 ;; Most of the cornercases are exercised above, the following just
 ;; makes sure that other opcodes work as expected.
 
@@ -923,6 +984,3 @@ define float @reduce_fadd_4xi32_non_associative2(ptr %p) {
 }
 
 
-;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
-; RV32: {{.*}}
-; RV64: {{.*}}

@github-actions
Copy link

github-actions bot commented Oct 9, 2023

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff 07d2e90f28e36ac3c0a79d208ab74610f4b98546 4659721e895bd52eda40e7d500c4bfec018c57a2 -- llvm/lib/Target/RISCV/RISCVISelLowering.cpp
View the diff from clang-format here.
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b0fc99f6eff8..93579c2c1f69 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -11393,7 +11393,7 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
   if (LHS.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
       LHS.getOperand(0) == SrcVec && isa<ConstantSDNode>(LHS.getOperand(1))) {
     uint64_t LHSIdx =
-      cast<ConstantSDNode>(LHS.getOperand(1))->getLimitedValue();
+        cast<ConstantSDNode>(LHS.getOperand(1))->getLimitedValue();
     if (0 == std::min(LHSIdx, RHSIdx) && 1 == std::max(LHSIdx, RHSIdx)) {
       EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
       SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,

@preames
Copy link
Collaborator Author

preames commented Oct 17, 2023

ping

@lukel97 lukel97 requested review from lukel97 and removed request for luke957 October 17, 2023 20:11
Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@preames preames merged commit 25da9bb into llvm:main Oct 23, 2023
@preames preames deleted the pr-riscv-reduction-formation-swapped-operands branch October 23, 2023 17:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants