Skip to content

[DAG] SDPatternMatch - add matchers for reassociatable binops #119985

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions llvm/include/llvm/CodeGen/SDPatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
#define LLVM_CODEGEN_SDPATTERNMATCH_H

#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/SelectionDAGNodes.h"
#include "llvm/CodeGen/TargetLowering.h"
Expand Down Expand Up @@ -1134,6 +1136,87 @@ inline BinaryOpc_match<ValTy, AllOnes_match, true> m_Not(const ValTy &V) {
return m_Xor(V, m_AllOnes());
}

template <typename... PatternTs> struct ReassociatableOpc_match {
unsigned Opcode;
std::tuple<PatternTs...> Patterns;

ReassociatableOpc_match(unsigned Opcode, const PatternTs &...Patterns)
: Opcode(Opcode), Patterns(Patterns...) {}

template <typename MatchContext>
bool match(const MatchContext &Ctx, SDValue N) {
SmallVector<SDValue> Leaves;
collectLeaves(N, Leaves);
if (Leaves.size() != std::tuple_size_v<std::tuple<PatternTs...>>)
return false;

// Matches[I][J] == true iff sd_context_match(Leaves[I], Ctx,
// std::get<J>(Patterns)) == true
std::array<SmallBitVector, std::tuple_size_v<std::tuple<PatternTs...>>>
Matches;
for (size_t I = 0, N = Leaves.size(); I < N; I++) {
SmallVector<bool> MatchResults;
std::apply(
[&](auto &...P) {
(Matches[I].push_back(sd_context_match(Leaves[I], Ctx, P)), ...);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Esan5 I'm hitting an issue with this with m_Value() leaves that can match in multiple positions - it looks like the last leaf is binding to every m_Value() instead of sharing the matches correctly - do we need to run the matches again after reassociatableMatchHelper to ensure that each leaf is allocated and matched to a single pattern? I think technically this could still fail with m_Deferred() matches wdyt?

define i64 @test_lsb_i64(i64 %a0, i64 %a1) nounwind {
  %s0 = lshr i64 %a0, 1
  %s1 = lshr i64 %a1, 1
  %s = add i64 %s1, %s0
  %m0 = and i64 %a0, 1
  %m1 = and i64 %m0, %a1
  %res = add i64 %s, %m1
  ret i64 %res
}
    if (sd_match(N, m_ReassociatableAdd(m_Srl(m_Value(A), m_SpecificInt(1)),
                                        m_Srl(m_Value(B), m_SpecificInt(1)),
                                        m_Value(C))))
      if (sd_match(C, m_ReassociatableAnd(m_Specific(A), m_Specific(B),
                                          m_SpecificInt(1))))
        return DAG.getNode(ISD::AVGFLOORU, DL, VT, A, B);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may be misunderstanding, but is the intention here for

if (sd_match(N, m_ReassociatableAdd(m_Srl(m_Value(A), m_SpecificInt(1)),
                                        m_Srl(m_Value(B), m_SpecificInt(1)),
                                        m_Value(C))))

to match

%s0 = lshr i64 %a0, 1
%s1 = lshr i64 %a1, 1
%s = add i64 %s1, %s0

The code in this helper function should prevent a given sub-pattern from being used to match more then one leaf by tracking which patterns have already been used.

Unfortunately, I don't have enough experience with LLVM to identify the intended behavior, can you clarify what this test case should do?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was expecting m_Value(A) to match one of %s0 or %s1 and m_Value(B) to match the other - but instead both are matching to the same one.

I was really hoping to do this, but I doubt the current implementation can manage it:

if (sd_match(N, m_ReassociatableAdd(m_Srl(m_Value(A), m_SpecificInt(1)),
                                    m_Srl(m_Value(B), m_SpecificInt(1)),
                                    m_ReassociatableAnd(m_Deferred(A), m_Deferred(B), m_SpecificInt(1)))))
  return DAG.getNode(ISD::AVGFLOORU, DL, VT, A, B);

},
Patterns);
}

SmallBitVector Used(std::tuple_size_v<std::tuple<PatternTs...>>);
return reassociatableMatchHelper(Matches, Used);
}

void collectLeaves(SDValue V, SmallVector<SDValue> &Leaves) {
if (V->getOpcode() == Opcode) {
for (size_t I = 0, N = V->getNumOperands(); I < N; I++)
collectLeaves(V->getOperand(I), Leaves);
} else {
Leaves.emplace_back(V);
}
}

[[nodiscard]] inline bool
reassociatableMatchHelper(const ArrayRef<SmallBitVector> Matches,
SmallBitVector &Used, size_t Curr = 0) {
if (Curr == Matches.size())
return true;
for (size_t Match = 0, N = Matches[Curr].size(); Match < N; Match++) {
if (!Matches[Curr][Match] || Used[Match])
continue;
Used[Match] = true;
if (reassociatableMatchHelper(Matches, Used, Curr + 1))
return true;
Used[Match] = false;
}
return false;
}
};

template <typename... PatternTs>
inline ReassociatableOpc_match<PatternTs...>
m_ReassociatableAdd(const PatternTs &...Patterns) {
return ReassociatableOpc_match<PatternTs...>(ISD::ADD, Patterns...);
}

template <typename... PatternTs>
inline ReassociatableOpc_match<PatternTs...>
m_ReassociatableOr(const PatternTs &...Patterns) {
return ReassociatableOpc_match<PatternTs...>(ISD::OR, Patterns...);
}

template <typename... PatternTs>
inline ReassociatableOpc_match<PatternTs...>
m_ReassociatableAnd(const PatternTs &...Patterns) {
return ReassociatableOpc_match<PatternTs...>(ISD::AND, Patterns...);
}

template <typename... PatternTs>
inline ReassociatableOpc_match<PatternTs...>
m_ReassociatableMul(const PatternTs &...Patterns) {
return ReassociatableOpc_match<PatternTs...>(ISD::MUL, Patterns...);
}

} // namespace SDPatternMatch
} // namespace llvm
#endif
125 changes: 125 additions & 0 deletions llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -651,3 +651,128 @@ TEST_F(SelectionDAGPatternMatchTest, matchAdvancedProperties) {
EXPECT_TRUE(sd_match(Add, DAG.get(),
m_LegalOp(m_IntegerVT(m_Add(m_Value(), m_Value())))));
}

TEST_F(SelectionDAGPatternMatchTest, matchReassociatableOp) {
using namespace SDPatternMatch;

SDLoc DL;
auto Int32VT = EVT::getIntegerVT(Context, 32);

SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Int32VT);
SDValue Op3 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 8, Int32VT);

// (Op0 + Op1) + (Op2 + Op3)
SDValue ADD01 = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1);
SDValue ADD23 = DAG->getNode(ISD::ADD, DL, Int32VT, Op2, Op3);
SDValue ADD = DAG->getNode(ISD::ADD, DL, Int32VT, ADD01, ADD23);

EXPECT_FALSE(sd_match(ADD01, m_ReassociatableAdd(m_Value())));
EXPECT_TRUE(sd_match(ADD01, m_ReassociatableAdd(m_Value(), m_Value())));
EXPECT_TRUE(sd_match(ADD23, m_ReassociatableAdd(m_Value(), m_Value())));
EXPECT_TRUE(sd_match(
ADD, m_ReassociatableAdd(m_Value(), m_Value(), m_Value(), m_Value())));

// Op0 + (Op1 + (Op2 + Op3))
SDValue ADD123 = DAG->getNode(ISD::ADD, DL, Int32VT, Op1, ADD23);
SDValue ADD0123 = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, ADD123);
EXPECT_TRUE(
sd_match(ADD123, m_ReassociatableAdd(m_Value(), m_Value(), m_Value())));
EXPECT_TRUE(sd_match(ADD0123, m_ReassociatableAdd(m_Value(), m_Value(),
m_Value(), m_Value())));

// (Op0 - Op1) + (Op2 - Op3)
SDValue SUB01 = DAG->getNode(ISD::SUB, DL, Int32VT, Op0, Op1);
SDValue SUB23 = DAG->getNode(ISD::SUB, DL, Int32VT, Op2, Op3);
SDValue ADDS0123 = DAG->getNode(ISD::ADD, DL, Int32VT, SUB01, SUB23);

EXPECT_FALSE(sd_match(SUB01, m_ReassociatableAdd(m_Value(), m_Value())));
EXPECT_FALSE(sd_match(ADDS0123, m_ReassociatableAdd(m_Value(), m_Value(),
m_Value(), m_Value())));

// SUB + SUB matches (Op0 - Op1) + (Op2 - Op3)
EXPECT_TRUE(
sd_match(ADDS0123, m_ReassociatableAdd(m_Sub(m_Value(), m_Value()),
m_Sub(m_Value(), m_Value()))));
EXPECT_FALSE(sd_match(ADDS0123, m_ReassociatableAdd(m_Value(), m_Value(),
m_Value(), m_Value())));

// (Op0 * Op1) * (Op2 * Op3)
SDValue MUL01 = DAG->getNode(ISD::MUL, DL, Int32VT, Op0, Op1);
SDValue MUL23 = DAG->getNode(ISD::MUL, DL, Int32VT, Op2, Op3);
SDValue MUL = DAG->getNode(ISD::MUL, DL, Int32VT, MUL01, MUL23);

EXPECT_TRUE(sd_match(MUL01, m_ReassociatableMul(m_Value(), m_Value())));
EXPECT_TRUE(sd_match(MUL23, m_ReassociatableMul(m_Value(), m_Value())));
EXPECT_TRUE(sd_match(
MUL, m_ReassociatableMul(m_Value(), m_Value(), m_Value(), m_Value())));

// Op0 * (Op1 * (Op2 * Op3))
SDValue MUL123 = DAG->getNode(ISD::MUL, DL, Int32VT, Op1, MUL23);
SDValue MUL0123 = DAG->getNode(ISD::MUL, DL, Int32VT, Op0, MUL123);
EXPECT_TRUE(
sd_match(MUL123, m_ReassociatableMul(m_Value(), m_Value(), m_Value())));
EXPECT_TRUE(sd_match(MUL0123, m_ReassociatableMul(m_Value(), m_Value(),
m_Value(), m_Value())));

// (Op0 - Op1) * (Op2 - Op3)
SDValue MULS0123 = DAG->getNode(ISD::MUL, DL, Int32VT, SUB01, SUB23);
EXPECT_TRUE(
sd_match(MULS0123, m_ReassociatableMul(m_Sub(m_Value(), m_Value()),
m_Sub(m_Value(), m_Value()))));
EXPECT_FALSE(sd_match(MULS0123, m_ReassociatableMul(m_Value(), m_Value(),
m_Value(), m_Value())));

// (Op0 && Op1) && (Op2 && Op3)
SDValue AND01 = DAG->getNode(ISD::AND, DL, Int32VT, Op0, Op1);
SDValue AND23 = DAG->getNode(ISD::AND, DL, Int32VT, Op2, Op3);
SDValue AND = DAG->getNode(ISD::AND, DL, Int32VT, AND01, AND23);

EXPECT_TRUE(sd_match(AND01, m_ReassociatableAnd(m_Value(), m_Value())));
EXPECT_TRUE(sd_match(AND23, m_ReassociatableAnd(m_Value(), m_Value())));
EXPECT_TRUE(sd_match(
AND, m_ReassociatableAnd(m_Value(), m_Value(), m_Value(), m_Value())));

// Op0 && (Op1 && (Op2 && Op3))
SDValue AND123 = DAG->getNode(ISD::AND, DL, Int32VT, Op1, AND23);
SDValue AND0123 = DAG->getNode(ISD::AND, DL, Int32VT, Op0, AND123);
EXPECT_TRUE(
sd_match(AND123, m_ReassociatableAnd(m_Value(), m_Value(), m_Value())));
EXPECT_TRUE(sd_match(AND0123, m_ReassociatableAnd(m_Value(), m_Value(),
m_Value(), m_Value())));

// (Op0 - Op1) && (Op2 - Op3)
SDValue ANDS0123 = DAG->getNode(ISD::AND, DL, Int32VT, SUB01, SUB23);
EXPECT_TRUE(
sd_match(ANDS0123, m_ReassociatableAnd(m_Sub(m_Value(), m_Value()),
m_Sub(m_Value(), m_Value()))));
EXPECT_FALSE(sd_match(ANDS0123, m_ReassociatableAnd(m_Value(), m_Value(),
m_Value(), m_Value())));

// (Op0 || Op1) || (Op2 || Op3)
SDValue OR01 = DAG->getNode(ISD::OR, DL, Int32VT, Op0, Op1);
SDValue OR23 = DAG->getNode(ISD::OR, DL, Int32VT, Op2, Op3);
SDValue OR = DAG->getNode(ISD::OR, DL, Int32VT, OR01, OR23);

EXPECT_TRUE(sd_match(OR01, m_ReassociatableOr(m_Value(), m_Value())));
EXPECT_TRUE(sd_match(OR23, m_ReassociatableOr(m_Value(), m_Value())));
EXPECT_TRUE(sd_match(
OR, m_ReassociatableOr(m_Value(), m_Value(), m_Value(), m_Value())));

// Op0 || (Op1 || (Op2 || Op3))
SDValue OR123 = DAG->getNode(ISD::OR, DL, Int32VT, Op1, OR23);
SDValue OR0123 = DAG->getNode(ISD::OR, DL, Int32VT, Op0, OR123);
EXPECT_TRUE(
sd_match(OR123, m_ReassociatableOr(m_Value(), m_Value(), m_Value())));
EXPECT_TRUE(sd_match(
OR0123, m_ReassociatableOr(m_Value(), m_Value(), m_Value(), m_Value())));

// (Op0 - Op1) || (Op2 - Op3)
SDValue ORS0123 = DAG->getNode(ISD::OR, DL, Int32VT, SUB01, SUB23);
EXPECT_TRUE(
sd_match(ORS0123, m_ReassociatableOr(m_Sub(m_Value(), m_Value()),
m_Sub(m_Value(), m_Value()))));
EXPECT_FALSE(sd_match(
ORS0123, m_ReassociatableOr(m_Value(), m_Value(), m_Value(), m_Value())));
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add some negative tests showing that it works as expected even when there are non-associative operations in the expression tree?