From cff94d4b078b04b26855f59e62bf439a8b2f2b1d Mon Sep 17 00:00:00 2001 From: Philip Reames Date: Tue, 28 Nov 2023 08:22:59 -0800 Subject: [PATCH] [RISCV] Work on subreg for insert_vector_elt when vlen is known (#72666) If we have a constant index and a known vlen, then we can identify which registers out of a register group is being accessed. Given this, we can reuse the (slightly generalized) existing handling for working on sub-register groups. This results in all constant index extracts with known vlen becoming m1 operations. One bit of weirdness to highlight and explain: the existing code uses the VL from the original vector type, not the inner vector type. This is correct because the inner register group must be smaller than the original (possibly fixed length) vector type. Overall, this seems to a reasonable codegen tradeoff as it biases us towards immediate AVLs, which avoids needing the vsetvli form which clobbers a GPR for no real purpose. The downside is that for large fixed length vectors, we end up materializing an immediate in register for little value. We should probably generalize this idea and try to optimize the large fixed length vector case, but that can be done in separate work. --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 48 ++++++++++++++----- .../CodeGen/RISCV/rvv/fixed-vectors-insert.ll | 40 +++++++--------- 2 files changed, 54 insertions(+), 34 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index bd5b1a879f32b..72b2e5e78c299 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -7739,17 +7739,41 @@ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op, Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget); } - MVT OrigContainerVT = ContainerVT; - SDValue OrigVec = Vec; // If we know the index we're going to insert at, we can shrink Vec so that // we're performing the scalar inserts and slideup on a smaller LMUL. - if (auto *CIdx = dyn_cast(Idx)) { - if (auto ShrunkVT = getSmallestVTForIndex(ContainerVT, CIdx->getZExtValue(), + MVT OrigContainerVT = ContainerVT; + SDValue OrigVec = Vec; + SDValue AlignedIdx; + if (auto *IdxC = dyn_cast(Idx)) { + const unsigned OrigIdx = IdxC->getZExtValue(); + // Do we know an upper bound on LMUL? + if (auto ShrunkVT = getSmallestVTForIndex(ContainerVT, OrigIdx, DL, DAG, Subtarget)) { ContainerVT = *ShrunkVT; + AlignedIdx = DAG.getVectorIdxConstant(0, DL); + } + + // If we're compiling for an exact VLEN value, we can always perform + // the insert in m1 as we can determine the register corresponding to + // the index in the register group. + const unsigned MinVLen = Subtarget.getRealMinVLen(); + const unsigned MaxVLen = Subtarget.getRealMaxVLen(); + const MVT M1VT = getLMUL1VT(ContainerVT); + if (MinVLen == MaxVLen && ContainerVT.bitsGT(M1VT)) { + EVT ElemVT = VecVT.getVectorElementType(); + unsigned ElemsPerVReg = MinVLen / ElemVT.getFixedSizeInBits(); + unsigned RemIdx = OrigIdx % ElemsPerVReg; + unsigned SubRegIdx = OrigIdx / ElemsPerVReg; + unsigned ExtractIdx = + SubRegIdx * M1VT.getVectorElementCount().getKnownMinValue(); + AlignedIdx = DAG.getVectorIdxConstant(ExtractIdx, DL); + Idx = DAG.getVectorIdxConstant(RemIdx, DL); + ContainerVT = M1VT; + } + + if (AlignedIdx) Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ContainerVT, Vec, - DAG.getVectorIdxConstant(0, DL)); - } + AlignedIdx); } MVT XLenVT = Subtarget.getXLenVT(); @@ -7779,9 +7803,9 @@ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op, Val = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Val); Vec = DAG.getNode(Opc, DL, ContainerVT, Vec, Val, VL); - if (ContainerVT != OrigContainerVT) + if (AlignedIdx) Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, OrigContainerVT, OrigVec, - Vec, DAG.getVectorIdxConstant(0, DL)); + Vec, AlignedIdx); if (!VecVT.isFixedLengthVector()) return Vec; return convertFromScalableVector(VecVT, Vec, DAG, Subtarget); @@ -7814,10 +7838,10 @@ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op, // Bitcast back to the right container type. ValInVec = DAG.getBitcast(ContainerVT, ValInVec); - if (ContainerVT != OrigContainerVT) + if (AlignedIdx) ValInVec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, OrigContainerVT, OrigVec, - ValInVec, DAG.getVectorIdxConstant(0, DL)); + ValInVec, AlignedIdx); if (!VecVT.isFixedLengthVector()) return ValInVec; return convertFromScalableVector(VecVT, ValInVec, DAG, Subtarget); @@ -7849,9 +7873,9 @@ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op, SDValue Slideup = getVSlideup(DAG, Subtarget, DL, ContainerVT, Vec, ValInVec, Idx, Mask, InsertVL, Policy); - if (ContainerVT != OrigContainerVT) + if (AlignedIdx) Slideup = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, OrigContainerVT, OrigVec, - Slideup, DAG.getVectorIdxConstant(0, DL)); + Slideup, AlignedIdx); if (!VecVT.isFixedLengthVector()) return Slideup; return convertFromScalableVector(VecVT, Slideup, DAG, Subtarget); diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-insert.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-insert.ll index de5c4fbc08764..a3f41fd842222 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-insert.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-insert.ll @@ -614,9 +614,8 @@ define <16 x i32> @insertelt_c3_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_ra define <16 x i32> @insertelt_c12_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_range(2,2) { ; CHECK-LABEL: insertelt_c12_v16xi32_exact: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 13, e32, m4, tu, ma -; CHECK-NEXT: vmv.s.x v12, a0 -; CHECK-NEXT: vslideup.vi v8, v12, 12 +; CHECK-NEXT: vsetivli zero, 16, e32, m1, tu, ma +; CHECK-NEXT: vmv.s.x v11, a0 ; CHECK-NEXT: ret %v = insertelement <16 x i32> %vin, i32 %a, i32 12 ret <16 x i32> %v @@ -625,9 +624,9 @@ define <16 x i32> @insertelt_c12_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_r define <16 x i32> @insertelt_c13_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_range(2,2) { ; CHECK-LABEL: insertelt_c13_v16xi32_exact: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 14, e32, m4, tu, ma +; CHECK-NEXT: vsetivli zero, 2, e32, m1, tu, ma ; CHECK-NEXT: vmv.s.x v12, a0 -; CHECK-NEXT: vslideup.vi v8, v12, 13 +; CHECK-NEXT: vslideup.vi v11, v12, 1 ; CHECK-NEXT: ret %v = insertelement <16 x i32> %vin, i32 %a, i32 13 ret <16 x i32> %v @@ -636,9 +635,9 @@ define <16 x i32> @insertelt_c13_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_r define <16 x i32> @insertelt_c14_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_range(2,2) { ; CHECK-LABEL: insertelt_c14_v16xi32_exact: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 15, e32, m4, tu, ma +; CHECK-NEXT: vsetivli zero, 3, e32, m1, tu, ma ; CHECK-NEXT: vmv.s.x v12, a0 -; CHECK-NEXT: vslideup.vi v8, v12, 14 +; CHECK-NEXT: vslideup.vi v11, v12, 2 ; CHECK-NEXT: ret %v = insertelement <16 x i32> %vin, i32 %a, i32 14 ret <16 x i32> %v @@ -647,9 +646,9 @@ define <16 x i32> @insertelt_c14_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_r define <16 x i32> @insertelt_c15_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_range(2,2) { ; CHECK-LABEL: insertelt_c15_v16xi32_exact: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 16, e32, m4, ta, ma +; CHECK-NEXT: vsetivli zero, 4, e32, m1, tu, ma ; CHECK-NEXT: vmv.s.x v12, a0 -; CHECK-NEXT: vslideup.vi v8, v12, 15 +; CHECK-NEXT: vslideup.vi v11, v12, 3 ; CHECK-NEXT: ret %v = insertelement <16 x i32> %vin, i32 %a, i32 15 ret <16 x i32> %v @@ -658,18 +657,15 @@ define <16 x i32> @insertelt_c15_v16xi32_exact(<16 x i32> %vin, i32 %a) vscale_r define <8 x i64> @insertelt_c4_v8xi64_exact(<8 x i64> %vin, i64 %a) vscale_range(2,2) { ; RV32-LABEL: insertelt_c4_v8xi64_exact: ; RV32: # %bb.0: -; RV32-NEXT: vsetivli zero, 2, e32, m4, ta, ma -; RV32-NEXT: vslide1down.vx v12, v8, a0 -; RV32-NEXT: vslide1down.vx v12, v12, a1 -; RV32-NEXT: vsetivli zero, 5, e64, m4, tu, ma -; RV32-NEXT: vslideup.vi v8, v12, 4 +; RV32-NEXT: vsetivli zero, 2, e32, m1, tu, ma +; RV32-NEXT: vslide1down.vx v10, v10, a0 +; RV32-NEXT: vslide1down.vx v10, v10, a1 ; RV32-NEXT: ret ; ; RV64-LABEL: insertelt_c4_v8xi64_exact: ; RV64: # %bb.0: -; RV64-NEXT: vsetivli zero, 5, e64, m4, tu, ma -; RV64-NEXT: vmv.s.x v12, a0 -; RV64-NEXT: vslideup.vi v8, v12, 4 +; RV64-NEXT: vsetivli zero, 8, e64, m1, tu, ma +; RV64-NEXT: vmv.s.x v10, a0 ; RV64-NEXT: ret %v = insertelement <8 x i64> %vin, i64 %a, i32 4 ret <8 x i64> %v @@ -678,18 +674,18 @@ define <8 x i64> @insertelt_c4_v8xi64_exact(<8 x i64> %vin, i64 %a) vscale_range define <8 x i64> @insertelt_c5_v8xi64_exact(<8 x i64> %vin, i64 %a) vscale_range(2,2) { ; RV32-LABEL: insertelt_c5_v8xi64_exact: ; RV32: # %bb.0: -; RV32-NEXT: vsetivli zero, 2, e32, m4, ta, ma +; RV32-NEXT: vsetivli zero, 2, e32, m1, ta, ma ; RV32-NEXT: vslide1down.vx v12, v8, a0 ; RV32-NEXT: vslide1down.vx v12, v12, a1 -; RV32-NEXT: vsetivli zero, 6, e64, m4, tu, ma -; RV32-NEXT: vslideup.vi v8, v12, 5 +; RV32-NEXT: vsetivli zero, 2, e64, m1, tu, ma +; RV32-NEXT: vslideup.vi v10, v12, 1 ; RV32-NEXT: ret ; ; RV64-LABEL: insertelt_c5_v8xi64_exact: ; RV64: # %bb.0: -; RV64-NEXT: vsetivli zero, 6, e64, m4, tu, ma +; RV64-NEXT: vsetivli zero, 2, e64, m1, tu, ma ; RV64-NEXT: vmv.s.x v12, a0 -; RV64-NEXT: vslideup.vi v8, v12, 5 +; RV64-NEXT: vslideup.vi v10, v12, 1 ; RV64-NEXT: ret %v = insertelement <8 x i64> %vin, i64 %a, i32 5 ret <8 x i64> %v