From 6a9e920348a685f3c0286f4ae8f4865217f612b0 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 31 Mar 2025 16:55:44 -0700 Subject: [PATCH] Remove temporary vectors in resolver logic for the planner. PiperOrigin-RevId: 742452902 --- eval/compiler/BUILD | 2 + eval/compiler/flat_expr_builder.cc | 7 ++- eval/compiler/resolver.cc | 80 +++++++++++++++++++++++------- eval/compiler/resolver.h | 12 +++++ runtime/function_registry.cc | 65 +++++++++++++++++++++--- runtime/function_registry.h | 8 +++ 6 files changed, 147 insertions(+), 27 deletions(-) diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index 707464ea3..340353259 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -419,10 +419,12 @@ cc_library( "//runtime:function_overload_reference", "//runtime:function_registry", "//runtime:type_registry", + "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", ], ) diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 414c9a0f7..d66037a50 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -1637,12 +1637,11 @@ class FlatExprVisitor : public cel::AstVisitor { // Establish the search criteria for a given function. bool receiver_style = call_expr->has_target(); size_t num_args = call_expr->args().size() + (receiver_style ? 1 : 0); - auto arguments_matcher = ArgumentsMatcher(num_args); // First, search for lazily defined function overloads. // Lazy functions shadow eager functions with the same signature. auto lazy_overloads = resolver_.FindLazyOverloads( - function, call_expr->has_target(), arguments_matcher, expr->id()); + function, call_expr->has_target(), num_args, expr->id()); if (!lazy_overloads.empty()) { auto depth = RecursionEligible(); if (depth.has_value()) { @@ -1659,8 +1658,8 @@ class FlatExprVisitor : public cel::AstVisitor { } // Second, search for eagerly defined function overloads. - auto overloads = resolver_.FindOverloads(function, receiver_style, - arguments_matcher, expr->id()); + auto overloads = + resolver_.FindOverloads(function, receiver_style, num_args, expr->id()); if (overloads.empty()) { // Create a warning that the overload could not be found. Depending on the // builder_warnings configuration, this could result in termination of the diff --git a/eval/compiler/resolver.cc b/eval/compiler/resolver.cc index d1d20e30a..d63067257 100644 --- a/eval/compiler/resolver.cc +++ b/eval/compiler/resolver.cc @@ -14,11 +14,14 @@ #include "eval/compiler/resolver.h" +#include #include +#include #include #include #include +#include "absl/base/no_destructor.h" #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" @@ -27,6 +30,7 @@ #include "absl/strings/string_view.h" #include "absl/strings/strip.h" #include "absl/types/optional.h" +#include "absl/types/span.h" #include "common/kind.h" #include "common/type.h" #include "common/type_reflector.h" @@ -96,37 +100,41 @@ std::vector Resolver::FullyQualifiedNames(absl::string_view name, // and handle the case where this id is in the reference map as either a // function name or identifier name. std::vector names; - // Handle the case where the name contains a leading '.' indicating it is - // already fully-qualified. - if (absl::StartsWith(name, ".")) { - std::string fully_qualified_name = std::string(name.substr(1)); - names.push_back(fully_qualified_name); - return names; - } - // namespace prefixes is guaranteed to contain at least empty string, so this - // function will always produce at least one result. - for (const auto& prefix : namespace_prefixes_) { + auto prefixes = GetPrefixesFor(name); + for (const auto& prefix : prefixes) { std::string fully_qualified_name = absl::StrCat(prefix, name); names.push_back(fully_qualified_name); } return names; } +absl::Span Resolver::GetPrefixesFor( + absl::string_view& name) const { + static const absl::NoDestructor kEmptyPrefix(""); + if (absl::StartsWith(name, ".")) { + name = name.substr(1); + return absl::MakeConstSpan(kEmptyPrefix.get(), 1); + } + return namespace_prefixes_; +} + absl::optional Resolver::FindConstant(absl::string_view name, int64_t expr_id) const { - auto names = FullyQualifiedNames(name, expr_id); - for (const auto& name : names) { + auto prefixes = GetPrefixesFor(name); + for (const auto& prefix : prefixes) { + std::string qualified_name = absl::StrCat(prefix, name); // Attempt to resolve the fully qualified name to a known enum. - auto enum_entry = enum_value_map_.find(name); + auto enum_entry = enum_value_map_.find(qualified_name); if (enum_entry != enum_value_map_.end()) { return enum_entry->second; } // Conditionally resolve fully qualified names as type values if the option // to do so is configured in the expression builder. If the type name is // not qualified, then it too may be returned as a constant value. - if (resolve_qualified_type_identifiers_ || !absl::StrContains(name, ".")) { - auto type_value = type_reflector_.FindType(name); + if (resolve_qualified_type_identifiers_ || + !absl::StrContains(qualified_name, ".")) { + auto type_value = type_reflector_.FindType(qualified_name); if (type_value.ok() && type_value->has_value()) { return TypeValue(**type_value); } @@ -157,6 +165,27 @@ std::vector Resolver::FindOverloads( return funcs; } +std::vector Resolver::FindOverloads( + absl::string_view name, bool receiver_style, size_t arity, + int64_t expr_id) const { + std::vector funcs; + auto prefixes = GetPrefixesFor(name); + for (const auto& prefix : prefixes) { + std::string qualified_name = absl::StrCat(prefix, name); + // Only one set of overloads is returned along the namespace hierarchy as + // the function name resolution follows the same behavior as variable name + // resolution, meaning the most specific definition wins. This is different + // from how C++ namespaces work, as they will accumulate the overload set + // over the namespace hierarchy. + funcs = function_registry_.FindStaticOverloadsByArity( + qualified_name, receiver_style, arity); + if (!funcs.empty()) { + return funcs; + } + } + return funcs; +} + std::vector Resolver::FindLazyOverloads( absl::string_view name, bool receiver_style, const std::vector& types, int64_t expr_id) const { @@ -173,10 +202,27 @@ std::vector Resolver::FindLazyOverloads( return funcs; } +std::vector Resolver::FindLazyOverloads( + absl::string_view name, bool receiver_style, size_t arity, + int64_t expr_id) const { + std::vector funcs; + auto prefixes = GetPrefixesFor(name); + for (const auto& prefix : prefixes) { + std::string qualified_name = absl::StrCat(prefix, name); + funcs = function_registry_.FindLazyOverloadsByArity(name, receiver_style, + arity); + if (!funcs.empty()) { + return funcs; + } + } + return funcs; +} + absl::StatusOr>> Resolver::FindType(absl::string_view name, int64_t expr_id) const { - auto qualified_names = FullyQualifiedNames(name, expr_id); - for (auto& qualified_name : qualified_names) { + auto prefixes = GetPrefixesFor(name); + for (auto& prefix : prefixes) { + std::string qualified_name = absl::StrCat(prefix, name); CEL_ASSIGN_OR_RETURN(auto maybe_type, type_reflector_.FindType(qualified_name)); if (maybe_type.has_value()) { diff --git a/eval/compiler/resolver.h b/eval/compiler/resolver.h index 240635901..c36fcafb9 100644 --- a/eval/compiler/resolver.h +++ b/eval/compiler/resolver.h @@ -15,6 +15,7 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_RESOLVER_H_ #define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_RESOLVER_H_ +#include #include #include #include @@ -24,6 +25,7 @@ #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "absl/types/span.h" #include "common/kind.h" #include "common/type_reflector.h" #include "common/value.h" @@ -75,18 +77,28 @@ class Resolver { absl::string_view name, bool receiver_style, const std::vector& types, int64_t expr_id = -1) const; + std::vector FindLazyOverloads( + absl::string_view name, bool receiver_style, size_t arity, + int64_t expr_id = -1) const; + // FindOverloads returns the set, possibly empty, of eager function overloads // matching the given function signature. std::vector FindOverloads( absl::string_view name, bool receiver_style, const std::vector& types, int64_t expr_id = -1) const; + std::vector FindOverloads( + absl::string_view name, bool receiver_style, size_t arity, + int64_t expr_id = -1) const; + // FullyQualifiedNames returns the set of fully qualified names which may be // derived from the base_name within the specified expression container. std::vector FullyQualifiedNames(absl::string_view base_name, int64_t expr_id = -1) const; private: + absl::Span GetPrefixesFor(absl::string_view& name) const; + std::vector namespace_prefixes_; absl::flat_hash_map enum_value_map_; const cel::FunctionRegistry& function_registry_; diff --git a/runtime/function_registry.cc b/runtime/function_registry.cc index 6959b22c0..ac1e53eb5 100644 --- a/runtime/function_registry.cc +++ b/runtime/function_registry.cc @@ -14,6 +14,7 @@ #include "runtime/function_registry.h" +#include #include #include #include @@ -134,6 +135,27 @@ FunctionRegistry::FindStaticOverloads(absl::string_view name, return matched_funcs; } +std::vector +FunctionRegistry::FindStaticOverloadsByArity(absl::string_view name, + bool receiver_style, + size_t arity) const { + std::vector matched_funcs; + + auto overloads = functions_.find(name); + if (overloads == functions_.end()) { + return matched_funcs; + } + + for (const auto& overload : overloads->second.static_overloads) { + if (overload.descriptor->receiver_style() == receiver_style && + overload.descriptor->types().size() == arity) { + matched_funcs.push_back({*overload.descriptor, *overload.implementation}); + } + } + + return matched_funcs; +} + std::vector FunctionRegistry::FindLazyOverloads( absl::string_view name, bool receiver_style, absl::Span types) const { @@ -153,6 +175,27 @@ std::vector FunctionRegistry::FindLazyOverloads( return matched_funcs; } +std::vector +FunctionRegistry::FindLazyOverloadsByArity(absl::string_view name, + bool receiver_style, + size_t arity) const { + std::vector matched_funcs; + + auto overloads = functions_.find(name); + if (overloads == functions_.end()) { + return matched_funcs; + } + + for (const auto& entry : overloads->second.lazy_overloads) { + if (entry.descriptor->receiver_style() == receiver_style && + entry.descriptor->types().size() == arity) { + matched_funcs.push_back({*entry.descriptor, *entry.function_provider}); + } + } + + return matched_funcs; +} + absl::node_hash_map> FunctionRegistry::ListFunctions() const { absl::node_hash_map> @@ -177,12 +220,22 @@ FunctionRegistry::ListFunctions() const { bool FunctionRegistry::DescriptorRegistered( const cel::FunctionDescriptor& descriptor) const { - return !(FindStaticOverloads(descriptor.name(), descriptor.receiver_style(), - descriptor.types()) - .empty()) || - !(FindLazyOverloads(descriptor.name(), descriptor.receiver_style(), - descriptor.types()) - .empty()); + auto overloads = functions_.find(descriptor.name()); + if (overloads == functions_.end()) { + return false; + } + const RegistryEntry& entry = overloads->second; + for (const auto& static_ovl : entry.static_overloads) { + if (static_ovl.descriptor->ShapeMatches(descriptor)) { + return true; + } + } + for (const auto& lazy_ovl : entry.lazy_overloads) { + if (lazy_ovl.descriptor->ShapeMatches(descriptor)) { + return true; + } + } + return false; } bool FunctionRegistry::ValidateNonStrictOverload( diff --git a/runtime/function_registry.h b/runtime/function_registry.h index 5d8943ccc..6a227978d 100644 --- a/runtime/function_registry.h +++ b/runtime/function_registry.h @@ -15,6 +15,7 @@ #ifndef THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_REGISTRY_H_ #define THIRD_PARTY_CEL_CPP_RUNTIME_FUNCTION_REGISTRY_H_ +#include #include #include #include @@ -83,6 +84,9 @@ class FunctionRegistry { absl::string_view name, bool receiver_style, absl::Span types) const; + std::vector FindStaticOverloadsByArity( + absl::string_view name, bool receiver_style, size_t arity) const; + // Find subset of cel::Function providers that match overload conditions. // As types may not be available during expression compilation, // further narrowing of this subset will happen at evaluation stage. @@ -98,6 +102,10 @@ class FunctionRegistry { absl::string_view name, bool receiver_style, absl::Span types) const; + std::vector FindLazyOverloadsByArity(absl::string_view name, + bool receiver_style, + size_t arity) const; + // Retrieve list of registered function descriptors. This includes both // static and lazy functions. absl::node_hash_map>