Skip to content

Commit

Permalink
[PatternMatch] Add matchers for m_{I,F,}Cmp and `m_{I,F,}SpecificCm…
Browse files Browse the repository at this point in the history
…p`; NFC

These matchers either take no predicate argument or match a specific
predicate respectively.

We have a lot of cases where the Pred argument is either unused and
requiring the argument reduces code clarity.

Likewise we have a lot of cases where we only pass in Pred to test
equality which the new `*Specific*` helpers can simplify.
  • Loading branch information
goldsteinn committed Jul 10, 2024
1 parent 015526b commit 6d74dbe
Show file tree
Hide file tree
Showing 2 changed files with 214 additions and 7 deletions.
77 changes: 71 additions & 6 deletions llvm/include/llvm/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -1548,25 +1548,38 @@ template <typename T> inline Exact_match<T> m_Exact(const T &SubPattern) {
//

template <typename LHS_t, typename RHS_t, typename Class, typename PredicateTy,
bool Commutable = false>
bool Commutable = false, bool MatchExistingPred = false>
struct CmpClass_match {
PredicateTy &Predicate;
static_assert(!Commutable || !MatchExistingPred,
"Can't match predicate when using commutable matcher");

// Make predicate ty const ref if we are matching. Not strictly necessary but
// will cause a compilation warning if we accidentally try to set it with
// MatchExistingPred enabled.
using InternalPredTy =
std::conditional_t<MatchExistingPred, const PredicateTy &, PredicateTy &>;
InternalPredTy Predicate;
LHS_t L;
RHS_t R;

// The evaluation order is always stable, regardless of Commutability.
// The LHS is always matched first.
CmpClass_match(PredicateTy &Pred, const LHS_t &LHS, const RHS_t &RHS)
CmpClass_match(InternalPredTy Pred, const LHS_t &LHS, const RHS_t &RHS)
: Predicate(Pred), L(LHS), R(RHS) {}

template <typename OpTy> bool match(OpTy *V) {
if (auto *I = dyn_cast<Class>(V)) {
if (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) {
Predicate = I->getPredicate();
return true;
if constexpr (MatchExistingPred)
return I->getPredicate() == Predicate;
else {
Predicate = I->getPredicate();
return true;
}
} else if (Commutable && L.match(I->getOperand(1)) &&
R.match(I->getOperand(0))) {
Predicate = I->getSwappedPredicate();
if constexpr (!MatchExistingPred)
Predicate = I->getSwappedPredicate();
return true;
}
}
Expand All @@ -1592,6 +1605,50 @@ m_FCmp(FCmpInst::Predicate &Pred, const LHS &L, const RHS &R) {
return CmpClass_match<LHS, RHS, FCmpInst, FCmpInst::Predicate>(Pred, L, R);
}

template <typename LHS, typename RHS>
inline CmpClass_match<LHS, RHS, CmpInst, CmpInst::Predicate>
m_Cmp(const LHS &L, const RHS &R) {
CmpInst::Predicate Unused;
return CmpClass_match<LHS, RHS, CmpInst, CmpInst::Predicate>(Unused, L, R);
}

template <typename LHS, typename RHS>
inline CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate>
m_ICmp(const LHS &L, const RHS &R) {
ICmpInst::Predicate Unused;
return CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate>(Unused, L, R);
}

template <typename LHS, typename RHS>
inline CmpClass_match<LHS, RHS, FCmpInst, FCmpInst::Predicate>
m_FCmp(const LHS &L, const RHS &R) {
FCmpInst::Predicate Unused;
return CmpClass_match<LHS, RHS, FCmpInst, FCmpInst::Predicate>(Unused, L, R);
}

template <typename LHS, typename RHS>
inline CmpClass_match<LHS, RHS, CmpInst, CmpInst::Predicate, false, true>
m_SpecificCmp(const CmpInst::Predicate &MatchPred, const LHS &L, const RHS &R) {
return CmpClass_match<LHS, RHS, CmpInst, CmpInst::Predicate, false, true>(
MatchPred, L, R);
}

template <typename LHS, typename RHS>
inline CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate, false, true>
m_SpecificICmp(const ICmpInst::Predicate &MatchPred, const LHS &L,
const RHS &R) {
return CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate, false, true>(
MatchPred, L, R);
}

template <typename LHS, typename RHS>
inline CmpClass_match<LHS, RHS, FCmpInst, FCmpInst::Predicate, false, true>
m_SpecificFCmp(const FCmpInst::Predicate &MatchPred, const LHS &L,
const RHS &R) {
return CmpClass_match<LHS, RHS, FCmpInst, FCmpInst::Predicate, false, true>(
MatchPred, L, R);
}

//===----------------------------------------------------------------------===//
// Matchers for instructions with a given opcode and number of operands.
//
Expand Down Expand Up @@ -2617,6 +2674,14 @@ m_c_ICmp(ICmpInst::Predicate &Pred, const LHS &L, const RHS &R) {
R);
}

template <typename LHS, typename RHS>
inline CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate, true>
m_c_ICmp(const LHS &L, const RHS &R) {
ICmpInst::Predicate Unused;
return CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate, true>(Unused,
L, R);
}

/// Matches a specific opcode with LHS and RHS in either order.
template <typename LHS, typename RHS>
inline SpecificBinaryOp_match<LHS, RHS, true>
Expand Down
144 changes: 143 additions & 1 deletion llvm/unittests/IR/PatternMatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2250,9 +2250,151 @@ TYPED_TEST(MutableConstTest, ICmp) {
ICmpInst::Predicate MatchPred;

EXPECT_TRUE(m_ICmp(MatchPred, m_Value(MatchL), m_Value(MatchR))
.match((InstructionType)IRB.CreateICmp(Pred, L, R)));
.match((InstructionType)IRB.CreateICmp(Pred, L, R)));
EXPECT_EQ(L, MatchL);
EXPECT_EQ(R, MatchR);

EXPECT_TRUE(m_Cmp(MatchPred, m_Value(MatchL), m_Value(MatchR))
.match((InstructionType)IRB.CreateICmp(Pred, L, R)));
EXPECT_EQ(L, MatchL);
EXPECT_EQ(R, MatchR);

EXPECT_TRUE(m_ICmp(m_Specific(L), m_Specific(R))
.match((InstructionType)IRB.CreateICmp(Pred, L, R)));

EXPECT_TRUE(m_Cmp(m_Specific(L), m_Specific(R))
.match((InstructionType)IRB.CreateICmp(Pred, L, R)));

EXPECT_FALSE(m_ICmp(m_Specific(R), m_Specific(L))
.match((InstructionType)IRB.CreateICmp(Pred, L, R)));
EXPECT_FALSE(m_Cmp(m_Specific(R), m_Specific(L))
.match((InstructionType)IRB.CreateICmp(Pred, L, R)));

EXPECT_TRUE(m_c_ICmp(m_Specific(R), m_Specific(L))
.match((InstructionType)IRB.CreateICmp(Pred, L, R)));

EXPECT_FALSE(m_c_ICmp(m_Specific(R), m_Specific(R))
.match((InstructionType)IRB.CreateICmp(Pred, L, R)));

EXPECT_TRUE(m_SpecificICmp(Pred, m_Specific(L), m_Specific(R))
.match((InstructionType)IRB.CreateICmp(Pred, L, R)));
EXPECT_TRUE(m_SpecificCmp(Pred, m_Specific(L), m_Specific(R))
.match((InstructionType)IRB.CreateICmp(Pred, L, R)));

EXPECT_FALSE(m_SpecificICmp(Pred, m_Specific(R), m_Specific(L))
.match((InstructionType)IRB.CreateICmp(Pred, L, R)));
EXPECT_FALSE(m_SpecificCmp(Pred, m_Specific(R), m_Specific(L))
.match((InstructionType)IRB.CreateICmp(Pred, L, R)));

MatchL = nullptr;
MatchR = nullptr;
EXPECT_TRUE(m_SpecificICmp(Pred, m_Value(MatchL), m_Value(MatchR))
.match((InstructionType)IRB.CreateICmp(Pred, L, R)));
EXPECT_EQ(L, MatchL);
EXPECT_EQ(R, MatchR);
MatchL = nullptr;
MatchR = nullptr;
EXPECT_TRUE(m_SpecificCmp(Pred, m_Value(MatchL), m_Value(MatchR))
.match((InstructionType)IRB.CreateICmp(Pred, L, R)));
EXPECT_EQ(L, MatchL);
EXPECT_EQ(R, MatchR);

EXPECT_FALSE(m_SpecificICmp(Pred, m_Specific(R), m_Specific(L))
.match((InstructionType)IRB.CreateICmp(Pred, L, R)));
EXPECT_FALSE(m_SpecificCmp(Pred, m_Specific(R), m_Specific(L))
.match((InstructionType)IRB.CreateICmp(Pred, L, R)));

EXPECT_FALSE(m_SpecificICmp(ICmpInst::getInversePredicate(Pred),
m_Specific(L), m_Specific(R))
.match((InstructionType)IRB.CreateICmp(Pred, L, R)));
EXPECT_FALSE(m_SpecificCmp(ICmpInst::getInversePredicate(Pred), m_Specific(L),
m_Specific(R))
.match((InstructionType)IRB.CreateICmp(Pred, L, R)));

EXPECT_FALSE(m_SpecificICmp(ICmpInst::getInversePredicate(Pred),
m_Value(MatchL), m_Value(MatchR))
.match((InstructionType)IRB.CreateICmp(Pred, L, R)));
EXPECT_FALSE(m_SpecificCmp(ICmpInst::getInversePredicate(Pred),
m_Value(MatchL), m_Value(MatchR))
.match((InstructionType)IRB.CreateICmp(Pred, L, R)));
}

TYPED_TEST(MutableConstTest, FCmp) {
auto &IRB = PatternMatchTest::IRB;

typedef std::tuple_element_t<0, TypeParam> ValueType;
typedef std::tuple_element_t<1, TypeParam> InstructionType;

Value *L = Constant::getNullValue(IRB.getFloatTy());
Value *R = ConstantFP::getInfinity(IRB.getFloatTy(), true);
FCmpInst::Predicate Pred = FCmpInst::FCMP_OGT;

ValueType MatchL;
ValueType MatchR;
FCmpInst::Predicate MatchPred;

EXPECT_TRUE(m_FCmp(MatchPred, m_Value(MatchL), m_Value(MatchR))
.match((InstructionType)IRB.CreateFCmp(Pred, L, R)));
EXPECT_EQ(L, MatchL);
EXPECT_EQ(R, MatchR);

EXPECT_TRUE(m_Cmp(MatchPred, m_Value(MatchL), m_Value(MatchR))
.match((InstructionType)IRB.CreateFCmp(Pred, L, R)));
EXPECT_EQ(L, MatchL);
EXPECT_EQ(R, MatchR);

EXPECT_TRUE(m_FCmp(m_Specific(L), m_Specific(R))
.match((InstructionType)IRB.CreateFCmp(Pred, L, R)));

EXPECT_TRUE(m_Cmp(m_Specific(L), m_Specific(R))
.match((InstructionType)IRB.CreateFCmp(Pred, L, R)));

EXPECT_FALSE(m_FCmp(m_Specific(R), m_Specific(L))
.match((InstructionType)IRB.CreateFCmp(Pred, L, R)));
EXPECT_FALSE(m_Cmp(m_Specific(R), m_Specific(L))
.match((InstructionType)IRB.CreateFCmp(Pred, L, R)));

EXPECT_TRUE(m_SpecificFCmp(Pred, m_Specific(L), m_Specific(R))
.match((InstructionType)IRB.CreateFCmp(Pred, L, R)));
EXPECT_TRUE(m_SpecificCmp(Pred, m_Specific(L), m_Specific(R))
.match((InstructionType)IRB.CreateFCmp(Pred, L, R)));

EXPECT_FALSE(m_SpecificFCmp(Pred, m_Specific(R), m_Specific(L))
.match((InstructionType)IRB.CreateFCmp(Pred, L, R)));
EXPECT_FALSE(m_SpecificCmp(Pred, m_Specific(R), m_Specific(L))
.match((InstructionType)IRB.CreateFCmp(Pred, L, R)));

MatchL = nullptr;
MatchR = nullptr;
EXPECT_TRUE(m_SpecificFCmp(Pred, m_Value(MatchL), m_Value(MatchR))
.match((InstructionType)IRB.CreateFCmp(Pred, L, R)));
EXPECT_EQ(L, MatchL);
EXPECT_EQ(R, MatchR);
MatchL = nullptr;
MatchR = nullptr;
EXPECT_TRUE(m_SpecificCmp(Pred, m_Value(MatchL), m_Value(MatchR))
.match((InstructionType)IRB.CreateFCmp(Pred, L, R)));
EXPECT_EQ(L, MatchL);
EXPECT_EQ(R, MatchR);

EXPECT_FALSE(m_SpecificFCmp(Pred, m_Specific(R), m_Specific(L))
.match((InstructionType)IRB.CreateFCmp(Pred, L, R)));
EXPECT_FALSE(m_SpecificCmp(Pred, m_Specific(R), m_Specific(L))
.match((InstructionType)IRB.CreateFCmp(Pred, L, R)));

EXPECT_FALSE(m_SpecificFCmp(FCmpInst::getInversePredicate(Pred),
m_Specific(L), m_Specific(R))
.match((InstructionType)IRB.CreateFCmp(Pred, L, R)));
EXPECT_FALSE(m_SpecificCmp(FCmpInst::getInversePredicate(Pred), m_Specific(L),
m_Specific(R))
.match((InstructionType)IRB.CreateFCmp(Pred, L, R)));

EXPECT_FALSE(m_SpecificFCmp(FCmpInst::getInversePredicate(Pred),
m_Value(MatchL), m_Value(MatchR))
.match((InstructionType)IRB.CreateFCmp(Pred, L, R)));
EXPECT_FALSE(m_SpecificCmp(FCmpInst::getInversePredicate(Pred),
m_Value(MatchL), m_Value(MatchR))
.match((InstructionType)IRB.CreateFCmp(Pred, L, R)));
}

TEST_F(PatternMatchTest, ConstExpr) {
Expand Down

0 comments on commit 6d74dbe

Please sign in to comment.