Skip to content

Commit 0ab24f8

Browse files
committed
[ISel] DAGCombine clmul -> clmul[hr]
1 parent 1cf2f7d commit 0ab24f8

File tree

4 files changed

+8314
-7
lines changed

4 files changed

+8314
-7
lines changed

llvm/include/llvm/CodeGen/SDPatternMatch.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -919,6 +919,11 @@ inline BinaryOpc_match<LHS, RHS> m_Rotr(const LHS &L, const RHS &R) {
919919
return BinaryOpc_match<LHS, RHS>(ISD::ROTR, L, R);
920920
}
921921

922+
template <typename LHS, typename RHS>
923+
inline BinaryOpc_match<LHS, RHS> m_Clmul(const LHS &L, const RHS &R) {
924+
return BinaryOpc_match<LHS, RHS>(ISD::CLMUL, L, R);
925+
}
926+
922927
template <typename LHS, typename RHS>
923928
inline BinaryOpc_match<LHS, RHS, true> m_FAdd(const LHS &L, const RHS &R) {
924929
return BinaryOpc_match<LHS, RHS, true>(ISD::FADD, L, R);

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10315,6 +10315,25 @@ SDValue DAGCombiner::visitShiftByConstant(SDNode *N) {
1031510315
if (SDValue R = combineShiftOfShiftedLogic(N, DAG))
1031610316
return R;
1031710317

10318+
// Fold clmul(zext(x), zext(y)) >> (BW - 1 | BW) -> clmul(r|h)(x, y).
10319+
SDLoc DL(N);
10320+
EVT VT = N->getValueType(0);
10321+
SDValue X, Y;
10322+
if (sd_match(N, m_Srl(m_Clmul(m_ZExt(m_Value(X)), m_ZExt(m_Value(Y))),
10323+
m_SpecificInt(VT.getScalarSizeInBits() / 2 - 1))))
10324+
return DAG.getNode(ISD::ZERO_EXTEND, DL, VT,
10325+
DAG.getNode(ISD::CLMULR, DL, X.getValueType(), X, Y));
10326+
if (sd_match(N, m_Srl(m_Clmul(m_ZExt(m_Value(X)), m_ZExt(m_Value(Y))),
10327+
m_SpecificInt(VT.getScalarSizeInBits() / 2))))
10328+
return DAG.getNode(ISD::ZERO_EXTEND, DL, VT,
10329+
DAG.getNode(ISD::CLMULH, DL, X.getValueType(), X, Y));
10330+
10331+
// Fold bitreverse(clmul(bitreverse(x), bitreverse(y))) >> 1 -> clmulh(x, y).
10332+
if (sd_match(N, m_Srl(m_BitReverse(m_Clmul(m_BitReverse(m_Value(X)),
10333+
m_BitReverse(m_Value(Y)))),
10334+
m_SpecificInt(1))))
10335+
return DAG.getNode(ISD::CLMULH, DL, VT, X, Y);
10336+
1031810337
// We want to pull some binops through shifts, so that we have (and (shift))
1031910338
// instead of (shift (and)), likewise for add, or, xor, etc. This sort of
1032010339
// thing happens with address calculations, so it's important to canonicalize
@@ -10350,8 +10369,6 @@ SDValue DAGCombiner::visitShiftByConstant(SDNode *N) {
1035010369
return SDValue();
1035110370

1035210371
// Attempt to fold the constants, shifting the binop RHS by the shift amount.
10353-
SDLoc DL(N);
10354-
EVT VT = N->getValueType(0);
1035510372
if (SDValue NewRHS = DAG.FoldConstantArithmetic(
1035610373
N->getOpcode(), DL, VT, {LHS.getOperand(1), N->getOperand(1)})) {
1035710374
SDValue NewShift = DAG.getNode(N->getOpcode(), DL, VT, LHS.getOperand(0),
@@ -11771,6 +11788,11 @@ SDValue DAGCombiner::visitBITREVERSE(SDNode *N) {
1177111788
sd_match(N, m_BitReverse(m_Shl(m_BitReverse(m_Value(X)), m_Value(Y)))))
1177211789
return DAG.getNode(ISD::SRL, DL, VT, X, Y);
1177311790

11791+
// fold bitreverse(clmul(bitreverse(x), bitreverse(y))) -> clmulr(x, y)
11792+
if (sd_match(N, m_BitReverse(m_Clmul(m_BitReverse(m_Value(X)),
11793+
m_BitReverse(m_Value(Y))))))
11794+
return DAG.getNode(ISD::CLMULR, DL, VT, X, Y);
11795+
1177411796
return SDValue();
1177511797
}
1177611798

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8307,13 +8307,14 @@ SDValue TargetLowering::expandCLMUL(SDNode *Node, SelectionDAG &DAG) const {
83078307
SDValue X = Node->getOperand(0);
83088308
SDValue Y = Node->getOperand(1);
83098309
unsigned BW = VT.getScalarSizeInBits();
8310+
unsigned Opcode = Node->getOpcode();
83108311

8311-
if (VT.isVector() && isOperationLegalOrCustomOrPromote(
8312-
Node->getOpcode(), VT.getVectorElementType()))
8312+
if (VT.isVector() &&
8313+
isOperationLegalOrCustomOrPromote(Opcode, VT.getVectorElementType()))
83138314
return DAG.UnrollVectorOp(Node);
83148315

83158316
SDValue Res = DAG.getConstant(0, DL, VT);
8316-
switch (Node->getOpcode()) {
8317+
switch (Opcode) {
83178318
case ISD::CLMUL: {
83188319
for (unsigned I = 0; I < BW; ++I) {
83198320
SDValue Mask = DAG.getConstant(APInt::getOneBitSet(BW, I), DL, VT);
@@ -8326,12 +8327,26 @@ SDValue TargetLowering::expandCLMUL(SDNode *Node, SelectionDAG &DAG) const {
83268327
case ISD::CLMULR:
83278328
case ISD::CLMULH: {
83288329
EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), 2 * BW);
8330+
// For example, ExtVT = i64 based operations aren't legal on rv32; use
8331+
// bitreverse-based lowering in this case.
8332+
if (!isOperationLegalOrCustom(ISD::ZERO_EXTEND, ExtVT) ||
8333+
!isOperationLegalOrCustom(ISD::SRL, ExtVT)) {
8334+
SDValue XRev = DAG.getNode(ISD::BITREVERSE, DL, VT, X);
8335+
SDValue YRev = DAG.getNode(ISD::BITREVERSE, DL, VT, Y);
8336+
SDValue ClMul = DAG.getNode(ISD::CLMUL, DL, VT, XRev, YRev);
8337+
Res = DAG.getNode(ISD::BITREVERSE, DL, VT, ClMul);
8338+
Res = Opcode == ISD::CLMULR
8339+
? Res
8340+
: DAG.getNode(ISD::SRL, DL, VT, Res,
8341+
DAG.getShiftAmountConstant(1, VT, DL));
8342+
break;
8343+
}
83298344
SDValue XExt = DAG.getNode(ISD::ZERO_EXTEND, DL, ExtVT, X);
83308345
SDValue YExt = DAG.getNode(ISD::ZERO_EXTEND, DL, ExtVT, Y);
83318346
SDValue ClMul = DAG.getNode(ISD::CLMUL, DL, ExtVT, XExt, YExt);
8332-
unsigned ShtAmt = Node->getOpcode() == ISD::CLMULR ? BW - 1 : BW;
8347+
unsigned ShtAmt = Opcode == ISD::CLMULR ? BW - 1 : BW;
83338348
SDValue HiBits = DAG.getNode(ISD::SRL, DL, ExtVT, ClMul,
8334-
DAG.getShiftAmountConstant(ShtAmt, VT, DL));
8349+
DAG.getShiftAmountConstant(ShtAmt, ExtVT, DL));
83358350
Res = DAG.getNode(ISD::TRUNCATE, DL, VT, HiBits);
83368351
break;
83378352
}

0 commit comments

Comments
 (0)