Skip to content

Commit

Permalink
[RISCV] Recursively split concat_vector into smaller LMULs when lower…
Browse files Browse the repository at this point in the history
…ing (#85825)

This is a reimplementation of the combine added in #83035 but as a
lowering instead of a combine, so we don't regress the test case added
in e59f120 by interfering with the
strided load combine

Previously the combine had to concatenate the split vectors with
insert_subvector instead of concat_vectors to prevent an infinite
combine loop. And the reasoning behind keeping it as a combine was
because if we emitted the insert_subvector during lowering then we
didn't fold away inserts of undef subvectors.

However it turns out we can avoid this if we just do this in lowering
and select a concat_vector directly, since we get the undef folding for
free with `DAG.getNode(ISD::CONCAT_VECTOR, ...)` via foldCONCAT_VECTORS.
  • Loading branch information
lukel97 committed Mar 21, 2024
1 parent dc74bf7 commit 06d2452
Show file tree
Hide file tree
Showing 9 changed files with 844 additions and 753 deletions.
24 changes: 24 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6611,6 +6611,30 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
// better than going through the stack, as the default expansion does.
SDLoc DL(Op);
MVT VT = Op.getSimpleValueType();
MVT ContainerVT = VT;
if (VT.isFixedLengthVector())
ContainerVT = ::getContainerForFixedLengthVector(DAG, VT, Subtarget);

// Recursively split concat_vectors with more than 2 operands:
//
// concat_vector op1, op2, op3, op4
// ->
// concat_vector (concat_vector op1, op2), (concat_vector op3, op4)
//
// This reduces the length of the chain of vslideups and allows us to
// perform the vslideups at a smaller LMUL, limited to MF2.
if (Op.getNumOperands() > 2 &&
ContainerVT.bitsGE(getLMUL1VT(ContainerVT))) {
MVT HalfVT = VT.getHalfNumVectorElementsVT();
assert(isPowerOf2_32(Op.getNumOperands()));
size_t HalfNumOps = Op.getNumOperands() / 2;
SDValue Lo = DAG.getNode(ISD::CONCAT_VECTORS, DL, HalfVT,
Op->ops().take_front(HalfNumOps));
SDValue Hi = DAG.getNode(ISD::CONCAT_VECTORS, DL, HalfVT,
Op->ops().drop_front(HalfNumOps));
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Hi);
}

unsigned NumOpElts =
Op.getOperand(0).getSimpleValueType().getVectorMinNumElements();
SDValue Vec = DAG.getUNDEF(VT);
Expand Down
91 changes: 45 additions & 46 deletions llvm/test/CodeGen/RISCV/rvv/active_lane_mask.ll
Original file line number Diff line number Diff line change
Expand Up @@ -161,72 +161,71 @@ define <64 x i1> @fv64(ptr %p, i64 %index, i64 %tc) {
define <128 x i1> @fv128(ptr %p, i64 %index, i64 %tc) {
; CHECK-LABEL: fv128:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 16, e64, m8, ta, ma
; CHECK-NEXT: lui a0, %hi(.LCPI10_0)
; CHECK-NEXT: addi a0, a0, %lo(.LCPI10_0)
; CHECK-NEXT: vsetivli zero, 16, e64, m8, ta, ma
; CHECK-NEXT: vle8.v v8, (a0)
; CHECK-NEXT: vid.v v16
; CHECK-NEXT: vsaddu.vx v16, v16, a1
; CHECK-NEXT: vmsltu.vx v0, v16, a2
; CHECK-NEXT: vsext.vf8 v16, v8
; CHECK-NEXT: vsaddu.vx v8, v16, a1
; CHECK-NEXT: vmsltu.vx v16, v8, a2
; CHECK-NEXT: vsetivli zero, 4, e8, m1, tu, ma
; CHECK-NEXT: vslideup.vi v0, v16, 2
; CHECK-NEXT: lui a0, %hi(.LCPI10_1)
; CHECK-NEXT: addi a0, a0, %lo(.LCPI10_1)
; CHECK-NEXT: vsetivli zero, 16, e64, m8, ta, ma
; CHECK-NEXT: vle8.v v8, (a0)
; CHECK-NEXT: vle8.v v9, (a0)
; CHECK-NEXT: vsext.vf8 v16, v8
; CHECK-NEXT: vsaddu.vx v8, v16, a1
; CHECK-NEXT: vmsltu.vx v16, v8, a2
; CHECK-NEXT: vsetivli zero, 6, e8, m1, tu, ma
; CHECK-NEXT: vslideup.vi v0, v16, 4
; CHECK-NEXT: vsaddu.vx v16, v16, a1
; CHECK-NEXT: vmsltu.vx v10, v16, a2
; CHECK-NEXT: vsext.vf8 v16, v9
; CHECK-NEXT: vsaddu.vx v16, v16, a1
; CHECK-NEXT: vmsltu.vx v8, v16, a2
; CHECK-NEXT: vsetivli zero, 4, e8, mf2, tu, ma
; CHECK-NEXT: vslideup.vi v8, v10, 2
; CHECK-NEXT: lui a0, %hi(.LCPI10_2)
; CHECK-NEXT: addi a0, a0, %lo(.LCPI10_2)
; CHECK-NEXT: vsetivli zero, 16, e64, m8, ta, ma
; CHECK-NEXT: vle8.v v8, (a0)
; CHECK-NEXT: vsext.vf8 v16, v8
; CHECK-NEXT: vsaddu.vx v8, v16, a1
; CHECK-NEXT: vmsltu.vx v16, v8, a2
; CHECK-NEXT: vsetivli zero, 8, e8, m1, tu, ma
; CHECK-NEXT: vslideup.vi v0, v16, 6
; CHECK-NEXT: vle8.v v9, (a0)
; CHECK-NEXT: vsext.vf8 v16, v9
; CHECK-NEXT: vsaddu.vx v16, v16, a1
; CHECK-NEXT: vmsltu.vx v9, v16, a2
; CHECK-NEXT: vsetivli zero, 6, e8, mf2, tu, ma
; CHECK-NEXT: vslideup.vi v8, v9, 4
; CHECK-NEXT: lui a0, %hi(.LCPI10_3)
; CHECK-NEXT: addi a0, a0, %lo(.LCPI10_3)
; CHECK-NEXT: vsetivli zero, 16, e64, m8, ta, ma
; CHECK-NEXT: vle8.v v8, (a0)
; CHECK-NEXT: vsext.vf8 v16, v8
; CHECK-NEXT: vsaddu.vx v8, v16, a1
; CHECK-NEXT: vmsltu.vx v16, v8, a2
; CHECK-NEXT: vsetivli zero, 10, e8, m1, tu, ma
; CHECK-NEXT: vslideup.vi v0, v16, 8
; CHECK-NEXT: vle8.v v9, (a0)
; CHECK-NEXT: vsext.vf8 v16, v9
; CHECK-NEXT: vsaddu.vx v16, v16, a1
; CHECK-NEXT: vmsltu.vx v9, v16, a2
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
; CHECK-NEXT: vslideup.vi v8, v9, 6
; CHECK-NEXT: vsetivli zero, 16, e64, m8, ta, ma
; CHECK-NEXT: lui a0, %hi(.LCPI10_4)
; CHECK-NEXT: addi a0, a0, %lo(.LCPI10_4)
; CHECK-NEXT: vsetivli zero, 16, e64, m8, ta, ma
; CHECK-NEXT: vle8.v v8, (a0)
; CHECK-NEXT: vsext.vf8 v16, v8
; CHECK-NEXT: vsaddu.vx v8, v16, a1
; CHECK-NEXT: vmsltu.vx v16, v8, a2
; CHECK-NEXT: vsetivli zero, 12, e8, m1, tu, ma
; CHECK-NEXT: vslideup.vi v0, v16, 10
; CHECK-NEXT: vle8.v v9, (a0)
; CHECK-NEXT: vid.v v16
; CHECK-NEXT: vsaddu.vx v16, v16, a1
; CHECK-NEXT: vmsltu.vx v0, v16, a2
; CHECK-NEXT: vsext.vf8 v16, v9
; CHECK-NEXT: vsaddu.vx v16, v16, a1
; CHECK-NEXT: vmsltu.vx v9, v16, a2
; CHECK-NEXT: vsetivli zero, 4, e8, mf2, tu, ma
; CHECK-NEXT: vslideup.vi v0, v9, 2
; CHECK-NEXT: lui a0, %hi(.LCPI10_5)
; CHECK-NEXT: addi a0, a0, %lo(.LCPI10_5)
; CHECK-NEXT: vsetivli zero, 16, e64, m8, ta, ma
; CHECK-NEXT: vle8.v v8, (a0)
; CHECK-NEXT: vsext.vf8 v16, v8
; CHECK-NEXT: vsaddu.vx v8, v16, a1
; CHECK-NEXT: vmsltu.vx v16, v8, a2
; CHECK-NEXT: vsetivli zero, 14, e8, m1, tu, ma
; CHECK-NEXT: vslideup.vi v0, v16, 12
; CHECK-NEXT: vle8.v v9, (a0)
; CHECK-NEXT: vsext.vf8 v16, v9
; CHECK-NEXT: vsaddu.vx v16, v16, a1
; CHECK-NEXT: vmsltu.vx v9, v16, a2
; CHECK-NEXT: vsetivli zero, 6, e8, mf2, tu, ma
; CHECK-NEXT: vslideup.vi v0, v9, 4
; CHECK-NEXT: lui a0, %hi(.LCPI10_6)
; CHECK-NEXT: addi a0, a0, %lo(.LCPI10_6)
; CHECK-NEXT: vsetivli zero, 16, e64, m8, ta, ma
; CHECK-NEXT: vle8.v v8, (a0)
; CHECK-NEXT: vsext.vf8 v16, v8
; CHECK-NEXT: vsaddu.vx v8, v16, a1
; CHECK-NEXT: vmsltu.vx v16, v8, a2
; CHECK-NEXT: vsetvli zero, zero, e8, m1, ta, ma
; CHECK-NEXT: vslideup.vi v0, v16, 14
; CHECK-NEXT: vle8.v v9, (a0)
; CHECK-NEXT: vsext.vf8 v16, v9
; CHECK-NEXT: vsaddu.vx v16, v16, a1
; CHECK-NEXT: vmsltu.vx v9, v16, a2
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
; CHECK-NEXT: vslideup.vi v0, v9, 6
; CHECK-NEXT: vsetivli zero, 16, e8, m1, ta, ma
; CHECK-NEXT: vslideup.vi v0, v8, 8
; CHECK-NEXT: ret
%mask = call <128 x i1> @llvm.get.active.lane.mask.v128i1.i64(i64 %index, i64 %tc)
ret <128 x i1> %mask
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/CodeGen/RISCV/rvv/combine-store-extract-crash.ll
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ define void @test(ptr %ref_array, ptr %sad_array) {
; RV32-NEXT: th.swia a0, (a1), 4, 0
; RV32-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
; RV32-NEXT: vle8.v v10, (a3)
; RV32-NEXT: vsetivli zero, 8, e8, m1, tu, ma
; RV32-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
; RV32-NEXT: vslideup.vi v10, v9, 4
; RV32-NEXT: vsetivli zero, 16, e32, m4, ta, ma
; RV32-NEXT: vzext.vf4 v12, v10
Expand All @@ -42,7 +42,7 @@ define void @test(ptr %ref_array, ptr %sad_array) {
; RV64-NEXT: th.swia a0, (a1), 4, 0
; RV64-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
; RV64-NEXT: vle8.v v10, (a3)
; RV64-NEXT: vsetivli zero, 8, e8, m1, tu, ma
; RV64-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
; RV64-NEXT: vslideup.vi v10, v9, 4
; RV64-NEXT: vsetivli zero, 16, e32, m4, ta, ma
; RV64-NEXT: vzext.vf4 v12, v10
Expand Down
3 changes: 1 addition & 2 deletions llvm/test/CodeGen/RISCV/rvv/extract-subvector.ll
Original file line number Diff line number Diff line change
Expand Up @@ -469,9 +469,8 @@ define <vscale x 6 x half> @extract_nxv6f16_nxv12f16_6(<vscale x 12 x half> %in)
; CHECK: # %bb.0:
; CHECK-NEXT: csrr a0, vlenb
; CHECK-NEXT: srli a0, a0, 2
; CHECK-NEXT: vsetvli zero, a0, e16, m1, ta, ma
; CHECK-NEXT: vslidedown.vx v13, v10, a0
; CHECK-NEXT: vsetvli a1, zero, e16, m1, ta, ma
; CHECK-NEXT: vslidedown.vx v13, v10, a0
; CHECK-NEXT: vslidedown.vx v12, v9, a0
; CHECK-NEXT: add a1, a0, a0
; CHECK-NEXT: vsetvli zero, a1, e16, m1, ta, ma
Expand Down
Loading

0 comments on commit 06d2452

Please sign in to comment.