diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 0b7a4c1fe5b0d..3828dcf5cbc56 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -3811,6 +3811,25 @@ static SDValue widenSubVector(SDValue Vec, bool ZeroNewElements, return widenSubVector(VT, Vec, ZeroNewElements, Subtarget, DAG, dl); } +/// Widen a mask vector type to a minimum of v8i1/v16i1 to allow use of KSHIFT +/// and bitcast with integer types. +static MVT widenMaskVectorType(MVT VT, const X86Subtarget &Subtarget) { + assert(VT.getVectorElementType() == MVT::i1 && "Expected bool vector"); + unsigned NumElts = VT.getVectorNumElements(); + if ((!Subtarget.hasDQI() && NumElts == 8) || NumElts < 8) + return Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1; + return VT; +} + +/// Widen a mask vector to a minimum of v8i1/v16i1 to allow use of KSHIFT and +/// bitcast with integer types. +static SDValue widenMaskVector(SDValue Vec, bool ZeroNewElements, + const X86Subtarget &Subtarget, SelectionDAG &DAG, + const SDLoc &dl) { + MVT VT = widenMaskVectorType(Vec.getSimpleValueType(), Subtarget); + return widenSubVector(VT, Vec, ZeroNewElements, Subtarget, DAG, dl); +} + // Helper function to collect subvector ops that are concatenated together, // either by ISD::CONCAT_VECTORS or a ISD::INSERT_SUBVECTOR series. // The subvectors in Ops are guaranteed to be the same type. @@ -4100,9 +4119,7 @@ static SDValue insert1BitVector(SDValue Op, SelectionDAG &DAG, SDValue ZeroIdx = DAG.getIntPtrConstant(0, dl); // Extend to natively supported kshift. - MVT WideOpVT = OpVT; - if ((!Subtarget.hasDQI() && NumElems == 8) || NumElems < 8) - WideOpVT = Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1; + MVT WideOpVT = widenMaskVectorType(OpVT, Subtarget); // Inserting into the lsbs of a zero vector is legal. ISel will insert shifts // if necessary. @@ -9008,16 +9025,12 @@ static SDValue LowerCONCAT_VECTORSvXi1(SDValue Op, // insert_subvector will give us two kshifts. if (isPowerOf2_64(NonZeros) && Zeros != 0 && NonZeros > Zeros && Log2_64(NonZeros) != NumOperands - 1) { - MVT ShiftVT = ResVT; - if ((!Subtarget.hasDQI() && NumElems == 8) || NumElems < 8) - ShiftVT = Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1; unsigned Idx = Log2_64(NonZeros); SDValue SubVec = Op.getOperand(Idx); unsigned SubVecNumElts = SubVec.getSimpleValueType().getVectorNumElements(); - SubVec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, ShiftVT, - DAG.getUNDEF(ShiftVT), SubVec, - DAG.getIntPtrConstant(0, dl)); - Op = DAG.getNode(X86ISD::KSHIFTL, dl, ShiftVT, SubVec, + MVT ShiftVT = widenMaskVectorType(ResVT, Subtarget); + Op = widenSubVector(ShiftVT, SubVec, false, Subtarget, DAG, dl); + Op = DAG.getNode(X86ISD::KSHIFTL, dl, ShiftVT, Op, DAG.getTargetConstant(Idx * SubVecNumElts, dl, MVT::i8)); return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, ResVT, Op, DAG.getIntPtrConstant(0, dl)); @@ -17004,13 +17017,8 @@ static SDValue lower1BitShuffleAsKSHIFTR(const SDLoc &DL, ArrayRef Mask, assert(ShiftAmt >= 0 && "All undef?"); // Great we found a shift right. - MVT WideVT = VT; - if ((!Subtarget.hasDQI() && NumElts == 8) || NumElts < 8) - WideVT = Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1; - SDValue Res = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, WideVT, - DAG.getUNDEF(WideVT), V1, - DAG.getIntPtrConstant(0, DL)); - Res = DAG.getNode(X86ISD::KSHIFTR, DL, WideVT, Res, + SDValue Res = widenMaskVector(V1, false, Subtarget, DAG, DL); + Res = DAG.getNode(X86ISD::KSHIFTR, DL, Res.getValueType(), Res, DAG.getTargetConstant(ShiftAmt, DL, MVT::i8)); return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Res, DAG.getIntPtrConstant(0, DL)); @@ -17107,12 +17115,8 @@ static SDValue lower1BitShuffle(const SDLoc &DL, ArrayRef Mask, unsigned Opcode; int ShiftAmt = match1BitShuffleAsKSHIFT(Opcode, Mask, Offset, Zeroable); if (ShiftAmt >= 0) { - MVT WideVT = VT; - if ((!Subtarget.hasDQI() && NumElts == 8) || NumElts < 8) - WideVT = Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1; - SDValue Res = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, WideVT, - DAG.getUNDEF(WideVT), V, - DAG.getIntPtrConstant(0, DL)); + SDValue Res = widenMaskVector(V, false, Subtarget, DAG, DL); + MVT WideVT = Res.getSimpleValueType(); // Widened right shifts need two shifts to ensure we shift in zeroes. if (Opcode == X86ISD::KSHIFTR && WideVT != VT) { int WideElts = WideVT.getVectorNumElements(); @@ -17650,17 +17654,9 @@ static SDValue ExtractBitFromMaskVector(SDValue Op, SelectionDAG &DAG, // Extending v8i1/v16i1 to 512-bit get better performance on KNL // than extending to 128/256bit. if (NumElts == 1) { - if (Subtarget.hasDQI()) { - Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v8i1, - DAG.getUNDEF(MVT::v8i1), Vec, - DAG.getIntPtrConstant(0, dl)); - return DAG.getBitcast(MVT::i8, Vec); - } - Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, MVT::v16i1, - DAG.getUNDEF(MVT::v16i1), Vec, - DAG.getIntPtrConstant(0, dl)); - return DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, - DAG.getBitcast(MVT::i16, Vec)); + Vec = widenMaskVector(Vec, false, Subtarget, DAG, dl); + MVT IntVT = MVT::getIntegerVT(Vec.getValueType().getVectorNumElements()); + return DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, DAG.getBitcast(IntVT, Vec)); } MVT ExtEltVT = (NumElts <= 8) ? MVT::getIntegerVT(128 / NumElts) : MVT::i8; MVT ExtVecVT = MVT::getVectorVT(ExtEltVT, NumElts); @@ -17674,17 +17670,10 @@ static SDValue ExtractBitFromMaskVector(SDValue Op, SelectionDAG &DAG, return Op; // Extend to natively supported kshift. - unsigned NumElems = VecVT.getVectorNumElements(); - MVT WideVecVT = VecVT; - if ((!Subtarget.hasDQI() && NumElems == 8) || NumElems < 8) { - WideVecVT = Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1; - Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideVecVT, - DAG.getUNDEF(WideVecVT), Vec, - DAG.getIntPtrConstant(0, dl)); - } + Vec = widenMaskVector(Vec, false, Subtarget, DAG, dl); // Use kshiftr instruction to move to the lower element. - Vec = DAG.getNode(X86ISD::KSHIFTR, dl, WideVecVT, Vec, + Vec = DAG.getNode(X86ISD::KSHIFTR, dl, Vec.getSimpleValueType(), Vec, DAG.getTargetConstant(IdxVal, dl, MVT::i8)); return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, Op.getValueType(), Vec, @@ -18176,20 +18165,11 @@ static SDValue LowerEXTRACT_SUBVECTOR(SDValue Op, const X86Subtarget &Subtarget, if (IdxVal == 0) // the operation is legal return Op; - MVT VecVT = Vec.getSimpleValueType(); - unsigned NumElems = VecVT.getVectorNumElements(); - // Extend to natively supported kshift. - MVT WideVecVT = VecVT; - if ((!Subtarget.hasDQI() && NumElems == 8) || NumElems < 8) { - WideVecVT = Subtarget.hasDQI() ? MVT::v8i1 : MVT::v16i1; - Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideVecVT, - DAG.getUNDEF(WideVecVT), Vec, - DAG.getIntPtrConstant(0, dl)); - } + Vec = widenMaskVector(Vec, false, Subtarget, DAG, dl); // Shift to the LSB. - Vec = DAG.getNode(X86ISD::KSHIFTR, dl, WideVecVT, Vec, + Vec = DAG.getNode(X86ISD::KSHIFTR, dl, Vec.getSimpleValueType(), Vec, DAG.getTargetConstant(IdxVal, dl, MVT::i8)); return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, Op.getValueType(), Vec,