diff --git a/clang/lib/Tooling/Transformer/SourceCode.cpp b/clang/lib/Tooling/Transformer/SourceCode.cpp index 6aae834b0db56..ab7184c2c069e 100644 --- a/clang/lib/Tooling/Transformer/SourceCode.cpp +++ b/clang/lib/Tooling/Transformer/SourceCode.cpp @@ -101,6 +101,54 @@ static bool spelledInMacroDefinition(SourceLocation Loc, return false; } +// Returns the expansion char-range of `Loc` if `Loc` is a split token. For +// example, `>>` in nested templates needs the first `>` to be split, otherwise +// the `SourceLocation` of the token would lex as `>>` instead of `>`. +static std::optional +getExpansionForSplitToken(SourceLocation Loc, const SourceManager &SM, + const LangOptions &LangOpts) { + if (Loc.isMacroID()) { + bool Invalid = false; + auto &SLoc = SM.getSLocEntry(SM.getFileID(Loc), &Invalid); + if (Invalid) + return std::nullopt; + if (auto &Expansion = SLoc.getExpansion(); + !Expansion.isExpansionTokenRange()) { + // A char-range expansion is only used where a token-range would be + // incorrect, and so identifies this as a split token (and importantly, + // not as a macro). + return Expansion.getExpansionLocRange(); + } + } + return std::nullopt; +} + +// If `Range` covers a split token, returns the expansion range, otherwise +// returns `Range`. +static CharSourceRange getRangeForSplitTokens(CharSourceRange Range, + const SourceManager &SM, + const LangOptions &LangOpts) { + if (Range.isTokenRange()) { + auto BeginToken = getExpansionForSplitToken(Range.getBegin(), SM, LangOpts); + auto EndToken = getExpansionForSplitToken(Range.getEnd(), SM, LangOpts); + if (EndToken) { + SourceLocation BeginLoc = + BeginToken ? BeginToken->getBegin() : Range.getBegin(); + // We can't use the expansion location with a token-range, because that + // will incorrectly lex the end token, so use a char-range that ends at + // the split. + return CharSourceRange::getCharRange(BeginLoc, EndToken->getEnd()); + } else if (BeginToken) { + // Since the end token is not split, the whole range covers the split, so + // the only adjustment we make is to use the expansion location of the + // begin token. + return CharSourceRange::getTokenRange(BeginToken->getBegin(), + Range.getEnd()); + } + } + return Range; +} + static CharSourceRange getRange(const CharSourceRange &EditRange, const SourceManager &SM, const LangOptions &LangOpts, @@ -109,13 +157,14 @@ static CharSourceRange getRange(const CharSourceRange &EditRange, if (IncludeMacroExpansion) { Range = Lexer::makeFileCharRange(EditRange, SM, LangOpts); } else { - if (spelledInMacroDefinition(EditRange.getBegin(), SM) || - spelledInMacroDefinition(EditRange.getEnd(), SM)) + auto AdjustedRange = getRangeForSplitTokens(EditRange, SM, LangOpts); + if (spelledInMacroDefinition(AdjustedRange.getBegin(), SM) || + spelledInMacroDefinition(AdjustedRange.getEnd(), SM)) return {}; - auto B = SM.getSpellingLoc(EditRange.getBegin()); - auto E = SM.getSpellingLoc(EditRange.getEnd()); - if (EditRange.isTokenRange()) + auto B = SM.getSpellingLoc(AdjustedRange.getBegin()); + auto E = SM.getSpellingLoc(AdjustedRange.getEnd()); + if (AdjustedRange.isTokenRange()) E = Lexer::getLocForEndOfToken(E, 0, SM, LangOpts); Range = CharSourceRange::getCharRange(B, E); } diff --git a/clang/unittests/Tooling/SourceCodeTest.cpp b/clang/unittests/Tooling/SourceCodeTest.cpp index 3641d2ee453f4..3c24b6220a224 100644 --- a/clang/unittests/Tooling/SourceCodeTest.cpp +++ b/clang/unittests/Tooling/SourceCodeTest.cpp @@ -8,6 +8,7 @@ #include "clang/Tooling/Transformer/SourceCode.h" #include "TestVisitor.h" +#include "clang/ASTMatchers/ASTMatchFinder.h" #include "clang/Basic/Diagnostic.h" #include "clang/Basic/SourceLocation.h" #include "clang/Lex/Lexer.h" @@ -18,10 +19,12 @@ #include using namespace clang; +using namespace clang::ast_matchers; using llvm::Failed; using llvm::Succeeded; using llvm::ValueIs; +using testing::Optional; using tooling::getAssociatedRange; using tooling::getExtendedRange; using tooling::getExtendedText; @@ -50,6 +53,15 @@ struct CallsVisitor : TestVisitor { std::function OnCall; }; +struct TypeLocVisitor : TestVisitor { + bool VisitTypeLoc(TypeLoc TL) { + OnTypeLoc(TL, Context); + return true; + } + + std::function OnTypeLoc; +}; + // Equality matcher for `clang::CharSourceRange`, which lacks `operator==`. MATCHER_P(EqualsRange, R, "") { return arg.isTokenRange() == R.isTokenRange() && @@ -510,6 +522,54 @@ int c = M3(3); Visitor.runOver(Code.code()); } +TEST(SourceCodeTest, InnerNestedTemplate) { + llvm::Annotations Code(R"cpp( + template + struct A {}; + template + struct B {}; + template + struct C {}; + + void f(A$r[[>>]]); + )cpp"); + + TypeLocVisitor Visitor; + Visitor.OnTypeLoc = [&](TypeLoc TL, ASTContext *Context) { + if (TL.getSourceRange().isInvalid()) + return; + + // There are no macros, so every TypeLoc's range should be valid. + auto Range = CharSourceRange::getTokenRange(TL.getSourceRange()); + auto LastTokenRange = CharSourceRange::getTokenRange(TL.getEndLoc()); + EXPECT_TRUE(getFileRangeForEdit(Range, *Context, + /*IncludeMacroExpansion=*/false)) + << TL.getSourceRange().printToString(Context->getSourceManager()); + EXPECT_TRUE(getFileRangeForEdit(LastTokenRange, *Context, + /*IncludeMacroExpansion=*/false)) + << TL.getEndLoc().printToString(Context->getSourceManager()); + + if (auto matches = match( + templateSpecializationTypeLoc( + loc(templateSpecializationType( + hasDeclaration(cxxRecordDecl(hasName("A"))))), + hasTemplateArgumentLoc( + 0, templateArgumentLoc(hasTypeLoc(typeLoc().bind("b"))))), + TL, *Context); + !matches.empty()) { + // A range where the start token is split, but the end token is not. + auto OuterTL = TL; + auto MiddleTL = *matches[0].getNodeAs("b"); + EXPECT_THAT( + getFileRangeForEdit(CharSourceRange::getTokenRange( + MiddleTL.getEndLoc(), OuterTL.getEndLoc()), + *Context, /*IncludeMacroExpansion=*/false), + Optional(EqualsAnnotatedRange(Context, Code.range("r")))); + } + }; + Visitor.runOver(Code.code(), TypeLocVisitor::Lang_CXX11); +} + TEST_P(GetFileRangeForEditTest, EditPartialMacroExpansionShouldFail) { std::string Code = R"cpp( #define BAR 10+