Skip to content

Commit

Permalink
[X86] Add pattern matching for PMADDUBSW
Browse files Browse the repository at this point in the history
Summary:
Similar to D49636, but for PMADDUBSW. This instruction has the additional complexity that the addition of the two products saturates to 16-bits rather than wrapping around. And one operand is treated as signed and the other as unsigned.

A C example that triggers this pattern

```
static const int N = 128;

int8_t A[2*N];
uint8_t B[2*N];
int16_t C[N];

void foo() {
  for (int i = 0; i != N; ++i)
    C[i] = MIN(MAX((int16_t)A[2*i]*(int16_t)B[2*i] + (int16_t)A[2*i+1]*(int16_t)B[2*i+1], -32768), 32767);
}
```

Reviewers: RKSimon, spatel, zvi

Reviewed By: RKSimon, zvi

Subscribers: llvm-commits

Differential Revision: https://reviews.llvm.org/D49829

llvm-svn: 338402
  • Loading branch information
topperc committed Jul 31, 2018
1 parent d03d44e commit bef126f
Show file tree
Hide file tree
Showing 2 changed files with 251 additions and 1,788 deletions.
143 changes: 143 additions & 0 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Expand Up @@ -36753,6 +36753,145 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
return DAG.getNode(Opc, DL, VT, LHS, RHS);
}

// Attempt to match PMADDUBSW, which multiplies corresponding unsigned bytes
// from one vector with signed bytes from another vector, adds together
// adjacent pairs of 16-bit products, and saturates the result before
// truncating to 16-bits.
//
// Which looks something like this:
// (i16 (ssat (add (mul (zext (even elts (i8 A))), (sext (even elts (i8 B)))),
// (mul (zext (odd elts (i8 A)), (sext (odd elts (i8 B))))))))
static SDValue detectPMADDUBSW(SDValue In, EVT VT, SelectionDAG &DAG,
const X86Subtarget &Subtarget,
const SDLoc &DL) {
if (!VT.isVector() || !Subtarget.hasSSSE3())
return SDValue();

unsigned NumElems = VT.getVectorNumElements();
EVT ScalarVT = VT.getVectorElementType();
if (ScalarVT != MVT::i16 || NumElems < 8 || !isPowerOf2_32(NumElems))
return SDValue();

SDValue SSatVal = detectSSatPattern(In, VT);
if (!SSatVal || SSatVal.getOpcode() != ISD::ADD)
return SDValue();

// Ok this is a signed saturation of an ADD. See if this ADD is adding pairs
// of multiplies from even/odd elements.
SDValue N0 = SSatVal.getOperand(0);
SDValue N1 = SSatVal.getOperand(1);

if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL)
return SDValue();

SDValue N00 = N0.getOperand(0);
SDValue N01 = N0.getOperand(1);
SDValue N10 = N1.getOperand(0);
SDValue N11 = N1.getOperand(1);

// TODO: Handle constant vectors and use knownbits/computenumsignbits?
// Canonicalize zero_extend to LHS.
if (N01.getOpcode() == ISD::ZERO_EXTEND)
std::swap(N00, N01);
if (N11.getOpcode() == ISD::ZERO_EXTEND)
std::swap(N10, N11);

// Ensure we have a zero_extend and a sign_extend.
if (N00.getOpcode() != ISD::ZERO_EXTEND ||
N01.getOpcode() != ISD::SIGN_EXTEND ||
N10.getOpcode() != ISD::ZERO_EXTEND ||
N11.getOpcode() != ISD::SIGN_EXTEND)
return SDValue();

// Peek through the extends.
N00 = N00.getOperand(0);
N01 = N01.getOperand(0);
N10 = N10.getOperand(0);
N11 = N11.getOperand(0);

// Ensure the extend is from vXi8.
if (N00.getValueType().getVectorElementType() != MVT::i8 ||
N01.getValueType().getVectorElementType() != MVT::i8 ||
N10.getValueType().getVectorElementType() != MVT::i8 ||
N11.getValueType().getVectorElementType() != MVT::i8)
return SDValue();

// All inputs should be build_vectors.
if (N00.getOpcode() != ISD::BUILD_VECTOR ||
N01.getOpcode() != ISD::BUILD_VECTOR ||
N10.getOpcode() != ISD::BUILD_VECTOR ||
N11.getOpcode() != ISD::BUILD_VECTOR)
return SDValue();

// N00/N10 are zero extended. N01/N11 are sign extended.

// For each element, we need to ensure we have an odd element from one vector
// multiplied by the odd element of another vector and the even element from
// one of the same vectors being multiplied by the even element from the
// other vector. So we need to make sure for each element i, this operator
// is being performed:
// A[2 * i] * B[2 * i] + A[2 * i + 1] * B[2 * i + 1]
SDValue ZExtIn, SExtIn;
for (unsigned i = 0; i != NumElems; ++i) {
SDValue N00Elt = N00.getOperand(i);
SDValue N01Elt = N01.getOperand(i);
SDValue N10Elt = N10.getOperand(i);
SDValue N11Elt = N11.getOperand(i);
// TODO: Be more tolerant to undefs.
if (N00Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
N01Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
N10Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
N11Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
return SDValue();
auto *ConstN00Elt = dyn_cast<ConstantSDNode>(N00Elt.getOperand(1));
auto *ConstN01Elt = dyn_cast<ConstantSDNode>(N01Elt.getOperand(1));
auto *ConstN10Elt = dyn_cast<ConstantSDNode>(N10Elt.getOperand(1));
auto *ConstN11Elt = dyn_cast<ConstantSDNode>(N11Elt.getOperand(1));
if (!ConstN00Elt || !ConstN01Elt || !ConstN10Elt || !ConstN11Elt)
return SDValue();
unsigned IdxN00 = ConstN00Elt->getZExtValue();
unsigned IdxN01 = ConstN01Elt->getZExtValue();
unsigned IdxN10 = ConstN10Elt->getZExtValue();
unsigned IdxN11 = ConstN11Elt->getZExtValue();
// Add is commutative so indices can be reordered.
if (IdxN00 > IdxN10) {
std::swap(IdxN00, IdxN10);
std::swap(IdxN01, IdxN11);
}
// N0 indices be the even element. N1 indices must be the next odd element.
if (IdxN00 != 2 * i || IdxN10 != 2 * i + 1 ||
IdxN01 != 2 * i || IdxN11 != 2 * i + 1)
return SDValue();
SDValue N00In = N00Elt.getOperand(0);
SDValue N01In = N01Elt.getOperand(0);
SDValue N10In = N10Elt.getOperand(0);
SDValue N11In = N11Elt.getOperand(0);
// First time we find an input capture it.
if (!ZExtIn) {
ZExtIn = N00In;
SExtIn = N01In;
}
if (ZExtIn != N00In || SExtIn != N01In ||
ZExtIn != N10In || SExtIn != N11In)
return SDValue();
}

auto PMADDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
ArrayRef<SDValue> Ops) {
// Shrink by adding truncate nodes and let DAGCombine fold with the
// sources.
EVT InVT = Ops[0].getValueType();
assert(InVT.getScalarType() == MVT::i8 &&
"Unexpected scalar element type");
assert(InVT == Ops[1].getValueType() && "Operands' types mismatch");
EVT ResVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16,
InVT.getVectorNumElements() / 2);
return DAG.getNode(X86ISD::VPMADDUBSW, DL, ResVT, Ops[0], Ops[1]);
};
return SplitOpsAndApply(DAG, Subtarget, DL, VT, { ZExtIn, SExtIn },
PMADDBuilder);
}

static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
EVT VT = N->getValueType(0);
Expand All @@ -36767,6 +36906,10 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
if (SDValue Avg = detectAVGPattern(Src, VT, DAG, Subtarget, DL))
return Avg;

// Try to detect PMADD
if (SDValue PMAdd = detectPMADDUBSW(Src, VT, DAG, Subtarget, DL))
return PMAdd;

// Try to combine truncation with signed/unsigned saturation.
if (SDValue Val = combineTruncateWithSat(Src, VT, DL, DAG, Subtarget))
return Val;
Expand Down

0 comments on commit bef126f

Please sign in to comment.