Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions tools/branch_coverage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class BranchCoverageImpl : public BranchCoverage {
BranchCoverage::NodeCoverageStats StatsForNode(
int64_t expr_id) const override;

const NavigableAst& ast() const override;
const NavigableProtoAst& ast() const override;
const CheckedExpr& expr() const override;

// Initializes the coverage implementation. This should be called by the
Expand All @@ -124,10 +124,10 @@ class BranchCoverageImpl : public BranchCoverage {

// Infer it the node is boolean typed. Check the type map if available.
// Otherwise infer typing based on built-in functions.
bool InferredBoolType(const AstNode& node) const;
bool InferredBoolType(const NavigableProtoAstNode& node) const;

CheckedExpr expr_;
NavigableAst ast_;
NavigableProtoAst ast_;
mutable absl::Mutex coverage_nodes_mu_;
absl::flat_hash_map<int64_t, CoverageNode> coverage_nodes_
ABSL_GUARDED_BY(coverage_nodes_mu_);
Expand Down Expand Up @@ -167,11 +167,12 @@ BranchCoverage::NodeCoverageStats BranchCoverageImpl::StatsForNode(
return stats;
}

const NavigableAst& BranchCoverageImpl::ast() const { return ast_; }
const NavigableProtoAst& BranchCoverageImpl::ast() const { return ast_; }

const CheckedExpr& BranchCoverageImpl::expr() const { return expr_; }

bool BranchCoverageImpl::InferredBoolType(const AstNode& node) const {
bool BranchCoverageImpl::InferredBoolType(
const NavigableProtoAstNode& node) const {
int64_t expr_id = node.expr()->id();
const auto* checker_type = FindCheckerType(expr_, expr_id);
if (checker_type != nullptr) {
Expand All @@ -183,8 +184,8 @@ bool BranchCoverageImpl::InferredBoolType(const AstNode& node) const {
}

void BranchCoverageImpl::Init() ABSL_NO_THREAD_SAFETY_ANALYSIS {
ast_ = NavigableAst::Build(expr_.expr());
for (const AstNode& node : ast_.Root().DescendantsPreorder()) {
ast_ = NavigableProtoAst::Build(expr_.expr());
for (const NavigableProtoAstNode& node : ast_.Root().DescendantsPreorder()) {
int64_t expr_id = node.expr()->id();

CoverageNode& coverage_node = coverage_nodes_[expr_id];
Expand Down
3 changes: 2 additions & 1 deletion tools/branch_coverage.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ class BranchCoverage {

virtual NodeCoverageStats StatsForNode(int64_t expr_id) const = 0;

virtual const NavigableAst& ast() const ABSL_ATTRIBUTE_LIFETIME_BOUND = 0;
virtual const NavigableProtoAst& ast() const
ABSL_ATTRIBUTE_LIFETIME_BOUND = 0;
virtual const cel::expr::CheckedExpr& expr() const
ABSL_ATTRIBUTE_LIFETIME_BOUND = 0;
};
Expand Down
50 changes: 28 additions & 22 deletions tools/navigable_ast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ NodeKind GetNodeKind(const Expr& expr) {

// Get the traversal relationship from parent to the given node.
// Note: these depend on the ast_visitor utility's traversal ordering.
ChildKind GetChildKind(
const common_internal::NavigableAstNodeData<AstNode>& parent_node,
size_t child_index) {
ChildKind GetChildKind(const common_internal::NavigableAstNodeData<
NavigableProtoAstNode>& parent_node,
size_t child_index) {
constexpr size_t kComprehensionRangeArgIndex =
google::api::expr::runtime::ITER_RANGE;
constexpr size_t kComprehensionInitArgIndex =
Expand Down Expand Up @@ -122,26 +122,27 @@ class NavigableExprBuilderVisitor
: public google::api::expr::runtime::AstVisitorBase {
public:
NavigableExprBuilderVisitor(
absl::AnyInvocable<std::unique_ptr<AstNode>()> node_factory,
absl::AnyInvocable<
common_internal::NavigableAstNodeData<AstNode>&(AstNode&)>
absl::AnyInvocable<std::unique_ptr<NavigableProtoAstNode>()> node_factory,
absl::AnyInvocable<common_internal::NavigableAstNodeData<
NavigableProtoAstNode>&(NavigableProtoAstNode&)>
node_data_accessor)
: node_factory_(std::move(node_factory)),
node_data_accessor_(std::move(node_data_accessor)),
metadata_(std::make_unique<
common_internal::NavigableAstMetadata<AstNode>>()) {}
metadata_(std::make_unique<common_internal::NavigableAstMetadata<
NavigableProtoAstNode>>()) {}

common_internal::NavigableAstNodeData<AstNode>& NodeDataAt(size_t index) {
common_internal::NavigableAstNodeData<NavigableProtoAstNode>& NodeDataAt(
size_t index) {
return node_data_accessor_(*metadata_->nodes[index]);
}

void PreVisitExpr(const Expr* expr, const SourcePosition* position) override {
AstNode* parent = parent_stack_.empty()
? nullptr
: metadata_->nodes[parent_stack_.back()].get();
NavigableProtoAstNode* parent =
parent_stack_.empty() ? nullptr
: metadata_->nodes[parent_stack_.back()].get();
size_t index = metadata_->nodes.size();
metadata_->nodes.push_back(node_factory_());
AstNode* node = metadata_->nodes[index].get();
NavigableProtoAstNode* node = metadata_->nodes[index].get();
auto& node_data = NodeDataAt(index);
node_data.parent = parent;
node_data.expr = expr;
Expand Down Expand Up @@ -170,7 +171,8 @@ class NavigableExprBuilderVisitor
size_t idx = parent_stack_.back();
parent_stack_.pop_back();
metadata_->postorder.push_back(metadata_->nodes[idx].get());
common_internal::NavigableAstNodeData<AstNode>& node = NodeDataAt(idx);
common_internal::NavigableAstNodeData<NavigableProtoAstNode>& node =
NodeDataAt(idx);
if (!parent_stack_.empty()) {
auto& parent_node_data = NodeDataAt(parent_stack_.back());
parent_node_data.tree_size += node.tree_size;
Expand All @@ -179,28 +181,32 @@ class NavigableExprBuilderVisitor
}
}

std::unique_ptr<common_internal::NavigableAstMetadata<AstNode>> Consume() && {
std::unique_ptr<common_internal::NavigableAstMetadata<NavigableProtoAstNode>>
Consume() && {
return std::move(metadata_);
}

private:
absl::AnyInvocable<std::unique_ptr<AstNode>()> node_factory_;
absl::AnyInvocable<common_internal::NavigableAstNodeData<AstNode>&(AstNode&)>
absl::AnyInvocable<std::unique_ptr<NavigableProtoAstNode>()> node_factory_;
absl::AnyInvocable<common_internal::NavigableAstNodeData<
NavigableProtoAstNode>&(NavigableProtoAstNode&)>
node_data_accessor_;
std::unique_ptr<common_internal::NavigableAstMetadata<AstNode>> metadata_;
std::unique_ptr<common_internal::NavigableAstMetadata<NavigableProtoAstNode>>
metadata_;
std::vector<size_t> parent_stack_;
};

} // namespace

NavigableAst NavigableAst::Build(const Expr& expr) {
NavigableProtoAst NavigableProtoAst::Build(const Expr& expr) {
NavigableExprBuilderVisitor visitor(
[]() { return absl::WrapUnique(new AstNode()); },
[](AstNode& node) -> common_internal::NavigableAstNodeData<AstNode>& {
[]() { return absl::WrapUnique(new NavigableProtoAstNode()); },
[](NavigableProtoAstNode& node)
-> common_internal::NavigableAstNodeData<NavigableProtoAstNode>& {
return node.data_;
});
AstTraverse(&expr, /*source_info=*/nullptr, &visitor);
return NavigableAst(std::move(visitor).Consume());
return NavigableProtoAst(std::move(visitor).Consume());
}

} // namespace cel
73 changes: 41 additions & 32 deletions tools/navigable_ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,29 @@

namespace cel {

class NavigableAst;
class NavigableProtoAst;

// Wrapper around a CEL AST node that exposes traversal information.
class AstNode {
class NavigableProtoAstNode {
public:
using ExprType = const cel::expr::Expr;

// A const Span like type that provides pre-order traversal for a sub tree.
// provides .begin() and .end() returning bidirectional iterators to
// const AstNode&.
// const NavigableProtoAstNode&.
using PreorderRange = common_internal::NavigableAstRange<
common_internal::PreorderTraits<AstNode>>;
common_internal::PreorderTraits<NavigableProtoAstNode>>;

// A const Span like type that provides post-order traversal for a sub tree.
// provides .begin() and .end() returning bidirectional iterators to
// const AstNode&.
// const NavigableProtoAstNode&.
using PostorderRange = common_internal::NavigableAstRange<
common_internal::PostorderTraits<AstNode>>;
common_internal::PostorderTraits<NavigableProtoAstNode>>;

// The parent of this node or nullptr if it is a root.
const AstNode* absl_nullable parent() const { return data_.parent; }
const NavigableProtoAstNode* absl_nullable parent() const {
return data_.parent;
}

const cel::expr::Expr* absl_nonnull expr() const {
return data_.expr;
Expand All @@ -72,15 +74,16 @@ class AstNode {
// self on the longest path).
size_t height() const { return data_.height; }

absl::Span<const AstNode* const> children() const {
absl::Span<const NavigableProtoAstNode* const> children() const {
return absl::MakeConstSpan(data_.children);
}

// Range over the descendants of this node (including self) using preorder
// semantics. Each node is visited immediately before all of its descendants.
//
// example:
// for (const cel::AstNode& node : ast.Root().DescendantsPreorder()) {
// for (const cel::NavigableProtoAstNode& node :
// ast.Root().DescendantsPreorder()) {
// ...
// }
//
Expand All @@ -103,13 +106,13 @@ class AstNode {
}

private:
friend class NavigableAst;
friend class NavigableProtoAst;

AstNode() = default;
AstNode(const AstNode&) = delete;
AstNode& operator=(const AstNode&) = delete;
NavigableProtoAstNode() = default;
NavigableProtoAstNode(const NavigableProtoAstNode&) = delete;
NavigableProtoAstNode& operator=(const NavigableProtoAstNode&) = delete;

common_internal::NavigableAstNodeData<AstNode> data_;
common_internal::NavigableAstNodeData<NavigableProtoAstNode> data_;
};

// NavigableExpr provides a view over a CEL AST that allows for generalized
Expand All @@ -120,32 +123,32 @@ class AstNode {
//
// Pointers to AstNodes are owned by this instance and must not outlive it.
//
// `NavigableAst` and Navigable nodes are independent of the input Expr and may
// outlive it, but may contain dangling pointers if the input Expr is modified
// or destroyed.
class NavigableAst {
// `NavigableProtoAst` and Navigable nodes are independent of the input Expr and
// may outlive it, but may contain dangling pointers if the input Expr is
// modified or destroyed.
class NavigableProtoAst {
public:
static NavigableAst Build(const cel::expr::Expr& expr);
static NavigableProtoAst Build(const cel::expr::Expr& expr);

// Default constructor creates an empty instance.
//
// Operations other than equality are undefined on an empty instance.
//
// This is intended for composed object construction, a new NavigableAst
// This is intended for composed object construction, a new NavigableProtoAst
// should be obtained from the Build factory function.
NavigableAst() = default;
NavigableProtoAst() = default;

// Move only.
NavigableAst(const NavigableAst&) = delete;
NavigableAst& operator=(const NavigableAst&) = delete;
NavigableAst(NavigableAst&&) = default;
NavigableAst& operator=(NavigableAst&&) = default;
NavigableProtoAst(const NavigableProtoAst&) = delete;
NavigableProtoAst& operator=(const NavigableProtoAst&) = delete;
NavigableProtoAst(NavigableProtoAst&&) = default;
NavigableProtoAst& operator=(NavigableProtoAst&&) = default;

// Return ptr to the AST node with id if present. Otherwise returns nullptr.
//
// If ids are non-unique, the first pre-order node encountered with id is
// returned.
const AstNode* absl_nullable FindId(int64_t id) const {
const NavigableProtoAstNode* absl_nullable FindId(int64_t id) const {
auto it = metadata_->id_to_node.find(id);
if (it == metadata_->id_to_node.end()) {
return nullptr;
Expand All @@ -154,7 +157,7 @@ class NavigableAst {
}

// Return ptr to the AST node representing the given Expr protobuf node.
const AstNode* absl_nullable FindExpr(
const NavigableProtoAstNode* absl_nullable FindExpr(
const cel::expr::Expr* expr) const {
auto it = metadata_->expr_to_node.find(expr);
if (it == metadata_->expr_to_node.end()) {
Expand All @@ -164,7 +167,7 @@ class NavigableAst {
}

// The root of the AST.
const AstNode& Root() const { return *metadata_->nodes[0]; }
const NavigableProtoAstNode& Root() const { return *metadata_->nodes[0]; }

// Check whether the source AST used unique IDs for each node.
//
Expand All @@ -177,26 +180,32 @@ class NavigableAst {

// Equality operators test for identity. They are intended to distinguish
// moved-from or uninitialized instances from initialized.
bool operator==(const NavigableAst& other) const {
bool operator==(const NavigableProtoAst& other) const {
return metadata_ == other.metadata_;
}

bool operator!=(const NavigableAst& other) const {
bool operator!=(const NavigableProtoAst& other) const {
return metadata_ != other.metadata_;
}

// Return true if this instance is initialized.
explicit operator bool() const { return metadata_ != nullptr; }

private:
using AstMetadata = common_internal::NavigableAstMetadata<AstNode>;
using AstMetadata =
common_internal::NavigableAstMetadata<NavigableProtoAstNode>;

explicit NavigableAst(std::unique_ptr<AstMetadata> metadata)
explicit NavigableProtoAst(std::unique_ptr<AstMetadata> metadata)
: metadata_(std::move(metadata)) {}

std::unique_ptr<AstMetadata> metadata_;
};

// Type aliases for backwards compatibility.
// To be removed.
using AstNode = NavigableProtoAstNode;
using NavigableAst = NavigableProtoAst;

} // namespace cel

#endif // THIRD_PARTY_CEL_CPP_TOOLS_NAVIGABLE_AST_H_
Loading