Skip to content

Commit

Permalink
[SDPatternMatch] Add m_CondCode, m_NoneOf, and some SExt improvements (
Browse files Browse the repository at this point in the history
…#90762)

  - Add m_CondCode to match the ISD::CondCode value from CondCodeSDNode
  - Add m_NoneOf combinator
  - m_SExt now recognizes sext_inreg
  • Loading branch information
mshockwave committed May 2, 2024
1 parent fbaba78 commit 0638e22
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 10 deletions.
73 changes: 63 additions & 10 deletions llvm/include/llvm/CodeGen/SDPatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,24 @@ struct Or<Pred, Preds...> : Or<Preds...> {
}
};

template <typename Pred> struct Not {
Pred P;

explicit Not(const Pred &P) : P(P) {}

template <typename MatchContext>
bool match(const MatchContext &Ctx, SDValue N) {
return !P.match(Ctx, N);
}
};
// Explicit deduction guide.
template <typename Pred> Not(const Pred &P) -> Not<Pred>;

/// Match if the inner pattern does NOT match.
template <typename Pred> inline Not<Pred> m_Unless(const Pred &P) {
return Not{P};
}

template <typename... Preds> And<Preds...> m_AllOf(Preds &&...preds) {
return And<Preds...>(std::forward<Preds>(preds)...);
}
Expand All @@ -366,6 +384,10 @@ template <typename... Preds> Or<Preds...> m_AnyOf(Preds &&...preds) {
return Or<Preds...>(std::forward<Preds>(preds)...);
}

template <typename... Preds> auto m_NoneOf(Preds &&...preds) {
return m_Unless(m_AnyOf(std::forward<Preds>(preds)...));
}

// === Generic node matching ===
template <unsigned OpIdx, typename... OpndPreds> struct Operands_match {
template <typename MatchContext>
Expand Down Expand Up @@ -620,8 +642,10 @@ template <typename Opnd> inline UnaryOpc_match<Opnd> m_ZExt(const Opnd &Op) {
return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op);
}

template <typename Opnd> inline UnaryOpc_match<Opnd> m_SExt(const Opnd &Op) {
return UnaryOpc_match<Opnd>(ISD::SIGN_EXTEND, Op);
template <typename Opnd> inline auto m_SExt(Opnd &&Op) {
return m_AnyOf(
UnaryOpc_match<Opnd>(ISD::SIGN_EXTEND, Op),
m_Node(ISD::SIGN_EXTEND_INREG, std::forward<Opnd>(Op), m_Value()));
}

template <typename Opnd> inline UnaryOpc_match<Opnd> m_AnyExt(const Opnd &Op) {
Expand All @@ -634,18 +658,14 @@ template <typename Opnd> inline UnaryOpc_match<Opnd> m_Trunc(const Opnd &Op) {

/// Match a zext or identity
/// Allows to peek through optional extensions
template <typename Opnd>
inline Or<UnaryOpc_match<Opnd>, Opnd> m_ZExtOrSelf(Opnd &&Op) {
return Or<UnaryOpc_match<Opnd>, Opnd>(m_ZExt(std::forward<Opnd>(Op)),
std::forward<Opnd>(Op));
template <typename Opnd> inline auto m_ZExtOrSelf(Opnd &&Op) {
return m_AnyOf(m_ZExt(std::forward<Opnd>(Op)), std::forward<Opnd>(Op));
}

/// Match a sext or identity
/// Allows to peek through optional extensions
template <typename Opnd>
inline Or<UnaryOpc_match<Opnd>, Opnd> m_SExtOrSelf(Opnd &&Op) {
return Or<UnaryOpc_match<Opnd>, Opnd>(m_SExt(std::forward<Opnd>(Op)),
std::forward<Opnd>(Op));
template <typename Opnd> inline auto m_SExtOrSelf(Opnd &&Op) {
return m_AnyOf(m_SExt(std::forward<Opnd>(Op)), std::forward<Opnd>(Op));
}

/// Match a aext or identity
Expand Down Expand Up @@ -768,6 +788,39 @@ inline auto m_False() {
m_Value()};
}

struct CondCode_match {
std::optional<ISD::CondCode> CCToMatch;
ISD::CondCode *BindCC = nullptr;

explicit CondCode_match(ISD::CondCode CC) : CCToMatch(CC) {}

explicit CondCode_match(ISD::CondCode *CC) : BindCC(CC) {}

template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
if (auto *CC = dyn_cast<CondCodeSDNode>(N.getNode())) {
if (CCToMatch && *CCToMatch != CC->get())
return false;

if (BindCC)
*BindCC = CC->get();
return true;
}

return false;
}
};

/// Match any conditional code SDNode.
inline CondCode_match m_CondCode() { return CondCode_match(nullptr); }
/// Match any conditional code SDNode and return its ISD::CondCode value.
inline CondCode_match m_CondCode(ISD::CondCode &CC) {
return CondCode_match(&CC);
}
/// Match a conditional code SDNode with a specific ISD::CondCode.
inline CondCode_match m_SpecificCondCode(ISD::CondCode CC) {
return CondCode_match(CC);
}

/// Match a negate as a sub(0, v)
template <typename ValTy>
inline BinaryOpc_match<SpecificInt_match, ValTy> m_Neg(const ValTy &V) {
Expand Down
13 changes: 13 additions & 0 deletions llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ TEST_F(SelectionDAGPatternMatchTest, matchConstants) {
SDValue Zero = DAG->getConstant(0, DL, Int32VT);
SDValue One = DAG->getConstant(1, DL, Int32VT);
SDValue AllOnes = DAG->getConstant(APInt::getAllOnes(32), DL, Int32VT);
SDValue SetCC = DAG->getSetCC(DL, Int32VT, Arg0, Const3, ISD::SETULT);

using namespace SDPatternMatch;
EXPECT_TRUE(sd_match(Const87, m_ConstInt()));
Expand All @@ -233,6 +234,13 @@ TEST_F(SelectionDAGPatternMatchTest, matchConstants) {
EXPECT_TRUE(sd_match(Zero, DAG.get(), m_False()));
EXPECT_TRUE(sd_match(One, DAG.get(), m_True()));
EXPECT_FALSE(sd_match(AllOnes, DAG.get(), m_True()));

ISD::CondCode CC;
EXPECT_TRUE(sd_match(
SetCC, m_Node(ISD::SETCC, m_Value(), m_Value(), m_CondCode(CC))));
EXPECT_EQ(CC, ISD::SETULT);
EXPECT_TRUE(sd_match(SetCC, m_Node(ISD::SETCC, m_Value(), m_Value(),
m_SpecificCondCode(ISD::SETULT))));
}

TEST_F(SelectionDAGPatternMatchTest, patternCombinators) {
Expand All @@ -249,6 +257,7 @@ TEST_F(SelectionDAGPatternMatchTest, patternCombinators) {
EXPECT_TRUE(sd_match(
Sub, m_AnyOf(m_Opc(ISD::ADD), m_Opc(ISD::SUB), m_Opc(ISD::MUL))));
EXPECT_TRUE(sd_match(Add, m_AllOf(m_Opc(ISD::ADD), m_OneUse())));
EXPECT_TRUE(sd_match(Add, m_NoneOf(m_Opc(ISD::SUB), m_Opc(ISD::MUL))));
}

TEST_F(SelectionDAGPatternMatchTest, optionalResizing) {
Expand All @@ -260,6 +269,8 @@ TEST_F(SelectionDAGPatternMatchTest, optionalResizing) {
SDValue Op64 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int64VT);
SDValue ZExt = DAG->getNode(ISD::ZERO_EXTEND, DL, Int64VT, Op32);
SDValue SExt = DAG->getNode(ISD::SIGN_EXTEND, DL, Int64VT, Op32);
SDValue SExtInReg = DAG->getNode(ISD::SIGN_EXTEND_INREG, DL, Int64VT, Op64,
DAG->getValueType(Int32VT));
SDValue AExt = DAG->getNode(ISD::ANY_EXTEND, DL, Int64VT, Op32);
SDValue Trunc = DAG->getNode(ISD::TRUNCATE, DL, Int32VT, Op64);

Expand All @@ -273,6 +284,8 @@ TEST_F(SelectionDAGPatternMatchTest, optionalResizing) {
EXPECT_TRUE(A == Op64);
EXPECT_TRUE(sd_match(SExt, m_SExtOrSelf(m_Value(A))));
EXPECT_TRUE(A == Op32);
EXPECT_TRUE(sd_match(SExtInReg, m_SExtOrSelf(m_Value(A))));
EXPECT_TRUE(A == Op64);
EXPECT_TRUE(sd_match(Op32, m_AExtOrSelf(m_Value(A))));
EXPECT_TRUE(A == Op32);
EXPECT_TRUE(sd_match(AExt, m_AExtOrSelf(m_Value(A))));
Expand Down

0 comments on commit 0638e22

Please sign in to comment.