Skip to content
21 changes: 21 additions & 0 deletions clang-tools-extra/clangd/AST.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1040,5 +1040,26 @@ bool isExpandedFromParameterPack(const ParmVarDecl *D) {
return getUnderlyingPackType(D) != nullptr;
}

bool isLikelyForwardingFunction(FunctionTemplateDecl *FT) {
const auto *FD = FT->getTemplatedDecl();
const auto NumParams = FD->getNumParams();
// Check whether its last parameter is a parameter pack...
if (NumParams > 0) {
const auto *LastParam = FD->getParamDecl(NumParams - 1);
if (const auto *PET = dyn_cast<PackExpansionType>(LastParam->getType())) {
// ... of the type T&&... or T...
const auto BaseType = PET->getPattern().getNonReferenceType();
if (const auto *TTPT =
dyn_cast<TemplateTypeParmType>(BaseType.getTypePtr())) {
// ... whose template parameter comes from the function directly
if (FT->getTemplateParameters()->getDepth() == TTPT->getDepth()) {
return true;
}
}
}
}
return false;
}

} // namespace clangd
} // namespace clang
4 changes: 4 additions & 0 deletions clang-tools-extra/clangd/AST.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,10 @@ resolveForwardingParameters(const FunctionDecl *D, unsigned MaxDepth = 10);
/// reference to one (e.g. `Args&...` or `Args&&...`).
bool isExpandedFromParameterPack(const ParmVarDecl *D);

/// Heuristic that checks if FT is forwarding a parameter pack to another
/// function. (e.g. `make_unique`).
bool isLikelyForwardingFunction(FunctionTemplateDecl *FT);

} // namespace clangd
} // namespace clang

Expand Down
22 changes: 1 addition & 21 deletions clang-tools-extra/clangd/Preamble.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "Preamble.h"
#include "AST.h"
#include "CollectMacros.h"
#include "Compiler.h"
#include "Config.h"
Expand Down Expand Up @@ -166,27 +167,6 @@ class CppFilePreambleCallbacks : public PreambleCallbacks {
collectPragmaMarksCallback(*SourceMgr, Marks));
}

static bool isLikelyForwardingFunction(FunctionTemplateDecl *FT) {
const auto *FD = FT->getTemplatedDecl();
const auto NumParams = FD->getNumParams();
// Check whether its last parameter is a parameter pack...
if (NumParams > 0) {
const auto *LastParam = FD->getParamDecl(NumParams - 1);
if (const auto *PET = dyn_cast<PackExpansionType>(LastParam->getType())) {
// ... of the type T&&... or T...
const auto BaseType = PET->getPattern().getNonReferenceType();
if (const auto *TTPT =
dyn_cast<TemplateTypeParmType>(BaseType.getTypePtr())) {
// ... whose template parameter comes from the function directly
if (FT->getTemplateParameters()->getDepth() == TTPT->getDepth()) {
return true;
}
}
}
}
return false;
}

bool shouldSkipFunctionBody(Decl *D) override {
// Usually we don't need to look inside the bodies of header functions
// to understand the program. However when forwarding function like
Expand Down
32 changes: 32 additions & 0 deletions clang-tools-extra/clangd/index/SymbolCollector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "clang/AST/DeclTemplate.h"
#include "clang/AST/DeclarationName.h"
#include "clang/AST/Expr.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/Basic/FileEntry.h"
#include "clang/Basic/LangOptions.h"
#include "clang/Basic/SourceLocation.h"
Expand Down Expand Up @@ -576,6 +577,22 @@ SymbolCollector::getRefContainer(const Decl *Enclosing,
return Enclosing;
}

class ForwardingToConstructorVisitor
: public RecursiveASTVisitor<ForwardingToConstructorVisitor> {
public:
ForwardingToConstructorVisitor() {}

bool VisitCXXConstructExpr(CXXConstructExpr *E) {
if (auto *Callee = E->getConstructor()) {
Constructors.push_back(Callee);
}
return true;
}

// Output of this visitor
std::vector<CXXConstructorDecl *> Constructors{};
};

// Always return true to continue indexing.
bool SymbolCollector::handleDeclOccurrence(
const Decl *D, index::SymbolRoleSet Roles,
Expand Down Expand Up @@ -669,6 +686,21 @@ bool SymbolCollector::handleDeclOccurrence(
addRef(ID, SymbolRef{FileLoc, FID, Roles, index::getSymbolInfo(ND).Kind,
getRefContainer(ASTNode.Parent, Opts),
isSpelled(FileLoc, *ND)});
// Also collect indirect constructor calls like `make_unique`
if (auto *FD = llvm::dyn_cast<clang::FunctionDecl>(ASTNode.OrigD);
FD && FD->isTemplateInstantiation()) {
if (auto *PT = FD->getPrimaryTemplate();
PT && isLikelyForwardingFunction(PT)) {
ForwardingToConstructorVisitor Visitor{};
Visitor.TraverseStmt(FD->getBody());
for (auto *Constructor : Visitor.Constructors) {
addRef(getSymbolIDCached(Constructor),
SymbolRef{FileLoc, FID, Roles, index::getSymbolInfo(ND).Kind,
getRefContainer(ASTNode.Parent, Opts),
isSpelled(FileLoc, *ND)});
}
}
}
}
}
// Don't continue indexing if this is a mere reference.
Expand Down
33 changes: 32 additions & 1 deletion clang-tools-extra/clangd/unittests/XRefsTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "Annotations.h"
#include "AST.h"
#include "Annotations.h"
#include "ParsedAST.h"
#include "Protocol.h"
#include "SourceCode.h"
#include "SyncAPI.h"
#include "TestFS.h"
#include "TestTU.h"
#include "TestWorkspace.h"
#include "XRefs.h"
#include "index/MemIndex.h"
#include "clang/AST/Decl.h"
Expand Down Expand Up @@ -2713,6 +2714,36 @@ TEST(FindReferences, NoQueryForLocalSymbols) {
}
}

TEST(FindReferences, ForwardingInIndex) {
Annotations Header(R"cpp(
namespace std {
template <class T> T &&forward(T &t);
template <class T, class... Args> T *make_unique(Args &&...args) {
return new T(std::forward<Args>(args)...);
}
}
struct Test {
[[T^est]](){}
};
)cpp");
Annotations Main(R"cpp(
#include "header.hpp"
int main() {
auto a = std::[[make_unique]]<Test>();
}
)cpp");
TestWorkspace TW;
TW.addMainFile("header.hpp", Header.code());
TW.addMainFile("main.cpp", Main.code());
auto AST = TW.openFile("header.hpp").value();
auto Index = TW.index();

EXPECT_THAT(findReferences(AST, Header.point(), 0, Index.get(),
/*AddContext*/ true)
.References,
ElementsAre(rangeIs(Header.range()), rangeIs(Main.range())));
}

TEST(GetNonLocalDeclRefs, All) {
struct Case {
llvm::StringRef AnnotatedCode;
Expand Down