diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 6eb253cc51466..4dc3f6137e306 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -27,6 +27,7 @@ #include "llvm/CodeGen/MachineInstrBuilder.h" #include "llvm/CodeGen/MachineJumpTableInfo.h" #include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/SelectionDAGAddressAnalysis.h" #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h" #include "llvm/CodeGen/ValueTypes.h" #include "llvm/IR/DiagnosticInfo.h" @@ -13803,9 +13804,17 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG, Align = std::min(Align, Ld->getAlign()); } - using PtrDiff = std::pair; - auto GetPtrDiff = [](LoadSDNode *Ld1, - LoadSDNode *Ld2) -> std::optional { + using PtrDiff = std::pair, bool>; + auto GetPtrDiff = [&DAG](LoadSDNode *Ld1, + LoadSDNode *Ld2) -> std::optional { + // If the load ptrs can be decomposed into a common (Base + Index) with a + // common constant stride, then return the constant stride. + BaseIndexOffset BIO1 = BaseIndexOffset::match(Ld1, DAG); + BaseIndexOffset BIO2 = BaseIndexOffset::match(Ld2, DAG); + if (BIO1.equalBaseIndex(BIO2, DAG)) + return {{BIO2.getOffset() - BIO1.getOffset(), false}}; + + // Otherwise try to match (add LastPtr, Stride) or (add NextPtr, Stride) SDValue P1 = Ld1->getBasePtr(); SDValue P2 = Ld2->getBasePtr(); if (P2.getOpcode() == ISD::ADD && P2.getOperand(0) == P1) @@ -13844,7 +13853,11 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG, if (!TLI.isLegalStridedLoadStore(WideVecVT, Align)) return SDValue(); - auto [Stride, MustNegateStride] = *BaseDiff; + auto [StrideVariant, MustNegateStride] = *BaseDiff; + SDValue Stride = std::holds_alternative(StrideVariant) + ? std::get(StrideVariant) + : DAG.getConstant(std::get(StrideVariant), DL, + Lds[0]->getOffset().getValueType()); if (MustNegateStride) Stride = DAG.getNegative(Stride, DL, Stride.getValueType()); diff --git a/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll b/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll index 611270ab98ebd..ff35043dbd7e7 100644 --- a/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll +++ b/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll @@ -7,21 +7,10 @@ define void @constant_forward_stride(ptr %s, ptr %d) { ; CHECK-LABEL: constant_forward_stride: ; CHECK: # %bb.0: -; CHECK-NEXT: addi a2, a0, 16 -; CHECK-NEXT: addi a3, a0, 32 -; CHECK-NEXT: addi a4, a0, 48 -; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, ma -; CHECK-NEXT: vle8.v v8, (a0) -; CHECK-NEXT: vle8.v v9, (a2) -; CHECK-NEXT: vle8.v v10, (a3) -; CHECK-NEXT: vle8.v v11, (a4) -; CHECK-NEXT: vsetivli zero, 4, e8, mf2, tu, ma -; CHECK-NEXT: vslideup.vi v8, v9, 2 -; CHECK-NEXT: vsetivli zero, 6, e8, mf2, tu, ma -; CHECK-NEXT: vslideup.vi v8, v10, 4 -; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma -; CHECK-NEXT: vslideup.vi v8, v11, 6 -; CHECK-NEXT: vse8.v v8, (a1) +; CHECK-NEXT: li a2, 16 +; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma +; CHECK-NEXT: vlse16.v v8, (a0), a2 +; CHECK-NEXT: vse16.v v8, (a1) ; CHECK-NEXT: ret %1 = getelementptr inbounds i8, ptr %s, i64 16 %2 = getelementptr inbounds i8, ptr %s, i64 32 @@ -40,21 +29,11 @@ define void @constant_forward_stride(ptr %s, ptr %d) { define void @constant_forward_stride2(ptr %s, ptr %d) { ; CHECK-LABEL: constant_forward_stride2: ; CHECK: # %bb.0: -; CHECK-NEXT: addi a2, a0, -16 -; CHECK-NEXT: addi a3, a0, -32 -; CHECK-NEXT: addi a4, a0, -48 -; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, ma -; CHECK-NEXT: vle8.v v8, (a4) -; CHECK-NEXT: vle8.v v9, (a3) -; CHECK-NEXT: vle8.v v10, (a2) -; CHECK-NEXT: vle8.v v11, (a0) -; CHECK-NEXT: vsetivli zero, 4, e8, mf2, tu, ma -; CHECK-NEXT: vslideup.vi v8, v9, 2 -; CHECK-NEXT: vsetivli zero, 6, e8, mf2, tu, ma -; CHECK-NEXT: vslideup.vi v8, v10, 4 -; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma -; CHECK-NEXT: vslideup.vi v8, v11, 6 -; CHECK-NEXT: vse8.v v8, (a1) +; CHECK-NEXT: addi a0, a0, -48 +; CHECK-NEXT: li a2, 16 +; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma +; CHECK-NEXT: vlse16.v v8, (a0), a2 +; CHECK-NEXT: vse16.v v8, (a1) ; CHECK-NEXT: ret %1 = getelementptr inbounds i8, ptr %s, i64 -16 %2 = getelementptr inbounds i8, ptr %s, i64 -32 @@ -73,21 +52,10 @@ define void @constant_forward_stride2(ptr %s, ptr %d) { define void @constant_forward_stride3(ptr %s, ptr %d) { ; CHECK-LABEL: constant_forward_stride3: ; CHECK: # %bb.0: -; CHECK-NEXT: addi a2, a0, 16 -; CHECK-NEXT: addi a3, a0, 32 -; CHECK-NEXT: addi a4, a0, 48 -; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, ma -; CHECK-NEXT: vle8.v v8, (a0) -; CHECK-NEXT: vle8.v v9, (a2) -; CHECK-NEXT: vle8.v v10, (a3) -; CHECK-NEXT: vle8.v v11, (a4) -; CHECK-NEXT: vsetivli zero, 4, e8, mf2, tu, ma -; CHECK-NEXT: vslideup.vi v8, v9, 2 -; CHECK-NEXT: vsetivli zero, 6, e8, mf2, tu, ma -; CHECK-NEXT: vslideup.vi v8, v10, 4 -; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma -; CHECK-NEXT: vslideup.vi v8, v11, 6 -; CHECK-NEXT: vse8.v v8, (a1) +; CHECK-NEXT: li a2, 16 +; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma +; CHECK-NEXT: vlse16.v v8, (a0), a2 +; CHECK-NEXT: vse16.v v8, (a1) ; CHECK-NEXT: ret %1 = getelementptr inbounds i8, ptr %s, i64 16 %2 = getelementptr inbounds i8, ptr %s, i64 32 @@ -109,21 +77,10 @@ define void @constant_forward_stride3(ptr %s, ptr %d) { define void @constant_back_stride(ptr %s, ptr %d) { ; CHECK-LABEL: constant_back_stride: ; CHECK: # %bb.0: -; CHECK-NEXT: addi a2, a0, -16 -; CHECK-NEXT: addi a3, a0, -32 -; CHECK-NEXT: addi a4, a0, -48 -; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, ma -; CHECK-NEXT: vle8.v v8, (a0) -; CHECK-NEXT: vle8.v v9, (a2) -; CHECK-NEXT: vle8.v v10, (a3) -; CHECK-NEXT: vle8.v v11, (a4) -; CHECK-NEXT: vsetivli zero, 4, e8, mf2, tu, ma -; CHECK-NEXT: vslideup.vi v8, v9, 2 -; CHECK-NEXT: vsetivli zero, 6, e8, mf2, tu, ma -; CHECK-NEXT: vslideup.vi v8, v10, 4 -; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma -; CHECK-NEXT: vslideup.vi v8, v11, 6 -; CHECK-NEXT: vse8.v v8, (a1) +; CHECK-NEXT: li a2, -16 +; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma +; CHECK-NEXT: vlse16.v v8, (a0), a2 +; CHECK-NEXT: vse16.v v8, (a1) ; CHECK-NEXT: ret %1 = getelementptr inbounds i8, ptr %s, i64 -16 %2 = getelementptr inbounds i8, ptr %s, i64 -32 @@ -142,21 +99,11 @@ define void @constant_back_stride(ptr %s, ptr %d) { define void @constant_back_stride2(ptr %s, ptr %d) { ; CHECK-LABEL: constant_back_stride2: ; CHECK: # %bb.0: -; CHECK-NEXT: addi a2, a0, 16 -; CHECK-NEXT: addi a3, a0, 32 -; CHECK-NEXT: addi a4, a0, 48 -; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, ma -; CHECK-NEXT: vle8.v v8, (a4) -; CHECK-NEXT: vle8.v v9, (a3) -; CHECK-NEXT: vle8.v v10, (a2) -; CHECK-NEXT: vle8.v v11, (a0) -; CHECK-NEXT: vsetivli zero, 4, e8, mf2, tu, ma -; CHECK-NEXT: vslideup.vi v8, v9, 2 -; CHECK-NEXT: vsetivli zero, 6, e8, mf2, tu, ma -; CHECK-NEXT: vslideup.vi v8, v10, 4 -; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma -; CHECK-NEXT: vslideup.vi v8, v11, 6 -; CHECK-NEXT: vse8.v v8, (a1) +; CHECK-NEXT: addi a0, a0, 48 +; CHECK-NEXT: li a2, -16 +; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma +; CHECK-NEXT: vlse16.v v8, (a0), a2 +; CHECK-NEXT: vse16.v v8, (a1) ; CHECK-NEXT: ret %1 = getelementptr inbounds i8, ptr %s, i64 16 %2 = getelementptr inbounds i8, ptr %s, i64 32 @@ -175,21 +122,10 @@ define void @constant_back_stride2(ptr %s, ptr %d) { define void @constant_back_stride3(ptr %s, ptr %d) { ; CHECK-LABEL: constant_back_stride3: ; CHECK: # %bb.0: -; CHECK-NEXT: addi a2, a0, -16 -; CHECK-NEXT: addi a3, a0, -32 -; CHECK-NEXT: addi a4, a0, -48 -; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, ma -; CHECK-NEXT: vle8.v v8, (a0) -; CHECK-NEXT: vle8.v v9, (a2) -; CHECK-NEXT: vle8.v v10, (a3) -; CHECK-NEXT: vle8.v v11, (a4) -; CHECK-NEXT: vsetivli zero, 4, e8, mf2, tu, ma -; CHECK-NEXT: vslideup.vi v8, v9, 2 -; CHECK-NEXT: vsetivli zero, 6, e8, mf2, tu, ma -; CHECK-NEXT: vslideup.vi v8, v10, 4 -; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma -; CHECK-NEXT: vslideup.vi v8, v11, 6 -; CHECK-NEXT: vse8.v v8, (a1) +; CHECK-NEXT: li a2, -16 +; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma +; CHECK-NEXT: vlse16.v v8, (a0), a2 +; CHECK-NEXT: vse16.v v8, (a1) ; CHECK-NEXT: ret %1 = getelementptr inbounds i8, ptr %s, i64 -16 %2 = getelementptr inbounds i8, ptr %s, i64 -32