diff --git a/clang/lib/ASTMatchers/ASTMatchFinder.cpp b/clang/lib/ASTMatchers/ASTMatchFinder.cpp index f9bd1354fa8dc..0bac2ed63a927 100644 --- a/clang/lib/ASTMatchers/ASTMatchFinder.cpp +++ b/clang/lib/ASTMatchers/ASTMatchFinder.cpp @@ -18,8 +18,10 @@ #include "clang/ASTMatchers/ASTMatchFinder.h" #include "clang/AST/ASTConsumer.h" #include "clang/AST/ASTContext.h" +#include "clang/AST/DeclCXX.h" #include "clang/AST/RecursiveASTVisitor.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/StringMap.h" #include "llvm/Support/PrettyStackTrace.h" #include "llvm/Support/Timer.h" @@ -651,11 +653,20 @@ class MatchASTVisitor : public RecursiveASTVisitor, BoundNodesTreeBuilder *Builder, bool Directly) override; +private: + bool + classIsDerivedFromImpl(const CXXRecordDecl *Declaration, + const Matcher &Base, + BoundNodesTreeBuilder *Builder, bool Directly, + llvm::SmallPtrSetImpl &Visited); + +public: bool objcClassIsDerivedFrom(const ObjCInterfaceDecl *Declaration, const Matcher &Base, BoundNodesTreeBuilder *Builder, bool Directly) override; +public: // Implements ASTMatchFinder::matchesChildOf. bool matchesChildOf(const DynTypedNode &Node, ASTContext &Ctx, const DynTypedMatcher &Matcher, @@ -1361,8 +1372,18 @@ bool MatchASTVisitor::classIsDerivedFrom(const CXXRecordDecl *Declaration, const Matcher &Base, BoundNodesTreeBuilder *Builder, bool Directly) { + llvm::SmallPtrSet Visited; + return classIsDerivedFromImpl(Declaration, Base, Builder, Directly, Visited); +} + +bool MatchASTVisitor::classIsDerivedFromImpl( + const CXXRecordDecl *Declaration, const Matcher &Base, + BoundNodesTreeBuilder *Builder, bool Directly, + llvm::SmallPtrSetImpl &Visited) { if (!Declaration->hasDefinition()) return false; + if (!Visited.insert(Declaration).second) + return false; for (const auto &It : Declaration->bases()) { const Type *TypeNode = It.getType().getTypePtr(); @@ -1384,7 +1405,8 @@ bool MatchASTVisitor::classIsDerivedFrom(const CXXRecordDecl *Declaration, *Builder = std::move(Result); return true; } - if (!Directly && classIsDerivedFrom(ClassDecl, Base, Builder, Directly)) + if (!Directly && + classIsDerivedFromImpl(ClassDecl, Base, Builder, Directly, Visited)) return true; } return false; diff --git a/clang/unittests/ASTMatchers/ASTMatchersNodeTest.cpp b/clang/unittests/ASTMatchers/ASTMatchersNodeTest.cpp index 7a6d6ef0a9555..8f0dd5602307c 100644 --- a/clang/unittests/ASTMatchers/ASTMatchersNodeTest.cpp +++ b/clang/unittests/ASTMatchers/ASTMatchersNodeTest.cpp @@ -2369,6 +2369,80 @@ TEST_P(ASTMatchersTest, LambdaCaptureTest_BindsToCaptureOfReferenceType) { "}", matcher)); } +TEST_P(ASTMatchersTest, IsDerivedFromRecursion) { + if (!GetParam().isCXX11OrLater()) + return; + + // Check we don't crash on cycles in the traversal and inheritance hierarchy. + // Clang will normally enforce there are no cycles, but matchers opted to + // traverse primary template for dependent specializations, spuriously + // creating the cycles. + DeclarationMatcher matcher = cxxRecordDecl(isDerivedFrom("X")); + EXPECT_TRUE(notMatches(R"cpp( + template + struct M; + + template + struct M {}; + + template + struct L : M {}; + + template + struct M : L, M> {}; + )cpp", + matcher)); + + // Check the running time is not exponential. The number of subojects to + // traverse grows as fibonacci numbers even though the number of bases to + // traverse is quadratic. + // The test will hang if implementation of matchers traverses all subojects. + EXPECT_TRUE(notMatches(R"cpp( + template struct A0 {}; + template struct A1 : A0 {}; + template struct A2 : A1, A0 {}; + template struct A3 : A2, A1 {}; + template struct A4 : A3, A2 {}; + template struct A5 : A4, A3 {}; + template struct A6 : A5, A4 {}; + template struct A7 : A6, A5 {}; + template struct A8 : A7, A6 {}; + template struct A9 : A8, A7 {}; + template struct A10 : A9, A8 {}; + template struct A11 : A10, A9 {}; + template struct A12 : A11, A10 {}; + template struct A13 : A12, A11 {}; + template struct A14 : A13, A12 {}; + template struct A15 : A14, A13 {}; + template struct A16 : A15, A14 {}; + template struct A17 : A16, A15 {}; + template struct A18 : A17, A16 {}; + template struct A19 : A18, A17 {}; + template struct A20 : A19, A18 {}; + template struct A21 : A20, A19 {}; + template struct A22 : A21, A20 {}; + template struct A23 : A22, A21 {}; + template struct A24 : A23, A22 {}; + template struct A25 : A24, A23 {}; + template struct A26 : A25, A24 {}; + template struct A27 : A26, A25 {}; + template struct A28 : A27, A26 {}; + template struct A29 : A28, A27 {}; + template struct A30 : A29, A28 {}; + template struct A31 : A30, A29 {}; + template struct A32 : A31, A30 {}; + template struct A33 : A32, A31 {}; + template struct A34 : A33, A32 {}; + template struct A35 : A34, A33 {}; + template struct A36 : A35, A34 {}; + template struct A37 : A36, A35 {}; + template struct A38 : A37, A36 {}; + template struct A39 : A38, A37 {}; + template struct A40 : A39, A38 {}; +)cpp", + matcher)); +} + TEST(ASTMatchersTestObjC, ObjCMessageCalees) { StatementMatcher MessagingFoo = objcMessageExpr(callee(objcMethodDecl(hasName("foo"))));