Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 159 additions & 14 deletions common/ast/navigable_ast_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,46 +103,47 @@ class NavigableAstRange {

explicit NavigableAstRange(SpanType span) : span_(span) {}

Iterator begin() { return Iterator(span_, 0); }
Iterator end() { return Iterator(span_, span_.size()); }
Iterator begin() const { return Iterator(span_, 0); }
Iterator end() const { return Iterator(span_, span_.size()); }

explicit operator bool() const { return !span_.empty(); }

private:
SpanType span_;
};

template <typename AstNode>
template <typename AstTraits>
struct NavigableAstMetadata;

// Internal implementation for data-structures handling cross-referencing nodes.
//
// This is exposed separately to allow building up the AST relationships
// without exposing too much mutable state on the client facing classes.
template <typename AstNode>
template <typename AstTraits>
struct NavigableAstNodeData {
AstNode* parent;
const typename AstNode::ExprType* expr;
typename AstTraits::NodeType* parent;
const typename AstTraits::ExprType* expr;
ChildKind parent_relation;
NodeKind node_kind;
const NavigableAstMetadata<AstNode>* absl_nonnull metadata;
const NavigableAstMetadata<AstTraits>* absl_nonnull metadata;
size_t index;
size_t tree_size;
size_t height;
int child_index;
std::vector<AstNode*> children;
std::vector<typename AstTraits::NodeType* absl_nonnull> children;
};

template <typename AstNode>
template <typename AstTraits>
struct NavigableAstMetadata {
// The nodes in the AST in preorder.
//
// unique_ptr is used to guarantee pointer stability in the other tables.
std::vector<std::unique_ptr<AstNode>> nodes;
std::vector<const AstNode* absl_nonnull> postorder;
absl::flat_hash_map<int64_t, const AstNode* absl_nonnull> id_to_node;
absl::flat_hash_map<const typename AstNode::ExprType*,
const AstNode* absl_nonnull>
std::vector<std::unique_ptr<typename AstTraits::NodeType>> nodes;
std::vector<const typename AstTraits::NodeType* absl_nonnull> postorder;
absl::flat_hash_map<int64_t, const typename AstTraits::NodeType* absl_nonnull>
id_to_node;
absl::flat_hash_map<const typename AstTraits::ExprType*,
const typename AstTraits::NodeType* absl_nonnull>
expr_to_node;
};

Expand All @@ -161,6 +162,150 @@ struct PreorderTraits {
}
};

// Base class for NavigableAstNode and NavigableProtoAstNode.
template <typename AstTraits>
class NavigableAstNodeBase {
private:
using MetadataType = NavigableAstMetadata<AstTraits>;
using NodeDataType = NavigableAstNodeData<AstTraits>;
using Derived = typename AstTraits::NodeType;
using ExprType = typename AstTraits::ExprType;

public:
using PreorderRange = NavigableAstRange<PreorderTraits<Derived>>;
using PostorderRange = NavigableAstRange<PostorderTraits<Derived>>;

// The parent of this node or nullptr if it is a root.
const Derived* absl_nullable parent() const { return data_.parent; }

const ExprType* absl_nonnull expr() const { return data_.expr; }

// The index of this node in the parent's children. -1 if this is a root.
int child_index() const { return data_.child_index; }

// The type of traversal from parent to this node.
ChildKind parent_relation() const { return data_.parent_relation; }

// The type of this node, analogous to Expr::ExprKindCase.
NodeKind node_kind() const { return data_.node_kind; }

// The number of nodes in the tree rooted at this node (including self).
size_t tree_size() const { return data_.tree_size; }

// The height of this node in the tree (the number of descendants including
// self on the longest path).
size_t height() const { return data_.height; }

absl::Span<const Derived* const> children() const {
return absl::MakeConstSpan(data_.children);
}

// Range over the descendants of this node (including self) using preorder
// semantics. Each node is visited immediately before all of its descendants.
PreorderRange DescendantsPreorder() const {
return PreorderRange(absl::MakeConstSpan(data_.metadata->nodes)
.subspan(data_.index, data_.tree_size));
}

// Range over the descendants of this node (including self) using postorder
// semantics. Each node is visited immediately after all of its descendants.
PostorderRange DescendantsPostorder() const {
return PostorderRange(absl::MakeConstSpan(data_.metadata->postorder)
.subspan(data_.index, data_.tree_size));
}

private:
friend Derived;

NavigableAstNodeBase() = default;
NavigableAstNodeBase(const NavigableAstNodeBase&) = delete;
NavigableAstNodeBase& operator=(const NavigableAstNodeBase&) = delete;

protected:
NodeDataType data_;
};

// Shared implementation for NavigableAst and NavigableProtoAst.
//
// AstTraits provides type info for the derived classes that implement building
// the traversal metadata. It provides the following types:
//
// ExprType is the expression node type of the source AST.
//
// AstType is the subclass of NavigableAstBase for the implementation.
//
// NodeType is the subclass of NavigableAstNodeBase for the implementation.
template <class AstTraits>
class NavigableAstBase {
private:
using MetadataType = NavigableAstMetadata<AstTraits>;
using Derived = typename AstTraits::AstType;
using NodeType = typename AstTraits::NodeType;
using ExprType = typename AstTraits::ExprType;

public:
NavigableAstBase(const NavigableAstBase&) = delete;
NavigableAstBase& operator=(const NavigableAstBase&) = delete;
NavigableAstBase(NavigableAstBase&&) = default;
NavigableAstBase& operator=(NavigableAstBase&&) = default;

// Return ptr to the AST node with id if present. Otherwise returns nullptr.
//
// If ids are non-unique, the first pre-order node encountered with id is
// returned.
const NodeType* absl_nullable FindId(int64_t id) const {
auto it = metadata_->id_to_node.find(id);
if (it == metadata_->id_to_node.end()) {
return nullptr;
}
return it->second;
}

// Return ptr to the AST node representing the given Expr protobuf node.
const NodeType* absl_nullable FindExpr(
const ExprType* absl_nonnull expr) const {
auto it = metadata_->expr_to_node.find(expr);
if (it == metadata_->expr_to_node.end()) {
return nullptr;
}
return it->second;
}

// The root of the AST.
const NodeType& Root() const { return *metadata_->nodes[0]; }

// Check whether the source AST used unique IDs for each node.
//
// This is typically the case, but older versions of the parsers didn't
// guarantee uniqueness for nodes generated by some macros and ASTs modified
// outside of CEL's parse/type check may not have unique IDs.
bool IdsAreUnique() const {
return metadata_->id_to_node.size() == metadata_->nodes.size();
}

// Equality operators test for identity. They are intended to distinguish
// moved-from or uninitialized instances from initialized.
bool operator==(const NavigableAstBase& other) const {
return metadata_ == other.metadata_;
}

bool operator!=(const NavigableAstBase& other) const {
return metadata_ != other.metadata_;
}

// Return true if this instance is initialized.
explicit operator bool() const { return metadata_ != nullptr; }

private:
friend Derived;

NavigableAstBase() = default;
explicit NavigableAstBase(std::unique_ptr<MetadataType> metadata)
: metadata_(std::move(metadata)) {}

std::unique_ptr<MetadataType> metadata_;
};

} // namespace cel::common_internal

#endif // THIRD_PARTY_CEL_CPP_COMMON_AST_NAVIGABLE_AST_INTERNAL_H_
42 changes: 17 additions & 25 deletions tools/navigable_ast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ using ::cel::expr::Expr;
using ::google::api::expr::runtime::AstTraverse;
using ::google::api::expr::runtime::SourcePosition;

using NavigableAstNodeData =
common_internal::NavigableAstNodeData<common_internal::ProtoAstTraits>;
using NavigableAstMetadata =
common_internal::NavigableAstMetadata<common_internal::ProtoAstTraits>;

NodeKind GetNodeKind(const Expr& expr) {
switch (expr.expr_kind_case()) {
case Expr::kConstExpr:
Expand Down Expand Up @@ -67,8 +72,7 @@ NodeKind GetNodeKind(const Expr& expr) {

// Get the traversal relationship from parent to the given node.
// Note: these depend on the ast_visitor utility's traversal ordering.
ChildKind GetChildKind(const common_internal::NavigableAstNodeData<
NavigableProtoAstNode>& parent_node,
ChildKind GetChildKind(const NavigableAstNodeData& parent_node,
size_t child_index) {
constexpr size_t kComprehensionRangeArgIndex =
google::api::expr::runtime::ITER_RANGE;
Expand Down Expand Up @@ -122,17 +126,13 @@ class NavigableExprBuilderVisitor
: public google::api::expr::runtime::AstVisitorBase {
public:
NavigableExprBuilderVisitor(
absl::AnyInvocable<std::unique_ptr<NavigableProtoAstNode>()> node_factory,
absl::AnyInvocable<common_internal::NavigableAstNodeData<
NavigableProtoAstNode>&(NavigableProtoAstNode&)>
node_data_accessor)
absl::AnyInvocable<std::unique_ptr<AstNode>()> node_factory,
absl::AnyInvocable<NavigableAstNodeData&(AstNode&)> node_data_accessor)
: node_factory_(std::move(node_factory)),
node_data_accessor_(std::move(node_data_accessor)),
metadata_(std::make_unique<common_internal::NavigableAstMetadata<
NavigableProtoAstNode>>()) {}
metadata_(std::make_unique<NavigableAstMetadata>()) {}

common_internal::NavigableAstNodeData<NavigableProtoAstNode>& NodeDataAt(
size_t index) {
NavigableAstNodeData& NodeDataAt(size_t index) {
return node_data_accessor_(*metadata_->nodes[index]);
}

Expand Down Expand Up @@ -171,8 +171,7 @@ class NavigableExprBuilderVisitor
size_t idx = parent_stack_.back();
parent_stack_.pop_back();
metadata_->postorder.push_back(metadata_->nodes[idx].get());
common_internal::NavigableAstNodeData<NavigableProtoAstNode>& node =
NodeDataAt(idx);
NavigableAstNodeData& node = NodeDataAt(idx);
if (!parent_stack_.empty()) {
auto& parent_node_data = NodeDataAt(parent_stack_.back());
parent_node_data.tree_size += node.tree_size;
Expand All @@ -181,30 +180,23 @@ class NavigableExprBuilderVisitor
}
}

std::unique_ptr<common_internal::NavigableAstMetadata<NavigableProtoAstNode>>
Consume() && {
std::unique_ptr<NavigableAstMetadata> Consume() && {
return std::move(metadata_);
}

private:
absl::AnyInvocable<std::unique_ptr<NavigableProtoAstNode>()> node_factory_;
absl::AnyInvocable<common_internal::NavigableAstNodeData<
NavigableProtoAstNode>&(NavigableProtoAstNode&)>
node_data_accessor_;
std::unique_ptr<common_internal::NavigableAstMetadata<NavigableProtoAstNode>>
metadata_;
absl::AnyInvocable<std::unique_ptr<AstNode>()> node_factory_;
absl::AnyInvocable<NavigableAstNodeData&(AstNode&)> node_data_accessor_;
std::unique_ptr<NavigableAstMetadata> metadata_;
std::vector<size_t> parent_stack_;
};

} // namespace

NavigableProtoAst NavigableProtoAst::Build(const Expr& expr) {
NavigableExprBuilderVisitor visitor(
[]() { return absl::WrapUnique(new NavigableProtoAstNode()); },
[](NavigableProtoAstNode& node)
-> common_internal::NavigableAstNodeData<NavigableProtoAstNode>& {
return node.data_;
});
[]() { return absl::WrapUnique(new AstNode()); },
[](AstNode& node) -> NavigableAstNodeData& { return node.data_; });
AstTraverse(&expr, /*source_info=*/nullptr, &visitor);
return NavigableProtoAst(std::move(visitor).Consume());
}
Expand Down
Loading