From 6d74dbe1e8b2a0e4643cc08d7e1386d5104b9b28 Mon Sep 17 00:00:00 2001 From: Noah Goldstein Date: Wed, 10 Jul 2024 16:13:31 +0800 Subject: [PATCH] [PatternMatch] Add matchers for `m_{I,F,}Cmp` and `m_{I,F,}SpecificCmp`; NFC These matchers either take no predicate argument or match a specific predicate respectively. We have a lot of cases where the Pred argument is either unused and requiring the argument reduces code clarity. Likewise we have a lot of cases where we only pass in Pred to test equality which the new `*Specific*` helpers can simplify. --- llvm/include/llvm/IR/PatternMatch.h | 77 +++++++++++++-- llvm/unittests/IR/PatternMatch.cpp | 144 +++++++++++++++++++++++++++- 2 files changed, 214 insertions(+), 7 deletions(-) diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h index d4e355431a27a2..f9f473fb4276ed 100644 --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -1548,25 +1548,38 @@ template inline Exact_match m_Exact(const T &SubPattern) { // template + bool Commutable = false, bool MatchExistingPred = false> struct CmpClass_match { - PredicateTy &Predicate; + static_assert(!Commutable || !MatchExistingPred, + "Can't match predicate when using commutable matcher"); + + // Make predicate ty const ref if we are matching. Not strictly necessary but + // will cause a compilation warning if we accidentally try to set it with + // MatchExistingPred enabled. + using InternalPredTy = + std::conditional_t; + InternalPredTy Predicate; LHS_t L; RHS_t R; // The evaluation order is always stable, regardless of Commutability. // The LHS is always matched first. - CmpClass_match(PredicateTy &Pred, const LHS_t &LHS, const RHS_t &RHS) + CmpClass_match(InternalPredTy Pred, const LHS_t &LHS, const RHS_t &RHS) : Predicate(Pred), L(LHS), R(RHS) {} template bool match(OpTy *V) { if (auto *I = dyn_cast(V)) { if (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) { - Predicate = I->getPredicate(); - return true; + if constexpr (MatchExistingPred) + return I->getPredicate() == Predicate; + else { + Predicate = I->getPredicate(); + return true; + } } else if (Commutable && L.match(I->getOperand(1)) && R.match(I->getOperand(0))) { - Predicate = I->getSwappedPredicate(); + if constexpr (!MatchExistingPred) + Predicate = I->getSwappedPredicate(); return true; } } @@ -1592,6 +1605,50 @@ m_FCmp(FCmpInst::Predicate &Pred, const LHS &L, const RHS &R) { return CmpClass_match(Pred, L, R); } +template +inline CmpClass_match +m_Cmp(const LHS &L, const RHS &R) { + CmpInst::Predicate Unused; + return CmpClass_match(Unused, L, R); +} + +template +inline CmpClass_match +m_ICmp(const LHS &L, const RHS &R) { + ICmpInst::Predicate Unused; + return CmpClass_match(Unused, L, R); +} + +template +inline CmpClass_match +m_FCmp(const LHS &L, const RHS &R) { + FCmpInst::Predicate Unused; + return CmpClass_match(Unused, L, R); +} + +template +inline CmpClass_match +m_SpecificCmp(const CmpInst::Predicate &MatchPred, const LHS &L, const RHS &R) { + return CmpClass_match( + MatchPred, L, R); +} + +template +inline CmpClass_match +m_SpecificICmp(const ICmpInst::Predicate &MatchPred, const LHS &L, + const RHS &R) { + return CmpClass_match( + MatchPred, L, R); +} + +template +inline CmpClass_match +m_SpecificFCmp(const FCmpInst::Predicate &MatchPred, const LHS &L, + const RHS &R) { + return CmpClass_match( + MatchPred, L, R); +} + //===----------------------------------------------------------------------===// // Matchers for instructions with a given opcode and number of operands. // @@ -2617,6 +2674,14 @@ m_c_ICmp(ICmpInst::Predicate &Pred, const LHS &L, const RHS &R) { R); } +template +inline CmpClass_match +m_c_ICmp(const LHS &L, const RHS &R) { + ICmpInst::Predicate Unused; + return CmpClass_match(Unused, + L, R); +} + /// Matches a specific opcode with LHS and RHS in either order. template inline SpecificBinaryOp_match diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp index 9f91b4f3f9939f..309fcc93996bc5 100644 --- a/llvm/unittests/IR/PatternMatch.cpp +++ b/llvm/unittests/IR/PatternMatch.cpp @@ -2250,9 +2250,151 @@ TYPED_TEST(MutableConstTest, ICmp) { ICmpInst::Predicate MatchPred; EXPECT_TRUE(m_ICmp(MatchPred, m_Value(MatchL), m_Value(MatchR)) - .match((InstructionType)IRB.CreateICmp(Pred, L, R))); + .match((InstructionType)IRB.CreateICmp(Pred, L, R))); EXPECT_EQ(L, MatchL); EXPECT_EQ(R, MatchR); + + EXPECT_TRUE(m_Cmp(MatchPred, m_Value(MatchL), m_Value(MatchR)) + .match((InstructionType)IRB.CreateICmp(Pred, L, R))); + EXPECT_EQ(L, MatchL); + EXPECT_EQ(R, MatchR); + + EXPECT_TRUE(m_ICmp(m_Specific(L), m_Specific(R)) + .match((InstructionType)IRB.CreateICmp(Pred, L, R))); + + EXPECT_TRUE(m_Cmp(m_Specific(L), m_Specific(R)) + .match((InstructionType)IRB.CreateICmp(Pred, L, R))); + + EXPECT_FALSE(m_ICmp(m_Specific(R), m_Specific(L)) + .match((InstructionType)IRB.CreateICmp(Pred, L, R))); + EXPECT_FALSE(m_Cmp(m_Specific(R), m_Specific(L)) + .match((InstructionType)IRB.CreateICmp(Pred, L, R))); + + EXPECT_TRUE(m_c_ICmp(m_Specific(R), m_Specific(L)) + .match((InstructionType)IRB.CreateICmp(Pred, L, R))); + + EXPECT_FALSE(m_c_ICmp(m_Specific(R), m_Specific(R)) + .match((InstructionType)IRB.CreateICmp(Pred, L, R))); + + EXPECT_TRUE(m_SpecificICmp(Pred, m_Specific(L), m_Specific(R)) + .match((InstructionType)IRB.CreateICmp(Pred, L, R))); + EXPECT_TRUE(m_SpecificCmp(Pred, m_Specific(L), m_Specific(R)) + .match((InstructionType)IRB.CreateICmp(Pred, L, R))); + + EXPECT_FALSE(m_SpecificICmp(Pred, m_Specific(R), m_Specific(L)) + .match((InstructionType)IRB.CreateICmp(Pred, L, R))); + EXPECT_FALSE(m_SpecificCmp(Pred, m_Specific(R), m_Specific(L)) + .match((InstructionType)IRB.CreateICmp(Pred, L, R))); + + MatchL = nullptr; + MatchR = nullptr; + EXPECT_TRUE(m_SpecificICmp(Pred, m_Value(MatchL), m_Value(MatchR)) + .match((InstructionType)IRB.CreateICmp(Pred, L, R))); + EXPECT_EQ(L, MatchL); + EXPECT_EQ(R, MatchR); + MatchL = nullptr; + MatchR = nullptr; + EXPECT_TRUE(m_SpecificCmp(Pred, m_Value(MatchL), m_Value(MatchR)) + .match((InstructionType)IRB.CreateICmp(Pred, L, R))); + EXPECT_EQ(L, MatchL); + EXPECT_EQ(R, MatchR); + + EXPECT_FALSE(m_SpecificICmp(Pred, m_Specific(R), m_Specific(L)) + .match((InstructionType)IRB.CreateICmp(Pred, L, R))); + EXPECT_FALSE(m_SpecificCmp(Pred, m_Specific(R), m_Specific(L)) + .match((InstructionType)IRB.CreateICmp(Pred, L, R))); + + EXPECT_FALSE(m_SpecificICmp(ICmpInst::getInversePredicate(Pred), + m_Specific(L), m_Specific(R)) + .match((InstructionType)IRB.CreateICmp(Pred, L, R))); + EXPECT_FALSE(m_SpecificCmp(ICmpInst::getInversePredicate(Pred), m_Specific(L), + m_Specific(R)) + .match((InstructionType)IRB.CreateICmp(Pred, L, R))); + + EXPECT_FALSE(m_SpecificICmp(ICmpInst::getInversePredicate(Pred), + m_Value(MatchL), m_Value(MatchR)) + .match((InstructionType)IRB.CreateICmp(Pred, L, R))); + EXPECT_FALSE(m_SpecificCmp(ICmpInst::getInversePredicate(Pred), + m_Value(MatchL), m_Value(MatchR)) + .match((InstructionType)IRB.CreateICmp(Pred, L, R))); +} + +TYPED_TEST(MutableConstTest, FCmp) { + auto &IRB = PatternMatchTest::IRB; + + typedef std::tuple_element_t<0, TypeParam> ValueType; + typedef std::tuple_element_t<1, TypeParam> InstructionType; + + Value *L = Constant::getNullValue(IRB.getFloatTy()); + Value *R = ConstantFP::getInfinity(IRB.getFloatTy(), true); + FCmpInst::Predicate Pred = FCmpInst::FCMP_OGT; + + ValueType MatchL; + ValueType MatchR; + FCmpInst::Predicate MatchPred; + + EXPECT_TRUE(m_FCmp(MatchPred, m_Value(MatchL), m_Value(MatchR)) + .match((InstructionType)IRB.CreateFCmp(Pred, L, R))); + EXPECT_EQ(L, MatchL); + EXPECT_EQ(R, MatchR); + + EXPECT_TRUE(m_Cmp(MatchPred, m_Value(MatchL), m_Value(MatchR)) + .match((InstructionType)IRB.CreateFCmp(Pred, L, R))); + EXPECT_EQ(L, MatchL); + EXPECT_EQ(R, MatchR); + + EXPECT_TRUE(m_FCmp(m_Specific(L), m_Specific(R)) + .match((InstructionType)IRB.CreateFCmp(Pred, L, R))); + + EXPECT_TRUE(m_Cmp(m_Specific(L), m_Specific(R)) + .match((InstructionType)IRB.CreateFCmp(Pred, L, R))); + + EXPECT_FALSE(m_FCmp(m_Specific(R), m_Specific(L)) + .match((InstructionType)IRB.CreateFCmp(Pred, L, R))); + EXPECT_FALSE(m_Cmp(m_Specific(R), m_Specific(L)) + .match((InstructionType)IRB.CreateFCmp(Pred, L, R))); + + EXPECT_TRUE(m_SpecificFCmp(Pred, m_Specific(L), m_Specific(R)) + .match((InstructionType)IRB.CreateFCmp(Pred, L, R))); + EXPECT_TRUE(m_SpecificCmp(Pred, m_Specific(L), m_Specific(R)) + .match((InstructionType)IRB.CreateFCmp(Pred, L, R))); + + EXPECT_FALSE(m_SpecificFCmp(Pred, m_Specific(R), m_Specific(L)) + .match((InstructionType)IRB.CreateFCmp(Pred, L, R))); + EXPECT_FALSE(m_SpecificCmp(Pred, m_Specific(R), m_Specific(L)) + .match((InstructionType)IRB.CreateFCmp(Pred, L, R))); + + MatchL = nullptr; + MatchR = nullptr; + EXPECT_TRUE(m_SpecificFCmp(Pred, m_Value(MatchL), m_Value(MatchR)) + .match((InstructionType)IRB.CreateFCmp(Pred, L, R))); + EXPECT_EQ(L, MatchL); + EXPECT_EQ(R, MatchR); + MatchL = nullptr; + MatchR = nullptr; + EXPECT_TRUE(m_SpecificCmp(Pred, m_Value(MatchL), m_Value(MatchR)) + .match((InstructionType)IRB.CreateFCmp(Pred, L, R))); + EXPECT_EQ(L, MatchL); + EXPECT_EQ(R, MatchR); + + EXPECT_FALSE(m_SpecificFCmp(Pred, m_Specific(R), m_Specific(L)) + .match((InstructionType)IRB.CreateFCmp(Pred, L, R))); + EXPECT_FALSE(m_SpecificCmp(Pred, m_Specific(R), m_Specific(L)) + .match((InstructionType)IRB.CreateFCmp(Pred, L, R))); + + EXPECT_FALSE(m_SpecificFCmp(FCmpInst::getInversePredicate(Pred), + m_Specific(L), m_Specific(R)) + .match((InstructionType)IRB.CreateFCmp(Pred, L, R))); + EXPECT_FALSE(m_SpecificCmp(FCmpInst::getInversePredicate(Pred), m_Specific(L), + m_Specific(R)) + .match((InstructionType)IRB.CreateFCmp(Pred, L, R))); + + EXPECT_FALSE(m_SpecificFCmp(FCmpInst::getInversePredicate(Pred), + m_Value(MatchL), m_Value(MatchR)) + .match((InstructionType)IRB.CreateFCmp(Pred, L, R))); + EXPECT_FALSE(m_SpecificCmp(FCmpInst::getInversePredicate(Pred), + m_Value(MatchL), m_Value(MatchR)) + .match((InstructionType)IRB.CreateFCmp(Pred, L, R))); } TEST_F(PatternMatchTest, ConstExpr) {