Skip to content

Commit

Permalink
[RISCV][NFC] Reuse getDeinterleaveViaVNSRL to lower deinterleave intr…
Browse files Browse the repository at this point in the history
…insics

This modifies it to work on both scalable and fixed vectors

Reviewed By: reames

Differential Revision: https://reviews.llvm.org/D144584
  • Loading branch information
lukel97 committed Feb 23, 2023
1 parent 8d15e72 commit e340e9e
Showing 1 changed file with 43 additions and 48 deletions.
91 changes: 43 additions & 48 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Expand Up @@ -3113,27 +3113,36 @@ static int isElementRotate(int &LoSrc, int &HiSrc, ArrayRef<int> Mask) {
}

// Lower a deinterleave shuffle to vnsrl.
static SDValue getDeinterleaveViaVNSRL(const SDLoc &DL, MVT VT,
MVT ContainerVT,
SDValue Src, bool EvenElts,
SDValue TrueMask, SDValue VL,
// [a, p, b, q, c, r, d, s] -> [a, b, c, d] (EvenElts == true)
// -> [p, q, r, s] (EvenElts == false)
// VT is the type of the vector to return, <[vscale x ]n x ty>
// Src is the vector to deinterleave of type <[vscale x ]n*2 x ty>
static SDValue getDeinterleaveViaVNSRL(const SDLoc &DL, MVT VT, SDValue Src,
bool EvenElts,
const RISCVSubtarget &Subtarget,
SelectionDAG &DAG) {
// Convert the source using a container type with twice the elements. Since
// source VT is legal and twice this VT, we know VT isn't LMUL=8 so it is
// safe to double.
MVT DoubleContainerVT =
MVT::getVectorVT(ContainerVT.getVectorElementType(),
ContainerVT.getVectorElementCount() * 2);
Src = convertToScalableVector(DoubleContainerVT, Src, DAG, Subtarget);

// Convert the vector to a wider integer type with the original element
// count. This also converts FP to int.
// The result is a vector of type <m x n x ty>
MVT ContainerVT = VT;
// Convert fixed vectors to scalable if needed
if (ContainerVT.isFixedLengthVector()) {
assert(Src.getSimpleValueType().isFixedLengthVector());
ContainerVT = getContainerForFixedLengthVector(DAG, ContainerVT, Subtarget);

// The source is a vector of type <m x n*2 x ty>
MVT SrcContainerVT =
MVT::getVectorVT(ContainerVT.getVectorElementType(),
ContainerVT.getVectorElementCount() * 2);
Src = convertToScalableVector(SrcContainerVT, Src, DAG, Subtarget);
}

auto [TrueMask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);

// Bitcast the source vector from <m x n*2 x ty> -> <m x n x ty*2>
// This also converts FP to int.
unsigned EltBits = ContainerVT.getScalarSizeInBits();
MVT WideIntContainerVT =
MVT::getVectorVT(MVT::getIntegerVT(EltBits * 2),
ContainerVT.getVectorElementCount());
Src = DAG.getBitcast(WideIntContainerVT, Src);
MVT WideSrcContainerVT = MVT::getVectorVT(
MVT::getIntegerVT(EltBits * 2), ContainerVT.getVectorElementCount());
Src = DAG.getBitcast(WideSrcContainerVT, Src);

// The integer version of the container type.
MVT IntContainerVT = ContainerVT.changeVectorElementTypeToInteger();
Expand All @@ -3150,7 +3159,9 @@ static SDValue getDeinterleaveViaVNSRL(const SDLoc &DL, MVT VT,
// Cast back to FP if needed.
Res = DAG.getBitcast(ContainerVT, Res);

return convertFromScalableVector(VT, Res, DAG, Subtarget);
if (VT.isFixedLengthVector())
Res = convertFromScalableVector(VT, Res, DAG, Subtarget);
return Res;
}

static SDValue
Expand Down Expand Up @@ -3461,9 +3472,12 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
return convertFromScalableVector(VT, Res, DAG, Subtarget);
}

if (isDeinterleaveShuffle(VT, ContainerVT, V1, V2, Mask, Subtarget))
return getDeinterleaveViaVNSRL(DL, VT, ContainerVT, V1.getOperand(0),
Mask[0] == 0, TrueMask, VL, Subtarget, DAG);
// If this is a deinterleave and we can widen the vector, then we can use
// vnsrl to deinterleave.
if (isDeinterleaveShuffle(VT, ContainerVT, V1, V2, Mask, Subtarget)) {
return getDeinterleaveViaVNSRL(DL, VT, V1.getOperand(0), Mask[0] == 0,
Subtarget, DAG);
}

// Detect an interleave shuffle and lower to
// (vmaccu.vx (vwaddu.vx lohalf(V1), lohalf(V2)), lohalf(V2), (2^eltbits - 1))
Expand Down Expand Up @@ -6619,33 +6633,14 @@ SDValue RISCVTargetLowering::lowerVECTOR_DEINTERLEAVE(SDValue Op,
auto [Mask, VL] = getDefaultScalableVLOps(ConcatVT, DL, DAG, Subtarget);
SDValue Passthru = DAG.getUNDEF(ConcatVT);

// If the element type is smaller than ELEN, then we can deinterleave
// through vnsrl.wi
// We can deinterleave through vnsrl.wi if the element type is smaller than
// ELEN
if (VecVT.getScalarSizeInBits() < Subtarget.getELEN()) {
// Bitcast the concatenated vector from <n x m x ty> -> <n x m / 2 x ty * 2>
// This is also casts FPs to ints
MVT WideVT = MVT::getVectorVT(
MVT::getIntegerVT(ConcatVT.getScalarSizeInBits() * 2),
ConcatVT.getVectorElementCount().divideCoefficientBy(2));
SDValue Wide = DAG.getBitcast(WideVT, Concat);

MVT NarrowVT = VecVT.changeVectorElementTypeToInteger();
SDValue Passthru = DAG.getUNDEF(VecVT);

SDValue Even = DAG.getNode(
RISCVISD::VNSRL_VL, DL, NarrowVT, Wide,
DAG.getSplatVector(NarrowVT, DL, DAG.getConstant(0, DL, XLenVT)),
Passthru, Mask, VL);
SDValue Odd = DAG.getNode(
RISCVISD::VNSRL_VL, DL, NarrowVT, Wide,
DAG.getSplatVector(
NarrowVT, DL,
DAG.getConstant(VecVT.getScalarSizeInBits(), DL, XLenVT)),
Passthru, Mask, VL);

// Bitcast the results back in case it was casted from an FP vector
return DAG.getMergeValues(
{DAG.getBitcast(VecVT, Even), DAG.getBitcast(VecVT, Odd)}, DL);
SDValue Even =
getDeinterleaveViaVNSRL(DL, VecVT, Concat, true, Subtarget, DAG);
SDValue Odd =
getDeinterleaveViaVNSRL(DL, VecVT, Concat, false, Subtarget, DAG);
return DAG.getMergeValues({Even, Odd}, DL);
}

// For the indices, use the same SEW to avoid an extra vsetvli
Expand Down

0 comments on commit e340e9e

Please sign in to comment.