Skip to content

Commit

Permalink
[x86] enhance matching of pmaddwd
Browse files Browse the repository at this point in the history
This was crashing with the example from:
https://llvm.org/PR49716
...and that was avoided with a283d72 ,
but as we can see from the SSE vs. AVX test code diff,
we can try harder to match the pattern.

This matcher code was adapted from another pmadd pattern
match in D49636, but it needs different ops to deal with
size mismatches.

Differential Revision: https://reviews.llvm.org/D99531
  • Loading branch information
rotateright committed Mar 30, 2021
1 parent c5109d3 commit e694e19
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 57 deletions.
23 changes: 17 additions & 6 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49045,10 +49045,10 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1,
In0 = N00In;
In1 = N01In;

// The input vector sizes must match the output.
// TODO: Insert cast ops to allow different types.
if (In0.getValueSizeInBits() != VT.getSizeInBits() ||
In1.getValueSizeInBits() != VT.getSizeInBits())
// The input vectors must be at least as wide as the output.
// If they are larger than the output, we extract subvector below.
if (In0.getValueSizeInBits() < VT.getSizeInBits() ||
In1.getValueSizeInBits() < VT.getSizeInBits())
return SDValue();
}
// Mul is commutative so the input vectors can be in any order.
Expand All @@ -49063,8 +49063,6 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1,

auto PMADDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
ArrayRef<SDValue> Ops) {
// Shrink by adding truncate nodes and let DAGCombine fold with the
// sources.
EVT OpVT = Ops[0].getValueType();
assert(OpVT.getScalarType() == MVT::i16 &&
"Unexpected scalar element type");
Expand All @@ -49073,6 +49071,19 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1,
OpVT.getVectorNumElements() / 2);
return DAG.getNode(X86ISD::VPMADDWD, DL, ResVT, Ops[0], Ops[1]);
};

// If the output is narrower than an input, extract the low part of the input
// vector.
EVT OutVT16 = EVT::getVectorVT(*DAG.getContext(), MVT::i16,
VT.getVectorNumElements() * 2);
if (OutVT16.bitsLT(In0.getValueType())) {
In0 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, OutVT16, In0,
DAG.getIntPtrConstant(0, DL));
}
if (OutVT16.bitsLT(In1.getValueType())) {
In1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, OutVT16, In1,
DAG.getIntPtrConstant(0, DL));
}
return SplitOpsAndApply(DAG, Subtarget, DL, VT, { In0, In1 },
PMADDBuilder);
}
Expand Down
112 changes: 61 additions & 51 deletions llvm/test/CodeGen/X86/madd.ll
Original file line number Diff line number Diff line change
Expand Up @@ -3052,48 +3052,12 @@ middle.block:
define <4 x i32> @input_size_mismatch(<16 x i16> %x, <16 x i16>* %p) {
; SSE2-LABEL: input_size_mismatch:
; SSE2: # %bb.0:
; SSE2-NEXT: movdqa (%rdi), %xmm1
; SSE2-NEXT: pshuflw {{.*#+}} xmm2 = xmm0[0,2,2,3,4,5,6,7]
; SSE2-NEXT: pshufhw {{.*#+}} xmm2 = xmm2[0,1,2,3,4,6,6,7]
; SSE2-NEXT: pshufd {{.*#+}} xmm2 = xmm2[0,2,2,3]
; SSE2-NEXT: pshuflw {{.*#+}} xmm0 = xmm0[3,1,2,3,4,5,6,7]
; SSE2-NEXT: pshufhw {{.*#+}} xmm0 = xmm0[0,1,2,3,7,5,6,7]
; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm0[0,2,2,3]
; SSE2-NEXT: pshuflw {{.*#+}} xmm0 = xmm0[1,0,3,2,4,5,6,7]
; SSE2-NEXT: pshuflw {{.*#+}} xmm3 = xmm1[0,2,2,3,4,5,6,7]
; SSE2-NEXT: pshufhw {{.*#+}} xmm3 = xmm3[0,1,2,3,4,6,6,7]
; SSE2-NEXT: pshufd {{.*#+}} xmm3 = xmm3[0,2,2,3]
; SSE2-NEXT: pshuflw {{.*#+}} xmm1 = xmm1[3,1,2,3,4,5,6,7]
; SSE2-NEXT: pshufhw {{.*#+}} xmm1 = xmm1[0,1,2,3,7,5,6,7]
; SSE2-NEXT: pshufd {{.*#+}} xmm1 = xmm1[0,2,2,3]
; SSE2-NEXT: pshuflw {{.*#+}} xmm1 = xmm1[1,0,3,2,4,5,6,7]
; SSE2-NEXT: movdqa %xmm2, %xmm4
; SSE2-NEXT: pmulhw %xmm3, %xmm4
; SSE2-NEXT: pmullw %xmm3, %xmm2
; SSE2-NEXT: punpcklwd {{.*#+}} xmm2 = xmm2[0],xmm4[0],xmm2[1],xmm4[1],xmm2[2],xmm4[2],xmm2[3],xmm4[3]
; SSE2-NEXT: movdqa %xmm0, %xmm3
; SSE2-NEXT: pmulhw %xmm1, %xmm3
; SSE2-NEXT: pmullw %xmm1, %xmm0
; SSE2-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm3[0],xmm0[1],xmm3[1],xmm0[2],xmm3[2],xmm0[3],xmm3[3]
; SSE2-NEXT: paddd %xmm2, %xmm0
; SSE2-NEXT: pmaddwd (%rdi), %xmm0
; SSE2-NEXT: retq
;
; AVX-LABEL: input_size_mismatch:
; AVX: # %bb.0:
; AVX-NEXT: vmovdqa {{.*#+}} xmm1 = [0,1,4,5,8,9,12,13,8,9,12,13,12,13,14,15]
; AVX-NEXT: vpshufb %xmm1, %xmm0, %xmm2
; AVX-NEXT: vmovdqa {{.*#+}} xmm3 = [2,3,6,7,10,11,14,15,14,15,10,11,12,13,14,15]
; AVX-NEXT: vpshufb %xmm3, %xmm0, %xmm0
; AVX-NEXT: vmovdqa (%rdi), %xmm4
; AVX-NEXT: vpshufb %xmm1, %xmm4, %xmm1
; AVX-NEXT: vpshufb %xmm3, %xmm4, %xmm3
; AVX-NEXT: vpmovsxwd %xmm2, %xmm2
; AVX-NEXT: vpmovsxwd %xmm0, %xmm0
; AVX-NEXT: vpmovsxwd %xmm1, %xmm1
; AVX-NEXT: vpmulld %xmm1, %xmm2, %xmm1
; AVX-NEXT: vpmovsxwd %xmm3, %xmm2
; AVX-NEXT: vpmulld %xmm2, %xmm0, %xmm0
; AVX-NEXT: vpaddd %xmm0, %xmm1, %xmm0
; AVX-NEXT: vpmaddwd (%rdi), %xmm0, %xmm0
; AVX-NEXT: vzeroupper
; AVX-NEXT: retq
%y = load <16 x i16>, <16 x i16>* %p, align 32
Expand All @@ -3119,19 +3083,7 @@ define <4 x i32> @output_size_mismatch(<16 x i16> %x, <16 x i16> %y) {
;
; AVX-LABEL: output_size_mismatch:
; AVX: # %bb.0:
; AVX-NEXT: vmovdqa {{.*#+}} xmm2 = [0,1,4,5,8,9,12,13,8,9,12,13,12,13,14,15]
; AVX-NEXT: vpshufb %xmm2, %xmm0, %xmm3
; AVX-NEXT: vmovdqa {{.*#+}} xmm4 = [2,3,6,7,10,11,14,15,14,15,10,11,12,13,14,15]
; AVX-NEXT: vpshufb %xmm4, %xmm0, %xmm0
; AVX-NEXT: vpshufb %xmm2, %xmm1, %xmm2
; AVX-NEXT: vpshufb %xmm4, %xmm1, %xmm1
; AVX-NEXT: vpmovsxwd %xmm3, %xmm3
; AVX-NEXT: vpmovsxwd %xmm0, %xmm0
; AVX-NEXT: vpmovsxwd %xmm2, %xmm2
; AVX-NEXT: vpmulld %xmm2, %xmm3, %xmm2
; AVX-NEXT: vpmovsxwd %xmm1, %xmm1
; AVX-NEXT: vpmulld %xmm1, %xmm0, %xmm0
; AVX-NEXT: vpaddd %xmm0, %xmm2, %xmm0
; AVX-NEXT: vpmaddwd %xmm1, %xmm0, %xmm0
; AVX-NEXT: vzeroupper
; AVX-NEXT: retq
%x0 = shufflevector <16 x i16> %x, <16 x i16> undef, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
Expand All @@ -3147,3 +3099,61 @@ define <4 x i32> @output_size_mismatch(<16 x i16> %x, <16 x i16> %y) {
%r = add <4 x i32> %m0, %m1
ret <4 x i32> %r
}

define <4 x i32> @output_size_mismatch_high_subvector(<16 x i16> %x, <16 x i16> %y) {
; SSE2-LABEL: output_size_mismatch_high_subvector:
; SSE2: # %bb.0:
; SSE2-NEXT: movdqa %xmm1, %xmm0
; SSE2-NEXT: pmaddwd %xmm2, %xmm0
; SSE2-NEXT: retq
;
; AVX1-LABEL: output_size_mismatch_high_subvector:
; AVX1: # %bb.0:
; AVX1-NEXT: vmovdqa {{.*#+}} xmm2 = [0,1,4,5,8,9,12,13,8,9,12,13,12,13,14,15]
; AVX1-NEXT: vextractf128 $1, %ymm0, %xmm0
; AVX1-NEXT: vpshufb %xmm2, %xmm0, %xmm3
; AVX1-NEXT: vmovdqa {{.*#+}} xmm4 = [2,3,6,7,10,11,14,15,14,15,10,11,12,13,14,15]
; AVX1-NEXT: vpshufb %xmm4, %xmm0, %xmm0
; AVX1-NEXT: vpshufb %xmm2, %xmm1, %xmm2
; AVX1-NEXT: vpshufb %xmm4, %xmm1, %xmm1
; AVX1-NEXT: vpmovsxwd %xmm3, %xmm3
; AVX1-NEXT: vpmovsxwd %xmm0, %xmm0
; AVX1-NEXT: vpmovsxwd %xmm2, %xmm2
; AVX1-NEXT: vpmulld %xmm2, %xmm3, %xmm2
; AVX1-NEXT: vpmovsxwd %xmm1, %xmm1
; AVX1-NEXT: vpmulld %xmm1, %xmm0, %xmm0
; AVX1-NEXT: vpaddd %xmm0, %xmm2, %xmm0
; AVX1-NEXT: vzeroupper
; AVX1-NEXT: retq
;
; AVX256-LABEL: output_size_mismatch_high_subvector:
; AVX256: # %bb.0:
; AVX256-NEXT: vmovdqa {{.*#+}} xmm2 = [0,1,4,5,8,9,12,13,8,9,12,13,12,13,14,15]
; AVX256-NEXT: vextracti128 $1, %ymm0, %xmm0
; AVX256-NEXT: vpshufb %xmm2, %xmm0, %xmm3
; AVX256-NEXT: vmovdqa {{.*#+}} xmm4 = [2,3,6,7,10,11,14,15,14,15,10,11,12,13,14,15]
; AVX256-NEXT: vpshufb %xmm4, %xmm0, %xmm0
; AVX256-NEXT: vpshufb %xmm2, %xmm1, %xmm2
; AVX256-NEXT: vpshufb %xmm4, %xmm1, %xmm1
; AVX256-NEXT: vpmovsxwd %xmm3, %xmm3
; AVX256-NEXT: vpmovsxwd %xmm0, %xmm0
; AVX256-NEXT: vpmovsxwd %xmm2, %xmm2
; AVX256-NEXT: vpmulld %xmm2, %xmm3, %xmm2
; AVX256-NEXT: vpmovsxwd %xmm1, %xmm1
; AVX256-NEXT: vpmulld %xmm1, %xmm0, %xmm0
; AVX256-NEXT: vpaddd %xmm0, %xmm2, %xmm0
; AVX256-NEXT: vzeroupper
; AVX256-NEXT: retq
%x0 = shufflevector <16 x i16> %x, <16 x i16> undef, <4 x i32> <i32 8, i32 10, i32 12, i32 14>
%x1 = shufflevector <16 x i16> %x, <16 x i16> undef, <4 x i32> <i32 9, i32 11, i32 13, i32 15>
%y0 = shufflevector <16 x i16> %y, <16 x i16> undef, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
%y1 = shufflevector <16 x i16> %y, <16 x i16> undef, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
%sx0 = sext <4 x i16> %x0 to <4 x i32>
%sx1 = sext <4 x i16> %x1 to <4 x i32>
%sy0 = sext <4 x i16> %y0 to <4 x i32>
%sy1 = sext <4 x i16> %y1 to <4 x i32>
%m0 = mul <4 x i32> %sx0, %sy0
%m1 = mul <4 x i32> %sx1, %sy1
%r = add <4 x i32> %m0, %m1
ret <4 x i32> %r
}

0 comments on commit e694e19

Please sign in to comment.