Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV] Improve performCONCAT_VECTORCombine stride matching #68726

Closed

Conversation

michaelmaitland
Copy link
Contributor

I intend to commit the two commits in this patch independently:

Pre-commit concat-vectors-constant-stride tests.ll

This patch commits tests that can be optimized by improving
performCONCAT_VECTORCombine to do a better job at decomposing the base
pointer and recognizing a constant offset.

Improve performCONCAT_VECTORCombine stride matching

If the load ptrs can be decomposed into a common (Base + Index) with a
common constant stride, then return the constant stride. This matcher
enables some additional optimization since BaseIndexOffset is capable of
decomposing the load ptrs to (add (add Base, Index), Stride) instead of
(add LastPtr, Stride) or (add NextPtr, Stride) that matchForwardStrided and
matchReverseStrided use, respectively.

@llvmbot
Copy link
Collaborator

llvmbot commented Oct 10, 2023

@llvm/pr-subscribers-backend-risc-v

Author: Michael Maitland (michaelmaitland)

Changes

I intend to commit the two commits in this patch independently:

Pre-commit concat-vectors-constant-stride tests.ll

This patch commits tests that can be optimized by improving
performCONCAT_VECTORCombine to do a better job at decomposing the base
pointer and recognizing a constant offset.

Improve performCONCAT_VECTORCombine stride matching

If the load ptrs can be decomposed into a common (Base + Index) with a
common constant stride, then return the constant stride. This matcher
enables some additional optimization since BaseIndexOffset is capable of
decomposing the load ptrs to (add (add Base, Index), Stride) instead of
(add LastPtr, Stride) or (add NextPtr, Stride) that matchForwardStrided and
matchReverseStrided use, respectively.

Full diff: https://github.com/llvm/llvm-project/pull/68726.diff

2 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+67-8)
  • (added) llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll (+231)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 6be3fa71479be5c..955674bd13e1c79 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -27,6 +27,7 @@
 #include "llvm/CodeGen/MachineInstrBuilder.h"
 #include "llvm/CodeGen/MachineJumpTableInfo.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
 #include "llvm/CodeGen/ValueTypes.h"
 #include "llvm/IR/DiagnosticInfo.h"
@@ -13821,6 +13822,58 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
     Align = std::min(Align, Ld->getAlign());
   }
 
+   // If the load ptrs can be decomposed into a common (Base + Index) with a
+  // common constant stride, then return the constant stride. This matcher
+  // enables some additional optimization since BaseIndexOffset is capable of
+  // decomposing the load ptrs to (add (add Base, Index), Stride) instead of
+  // (add LastPtr, Stride) or (add NextPtr, Stride) that matchForwardStrided and
+  // matchReverseStrided use respectively.
+  auto matchConstantStride = [&DAG, &N](ArrayRef<SDUse> Loads) {
+    // Initialize match constraints based on the first load. Initialize
+    // ConstStride by taking the difference between the offset of the first two
+    // loads.
+    if (Loads.size() < 2)
+      return SDValue();
+    BaseIndexOffset BaseLdBIO =
+        BaseIndexOffset::match(cast<LoadSDNode>(Loads[0]), DAG);
+    BaseIndexOffset LastLdBIO =
+        BaseIndexOffset::match(cast<LoadSDNode>(Loads[1]), DAG);
+    bool AllValidOffset =
+        BaseLdBIO.hasValidOffset() && LastLdBIO.hasValidOffset();
+    if (!AllValidOffset)
+      return SDValue();
+    bool BaseIndexMatch = BaseLdBIO.equalBaseIndex(LastLdBIO, DAG);
+    if (!BaseIndexMatch)
+      return SDValue();
+    int64_t ConstStride = LastLdBIO.getOffset() - BaseLdBIO.getOffset();
+
+    // Check that constraints hold for all subsequent loads and the ConstStride
+    // is the same.
+    for (auto Idx : enumerate(Loads.drop_front(2))) {
+      auto *Ld = cast<LoadSDNode>(Idx.value());
+      BaseIndexOffset BIO = BaseIndexOffset::match(Ld, DAG);
+      AllValidOffset &= BIO.hasValidOffset();
+      if (!AllValidOffset)
+        return SDValue();
+      BaseIndexMatch |= BaseLdBIO.equalBaseIndex(BIO, DAG);
+      // Add 3 to index because the first two loads have been processed before
+      // the loop.
+      bool StrideMatches =
+          ConstStride == BIO.getOffset() - LastLdBIO.getOffset();
+      if (!BaseIndexMatch || !StrideMatches)
+        return SDValue();
+      LastLdBIO = BIO;
+    }
+
+    // The match is a success if all the constraints hold.
+    if (BaseIndexMatch && AllValidOffset)
+      return DAG.getConstant(
+          ConstStride, SDLoc(N),
+          cast<LoadSDNode>(N->getOperand(0))->getOffset().getValueType());
+
+    // The match failed.
+    return SDValue();
+  };
   auto matchForwardStrided = [](ArrayRef<SDValue> Ptrs) {
     SDValue Stride;
     for (auto Idx : enumerate(Ptrs)) {
@@ -13862,13 +13915,21 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
   SDValue Stride = matchForwardStrided(Ptrs);
   if (!Stride) {
     Stride = matchReverseStrided(Ptrs);
-    Reversed = true;
-    // TODO: At this point, we've successfully matched a generalized gather
-    // load.  Maybe we should emit that, and then move the specialized
-    // matchers above and below into a DAG combine?
-    if (!Stride)
-      return SDValue();
+      if (Stride) {
+      Reversed = true;
+      Stride = DAG.getNegative(Stride, DL, Stride->getValueType(0));
+    } else {
+      Stride = matchConstantStride(N->ops());
+      if (Stride) {
+        Reversed = cast<ConstantSDNode>(Stride)->getSExtValue() < 0;
+      } else {
+          return SDValue();
+      }
+    }
   }
+  // TODO: At this point, we've successfully matched a generalized gather
+  // load.  Maybe we should emit that, and then move the specialized
+  // matchers above and below into a DAG combine?
 
   // Get the widened scalar type, e.g. v4i8 -> i64
   unsigned WideScalarBitWidth =
@@ -13888,8 +13949,6 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
   SDValue IntID =
     DAG.getTargetConstant(Intrinsic::riscv_masked_strided_load, DL,
                           Subtarget.getXLenVT());
-  if (Reversed)
-    Stride = DAG.getNegative(Stride, DL, Stride->getValueType(0));
   SDValue AllOneMask =
     DAG.getSplat(WideVecVT.changeVectorElementType(MVT::i1), DL,
                  DAG.getConstant(1, DL, MVT::i1));
diff --git a/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll b/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll
new file mode 100644
index 000000000000000..611270ab98ebdaf
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll
@@ -0,0 +1,231 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv32 -mattr=+v,+unaligned-vector-mem -target-abi=ilp32 \
+; RUN:     -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,RV32
+; RUN: llc -mtriple=riscv64 -mattr=+v,+unaligned-vector-mem -target-abi=lp64 \
+; RUN:     -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,RV64
+
+define void @constant_forward_stride(ptr %s, ptr %d) {
+; CHECK-LABEL: constant_forward_stride:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a2, a0, 16
+; CHECK-NEXT:    addi a3, a0, 32
+; CHECK-NEXT:    addi a4, a0, 48
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a0)
+; CHECK-NEXT:    vle8.v v9, (a2)
+; CHECK-NEXT:    vle8.v v10, (a3)
+; CHECK-NEXT:    vle8.v v11, (a4)
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v9, 2
+; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 4
+; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT:    vslideup.vi v8, v11, 6
+; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    ret
+  %1 = getelementptr inbounds i8, ptr %s, i64 16
+  %2 = getelementptr inbounds i8, ptr %s, i64 32
+  %3 = getelementptr inbounds i8, ptr %s, i64 48
+  %4 = load <2 x i8>, ptr %s, align 1
+  %5 = load <2 x i8>, ptr %1, align 1
+  %6 = load <2 x i8>, ptr %2, align 1
+  %7 = load <2 x i8>, ptr %3, align 1
+  %8 = shufflevector <2 x i8> %4, <2 x i8> %5, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %9 = shufflevector <2 x i8> %6, <2 x i8> %7, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %10 = shufflevector <4 x i8> %8, <4 x i8> %9, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+  store <8 x i8> %10, ptr %d, align 1
+  ret void
+}
+
+define void @constant_forward_stride2(ptr %s, ptr %d) {
+; CHECK-LABEL: constant_forward_stride2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a2, a0, -16
+; CHECK-NEXT:    addi a3, a0, -32
+; CHECK-NEXT:    addi a4, a0, -48
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a4)
+; CHECK-NEXT:    vle8.v v9, (a3)
+; CHECK-NEXT:    vle8.v v10, (a2)
+; CHECK-NEXT:    vle8.v v11, (a0)
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v9, 2
+; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 4
+; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT:    vslideup.vi v8, v11, 6
+; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    ret
+  %1 = getelementptr inbounds i8, ptr %s, i64 -16
+  %2 = getelementptr inbounds i8, ptr %s, i64 -32
+  %3 = getelementptr inbounds i8, ptr %s, i64 -48
+  %4 = load <2 x i8>, ptr %3, align 1
+  %5 = load <2 x i8>, ptr %2, align 1
+  %6 = load <2 x i8>, ptr %1, align 1
+  %7 = load <2 x i8>, ptr %s, align 1
+  %8 = shufflevector <2 x i8> %4, <2 x i8> %5, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %9 = shufflevector <2 x i8> %6, <2 x i8> %7, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %10 = shufflevector <4 x i8> %8, <4 x i8> %9, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+  store <8 x i8> %10, ptr %d, align 1
+  ret void
+}
+
+define void @constant_forward_stride3(ptr %s, ptr %d) {
+; CHECK-LABEL: constant_forward_stride3:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a2, a0, 16
+; CHECK-NEXT:    addi a3, a0, 32
+; CHECK-NEXT:    addi a4, a0, 48
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a0)
+; CHECK-NEXT:    vle8.v v9, (a2)
+; CHECK-NEXT:    vle8.v v10, (a3)
+; CHECK-NEXT:    vle8.v v11, (a4)
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v9, 2
+; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 4
+; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT:    vslideup.vi v8, v11, 6
+; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    ret
+  %1 = getelementptr inbounds i8, ptr %s, i64 16
+  %2 = getelementptr inbounds i8, ptr %s, i64 32
+  %3 = getelementptr inbounds i8, ptr %s, i64 48
+  %4 = getelementptr inbounds i8, ptr %1, i64 0
+  %5 = getelementptr inbounds i8, ptr %2, i64 0
+  %6 = getelementptr inbounds i8, ptr %3, i64 0
+  %7 = load <2 x i8>, ptr %s, align 1
+  %8 = load <2 x i8>, ptr %4, align 1
+  %9 = load <2 x i8>, ptr %5, align 1
+  %10 = load <2 x i8>, ptr %6, align 1
+  %11 = shufflevector <2 x i8> %7, <2 x i8> %8, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %12 = shufflevector <2 x i8> %9, <2 x i8> %10, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %13 = shufflevector <4 x i8> %11, <4 x i8> %12, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+  store <8 x i8> %13, ptr %d, align 1
+  ret void
+}
+
+define void @constant_back_stride(ptr %s, ptr %d) {
+; CHECK-LABEL: constant_back_stride:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a2, a0, -16
+; CHECK-NEXT:    addi a3, a0, -32
+; CHECK-NEXT:    addi a4, a0, -48
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a0)
+; CHECK-NEXT:    vle8.v v9, (a2)
+; CHECK-NEXT:    vle8.v v10, (a3)
+; CHECK-NEXT:    vle8.v v11, (a4)
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v9, 2
+; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 4
+; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT:    vslideup.vi v8, v11, 6
+; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    ret
+  %1 = getelementptr inbounds i8, ptr %s, i64 -16
+  %2 = getelementptr inbounds i8, ptr %s, i64 -32
+  %3 = getelementptr inbounds i8, ptr %s, i64 -48
+  %4 = load <2 x i8>, ptr %s, align 1
+  %5 = load <2 x i8>, ptr %1, align 1
+  %6 = load <2 x i8>, ptr %2, align 1
+  %7 = load <2 x i8>, ptr %3, align 1
+  %8 = shufflevector <2 x i8> %4, <2 x i8> %5, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %9 = shufflevector <2 x i8> %6, <2 x i8> %7, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %10 = shufflevector <4 x i8> %8, <4 x i8> %9, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+  store <8 x i8> %10, ptr %d, align 1
+  ret void
+}
+
+define void @constant_back_stride2(ptr %s, ptr %d) {
+; CHECK-LABEL: constant_back_stride2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a2, a0, 16
+; CHECK-NEXT:    addi a3, a0, 32
+; CHECK-NEXT:    addi a4, a0, 48
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a4)
+; CHECK-NEXT:    vle8.v v9, (a3)
+; CHECK-NEXT:    vle8.v v10, (a2)
+; CHECK-NEXT:    vle8.v v11, (a0)
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v9, 2
+; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 4
+; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT:    vslideup.vi v8, v11, 6
+; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    ret
+  %1 = getelementptr inbounds i8, ptr %s, i64 16
+  %2 = getelementptr inbounds i8, ptr %s, i64 32
+  %3 = getelementptr inbounds i8, ptr %s, i64 48
+  %4 = load <2 x i8>, ptr %3, align 1
+  %5 = load <2 x i8>, ptr %2, align 1
+  %6 = load <2 x i8>, ptr %1, align 1
+  %7 = load <2 x i8>, ptr %s, align 1
+  %8 = shufflevector <2 x i8> %4, <2 x i8> %5, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %9 = shufflevector <2 x i8> %6, <2 x i8> %7, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %10 = shufflevector <4 x i8> %8, <4 x i8> %9, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+  store <8 x i8> %10, ptr %d, align 1
+  ret void
+}
+
+define void @constant_back_stride3(ptr %s, ptr %d) {
+; CHECK-LABEL: constant_back_stride3:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a2, a0, -16
+; CHECK-NEXT:    addi a3, a0, -32
+; CHECK-NEXT:    addi a4, a0, -48
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a0)
+; CHECK-NEXT:    vle8.v v9, (a2)
+; CHECK-NEXT:    vle8.v v10, (a3)
+; CHECK-NEXT:    vle8.v v11, (a4)
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v9, 2
+; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 4
+; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT:    vslideup.vi v8, v11, 6
+; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    ret
+  %1 = getelementptr inbounds i8, ptr %s, i64 -16
+  %2 = getelementptr inbounds i8, ptr %s, i64 -32
+  %3 = getelementptr inbounds i8, ptr %s, i64 -48
+  %4 = getelementptr inbounds i8, ptr %1, i64 0
+  %5 = getelementptr inbounds i8, ptr %2, i64 0
+  %6 = getelementptr inbounds i8, ptr %3, i64 0
+  %7 = load <2 x i8>, ptr %s, align 1
+  %8 = load <2 x i8>, ptr %4, align 1
+  %9 = load <2 x i8>, ptr %5, align 1
+  %10 = load <2 x i8>, ptr %6, align 1
+  %11 = shufflevector <2 x i8> %7, <2 x i8> %8, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %12 = shufflevector <2 x i8> %9, <2 x i8> %10, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %13 = shufflevector <4 x i8> %11, <4 x i8> %12, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+  store <8 x i8> %13, ptr %d, align 1
+  ret void
+}
+
+define void @constant_zero_stride(ptr %s, ptr %d) {
+; CHECK-LABEL: constant_zero_stride:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a0)
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf4, ta, ma
+; CHECK-NEXT:    vmv1r.v v9, v8
+; CHECK-NEXT:    vslideup.vi v9, v8, 2
+; CHECK-NEXT:    vse8.v v9, (a1)
+; CHECK-NEXT:    ret
+  %1 = getelementptr inbounds i8, ptr %s, i64 0
+  %2 = load <2 x i8>, ptr %s, align 1
+  %3 = load <2 x i8>, ptr %1, align 1
+  %4 = shufflevector <2 x i8> %2, <2 x i8> %3, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  store <4 x i8> %4, ptr %d, align 1
+  ret void
+}
+
+;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
+; RV32: {{.*}}
+; RV64: {{.*}}

@michaelmaitland michaelmaitland changed the title Comcat vectors combine Improve performCONCAT_VECTORCombine stride matching Oct 10, 2023
@michaelmaitland michaelmaitland changed the title Improve performCONCAT_VECTORCombine stride matching [RISCV] Improve performCONCAT_VECTORCombine stride matching Oct 10, 2023
@github-actions
Copy link

github-actions bot commented Oct 10, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@michaelmaitland michaelmaitland force-pushed the comcat-vectors-combine branch 2 times, most recently from ce03922 to 9e87a89 Compare October 10, 2023 17:11
llvm/lib/Target/RISCV/RISCVISelLowering.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/RISCV/RISCVISelLowering.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/RISCV/RISCVISelLowering.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/RISCV/RISCVISelLowering.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/RISCV/RISCVISelLowering.cpp Outdated Show resolved Hide resolved
// loads.
if (Loads.size() < 2)
return SDValue();
BaseIndexOffset BaseLdBIO =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, an idea...

Looking at this code, I realized we could factor out a helper routine of the form optional<int64_t> getPtrDiff(LD1, LD2). Then we could loop over all of the indices, and check that the offset is known and the same.

After that, it occurred to me that both of the previous two matchers can also be written as rules inside a getPtrDiff helper. (The return value has to become SDValue instead, but otherwise it works.)

Maybe we should switch over in a separate commit, and then implement this matcher in that style?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realized we could factor out a helper routine of the form optional<int64_t> getPtrDiff(LD1, LD2). Then we could loop over all of the indices, and check that the offset is known and the same.

This seems to be very close to what BaseIndexOffset match does since if the (Base+Index) matches then the ptr diff is Offset. We could definitely add this helper method to simplify our checks in matchConstantStride.

After that, it occurred to me that both of the previous two matchers can also be written as rules inside a getPtrDiff helper. (The return value has to become SDValue instead, but otherwise it works.)

What happens in the case where there is no constant stride in matchFoward and matchReverse (i.e. stride can be a common add SDValue that is not a SDConstant).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I understand, getPtrDiff would return the canonicalized representation of the non-constant value. We wouldn't be able to use BaseIndexOffset here. Do you have any ideas on what we could use?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pulled out getPtrDiff myself when trying wrap my head around this, hope this doesn't step on anyones toes: #69068
We should be able to extend it to check BaseIndexOffsets and return DAG.getConstant like we do here if they match

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lukel97 thanks for your refactor. This change is very simple with getPtrDiff. I've updated this patch to be stacked on top of #69068

lukel97 added a commit to lukel97/llvm-project that referenced this pull request Oct 14, 2023
Instead of doing a forward pass for positive strides and a reverse pass for
negative strides, we can just do one pass by negating the offset if the
pointers do happen to be in reverse order.

We can extend getPtrDiff later in llvm#68726 to handle more constant offset
sequences.
lukel97 added a commit that referenced this pull request Oct 16, 2023
Instead of doing a forward pass for positive strides and a reverse pass
for
negative strides, we can just do one pass by negating the offset if the
pointers do happen to be in reverse order.

We can extend getPtrDiff later in #68726 to handle more constant offset
sequences.
BaseIndexOffset BIO1 = BaseIndexOffset::match(Ld1, DAG);
BaseIndexOffset BIO2 = BaseIndexOffset::match(Ld2, DAG);
if (BIO1.equalBaseIndex(BIO2, DAG))
return {{DAG.getConstant(BIO2.getOffset() - BIO1.getOffset(), DL,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realising we should probably defer creating the constant here too. One thought would to use a std::variant<uint64_t, SDValue> to represent either a constant or variable stride. But then we would need to extract the constants out of the SDValues below so we can compare the two etc. There's probably a simpler way of doing this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. Let me take a look.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what you propose makes sense. I thought about trying to do it using std::function to defer creation of constant or negative, but that doesn't allow us to compare getPtrDiff() != BaseDiff. We need to do a comparison by value on three cases:

  1. SDValue, true
  2. SDValue, false
  3. int64_t, false

And the pair<variant, bool> structure permits us to do this check.

This patch commits tests that can be optimized by improving
performCONCAT_VECTORCombine to do a better job at decomposing the base
pointer and recognizing a constant offset.
If the load ptrs can be decomposed into a common (Base + Index) with a
common constant stride, then return the constant stride.
Copy link
Contributor

@lukel97 lukel97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Just leaving a note here that I think we'll fail to detect sequences of pointers that are a mixture of (base + index) and (add LastPtr, constStride), because the latter will be an SDValue and not a int64_t. But it sounds like an unlikely pattern to show up and I doubt its worth the complexity

@michaelmaitland
Copy link
Contributor Author

LGTM. Just leaving a note here that I think we'll fail to detect sequences of pointers that are a mixture of (base + index) and (add LastPtr, constStride), because the latter will be an SDValue and not a int64_t. But it sounds like an unlikely pattern to show up and I doubt its worth the complexity

My guess is that if BaseIndexOffset couldn't resolve to a match, then we would have no way to compare a constant to a non-constant anyway.

@michaelmaitland
Copy link
Contributor Author

Committed in 30ca258 and c319c74 to preserve pre-commit diff.

@michaelmaitland michaelmaitland deleted the comcat-vectors-combine branch October 16, 2023 23:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants