diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index fd01363bed709..007074c3ffc82 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -54459,6 +54459,7 @@ static SDValue combinePMULH(SDValue Src, EVT VT, const SDLoc &DL, static SDValue detectPMADDUBSW(SDValue In, EVT VT, SelectionDAG &DAG, const X86Subtarget &Subtarget, const SDLoc &DL) { + using namespace SDPatternMatch; if (!VT.isVector() || !Subtarget.hasSSSE3()) return SDValue(); @@ -54468,42 +54469,19 @@ static SDValue detectPMADDUBSW(SDValue In, EVT VT, SelectionDAG &DAG, return SDValue(); SDValue SSatVal = detectSSatPattern(In, VT); - if (!SSatVal || SSatVal.getOpcode() != ISD::ADD) - return SDValue(); - - // Ok this is a signed saturation of an ADD. See if this ADD is adding pairs - // of multiplies from even/odd elements. - SDValue N0 = SSatVal.getOperand(0); - SDValue N1 = SSatVal.getOperand(1); - - if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL) + if (!SSatVal) return SDValue(); - SDValue N00 = N0.getOperand(0); - SDValue N01 = N0.getOperand(1); - SDValue N10 = N1.getOperand(0); - SDValue N11 = N1.getOperand(1); - + // See if this is a signed saturation of an ADD, adding pairs of multiplies + // from even/odd elements, from zero_extend/sign_extend operands. + // // TODO: Handle constant vectors and use knownbits/computenumsignbits? - // Canonicalize zero_extend to LHS. - if (N01.getOpcode() == ISD::ZERO_EXTEND) - std::swap(N00, N01); - if (N11.getOpcode() == ISD::ZERO_EXTEND) - std::swap(N10, N11); - - // Ensure we have a zero_extend and a sign_extend. - if (N00.getOpcode() != ISD::ZERO_EXTEND || - N01.getOpcode() != ISD::SIGN_EXTEND || - N10.getOpcode() != ISD::ZERO_EXTEND || - N11.getOpcode() != ISD::SIGN_EXTEND) + SDValue N00, N01, N10, N11; + if (!sd_match(SSatVal, + m_Add(m_Mul(m_ZExt(m_Value(N00)), m_SExt(m_Value(N01))), + m_Mul(m_ZExt(m_Value(N10)), m_SExt(m_Value(N11)))))) return SDValue(); - // Peek through the extends. - N00 = N00.getOperand(0); - N01 = N01.getOperand(0); - N10 = N10.getOperand(0); - N11 = N11.getOperand(0); - // Ensure the extend is from vXi8. if (N00.getValueType().getVectorElementType() != MVT::i8 || N01.getValueType().getVectorElementType() != MVT::i8 ||