From 69baabc9c0a72276e5b61d7061bddcacbee52ad4 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Fri, 22 Aug 2025 15:59:15 -0700 Subject: [PATCH] Minor naming changes to AstImpl for consistency In a follow-up, we will replace the opaque Ast with AstImpl. PiperOrigin-RevId: 798373303 --- checker/internal/type_checker_impl.cc | 5 +- checker/internal/type_checker_impl_test.cc | 29 ++--- checker/optional_test.cc | 108 +++++++++--------- checker/standard_library_test.cc | 15 +-- common/ast/BUILD | 3 + common/ast/ast_impl.cc | 22 +++- common/ast/ast_impl.h | 84 ++++++++------ common/ast/ast_impl_test.cc | 6 +- common/ast/metadata.h | 2 + common/ast_proto_test.cc | 22 ++-- common/ast_rewrite_test.cc | 4 +- eval/compiler/qualified_reference_resolver.cc | 2 +- .../qualified_reference_resolver_test.cc | 68 +++++------ extensions/select_optimization.cc | 4 +- testutil/baseline_tests_test.cc | 49 ++++---- 15 files changed, 230 insertions(+), 193 deletions(-) diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 0967c7326..9b8478556 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -1329,8 +1329,9 @@ absl::StatusOr TypeCheckerImpl::Check( // Happens in a second pass to simplify validating that pointers haven't // been invalidated by other updates. ResolveRewriter rewriter(visitor, type_inference_context, options_, - ast_impl.reference_map(), ast_impl.type_map()); - AstRewrite(ast_impl.root_expr(), rewriter); + ast_impl.mutable_reference_map(), + ast_impl.mutable_type_map()); + AstRewrite(ast_impl.mutable_root_expr(), rewriter); CEL_RETURN_IF_ERROR(rewriter.status()); diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc index 66fb7b57c..09a48fe37 100644 --- a/checker/internal/type_checker_impl_test.cc +++ b/checker/internal/type_checker_impl_test.cc @@ -867,7 +867,8 @@ TEST_P(PrimitiveLiteralsTest, LiteralsTypeInferred) { ASSERT_TRUE(result.IsValid()); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); - EXPECT_EQ(ast_impl.type_map()[1].primitive(), test_case.expected_type); + EXPECT_EQ(ast_impl.mutable_type_map()[1].primitive(), + test_case.expected_type); } INSTANTIATE_TEST_SUITE_P( @@ -917,7 +918,7 @@ TEST_P(AstTypeConversionTest, TypeConversion) { ASSERT_TRUE(result.IsValid()); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); - EXPECT_EQ(ast_impl.type_map()[1], test_case.expected_type) + EXPECT_EQ(ast_impl.mutable_type_map()[1], test_case.expected_type) << GetParam().decl_type.DebugString(); } @@ -1041,7 +1042,7 @@ TEST(TypeCheckerImplTest, NullLiteral) { ASSERT_TRUE(result.IsValid()); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); - EXPECT_TRUE(ast_impl.type_map()[1].has_null()); + EXPECT_TRUE(ast_impl.mutable_type_map()[1].has_null()); } TEST(TypeCheckerImplTest, ExpressionLimitInclusive) { @@ -1114,7 +1115,7 @@ TEST(TypeCheckerImplTest, BasicOvlResolution) { // Assumes parser numbering: + should always be id 2. ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); - EXPECT_THAT(ast_impl.reference_map()[2], + EXPECT_THAT(ast_impl.mutable_reference_map()[2], IsFunctionReference( "_+_", std::vector{"add_double_double"})); } @@ -1138,7 +1139,7 @@ TEST(TypeCheckerImplTest, OvlResolutionMultipleOverloads) { // Assumes parser numbering: + should always be id 3. ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); - EXPECT_THAT(ast_impl.reference_map()[3], + EXPECT_THAT(ast_impl.mutable_reference_map()[3], IsFunctionReference("_+_", std::vector{ "add_double_double", "add_int_int", "add_list", "add_uint_uint"})); @@ -1164,14 +1165,14 @@ TEST(TypeCheckerImplTest, BasicFunctionResultTypeResolution) { // Assumes parser numbering: + should always be id 2 and 4. ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); - EXPECT_THAT(ast_impl.reference_map()[2], + EXPECT_THAT(ast_impl.mutable_reference_map()[2], IsFunctionReference( "_+_", std::vector{"add_double_double"})); - EXPECT_THAT(ast_impl.reference_map()[4], + EXPECT_THAT(ast_impl.mutable_reference_map()[4], IsFunctionReference( "_+_", std::vector{"add_double_double"})); int64_t root_id = ast_impl.root_expr().id(); - EXPECT_EQ(ast_impl.type_map()[root_id].primitive(), + EXPECT_EQ(ast_impl.mutable_type_map()[root_id].primitive(), ast_internal::PrimitiveType::kDouble); } @@ -1335,7 +1336,7 @@ TEST(TypeCheckerImplTest, BadSourcePosition) { TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("foo")); auto& ast_impl = AstImpl::CastFromPublicAst(*ast); - ast_impl.source_info().mutable_positions()[1] = -42; + ast_impl.mutable_source_info().mutable_positions()[1] = -42; ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); ASSERT_OK_AND_ASSIGN(auto source, NewSource("foo")); @@ -1365,7 +1366,7 @@ TEST(TypeCheckerImplTest, FailsIfNoTypeDeduced) { // Assume that an unspecified expr kind is not deducible. Expr unspecified_expr; unspecified_expr.set_id(3); - ast_impl.root_expr().mutable_call_expr().mutable_args()[1] = + ast_impl.mutable_root_expr().mutable_call_expr().mutable_args()[1] = std::move(unspecified_expr); ASSERT_THAT(impl.Check(std::move(ast)), @@ -1382,7 +1383,7 @@ TEST(TypeCheckerImplTest, BadLineOffsets) { { ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("\nfoo")); auto& ast_impl = AstImpl::CastFromPublicAst(*ast); - ast_impl.source_info().mutable_line_offsets()[1] = 1; + ast_impl.mutable_source_info().mutable_line_offsets()[1] = 1; ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_FALSE(result.IsValid()); @@ -1395,9 +1396,9 @@ TEST(TypeCheckerImplTest, BadLineOffsets) { { ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("\nfoo")); auto& ast_impl = AstImpl::CastFromPublicAst(*ast); - ast_impl.source_info().mutable_line_offsets().clear(); - ast_impl.source_info().mutable_line_offsets().push_back(-1); - ast_impl.source_info().mutable_line_offsets().push_back(2); + ast_impl.mutable_source_info().mutable_line_offsets().clear(); + ast_impl.mutable_source_info().mutable_line_offsets().push_back(-1); + ast_impl.mutable_source_info().mutable_line_offsets().push_back(2); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); diff --git a/checker/optional_test.cc b/checker/optional_test.cc index 85f621591..d49747f67 100644 --- a/checker/optional_test.cc +++ b/checker/optional_test.cc @@ -49,10 +49,8 @@ using ::testing::Not; using ::testing::Property; using ::testing::SizeIs; -using AstType = ast_internal::Type; - MATCHER_P(IsOptionalType, inner_type, "") { - const ast_internal::Type& type = arg; + const TypeSpec& type = arg; if (!type.has_abstract_type()) { return false; } @@ -100,13 +98,13 @@ TEST(OptionalTest, OptSelectDoesNotAnnotateFieldType) { EXPECT_NE(field_id, 0); EXPECT_THAT(ast_impl.type_map(), Not(Contains(Key(field_id)))); - EXPECT_THAT(ast_impl.GetType(ast_impl.root_expr().id()), - IsOptionalType(AstType(ast_internal::PrimitiveType::kInt64))); + EXPECT_THAT(ast_impl.GetTypeOrDyn(ast_impl.root_expr().id()), + IsOptionalType(TypeSpec(ast_internal::PrimitiveType::kInt64))); } struct TestCase { std::string expr; - testing::Matcher result_type_matcher; + testing::Matcher result_type_matcher; std::string error_substring; }; @@ -144,7 +142,7 @@ TEST_P(OptionalTest, Runner) { int64_t root_id = ast_impl.root_expr().id(); - EXPECT_THAT(ast_impl.GetType(root_id), test_case.result_type_matcher) + EXPECT_THAT(ast_impl.GetTypeOrDyn(root_id), test_case.result_type_matcher) << "for expression: " << test_case.expr; } @@ -153,130 +151,132 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values( TestCase{ "optional.of('abc')", - IsOptionalType(AstType(ast_internal::PrimitiveType::kString)), + IsOptionalType(TypeSpec(ast_internal::PrimitiveType::kString)), }, TestCase{ "optional.ofNonZeroValue('')", - IsOptionalType(AstType(ast_internal::PrimitiveType::kString)), + IsOptionalType(TypeSpec(ast_internal::PrimitiveType::kString)), }, TestCase{ "optional.none()", - IsOptionalType(AstType(ast_internal::DynamicType())), + IsOptionalType(TypeSpec(ast_internal::DynamicType())), }, TestCase{ "optional.of('abc').hasValue()", - Eq(AstType(ast_internal::PrimitiveType::kBool)), + Eq(TypeSpec(ast_internal::PrimitiveType::kBool)), }, TestCase{ "optional.of('abc').value()", - Eq(AstType(ast_internal::PrimitiveType::kString)), + Eq(TypeSpec(ast_internal::PrimitiveType::kString)), }, TestCase{ "type(optional.of('abc')) == optional_type", - Eq(AstType(ast_internal::PrimitiveType::kBool)), + Eq(TypeSpec(ast_internal::PrimitiveType::kBool)), }, TestCase{ "type(optional.of('abc')) == optional_type", - Eq(AstType(ast_internal::PrimitiveType::kBool)), + Eq(TypeSpec(ast_internal::PrimitiveType::kBool)), }, TestCase{ "optional.of('abc').or(optional.of('def'))", - IsOptionalType(AstType(ast_internal::PrimitiveType::kString)), + IsOptionalType(TypeSpec(ast_internal::PrimitiveType::kString)), }, TestCase{"optional.of('abc').or(optional.of(1))", _, "no matching overload for 'or'"}, TestCase{ "optional.of('abc').orValue('def')", - Eq(AstType(ast_internal::PrimitiveType::kString)), + Eq(TypeSpec(ast_internal::PrimitiveType::kString)), }, TestCase{"optional.of('abc').orValue(1)", _, "no matching overload for 'orValue'"}, TestCase{ "{'k': 'v'}.?k", - IsOptionalType(AstType(ast_internal::PrimitiveType::kString)), + IsOptionalType(TypeSpec(ast_internal::PrimitiveType::kString)), }, TestCase{"1.?k", _, "expression of type 'int' cannot be the operand of a select " "operation"}, TestCase{ "{'k': {'k': 'v'}}.?k.?k2", - IsOptionalType(AstType(ast_internal::PrimitiveType::kString)), + IsOptionalType(TypeSpec(ast_internal::PrimitiveType::kString)), }, TestCase{ "{'k': {'k': 'v'}}.?k.k2", - IsOptionalType(AstType(ast_internal::PrimitiveType::kString)), + IsOptionalType(TypeSpec(ast_internal::PrimitiveType::kString)), }, TestCase{"{?'k': optional.of('v')}", - Eq(AstType(ast_internal::MapType( - std::unique_ptr( - new AstType(ast_internal::PrimitiveType::kString)), - std::unique_ptr( - new AstType(ast_internal::PrimitiveType::kString)))))}, + Eq(TypeSpec(ast_internal::MapType( + std::unique_ptr( + new TypeSpec(ast_internal::PrimitiveType::kString)), + std::unique_ptr(new TypeSpec( + ast_internal::PrimitiveType::kString)))))}, TestCase{"{'k': 'v', ?'k2': optional.none()}", - Eq(AstType(ast_internal::MapType( - std::unique_ptr( - new AstType(ast_internal::PrimitiveType::kString)), - std::unique_ptr( - new AstType(ast_internal::PrimitiveType::kString)))))}, + Eq(TypeSpec(ast_internal::MapType( + std::unique_ptr( + new TypeSpec(ast_internal::PrimitiveType::kString)), + std::unique_ptr(new TypeSpec( + ast_internal::PrimitiveType::kString)))))}, TestCase{"{'k': 'v', ?'k2': 'v'}", _, "expected type 'optional_type(string)' but found 'string'"}, TestCase{"[?optional.of('v')]", - Eq(AstType(ast_internal::ListType(std::unique_ptr( - new AstType(ast_internal::PrimitiveType::kString)))))}, + Eq(TypeSpec(ast_internal::ListType(std::unique_ptr( + new TypeSpec(ast_internal::PrimitiveType::kString)))))}, TestCase{"['v', ?optional.none()]", - Eq(AstType(ast_internal::ListType(std::unique_ptr( - new AstType(ast_internal::PrimitiveType::kString)))))}, + Eq(TypeSpec(ast_internal::ListType(std::unique_ptr( + new TypeSpec(ast_internal::PrimitiveType::kString)))))}, TestCase{"['v1', ?'v2']", _, "expected type 'optional_type(string)' but found 'string'"}, TestCase{"[optional.of(dyn('1')), optional.of('2')][0]", - IsOptionalType(AstType(ast_internal::DynamicType()))}, + IsOptionalType(TypeSpec(ast_internal::DynamicType()))}, TestCase{"[optional.of('1'), optional.of(dyn('2'))][0]", - IsOptionalType(AstType(ast_internal::DynamicType()))}, + IsOptionalType(TypeSpec(ast_internal::DynamicType()))}, TestCase{"[{1: optional.of(1)}, {1: optional.of(dyn(1))}][0][1]", - IsOptionalType(AstType(ast_internal::DynamicType()))}, + IsOptionalType(TypeSpec(ast_internal::DynamicType()))}, TestCase{"[{1: optional.of(dyn(1))}, {1: optional.of(1)}][0][1]", - IsOptionalType(AstType(ast_internal::DynamicType()))}, + IsOptionalType(TypeSpec(ast_internal::DynamicType()))}, TestCase{"[optional.of('1'), optional.of(2)][0]", - Eq(AstType(ast_internal::DynamicType()))}, + Eq(TypeSpec(ast_internal::DynamicType()))}, TestCase{"['v1', ?'v2']", _, "expected type 'optional_type(string)' but found 'string'"}, TestCase{"cel.expr.conformance.proto3.TestAllTypes{?single_int64: " "optional.of(1)}", - Eq(AstType(ast_internal::MessageType( + Eq(TypeSpec(ast_internal::MessageType( "cel.expr.conformance.proto3.TestAllTypes")))}, TestCase{"[0][?1]", - IsOptionalType(AstType(ast_internal::PrimitiveType::kInt64))}, + IsOptionalType(TypeSpec(ast_internal::PrimitiveType::kInt64))}, TestCase{"[[0]][?1][?1]", - IsOptionalType(AstType(ast_internal::PrimitiveType::kInt64))}, + IsOptionalType(TypeSpec(ast_internal::PrimitiveType::kInt64))}, TestCase{"[[0]][?1][1]", - IsOptionalType(AstType(ast_internal::PrimitiveType::kInt64))}, + IsOptionalType(TypeSpec(ast_internal::PrimitiveType::kInt64))}, TestCase{"{0: 1}[?1]", - IsOptionalType(AstType(ast_internal::PrimitiveType::kInt64))}, + IsOptionalType(TypeSpec(ast_internal::PrimitiveType::kInt64))}, TestCase{"{0: {0: 1}}[?1][?1]", - IsOptionalType(AstType(ast_internal::PrimitiveType::kInt64))}, + IsOptionalType(TypeSpec(ast_internal::PrimitiveType::kInt64))}, TestCase{"{0: {0: 1}}[?1][1]", - IsOptionalType(AstType(ast_internal::PrimitiveType::kInt64))}, + IsOptionalType(TypeSpec(ast_internal::PrimitiveType::kInt64))}, TestCase{"{0: {0: 1}}[?1]['']", _, "no matching overload for '_[_]'"}, TestCase{"{0: {0: 1}}[?1][?'']", _, "no matching overload for '_[?_]'"}, - TestCase{"optional.of('abc').optMap(x, x + 'def')", - IsOptionalType(AstType(ast_internal::PrimitiveType::kString))}, - TestCase{"optional.of('abc').optFlatMap(x, optional.of(x + 'def'))", - IsOptionalType(AstType(ast_internal::PrimitiveType::kString))}, + TestCase{ + "optional.of('abc').optMap(x, x + 'def')", + IsOptionalType(TypeSpec(ast_internal::PrimitiveType::kString))}, + TestCase{ + "optional.of('abc').optFlatMap(x, optional.of(x + 'def'))", + IsOptionalType(TypeSpec(ast_internal::PrimitiveType::kString))}, // Legacy nullability behaviors. TestCase{"cel.expr.conformance.proto3.TestAllTypes{?null_value: " "optional.of(0)}", - Eq(AstType(ast_internal::MessageType( + Eq(TypeSpec(ast_internal::MessageType( "cel.expr.conformance.proto3.TestAllTypes")))}, TestCase{"cel.expr.conformance.proto3.TestAllTypes{?null_value: null}", - Eq(AstType(ast_internal::MessageType( + Eq(TypeSpec(ast_internal::MessageType( "cel.expr.conformance.proto3.TestAllTypes")))}, TestCase{"cel.expr.conformance.proto3.TestAllTypes{?null_value: " "optional.of(null)}", - Eq(AstType(ast_internal::MessageType( + Eq(TypeSpec(ast_internal::MessageType( "cel.expr.conformance.proto3.TestAllTypes")))}, TestCase{"cel.expr.conformance.proto3.TestAllTypes{}.?single_int64 " "== null", - Eq(AstType(ast_internal::PrimitiveType::kBool))})); + Eq(TypeSpec(ast_internal::PrimitiveType::kBool))})); class OptionalStrictNullAssignmentTest : public testing::TestWithParam {}; @@ -315,7 +315,7 @@ TEST_P(OptionalStrictNullAssignmentTest, Runner) { int64_t root_id = ast_impl.root_expr().id(); - EXPECT_THAT(ast_impl.GetType(root_id), test_case.result_type_matcher) + EXPECT_THAT(ast_impl.GetTypeOrDyn(root_id), test_case.result_type_matcher) << "for expression: " << test_case.expr; } diff --git a/checker/standard_library_test.cc b/checker/standard_library_test.cc index 77694e37c..cc051cfe3 100644 --- a/checker/standard_library_test.cc +++ b/checker/standard_library_test.cc @@ -135,7 +135,8 @@ TEST(StandardLibraryTest, ComprehensionResultTypeIsSubstituted) { const ast_internal::AstImpl& checked_impl = ast_internal::AstImpl::CastFromPublicAst(*checked_ast); - ast_internal::Type type = checked_impl.GetType(checked_impl.root_expr().id()); + ast_internal::Type type = + checked_impl.GetTypeOrDyn(checked_impl.root_expr().id()); EXPECT_TRUE(type.has_primitive() && type.primitive() == ast_internal::PrimitiveType::kInt64); } @@ -160,8 +161,8 @@ class StdlibTypeVarDefinitionTest TEST_P(StdlibTypeVarDefinitionTest, DefinesTypeConstants) { auto ast = std::make_unique(); - ast->root_expr().mutable_ident_expr().set_name(GetParam()); - ast->root_expr().set_id(1); + ast->mutable_root_expr().mutable_ident_expr().set_name(GetParam()); + ast->mutable_root_expr().set_id(1); ASSERT_OK_AND_ASSIGN(ValidationResult result, stdlib_type_checker_->Check(std::move(ast))); @@ -171,7 +172,7 @@ TEST_P(StdlibTypeVarDefinitionTest, DefinesTypeConstants) { const auto& checked_impl = AstImpl::CastFromPublicAst(*checked_ast); EXPECT_THAT(checked_impl.GetReference(1), Pointee(Property(&Reference::name, GetParam()))); - EXPECT_THAT(checked_impl.GetType(1), Property(&AstType::has_type, true)); + EXPECT_THAT(checked_impl.GetTypeOrDyn(1), Property(&AstType::has_type, true)); } INSTANTIATE_TEST_SUITE_P(StdlibTypeVarDefinitions, StdlibTypeVarDefinitionTest, @@ -185,7 +186,7 @@ INSTANTIATE_TEST_SUITE_P(StdlibTypeVarDefinitions, StdlibTypeVarDefinitionTest, TEST_F(StandardLibraryDefinitionsTest, DefinesProtoStructNull) { auto ast = std::make_unique(); - auto& enumerator = ast->root_expr(); + auto& enumerator = ast->mutable_root_expr(); enumerator.set_id(4); enumerator.mutable_select_expr().set_field("NULL_VALUE"); auto& enumeration = enumerator.mutable_select_expr().mutable_operand(); @@ -212,7 +213,7 @@ TEST_F(StandardLibraryDefinitionsTest, DefinesProtoStructNull) { TEST_F(StandardLibraryDefinitionsTest, DefinesTypeType) { auto ast = std::make_unique(); - auto& ident = ast->root_expr(); + auto& ident = ast->mutable_root_expr(); ident.set_id(1); ident.mutable_ident_expr().set_name("type"); @@ -224,7 +225,7 @@ TEST_F(StandardLibraryDefinitionsTest, DefinesTypeType) { const auto& checked_impl = AstImpl::CastFromPublicAst(*checked_ast); EXPECT_THAT(checked_impl.GetReference(1), Pointee(Property(&Reference::name, "type"))); - EXPECT_THAT(checked_impl.GetType(1), Property(&AstType::has_type, true)); + EXPECT_THAT(checked_impl.GetTypeOrDyn(1), Property(&AstType::has_type, true)); } struct DefinitionsTestCase { diff --git a/common/ast/BUILD b/common/ast/BUILD index 07d4185ba..19d00dde6 100644 --- a/common/ast/BUILD +++ b/common/ast/BUILD @@ -79,8 +79,11 @@ cc_library( hdrs = ["ast_impl.h"], deps = [ ":expr", + ":metadata", "//common:ast", + "//common:expr", "//internal:casts", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:string_view", ], diff --git a/common/ast/ast_impl.cc b/common/ast/ast_impl.cc index dad62e257..9177f53d3 100644 --- a/common/ast/ast_impl.cc +++ b/common/ast/ast_impl.cc @@ -16,27 +16,39 @@ #include +#include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" +#include "common/ast/expr.h" +#include "common/ast/metadata.h" namespace cel::ast_internal { namespace { const Type& DynSingleton() { - static auto* singleton = new Type(TypeKind(DynamicType())); + static auto* singleton = new TypeSpec(TypeKind(DynamicType())); return *singleton; } } // namespace -const Type& AstImpl::GetType(int64_t expr_id) const { +const TypeSpec* absl_nullable AstImpl::GetType(int64_t expr_id) const { auto iter = type_map_.find(expr_id); if (iter == type_map_.end()) { - return DynSingleton(); + return nullptr; + } + return &iter->second; +} + +const TypeSpec& AstImpl::GetTypeOrDyn(int64_t expr_id) const { + if (const TypeSpec* type = GetType(expr_id); type != nullptr) { + return *type; } - return iter->second; + return DynSingleton(); } -const Type& AstImpl::GetReturnType() const { return GetType(root_expr().id()); } +const TypeSpec& AstImpl::GetReturnType() const { + return GetTypeOrDyn(root_expr().id()); +} const Reference* AstImpl::GetReference(int64_t expr_id) const { auto iter = reference_map_.find(expr_id); diff --git a/common/ast/ast_impl.h b/common/ast/ast_impl.h index 53e210acb..3c5dc7bae 100644 --- a/common/ast/ast_impl.h +++ b/common/ast/ast_impl.h @@ -23,14 +23,20 @@ #include "absl/strings/string_view.h" #include "common/ast.h" #include "common/ast/expr.h" +#include "common/ast/metadata.h" // IWYU pragma: export +#include "common/expr.h" #include "internal/casts.h" namespace cel::ast_internal { -// Runtime implementation of a CEL abstract syntax tree. -// CEL users should not use this directly. -// If AST inspection is needed, prefer to use an existing tool or traverse the -// the protobuf representation. +// In memory representation of a CEL abstract syntax tree. +// +// If AST inspection or manipulation is needed, prefer to use an existing tool +// or traverse the protobuf representation rather than directly manipulating +// through this class. See `cel::NavigableAst` and `cel::AstTraverse`. +// +// Type and reference maps are only populated if the AST is checked. Any changes +// to the AST are not automatically reflected in the type or reference maps. class AstImpl : public Ast { public: using ReferenceMap = absl::flat_hash_map; @@ -79,39 +85,33 @@ class AstImpl : public Ast { // Implement public Ast APIs. bool IsChecked() const override { return is_checked_; } - // CEL internal functions. + bool is_checked() const { return is_checked_; } void set_is_checked(bool is_checked) { is_checked_ = is_checked; } + // The root expression of the AST. + // + // This is the entry point for evaluation and determines the overall result + // of the expression given a context. const Expr& root_expr() const { return root_expr_; } - Expr& root_expr() { return root_expr_; } + Expr& mutable_root_expr() { return root_expr_; } + // Metadata about the source expression. const SourceInfo& source_info() const { return source_info_; } - SourceInfo& source_info() { return source_info_; } - - const Type& GetType(int64_t expr_id) const; - const Type& GetReturnType() const; - const Reference* GetReference(int64_t expr_id) const; - - const absl::flat_hash_map& reference_map() const { - return reference_map_; - } - - ReferenceMap& reference_map() { return reference_map_; } - - const TypeMap& type_map() const { return type_map_; } + SourceInfo& mutable_source_info() { return source_info_; } - TypeMap& type_map() { return type_map_; } + // Returns the type of the expression with the given `expr_id`. + // + // Returns `nullptr` if the expression node is not found or has dynamic type. + const TypeSpec* absl_nullable GetType(int64_t expr_id) const; + const TypeSpec& GetTypeOrDyn(int64_t expr_id) const; + const TypeSpec& GetReturnType() const; - absl::string_view expr_version() const { return expr_version_; } - void set_expr_version(absl::string_view expr_version) { - expr_version_ = expr_version; - } + // Returns the resolved reference for the expression with the given `expr_id`. + // + // Returns `nullptr` if the expression node is not found or no reference was + // resolved. + const Reference* absl_nullable GetReference(int64_t expr_id) const; - private: - Expr root_expr_; - // The source info derived from input that generated the parsed `expr` and - // any optimizations made during the type-checking pass. - SourceInfo source_info_; // A map from expression ids to resolved references. // // The following entries are in this table: @@ -127,22 +127,38 @@ class AstImpl : public Ast { // called. // - Every CreateStruct expression for a message has an entry, identifying // the message. - ReferenceMap reference_map_; + // + // Unpopulated if the AST is not checked. + const ReferenceMap& reference_map() const { return reference_map_; } + ReferenceMap& mutable_reference_map() { return reference_map_; } + // A map from expression ids to types. // // Every expression node which has a type different than DYN has a mapping // here. If an expression has type DYN, it is omitted from this map to save // space. - TypeMap type_map_; + // + // Unpopulated if the AST is not checked. + const TypeMap& type_map() const { return type_map_; } + TypeMap& mutable_type_map() { return type_map_; } + // The expr version indicates the major / minor version number of the `expr` // representation. // // The most common reason for a version change will be to indicate to the CEL // runtimes that transformations have been performed on the expr during static - // analysis. In some cases, this will save the runtime the work of applying - // the same or similar transformations prior to evaluation. - std::string expr_version_; + // analysis. + absl::string_view expr_version() const { return expr_version_; } + void set_expr_version(absl::string_view expr_version) { + expr_version_ = expr_version; + } + private: + Expr root_expr_; + SourceInfo source_info_; + ReferenceMap reference_map_; + TypeMap type_map_; + std::string expr_version_; bool is_checked_; }; diff --git a/common/ast/ast_impl_test.cc b/common/ast/ast_impl_test.cc index 2f5c7a47e..e9eb96e37 100644 --- a/common/ast/ast_impl_test.cc +++ b/common/ast/ast_impl_test.cc @@ -56,7 +56,7 @@ TEST(AstImpl, RawExprCtor) { // assert ASSERT_FALSE(ast.IsChecked()); - EXPECT_EQ(ast_impl.GetType(1), Type(DynamicType())); + EXPECT_EQ(ast_impl.GetTypeOrDyn(1), Type(DynamicType())); EXPECT_EQ(ast_impl.GetReturnType(), Type(DynamicType())); EXPECT_EQ(ast_impl.GetReference(1), nullptr); EXPECT_TRUE(ast_impl.root_expr().has_call_expr()); @@ -83,7 +83,7 @@ TEST(AstImpl, CheckedExprCtor) { Ast& ast = ast_impl; ASSERT_TRUE(ast.IsChecked()); - EXPECT_EQ(ast_impl.GetType(1), Type(PrimitiveType::kInt64)); + EXPECT_EQ(ast_impl.GetTypeOrDyn(1), Type(PrimitiveType::kInt64)); EXPECT_THAT(ast_impl.GetReference(1), Pointee(Truly([&ref](const Reference& arg) { return arg.name() == ref.name(); @@ -126,7 +126,7 @@ TEST(AstImpl, CheckedExprDeepCopy) { Ast& ast = ast_impl; ASSERT_TRUE(ast.IsChecked()); - EXPECT_EQ(ast_impl.GetType(1), Type(PrimitiveType::kInt64)); + EXPECT_EQ(ast_impl.GetTypeOrDyn(1), Type(PrimitiveType::kInt64)); EXPECT_THAT(ast_impl.GetReference(1), Pointee(Truly([](const Reference& arg) { return arg.name() == "com.int_value"; }))); diff --git a/common/ast/metadata.h b/common/ast/metadata.h index 7e6d4d182..707240b9a 100644 --- a/common/ast/metadata.h +++ b/common/ast/metadata.h @@ -15,6 +15,8 @@ // Type definitions for auxiliary structures in the AST. // // These are more direct equivalents to the public protobuf definitions. +// +// IWYU pragma: private, include "common/ast/ast_impl.h" #ifndef THIRD_PARTY_CEL_CPP_COMMON_AST_METADATA_H_ #define THIRD_PARTY_CEL_CPP_COMMON_AST_METADATA_H_ diff --git a/common/ast_proto_test.cc b/common/ast_proto_test.cc index 4837a413d..cc39c1156 100644 --- a/common/ast_proto_test.cc +++ b/common/ast_proto_test.cc @@ -461,19 +461,19 @@ TEST(AstConvertersTest, CheckedExprToAst) { TEST(AstConvertersTest, AstToCheckedExprBasic) { ast_internal::AstImpl ast; - ast.root_expr().set_id(1); - ast.root_expr().mutable_ident_expr().set_name("expr"); + ast.mutable_root_expr().set_id(1); + ast.mutable_root_expr().mutable_ident_expr().set_name("expr"); - ast.source_info().set_syntax_version("version"); - ast.source_info().set_location("location"); - ast.source_info().mutable_line_offsets().push_back(1); - ast.source_info().mutable_line_offsets().push_back(2); - ast.source_info().mutable_positions().insert({1, 2}); - ast.source_info().mutable_positions().insert({3, 4}); + ast.mutable_source_info().set_syntax_version("version"); + ast.mutable_source_info().set_location("location"); + ast.mutable_source_info().mutable_line_offsets().push_back(1); + ast.mutable_source_info().mutable_line_offsets().push_back(2); + ast.mutable_source_info().mutable_positions().insert({1, 2}); + ast.mutable_source_info().mutable_positions().insert({3, 4}); Expr macro; macro.mutable_ident_expr().set_name("name"); - ast.source_info().mutable_macro_calls().insert({1, std::move(macro)}); + ast.mutable_source_info().mutable_macro_calls().insert({1, std::move(macro)}); ast_internal::AstImpl::TypeMap type_map; ast_internal::AstImpl::ReferenceMap reference_map; @@ -487,8 +487,8 @@ TEST(AstConvertersTest, AstToCheckedExprBasic) { ast_internal::Type type; type.set_type_kind(ast_internal::DynamicType()); - ast.reference_map().insert({1, std::move(reference)}); - ast.type_map().insert({1, std::move(type)}); + ast.mutable_reference_map().insert({1, std::move(reference)}); + ast.mutable_type_map().insert({1, std::move(type)}); ast.set_expr_version("version"); ast.set_is_checked(true); diff --git a/common/ast_rewrite_test.cc b/common/ast_rewrite_test.cc index 1e7ca74af..209a305d9 100644 --- a/common/ast_rewrite_test.cc +++ b/common/ast_rewrite_test.cc @@ -537,7 +537,7 @@ TEST(AstRewrite, SelectRewriteExample) { google::api::expr::parser::Parse("com.google.Identifier").value())); AstImpl& ast_impl = AstImpl::CastFromPublicAst(*ast); RewriterExample example; - ASSERT_TRUE(AstRewrite(ast_impl.root_expr(), example)); + ASSERT_TRUE(AstRewrite(ast_impl.mutable_root_expr(), example)); cel::expr::Expr expected_expr; google::protobuf::TextFormat::ParseFromString( @@ -592,7 +592,7 @@ TEST(AstRewrite, PreAndPostVisitExpample) { CreateAstFromParsedExpr(google::api::expr::parser::Parse("x").value())); PreRewriterExample visitor; AstImpl& ast_impl = AstImpl::CastFromPublicAst(*ast); - ASSERT_TRUE(AstRewrite(ast_impl.root_expr(), visitor)); + ASSERT_TRUE(AstRewrite(ast_impl.mutable_root_expr(), visitor)); cel::expr::Expr expected_expr; google::protobuf::TextFormat::ParseFromString( diff --git a/eval/compiler/qualified_reference_resolver.cc b/eval/compiler/qualified_reference_resolver.cc index 2fc4e95e4..40b88e341 100644 --- a/eval/compiler/qualified_reference_resolver.cc +++ b/eval/compiler/qualified_reference_resolver.cc @@ -337,7 +337,7 @@ absl::StatusOr ResolveReferences(const Resolver& resolver, // Rewriting interface doesn't support failing mid traverse propagate first // error encountered if fail fast enabled. - bool was_rewritten = cel::AstRewrite(ast.root_expr(), ref_resolver); + bool was_rewritten = cel::AstRewrite(ast.mutable_root_expr(), ref_resolver); if (!ref_resolver.GetProgressStatus().ok()) { return ref_resolver.GetProgressStatus(); } diff --git a/eval/compiler/qualified_reference_resolver_test.cc b/eval/compiler/qualified_reference_resolver_test.cc index aa9518ae2..d470c3ccb 100644 --- a/eval/compiler/qualified_reference_resolver_test.cc +++ b/eval/compiler/qualified_reference_resolver_test.cc @@ -132,8 +132,8 @@ cel::expr::Expr ExprToProtoOrDie(const Expr& expr) { TEST(ResolveReferences, Basic) { std::unique_ptr expr_ast = ParseTestProto(kExpr); - expr_ast->reference_map()[2].set_name("foo.bar.var1"); - expr_ast->reference_map()[5].set_name("bar.foo.var2"); + expr_ast->mutable_reference_map()[2].set_name("foo.bar.var1"); + expr_ast->mutable_reference_map()[5].set_name("bar.foo.var2"); IssueCollector issues(RuntimeIssue::Severity::kError); CelFunctionRegistry func_registry; cel::TypeRegistry type_registry; @@ -170,8 +170,8 @@ TEST(ResolveReferences, ReturnsFalseIfNoChanges) { ASSERT_THAT(result, IsOkAndHolds(false)); // reference to the same name also doesn't count as a rewrite. - expr_ast->reference_map()[4].set_name("foo"); - expr_ast->reference_map()[7].set_name("bar"); + expr_ast->mutable_reference_map()[4].set_name("foo"); + expr_ast->mutable_reference_map()[7].set_name("bar"); result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(false)); @@ -185,8 +185,8 @@ TEST(ResolveReferences, NamespacedIdent) { cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); - expr_ast->reference_map()[2].set_name("foo.bar.var1"); - expr_ast->reference_map()[7].set_name("namespace_x.bar"); + expr_ast->mutable_reference_map()[2].set_name("foo.bar.var1"); + expr_ast->mutable_reference_map()[7].set_name("namespace_x.bar"); auto result = ResolveReferences(registry, issues, *expr_ast); ASSERT_THAT(result, IsOkAndHolds(true)); @@ -242,7 +242,7 @@ TEST(ResolveReferences, WarningOnPresenceTest) { cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); - expr_ast->reference_map()[1].set_name("foo.bar.var1"); + expr_ast->mutable_reference_map()[1].set_name("foo.bar.var1"); auto result = ResolveReferences(registry, issues, *expr_ast); @@ -291,9 +291,9 @@ TEST(ResolveReferences, EnumConstReferenceUsed) { cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); - expr_ast->reference_map()[2].set_name("foo.bar.var1"); - expr_ast->reference_map()[5].set_name("bar.foo.Enum.ENUM_VAL1"); - expr_ast->reference_map()[5].mutable_value().set_int64_value(9); + expr_ast->mutable_reference_map()[2].set_name("foo.bar.var1"); + expr_ast->mutable_reference_map()[5].set_name("bar.foo.Enum.ENUM_VAL1"); + expr_ast->mutable_reference_map()[5].mutable_value().set_int64_value(9); IssueCollector issues(RuntimeIssue::Severity::kError); auto result = ResolveReferences(registry, issues, *expr_ast); @@ -324,10 +324,10 @@ TEST(ResolveReferences, EnumConstReferenceUsedSelect) { cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); - expr_ast->reference_map()[2].set_name("foo.bar.var1"); - expr_ast->reference_map()[2].mutable_value().set_int64_value(2); - expr_ast->reference_map()[5].set_name("bar.foo.Enum.ENUM_VAL1"); - expr_ast->reference_map()[5].mutable_value().set_int64_value(9); + expr_ast->mutable_reference_map()[2].set_name("foo.bar.var1"); + expr_ast->mutable_reference_map()[2].mutable_value().set_int64_value(2); + expr_ast->mutable_reference_map()[5].set_name("bar.foo.Enum.ENUM_VAL1"); + expr_ast->mutable_reference_map()[5].mutable_value().set_int64_value(9); IssueCollector issues(RuntimeIssue::Severity::kError); auto result = ResolveReferences(registry, issues, *expr_ast); @@ -357,9 +357,9 @@ TEST(ResolveReferences, ConstReferenceSkipped) { cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); - expr_ast->reference_map()[2].set_name("foo.bar.var1"); - expr_ast->reference_map()[2].mutable_value().set_bool_value(true); - expr_ast->reference_map()[5].set_name("bar.foo.var2"); + expr_ast->mutable_reference_map()[2].set_name("foo.bar.var1"); + expr_ast->mutable_reference_map()[2].mutable_value().set_bool_value(true); + expr_ast->mutable_reference_map()[5].set_name("bar.foo.var2"); IssueCollector issues(RuntimeIssue::Severity::kError); auto result = ResolveReferences(registry, issues, *expr_ast); @@ -426,7 +426,7 @@ TEST(ResolveReferences, FunctionReferenceBasic) { Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); IssueCollector issues(RuntimeIssue::Severity::kError); - expr_ast->reference_map()[1].mutable_overload_id().push_back( + expr_ast->mutable_reference_map()[1].mutable_overload_id().push_back( "udf_boolean_and"); auto result = ResolveReferences(registry, issues, *expr_ast); @@ -443,7 +443,7 @@ TEST(ResolveReferences, FunctionReferenceMissingOverloadDetected) { Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); IssueCollector issues(RuntimeIssue::Severity::kError); - expr_ast->reference_map()[1].mutable_overload_id().push_back( + expr_ast->mutable_reference_map()[1].mutable_overload_id().push_back( "udf_boolean_and"); auto result = ResolveReferences(registry, issues, *expr_ast); @@ -479,9 +479,9 @@ TEST(ResolveReferences, SpecialBuiltinsNotWarned) { Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); IssueCollector issues(RuntimeIssue::Severity::kError); - expr_ast->reference_map()[1].mutable_overload_id().push_back( + expr_ast->mutable_reference_map()[1].mutable_overload_id().push_back( absl::StrCat("builtin.", builtin_fn)); - expr_ast->root_expr().mutable_call_expr().set_function(builtin_fn); + expr_ast->mutable_root_expr().mutable_call_expr().set_function(builtin_fn); auto result = ResolveReferences(registry, issues, *expr_ast); @@ -500,7 +500,7 @@ TEST(ResolveReferences, Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); IssueCollector issues(RuntimeIssue::Severity::kError); - expr_ast->reference_map()[1].set_name("udf_boolean_and"); + expr_ast->mutable_reference_map()[1].set_name("udf_boolean_and"); auto result = ResolveReferences(registry, issues, *expr_ast); @@ -523,7 +523,7 @@ TEST(ResolveReferences, EmulatesEagerFailing) { Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); IssueCollector issues(RuntimeIssue::Severity::kWarning); - expr_ast->reference_map()[1].set_name("udf_boolean_and"); + expr_ast->mutable_reference_map()[1].set_name("udf_boolean_and"); EXPECT_THAT( ResolveReferences(registry, issues, *expr_ast), @@ -540,7 +540,7 @@ TEST(ResolveReferences, FunctionReferenceToWrongExprKind) { cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); - expr_ast->reference_map()[2].mutable_overload_id().push_back( + expr_ast->mutable_reference_map()[2].mutable_overload_id().push_back( "udf_boolean_and"); auto result = ResolveReferences(registry, issues, *expr_ast); @@ -580,7 +580,7 @@ TEST(ResolveReferences, FunctionReferenceWithTargetNoChange) { cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); - expr_ast->reference_map()[1].mutable_overload_id().push_back( + expr_ast->mutable_reference_map()[1].mutable_overload_id().push_back( "udf_boolean_and"); auto result = ResolveReferences(registry, issues, *expr_ast); @@ -600,7 +600,7 @@ TEST(ResolveReferences, cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); - expr_ast->reference_map()[1].mutable_overload_id().push_back( + expr_ast->mutable_reference_map()[1].mutable_overload_id().push_back( "udf_boolean_and"); auto result = ResolveReferences(registry, issues, *expr_ast); @@ -622,7 +622,7 @@ TEST(ResolveReferences, FunctionReferenceWithTargetToNamespacedFunction) { cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); - expr_ast->reference_map()[1].mutable_overload_id().push_back( + expr_ast->mutable_reference_map()[1].mutable_overload_id().push_back( "udf_boolean_and"); auto result = ResolveReferences(registry, issues, *expr_ast); @@ -647,7 +647,7 @@ TEST(ResolveReferences, ParseTestProto(kReceiverCallExtensionAndExpr); SourceInfo source_info; - expr_ast->reference_map()[1].mutable_overload_id().push_back( + expr_ast->mutable_reference_map()[1].mutable_overload_id().push_back( "udf_boolean_and"); IssueCollector issues(RuntimeIssue::Severity::kError); CelFunctionRegistry func_registry; @@ -714,7 +714,7 @@ TEST(ResolveReferences, FunctionReferenceWithHasTargetNoChange) { cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); - expr_ast->reference_map()[1].mutable_overload_id().push_back( + expr_ast->mutable_reference_map()[1].mutable_overload_id().push_back( "udf_boolean_and"); auto result = ResolveReferences(registry, issues, *expr_ast); @@ -803,10 +803,10 @@ TEST(ResolveReferences, EnumConstReferenceUsedInComprehension) { cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); - expr_ast->reference_map()[3].set_name("ENUM"); - expr_ast->reference_map()[3].mutable_value().set_int64_value(2); - expr_ast->reference_map()[7].set_name("ENUM"); - expr_ast->reference_map()[7].mutable_value().set_int64_value(2); + expr_ast->mutable_reference_map()[3].set_name("ENUM"); + expr_ast->mutable_reference_map()[3].mutable_value().set_int64_value(2); + expr_ast->mutable_reference_map()[7].set_name("ENUM"); + expr_ast->mutable_reference_map()[7].mutable_value().set_int64_value(2); IssueCollector issues(RuntimeIssue::Severity::kError); auto result = ResolveReferences(registry, issues, *expr_ast); @@ -905,7 +905,7 @@ TEST(ResolveReferences, ReferenceToId0Warns) { cel::TypeRegistry type_registry; Resolver registry("", func_registry.InternalGetRegistry(), type_registry, type_registry.GetComposedTypeProvider()); - expr_ast->reference_map()[0].set_name("pkg.var"); + expr_ast->mutable_reference_map()[0].set_name("pkg.var"); IssueCollector issues(RuntimeIssue::Severity::kError); auto result = ResolveReferences(registry, issues, *expr_ast); diff --git a/extensions/select_optimization.cc b/extensions/select_optimization.cc index 21dffe56a..bebc25cf8 100644 --- a/extensions/select_optimization.cc +++ b/extensions/select_optimization.cc @@ -394,7 +394,7 @@ class RewriterImpl : public AstRewriterBase { const std::string& field_name = select.field(); // Select optimization can generalize to lists and maps, but for now only // support message traversal. - const ast_internal::Type& checker_type = ast_.GetType(operand.id()); + const TypeSpec checker_type = ast_.GetTypeOrDyn(operand.id()); absl::optional rt_type = (checker_type.has_message_type()) @@ -907,7 +907,7 @@ google::api::expr::runtime::FlatExprBuilder* GetFlatExprBuilder( absl::Status SelectOptimizationAstUpdater::UpdateAst(PlannerContext& context, AstImpl& ast) const { RewriterImpl rewriter(ast, context); - AstRewrite(ast.root_expr(), rewriter); + AstRewrite(ast.mutable_root_expr(), rewriter); return rewriter.GetProgressStatus(); } diff --git a/testutil/baseline_tests_test.cc b/testutil/baseline_tests_test.cc index cf7027982..3c7f00cb5 100644 --- a/testutil/baseline_tests_test.cc +++ b/testutil/baseline_tests_test.cc @@ -32,54 +32,55 @@ using AstType = ast_internal::Type; TEST(FormatBaselineAst, Basic) { AstImpl impl; - impl.root_expr().mutable_ident_expr().set_name("foo"); - impl.root_expr().set_id(1); - impl.type_map()[1] = AstType(ast_internal::PrimitiveType::kInt64); - impl.reference_map()[1].set_name("foo"); + impl.mutable_root_expr().mutable_ident_expr().set_name("foo"); + impl.mutable_root_expr().set_id(1); + impl.mutable_type_map()[1] = AstType(ast_internal::PrimitiveType::kInt64); + impl.mutable_reference_map()[1].set_name("foo"); EXPECT_EQ(FormatBaselineAst(impl), "foo~int^foo"); } TEST(FormatBaselineAst, NoType) { AstImpl impl; - impl.root_expr().mutable_ident_expr().set_name("foo"); - impl.root_expr().set_id(1); - impl.reference_map()[1].set_name("foo"); + impl.mutable_root_expr().mutable_ident_expr().set_name("foo"); + impl.mutable_root_expr().set_id(1); + impl.mutable_reference_map()[1].set_name("foo"); EXPECT_EQ(FormatBaselineAst(impl), "foo^foo"); } TEST(FormatBaselineAst, NoReference) { AstImpl impl; - impl.root_expr().mutable_ident_expr().set_name("foo"); - impl.root_expr().set_id(1); - impl.type_map()[1] = AstType(ast_internal::PrimitiveType::kInt64); + impl.mutable_root_expr().mutable_ident_expr().set_name("foo"); + impl.mutable_root_expr().set_id(1); + impl.mutable_type_map()[1] = AstType(ast_internal::PrimitiveType::kInt64); EXPECT_EQ(FormatBaselineAst(impl), "foo~int"); } TEST(FormatBaselineAst, MutlipleReferences) { AstImpl impl; - impl.root_expr().mutable_call_expr().set_function("_+_"); - impl.root_expr().set_id(1); - impl.type_map()[1] = AstType(ast_internal::DynamicType()); - impl.reference_map()[1].mutable_overload_id().push_back( + impl.mutable_root_expr().mutable_call_expr().set_function("_+_"); + impl.mutable_root_expr().set_id(1); + impl.mutable_type_map()[1] = AstType(ast_internal::DynamicType()); + impl.mutable_reference_map()[1].mutable_overload_id().push_back( "add_timestamp_duration"); - impl.reference_map()[1].mutable_overload_id().push_back( + impl.mutable_reference_map()[1].mutable_overload_id().push_back( "add_duration_duration"); { - auto& arg1 = impl.root_expr().mutable_call_expr().add_args(); + auto& arg1 = impl.mutable_root_expr().mutable_call_expr().add_args(); arg1.mutable_ident_expr().set_name("a"); arg1.set_id(2); - impl.type_map()[2] = AstType(ast_internal::DynamicType()); - impl.reference_map()[2].set_name("a"); + impl.mutable_type_map()[2] = AstType(ast_internal::DynamicType()); + impl.mutable_reference_map()[2].set_name("a"); } { - auto& arg2 = impl.root_expr().mutable_call_expr().add_args(); + auto& arg2 = impl.mutable_root_expr().mutable_call_expr().add_args(); arg2.mutable_ident_expr().set_name("b"); arg2.set_id(3); - impl.type_map()[3] = AstType(ast_internal::WellKnownType::kDuration); - impl.reference_map()[3].set_name("b"); + impl.mutable_type_map()[3] = + AstType(ast_internal::WellKnownType::kDuration); + impl.mutable_reference_map()[3].set_name("b"); } EXPECT_EQ(FormatBaselineAst(impl), @@ -153,9 +154,9 @@ class FormatBaselineAstTypeTest : public testing::TestWithParam {}; TEST_P(FormatBaselineAstTypeTest, Runner) { AstImpl impl; - impl.root_expr().set_id(1); - impl.root_expr().mutable_ident_expr().set_name("x"); - impl.type_map()[1] = GetParam().type; + impl.mutable_root_expr().set_id(1); + impl.mutable_root_expr().mutable_ident_expr().set_name("x"); + impl.mutable_type_map()[1] = GetParam().type; EXPECT_EQ(FormatBaselineAst(impl), GetParam().expected_string); }