diff --git a/clang-tools-extra/clangd/AST.cpp b/clang-tools-extra/clangd/AST.cpp index 0dcff2eae05e7..4b73bdba8aaa9 100644 --- a/clang-tools-extra/clangd/AST.cpp +++ b/clang-tools-extra/clangd/AST.cpp @@ -1040,5 +1040,26 @@ bool isExpandedFromParameterPack(const ParmVarDecl *D) { return getUnderlyingPackType(D) != nullptr; } +bool isLikelyForwardingFunction(FunctionTemplateDecl *FT) { + const auto *FD = FT->getTemplatedDecl(); + const auto NumParams = FD->getNumParams(); + // Check whether its last parameter is a parameter pack... + if (NumParams > 0) { + const auto *LastParam = FD->getParamDecl(NumParams - 1); + if (const auto *PET = dyn_cast(LastParam->getType())) { + // ... of the type T&&... or T... + const auto BaseType = PET->getPattern().getNonReferenceType(); + if (const auto *TTPT = + dyn_cast(BaseType.getTypePtr())) { + // ... whose template parameter comes from the function directly + if (FT->getTemplateParameters()->getDepth() == TTPT->getDepth()) { + return true; + } + } + } + } + return false; +} + } // namespace clangd } // namespace clang diff --git a/clang-tools-extra/clangd/AST.h b/clang-tools-extra/clangd/AST.h index 2b83595e5b8e9..af45ae2d9022d 100644 --- a/clang-tools-extra/clangd/AST.h +++ b/clang-tools-extra/clangd/AST.h @@ -253,6 +253,10 @@ resolveForwardingParameters(const FunctionDecl *D, unsigned MaxDepth = 10); /// reference to one (e.g. `Args&...` or `Args&&...`). bool isExpandedFromParameterPack(const ParmVarDecl *D); +/// Heuristic that checks if FT is forwarding a parameter pack to another +/// function. (e.g. `make_unique`). +bool isLikelyForwardingFunction(FunctionTemplateDecl *FT); + } // namespace clangd } // namespace clang diff --git a/clang-tools-extra/clangd/Preamble.cpp b/clang-tools-extra/clangd/Preamble.cpp index 8af9e4649218d..09aaf3290b585 100644 --- a/clang-tools-extra/clangd/Preamble.cpp +++ b/clang-tools-extra/clangd/Preamble.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "Preamble.h" +#include "AST.h" #include "CollectMacros.h" #include "Compiler.h" #include "Config.h" @@ -166,27 +167,6 @@ class CppFilePreambleCallbacks : public PreambleCallbacks { collectPragmaMarksCallback(*SourceMgr, Marks)); } - static bool isLikelyForwardingFunction(FunctionTemplateDecl *FT) { - const auto *FD = FT->getTemplatedDecl(); - const auto NumParams = FD->getNumParams(); - // Check whether its last parameter is a parameter pack... - if (NumParams > 0) { - const auto *LastParam = FD->getParamDecl(NumParams - 1); - if (const auto *PET = dyn_cast(LastParam->getType())) { - // ... of the type T&&... or T... - const auto BaseType = PET->getPattern().getNonReferenceType(); - if (const auto *TTPT = - dyn_cast(BaseType.getTypePtr())) { - // ... whose template parameter comes from the function directly - if (FT->getTemplateParameters()->getDepth() == TTPT->getDepth()) { - return true; - } - } - } - } - return false; - } - bool shouldSkipFunctionBody(Decl *D) override { // Usually we don't need to look inside the bodies of header functions // to understand the program. However when forwarding function like diff --git a/clang-tools-extra/clangd/index/SymbolCollector.cpp b/clang-tools-extra/clangd/index/SymbolCollector.cpp index 39c479b5f4d5b..bc5ebbef31fbd 100644 --- a/clang-tools-extra/clangd/index/SymbolCollector.cpp +++ b/clang-tools-extra/clangd/index/SymbolCollector.cpp @@ -29,6 +29,7 @@ #include "clang/AST/DeclTemplate.h" #include "clang/AST/DeclarationName.h" #include "clang/AST/Expr.h" +#include "clang/AST/RecursiveASTVisitor.h" #include "clang/Basic/FileEntry.h" #include "clang/Basic/LangOptions.h" #include "clang/Basic/SourceLocation.h" @@ -576,6 +577,22 @@ SymbolCollector::getRefContainer(const Decl *Enclosing, return Enclosing; } +class ForwardingToConstructorVisitor + : public RecursiveASTVisitor { +public: + ForwardingToConstructorVisitor() {} + + bool VisitCXXConstructExpr(CXXConstructExpr *E) { + if (auto *Callee = E->getConstructor()) { + Constructors.push_back(Callee); + } + return true; + } + + // Output of this visitor + std::vector Constructors{}; +}; + // Always return true to continue indexing. bool SymbolCollector::handleDeclOccurrence( const Decl *D, index::SymbolRoleSet Roles, @@ -669,6 +686,21 @@ bool SymbolCollector::handleDeclOccurrence( addRef(ID, SymbolRef{FileLoc, FID, Roles, index::getSymbolInfo(ND).Kind, getRefContainer(ASTNode.Parent, Opts), isSpelled(FileLoc, *ND)}); + // Also collect indirect constructor calls like `make_unique` + if (auto *FD = llvm::dyn_cast(ASTNode.OrigD); + FD && FD->isTemplateInstantiation()) { + if (auto *PT = FD->getPrimaryTemplate(); + PT && isLikelyForwardingFunction(PT)) { + ForwardingToConstructorVisitor Visitor{}; + Visitor.TraverseStmt(FD->getBody()); + for (auto *Constructor : Visitor.Constructors) { + addRef(getSymbolIDCached(Constructor), + SymbolRef{FileLoc, FID, Roles, index::getSymbolInfo(ND).Kind, + getRefContainer(ASTNode.Parent, Opts), + isSpelled(FileLoc, *ND)}); + } + } + } } } // Don't continue indexing if this is a mere reference. diff --git a/clang-tools-extra/clangd/unittests/XRefsTests.cpp b/clang-tools-extra/clangd/unittests/XRefsTests.cpp index 7ed08d7cce3d3..d3b6faa5fa25f 100644 --- a/clang-tools-extra/clangd/unittests/XRefsTests.cpp +++ b/clang-tools-extra/clangd/unittests/XRefsTests.cpp @@ -5,14 +5,15 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -#include "Annotations.h" #include "AST.h" +#include "Annotations.h" #include "ParsedAST.h" #include "Protocol.h" #include "SourceCode.h" #include "SyncAPI.h" #include "TestFS.h" #include "TestTU.h" +#include "TestWorkspace.h" #include "XRefs.h" #include "index/MemIndex.h" #include "clang/AST/Decl.h" @@ -2713,6 +2714,36 @@ TEST(FindReferences, NoQueryForLocalSymbols) { } } +TEST(FindReferences, ForwardingInIndex) { + Annotations Header(R"cpp( + namespace std { + template T &&forward(T &t); + template T *make_unique(Args &&...args) { + return new T(std::forward(args)...); + } + } + struct Test { + [[T^est]](){} + }; + )cpp"); + Annotations Main(R"cpp( + #include "header.hpp" + int main() { + auto a = std::[[make_unique]](); + } + )cpp"); + TestWorkspace TW; + TW.addMainFile("header.hpp", Header.code()); + TW.addMainFile("main.cpp", Main.code()); + auto AST = TW.openFile("header.hpp").value(); + auto Index = TW.index(); + + EXPECT_THAT(findReferences(AST, Header.point(), 0, Index.get(), + /*AddContext*/ true) + .References, + ElementsAre(rangeIs(Header.range()), rangeIs(Main.range()))); +} + TEST(GetNonLocalDeclRefs, All) { struct Case { llvm::StringRef AnnotatedCode;