Skip to content

Commit

Permalink
[clangd] add inlay hints for std::forward-ed parameter packs
Browse files Browse the repository at this point in the history
This adds special-case treatment for parameter packs in
make_unique-like functions to forward parameter names to inlay hints.
The parameter packs are being resolved recursively by traversing the
function body of forwarding functions looking for expressions matching
the (std::forwarded) parameters expanded from a pack.
The implementation checks whether parameters are being passed by
(rvalue) reference or value and adds reference inlay hints accordingly.
The traversal has a limited recursion stack depth, and recursive calls
like std::make_tuple are cut off to avoid hinting duplicate parameter
names.

Reviewed By: sammccall

Differential Revision: https://reviews.llvm.org/D124690
  • Loading branch information
upsj committed Jul 6, 2022
1 parent a60360f commit a638648
Show file tree
Hide file tree
Showing 4 changed files with 903 additions and 57 deletions.
299 changes: 299 additions & 0 deletions clang-tools-extra/clangd/AST.cpp
Expand Up @@ -16,19 +16,23 @@
#include "clang/AST/DeclCXX.h"
#include "clang/AST/DeclTemplate.h"
#include "clang/AST/DeclarationName.h"
#include "clang/AST/ExprCXX.h"
#include "clang/AST/NestedNameSpecifier.h"
#include "clang/AST/PrettyPrinter.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/AST/Stmt.h"
#include "clang/AST/TemplateBase.h"
#include "clang/AST/TypeLoc.h"
#include "clang/Basic/Builtins.h"
#include "clang/Basic/SourceLocation.h"
#include "clang/Basic/SourceManager.h"
#include "clang/Basic/Specifiers.h"
#include "clang/Index/USRGeneration.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
Expand Down Expand Up @@ -667,5 +671,300 @@ bool isDeeplyNested(const Decl *D, unsigned MaxDepth) {
}
return false;
}

namespace {

// returns true for `X` in `template <typename... X> void foo()`
bool isTemplateTypeParameterPack(NamedDecl *D) {
if (const auto *TTPD = dyn_cast<TemplateTypeParmDecl>(D)) {
return TTPD->isParameterPack();
}
return false;
}

// Returns the template parameter pack type from an instantiated function
// template, if it exists, nullptr otherwise.
const TemplateTypeParmType *getFunctionPackType(const FunctionDecl *Callee) {
if (const auto *TemplateDecl = Callee->getPrimaryTemplate()) {
auto TemplateParams = TemplateDecl->getTemplateParameters()->asArray();
// find the template parameter pack from the back
const auto It = std::find_if(TemplateParams.rbegin(), TemplateParams.rend(),
isTemplateTypeParameterPack);
if (It != TemplateParams.rend()) {
const auto *TTPD = dyn_cast<TemplateTypeParmDecl>(*It);
return TTPD->getTypeForDecl()->castAs<TemplateTypeParmType>();
}
}
return nullptr;
}

// Returns the template parameter pack type that this parameter was expanded
// from (if in the Args... or Args&... or Args&&... form), if this is the case,
// nullptr otherwise.
const TemplateTypeParmType *getUnderylingPackType(const ParmVarDecl *Param) {
const auto *PlainType = Param->getType().getTypePtr();
if (auto *RT = dyn_cast<ReferenceType>(PlainType))
PlainType = RT->getPointeeTypeAsWritten().getTypePtr();
if (const auto *SubstType = dyn_cast<SubstTemplateTypeParmType>(PlainType)) {
const auto *ReplacedParameter = SubstType->getReplacedParameter();
if (ReplacedParameter->isParameterPack()) {
return dyn_cast<TemplateTypeParmType>(
ReplacedParameter->getCanonicalTypeUnqualified()->getTypePtr());
}
}
return nullptr;
}

// This visitor walks over the body of an instantiated function template.
// The template accepts a parameter pack and the visitor records whether
// the pack parameters were forwarded to another call. For example, given:
//
// template <typename T, typename... Args>
// auto make_unique(Args... args) {
// return unique_ptr<T>(new T(args...));
// }
//
// When called as `make_unique<std::string>(2, 'x')` this yields a function
// `make_unique<std::string, int, char>` with two parameters.
// The visitor records that those two parameters are forwarded to the
// `constructor std::string(int, char);`.
//
// This information is recorded in the `ForwardingInfo` split into fully
// resolved parameters (passed as argument to a parameter that is not an
// expanded template type parameter pack) and forwarding parameters (passed to a
// parameter that is an expanded template type parameter pack).
class ForwardingCallVisitor
: public RecursiveASTVisitor<ForwardingCallVisitor> {
public:
ForwardingCallVisitor(ArrayRef<const ParmVarDecl *> Parameters)
: Parameters{Parameters}, PackType{getUnderylingPackType(
Parameters.front())} {}

bool VisitCallExpr(CallExpr *E) {
auto *Callee = getCalleeDeclOrUniqueOverload(E);
if (Callee) {
handleCall(Callee, E->arguments());
}
return !Info.hasValue();
}

bool VisitCXXConstructExpr(CXXConstructExpr *E) {
auto *Callee = E->getConstructor();
if (Callee) {
handleCall(Callee, E->arguments());
}
return !Info.hasValue();
}

// The expanded parameter pack to be resolved
ArrayRef<const ParmVarDecl *> Parameters;
// The type of the parameter pack
const TemplateTypeParmType *PackType;

struct ForwardingInfo {
// If the parameters were resolved to another FunctionDecl, these are its
// first non-variadic parameters (i.e. the first entries of the parameter
// pack that are passed as arguments bound to a non-pack parameter.)
ArrayRef<const ParmVarDecl *> Head;
// If the parameters were resolved to another FunctionDecl, these are its
// variadic parameters (i.e. the entries of the parameter pack that are
// passed as arguments bound to a pack parameter.)
ArrayRef<const ParmVarDecl *> Pack;
// If the parameters were resolved to another FunctionDecl, these are its
// last non-variadic parameters (i.e. the last entries of the parameter pack
// that are passed as arguments bound to a non-pack parameter.)
ArrayRef<const ParmVarDecl *> Tail;
// If the parameters were resolved to another forwarding FunctionDecl, this
// is it.
Optional<FunctionDecl *> PackTarget;
};

// The output of this visitor
Optional<ForwardingInfo> Info;

private:
// inspects the given callee with the given args to check whether it
// contains Parameters, and sets Info accordingly.
void handleCall(FunctionDecl *Callee, typename CallExpr::arg_range Args) {
if (std::any_of(Args.begin(), Args.end(), [](const Expr *E) {
return dyn_cast<PackExpansionExpr>(E) != nullptr;
})) {
return;
}
auto OptPackLocation = findPack(Args);
if (OptPackLocation) {
size_t PackLocation = OptPackLocation.getValue();
ArrayRef<ParmVarDecl *> MatchingParams =
Callee->parameters().slice(PackLocation, Parameters.size());
// Check whether the function has a parameter pack as the last template
// parameter
if (const auto *TTPT = getFunctionPackType(Callee)) {
// In this case: Separate the parameters into head, pack and tail
auto IsExpandedPack = [&](const ParmVarDecl *P) {
return getUnderylingPackType(P) == TTPT;
};
ForwardingInfo FI;
FI.Head = MatchingParams.take_until(IsExpandedPack);
FI.Pack = MatchingParams.drop_front(FI.Head.size())
.take_while(IsExpandedPack);
FI.Tail = MatchingParams.drop_front(FI.Head.size() + FI.Pack.size());
FI.PackTarget = Callee;
Info = FI;
return;
}
// Default case: assume all parameters were fully resolved
ForwardingInfo FI;
FI.Head = MatchingParams;
Info = FI;
}
}

// Returns the beginning of the expanded pack represented by Parameters
// in the given arguments, if it is there.
llvm::Optional<size_t> findPack(typename CallExpr::arg_range Args) {
// find the argument directly referring to the first parameter
auto FirstMatch = std::find_if(Args.begin(), Args.end(), [&](Expr *Arg) {
const auto *RefArg = unwrapArgument(Arg);
if (RefArg) {
if (Parameters.front() == dyn_cast<ParmVarDecl>(RefArg->getDecl())) {
return true;
}
}
return false;
});
if (FirstMatch == Args.end()) {
return llvm::None;
}
return std::distance(Args.begin(), FirstMatch);
}

static FunctionDecl *getCalleeDeclOrUniqueOverload(CallExpr *E) {
Decl *CalleeDecl = E->getCalleeDecl();
auto *Callee = dyn_cast_or_null<FunctionDecl>(CalleeDecl);
if (!Callee) {
if (auto *Lookup = dyn_cast<UnresolvedLookupExpr>(E->getCallee())) {
Callee = resolveOverload(Lookup, E);
}
}
// Ignore the callee if the number of arguments is wrong (deal with va_args)
if (Callee->getNumParams() == E->getNumArgs())
return Callee;
return nullptr;
}

static FunctionDecl *resolveOverload(UnresolvedLookupExpr *Lookup,
CallExpr *E) {
FunctionDecl *MatchingDecl = nullptr;
if (!Lookup->requiresADL()) {
// Check whether there is a single overload with this number of
// parameters
for (auto *Candidate : Lookup->decls()) {
if (auto *FuncCandidate = dyn_cast_or_null<FunctionDecl>(Candidate)) {
if (FuncCandidate->getNumParams() == E->getNumArgs()) {
if (MatchingDecl) {
// there are multiple candidates - abort
return nullptr;
}
MatchingDecl = FuncCandidate;
}
}
}
}
return MatchingDecl;
}

// Removes any implicit cast expressions around the given expression.
static const Expr *unwrapImplicitCast(const Expr *E) {
while (const auto *Cast = dyn_cast<ImplicitCastExpr>(E)) {
E = Cast->getSubExpr();
}
return E;
}

// Maps std::forward(E) to E, nullptr otherwise
static const Expr *unwrapForward(const Expr *E) {
if (const auto *Call = dyn_cast<CallExpr>(E)) {
const auto Callee = Call->getBuiltinCallee();
if (Callee == Builtin::BIforward) {
return Call->getArg(0);
}
}
return E;
}

// Maps std::forward(DeclRefExpr) to DeclRefExpr, removing any intermediate
// implicit casts, nullptr otherwise
static const DeclRefExpr *unwrapArgument(const Expr *E) {
E = unwrapImplicitCast(E);
E = unwrapForward(E);
E = unwrapImplicitCast(E);
return dyn_cast<DeclRefExpr>(E);
}
};

} // namespace

SmallVector<const ParmVarDecl *>
resolveForwardingParameters(const FunctionDecl *D, unsigned MaxDepth) {
auto Parameters = D->parameters();
// If the function has a template parameter pack
if (const auto *TTPT = getFunctionPackType(D)) {
// Split the parameters into head, pack and tail
auto IsExpandedPack = [TTPT](const ParmVarDecl *P) {
return getUnderylingPackType(P) == TTPT;
};
ArrayRef<const ParmVarDecl *> Head = Parameters.take_until(IsExpandedPack);
ArrayRef<const ParmVarDecl *> Pack =
Parameters.drop_front(Head.size()).take_while(IsExpandedPack);
ArrayRef<const ParmVarDecl *> Tail =
Parameters.drop_front(Head.size() + Pack.size());
SmallVector<const ParmVarDecl *> Result(Parameters.size());
// Fill in non-pack parameters
auto HeadIt = std::copy(Head.begin(), Head.end(), Result.begin());
auto TailIt = std::copy(Tail.rbegin(), Tail.rend(), Result.rbegin());
// Recurse on pack parameters
size_t Depth = 0;
const FunctionDecl *CurrentFunction = D;
llvm::SmallSet<const FunctionTemplateDecl *, 4> SeenTemplates;
if (const auto *Template = D->getPrimaryTemplate()) {
SeenTemplates.insert(Template);
}
while (!Pack.empty() && CurrentFunction && Depth < MaxDepth) {
// Find call expressions involving the pack
ForwardingCallVisitor V{Pack};
V.TraverseStmt(CurrentFunction->getBody());
if (!V.Info) {
break;
}
// If we found something: Fill in non-pack parameters
auto Info = V.Info.getValue();
HeadIt = std::copy(Info.Head.begin(), Info.Head.end(), HeadIt);
TailIt = std::copy(Info.Tail.rbegin(), Info.Tail.rend(), TailIt);
// Prepare next recursion level
Pack = Info.Pack;
CurrentFunction = Info.PackTarget.getValueOr(nullptr);
Depth++;
// If we are recursing into a previously encountered function: Abort
if (CurrentFunction) {
if (const auto *Template = CurrentFunction->getPrimaryTemplate()) {
bool NewFunction = SeenTemplates.insert(Template).second;
if (!NewFunction) {
return {Parameters.begin(), Parameters.end()};
}
}
}
}
// Fill in the remaining unresolved pack parameters
HeadIt = std::copy(Pack.begin(), Pack.end(), HeadIt);
assert(TailIt.base() == HeadIt);
return Result;
}
return {Parameters.begin(), Parameters.end()};
}

bool isExpandedFromParameterPack(const ParmVarDecl *D) {
return getUnderylingPackType(D) != nullptr;
}

} // namespace clangd
} // namespace clang
12 changes: 12 additions & 0 deletions clang-tools-extra/clangd/AST.h
Expand Up @@ -199,6 +199,18 @@ bool hasUnstableLinkage(const Decl *D);
/// reasonable.
bool isDeeplyNested(const Decl *D, unsigned MaxDepth = 10);

/// Recursively resolves the parameters of a FunctionDecl that forwards its
/// parameters to another function via variadic template parameters. This can
/// for example be used to retrieve the constructor parameter ParmVarDecl for a
/// make_unique or emplace_back call.
llvm::SmallVector<const ParmVarDecl *>
resolveForwardingParameters(const FunctionDecl *D, unsigned MaxDepth = 10);

/// Checks whether D is instantiated from a function parameter pack
/// whose type is a bare type parameter pack (e.g. `Args...`), or a
/// reference to one (e.g. `Args&...` or `Args&&...`).
bool isExpandedFromParameterPack(const ParmVarDecl *D);

} // namespace clangd
} // namespace clang

Expand Down

0 comments on commit a638648

Please sign in to comment.