Skip to content

Commit

Permalink
[ASTMatchers] Fix hasParent while ignoring unwritten nodes
Browse files Browse the repository at this point in the history
For example, before this patch we can use has() to get from a
cxxRewrittenBinaryOperator to its operand, but hasParent doesn't get
back to the cxxRewrittenBinaryOperator.  This patch fixes that.

Differential Revision: https://reviews.llvm.org/D96113
  • Loading branch information
steveire committed Feb 18, 2021
1 parent 25aa0d1 commit e4d5f00
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 6 deletions.
3 changes: 2 additions & 1 deletion clang/include/clang/AST/ParentMapContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,10 @@ class ParentMapContext {
Expr *traverseIgnored(Expr *E) const;
DynTypedNode traverseIgnored(const DynTypedNode &N) const;

class ParentMap;

private:
ASTContext &ASTCtx;
class ParentMap;
TraversalKind Traversal = TK_AsIs;
std::unique_ptr<ParentMap> Parents;
};
Expand Down
134 changes: 129 additions & 5 deletions clang/lib/AST/ParentMapContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,17 @@ DynTypedNode ParentMapContext::traverseIgnored(const DynTypedNode &N) const {
return N;
}

template <typename T, typename... U>
std::tuple<bool, DynTypedNodeList, const T *, const U *...>
matchParents(const DynTypedNodeList &NodeList,
ParentMapContext::ParentMap *ParentMap);

template <typename, typename...> struct MatchParents;

class ParentMapContext::ParentMap {

template <typename, typename...> friend struct ::MatchParents;

/// Contains parents of a node.
using ParentVector = llvm::SmallVector<DynTypedNode, 2>;

Expand Down Expand Up @@ -117,11 +127,72 @@ class ParentMapContext::ParentMap {
if (Node.getNodeKind().hasPointerIdentity()) {
auto ParentList =
getDynNodeFromMap(Node.getMemoizationData(), PointerParents);
if (ParentList.size() == 1 && TK == TK_IgnoreUnlessSpelledInSource) {
const auto *E = ParentList[0].get<Expr>();
const auto *Child = Node.get<Expr>();
if (E && Child)
return AscendIgnoreUnlessSpelledInSource(E, Child);
if (ParentList.size() > 0 && TK == TK_IgnoreUnlessSpelledInSource) {

const auto *ChildExpr = Node.get<Expr>();

{
// Don't match explicit node types because different stdlib
// implementations implement this in different ways and have
// different intermediate nodes.
// Look up 4 levels for a cxxRewrittenBinaryOperator as that is
// enough for the major stdlib implementations.
auto RewrittenBinOpParentsList = ParentList;
int I = 0;
while (ChildExpr && RewrittenBinOpParentsList.size() == 1 &&
I++ < 4) {
const auto *S = RewrittenBinOpParentsList[0].get<Stmt>();
if (!S)
break;

const auto *RWBO = dyn_cast<CXXRewrittenBinaryOperator>(S);
if (!RWBO) {
RewrittenBinOpParentsList = getDynNodeFromMap(S, PointerParents);
continue;
}
if (RWBO->getLHS()->IgnoreUnlessSpelledInSource() != ChildExpr &&
RWBO->getRHS()->IgnoreUnlessSpelledInSource() != ChildExpr)
break;
return DynTypedNode::create(*RWBO);
}
}

const auto *ParentExpr = ParentList[0].get<Expr>();
if (ParentExpr && ChildExpr)
return AscendIgnoreUnlessSpelledInSource(ParentExpr, ChildExpr);

{
auto AncestorNodes =
matchParents<DeclStmt, CXXForRangeStmt>(ParentList, this);
if (std::get<bool>(AncestorNodes) &&
std::get<const CXXForRangeStmt *>(AncestorNodes)
->getLoopVarStmt() ==
std::get<const DeclStmt *>(AncestorNodes))
return std::get<DynTypedNodeList>(AncestorNodes);
}
{
auto AncestorNodes = matchParents<VarDecl, DeclStmt, CXXForRangeStmt>(
ParentList, this);
if (std::get<bool>(AncestorNodes) &&
std::get<const CXXForRangeStmt *>(AncestorNodes)
->getRangeStmt() ==
std::get<const DeclStmt *>(AncestorNodes))
return std::get<DynTypedNodeList>(AncestorNodes);
}
{
auto AncestorNodes =
matchParents<CXXMethodDecl, CXXRecordDecl, LambdaExpr>(ParentList,
this);
if (std::get<bool>(AncestorNodes))
return std::get<DynTypedNodeList>(AncestorNodes);
}
{
auto AncestorNodes =
matchParents<FunctionTemplateDecl, CXXRecordDecl, LambdaExpr>(
ParentList, this);
if (std::get<bool>(AncestorNodes))
return std::get<DynTypedNodeList>(AncestorNodes);
}
}
return ParentList;
}
Expand Down Expand Up @@ -194,6 +265,59 @@ class ParentMapContext::ParentMap {
}
};

template <typename Tuple, std::size_t... Is>
auto tuple_pop_front_impl(const Tuple &tuple, std::index_sequence<Is...>) {
return std::make_tuple(std::get<1 + Is>(tuple)...);
}

template <typename Tuple> auto tuple_pop_front(const Tuple &tuple) {
return tuple_pop_front_impl(
tuple, std::make_index_sequence<std::tuple_size<Tuple>::value - 1>());
}

template <typename T, typename... U> struct MatchParents {
static std::tuple<bool, DynTypedNodeList, const T *, const U *...>
match(const DynTypedNodeList &NodeList,
ParentMapContext::ParentMap *ParentMap) {
if (const auto *TypedNode = NodeList[0].get<T>()) {
auto NextParentList =
ParentMap->getDynNodeFromMap(TypedNode, ParentMap->PointerParents);
if (NextParentList.size() == 1) {
auto TailTuple = MatchParents<U...>::match(NextParentList, ParentMap);
if (std::get<bool>(TailTuple)) {
return std::tuple_cat(
std::make_tuple(true, std::get<DynTypedNodeList>(TailTuple),
TypedNode),
tuple_pop_front(tuple_pop_front(TailTuple)));
}
}
}
return std::tuple_cat(std::make_tuple(false, NodeList),
std::tuple<const T *, const U *...>());
}
};

template <typename T> struct MatchParents<T> {
static std::tuple<bool, DynTypedNodeList, const T *>
match(const DynTypedNodeList &NodeList,
ParentMapContext::ParentMap *ParentMap) {
if (const auto *TypedNode = NodeList[0].get<T>()) {
auto NextParentList =
ParentMap->getDynNodeFromMap(TypedNode, ParentMap->PointerParents);
if (NextParentList.size() == 1)
return std::make_tuple(true, NodeList, TypedNode);
}
return std::make_tuple(false, NodeList, nullptr);
}
};

template <typename T, typename... U>
std::tuple<bool, DynTypedNodeList, const T *, const U *...>
matchParents(const DynTypedNodeList &NodeList,
ParentMapContext::ParentMap *ParentMap) {
return MatchParents<T, U...>::match(NodeList, ParentMap);
}

/// Template specializations to abstract away from pointers and TypeLocs.
/// @{
template <typename T> static DynTypedNode createDynTypedNode(const T &Node) {
Expand Down
64 changes: 64 additions & 0 deletions clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2933,6 +2933,37 @@ struct CtorInitsNonTrivial : NonTrivial
EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M)));
EXPECT_FALSE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M)));
}
{
auto M = ifStmt(hasParent(compoundStmt(hasParent(cxxForRangeStmt()))));
EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M)));
EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M)));
}
{
auto M = cxxForRangeStmt(
has(varDecl(hasName("i"), hasParent(cxxForRangeStmt()))));
EXPECT_FALSE(matches(Code, traverse(TK_AsIs, M)));
EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M)));
}
{
auto M = cxxForRangeStmt(hasDescendant(varDecl(
hasName("i"), hasParent(declStmt(hasParent(cxxForRangeStmt()))))));
EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M)));
EXPECT_FALSE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M)));
}
{
auto M = cxxForRangeStmt(hasRangeInit(declRefExpr(
to(varDecl(hasName("arr"))), hasParent(cxxForRangeStmt()))));
EXPECT_FALSE(matches(Code, traverse(TK_AsIs, M)));
EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M)));
}

{
auto M = cxxForRangeStmt(hasRangeInit(declRefExpr(
to(varDecl(hasName("arr"))), hasParent(varDecl(hasParent(declStmt(
hasParent(cxxForRangeStmt()))))))));
EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M)));
EXPECT_FALSE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M)));
}

Code = R"cpp(
struct Range {
Expand Down Expand Up @@ -3035,6 +3066,15 @@ struct CtorInitsNonTrivial : NonTrivial
matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M),
true, {"-std=c++20"}));
}
{
auto M = cxxForRangeStmt(hasInitStatement(declStmt(
hasSingleDecl(varDecl(hasName("a"))), hasParent(cxxForRangeStmt()))));
EXPECT_TRUE(
matchesConditionally(Code, traverse(TK_AsIs, M), true, {"-std=c++20"}));
EXPECT_TRUE(
matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M),
true, {"-std=c++20"}));
}

Code = R"cpp(
struct Range {
Expand Down Expand Up @@ -3511,6 +3551,20 @@ void func15() {
forFunction(functionDecl(hasName("func13"))))))),
langCxx20OrLater()));

EXPECT_TRUE(matches(Code,
traverse(TK_IgnoreUnlessSpelledInSource,
compoundStmt(hasParent(lambdaExpr(forFunction(
functionDecl(hasName("func13"))))))),
langCxx20OrLater()));

EXPECT_TRUE(matches(
Code,
traverse(TK_IgnoreUnlessSpelledInSource,
templateTypeParmDecl(hasName("TemplateType"),
hasParent(lambdaExpr(forFunction(
functionDecl(hasName("func14"))))))),
langCxx20OrLater()));

EXPECT_TRUE(matches(
Code,
traverse(TK_IgnoreUnlessSpelledInSource,
Expand Down Expand Up @@ -3635,6 +3689,16 @@ void binop()
matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M),
true, {"-std=c++20"}));
}
{
auto M = cxxRewrittenBinaryOperator(
hasLHS(expr(hasParent(cxxRewrittenBinaryOperator()))),
hasRHS(expr(hasParent(cxxRewrittenBinaryOperator()))));
EXPECT_FALSE(
matchesConditionally(Code, traverse(TK_AsIs, M), true, {"-std=c++20"}));
EXPECT_TRUE(
matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M),
true, {"-std=c++20"}));
}
{
EXPECT_TRUE(matchesConditionally(
Code,
Expand Down

0 comments on commit e4d5f00

Please sign in to comment.