diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 50df19b3e6e47..aa44ef9c8e291 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -58033,7 +58033,8 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDNode *N, // (extract_elt Mul, 3), // (extract_elt Mul, 5), // ... - // and identify Mul. + // and identify Mul. Mul must be either ISD::MUL, or can be ISD::SIGN_EXTEND + // in which case we add a trivial multiplication by an all-ones vector. SDValue Mul; for (unsigned i = 0, e = VT.getVectorNumElements(); i != e; i += 2) { SDValue Op0L = Op0->getOperand(i), Op1L = Op1->getOperand(i), @@ -58064,7 +58065,8 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDNode *N, // with 2X number of vector elements than the BUILD_VECTOR. // Both extracts must be from same MUL. Mul = Vec0L; - if (Mul.getOpcode() != ISD::MUL || + if ((Mul.getOpcode() != ISD::MUL && + Mul.getOpcode() != ISD::SIGN_EXTEND) || Mul.getValueType().getVectorNumElements() != 2 * e) return SDValue(); } @@ -58073,16 +58075,32 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDNode *N, return SDValue(); } - // Check if the Mul source can be safely shrunk. - ShrinkMode Mode; - if (!canReduceVMulWidth(Mul.getNode(), DAG, Mode) || - Mode == ShrinkMode::MULU16) - return SDValue(); + SDValue N0, N1; + if (Mul.getOpcode() == ISD::MUL) { + // Check if the Mul source can be safely shrunk. + ShrinkMode Mode; + if (!canReduceVMulWidth(Mul.getNode(), DAG, Mode) || + Mode == ShrinkMode::MULU16) + return SDValue(); + + EVT TruncVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, + VT.getVectorNumElements() * 2); + N0 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, Mul.getOperand(0)); + N1 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, Mul.getOperand(1)); + } else { + assert(Mul.getOpcode() == ISD::SIGN_EXTEND); - EVT TruncVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, - VT.getVectorNumElements() * 2); - SDValue N0 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, Mul.getOperand(0)); - SDValue N1 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, Mul.getOperand(1)); + // Add a trivial multiplication with an all-ones vector so that we can make + // use of VPMADDWD. + N0 = Mul.getOperand(0); + EVT SrcVT = N0.getValueType(); + + if (!SrcVT.isVector() || SrcVT.getVectorElementType() != MVT::i16 || + SrcVT.getVectorNumElements() != 2 * VT.getVectorNumElements()) + return SDValue(); + + N1 = DAG.getConstant(1, DL, SrcVT); + } auto PMADDBuilder = [](SelectionDAG &DAG, const SDLoc &DL, ArrayRef Ops) { diff --git a/llvm/test/CodeGen/X86/combine-pmadd.ll b/llvm/test/CodeGen/X86/combine-pmadd.ll index d9283aa8591fc..53f1374669ca5 100644 --- a/llvm/test/CodeGen/X86/combine-pmadd.ll +++ b/llvm/test/CodeGen/X86/combine-pmadd.ll @@ -331,3 +331,32 @@ define i1 @pmaddwd_pcmpgt_infinite_loop() { %8 = icmp eq i4 %7, 0 ret i1 %8 } + +; If the shuffle matches, but there is no multiply, introduce a trivial multiply by an all-ones vector. +define <8 x i32> @introduce_trivial_multiply(<16 x i16> %x) { +; SSE-LABEL: introduce_trivial_multiply: +; SSE: # %bb.0: +; SSE-NEXT: pmovsxbw {{.*#+}} xmm2 = [1,1,1,1,1,1,1,1] +; SSE-NEXT: pmaddwd %xmm2, %xmm0 +; SSE-NEXT: pmaddwd %xmm2, %xmm1 +; SSE-NEXT: retq +; +; AVX1-LABEL: introduce_trivial_multiply: +; AVX1: # %bb.0: +; AVX1-NEXT: vextractf128 $1, %ymm0, %xmm1 +; AVX1-NEXT: vbroadcastss {{.*#+}} xmm2 = [1,1,1,1,1,1,1,1] +; AVX1-NEXT: vpmaddwd %xmm2, %xmm1, %xmm1 +; AVX1-NEXT: vpmaddwd %xmm2, %xmm0, %xmm0 +; AVX1-NEXT: vinsertf128 $1, %xmm1, %ymm0, %ymm0 +; AVX1-NEXT: retq +; +; AVX2-LABEL: introduce_trivial_multiply: +; AVX2: # %bb.0: +; AVX2-NEXT: vpmaddwd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0, %ymm0 # [1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1] +; AVX2-NEXT: retq + %1 = sext <16 x i16> %x to <16 x i32> + %2 = shufflevector <16 x i32> %1, <16 x i32> poison, <8 x i32> + %3 = shufflevector <16 x i32> %1, <16 x i32> poison, <8 x i32> + %4 = add nsw <8 x i32> %2, %3 + ret <8 x i32> %4 +}