Skip to content

Commit

Permalink
[X86] Allow input vector extracted from larger vector when combining …
Browse files Browse the repository at this point in the history
…to VPMADDUBSW (#89584)

Failed on main trunk: https://godbolt.org/z/edWMz8chE
  • Loading branch information
phoebewang committed Apr 22, 2024
1 parent 35b292e commit c7e0f1e
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
11 changes: 11 additions & 0 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51841,6 +51841,17 @@ static SDValue detectPMADDUBSW(SDValue In, EVT VT, SelectionDAG &DAG,
return SDValue();
}

auto ExtractVec = [&DAG, &DL, NumElems](SDValue &Ext) {
EVT ExtVT = Ext.getValueType();
if (ExtVT.getVectorNumElements() != NumElems * 2) {
MVT NVT = MVT::getVectorVT(MVT::i8, NumElems * 2);
Ext = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NVT, Ext,
DAG.getIntPtrConstant(0, DL));
}
};
ExtractVec(ZExtIn);
ExtractVec(SExtIn);

auto PMADDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
ArrayRef<SDValue> Ops) {
// Shrink by adding truncate nodes and let DAGCombine fold with the
Expand Down
38 changes: 38 additions & 0 deletions llvm/test/CodeGen/X86/pmaddubsw.ll
Original file line number Diff line number Diff line change
Expand Up @@ -469,3 +469,41 @@ define <8 x i16> @pmaddubsw_bad_indices(ptr %Aptr, ptr %Bptr) {
%trunc = trunc <8 x i32> %min to <8 x i16>
ret <8 x i16> %trunc
}

define <8 x i16> @pmaddubsw_large_vector(ptr %p1, ptr %p2) {
; SSE-LABEL: pmaddubsw_large_vector:
; SSE: # %bb.0:
; SSE-NEXT: movdqa (%rdi), %xmm0
; SSE-NEXT: pmaddubsw (%rsi), %xmm0
; SSE-NEXT: pand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
; SSE-NEXT: retq
;
; AVX-LABEL: pmaddubsw_large_vector:
; AVX: # %bb.0:
; AVX-NEXT: vmovdqa (%rdi), %xmm0
; AVX-NEXT: vpmaddubsw (%rsi), %xmm0, %xmm0
; AVX-NEXT: vpxor %xmm1, %xmm1, %xmm1
; AVX-NEXT: vpblendw {{.*#+}} xmm0 = xmm1[0,1],xmm0[2],xmm1[3,4],xmm0[5],xmm1[6],xmm0[7]
; AVX-NEXT: retq
%1 = load <64 x i8>, ptr %p1, align 64
%2 = shufflevector <64 x i8> %1, <64 x i8> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
%3 = shufflevector <64 x i8> %1, <64 x i8> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
%4 = load <32 x i8>, ptr %p2, align 64
%5 = shufflevector <32 x i8> %4, <32 x i8> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
%6 = shufflevector <32 x i8> %4, <32 x i8> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
%7 = sext <8 x i8> %5 to <8 x i32>
%8 = zext <8 x i8> %2 to <8 x i32>
%9 = mul nsw <8 x i32> %7, %8
%10 = sext <8 x i8> %6 to <8 x i32>
%11 = zext <8 x i8> %3 to <8 x i32>
%12 = mul nsw <8 x i32> %10, %11
%13 = add nsw <8 x i32> %9, %12
%14 = tail call <8 x i32> @llvm.smin.v8i32(<8 x i32> %13, <8 x i32> <i32 32767, i32 32767, i32 32767, i32 32767, i32 32767, i32 32767, i32 32767, i32 32767>)
%15 = tail call <8 x i32> @llvm.smax.v8i32(<8 x i32> %14, <8 x i32> <i32 -32768, i32 -32768, i32 -32768, i32 -32768, i32 -32768, i32 -32768, i32 -32768, i32 -32768>)
%16 = trunc <8 x i32> %15 to <8 x i16>
%17 = shufflevector <8 x i16> zeroinitializer, <8 x i16> %16, <8 x i32> <i32 0, i32 1, i32 10, i32 3, i32 4, i32 13, i32 6, i32 15>
ret <8 x i16> %17
}

declare <8 x i32> @llvm.smin.v8i32(<8 x i32>, <8 x i32>)
declare <8 x i32> @llvm.smax.v8i32(<8 x i32>, <8 x i32>)

0 comments on commit c7e0f1e

Please sign in to comment.