Skip to content

Commit

Permalink
[RISCV] Optimize splats of extracted vector elements
Browse files Browse the repository at this point in the history
This patch adds an optimization to splat-like operations where the
splatted value is extracted from a identically-sized vector. On RVV we
can splat that via vrgather.vx/vrgather.vi without dropping to scalar
beforehand.

We do have a similar VECTOR_SHUFFLE-specific optimization but that only
works on fixed-length vector types and for those with a constant splat
lane. This patch extends this optimization to make it work on
scalable-vector types and on unknown extract indices.

It is performed during fixed-vector BUILD_VECTOR lowering and during a
new DAGCombine on SPLAT_VECTOR for scalable vectors.

Reviewed By: craig.topper, khchen

Differential Revision: https://reviews.llvm.org/D118456
  • Loading branch information
frasercrmck committed Feb 8, 2022
1 parent 215aba7 commit 62c4ac7
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 52 deletions.
47 changes: 47 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Expand Up @@ -1086,6 +1086,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setTargetDAGCombine(ISD::SRL);
setTargetDAGCombine(ISD::SHL);
setTargetDAGCombine(ISD::STORE);
setTargetDAGCombine(ISD::SPLAT_VECTOR);
}

setLibcallName(RTLIB::FPEXT_F16_F32, "__extendhfsf2");
Expand Down Expand Up @@ -2000,6 +2001,40 @@ static Optional<VIDSequence> isSimpleVIDSequence(SDValue Op) {
return VIDSequence{*SeqStepNum, *SeqStepDenom, *SeqAddend};
}

// Match a splatted value (SPLAT_VECTOR/BUILD_VECTOR) of an EXTRACT_VECTOR_ELT
// and lower it as a VRGATHER_VX_VL from the source vector.
static SDValue matchSplatAsGather(SDValue SplatVal, MVT VT, const SDLoc &DL,
SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
if (SplatVal.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
return SDValue();
SDValue Vec = SplatVal.getOperand(0);
// Only perform this optimization on vectors of the same size for simplicity.
if (Vec.getValueType() != VT)
return SDValue();
SDValue Idx = SplatVal.getOperand(1);
// The index must be a legal type.
if (Idx.getValueType() != Subtarget.getXLenVT())
return SDValue();

MVT ContainerVT = VT;
if (VT.isFixedLengthVector()) {
ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
}

SDValue Mask, VL;
std::tie(Mask, VL) = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);

SDValue Gather = DAG.getNode(RISCVISD::VRGATHER_VX_VL, DL, ContainerVT, Vec,
Idx, Mask, VL);

if (!VT.isFixedLengthVector())
return Gather;

return convertFromScalableVector(VT, Gather, DAG, Subtarget);
}

static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
MVT VT = Op.getSimpleValueType();
Expand Down Expand Up @@ -2123,6 +2158,8 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
}

if (SDValue Splat = cast<BuildVectorSDNode>(Op)->getSplatValue()) {
if (auto Gather = matchSplatAsGather(Splat, VT, DL, DAG, Subtarget))
return Gather;
unsigned Opc = VT.isFloatingPoint() ? RISCVISD::VFMV_V_F_VL
: RISCVISD::VMV_V_X_VL;
Splat = DAG.getNode(Opc, DL, ContainerVT, Splat, VL);
Expand Down Expand Up @@ -8260,6 +8297,16 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,

break;
}
case ISD::SPLAT_VECTOR: {
EVT VT = N->getValueType(0);
// Only perform this combine on legal MVT types.
if (!isTypeLegal(VT))
break;
if (auto Gather = matchSplatAsGather(N->getOperand(0), VT.getSimpleVT(), N,
DAG, Subtarget))
return Gather;
break;
}
}

return SDValue();
Expand Down
12 changes: 4 additions & 8 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp-buildvec.ll
Expand Up @@ -217,11 +217,9 @@ define <4 x half> @splat_c3_v4f16(<4 x half> %v) {
define <4 x half> @splat_idx_v4f16(<4 x half> %v, i64 %idx) {
; CHECK-LABEL: splat_idx_v4f16:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 1, e16, mf2, ta, mu
; CHECK-NEXT: vslidedown.vx v8, v8, a0
; CHECK-NEXT: vfmv.f.s ft0, v8
; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, mu
; CHECK-NEXT: vfmv.v.f v8, ft0
; CHECK-NEXT: vrgather.vx v9, v8, a0
; CHECK-NEXT: vmv1r.v v8, v9
; CHECK-NEXT: ret
%x = extractelement <4 x half> %v, i64 %idx
%ins = insertelement <4 x half> poison, half %x, i32 0
Expand Down Expand Up @@ -270,11 +268,9 @@ define <8 x float> @splat_idx_v8f32(<8 x float> %v, i64 %idx) {
;
; LMULMAX2-LABEL: splat_idx_v8f32:
; LMULMAX2: # %bb.0:
; LMULMAX2-NEXT: vsetivli zero, 1, e32, m2, ta, mu
; LMULMAX2-NEXT: vslidedown.vx v8, v8, a0
; LMULMAX2-NEXT: vfmv.f.s ft0, v8
; LMULMAX2-NEXT: vsetivli zero, 8, e32, m2, ta, mu
; LMULMAX2-NEXT: vfmv.v.f v8, ft0
; LMULMAX2-NEXT: vrgather.vx v10, v8, a0
; LMULMAX2-NEXT: vmv.v.v v8, v10
; LMULMAX2-NEXT: ret
%x = extractelement <8 x float> %v, i64 %idx
%ins = insertelement <8 x float> poison, float %x, i32 0
Expand Down
12 changes: 4 additions & 8 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-int-buildvec.ll
Expand Up @@ -665,11 +665,9 @@ define <4 x i32> @splat_c3_v4i32(<4 x i32> %v) {
define <4 x i32> @splat_idx_v4i32(<4 x i32> %v, i64 %idx) {
; CHECK-LABEL: splat_idx_v4i32:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 1, e32, m1, ta, mu
; CHECK-NEXT: vslidedown.vx v8, v8, a0
; CHECK-NEXT: vmv.x.s a0, v8
; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, mu
; CHECK-NEXT: vmv.v.x v8, a0
; CHECK-NEXT: vrgather.vx v9, v8, a0
; CHECK-NEXT: vmv.v.v v8, v9
; CHECK-NEXT: ret
%x = extractelement <4 x i32> %v, i64 %idx
%ins = insertelement <4 x i32> poison, i32 %x, i32 0
Expand All @@ -693,11 +691,9 @@ define <8 x i16> @splat_c4_v8i16(<8 x i16> %v) {
define <8 x i16> @splat_idx_v8i16(<8 x i16> %v, i64 %idx) {
; CHECK-LABEL: splat_idx_v8i16:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, mu
; CHECK-NEXT: vslidedown.vx v8, v8, a0
; CHECK-NEXT: vmv.x.s a0, v8
; CHECK-NEXT: vsetivli zero, 8, e16, m1, ta, mu
; CHECK-NEXT: vmv.v.x v8, a0
; CHECK-NEXT: vrgather.vx v9, v8, a0
; CHECK-NEXT: vmv.v.v v8, v9
; CHECK-NEXT: ret
%x = extractelement <8 x i16> %v, i64 %idx
%ins = insertelement <8 x i16> poison, i16 %x, i32 0
Expand Down
56 changes: 20 additions & 36 deletions llvm/test/CodeGen/RISCV/rvv/splat-vectors.ll
Expand Up @@ -5,11 +5,9 @@
define <vscale x 4 x i32> @splat_c3_nxv4i32(<vscale x 4 x i32> %v) {
; CHECK-LABEL: splat_c3_nxv4i32:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 1, e32, m2, ta, mu
; CHECK-NEXT: vslidedown.vi v8, v8, 3
; CHECK-NEXT: vmv.x.s a0, v8
; CHECK-NEXT: vsetvli a1, zero, e32, m2, ta, mu
; CHECK-NEXT: vmv.v.x v8, a0
; CHECK-NEXT: vsetvli a0, zero, e32, m2, ta, mu
; CHECK-NEXT: vrgather.vi v10, v8, 3
; CHECK-NEXT: vmv.v.v v8, v10
; CHECK-NEXT: ret
%x = extractelement <vscale x 4 x i32> %v, i32 3
%ins = insertelement <vscale x 4 x i32> poison, i32 %x, i32 0
Expand All @@ -20,11 +18,9 @@ define <vscale x 4 x i32> @splat_c3_nxv4i32(<vscale x 4 x i32> %v) {
define <vscale x 4 x i32> @splat_idx_nxv4i32(<vscale x 4 x i32> %v, i64 %idx) {
; CHECK-LABEL: splat_idx_nxv4i32:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 1, e32, m2, ta, mu
; CHECK-NEXT: vslidedown.vx v8, v8, a0
; CHECK-NEXT: vmv.x.s a0, v8
; CHECK-NEXT: vsetvli a1, zero, e32, m2, ta, mu
; CHECK-NEXT: vmv.v.x v8, a0
; CHECK-NEXT: vrgather.vx v10, v8, a0
; CHECK-NEXT: vmv.v.v v8, v10
; CHECK-NEXT: ret
%x = extractelement <vscale x 4 x i32> %v, i64 %idx
%ins = insertelement <vscale x 4 x i32> poison, i32 %x, i32 0
Expand All @@ -35,11 +31,9 @@ define <vscale x 4 x i32> @splat_idx_nxv4i32(<vscale x 4 x i32> %v, i64 %idx) {
define <vscale x 8 x i16> @splat_c4_nxv8i16(<vscale x 8 x i16> %v) {
; CHECK-LABEL: splat_c4_nxv8i16:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 1, e16, m2, ta, mu
; CHECK-NEXT: vslidedown.vi v8, v8, 4
; CHECK-NEXT: vmv.x.s a0, v8
; CHECK-NEXT: vsetvli a1, zero, e16, m2, ta, mu
; CHECK-NEXT: vmv.v.x v8, a0
; CHECK-NEXT: vsetvli a0, zero, e16, m2, ta, mu
; CHECK-NEXT: vrgather.vi v10, v8, 4
; CHECK-NEXT: vmv.v.v v8, v10
; CHECK-NEXT: ret
%x = extractelement <vscale x 8 x i16> %v, i32 4
%ins = insertelement <vscale x 8 x i16> poison, i16 %x, i32 0
Expand All @@ -50,11 +44,9 @@ define <vscale x 8 x i16> @splat_c4_nxv8i16(<vscale x 8 x i16> %v) {
define <vscale x 8 x i16> @splat_idx_nxv8i16(<vscale x 8 x i16> %v, i64 %idx) {
; CHECK-LABEL: splat_idx_nxv8i16:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 1, e16, m2, ta, mu
; CHECK-NEXT: vslidedown.vx v8, v8, a0
; CHECK-NEXT: vmv.x.s a0, v8
; CHECK-NEXT: vsetvli a1, zero, e16, m2, ta, mu
; CHECK-NEXT: vmv.v.x v8, a0
; CHECK-NEXT: vrgather.vx v10, v8, a0
; CHECK-NEXT: vmv.v.v v8, v10
; CHECK-NEXT: ret
%x = extractelement <vscale x 8 x i16> %v, i64 %idx
%ins = insertelement <vscale x 8 x i16> poison, i16 %x, i32 0
Expand All @@ -65,11 +57,9 @@ define <vscale x 8 x i16> @splat_idx_nxv8i16(<vscale x 8 x i16> %v, i64 %idx) {
define <vscale x 2 x half> @splat_c1_nxv2f16(<vscale x 2 x half> %v) {
; CHECK-LABEL: splat_c1_nxv2f16:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 1, e16, mf2, ta, mu
; CHECK-NEXT: vslidedown.vi v8, v8, 1
; CHECK-NEXT: vfmv.f.s ft0, v8
; CHECK-NEXT: vsetvli a0, zero, e16, mf2, ta, mu
; CHECK-NEXT: vfmv.v.f v8, ft0
; CHECK-NEXT: vrgather.vi v9, v8, 1
; CHECK-NEXT: vmv1r.v v8, v9
; CHECK-NEXT: ret
%x = extractelement <vscale x 2 x half> %v, i32 1
%ins = insertelement <vscale x 2 x half> poison, half %x, i32 0
Expand All @@ -80,11 +70,9 @@ define <vscale x 2 x half> @splat_c1_nxv2f16(<vscale x 2 x half> %v) {
define <vscale x 2 x half> @splat_idx_nxv2f16(<vscale x 2 x half> %v, i64 %idx) {
; CHECK-LABEL: splat_idx_nxv2f16:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 1, e16, mf2, ta, mu
; CHECK-NEXT: vslidedown.vx v8, v8, a0
; CHECK-NEXT: vfmv.f.s ft0, v8
; CHECK-NEXT: vsetvli a0, zero, e16, mf2, ta, mu
; CHECK-NEXT: vfmv.v.f v8, ft0
; CHECK-NEXT: vsetvli a1, zero, e16, mf2, ta, mu
; CHECK-NEXT: vrgather.vx v9, v8, a0
; CHECK-NEXT: vmv1r.v v8, v9
; CHECK-NEXT: ret
%x = extractelement <vscale x 2 x half> %v, i64 %idx
%ins = insertelement <vscale x 2 x half> poison, half %x, i32 0
Expand All @@ -95,11 +83,9 @@ define <vscale x 2 x half> @splat_idx_nxv2f16(<vscale x 2 x half> %v, i64 %idx)
define <vscale x 4 x float> @splat_c3_nxv4f32(<vscale x 4 x float> %v) {
; CHECK-LABEL: splat_c3_nxv4f32:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 1, e32, m2, ta, mu
; CHECK-NEXT: vslidedown.vi v8, v8, 3
; CHECK-NEXT: vfmv.f.s ft0, v8
; CHECK-NEXT: vsetvli a0, zero, e32, m2, ta, mu
; CHECK-NEXT: vfmv.v.f v8, ft0
; CHECK-NEXT: vrgather.vi v10, v8, 3
; CHECK-NEXT: vmv.v.v v8, v10
; CHECK-NEXT: ret
%x = extractelement <vscale x 4 x float> %v, i64 3
%ins = insertelement <vscale x 4 x float> poison, float %x, i32 0
Expand All @@ -110,11 +96,9 @@ define <vscale x 4 x float> @splat_c3_nxv4f32(<vscale x 4 x float> %v) {
define <vscale x 4 x float> @splat_idx_nxv4f32(<vscale x 4 x float> %v, i64 %idx) {
; CHECK-LABEL: splat_idx_nxv4f32:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 1, e32, m2, ta, mu
; CHECK-NEXT: vslidedown.vx v8, v8, a0
; CHECK-NEXT: vfmv.f.s ft0, v8
; CHECK-NEXT: vsetvli a0, zero, e32, m2, ta, mu
; CHECK-NEXT: vfmv.v.f v8, ft0
; CHECK-NEXT: vsetvli a1, zero, e32, m2, ta, mu
; CHECK-NEXT: vrgather.vx v10, v8, a0
; CHECK-NEXT: vmv.v.v v8, v10
; CHECK-NEXT: ret
%x = extractelement <vscale x 4 x float> %v, i64 %idx
%ins = insertelement <vscale x 4 x float> poison, float %x, i32 0
Expand Down

0 comments on commit 62c4ac7

Please sign in to comment.