diff --git a/clang/unittests/AST/ASTPrint.h b/clang/unittests/AST/ASTPrint.h index 0e35846c86f47..c3b6b842316d9 100644 --- a/clang/unittests/AST/ASTPrint.h +++ b/clang/unittests/AST/ASTPrint.h @@ -19,88 +19,72 @@ namespace clang { -using PrintingPolicyAdjuster = llvm::function_ref; - -template -using NodePrinter = - std::function; - -template -using NodeFilter = std::function; +using PolicyAdjusterType = + Optional>; + +static void PrintStmt(raw_ostream &Out, const ASTContext *Context, + const Stmt *S, PolicyAdjusterType PolicyAdjuster) { + assert(S != nullptr && "Expected non-null Stmt"); + PrintingPolicy Policy = Context->getPrintingPolicy(); + if (PolicyAdjuster) + (*PolicyAdjuster)(Policy); + S->printPretty(Out, /*Helper*/ nullptr, Policy); +} -template class PrintMatch : public ast_matchers::MatchFinder::MatchCallback { - using PrinterT = NodePrinter; - using FilterT = NodeFilter; - SmallString<1024> Printed; - unsigned NumFoundNodes; - PrinterT Printer; - FilterT Filter; - PrintingPolicyAdjuster PolicyAdjuster; + unsigned NumFoundStmts; + PolicyAdjusterType PolicyAdjuster; public: - PrintMatch(PrinterT Printer, PrintingPolicyAdjuster PolicyAdjuster, - FilterT Filter) - : NumFoundNodes(0), Printer(std::move(Printer)), - Filter(std::move(Filter)), PolicyAdjuster(PolicyAdjuster) {} + PrintMatch(PolicyAdjusterType PolicyAdjuster) + : NumFoundStmts(0), PolicyAdjuster(PolicyAdjuster) {} void run(const ast_matchers::MatchFinder::MatchResult &Result) override { - const NodeType *N = Result.Nodes.getNodeAs("id"); - if (!N || !Filter(N)) + const Stmt *S = Result.Nodes.getNodeAs("id"); + if (!S) return; - NumFoundNodes++; - if (NumFoundNodes > 1) + NumFoundStmts++; + if (NumFoundStmts > 1) return; llvm::raw_svector_ostream Out(Printed); - Printer(Out, Result.Context, N, PolicyAdjuster); + PrintStmt(Out, Result.Context, S, PolicyAdjuster); } StringRef getPrinted() const { return Printed; } - unsigned getNumFoundNodes() const { return NumFoundNodes; } + unsigned getNumFoundStmts() const { return NumFoundStmts; } }; -template -::testing::AssertionResult PrintedNodeMatches( - StringRef Code, const std::vector &Args, - const Matcher &NodeMatch, StringRef ExpectedPrinted, StringRef FileName, - NodePrinter Printer, - PrintingPolicyAdjuster PolicyAdjuster = nullptr, bool AllowError = false, - NodeFilter Filter = [](const NodeType *) { return true; }) { +template +::testing::AssertionResult +PrintedStmtMatches(StringRef Code, const std::vector &Args, + const T &NodeMatch, StringRef ExpectedPrinted, + PolicyAdjusterType PolicyAdjuster = None) { - PrintMatch Callback(Printer, PolicyAdjuster, Filter); + PrintMatch Printer(PolicyAdjuster); ast_matchers::MatchFinder Finder; - Finder.addMatcher(NodeMatch, &Callback); + Finder.addMatcher(NodeMatch, &Printer); std::unique_ptr Factory( tooling::newFrontendActionFactory(&Finder)); - bool ToolResult; - if (FileName.empty()) { - ToolResult = tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args); - } else { - ToolResult = - tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName); - } - if (!ToolResult && !AllowError) + if (!tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args)) return testing::AssertionFailure() << "Parsing error in \"" << Code.str() << "\""; - if (Callback.getNumFoundNodes() == 0) - return testing::AssertionFailure() << "Matcher didn't find any nodes"; + if (Printer.getNumFoundStmts() == 0) + return testing::AssertionFailure() << "Matcher didn't find any statements"; - if (Callback.getNumFoundNodes() > 1) + if (Printer.getNumFoundStmts() > 1) return testing::AssertionFailure() - << "Matcher should match only one node (found " - << Callback.getNumFoundNodes() << ")"; + << "Matcher should match only one statement (found " + << Printer.getNumFoundStmts() << ")"; - if (Callback.getPrinted() != ExpectedPrinted) + if (Printer.getPrinted() != ExpectedPrinted) return ::testing::AssertionFailure() << "Expected \"" << ExpectedPrinted.str() << "\", got \"" - << Callback.getPrinted().str() << "\""; + << Printer.getPrinted().str() << "\""; return ::testing::AssertionSuccess(); } diff --git a/clang/unittests/AST/DeclPrinterTest.cpp b/clang/unittests/AST/DeclPrinterTest.cpp index bdc23f33f39b0..e70d2bef72121 100644 --- a/clang/unittests/AST/DeclPrinterTest.cpp +++ b/clang/unittests/AST/DeclPrinterTest.cpp @@ -18,7 +18,6 @@ // //===----------------------------------------------------------------------===// -#include "ASTPrint.h" #include "clang/AST/ASTContext.h" #include "clang/ASTMatchers/ASTMatchFinder.h" #include "clang/ASTMatchers/ASTMatchers.h" @@ -33,8 +32,10 @@ using namespace tooling; namespace { +using PrintingPolicyModifier = void (*)(PrintingPolicy &policy); + void PrintDecl(raw_ostream &Out, const ASTContext *Context, const Decl *D, - PrintingPolicyAdjuster PolicyModifier) { + PrintingPolicyModifier PolicyModifier) { PrintingPolicy Policy = Context->getPrintingPolicy(); Policy.TerseOutput = true; Policy.Indentation = 0; @@ -43,23 +44,74 @@ void PrintDecl(raw_ostream &Out, const ASTContext *Context, const Decl *D, D->print(Out, Policy, /*Indentation*/ 0, /*PrintInstantiation*/ false); } +class PrintMatch : public MatchFinder::MatchCallback { + SmallString<1024> Printed; + unsigned NumFoundDecls; + PrintingPolicyModifier PolicyModifier; + +public: + PrintMatch(PrintingPolicyModifier PolicyModifier) + : NumFoundDecls(0), PolicyModifier(PolicyModifier) {} + + void run(const MatchFinder::MatchResult &Result) override { + const Decl *D = Result.Nodes.getNodeAs("id"); + if (!D || D->isImplicit()) + return; + NumFoundDecls++; + if (NumFoundDecls > 1) + return; + + llvm::raw_svector_ostream Out(Printed); + PrintDecl(Out, Result.Context, D, PolicyModifier); + } + + StringRef getPrinted() const { + return Printed; + } + + unsigned getNumFoundDecls() const { + return NumFoundDecls; + } +}; + ::testing::AssertionResult PrintedDeclMatches(StringRef Code, const std::vector &Args, const DeclarationMatcher &NodeMatch, StringRef ExpectedPrinted, StringRef FileName, - PrintingPolicyAdjuster PolicyModifier = nullptr, + PrintingPolicyModifier PolicyModifier = nullptr, bool AllowError = false) { - return PrintedNodeMatches( - Code, Args, NodeMatch, ExpectedPrinted, FileName, PrintDecl, - PolicyModifier, AllowError, - // Filter out implicit decls - [](const Decl *D) { return !D->isImplicit(); }); + PrintMatch Printer(PolicyModifier); + MatchFinder Finder; + Finder.addMatcher(NodeMatch, &Printer); + std::unique_ptr Factory( + newFrontendActionFactory(&Finder)); + + if (!runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName) && + !AllowError) + return testing::AssertionFailure() + << "Parsing error in \"" << Code.str() << "\""; + + if (Printer.getNumFoundDecls() == 0) + return testing::AssertionFailure() + << "Matcher didn't find any declarations"; + + if (Printer.getNumFoundDecls() > 1) + return testing::AssertionFailure() + << "Matcher should match only one declaration " + "(found " << Printer.getNumFoundDecls() << ")"; + + if (Printer.getPrinted() != ExpectedPrinted) + return ::testing::AssertionFailure() + << "Expected \"" << ExpectedPrinted.str() << "\", " + "got \"" << Printer.getPrinted().str() << "\""; + + return ::testing::AssertionSuccess(); } ::testing::AssertionResult PrintedDeclCXX98Matches(StringRef Code, StringRef DeclName, StringRef ExpectedPrinted, - PrintingPolicyAdjuster PolicyModifier = nullptr) { + PrintingPolicyModifier PolicyModifier = nullptr) { std::vector Args(1, "-std=c++98"); return PrintedDeclMatches(Code, Args, namedDecl(hasName(DeclName)).bind("id"), ExpectedPrinted, "input.cc", PolicyModifier); @@ -68,7 +120,7 @@ PrintedDeclCXX98Matches(StringRef Code, StringRef DeclName, ::testing::AssertionResult PrintedDeclCXX98Matches(StringRef Code, const DeclarationMatcher &NodeMatch, StringRef ExpectedPrinted, - PrintingPolicyAdjuster PolicyModifier = nullptr) { + PrintingPolicyModifier PolicyModifier = nullptr) { std::vector Args(1, "-std=c++98"); return PrintedDeclMatches(Code, Args, @@ -113,7 +165,7 @@ ::testing::AssertionResult PrintedDeclCXX11nonMSCMatches( ::testing::AssertionResult PrintedDeclCXX17Matches(StringRef Code, const DeclarationMatcher &NodeMatch, StringRef ExpectedPrinted, - PrintingPolicyAdjuster PolicyModifier = nullptr) { + PrintingPolicyModifier PolicyModifier = nullptr) { std::vector Args{"-std=c++17", "-fno-delayed-template-parsing"}; return PrintedDeclMatches(Code, Args, NodeMatch, ExpectedPrinted, "input.cc", PolicyModifier); @@ -122,7 +174,7 @@ PrintedDeclCXX17Matches(StringRef Code, const DeclarationMatcher &NodeMatch, ::testing::AssertionResult PrintedDeclC11Matches(StringRef Code, const DeclarationMatcher &NodeMatch, StringRef ExpectedPrinted, - PrintingPolicyAdjuster PolicyModifier = nullptr) { + PrintingPolicyModifier PolicyModifier = nullptr) { std::vector Args(1, "-std=c11"); return PrintedDeclMatches(Code, Args, NodeMatch, ExpectedPrinted, "input.c", PolicyModifier); diff --git a/clang/unittests/AST/NamedDeclPrinterTest.cpp b/clang/unittests/AST/NamedDeclPrinterTest.cpp index cd833725b448d..1042312e8a730 100644 --- a/clang/unittests/AST/NamedDeclPrinterTest.cpp +++ b/clang/unittests/AST/NamedDeclPrinterTest.cpp @@ -15,7 +15,6 @@ // //===----------------------------------------------------------------------===// -#include "ASTPrint.h" #include "clang/AST/ASTContext.h" #include "clang/AST/Decl.h" #include "clang/AST/PrettyPrinter.h" @@ -67,11 +66,31 @@ ::testing::AssertionResult PrintedDeclMatches( const DeclarationMatcher &NodeMatch, StringRef ExpectedPrinted, StringRef FileName, std::function Print) { - return PrintedNodeMatches( - Code, Args, NodeMatch, ExpectedPrinted, FileName, - [Print](llvm::raw_ostream &Out, const ASTContext *Context, - const NamedDecl *ND, - PrintingPolicyAdjuster PolicyAdjuster) { Print(Out, ND); }); + PrintMatch Printer(std::move(Print)); + MatchFinder Finder; + Finder.addMatcher(NodeMatch, &Printer); + std::unique_ptr Factory = + newFrontendActionFactory(&Finder); + + if (!runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName)) + return testing::AssertionFailure() + << "Parsing error in \"" << Code.str() << "\""; + + if (Printer.getNumFoundDecls() == 0) + return testing::AssertionFailure() + << "Matcher didn't find any named declarations"; + + if (Printer.getNumFoundDecls() > 1) + return testing::AssertionFailure() + << "Matcher should match only one named declaration " + "(found " << Printer.getNumFoundDecls() << ")"; + + if (Printer.getPrinted() != ExpectedPrinted) + return ::testing::AssertionFailure() + << "Expected \"" << ExpectedPrinted.str() << "\", " + "got \"" << Printer.getPrinted().str() << "\""; + + return ::testing::AssertionSuccess(); } ::testing::AssertionResult diff --git a/clang/unittests/AST/StmtPrinterTest.cpp b/clang/unittests/AST/StmtPrinterTest.cpp index 65dfec4cc5b4a..29cdbf75a00c8 100644 --- a/clang/unittests/AST/StmtPrinterTest.cpp +++ b/clang/unittests/AST/StmtPrinterTest.cpp @@ -38,29 +38,11 @@ DeclarationMatcher FunctionBodyMatcher(StringRef ContainingFunction) { has(compoundStmt(has(stmt().bind("id"))))); } -static void PrintStmt(raw_ostream &Out, const ASTContext *Context, - const Stmt *S, PrintingPolicyAdjuster PolicyAdjuster) { - assert(S != nullptr && "Expected non-null Stmt"); - PrintingPolicy Policy = Context->getPrintingPolicy(); - if (PolicyAdjuster) - PolicyAdjuster(Policy); - S->printPretty(Out, /*Helper*/ nullptr, Policy); -} - -template -::testing::AssertionResult -PrintedStmtMatches(StringRef Code, const std::vector &Args, - const Matcher &NodeMatch, StringRef ExpectedPrinted, - PrintingPolicyAdjuster PolicyAdjuster = nullptr) { - return PrintedNodeMatches(Code, Args, NodeMatch, ExpectedPrinted, "", - PrintStmt, PolicyAdjuster); -} - template ::testing::AssertionResult PrintedStmtCXXMatches(StdVer Standard, StringRef Code, const T &NodeMatch, StringRef ExpectedPrinted, - PrintingPolicyAdjuster PolicyAdjuster = nullptr) { + PolicyAdjusterType PolicyAdjuster = None) { const char *StdOpt; switch (Standard) { case StdVer::CXX98: StdOpt = "-std=c++98"; break; @@ -82,7 +64,7 @@ template ::testing::AssertionResult PrintedStmtMSMatches(StringRef Code, const T &NodeMatch, StringRef ExpectedPrinted, - PrintingPolicyAdjuster PolicyAdjuster = nullptr) { + PolicyAdjusterType PolicyAdjuster = None) { std::vector Args = { "-std=c++98", "-target", "i686-pc-win32", @@ -97,7 +79,7 @@ template ::testing::AssertionResult PrintedStmtObjCMatches(StringRef Code, const T &NodeMatch, StringRef ExpectedPrinted, - PrintingPolicyAdjuster PolicyAdjuster = nullptr) { + PolicyAdjusterType PolicyAdjuster = None) { std::vector Args = { "-ObjC", "-fobjc-runtime=macosx-10.12.0", @@ -220,10 +202,10 @@ class A { }; )"; // No implicit 'this'. - ASSERT_TRUE(PrintedStmtCXXMatches( - StdVer::CXX11, CPPSource, memberExpr(anything()).bind("id"), "field", - - [](PrintingPolicy &PP) { PP.SuppressImplicitBase = true; })); + ASSERT_TRUE(PrintedStmtCXXMatches(StdVer::CXX11, + CPPSource, memberExpr(anything()).bind("id"), "field", + PolicyAdjusterType( + [](PrintingPolicy &PP) { PP.SuppressImplicitBase = true; }))); // Print implicit 'this'. ASSERT_TRUE(PrintedStmtCXXMatches(StdVer::CXX11, CPPSource, memberExpr(anything()).bind("id"), "this->field")); @@ -240,10 +222,11 @@ class A { @end )"; // No implicit 'self'. - ASSERT_TRUE(PrintedStmtObjCMatches( - ObjCSource, returnStmt().bind("id"), "return ivar;\n", - - [](PrintingPolicy &PP) { PP.SuppressImplicitBase = true; })); + ASSERT_TRUE(PrintedStmtObjCMatches(ObjCSource, returnStmt().bind("id"), + "return ivar;\n", + PolicyAdjusterType([](PrintingPolicy &PP) { + PP.SuppressImplicitBase = true; + }))); // Print implicit 'self'. ASSERT_TRUE(PrintedStmtObjCMatches(ObjCSource, returnStmt().bind("id"), "return self->ivar;\n")); @@ -260,6 +243,5 @@ TEST(StmtPrinter, TerseOutputWithLambdas) { // body not printed when TerseOutput is on. ASSERT_TRUE(PrintedStmtCXXMatches( StdVer::CXX11, CPPSource, lambdaExpr(anything()).bind("id"), "[] {}", - - [](PrintingPolicy &PP) { PP.TerseOutput = true; })); + PolicyAdjusterType([](PrintingPolicy &PP) { PP.TerseOutput = true; }))); }