Skip to content

Commit

Permalink
[X86] Add widenMaskVector helper function to remove duplicated code f…
Browse files Browse the repository at this point in the history
…or widening mask vectors for KSHIFT etc.
  • Loading branch information
RKSimon committed Aug 31, 2023
1 parent e87d2d2 commit 81dc54e
Showing 1 changed file with 34 additions and 54 deletions.
88 changes: 34 additions & 54 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3811,6 +3811,25 @@ static SDValue widenSubVector(SDValue Vec, bool ZeroNewElements,
return widenSubVector(VT, Vec, ZeroNewElements, Subtarget, DAG, dl);
}

/// Widen a mask vector type to a minimum of v8i1/v16i1 to allow use of KSHIFT
/// and bitcast with integer types.
static MVT widenMaskVectorType(MVT VT, const X86Subtarget &Subtarget) {
assert(VT.getVectorElementType() == MVT::i1 && "Expected bool vector");
unsigned NumElts = VT.getVectorNumElements();
if ((!Subtarget.hasDQI() && NumElts == 8) || NumElts < 8)
return Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1;
return VT;
}

/// Widen a mask vector to a minimum of v8i1/v16i1 to allow use of KSHIFT and
/// bitcast with integer types.
static SDValue widenMaskVector(SDValue Vec, bool ZeroNewElements,
const X86Subtarget &Subtarget, SelectionDAG &DAG,
const SDLoc &dl) {
MVT VT = widenMaskVectorType(Vec.getSimpleValueType(), Subtarget);
return widenSubVector(VT, Vec, ZeroNewElements, Subtarget, DAG, dl);
}

// Helper function to collect subvector ops that are concatenated together,
// either by ISD::CONCAT_VECTORS or a ISD::INSERT_SUBVECTOR series.
// The subvectors in Ops are guaranteed to be the same type.
Expand Down Expand Up @@ -4100,9 +4119,7 @@ static SDValue insert1BitVector(SDValue Op, SelectionDAG &DAG,
SDValue ZeroIdx = DAG.getIntPtrConstant(0, dl);

// Extend to natively supported kshift.
MVT WideOpVT = OpVT;
if ((!Subtarget.hasDQI() && NumElems == 8) || NumElems < 8)
WideOpVT = Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1;
MVT WideOpVT = widenMaskVectorType(OpVT, Subtarget);

// Inserting into the lsbs of a zero vector is legal. ISel will insert shifts
// if necessary.
Expand Down Expand Up @@ -9008,16 +9025,12 @@ static SDValue LowerCONCAT_VECTORSvXi1(SDValue Op,
// insert_subvector will give us two kshifts.
if (isPowerOf2_64(NonZeros) && Zeros != 0 && NonZeros > Zeros &&
Log2_64(NonZeros) != NumOperands - 1) {
MVT ShiftVT = ResVT;
if ((!Subtarget.hasDQI() && NumElems == 8) || NumElems < 8)
ShiftVT = Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1;
unsigned Idx = Log2_64(NonZeros);
SDValue SubVec = Op.getOperand(Idx);
unsigned SubVecNumElts = SubVec.getSimpleValueType().getVectorNumElements();
SubVec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, ShiftVT,
DAG.getUNDEF(ShiftVT), SubVec,
DAG.getIntPtrConstant(0, dl));
Op = DAG.getNode(X86ISD::KSHIFTL, dl, ShiftVT, SubVec,
MVT ShiftVT = widenMaskVectorType(ResVT, Subtarget);
Op = widenSubVector(ShiftVT, SubVec, false, Subtarget, DAG, dl);
Op = DAG.getNode(X86ISD::KSHIFTL, dl, ShiftVT, Op,
DAG.getTargetConstant(Idx * SubVecNumElts, dl, MVT::i8));
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, ResVT, Op,
DAG.getIntPtrConstant(0, dl));
Expand Down Expand Up @@ -17004,13 +17017,8 @@ static SDValue lower1BitShuffleAsKSHIFTR(const SDLoc &DL, ArrayRef<int> Mask,
assert(ShiftAmt >= 0 && "All undef?");

// Great we found a shift right.
MVT WideVT = VT;
if ((!Subtarget.hasDQI() && NumElts == 8) || NumElts < 8)
WideVT = Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1;
SDValue Res = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, WideVT,
DAG.getUNDEF(WideVT), V1,
DAG.getIntPtrConstant(0, DL));
Res = DAG.getNode(X86ISD::KSHIFTR, DL, WideVT, Res,
SDValue Res = widenMaskVector(V1, false, Subtarget, DAG, DL);
Res = DAG.getNode(X86ISD::KSHIFTR, DL, Res.getValueType(), Res,
DAG.getTargetConstant(ShiftAmt, DL, MVT::i8));
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Res,
DAG.getIntPtrConstant(0, DL));
Expand Down Expand Up @@ -17107,12 +17115,8 @@ static SDValue lower1BitShuffle(const SDLoc &DL, ArrayRef<int> Mask,
unsigned Opcode;
int ShiftAmt = match1BitShuffleAsKSHIFT(Opcode, Mask, Offset, Zeroable);
if (ShiftAmt >= 0) {
MVT WideVT = VT;
if ((!Subtarget.hasDQI() && NumElts == 8) || NumElts < 8)
WideVT = Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1;
SDValue Res = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, WideVT,
DAG.getUNDEF(WideVT), V,
DAG.getIntPtrConstant(0, DL));
SDValue Res = widenMaskVector(V, false, Subtarget, DAG, DL);
MVT WideVT = Res.getSimpleValueType();
// Widened right shifts need two shifts to ensure we shift in zeroes.
if (Opcode == X86ISD::KSHIFTR && WideVT != VT) {
int WideElts = WideVT.getVectorNumElements();
Expand Down Expand Up @@ -17650,17 +17654,9 @@ static SDValue ExtractBitFromMaskVector(SDValue Op, SelectionDAG &DAG,
// Extending v8i1/v16i1 to 512-bit get better performance on KNL
// than extending to 128/256bit.
if (NumElts == 1) {
if (Subtarget.hasDQI()) {
Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v8i1,
DAG.getUNDEF(MVT::v8i1), Vec,
DAG.getIntPtrConstant(0, dl));
return DAG.getBitcast(MVT::i8, Vec);
}
Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v16i1,
DAG.getUNDEF(MVT::v16i1), Vec,
DAG.getIntPtrConstant(0, dl));
return DAG.getNode(ISD::TRUNCATE, dl, MVT::i8,
DAG.getBitcast(MVT::i16, Vec));
Vec = widenMaskVector(Vec, false, Subtarget, DAG, dl);
MVT IntVT = MVT::getIntegerVT(Vec.getValueType().getVectorNumElements());
return DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, DAG.getBitcast(IntVT, Vec));
}
MVT ExtEltVT = (NumElts <= 8) ? MVT::getIntegerVT(128 / NumElts) : MVT::i8;
MVT ExtVecVT = MVT::getVectorVT(ExtEltVT, NumElts);
Expand All @@ -17674,17 +17670,10 @@ static SDValue ExtractBitFromMaskVector(SDValue Op, SelectionDAG &DAG,
return Op;

// Extend to natively supported kshift.
unsigned NumElems = VecVT.getVectorNumElements();
MVT WideVecVT = VecVT;
if ((!Subtarget.hasDQI() && NumElems == 8) || NumElems < 8) {
WideVecVT = Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1;
Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideVecVT,
DAG.getUNDEF(WideVecVT), Vec,
DAG.getIntPtrConstant(0, dl));
}
Vec = widenMaskVector(Vec, false, Subtarget, DAG, dl);

// Use kshiftr instruction to move to the lower element.
Vec = DAG.getNode(X86ISD::KSHIFTR, dl, WideVecVT, Vec,
Vec = DAG.getNode(X86ISD::KSHIFTR, dl, Vec.getSimpleValueType(), Vec,
DAG.getTargetConstant(IdxVal, dl, MVT::i8));

return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, Op.getValueType(), Vec,
Expand Down Expand Up @@ -18176,20 +18165,11 @@ static SDValue LowerEXTRACT_SUBVECTOR(SDValue Op, const X86Subtarget &Subtarget,
if (IdxVal == 0) // the operation is legal
return Op;

MVT VecVT = Vec.getSimpleValueType();
unsigned NumElems = VecVT.getVectorNumElements();

// Extend to natively supported kshift.
MVT WideVecVT = VecVT;
if ((!Subtarget.hasDQI() && NumElems == 8) || NumElems < 8) {
WideVecVT = Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1;
Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideVecVT,
DAG.getUNDEF(WideVecVT), Vec,
DAG.getIntPtrConstant(0, dl));
}
Vec = widenMaskVector(Vec, false, Subtarget, DAG, dl);

// Shift to the LSB.
Vec = DAG.getNode(X86ISD::KSHIFTR, dl, WideVecVT, Vec,
Vec = DAG.getNode(X86ISD::KSHIFTR, dl, Vec.getSimpleValueType(), Vec,
DAG.getTargetConstant(IdxVal, dl, MVT::i8));

return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, Op.getValueType(), Vec,
Expand Down

0 comments on commit 81dc54e

Please sign in to comment.