Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 9 additions & 31 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54459,6 +54459,7 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL,
static SDValue detectPMADDUBSW(SDValue In, EVT VT, SelectionDAG &DAG,
const X86Subtarget &Subtarget,
const SDLoc &DL) {
using namespace SDPatternMatch;
if (!VT.isVector() || !Subtarget.hasSSSE3())
return SDValue();

Expand All @@ -54468,42 +54469,19 @@ static SDValue detectPMADDUBSW(SDValue In, EVT VT, SelectionDAG &DAG,
return SDValue();

SDValue SSatVal = detectSSatPattern(In, VT);
if (!SSatVal || SSatVal.getOpcode() != ISD::ADD)
return SDValue();

// Ok this is a signed saturation of an ADD. See if this ADD is adding pairs
// of multiplies from even/odd elements.
SDValue N0 = SSatVal.getOperand(0);
SDValue N1 = SSatVal.getOperand(1);

if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL)
if (!SSatVal)
return SDValue();

SDValue N00 = N0.getOperand(0);
SDValue N01 = N0.getOperand(1);
SDValue N10 = N1.getOperand(0);
SDValue N11 = N1.getOperand(1);

// See if this is a signed saturation of an ADD, adding pairs of multiplies
// from even/odd elements, from zero_extend/sign_extend operands.
//
// TODO: Handle constant vectors and use knownbits/computenumsignbits?
// Canonicalize zero_extend to LHS.
if (N01.getOpcode() == ISD::ZERO_EXTEND)
std::swap(N00, N01);
if (N11.getOpcode() == ISD::ZERO_EXTEND)
std::swap(N10, N11);

// Ensure we have a zero_extend and a sign_extend.
if (N00.getOpcode() != ISD::ZERO_EXTEND ||
N01.getOpcode() != ISD::SIGN_EXTEND ||
N10.getOpcode() != ISD::ZERO_EXTEND ||
N11.getOpcode() != ISD::SIGN_EXTEND)
SDValue N00, N01, N10, N11;
if (!sd_match(SSatVal,
m_Add(m_Mul(m_ZExt(m_Value(N00)), m_SExt(m_Value(N01))),
m_Mul(m_ZExt(m_Value(N10)), m_SExt(m_Value(N11))))))
return SDValue();

// Peek through the extends.
N00 = N00.getOperand(0);
N01 = N01.getOperand(0);
N10 = N10.getOperand(0);
N11 = N11.getOperand(0);

// Ensure the extend is from vXi8.
if (N00.getValueType().getVectorElementType() != MVT::i8 ||
N01.getValueType().getVectorElementType() != MVT::i8 ||
Expand Down
Loading