diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 15b59afabc61f..4a6e193e09e6e 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -50936,10 +50936,12 @@ static SDValue combineAndShuffleNot(SDNode *N, SelectionDAG &DAG, // Given a target type \p VT, we generate // or (and x, y), (xor z, zext(build_vector (constants))) // given x, y and z are of type \p VT. We can do so, if operands are either -// truncates from VT types, the second operand is a vector of constants or can -// be recursively promoted. +// truncates from VT types, the second operand is a vector of constants, can +// be recursively promoted or is an existing extension we can extend further. static SDValue PromoteMaskArithmetic(SDValue N, const SDLoc &DL, EVT VT, - SelectionDAG &DAG, unsigned Depth) { + SelectionDAG &DAG, + const X86Subtarget &Subtarget, + unsigned Depth) { // Limit recursion to avoid excessive compile times. if (Depth >= SelectionDAG::MaxRecursionDepth) return SDValue(); @@ -50954,7 +50956,8 @@ static SDValue PromoteMaskArithmetic(SDValue N, const SDLoc &DL, EVT VT, if (!TLI.isOperationLegalOrPromote(N.getOpcode(), VT)) return SDValue(); - if (SDValue NN0 = PromoteMaskArithmetic(N0, DL, VT, DAG, Depth + 1)) + if (SDValue NN0 = + PromoteMaskArithmetic(N0, DL, VT, DAG, Subtarget, Depth + 1)) N0 = NN0; else { // The left side has to be a 'trunc'. @@ -50966,14 +50969,19 @@ static SDValue PromoteMaskArithmetic(SDValue N, const SDLoc &DL, EVT VT, return SDValue(); } - if (SDValue NN1 = PromoteMaskArithmetic(N1, DL, VT, DAG, Depth + 1)) + if (SDValue NN1 = + PromoteMaskArithmetic(N1, DL, VT, DAG, Subtarget, Depth + 1)) N1 = NN1; else { - // The right side has to be a 'trunc' or a (foldable) constant. + // The right side has to be a 'trunc', a (foldable) constant or an + // existing extension we can extend further. bool RHSTrunc = N1.getOpcode() == ISD::TRUNCATE && N1.getOperand(0).getValueType() == VT; if (RHSTrunc) N1 = N1.getOperand(0); + else if (ISD::isExtVecInRegOpcode(N1.getOpcode()) && VT.is256BitVector() && + Subtarget.hasInt256() && N1.hasOneUse()) + N1 = DAG.getNode(N1.getOpcode(), DL, VT, N1.getOperand(0)); else if (SDValue Cst = DAG.FoldConstantArithmetic(ISD::ZERO_EXTEND, DL, VT, {N1})) N1 = Cst; @@ -51003,7 +51011,7 @@ static SDValue PromoteMaskArithmetic(SDValue N, const SDLoc &DL, EVT NarrowVT = Narrow.getValueType(); // Generate the wide operation. - SDValue Op = PromoteMaskArithmetic(Narrow, DL, VT, DAG, 0); + SDValue Op = PromoteMaskArithmetic(Narrow, DL, VT, DAG, Subtarget, 0); if (!Op) return SDValue(); switch (N.getOpcode()) { diff --git a/llvm/test/CodeGen/X86/bitcast-int-to-vector-bool-sext.ll b/llvm/test/CodeGen/X86/bitcast-int-to-vector-bool-sext.ll index 9e1686e19ce1b..474be4465d9b7 100644 --- a/llvm/test/CodeGen/X86/bitcast-int-to-vector-bool-sext.ll +++ b/llvm/test/CodeGen/X86/bitcast-int-to-vector-bool-sext.ll @@ -733,11 +733,8 @@ define <8 x i32> @PR157382(ptr %p0, ptr %p1, ptr %p2) { ; AVX2-NEXT: vpcmpeqb %xmm3, %xmm2, %xmm2 ; AVX2-NEXT: vpcmpeqd %xmm3, %xmm3, %xmm3 ; AVX2-NEXT: vpxor %xmm3, %xmm2, %xmm2 -; AVX2-NEXT: vpmovsxbw %xmm2, %xmm2 -; AVX2-NEXT: vextracti128 $1, %ymm1, %xmm3 -; AVX2-NEXT: vpackssdw %xmm3, %xmm1, %xmm1 -; AVX2-NEXT: vpor %xmm2, %xmm1, %xmm1 -; AVX2-NEXT: vpmovsxwd %xmm1, %ymm1 +; AVX2-NEXT: vpmovsxbd %xmm2, %ymm2 +; AVX2-NEXT: vpor %ymm2, %ymm1, %ymm1 ; AVX2-NEXT: vpand %ymm0, %ymm1, %ymm0 ; AVX2-NEXT: retq ;