diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index ee2918e419404..9235042c54a4a 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -21891,6 +21891,109 @@ static SDValue combineConcatVectorOfCasts(SDNode *N, SelectionDAG &DAG) { return DAG.getNode(CastOpcode, DL, VT, NewConcat); } +// See if this is a simple CONCAT_VECTORS with no UNDEF operands, and if one of +// the operands is a SHUFFLE_VECTOR, and all other operands are also operands +// to that SHUFFLE_VECTOR, create wider SHUFFLE_VECTOR. +static SDValue combineConcatVectorOfShuffleAndItsOperands( + SDNode *N, SelectionDAG &DAG, const TargetLowering &TLI, bool LegalTypes, + bool LegalOperations) { + EVT VT = N->getValueType(0); + EVT OpVT = N->getOperand(0).getValueType(); + if (VT.isScalableVector()) + return SDValue(); + + // For now, only allow simple 2-operand concatenations. + if (N->getNumOperands() != 2) + return SDValue(); + + // Don't create illegal types/shuffles when not allowed to. + if ((LegalTypes && !TLI.isTypeLegal(VT)) || + (LegalOperations && + !TLI.isOperationLegalOrCustom(ISD::VECTOR_SHUFFLE, VT))) + return SDValue(); + + // Analyze all of the operands of the CONCAT_VECTORS. Out of all of them, + // we want to find one that is: (1) a SHUFFLE_VECTOR (2) only used by us, + // and (3) all operands of CONCAT_VECTORS must be either that SHUFFLE_VECTOR, + // or one of the operands of that SHUFFLE_VECTOR (but not UNDEF!). + // (4) and for now, the SHUFFLE_VECTOR must be unary. + ShuffleVectorSDNode *SVN = nullptr; + for (SDValue Op : N->ops()) { + if (auto *CurSVN = dyn_cast(Op); + CurSVN && CurSVN->getOperand(1).isUndef() && N->isOnlyUserOf(CurSVN) && + all_of(N->ops(), [CurSVN](SDValue Op) { + // FIXME: can we allow UNDEF operands? + return !Op.isUndef() && + (Op.getNode() == CurSVN || is_contained(CurSVN->ops(), Op)); + })) { + SVN = CurSVN; + break; + } + } + if (!SVN) + return SDValue(); + + // We are going to pad the shuffle operands, so any indice, that was picking + // from the second operand, must be adjusted. + SmallVector AdjustedMask; + AdjustedMask.reserve(SVN->getMask().size()); + assert(SVN->getOperand(1).isUndef() && "Expected unary shuffle!"); + append_range(AdjustedMask, SVN->getMask()); + + // Identity masks for the operands of the (padded) shuffle. + SmallVector IdentityMask(2 * OpVT.getVectorNumElements()); + MutableArrayRef FirstShufOpIdentityMask = + MutableArrayRef(IdentityMask) + .take_front(OpVT.getVectorNumElements()); + MutableArrayRef SecondShufOpIdentityMask = + MutableArrayRef(IdentityMask).take_back(OpVT.getVectorNumElements()); + std::iota(FirstShufOpIdentityMask.begin(), FirstShufOpIdentityMask.end(), 0); + std::iota(SecondShufOpIdentityMask.begin(), SecondShufOpIdentityMask.end(), + VT.getVectorNumElements()); + + // New combined shuffle mask. + SmallVector Mask; + Mask.reserve(VT.getVectorNumElements()); + for (SDValue Op : N->ops()) { + assert(!Op.isUndef() && "Not expecting to concatenate UNDEF."); + if (Op.getNode() == SVN) { + append_range(Mask, AdjustedMask); + continue; + } + if (Op == SVN->getOperand(0)) { + append_range(Mask, FirstShufOpIdentityMask); + continue; + } + if (Op == SVN->getOperand(1)) { + append_range(Mask, SecondShufOpIdentityMask); + continue; + } + llvm_unreachable("Unexpected operand!"); + } + + // Don't create illegal shuffle masks. + if (!TLI.isShuffleMaskLegal(Mask, VT)) + return SDValue(); + + // Pad the shuffle operands with UNDEF. + SDLoc dl(N); + std::array ShufOps; + for (auto I : zip(SVN->ops(), ShufOps)) { + SDValue ShufOp = std::get<0>(I); + SDValue &NewShufOp = std::get<1>(I); + if (ShufOp.isUndef()) + NewShufOp = DAG.getUNDEF(VT); + else { + SmallVector ShufOpParts(N->getNumOperands(), + DAG.getUNDEF(OpVT)); + ShufOpParts[0] = ShufOp; + NewShufOp = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, ShufOpParts); + } + } + // Finally, create the new wide shuffle. + return DAG.getVectorShuffle(VT, dl, ShufOps[0], ShufOps[1], Mask); +} + SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { // If we only have one input vector, we don't need to do any concatenation. if (N->getNumOperands() == 1) @@ -22026,6 +22129,10 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { if (SDValue V = combineConcatVectorOfCasts(N, DAG)) return V; + if (SDValue V = combineConcatVectorOfShuffleAndItsOperands( + N, DAG, TLI, LegalTypes, LegalOperations)) + return V; + // Type legalization of vectors and DAG canonicalization of SHUFFLE_VECTOR // nodes often generate nop CONCAT_VECTOR nodes. Scan the CONCAT_VECTOR // operands and look for a CONCAT operations that place the incoming vectors diff --git a/llvm/test/CodeGen/X86/vector-shuffle-concatenation.ll b/llvm/test/CodeGen/X86/vector-shuffle-concatenation.ll index 4ffe97e6de236..52fc059cc6818 100644 --- a/llvm/test/CodeGen/X86/vector-shuffle-concatenation.ll +++ b/llvm/test/CodeGen/X86/vector-shuffle-concatenation.ll @@ -23,32 +23,33 @@ define void @concat_a_to_shuf_of_a(ptr %a.ptr, ptr %dst) { ; AVX: # %bb.0: ; AVX-NEXT: vmovaps (%rdi), %xmm0 ; AVX-NEXT: vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX-NEXT: vmovaps %xmm0, 16(%rsi) -; AVX-NEXT: vmovaps %xmm1, (%rsi) +; AVX-NEXT: vinsertf128 $1, %xmm0, %ymm1, %ymm0 +; AVX-NEXT: vmovaps %ymm0, (%rsi) +; AVX-NEXT: vzeroupper ; AVX-NEXT: retq ; ; AVX2-LABEL: concat_a_to_shuf_of_a: ; AVX2: # %bb.0: ; AVX2-NEXT: vmovaps (%rdi), %xmm0 -; AVX2-NEXT: vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX2-NEXT: vmovaps %xmm0, 16(%rsi) -; AVX2-NEXT: vmovaps %xmm1, (%rsi) +; AVX2-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[1,0,0,1] +; AVX2-NEXT: vmovaps %ymm0, (%rsi) +; AVX2-NEXT: vzeroupper ; AVX2-NEXT: retq ; ; AVX512F-LABEL: concat_a_to_shuf_of_a: ; AVX512F: # %bb.0: ; AVX512F-NEXT: vmovaps (%rdi), %xmm0 -; AVX512F-NEXT: vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX512F-NEXT: vmovaps %xmm0, 16(%rsi) -; AVX512F-NEXT: vmovaps %xmm1, (%rsi) +; AVX512F-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[1,0,0,1] +; AVX512F-NEXT: vmovaps %ymm0, (%rsi) +; AVX512F-NEXT: vzeroupper ; AVX512F-NEXT: retq ; ; AVX512BW-LABEL: concat_a_to_shuf_of_a: ; AVX512BW: # %bb.0: ; AVX512BW-NEXT: vmovaps (%rdi), %xmm0 -; AVX512BW-NEXT: vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX512BW-NEXT: vmovaps %xmm0, 16(%rsi) -; AVX512BW-NEXT: vmovaps %xmm1, (%rsi) +; AVX512BW-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[1,0,0,1] +; AVX512BW-NEXT: vmovaps %ymm0, (%rsi) +; AVX512BW-NEXT: vzeroupper ; AVX512BW-NEXT: retq %a = load <2 x i64>, ptr %a.ptr, align 64 %shuffle = shufflevector <2 x i64> %a, <2 x i64> poison, <2 x i32> @@ -69,32 +70,33 @@ define void @concat_shuf_of_a_to_a(ptr %a.ptr, ptr %b.ptr, ptr %dst) { ; AVX: # %bb.0: ; AVX-NEXT: vmovaps (%rdi), %xmm0 ; AVX-NEXT: vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX-NEXT: vmovaps %xmm0, (%rdx) -; AVX-NEXT: vmovaps %xmm1, 16(%rdx) +; AVX-NEXT: vinsertf128 $1, %xmm1, %ymm0, %ymm0 +; AVX-NEXT: vmovaps %ymm0, (%rdx) +; AVX-NEXT: vzeroupper ; AVX-NEXT: retq ; ; AVX2-LABEL: concat_shuf_of_a_to_a: ; AVX2: # %bb.0: ; AVX2-NEXT: vmovaps (%rdi), %xmm0 -; AVX2-NEXT: vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX2-NEXT: vmovaps %xmm0, (%rdx) -; AVX2-NEXT: vmovaps %xmm1, 16(%rdx) +; AVX2-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[0,1,1,0] +; AVX2-NEXT: vmovaps %ymm0, (%rdx) +; AVX2-NEXT: vzeroupper ; AVX2-NEXT: retq ; ; AVX512F-LABEL: concat_shuf_of_a_to_a: ; AVX512F: # %bb.0: ; AVX512F-NEXT: vmovaps (%rdi), %xmm0 -; AVX512F-NEXT: vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX512F-NEXT: vmovaps %xmm0, (%rdx) -; AVX512F-NEXT: vmovaps %xmm1, 16(%rdx) +; AVX512F-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[0,1,1,0] +; AVX512F-NEXT: vmovaps %ymm0, (%rdx) +; AVX512F-NEXT: vzeroupper ; AVX512F-NEXT: retq ; ; AVX512BW-LABEL: concat_shuf_of_a_to_a: ; AVX512BW: # %bb.0: ; AVX512BW-NEXT: vmovaps (%rdi), %xmm0 -; AVX512BW-NEXT: vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1] -; AVX512BW-NEXT: vmovaps %xmm0, (%rdx) -; AVX512BW-NEXT: vmovaps %xmm1, 16(%rdx) +; AVX512BW-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[0,1,1,0] +; AVX512BW-NEXT: vmovaps %ymm0, (%rdx) +; AVX512BW-NEXT: vzeroupper ; AVX512BW-NEXT: retq %a = load <2 x i64>, ptr %a.ptr, align 64 %b = load <2 x i64>, ptr %b.ptr, align 64 @@ -567,29 +569,33 @@ define void @concat_shuf_of_a_to_itself(ptr %a.ptr, ptr %dst) { ; AVX-LABEL: concat_shuf_of_a_to_itself: ; AVX: # %bb.0: ; AVX-NEXT: vpermilps {{.*#+}} xmm0 = mem[2,3,0,1] -; AVX-NEXT: vmovaps %xmm0, 16(%rsi) -; AVX-NEXT: vmovaps %xmm0, (%rsi) +; AVX-NEXT: vinsertf128 $1, %xmm0, %ymm0, %ymm0 +; AVX-NEXT: vmovaps %ymm0, (%rsi) +; AVX-NEXT: vzeroupper ; AVX-NEXT: retq ; ; AVX2-LABEL: concat_shuf_of_a_to_itself: ; AVX2: # %bb.0: -; AVX2-NEXT: vpermilps {{.*#+}} xmm0 = mem[2,3,0,1] -; AVX2-NEXT: vmovaps %xmm0, 16(%rsi) -; AVX2-NEXT: vmovaps %xmm0, (%rsi) +; AVX2-NEXT: vmovaps (%rdi), %xmm0 +; AVX2-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[1,0,1,0] +; AVX2-NEXT: vmovaps %ymm0, (%rsi) +; AVX2-NEXT: vzeroupper ; AVX2-NEXT: retq ; ; AVX512F-LABEL: concat_shuf_of_a_to_itself: ; AVX512F: # %bb.0: -; AVX512F-NEXT: vpermilps {{.*#+}} xmm0 = mem[2,3,0,1] -; AVX512F-NEXT: vmovaps %xmm0, 16(%rsi) -; AVX512F-NEXT: vmovaps %xmm0, (%rsi) +; AVX512F-NEXT: vmovaps (%rdi), %xmm0 +; AVX512F-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[1,0,1,0] +; AVX512F-NEXT: vmovaps %ymm0, (%rsi) +; AVX512F-NEXT: vzeroupper ; AVX512F-NEXT: retq ; ; AVX512BW-LABEL: concat_shuf_of_a_to_itself: ; AVX512BW: # %bb.0: -; AVX512BW-NEXT: vpermilps {{.*#+}} xmm0 = mem[2,3,0,1] -; AVX512BW-NEXT: vmovaps %xmm0, 16(%rsi) -; AVX512BW-NEXT: vmovaps %xmm0, (%rsi) +; AVX512BW-NEXT: vmovaps (%rdi), %xmm0 +; AVX512BW-NEXT: vpermpd {{.*#+}} ymm0 = ymm0[1,0,1,0] +; AVX512BW-NEXT: vmovaps %ymm0, (%rsi) +; AVX512BW-NEXT: vzeroupper ; AVX512BW-NEXT: retq %a = load <2 x i64>, ptr %a.ptr, align 64 %shuffle = shufflevector <2 x i64> %a, <2 x i64> poison, <2 x i32> @@ -613,19 +619,18 @@ define void @concat_aaa_to_shuf_of_a(ptr %a.ptr, ptr %dst) { ; AVX: # %bb.0: ; AVX-NEXT: vbroadcastf128 {{.*#+}} ymm0 = mem[0,1,0,1] ; AVX-NEXT: vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1] +; AVX-NEXT: vinsertf128 $1, %xmm0, %ymm1, %ymm1 ; AVX-NEXT: vmovaps %ymm0, 32(%rsi) -; AVX-NEXT: vmovaps %xmm0, 16(%rsi) -; AVX-NEXT: vmovaps %xmm1, (%rsi) +; AVX-NEXT: vmovaps %ymm1, (%rsi) ; AVX-NEXT: vzeroupper ; AVX-NEXT: retq ; ; AVX2-LABEL: concat_aaa_to_shuf_of_a: ; AVX2: # %bb.0: ; AVX2-NEXT: vbroadcastf128 {{.*#+}} ymm0 = mem[0,1,0,1] -; AVX2-NEXT: vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1] +; AVX2-NEXT: vpermpd {{.*#+}} ymm1 = ymm0[1,0,0,1] ; AVX2-NEXT: vmovaps %ymm0, 32(%rsi) -; AVX2-NEXT: vmovaps %xmm0, 16(%rsi) -; AVX2-NEXT: vmovaps %xmm1, (%rsi) +; AVX2-NEXT: vmovaps %ymm1, (%rsi) ; AVX2-NEXT: vzeroupper ; AVX2-NEXT: retq ; @@ -671,19 +676,18 @@ define void @concat_shuf_of_a_to_aaa(ptr %a.ptr, ptr %dst) { ; AVX: # %bb.0: ; AVX-NEXT: vbroadcastf128 {{.*#+}} ymm0 = mem[0,1,0,1] ; AVX-NEXT: vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1] +; AVX-NEXT: vinsertf128 $1, %xmm1, %ymm0, %ymm1 ; AVX-NEXT: vmovaps %ymm0, (%rsi) -; AVX-NEXT: vmovaps %xmm0, 32(%rsi) -; AVX-NEXT: vmovaps %xmm1, 48(%rsi) +; AVX-NEXT: vmovaps %ymm1, 32(%rsi) ; AVX-NEXT: vzeroupper ; AVX-NEXT: retq ; ; AVX2-LABEL: concat_shuf_of_a_to_aaa: ; AVX2: # %bb.0: ; AVX2-NEXT: vbroadcastf128 {{.*#+}} ymm0 = mem[0,1,0,1] -; AVX2-NEXT: vpermilps {{.*#+}} xmm1 = xmm0[2,3,0,1] +; AVX2-NEXT: vpermpd {{.*#+}} ymm1 = ymm0[0,1,1,0] ; AVX2-NEXT: vmovaps %ymm0, (%rsi) -; AVX2-NEXT: vmovaps %xmm0, 32(%rsi) -; AVX2-NEXT: vmovaps %xmm1, 48(%rsi) +; AVX2-NEXT: vmovaps %ymm1, 32(%rsi) ; AVX2-NEXT: vzeroupper ; AVX2-NEXT: retq ;