Skip to content

Commit

Permalink
[RISCV] Teach lowerVECTOR_SHUFFLE to recognize some shuffles as vnsrl.
Browse files Browse the repository at this point in the history
Unary shuffles such as <0,2,4,6,8,10,12,14> or <1,3,5,7,9,11,13,15>
where half the elements are returned, can be lowered using vnsrl.

SelectionDAGBuilder lowers such shuffles as a build_vector of
extract_elements since the mask has less elements than the source.
To fix this, I've enable the extractSubvectorIsCheapHook to allow
DAGCombine to rebuild the shuffle using 2 extract_subvectors preceding
the shufffle.

I've gone very conservative on extractSubvectorIsCheapHook to minimize
test impact and match what we have test coverage for. This can be
improved in the future.

Reviewed By: reames

Differential Revision: https://reviews.llvm.org/D133736
  • Loading branch information
topperc committed Sep 13, 2022
1 parent 8989aa0 commit 8d7e73e
Show file tree
Hide file tree
Showing 5 changed files with 529 additions and 16 deletions.
118 changes: 118 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1398,6 +1398,39 @@ bool RISCVTargetLowering::isFPImmLegal(const APFloat &Imm, EVT VT,
return Imm.isZero();
}

// TODO: This is very conservative.
bool RISCVTargetLowering::isExtractSubvectorCheap(EVT ResVT, EVT SrcVT,
unsigned Index) const {
if (!isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, ResVT))
return false;

// Only support extracting a fixed from a fixed vector for now.
if (ResVT.isScalableVector() || SrcVT.isScalableVector())
return false;

unsigned ResElts = ResVT.getVectorNumElements();
unsigned SrcElts = SrcVT.getVectorNumElements();

// Convervatively only handle extracting half of a vector.
// TODO: Relax this.
if ((ResElts * 2) != SrcElts)
return false;

// The smallest type we can slide is i8.
// TODO: We can extract index 0 from a mask vector without a slide.
if (ResVT.getVectorElementType() == MVT::i1)
return false;

// Slide can support arbitrary index, but we only treat vslidedown.vi as
// cheap.
if (Index >= 32)
return false;

// TODO: We can do arbitrary slidedowns, but for now only support extracting
// the upper half of a vector until we have more test coverage.
return Index == 0 || Index == ResElts;
}

bool RISCVTargetLowering::hasBitPreservingFPLogic(EVT VT) const {
return (VT == MVT::f16 && Subtarget.hasStdExtZfh()) ||
(VT == MVT::f32 && Subtarget.hasStdExtF()) ||
Expand Down Expand Up @@ -2629,6 +2662,86 @@ static int isElementRotate(int &LoSrc, int &HiSrc, ArrayRef<int> Mask) {
return Rotation;
}

// Lower the following shuffles to vnsrl.
// t34: v8i8 = extract_subvector t11, Constant:i64<0>
// t33: v8i8 = extract_subvector t11, Constant:i64<8>
// a) t35: v8i8 = vector_shuffle<0,2,4,6,8,10,12,14> t34, t33
// b) t35: v8i8 = vector_shuffle<1,3,5,7,9,11,13,15> t34, t33
static SDValue lowerVECTOR_SHUFFLEAsVNSRL(const SDLoc &DL, MVT VT,
MVT ContainerVT, SDValue V1,
SDValue V2, SDValue TrueMask,
SDValue VL, ArrayRef<int> Mask,
const RISCVSubtarget &Subtarget,
SelectionDAG &DAG) {
// Need to be able to widen the vector.
if (VT.getScalarSizeInBits() >= Subtarget.getELEN())
return SDValue();

// Both input must be extracts.
if (V1.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
V2.getOpcode() != ISD::EXTRACT_SUBVECTOR)
return SDValue();

// Extracting from the same source.
SDValue Src = V1.getOperand(0);
if (Src != V2.getOperand(0))
return SDValue();

// Src needs to have twice the number of elements.
if (Src.getValueType().getVectorNumElements() != (Mask.size() * 2))
return SDValue();

// The extracts must extract the two halves of the source.
if (V1.getConstantOperandVal(1) != 0 ||
V2.getConstantOperandVal(1) != Mask.size())
return SDValue();

// First index must be the first even or odd element from V1.
if (Mask[0] != 0 && Mask[0] != 1)
return SDValue();

// The others must increase by 2 each time.
// TODO: Support undef elements?
for (unsigned i = 1; i != Mask.size(); ++i)
if (Mask[i] != Mask[i - 1] + 2)
return SDValue();

// Convert the source using a container type with twice the elements. Since
// source VT is legal and twice this VT, we know VT isn't LMUL=8 so it is
// safe to double.
MVT DoubleContainerVT =
MVT::getVectorVT(ContainerVT.getVectorElementType(),
ContainerVT.getVectorElementCount() * 2);
Src = convertToScalableVector(DoubleContainerVT, Src, DAG, Subtarget);

// Convert the vector to a wider integer type with the original element
// count. This also converts FP to int.
unsigned EltBits = ContainerVT.getScalarSizeInBits();
MVT WideIntEltVT = MVT::getIntegerVT(EltBits * 2);
MVT WideIntContainerVT =
MVT::getVectorVT(WideIntEltVT, ContainerVT.getVectorElementCount());
Src = DAG.getBitcast(WideIntContainerVT, Src);

// Convert to the integer version of the container type.
MVT IntEltVT = MVT::getIntegerVT(EltBits);
MVT IntContainerVT =
MVT::getVectorVT(IntEltVT, ContainerVT.getVectorElementCount());

// If we want even elements, then the shift amount is 0. Otherwise, shift by
// the original element size.
unsigned Shift = Mask[0] == 0 ? 0 : EltBits;
SDValue SplatShift = DAG.getNode(
RISCVISD::VMV_V_X_VL, DL, IntContainerVT, DAG.getUNDEF(ContainerVT),
DAG.getConstant(Shift, DL, Subtarget.getXLenVT()), VL);
SDValue Res =
DAG.getNode(RISCVISD::VNSRL_VL, DL, IntContainerVT, Src, SplatShift,
DAG.getUNDEF(IntContainerVT), TrueMask, VL);
// Cast back to FP if needed.
Res = DAG.getBitcast(ContainerVT, Res);

return convertFromScalableVector(VT, Res, DAG, Subtarget);
}

static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
SDValue V1 = Op.getOperand(0);
Expand Down Expand Up @@ -2760,6 +2873,10 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
return convertFromScalableVector(VT, Res, DAG, Subtarget);
}

if (SDValue V = lowerVECTOR_SHUFFLEAsVNSRL(
DL, VT, ContainerVT, V1, V2, TrueMask, VL, Mask, Subtarget, DAG))
return V;

// Detect an interleave shuffle and lower to
// (vmaccu.vx (vwaddu.vx lohalf(V1), lohalf(V2)), lohalf(V2), (2^eltbits - 1))
bool SwapSources;
Expand Down Expand Up @@ -12259,6 +12376,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(VWADDU_W_VL)
NODE_NAME_CASE(VWSUB_W_VL)
NODE_NAME_CASE(VWSUBU_W_VL)
NODE_NAME_CASE(VNSRL_VL)
NODE_NAME_CASE(SETCC_VL)
NODE_NAME_CASE(VSELECT_VL)
NODE_NAME_CASE(VP_MERGE_VL)
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ enum NodeType : unsigned {
VWSUB_W_VL,
VWSUBU_W_VL,

VNSRL_VL,

// Vector compare producing a mask. Fourth operand is input mask. Fifth
// operand is VL.
SETCC_VL,
Expand Down Expand Up @@ -386,6 +388,8 @@ class RISCVTargetLowering : public TargetLowering {
bool isOffsetFoldingLegal(const GlobalAddressSDNode *GA) const override;
bool isFPImmLegal(const APFloat &Imm, EVT VT,
bool ForCodeSize) const override;
bool isExtractSubvectorCheap(EVT ResVT, EVT SrcVT,
unsigned Index) const override;

bool isIntDivCheap(EVT VT, AttributeList Attr) const override;

Expand Down
30 changes: 30 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,16 @@ def riscv_vwaddu_vl : SDNode<"RISCVISD::VWADDU_VL", SDT_RISCVVWBinOp_VL, [SDNPCo
def riscv_vwsub_vl : SDNode<"RISCVISD::VWSUB_VL", SDT_RISCVVWBinOp_VL, [SDNPCommutative]>;
def riscv_vwsubu_vl : SDNode<"RISCVISD::VWSUBU_VL", SDT_RISCVVWBinOp_VL, [SDNPCommutative]>;

def SDT_RISCVVNBinOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>,
SDTCisSameNumEltsAs<0, 1>,
SDTCisOpSmallerThanOp<0, 1>,
SDTCisSameAs<0, 2>,
SDTCisSameAs<0, 3>,
SDTCisSameNumEltsAs<0, 4>,
SDTCVecEltisVT<4, i1>,
SDTCisVT<5, XLenVT>]>;
def riscv_vnsrl_vl : SDNode<"RISCVISD::VNSRL_VL", SDT_RISCVVNBinOp_VL>;

def SDT_RISCVVWBinOpW_VL : SDTypeProfile<1, 5, [SDTCisVec<0>,
SDTCisSameAs<0, 1>,
SDTCisSameNumEltsAs<1, 2>,
Expand Down Expand Up @@ -446,6 +456,24 @@ multiclass VPatBinaryWVL_VV_VX_WV_WX<SDNode vop, SDNode vop_w,
}
}

multiclass VPatBinaryNVL_WV_WX_WI<SDNode vop, string instruction_name> {
foreach VtiToWti = AllWidenableIntVectors in {
defvar vti = VtiToWti.Vti;
defvar wti = VtiToWti.Wti;
defm : VPatBinaryVL_V<vop, instruction_name, "WV",
vti.Vector, wti.Vector, vti.Vector, vti.Mask,
vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass, vti.RegClass>;
defm : VPatBinaryVL_XI<vop, instruction_name, "WX",
vti.Vector, wti.Vector, vti.Vector, vti.Mask,
vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass, SplatPat, GPR>;
defm : VPatBinaryVL_XI<vop, instruction_name, "WI",
vti.Vector, wti.Vector, vti.Vector, vti.Mask,
vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass,
!cast<ComplexPattern>(SplatPat#_#uimm5),
uimm5>;
}
}

multiclass VPatBinaryVL_VF<SDNode vop,
string instruction_name,
ValueType result_type,
Expand Down Expand Up @@ -1110,6 +1138,8 @@ defm : VPatNarrowShiftSplatExt_WX<riscv_sra_vl, riscv_zext_vl_oneuse, "PseudoVNS
defm : VPatNarrowShiftSplatExt_WX<riscv_srl_vl, riscv_sext_vl_oneuse, "PseudoVNSRL">;
defm : VPatNarrowShiftSplatExt_WX<riscv_srl_vl, riscv_zext_vl_oneuse, "PseudoVNSRL">;

defm : VPatBinaryNVL_WV_WX_WI<riscv_vnsrl_vl, "PseudoVNSRL">;

foreach vtiTowti = AllWidenableIntVectors in {
defvar vti = vtiTowti.Vti;
defvar wti = vtiTowti.Wti;
Expand Down
26 changes: 10 additions & 16 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp-buildvec.ll
Original file line number Diff line number Diff line change
Expand Up @@ -53,22 +53,16 @@ define <4 x float> @hang_when_merging_stores_after_legalization(<8 x float> %x,
;
; LMULMAX2-LABEL: hang_when_merging_stores_after_legalization:
; LMULMAX2: # %bb.0:
; LMULMAX2-NEXT: addi sp, sp, -16
; LMULMAX2-NEXT: .cfi_def_cfa_offset 16
; LMULMAX2-NEXT: addi a0, sp, 8
; LMULMAX2-NEXT: vsetivli zero, 1, e32, m2, ta, mu
; LMULMAX2-NEXT: vse32.v v10, (a0)
; LMULMAX2-NEXT: mv a0, sp
; LMULMAX2-NEXT: vse32.v v8, (a0)
; LMULMAX2-NEXT: vslidedown.vi v10, v10, 7
; LMULMAX2-NEXT: addi a1, sp, 12
; LMULMAX2-NEXT: vse32.v v10, (a1)
; LMULMAX2-NEXT: vslidedown.vi v8, v8, 7
; LMULMAX2-NEXT: addi a1, sp, 4
; LMULMAX2-NEXT: vse32.v v8, (a1)
; LMULMAX2-NEXT: vsetivli zero, 4, e32, m1, ta, mu
; LMULMAX2-NEXT: vle32.v v8, (a0)
; LMULMAX2-NEXT: addi sp, sp, 16
; LMULMAX2-NEXT: vsetivli zero, 8, e32, m2, ta, mu
; LMULMAX2-NEXT: vid.v v12
; LMULMAX2-NEXT: li a0, 7
; LMULMAX2-NEXT: vmul.vx v14, v12, a0
; LMULMAX2-NEXT: vrgather.vv v12, v8, v14
; LMULMAX2-NEXT: li a0, 12
; LMULMAX2-NEXT: vmv.s.x v0, a0
; LMULMAX2-NEXT: vadd.vi v8, v14, -14
; LMULMAX2-NEXT: vrgather.vv v12, v10, v8, v0.t
; LMULMAX2-NEXT: vmv1r.v v8, v12
; LMULMAX2-NEXT: ret
%z = shufflevector <8 x float> %x, <8 x float> %y, <4 x i32> <i32 0, i32 7, i32 8, i32 15>
ret <4 x float> %z
Expand Down
Loading

0 comments on commit 8d7e73e

Please sign in to comment.