diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 84c7ff58ae9b0..e91f68425522f 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -43123,6 +43123,104 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG, return SDValue(); } +// Convert (vXiY *ext(vXi1 bitcast(iX))) to extend_in_reg(broadcast(iX)). +// This is more or less the reverse of combineBitcastvxi1. +static SDValue combineToExtendBoolVectorInReg( + unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N0, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { + if (Opcode != ISD::SIGN_EXTEND && Opcode != ISD::ZERO_EXTEND && + Opcode != ISD::ANY_EXTEND) + return SDValue(); + if (!DCI.isBeforeLegalizeOps()) + return SDValue(); + if (!Subtarget.hasSSE2() || Subtarget.hasAVX512()) + return SDValue(); + + EVT SVT = VT.getScalarType(); + EVT InSVT = N0.getValueType().getScalarType(); + unsigned EltSizeInBits = SVT.getSizeInBits(); + + // Input type must be extending a bool vector (bit-casted from a scalar + // integer) to legal integer types. + if (!VT.isVector()) + return SDValue(); + if (SVT != MVT::i64 && SVT != MVT::i32 && SVT != MVT::i16 && SVT != MVT::i8) + return SDValue(); + if (InSVT != MVT::i1 || N0.getOpcode() != ISD::BITCAST) + return SDValue(); + + SDValue N00 = N0.getOperand(0); + EVT SclVT = N00.getValueType(); + if (!SclVT.isScalarInteger()) + return SDValue(); + + SDValue Vec; + SmallVector ShuffleMask; + unsigned NumElts = VT.getVectorNumElements(); + assert(NumElts == SclVT.getSizeInBits() && "Unexpected bool vector size"); + + // Broadcast the scalar integer to the vector elements. + if (NumElts > EltSizeInBits) { + // If the scalar integer is greater than the vector element size, then we + // must split it down into sub-sections for broadcasting. For example: + // i16 -> v16i8 (i16 -> v8i16 -> v16i8) with 2 sub-sections. + // i32 -> v32i8 (i32 -> v8i32 -> v32i8) with 4 sub-sections. + assert((NumElts % EltSizeInBits) == 0 && "Unexpected integer scale"); + unsigned Scale = NumElts / EltSizeInBits; + EVT BroadcastVT = EVT::getVectorVT(*DAG.getContext(), SclVT, EltSizeInBits); + Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, BroadcastVT, N00); + Vec = DAG.getBitcast(VT, Vec); + + for (unsigned i = 0; i != Scale; ++i) + ShuffleMask.append(EltSizeInBits, i); + Vec = DAG.getVectorShuffle(VT, DL, Vec, Vec, ShuffleMask); + } else if (Subtarget.hasAVX2() && NumElts < EltSizeInBits && + (SclVT == MVT::i8 || SclVT == MVT::i16 || SclVT == MVT::i32)) { + // If we have register broadcast instructions, use the scalar size as the + // element type for the shuffle. Then cast to the wider element type. The + // widened bits won't be used, and this might allow the use of a broadcast + // load. + assert((EltSizeInBits % NumElts) == 0 && "Unexpected integer scale"); + unsigned Scale = EltSizeInBits / NumElts; + EVT BroadcastVT = + EVT::getVectorVT(*DAG.getContext(), SclVT, NumElts * Scale); + Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, BroadcastVT, N00); + ShuffleMask.append(NumElts * Scale, 0); + Vec = DAG.getVectorShuffle(BroadcastVT, DL, Vec, Vec, ShuffleMask); + Vec = DAG.getBitcast(VT, Vec); + } else { + // For smaller scalar integers, we can simply any-extend it to the vector + // element size (we don't care about the upper bits) and broadcast it to all + // elements. + SDValue Scl = DAG.getAnyExtOrTrunc(N00, DL, SVT); + Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VT, Scl); + ShuffleMask.append(NumElts, 0); + Vec = DAG.getVectorShuffle(VT, DL, Vec, Vec, ShuffleMask); + } + + // Now, mask the relevant bit in each element. + SmallVector Bits; + for (unsigned i = 0; i != NumElts; ++i) { + int BitIdx = (i % EltSizeInBits); + APInt Bit = APInt::getBitsSet(EltSizeInBits, BitIdx, BitIdx + 1); + Bits.push_back(DAG.getConstant(Bit, DL, SVT)); + } + SDValue BitMask = DAG.getBuildVector(VT, DL, Bits); + Vec = DAG.getNode(ISD::AND, DL, VT, Vec, BitMask); + + // Compare against the bitmask and extend the result. + EVT CCVT = VT.changeVectorElementType(MVT::i1); + Vec = DAG.getSetCC(DL, CCVT, Vec, BitMask, ISD::SETEQ); + Vec = DAG.getSExtOrTrunc(Vec, DL, VT); + + // For SEXT, this is now done, otherwise shift the result down for + // zero-extension. + if (Opcode == ISD::SIGN_EXTEND) + return Vec; + return DAG.getNode(ISD::SRL, DL, VT, Vec, + DAG.getConstant(EltSizeInBits - 1, DL, VT)); +} + /// If a vector select has an operand that is -1 or 0, try to simplify the /// select to a bitwise logic operation. /// TODO: Move to DAGCombiner, possibly using TargetLowering::hasAndNot()? @@ -50420,105 +50518,6 @@ static SDValue combineToExtendCMOV(SDNode *Extend, SelectionDAG &DAG) { return Res; } -// Convert (vXiY *ext(vXi1 bitcast(iX))) to extend_in_reg(broadcast(iX)). -// This is more or less the reverse of combineBitcastvxi1. -static SDValue combineToExtendBoolVectorInReg( - unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N0, SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI, const X86Subtarget &Subtarget) { - if (Opcode != ISD::SIGN_EXTEND && Opcode != ISD::ZERO_EXTEND && - Opcode != ISD::ANY_EXTEND) - return SDValue(); - if (!DCI.isBeforeLegalizeOps()) - return SDValue(); - if (!Subtarget.hasSSE2() || Subtarget.hasAVX512()) - return SDValue(); - - EVT SVT = VT.getScalarType(); - EVT InSVT = N0.getValueType().getScalarType(); - unsigned EltSizeInBits = SVT.getSizeInBits(); - - // Input type must be extending a bool vector (bit-casted from a scalar - // integer) to legal integer types. - if (!VT.isVector()) - return SDValue(); - if (SVT != MVT::i64 && SVT != MVT::i32 && SVT != MVT::i16 && SVT != MVT::i8) - return SDValue(); - if (InSVT != MVT::i1 || N0.getOpcode() != ISD::BITCAST) - return SDValue(); - - SDValue N00 = N0.getOperand(0); - EVT SclVT = N00.getValueType(); - if (!SclVT.isScalarInteger()) - return SDValue(); - - SDValue Vec; - SmallVector ShuffleMask; - unsigned NumElts = VT.getVectorNumElements(); - assert(NumElts == SclVT.getSizeInBits() && "Unexpected bool vector size"); - - // Broadcast the scalar integer to the vector elements. - if (NumElts > EltSizeInBits) { - // If the scalar integer is greater than the vector element size, then we - // must split it down into sub-sections for broadcasting. For example: - // i16 -> v16i8 (i16 -> v8i16 -> v16i8) with 2 sub-sections. - // i32 -> v32i8 (i32 -> v8i32 -> v32i8) with 4 sub-sections. - assert((NumElts % EltSizeInBits) == 0 && "Unexpected integer scale"); - unsigned Scale = NumElts / EltSizeInBits; - EVT BroadcastVT = - EVT::getVectorVT(*DAG.getContext(), SclVT, EltSizeInBits); - Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, BroadcastVT, N00); - Vec = DAG.getBitcast(VT, Vec); - - for (unsigned i = 0; i != Scale; ++i) - ShuffleMask.append(EltSizeInBits, i); - Vec = DAG.getVectorShuffle(VT, DL, Vec, Vec, ShuffleMask); - } else if (Subtarget.hasAVX2() && NumElts < EltSizeInBits && - (SclVT == MVT::i8 || SclVT == MVT::i16 || SclVT == MVT::i32)) { - // If we have register broadcast instructions, use the scalar size as the - // element type for the shuffle. Then cast to the wider element type. The - // widened bits won't be used, and this might allow the use of a broadcast - // load. - assert((EltSizeInBits % NumElts) == 0 && "Unexpected integer scale"); - unsigned Scale = EltSizeInBits / NumElts; - EVT BroadcastVT = - EVT::getVectorVT(*DAG.getContext(), SclVT, NumElts * Scale); - Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, BroadcastVT, N00); - ShuffleMask.append(NumElts * Scale, 0); - Vec = DAG.getVectorShuffle(BroadcastVT, DL, Vec, Vec, ShuffleMask); - Vec = DAG.getBitcast(VT, Vec); - } else { - // For smaller scalar integers, we can simply any-extend it to the vector - // element size (we don't care about the upper bits) and broadcast it to all - // elements. - SDValue Scl = DAG.getAnyExtOrTrunc(N00, DL, SVT); - Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, VT, Scl); - ShuffleMask.append(NumElts, 0); - Vec = DAG.getVectorShuffle(VT, DL, Vec, Vec, ShuffleMask); - } - - // Now, mask the relevant bit in each element. - SmallVector Bits; - for (unsigned i = 0; i != NumElts; ++i) { - int BitIdx = (i % EltSizeInBits); - APInt Bit = APInt::getBitsSet(EltSizeInBits, BitIdx, BitIdx + 1); - Bits.push_back(DAG.getConstant(Bit, DL, SVT)); - } - SDValue BitMask = DAG.getBuildVector(VT, DL, Bits); - Vec = DAG.getNode(ISD::AND, DL, VT, Vec, BitMask); - - // Compare against the bitmask and extend the result. - EVT CCVT = VT.changeVectorElementType(MVT::i1); - Vec = DAG.getSetCC(DL, CCVT, Vec, BitMask, ISD::SETEQ); - Vec = DAG.getSExtOrTrunc(Vec, DL, VT); - - // For SEXT, this is now done, otherwise shift the result down for - // zero-extension. - if (Opcode == ISD::SIGN_EXTEND) - return Vec; - return DAG.getNode(ISD::SRL, DL, VT, Vec, - DAG.getConstant(EltSizeInBits - 1, DL, VT)); -} - // Attempt to combine a (sext/zext (setcc)) to a setcc with a xmm/ymm/zmm // result type. static SDValue combineExtSetcc(SDNode *N, SelectionDAG &DAG,