Skip to content

Commit 212ba92

Browse files
authored
[X86] Recognise VPMADD52L pattern with AVX512IFMA/AVXIFMA (#153787) (#156714)
Match `(X * Y) + Z` in `combineAdd`. If target supports and we don't overflow (ie. we know the top 12 bits are unset), rewrite using VPMADD52L Have just done the `L` version for now at least, wanted to get feedback before continuing
1 parent dcaa29c commit 212ba92

File tree

2 files changed

+628
-0
lines changed

2 files changed

+628
-0
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57970,6 +57970,51 @@ static SDValue pushAddIntoCmovOfConsts(SDNode *N, const SDLoc &DL,
5797057970
Cmov.getOperand(3));
5797157971
}
5797257972

57973+
// Attempt to turn ADD(MUL(x, y), acc)) -> VPMADD52L
57974+
// When upper 12 bits of x, y and MUL(x, y) are known to be 0
57975+
static SDValue matchVPMADD52(SDNode *N, SelectionDAG &DAG, const SDLoc &DL,
57976+
EVT VT, const X86Subtarget &Subtarget) {
57977+
using namespace SDPatternMatch;
57978+
if (!VT.isVector() || VT.getScalarSizeInBits() != 64 ||
57979+
(!Subtarget.hasAVXIFMA() && !Subtarget.hasIFMA()))
57980+
return SDValue();
57981+
57982+
// Need AVX-512VL vector length extensions if operating on XMM/YMM registers
57983+
if (!Subtarget.hasAVXIFMA() && !Subtarget.hasVLX() &&
57984+
VT.getSizeInBits() < 512)
57985+
return SDValue();
57986+
57987+
const auto TotalSize = VT.getSizeInBits();
57988+
if (TotalSize < 128 || !isPowerOf2_64(TotalSize))
57989+
return SDValue();
57990+
57991+
SDValue X, Y, Acc;
57992+
if (!sd_match(N, m_Add(m_Mul(m_Value(X), m_Value(Y)), m_Value(Acc))))
57993+
return SDValue();
57994+
57995+
KnownBits KnownX = DAG.computeKnownBits(X);
57996+
if (KnownX.countMinLeadingZeros() < 12)
57997+
return SDValue();
57998+
KnownBits KnownY = DAG.computeKnownBits(Y);
57999+
if (KnownY.countMinLeadingZeros() < 12)
58000+
return SDValue();
58001+
KnownBits KnownMul = KnownBits::mul(KnownX, KnownY);
58002+
if (KnownMul.countMinLeadingZeros() < 12)
58003+
return SDValue();
58004+
58005+
auto VPMADD52Builder = [](SelectionDAG &G, SDLoc DL,
58006+
ArrayRef<SDValue> SubOps) {
58007+
EVT SubVT = SubOps[0].getValueType();
58008+
assert(SubVT.getScalarSizeInBits() == 64 &&
58009+
"Unexpected element size, only supports 64bit size");
58010+
return G.getNode(X86ISD::VPMADD52L, DL, SubVT, SubOps[1] /*X*/,
58011+
SubOps[2] /*Y*/, SubOps[0] /*Acc*/);
58012+
};
58013+
58014+
return SplitOpsAndApply(DAG, Subtarget, DL, VT, {Acc, X, Y}, VPMADD52Builder,
58015+
/*CheckBWI*/ false);
58016+
}
58017+
5797358018
static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
5797458019
TargetLowering::DAGCombinerInfo &DCI,
5797558020
const X86Subtarget &Subtarget) {
@@ -58073,6 +58118,9 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
5807358118
Op0.getOperand(0), Op0.getOperand(2));
5807458119
}
5807558120

58121+
if (SDValue IFMA52 = matchVPMADD52(N, DAG, DL, VT, Subtarget))
58122+
return IFMA52;
58123+
5807658124
return combineAddOrSubToADCOrSBB(N, DL, DAG);
5807758125
}
5807858126

0 commit comments

Comments
 (0)