diff --git a/tools/internal/navigable_ast_internal.h b/tools/internal/navigable_ast_internal.h index 3804ff79a..749bde68e 100644 --- a/tools/internal/navigable_ast_internal.h +++ b/tools/internal/navigable_ast_internal.h @@ -16,6 +16,7 @@ #include #include +#include #include "absl/log/absl_check.h" #include "absl/types/span.h" @@ -54,6 +55,13 @@ class NavigableAstRange { return RangeTraits::Adapt(*ptr_); } + template + std::enable_if_t::value, + std::add_pointer_t>> + operator->() const { + return &operator*(); + } + Iterator& operator++() { ++ptr_; return *this; diff --git a/tools/navigable_ast.cc b/tools/navigable_ast.cc index 8e2a5e262..d86d31dff 100644 --- a/tools/navigable_ast.cc +++ b/tools/navigable_ast.cc @@ -14,6 +14,7 @@ #include "tools/navigable_ast.h" +#include #include #include #include @@ -150,6 +151,7 @@ class NavigableExprBuilderVisitor node_data.parent_relation = ChildKind::kUnspecified; node_data.node_kind = GetNodeKind(*expr); node_data.tree_size = 1; + node_data.height = 1; node_data.index = index; node_data.metadata = metadata_.get(); @@ -174,6 +176,8 @@ class NavigableExprBuilderVisitor tools_internal::AstNodeData& parent_node_data = metadata_->NodeDataAt(parent_stack_.back()); parent_node_data.tree_size += node.tree_size; + parent_node_data.height = + std::max(parent_node_data.height, node.height + 1); } } diff --git a/tools/navigable_ast.h b/tools/navigable_ast.h index e5bc460a9..c3e57eadb 100644 --- a/tools/navigable_ast.h +++ b/tools/navigable_ast.h @@ -94,6 +94,7 @@ struct AstNodeData { const AstMetadata* metadata; size_t index; size_t tree_size; + size_t height; std::vector children; }; @@ -157,6 +158,13 @@ class AstNode { // The type of this node, analogous to Expr::ExprKindCase. NodeKind node_kind() const { return data_.node_kind; } + // The number of nodes in the tree rooted at this node (including self). + size_t tree_size() const { return data_.tree_size; } + + // The height of this node in the tree (the number of descendants including + // self on the longest path). + size_t height() const { return data_.height; } + absl::Span children() const { return absl::MakeConstSpan(data_.children); } diff --git a/tools/navigable_ast_test.cc b/tools/navigable_ast_test.cc index 63b4ebd5c..5e28d67f0 100644 --- a/tools/navigable_ast_test.cc +++ b/tools/navigable_ast_test.cc @@ -341,6 +341,36 @@ TEST(NavigableAst, DescendantsPreorderComprehension) { Pair(NodeKind::kIdent, ChildKind::kComprensionResult))); } +TEST(NavigableAst, TreeSize) { + ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("[1, 2, 3].map(x, x + 1)")); + + NavigableAst ast = NavigableAst::Build(parsed_expr.expr()); + const AstNode& root = ast.Root(); + + EXPECT_EQ(root.node_kind(), NodeKind::kComprehension); + + std::vector> node_kinds; + + EXPECT_EQ(root.tree_size(), 14); + auto it = root.DescendantsPostorder().begin(); + EXPECT_EQ(it->tree_size(), 1); +} + +TEST(NavigableAst, Height) { + ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("[1, 2, 3].map(x, x + 1)")); + + NavigableAst ast = NavigableAst::Build(parsed_expr.expr()); + const AstNode& root = ast.Root(); + + EXPECT_EQ(root.node_kind(), NodeKind::kComprehension); + + std::vector> node_kinds; + + EXPECT_EQ(root.height(), 5); + auto it = root.DescendantsPostorder().begin(); + EXPECT_EQ(it->height(), 1); +} + TEST(NavigableAst, DescendantsPreorderCreateMap) { ASSERT_OK_AND_ASSIGN(auto parsed_expr, Parse("{'key1': 1, 'key2': 2}"));