Skip to content

Commit

Permalink
[X86] Teach combineLoopMAddPattern to handle cases where there is no …
Browse files Browse the repository at this point in the history
…loop and the add has two multiply inputs

Differential Revision: https://reviews.llvm.org/D50868

llvm-svn: 340631
  • Loading branch information
topperc committed Aug 24, 2018
1 parent 3c78622 commit 4058e29
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 41 deletions.
54 changes: 33 additions & 21 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Expand Up @@ -39059,17 +39059,8 @@ static SDValue combineLoopMAddPattern(SDNode *N, SelectionDAG &DAG,
if (!Subtarget.hasSSE2())
return SDValue();

SDValue MulOp = N->getOperand(0);
SDValue Phi = N->getOperand(1);

if (MulOp.getOpcode() != ISD::MUL)
std::swap(MulOp, Phi);
if (MulOp.getOpcode() != ISD::MUL)
return SDValue();

ShrinkMode Mode;
if (!canReduceVMulWidth(MulOp.getNode(), DAG, Mode) || Mode == MULU16)
return SDValue();
SDValue Op0 = N->getOperand(0);
SDValue Op1 = N->getOperand(1);

EVT VT = N->getValueType(0);

Expand All @@ -39078,28 +39069,49 @@ static SDValue combineLoopMAddPattern(SDNode *N, SelectionDAG &DAG,
if (VT.getVectorNumElements() < 8)
return SDValue();

if (Op0.getOpcode() != ISD::MUL)
std::swap(Op0, Op1);
if (Op0.getOpcode() != ISD::MUL)
return SDValue();

ShrinkMode Mode;
if (!canReduceVMulWidth(Op0.getNode(), DAG, Mode) || Mode == MULU16)
return SDValue();

SDLoc DL(N);
EVT ReducedVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16,
VT.getVectorNumElements());
EVT MAddVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32,
VT.getVectorNumElements() / 2);

// Shrink the operands of mul.
SDValue N0 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, MulOp->getOperand(0));
SDValue N1 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, MulOp->getOperand(1));

// Madd vector size is half of the original vector size
auto PMADDWDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
ArrayRef<SDValue> Ops) {
MVT VT = MVT::getVectorVT(MVT::i32, Ops[0].getValueSizeInBits() / 32);
return DAG.getNode(X86ISD::VPMADDWD, DL, VT, Ops);
};
SDValue Madd = SplitOpsAndApply(DAG, Subtarget, DL, MAddVT, { N0, N1 },
PMADDWDBuilder);
// Fill the rest of the output with 0
SDValue Zero = getZeroVector(Madd.getSimpleValueType(), Subtarget, DAG, DL);
SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Madd, Zero);
return DAG.getNode(ISD::ADD, DL, VT, Concat, Phi);

auto BuildPMADDWD = [&](SDValue Mul) {
// Shrink the operands of mul.
SDValue N0 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, Mul.getOperand(0));
SDValue N1 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, Mul.getOperand(1));

SDValue Madd = SplitOpsAndApply(DAG, Subtarget, DL, MAddVT, { N0, N1 },
PMADDWDBuilder);
// Fill the rest of the output with 0
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Madd,
DAG.getConstant(0, DL, MAddVT));
};

Op0 = BuildPMADDWD(Op0);

// It's possible that Op1 is also a mul we can reduce.
if (Op1.getOpcode() == ISD::MUL &&
canReduceVMulWidth(Op1.getNode(), DAG, Mode) && Mode != MULU16) {
Op1 = BuildPMADDWD(Op1);
}

return DAG.getNode(ISD::ADD, DL, VT, Op0, Op1);
}

static SDValue combineLoopSADPattern(SDNode *N, SelectionDAG &DAG,
Expand Down
28 changes: 8 additions & 20 deletions llvm/test/CodeGen/X86/madd.ll
Expand Up @@ -2671,17 +2671,11 @@ define i32 @madd_double_reduction(<8 x i16>* %arg, <8 x i16>* %arg1, <8 x i16>*
; SSE2: # %bb.0:
; SSE2-NEXT: movdqu (%rdi), %xmm0
; SSE2-NEXT: movdqu (%rsi), %xmm1
; SSE2-NEXT: movdqa %xmm0, %xmm2
; SSE2-NEXT: pmulhw %xmm1, %xmm2
; SSE2-NEXT: pmullw %xmm1, %xmm0
; SSE2-NEXT: movdqa %xmm0, %xmm1
; SSE2-NEXT: punpckhwd {{.*#+}} xmm1 = xmm1[4],xmm2[4],xmm1[5],xmm2[5],xmm1[6],xmm2[6],xmm1[7],xmm2[7]
; SSE2-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1],xmm0[2],xmm2[2],xmm0[3],xmm2[3]
; SSE2-NEXT: paddd %xmm1, %xmm0
; SSE2-NEXT: movdqu (%rdx), %xmm1
; SSE2-NEXT: pmaddwd %xmm0, %xmm1
; SSE2-NEXT: movdqu (%rdx), %xmm0
; SSE2-NEXT: movdqu (%rcx), %xmm2
; SSE2-NEXT: pmaddwd %xmm1, %xmm2
; SSE2-NEXT: paddd %xmm0, %xmm2
; SSE2-NEXT: pmaddwd %xmm0, %xmm2
; SSE2-NEXT: paddd %xmm1, %xmm2
; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm2[2,3,0,1]
; SSE2-NEXT: paddd %xmm2, %xmm0
; SSE2-NEXT: pshufd {{.*#+}} xmm1 = xmm0[1,1,2,3]
Expand All @@ -2691,14 +2685,9 @@ define i32 @madd_double_reduction(<8 x i16>* %arg, <8 x i16>* %arg1, <8 x i16>*
;
; AVX1-LABEL: madd_double_reduction:
; AVX1: # %bb.0:
; AVX1-NEXT: vpmovsxwd (%rdi), %xmm0
; AVX1-NEXT: vpmovsxwd 8(%rdi), %xmm1
; AVX1-NEXT: vpmovsxwd (%rsi), %xmm2
; AVX1-NEXT: vpmulld %xmm2, %xmm0, %xmm0
; AVX1-NEXT: vpmovsxwd 8(%rsi), %xmm2
; AVX1-NEXT: vpmulld %xmm2, %xmm1, %xmm1
; AVX1-NEXT: vpaddd %xmm1, %xmm0, %xmm0
; AVX1-NEXT: vmovdqu (%rdi), %xmm0
; AVX1-NEXT: vmovdqu (%rdx), %xmm1
; AVX1-NEXT: vpmaddwd (%rsi), %xmm0, %xmm0
; AVX1-NEXT: vpmaddwd (%rcx), %xmm1, %xmm1
; AVX1-NEXT: vpaddd %xmm0, %xmm1, %xmm0
; AVX1-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
Expand All @@ -2709,10 +2698,9 @@ define i32 @madd_double_reduction(<8 x i16>* %arg, <8 x i16>* %arg1, <8 x i16>*
;
; AVX256-LABEL: madd_double_reduction:
; AVX256: # %bb.0:
; AVX256-NEXT: vpmovsxwd (%rdi), %ymm0
; AVX256-NEXT: vpmovsxwd (%rsi), %ymm1
; AVX256-NEXT: vpmulld %ymm1, %ymm0, %ymm0
; AVX256-NEXT: vmovdqu (%rdi), %xmm0
; AVX256-NEXT: vmovdqu (%rdx), %xmm1
; AVX256-NEXT: vpmaddwd (%rsi), %xmm0, %xmm0
; AVX256-NEXT: vpmaddwd (%rcx), %xmm1, %xmm1
; AVX256-NEXT: vpaddd %ymm0, %ymm1, %ymm0
; AVX256-NEXT: vextracti128 $1, %ymm0, %xmm1
Expand Down

0 comments on commit 4058e29

Please sign in to comment.