Skip to content

Commit

Permalink
[RISCV] Recursively split concat_vector into smaller LMULs (#83035)
Browse files Browse the repository at this point in the history
This is the concat_vector equivalent of #81312, in that we recursively
split concat_vectors with more than two operands into smaller
concat_vectors.

This allows us to break up the chain of vslideups, as well as perform
the vslideups at a smaller LMUL, which in turn reduces register pressure
as the previous lowering performed N vslideups at the highest result
LMUL. For now, it stops splitting past MF2.

This is done as a DAG combine so that any undef operands are combined
away: If we do this during lowering then we end up with unnecessary
vslideups of undefs.
  • Loading branch information
lukel97 committed Mar 7, 2024
1 parent 99500e8 commit c59129a
Show file tree
Hide file tree
Showing 9 changed files with 926 additions and 781 deletions.
60 changes: 56 additions & 4 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15283,13 +15283,62 @@ static SDValue performINSERT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ConcatOps);
}

// Recursively split up concat_vectors with more than 2 operands:
//
// concat_vector op1, op2, op3, op4
// ->
// concat_vector (concat_vector op1, op2), (concat_vector op3, op4)
//
// This reduces the length of the chain of vslideups and allows us to perform
// the vslideups at a smaller LMUL, limited to MF2.
//
// We do this as a DAG combine rather than during lowering so that any undef
// operands can get combined away.
static SDValue
performCONCAT_VECTORSSplitCombine(SDNode *N, SelectionDAG &DAG,
const RISCVTargetLowering &TLI) {
SDLoc DL(N);

if (N->getNumOperands() <= 2)
return SDValue();

if (!TLI.isTypeLegal(N->getValueType(0)))
return SDValue();
MVT VT = N->getSimpleValueType(0);

// Don't split any further than MF2.
MVT ContainerVT = VT;
if (VT.isFixedLengthVector())
ContainerVT = getContainerForFixedLengthVector(DAG, VT, TLI.getSubtarget());
if (ContainerVT.bitsLT(getLMUL1VT(ContainerVT)))
return SDValue();

MVT HalfVT = VT.getHalfNumVectorElementsVT();
assert(isPowerOf2_32(N->getNumOperands()));
size_t HalfNumOps = N->getNumOperands() / 2;
SDValue Lo = DAG.getNode(ISD::CONCAT_VECTORS, DL, HalfVT,
N->ops().take_front(HalfNumOps));
SDValue Hi = DAG.getNode(ISD::CONCAT_VECTORS, DL, HalfVT,
N->ops().drop_front(HalfNumOps));

// Lower to an insert_subvector directly so the concat_vectors don't get
// recombined.
SDValue Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, DAG.getUNDEF(VT), Lo,
DAG.getVectorIdxConstant(0, DL));
Vec = DAG.getNode(
ISD::INSERT_SUBVECTOR, DL, VT, Vec, Hi,
DAG.getVectorIdxConstant(HalfVT.getVectorMinNumElements(), DL));
return Vec;
}

// If we're concatenating a series of vector loads like
// concat_vectors (load v4i8, p+0), (load v4i8, p+n), (load v4i8, p+n*2) ...
// Then we can turn this into a strided load by widening the vector elements
// vlse32 p, stride=n
static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget,
const RISCVTargetLowering &TLI) {
static SDValue
performCONCAT_VECTORSStridedLoadCombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget,
const RISCVTargetLowering &TLI) {
SDLoc DL(N);
EVT VT = N->getValueType(0);

Expand Down Expand Up @@ -16394,7 +16443,10 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
return V;
break;
case ISD::CONCAT_VECTORS:
if (SDValue V = performCONCAT_VECTORSCombine(N, DAG, Subtarget, *this))
if (SDValue V =
performCONCAT_VECTORSStridedLoadCombine(N, DAG, Subtarget, *this))
return V;
if (SDValue V = performCONCAT_VECTORSSplitCombine(N, DAG, *this))
return V;
break;
case ISD::INSERT_VECTOR_ELT:
Expand Down
91 changes: 45 additions & 46 deletions llvm/test/CodeGen/RISCV/rvv/active_lane_mask.ll
Original file line number Diff line number Diff line change
Expand Up @@ -161,72 +161,71 @@ define <64 x i1> @fv64(ptr %p, i64 %index, i64 %tc) {
define <128 x i1> @fv128(ptr %p, i64 %index, i64 %tc) {
; CHECK-LABEL: fv128:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 16, e64, m8, ta, ma
; CHECK-NEXT: lui a0, %hi(.LCPI10_0)
; CHECK-NEXT: addi a0, a0, %lo(.LCPI10_0)
; CHECK-NEXT: vsetivli zero, 16, e64, m8, ta, ma
; CHECK-NEXT: vle8.v v8, (a0)
; CHECK-NEXT: vid.v v16
; CHECK-NEXT: vsaddu.vx v16, v16, a1
; CHECK-NEXT: vmsltu.vx v0, v16, a2
; CHECK-NEXT: vsext.vf8 v16, v8
; CHECK-NEXT: vsaddu.vx v8, v16, a1
; CHECK-NEXT: vmsltu.vx v16, v8, a2
; CHECK-NEXT: vsetivli zero, 4, e8, m1, tu, ma
; CHECK-NEXT: vslideup.vi v0, v16, 2
; CHECK-NEXT: lui a0, %hi(.LCPI10_1)
; CHECK-NEXT: addi a0, a0, %lo(.LCPI10_1)
; CHECK-NEXT: vsetivli zero, 16, e64, m8, ta, ma
; CHECK-NEXT: vle8.v v8, (a0)
; CHECK-NEXT: vle8.v v9, (a0)
; CHECK-NEXT: vsext.vf8 v16, v8
; CHECK-NEXT: vsaddu.vx v8, v16, a1
; CHECK-NEXT: vmsltu.vx v16, v8, a2
; CHECK-NEXT: vsetivli zero, 6, e8, m1, tu, ma
; CHECK-NEXT: vslideup.vi v0, v16, 4
; CHECK-NEXT: vsaddu.vx v16, v16, a1
; CHECK-NEXT: vmsltu.vx v10, v16, a2
; CHECK-NEXT: vsext.vf8 v16, v9
; CHECK-NEXT: vsaddu.vx v16, v16, a1
; CHECK-NEXT: vmsltu.vx v8, v16, a2
; CHECK-NEXT: vsetivli zero, 4, e8, mf2, tu, ma
; CHECK-NEXT: vslideup.vi v8, v10, 2
; CHECK-NEXT: lui a0, %hi(.LCPI10_2)
; CHECK-NEXT: addi a0, a0, %lo(.LCPI10_2)
; CHECK-NEXT: vsetivli zero, 16, e64, m8, ta, ma
; CHECK-NEXT: vle8.v v8, (a0)
; CHECK-NEXT: vsext.vf8 v16, v8
; CHECK-NEXT: vsaddu.vx v8, v16, a1
; CHECK-NEXT: vmsltu.vx v16, v8, a2
; CHECK-NEXT: vsetivli zero, 8, e8, m1, tu, ma
; CHECK-NEXT: vslideup.vi v0, v16, 6
; CHECK-NEXT: vle8.v v9, (a0)
; CHECK-NEXT: vsext.vf8 v16, v9
; CHECK-NEXT: vsaddu.vx v16, v16, a1
; CHECK-NEXT: vmsltu.vx v9, v16, a2
; CHECK-NEXT: vsetivli zero, 6, e8, mf2, tu, ma
; CHECK-NEXT: vslideup.vi v8, v9, 4
; CHECK-NEXT: lui a0, %hi(.LCPI10_3)
; CHECK-NEXT: addi a0, a0, %lo(.LCPI10_3)
; CHECK-NEXT: vsetivli zero, 16, e64, m8, ta, ma
; CHECK-NEXT: vle8.v v8, (a0)
; CHECK-NEXT: vsext.vf8 v16, v8
; CHECK-NEXT: vsaddu.vx v8, v16, a1
; CHECK-NEXT: vmsltu.vx v16, v8, a2
; CHECK-NEXT: vsetivli zero, 10, e8, m1, tu, ma
; CHECK-NEXT: vslideup.vi v0, v16, 8
; CHECK-NEXT: vle8.v v9, (a0)
; CHECK-NEXT: vsext.vf8 v16, v9
; CHECK-NEXT: vsaddu.vx v16, v16, a1
; CHECK-NEXT: vmsltu.vx v9, v16, a2
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
; CHECK-NEXT: vslideup.vi v8, v9, 6
; CHECK-NEXT: vsetivli zero, 16, e64, m8, ta, ma
; CHECK-NEXT: lui a0, %hi(.LCPI10_4)
; CHECK-NEXT: addi a0, a0, %lo(.LCPI10_4)
; CHECK-NEXT: vsetivli zero, 16, e64, m8, ta, ma
; CHECK-NEXT: vle8.v v8, (a0)
; CHECK-NEXT: vsext.vf8 v16, v8
; CHECK-NEXT: vsaddu.vx v8, v16, a1
; CHECK-NEXT: vmsltu.vx v16, v8, a2
; CHECK-NEXT: vsetivli zero, 12, e8, m1, tu, ma
; CHECK-NEXT: vslideup.vi v0, v16, 10
; CHECK-NEXT: vle8.v v9, (a0)
; CHECK-NEXT: vid.v v16
; CHECK-NEXT: vsaddu.vx v16, v16, a1
; CHECK-NEXT: vmsltu.vx v0, v16, a2
; CHECK-NEXT: vsext.vf8 v16, v9
; CHECK-NEXT: vsaddu.vx v16, v16, a1
; CHECK-NEXT: vmsltu.vx v9, v16, a2
; CHECK-NEXT: vsetivli zero, 4, e8, mf2, tu, ma
; CHECK-NEXT: vslideup.vi v0, v9, 2
; CHECK-NEXT: lui a0, %hi(.LCPI10_5)
; CHECK-NEXT: addi a0, a0, %lo(.LCPI10_5)
; CHECK-NEXT: vsetivli zero, 16, e64, m8, ta, ma
; CHECK-NEXT: vle8.v v8, (a0)
; CHECK-NEXT: vsext.vf8 v16, v8
; CHECK-NEXT: vsaddu.vx v8, v16, a1
; CHECK-NEXT: vmsltu.vx v16, v8, a2
; CHECK-NEXT: vsetivli zero, 14, e8, m1, tu, ma
; CHECK-NEXT: vslideup.vi v0, v16, 12
; CHECK-NEXT: vle8.v v9, (a0)
; CHECK-NEXT: vsext.vf8 v16, v9
; CHECK-NEXT: vsaddu.vx v16, v16, a1
; CHECK-NEXT: vmsltu.vx v9, v16, a2
; CHECK-NEXT: vsetivli zero, 6, e8, mf2, tu, ma
; CHECK-NEXT: vslideup.vi v0, v9, 4
; CHECK-NEXT: lui a0, %hi(.LCPI10_6)
; CHECK-NEXT: addi a0, a0, %lo(.LCPI10_6)
; CHECK-NEXT: vsetivli zero, 16, e64, m8, ta, ma
; CHECK-NEXT: vle8.v v8, (a0)
; CHECK-NEXT: vsext.vf8 v16, v8
; CHECK-NEXT: vsaddu.vx v8, v16, a1
; CHECK-NEXT: vmsltu.vx v16, v8, a2
; CHECK-NEXT: vsetvli zero, zero, e8, m1, ta, ma
; CHECK-NEXT: vslideup.vi v0, v16, 14
; CHECK-NEXT: vle8.v v9, (a0)
; CHECK-NEXT: vsext.vf8 v16, v9
; CHECK-NEXT: vsaddu.vx v16, v16, a1
; CHECK-NEXT: vmsltu.vx v9, v16, a2
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
; CHECK-NEXT: vslideup.vi v0, v9, 6
; CHECK-NEXT: vsetivli zero, 16, e8, m1, ta, ma
; CHECK-NEXT: vslideup.vi v0, v8, 8
; CHECK-NEXT: ret
%mask = call <128 x i1> @llvm.get.active.lane.mask.v128i1.i64(i64 %index, i64 %tc)
ret <128 x i1> %mask
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/CodeGen/RISCV/rvv/combine-store-extract-crash.ll
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ define void @test(ptr %ref_array, ptr %sad_array) {
; RV32-NEXT: th.swia a0, (a1), 4, 0
; RV32-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
; RV32-NEXT: vle8.v v10, (a3)
; RV32-NEXT: vsetivli zero, 8, e8, m1, tu, ma
; RV32-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
; RV32-NEXT: vslideup.vi v10, v9, 4
; RV32-NEXT: vsetivli zero, 16, e32, m4, ta, ma
; RV32-NEXT: vzext.vf4 v12, v10
Expand All @@ -42,7 +42,7 @@ define void @test(ptr %ref_array, ptr %sad_array) {
; RV64-NEXT: th.swia a0, (a1), 4, 0
; RV64-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
; RV64-NEXT: vle8.v v10, (a3)
; RV64-NEXT: vsetivli zero, 8, e8, m1, tu, ma
; RV64-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
; RV64-NEXT: vslideup.vi v10, v9, 4
; RV64-NEXT: vsetivli zero, 16, e32, m4, ta, ma
; RV64-NEXT: vzext.vf4 v12, v10
Expand Down
3 changes: 1 addition & 2 deletions llvm/test/CodeGen/RISCV/rvv/extract-subvector.ll
Original file line number Diff line number Diff line change
Expand Up @@ -469,9 +469,8 @@ define <vscale x 6 x half> @extract_nxv6f16_nxv12f16_6(<vscale x 12 x half> %in)
; CHECK: # %bb.0:
; CHECK-NEXT: csrr a0, vlenb
; CHECK-NEXT: srli a0, a0, 2
; CHECK-NEXT: vsetvli zero, a0, e16, m1, ta, ma
; CHECK-NEXT: vslidedown.vx v13, v10, a0
; CHECK-NEXT: vsetvli a1, zero, e16, m1, ta, ma
; CHECK-NEXT: vslidedown.vx v13, v10, a0
; CHECK-NEXT: vslidedown.vx v12, v9, a0
; CHECK-NEXT: add a1, a0, a0
; CHECK-NEXT: vsetvli zero, a1, e16, m1, ta, ma
Expand Down

0 comments on commit c59129a

Please sign in to comment.