diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h index 61e524793ec70b..1426e569eabe1b 100644 --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -468,6 +468,8 @@ template class RecursiveASTVisitor { DEF_TRAVERSE_TMPL_INST(Function) #undef DEF_TRAVERSE_TMPL_INST + bool dataTraverseNode(Stmt *S, DataRecursionQueue *Queue); + private: // These are helper methods used by more than one Traverse* method. bool TraverseTemplateParameterListHelper(TemplateParameterList *TPL); @@ -497,7 +499,6 @@ template class RecursiveASTVisitor { bool VisitOMPClauseWithPreInit(OMPClauseWithPreInit *Node); bool VisitOMPClauseWithPostUpdate(OMPClauseWithPostUpdate *Node); - bool dataTraverseNode(Stmt *S, DataRecursionQueue *Queue); bool PostVisitStmt(Stmt *S); }; diff --git a/clang/lib/ASTMatchers/ASTMatchFinder.cpp b/clang/lib/ASTMatchers/ASTMatchFinder.cpp index cc953714452423..762885fa00527c 100644 --- a/clang/lib/ASTMatchers/ASTMatchFinder.cpp +++ b/clang/lib/ASTMatchers/ASTMatchFinder.cpp @@ -463,6 +463,22 @@ class MatchASTVisitor : public RecursiveASTVisitor, bool TraverseConstructorInitializer(CXXCtorInitializer *CtorInit); bool TraverseTemplateArgumentLoc(TemplateArgumentLoc TAL); + bool dataTraverseNode(Stmt *S, DataRecursionQueue *Queue) { + if (auto *RF = dyn_cast(S)) { + for (auto *SubStmt : RF->children()) { + if (SubStmt == RF->getInit() || SubStmt == RF->getLoopVarStmt() || + SubStmt == RF->getRangeInit() || SubStmt == RF->getBody()) { + TraverseStmt(SubStmt, Queue); + } else { + ASTNodeNotSpelledInSourceScope RAII(this, true); + TraverseStmt(SubStmt, Queue); + } + } + return true; + } + return RecursiveASTVisitor::dataTraverseNode(S, Queue); + } + // Matches children or descendants of 'Node' with 'BaseMatcher'. bool memoizedMatchesRecursively(const DynTypedNode &Node, ASTContext &Ctx, const DynTypedMatcher &Matcher, diff --git a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp index 10d2d6ec3916a1..a3a3a911b85c2b 100644 --- a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp +++ b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp @@ -2580,6 +2580,31 @@ struct CtorInitsNonTrivial : NonTrivial EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M))); EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M))); } + { + auto M = binaryOperator(hasOperatorName("!=")); + EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M))); + EXPECT_FALSE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M))); + } + { + auto M = unaryOperator(hasOperatorName("++")); + EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M))); + EXPECT_FALSE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M))); + } + { + auto M = declStmt(hasSingleDecl(varDecl(matchesName("__range")))); + EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M))); + EXPECT_FALSE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M))); + } + { + auto M = declStmt(hasSingleDecl(varDecl(matchesName("__begin")))); + EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M))); + EXPECT_FALSE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M))); + } + { + auto M = declStmt(hasSingleDecl(varDecl(matchesName("__end")))); + EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M))); + EXPECT_FALSE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M))); + } Code = R"cpp( void rangeFor()