diff --git a/checker/internal/BUILD b/checker/internal/BUILD index 05058618d..9bcf73582 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -176,7 +176,6 @@ cc_test( "//common:expr", "//common:source", "//common:type", - "//common/ast:ast_impl", "//common/ast:expr", "//internal:status_macros", "//internal:testing", diff --git a/checker/internal/test_ast_helpers_test.cc b/checker/internal/test_ast_helpers_test.cc index ddaff082d..51fb8461a 100644 --- a/checker/internal/test_ast_helpers_test.cc +++ b/checker/internal/test_ast_helpers_test.cc @@ -19,19 +19,16 @@ #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "common/ast.h" -#include "common/ast/ast_impl.h" #include "internal/testing.h" namespace cel::checker_internal { namespace { using ::absl_testing::StatusIs; -using ::cel::ast_internal::AstImpl; TEST(MakeTestParsedAstTest, Works) { ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, MakeTestParsedAst("123")); - AstImpl& impl = AstImpl::CastFromPublicAst(*ast); - EXPECT_TRUE(impl.root_expr().has_const_expr()); + EXPECT_TRUE(ast->root_expr().has_const_expr()); } TEST(MakeTestParsedAstTest, ForwardsParseError) { diff --git a/checker/internal/type_checker_builder_impl_test.cc b/checker/internal/type_checker_builder_impl_test.cc index 5e4054128..9f885f79d 100644 --- a/checker/internal/type_checker_builder_impl_test.cc +++ b/checker/internal/type_checker_builder_impl_test.cc @@ -41,7 +41,6 @@ namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; -using ::cel::ast_internal::AstImpl; using AstType = cel::ast_internal::Type; @@ -67,9 +66,7 @@ TEST_P(ContextDeclsFieldsDefinedTest, ContextDeclsFieldsDefined) { ASSERT_TRUE(result.IsValid()); - const auto& ast_impl = AstImpl::CastFromPublicAst(*result.GetAst()); - - EXPECT_EQ(ast_impl.GetReturnType(), GetParam().expected_type); + EXPECT_EQ(result.GetAst()->GetReturnType(), GetParam().expected_type); } INSTANTIATE_TEST_SUITE_P( @@ -324,9 +321,9 @@ TEST(TypeCheckerBuilderImplTest, ReplaceVariable) { ASSERT_TRUE(result.IsValid()); - const auto& ast_impl = AstImpl::CastFromPublicAst(*result.GetAst()); + const auto& checked_ast = *result.GetAst(); - EXPECT_EQ(ast_impl.GetReturnType(), + EXPECT_EQ(checked_ast.GetReturnType(), AstType(ast_internal::PrimitiveType::kString)); } diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 9b8478556..b051032cd 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -58,8 +58,6 @@ namespace cel::checker_internal { namespace { -using cel::ast_internal::AstImpl; - using AstType = cel::ast_internal::Type; using Severity = TypeCheckIssue::Severity; @@ -69,7 +67,7 @@ std::string FormatCandidate(absl::Span qualifiers) { return absl::StrJoin(qualifiers, "."); } -SourceLocation ComputeSourceLocation(const AstImpl& ast, int64_t expr_id) { +SourceLocation ComputeSourceLocation(const Ast& ast, int64_t expr_id) { const auto& source_info = ast.source_info(); auto iter = source_info.positions().find(expr_id); if (iter == source_info.positions().end()) { @@ -248,7 +246,7 @@ class ResolveVisitor : public AstVisitorBase { ResolveVisitor(absl::string_view container, NamespaceGenerator namespace_generator, - const TypeCheckEnv& env, const AstImpl& ast, + const TypeCheckEnv& env, const Ast& ast, TypeInferenceContext& inference_context, std::vector& issues, google::protobuf::Arena* absl_nonnull arena) @@ -468,7 +466,7 @@ class ResolveVisitor : public AstVisitorBase { const TypeCheckEnv* absl_nonnull env_; TypeInferenceContext* absl_nonnull inference_context_; std::vector* absl_nonnull issues_; - const ast_internal::AstImpl* absl_nonnull ast_; + const Ast* absl_nonnull ast_; VariableScope root_scope_; google::protobuf::Arena* absl_nonnull arena_; @@ -1198,8 +1196,7 @@ class ResolveRewriter : public AstRewriterBase { explicit ResolveRewriter(const ResolveVisitor& visitor, const TypeInferenceContext& inference_context, const CheckerOptions& options, - AstImpl::ReferenceMap& references, - AstImpl::TypeMap& types) + Ast::ReferenceMap& references, Ast::TypeMap& types) : visitor_(visitor), inference_context_(inference_context), reference_map_(references), @@ -1264,8 +1261,8 @@ class ResolveRewriter : public AstRewriterBase { absl::Status status_; const ResolveVisitor& visitor_; const TypeInferenceContext& inference_context_; - AstImpl::ReferenceMap& reference_map_; - AstImpl::TypeMap& type_map_; + Ast::ReferenceMap& reference_map_; + Ast::TypeMap& type_map_; const CheckerOptions& options_; }; @@ -1273,7 +1270,6 @@ class ResolveRewriter : public AstRewriterBase { absl::StatusOr TypeCheckerImpl::Check( std::unique_ptr ast) const { - auto& ast_impl = AstImpl::CastFromPublicAst(*ast); google::protobuf::Arena type_arena; std::vector issues; @@ -1282,13 +1278,13 @@ absl::StatusOr TypeCheckerImpl::Check( TypeInferenceContext type_inference_context( &type_arena, options_.enable_legacy_null_assignment); - ResolveVisitor visitor(env_.container(), std::move(generator), env_, ast_impl, + ResolveVisitor visitor(env_.container(), std::move(generator), env_, *ast, type_inference_context, issues, &type_arena); TraversalOptions opts; opts.use_comprehension_callbacks = true; bool error_limit_reached = false; - auto traversal = AstTraversal::Create(ast_impl.root_expr(), opts); + auto traversal = AstTraversal::Create(ast->root_expr(), opts); for (int step = 0; step < options_.max_expression_node_count * 2; ++step) { bool has_next = traversal.Step(visitor); @@ -1315,7 +1311,7 @@ absl::StatusOr TypeCheckerImpl::Check( {}, absl::StrCat("maximum number of ERROR issues exceeded: ", options_.max_error_issues))); } else if (env_.expected_type().has_value()) { - visitor.AssertExpectedType(ast_impl.root_expr(), *env_.expected_type()); + visitor.AssertExpectedType(ast->root_expr(), *env_.expected_type()); } // If any issues are errors, return without an AST. @@ -1329,13 +1325,13 @@ absl::StatusOr TypeCheckerImpl::Check( // Happens in a second pass to simplify validating that pointers haven't // been invalidated by other updates. ResolveRewriter rewriter(visitor, type_inference_context, options_, - ast_impl.mutable_reference_map(), - ast_impl.mutable_type_map()); - AstRewrite(ast_impl.mutable_root_expr(), rewriter); + ast->mutable_reference_map(), + ast->mutable_type_map()); + AstRewrite(ast->mutable_root_expr(), rewriter); CEL_RETURN_IF_ERROR(rewriter.status()); - ast_impl.set_is_checked(true); + ast->set_is_checked(true); return ValidationResult(std::move(ast), std::move(issues)); } diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc index 09a48fe37..9f2ccc69c 100644 --- a/checker/internal/type_checker_impl_test.cc +++ b/checker/internal/type_checker_impl_test.cc @@ -35,7 +35,6 @@ #include "checker/type_check_issue.h" #include "checker/validation_result.h" #include "common/ast.h" -#include "common/ast/ast_impl.h" #include "common/ast/expr.h" #include "common/decl.h" #include "common/expr.h" @@ -58,7 +57,6 @@ namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; -using ::cel::ast_internal::AstImpl; using ::cel::ast_internal::Reference; using ::cel::expr::conformance::proto3::TestAllTypes; using ::cel::internal::GetSharedTestingDescriptorPool; @@ -460,8 +458,7 @@ TEST(TypeCheckerImplTest, ResolveMostQualfiedIdent) { ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); - auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); - EXPECT_THAT(ast_impl.reference_map(), + EXPECT_THAT(checked_ast->reference_map(), Contains(Pair(_, IsVariableReference("x.y")))); } @@ -547,11 +544,10 @@ TEST(TypeCheckerImplTest, NamespaceFunctionCallResolved) { EXPECT_THAT(result.GetIssues(), IsEmpty()); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); - auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); - EXPECT_TRUE(ast_impl.root_expr().has_call_expr()) - << absl::StrCat("kind: ", ast_impl.root_expr().kind().index()); - EXPECT_EQ(ast_impl.root_expr().call_expr().function(), "x.foo"); - EXPECT_FALSE(ast_impl.root_expr().call_expr().has_target()); + EXPECT_TRUE(checked_ast->root_expr().has_call_expr()) + << absl::StrCat("kind: ", checked_ast->root_expr().kind().index()); + EXPECT_EQ(checked_ast->root_expr().call_expr().function(), "x.foo"); + EXPECT_FALSE(checked_ast->root_expr().call_expr().has_target()); } TEST(TypeCheckerImplTest, NamespacedFunctionSkipsFieldCheck) { @@ -576,11 +572,10 @@ TEST(TypeCheckerImplTest, NamespacedFunctionSkipsFieldCheck) { EXPECT_THAT(result.GetIssues(), IsEmpty()); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); - auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); - EXPECT_TRUE(ast_impl.root_expr().has_call_expr()) - << absl::StrCat("kind: ", ast_impl.root_expr().kind().index()); - EXPECT_EQ(ast_impl.root_expr().call_expr().function(), "x.y.foo"); - EXPECT_FALSE(ast_impl.root_expr().call_expr().has_target()); + EXPECT_TRUE(checked_ast->root_expr().has_call_expr()) + << absl::StrCat("kind: ", checked_ast->root_expr().kind().index()); + EXPECT_EQ(checked_ast->root_expr().call_expr().function(), "x.y.foo"); + EXPECT_FALSE(checked_ast->root_expr().call_expr().has_target()); } TEST(TypeCheckerImplTest, MixedListTypeToDyn) { @@ -596,8 +591,8 @@ TEST(TypeCheckerImplTest, MixedListTypeToDyn) { ASSERT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); - auto& ast_impl = AstImpl::CastFromPublicAst(*result.GetAst()); - EXPECT_TRUE(ast_impl.type_map().at(1).list_type().elem_type().has_dyn()); + EXPECT_TRUE( + result.GetAst()->type_map().at(1).list_type().elem_type().has_dyn()); } TEST(TypeCheckerImplTest, FreeListTypeToDyn) { @@ -613,8 +608,8 @@ TEST(TypeCheckerImplTest, FreeListTypeToDyn) { ASSERT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); - auto& ast_impl = AstImpl::CastFromPublicAst(*result.GetAst()); - EXPECT_TRUE(ast_impl.type_map().at(1).list_type().elem_type().has_dyn()); + EXPECT_TRUE( + result.GetAst()->type_map().at(1).list_type().elem_type().has_dyn()); } TEST(TypeCheckerImplTest, FreeMapValueTypeToDyn) { @@ -630,9 +625,8 @@ TEST(TypeCheckerImplTest, FreeMapValueTypeToDyn) { ASSERT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); - auto& ast_impl = AstImpl::CastFromPublicAst(*result.GetAst()); - auto root_id = ast_impl.root_expr().id(); - EXPECT_TRUE(ast_impl.type_map().at(root_id).has_dyn()); + auto root_id = result.GetAst()->root_expr().id(); + EXPECT_TRUE(result.GetAst()->type_map().at(root_id).has_dyn()); } TEST(TypeCheckerImplTest, FreeMapTypeToDyn) { @@ -648,9 +642,9 @@ TEST(TypeCheckerImplTest, FreeMapTypeToDyn) { ASSERT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); - auto& ast_impl = AstImpl::CastFromPublicAst(*result.GetAst()); - EXPECT_TRUE(ast_impl.type_map().at(1).map_type().key_type().has_dyn()); - EXPECT_TRUE(ast_impl.type_map().at(1).map_type().value_type().has_dyn()); + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_TRUE(checked_ast->type_map().at(1).map_type().key_type().has_dyn()); + EXPECT_TRUE(checked_ast->type_map().at(1).map_type().value_type().has_dyn()); } TEST(TypeCheckerImplTest, MapTypeWithMixedKeys) { @@ -666,9 +660,9 @@ TEST(TypeCheckerImplTest, MapTypeWithMixedKeys) { ASSERT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); - auto& ast_impl = AstImpl::CastFromPublicAst(*result.GetAst()); - EXPECT_TRUE(ast_impl.type_map().at(1).map_type().key_type().has_dyn()); - EXPECT_EQ(ast_impl.type_map().at(1).map_type().value_type().primitive(), + const auto* checked_ast = result.GetAst(); + EXPECT_TRUE(checked_ast->type_map().at(1).map_type().key_type().has_dyn()); + EXPECT_EQ(checked_ast->type_map().at(1).map_type().value_type().primitive(), ast_internal::PrimitiveType::kInt64); } @@ -702,10 +696,10 @@ TEST(TypeCheckerImplTest, MapTypeWithMixedValues) { ASSERT_TRUE(result.IsValid()); EXPECT_THAT(result.GetIssues(), IsEmpty()); - auto& ast_impl = AstImpl::CastFromPublicAst(*result.GetAst()); - EXPECT_EQ(ast_impl.type_map().at(1).map_type().key_type().primitive(), + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + EXPECT_EQ(checked_ast->type_map().at(1).map_type().key_type().primitive(), ast_internal::PrimitiveType::kString); - EXPECT_TRUE(ast_impl.type_map().at(1).map_type().value_type().has_dyn()); + EXPECT_TRUE(checked_ast->type_map().at(1).map_type().value_type().has_dyn()); } TEST(TypeCheckerImplTest, ComprehensionVariablesResolved) { @@ -775,8 +769,7 @@ TEST(TypeCheckerImplTest, ComprehensionVarsFollowNamespacePriorityRules) { EXPECT_THAT(result.GetIssues(), IsEmpty()); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); - auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); - EXPECT_THAT(ast_impl.reference_map(), + EXPECT_THAT(checked_ast->reference_map(), Contains(Pair(_, IsVariableReference("com.x")))); } @@ -797,8 +790,7 @@ TEST(TypeCheckerImplTest, ComprehensionVarsFollowQualifiedIdentPriority) { EXPECT_THAT(result.GetIssues(), IsEmpty()); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); - auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); - EXPECT_THAT(ast_impl.reference_map(), + EXPECT_THAT(checked_ast->reference_map(), Contains(Pair(_, IsVariableReference("x.y")))); } @@ -866,8 +858,7 @@ TEST_P(PrimitiveLiteralsTest, LiteralsTypeInferred) { ASSERT_TRUE(result.IsValid()); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); - auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); - EXPECT_EQ(ast_impl.mutable_type_map()[1].primitive(), + EXPECT_EQ(checked_ast->mutable_type_map()[1].primitive(), test_case.expected_type); } @@ -917,8 +908,7 @@ TEST_P(AstTypeConversionTest, TypeConversion) { ASSERT_TRUE(result.IsValid()); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); - auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); - EXPECT_EQ(ast_impl.mutable_type_map()[1], test_case.expected_type) + EXPECT_EQ(checked_ast->mutable_type_map()[1], test_case.expected_type) << GetParam().decl_type.DebugString(); } @@ -1041,8 +1031,7 @@ TEST(TypeCheckerImplTest, NullLiteral) { ASSERT_TRUE(result.IsValid()); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); - auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); - EXPECT_TRUE(ast_impl.mutable_type_map()[1].has_null()); + EXPECT_TRUE(checked_ast->mutable_type_map()[1].has_null()); } TEST(TypeCheckerImplTest, ExpressionLimitInclusive) { @@ -1114,8 +1103,7 @@ TEST(TypeCheckerImplTest, BasicOvlResolution) { // Assumes parser numbering: + should always be id 2. ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); - auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); - EXPECT_THAT(ast_impl.mutable_reference_map()[2], + EXPECT_THAT(checked_ast->mutable_reference_map()[2], IsFunctionReference( "_+_", std::vector{"add_double_double"})); } @@ -1138,8 +1126,7 @@ TEST(TypeCheckerImplTest, OvlResolutionMultipleOverloads) { // Assumes parser numbering: + should always be id 3. ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); - auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); - EXPECT_THAT(ast_impl.mutable_reference_map()[3], + EXPECT_THAT(checked_ast->mutable_reference_map()[3], IsFunctionReference("_+_", std::vector{ "add_double_double", "add_int_int", "add_list", "add_uint_uint"})); @@ -1164,15 +1151,14 @@ TEST(TypeCheckerImplTest, BasicFunctionResultTypeResolution) { // Assumes parser numbering: + should always be id 2 and 4. ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); - auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); - EXPECT_THAT(ast_impl.mutable_reference_map()[2], + EXPECT_THAT(checked_ast->mutable_reference_map()[2], IsFunctionReference( "_+_", std::vector{"add_double_double"})); - EXPECT_THAT(ast_impl.mutable_reference_map()[4], + EXPECT_THAT(checked_ast->mutable_reference_map()[4], IsFunctionReference( "_+_", std::vector{"add_double_double"})); - int64_t root_id = ast_impl.root_expr().id(); - EXPECT_EQ(ast_impl.mutable_type_map()[root_id].primitive(), + int64_t root_id = checked_ast->root_expr().id(); + EXPECT_EQ(checked_ast->mutable_type_map()[root_id].primitive(), ast_internal::PrimitiveType::kDouble); } @@ -1253,13 +1239,12 @@ TEST(TypeCheckerImplTest, WellKnownTypeCreation) { ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); - const auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); - EXPECT_THAT(ast_impl.type_map(), - Contains(Pair(ast_impl.root_expr().id(), + EXPECT_THAT(checked_ast->type_map(), + Contains(Pair(checked_ast->root_expr().id(), Eq(AstType(ast_internal::PrimitiveTypeWrapper( ast_internal::PrimitiveType::kInt64)))))); - EXPECT_THAT(ast_impl.reference_map(), - Contains(Pair(ast_impl.root_expr().id(), + EXPECT_THAT(checked_ast->reference_map(), + Contains(Pair(checked_ast->root_expr().id(), Property(&ast_internal::Reference::name, "google.protobuf.Int32Value")))); } @@ -1274,13 +1259,11 @@ TEST(TypeCheckerImplTest, TypeInferredFromStructCreation) { ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); - - const auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); int64_t map_expr_id = - ast_impl.root_expr().struct_expr().fields().at(0).value().id(); + checked_ast->root_expr().struct_expr().fields().at(0).value().id(); ASSERT_NE(map_expr_id, 0); EXPECT_THAT( - ast_impl.type_map(), + checked_ast->type_map(), Contains(Pair( map_expr_id, Eq(AstType(ast_internal::MapType( @@ -1300,12 +1283,10 @@ TEST(TypeCheckerImplTest, ExpectedTypeMatches) { ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); - const auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); - EXPECT_THAT( - ast_impl.type_map(), + checked_ast->type_map(), Contains(Pair( - ast_impl.root_expr().id(), + checked_ast->root_expr().id(), Eq(AstType(ast_internal::MapType( std::make_unique(ast_internal::PrimitiveType::kString), std::make_unique( @@ -1335,8 +1316,7 @@ TEST(TypeCheckerImplTest, BadSourcePosition) { TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("foo")); - auto& ast_impl = AstImpl::CastFromPublicAst(*ast); - ast_impl.mutable_source_info().mutable_positions()[1] = -42; + ast->mutable_source_info().mutable_positions()[1] = -42; ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); ASSERT_OK_AND_ASSIGN(auto source, NewSource("foo")); @@ -1361,12 +1341,11 @@ TEST(TypeCheckerImplTest, FailsIfNoTypeDeduced) { TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("a || b")); - auto& ast_impl = AstImpl::CastFromPublicAst(*ast); // Assume that an unspecified expr kind is not deducible. Expr unspecified_expr; unspecified_expr.set_id(3); - ast_impl.mutable_root_expr().mutable_call_expr().mutable_args()[1] = + ast->mutable_root_expr().mutable_call_expr().mutable_args()[1] = std::move(unspecified_expr); ASSERT_THAT(impl.Check(std::move(ast)), @@ -1382,8 +1361,7 @@ TEST(TypeCheckerImplTest, BadLineOffsets) { { ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("\nfoo")); - auto& ast_impl = AstImpl::CastFromPublicAst(*ast); - ast_impl.mutable_source_info().mutable_line_offsets()[1] = 1; + ast->mutable_source_info().mutable_line_offsets()[1] = 1; ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); EXPECT_FALSE(result.IsValid()); @@ -1395,10 +1373,9 @@ TEST(TypeCheckerImplTest, BadLineOffsets) { } { ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("\nfoo")); - auto& ast_impl = AstImpl::CastFromPublicAst(*ast); - ast_impl.mutable_source_info().mutable_line_offsets().clear(); - ast_impl.mutable_source_info().mutable_line_offsets().push_back(-1); - ast_impl.mutable_source_info().mutable_line_offsets().push_back(2); + ast->mutable_source_info().mutable_line_offsets().clear(); + ast->mutable_source_info().mutable_line_offsets().push_back(-1); + ast->mutable_source_info().mutable_line_offsets().push_back(2); ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); @@ -1422,13 +1399,12 @@ TEST(TypeCheckerImplTest, ContainerLookupForMessageCreation) { ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); - const auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); - EXPECT_THAT(ast_impl.type_map(), - Contains(Pair(ast_impl.root_expr().id(), + EXPECT_THAT(checked_ast->type_map(), + Contains(Pair(checked_ast->root_expr().id(), Eq(AstType(ast_internal::PrimitiveTypeWrapper( ast_internal::PrimitiveType::kInt64)))))); - EXPECT_THAT(ast_impl.reference_map(), - Contains(Pair(ast_impl.root_expr().id(), + EXPECT_THAT(checked_ast->reference_map(), + Contains(Pair(checked_ast->root_expr().id(), Property(&ast_internal::Reference::name, "google.protobuf.Int32Value")))); } @@ -1446,16 +1422,15 @@ TEST(TypeCheckerImplTest, ContainerLookupForMessageCreationNoRewrite) { ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); - const auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); - EXPECT_THAT(ast_impl.type_map(), - Contains(Pair(ast_impl.root_expr().id(), + EXPECT_THAT(checked_ast->type_map(), + Contains(Pair(checked_ast->root_expr().id(), Eq(AstType(ast_internal::PrimitiveTypeWrapper( ast_internal::PrimitiveType::kInt64)))))); - EXPECT_THAT(ast_impl.reference_map(), - Contains(Pair(ast_impl.root_expr().id(), + EXPECT_THAT(checked_ast->reference_map(), + Contains(Pair(checked_ast->root_expr().id(), Property(&ast_internal::Reference::name, "google.protobuf.Int32Value")))); - EXPECT_THAT(ast_impl.root_expr().struct_expr(), + EXPECT_THAT(checked_ast->root_expr().struct_expr(), Property(&StructExpr::name, "Int32Value")); } @@ -1470,9 +1445,9 @@ TEST(TypeCheckerImplTest, EnumValueCopiedToReferenceMap) { ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); - const auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); - auto ref_iter = ast_impl.reference_map().find(ast_impl.root_expr().id()); - ASSERT_NE(ref_iter, ast_impl.reference_map().end()); + auto ref_iter = + checked_ast->reference_map().find(checked_ast->root_expr().id()); + ASSERT_NE(ref_iter, checked_ast->reference_map().end()); EXPECT_EQ(ref_iter->second.name(), "cel.expr.conformance.proto3.TestAllTypes.NestedEnum.BAZ"); EXPECT_EQ(ref_iter->second.value().int_value(), 2); @@ -1514,9 +1489,8 @@ TEST_P(WktCreationTest, MessageCreation) { ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); - const auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); - EXPECT_THAT(ast_impl.type_map(), - Contains(Pair(ast_impl.root_expr().id(), + EXPECT_THAT(checked_ast->type_map(), + Contains(Pair(checked_ast->root_expr().id(), Eq(test_case.expected_result_type)))); } @@ -1679,9 +1653,8 @@ TEST_P(GenericMessagesTest, TypeChecksProto3) { ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); - const auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); - EXPECT_THAT(ast_impl.type_map(), - Contains(Pair(ast_impl.root_expr().id(), + EXPECT_THAT(checked_ast->type_map(), + Contains(Pair(checked_ast->root_expr().id(), Eq(test_case.expected_result_type)))) << cel::test::FormatBaselineAst(*checked_ast); } @@ -2269,9 +2242,8 @@ TEST_P(StrictNullAssignmentTest, TypeChecksProto3) { ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); - const auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); - EXPECT_THAT(ast_impl.type_map(), - Contains(Pair(ast_impl.root_expr().id(), + EXPECT_THAT(checked_ast->type_map(), + Contains(Pair(checked_ast->root_expr().id(), Eq(test_case.expected_result_type)))); } diff --git a/checker/optional_test.cc b/checker/optional_test.cc index d49747f67..877abf08d 100644 --- a/checker/optional_test.cc +++ b/checker/optional_test.cc @@ -91,14 +91,13 @@ TEST(OptionalTest, OptSelectDoesNotAnnotateFieldType) { EXPECT_THAT(result.GetIssues(), IsEmpty()); ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); - const auto& ast_impl = ast_internal::AstImpl::CastFromPublicAst(*checked_ast); - ASSERT_THAT(ast_impl.root_expr().call_expr().args(), SizeIs(2)); - int64_t field_id = ast_impl.root_expr().call_expr().args()[1].id(); + ASSERT_THAT(checked_ast->root_expr().call_expr().args(), SizeIs(2)); + int64_t field_id = checked_ast->root_expr().call_expr().args()[1].id(); EXPECT_NE(field_id, 0); - EXPECT_THAT(ast_impl.type_map(), Not(Contains(Key(field_id)))); - EXPECT_THAT(ast_impl.GetTypeOrDyn(ast_impl.root_expr().id()), + EXPECT_THAT(checked_ast->type_map(), Not(Contains(Key(field_id)))); + EXPECT_THAT(checked_ast->GetTypeOrDyn(checked_ast->root_expr().id()), IsOptionalType(TypeSpec(ast_internal::PrimitiveType::kInt64))); } @@ -138,11 +137,10 @@ TEST_P(OptionalTest, Runner) { EXPECT_THAT(result.GetIssues(), IsEmpty()) << "for expression: " << test_case.expr; ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); - const auto& ast_impl = ast_internal::AstImpl::CastFromPublicAst(*checked_ast); - int64_t root_id = ast_impl.root_expr().id(); + int64_t root_id = checked_ast->root_expr().id(); - EXPECT_THAT(ast_impl.GetTypeOrDyn(root_id), test_case.result_type_matcher) + EXPECT_THAT(checked_ast->GetTypeOrDyn(root_id), test_case.result_type_matcher) << "for expression: " << test_case.expr; } @@ -311,11 +309,10 @@ TEST_P(OptionalStrictNullAssignmentTest, Runner) { EXPECT_THAT(result.GetIssues(), IsEmpty()) << "for expression: " << test_case.expr; ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); - const auto& ast_impl = ast_internal::AstImpl::CastFromPublicAst(*checked_ast); - int64_t root_id = ast_impl.root_expr().id(); + int64_t root_id = checked_ast->root_expr().id(); - EXPECT_THAT(ast_impl.GetTypeOrDyn(root_id), test_case.result_type_matcher) + EXPECT_THAT(checked_ast->GetTypeOrDyn(root_id), test_case.result_type_matcher) << "for expression: " << test_case.expr; } diff --git a/checker/standard_library_test.cc b/checker/standard_library_test.cc index cc051cfe3..2b7533027 100644 --- a/checker/standard_library_test.cc +++ b/checker/standard_library_test.cc @@ -27,7 +27,6 @@ #include "checker/type_checker_builder_factory.h" #include "checker/validation_result.h" #include "common/ast.h" -#include "common/ast/ast_impl.h" #include "common/ast/expr.h" #include "common/constant.h" #include "common/decl.h" @@ -41,7 +40,6 @@ namespace { using ::absl_testing::IsOk; using ::absl_testing::StatusIs; -using ::cel::ast_internal::AstImpl; using ::cel::ast_internal::Reference; using ::cel::internal::GetSharedTestingDescriptorPool; using ::testing::IsEmpty; @@ -132,11 +130,8 @@ TEST(StandardLibraryTest, ComprehensionResultTypeIsSubstituted) { ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); - const ast_internal::AstImpl& checked_impl = - ast_internal::AstImpl::CastFromPublicAst(*checked_ast); - ast_internal::Type type = - checked_impl.GetTypeOrDyn(checked_impl.root_expr().id()); + checked_ast->GetTypeOrDyn(checked_ast->root_expr().id()); EXPECT_TRUE(type.has_primitive() && type.primitive() == ast_internal::PrimitiveType::kInt64); } @@ -160,7 +155,7 @@ class StdlibTypeVarDefinitionTest public testing::WithParamInterface {}; TEST_P(StdlibTypeVarDefinitionTest, DefinesTypeConstants) { - auto ast = std::make_unique(); + auto ast = std::make_unique(); ast->mutable_root_expr().mutable_ident_expr().set_name(GetParam()); ast->mutable_root_expr().set_id(1); @@ -169,10 +164,9 @@ TEST_P(StdlibTypeVarDefinitionTest, DefinesTypeConstants) { EXPECT_THAT(result.GetIssues(), IsEmpty()); ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); - const auto& checked_impl = AstImpl::CastFromPublicAst(*checked_ast); - EXPECT_THAT(checked_impl.GetReference(1), + EXPECT_THAT(checked_ast->GetReference(1), Pointee(Property(&Reference::name, GetParam()))); - EXPECT_THAT(checked_impl.GetTypeOrDyn(1), Property(&AstType::has_type, true)); + EXPECT_THAT(checked_ast->GetTypeOrDyn(1), Property(&AstType::has_type, true)); } INSTANTIATE_TEST_SUITE_P(StdlibTypeVarDefinitions, StdlibTypeVarDefinitionTest, @@ -184,7 +178,7 @@ INSTANTIATE_TEST_SUITE_P(StdlibTypeVarDefinitions, StdlibTypeVarDefinitionTest, }); TEST_F(StandardLibraryDefinitionsTest, DefinesProtoStructNull) { - auto ast = std::make_unique(); + auto ast = std::make_unique(); auto& enumerator = ast->mutable_root_expr(); enumerator.set_id(4); @@ -204,14 +198,13 @@ TEST_F(StandardLibraryDefinitionsTest, DefinesProtoStructNull) { EXPECT_THAT(result.GetIssues(), IsEmpty()); ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); - const auto& checked_impl = AstImpl::CastFromPublicAst(*checked_ast); - EXPECT_THAT(checked_impl.GetReference(4), + EXPECT_THAT(checked_ast->GetReference(4), Pointee(Property(&Reference::name, "google.protobuf.NullValue.NULL_VALUE"))); } TEST_F(StandardLibraryDefinitionsTest, DefinesTypeType) { - auto ast = std::make_unique(); + auto ast = std::make_unique(); auto& ident = ast->mutable_root_expr(); ident.set_id(1); @@ -222,10 +215,9 @@ TEST_F(StandardLibraryDefinitionsTest, DefinesTypeType) { EXPECT_THAT(result.GetIssues(), IsEmpty()); ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); - const auto& checked_impl = AstImpl::CastFromPublicAst(*checked_ast); - EXPECT_THAT(checked_impl.GetReference(1), + EXPECT_THAT(checked_ast->GetReference(1), Pointee(Property(&Reference::name, "type"))); - EXPECT_THAT(checked_impl.GetTypeOrDyn(1), Property(&AstType::has_type, true)); + EXPECT_THAT(checked_ast->GetTypeOrDyn(1), Property(&AstType::has_type, true)); } struct DefinitionsTestCase { diff --git a/common/ast/ast_impl.h b/common/ast/ast_impl.h index 9f424bc7e..2b40dffc7 100644 --- a/common/ast/ast_impl.h +++ b/common/ast/ast_impl.h @@ -22,7 +22,6 @@ #include "common/ast/expr.h" #include "common/ast/metadata.h" // IWYU pragma: export #include "common/expr.h" -#include "internal/casts.h" namespace cel::ast_internal { @@ -34,24 +33,6 @@ class AstImpl : public Ast { using ReferenceMap = Ast::ReferenceMap; using TypeMap = Ast::TypeMap; - // Overloads for down casting from the public interface to the internal - // implementation. - static AstImpl& CastFromPublicAst(Ast& ast) { - return cel::internal::down_cast(ast); - } - - static const AstImpl& CastFromPublicAst(const Ast& ast) { - return cel::internal::down_cast(ast); - } - - static AstImpl* CastFromPublicAst(Ast* ast) { - return cel::internal::down_cast(ast); - } - - static const AstImpl* CastFromPublicAst(const Ast* ast) { - return cel::internal::down_cast(ast); - } - AstImpl() = default; AstImpl(Expr expr, SourceInfo source_info) diff --git a/common/ast_proto_test.cc b/common/ast_proto_test.cc index e9b36a4bd..195d25edd 100644 --- a/common/ast_proto_test.cc +++ b/common/ast_proto_test.cc @@ -32,7 +32,6 @@ #include "absl/strings/string_view.h" #include "absl/types/variant.h" #include "common/ast.h" -#include "common/ast/ast_impl.h" #include "common/ast/expr.h" #include "common/decl.h" #include "common/expr.h" @@ -72,8 +71,7 @@ absl::StatusOr ConvertProtoTypeToNative( CEL_ASSIGN_OR_RETURN(auto ast, CreateAstFromCheckedExpr(checked_expr)); - const auto& type_map = - ast_internal::AstImpl::CastFromPublicAst(*ast).type_map(); + const auto& type_map = ast->type_map(); auto iter = type_map.find(1); if (iter != type_map.end()) { return iter->second; @@ -385,8 +383,7 @@ TEST(AstConvertersTest, ReferenceToNative) { &reference_wrapper)); ASSERT_OK_AND_ASSIGN(auto ast, CreateAstFromCheckedExpr(reference_wrapper)); - const auto& native_references = - ast_internal::AstImpl::CastFromPublicAst(*ast).reference_map(); + const auto& native_references = ast->reference_map(); auto native_reference = native_references.at(1); @@ -415,8 +412,7 @@ TEST(AstConvertersTest, SourceInfoToNative) { &source_info_wrapper)); ASSERT_OK_AND_ASSIGN(auto ast, CreateAstFromParsedExpr(source_info_wrapper)); - const auto& native_source_info = - ast_internal::AstImpl::CastFromPublicAst(*ast).source_info(); + const auto& native_source_info = ast->source_info(); EXPECT_EQ(native_source_info.syntax_version(), "version"); EXPECT_EQ(native_source_info.location(), "location"); @@ -467,7 +463,7 @@ TEST(AstConvertersTest, CheckedExprToAst) { } TEST(AstConvertersTest, AstToCheckedExprBasic) { - ast_internal::AstImpl ast; + Ast ast; ast.mutable_root_expr().set_id(1); ast.mutable_root_expr().mutable_ident_expr().set_name("expr"); @@ -482,9 +478,6 @@ TEST(AstConvertersTest, AstToCheckedExprBasic) { macro.mutable_ident_expr().set_name("name"); ast.mutable_source_info().mutable_macro_calls().insert({1, std::move(macro)}); - ast_internal::AstImpl::TypeMap type_map; - ast_internal::AstImpl::ReferenceMap reference_map; - ast_internal::Reference reference; reference.set_name("name"); reference.mutable_overload_id().push_back("id1"); @@ -673,7 +666,7 @@ TEST(AstConvertersTest, AstToParsedExprBasic) { macro.mutable_ident_expr().set_name("name"); source_info.mutable_macro_calls().insert({1, std::move(macro)}); - ast_internal::AstImpl ast(std::move(expr), std::move(source_info)); + Ast ast(std::move(expr), std::move(source_info)); ParsedExpr parsed_expr; ASSERT_THAT(AstToParsedExpr(ast, &parsed_expr), IsOk()); @@ -859,14 +852,12 @@ TEST_P(ConversionRoundTripTest, ParsedExprCopyable) { ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, CreateAstFromParsedExpr(parsed_expr)); - const auto& impl = ast_internal::AstImpl::CastFromPublicAst(*ast); - CheckedExpr expr_pb; - EXPECT_THAT(AstToCheckedExpr(impl, &expr_pb), + EXPECT_THAT(AstToCheckedExpr(*ast, &expr_pb), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("AST is not type-checked"))); ParsedExpr proto_out; - ASSERT_THAT(AstToParsedExpr(impl, &proto_out), IsOk()); + ASSERT_THAT(AstToParsedExpr(*ast, &proto_out), IsOk()); EXPECT_THAT(proto_out, EqualsProto(parsed_expr)); } diff --git a/common/ast_rewrite_test.cc b/common/ast_rewrite_test.cc index 209a305d9..679c4caa2 100644 --- a/common/ast_rewrite_test.cc +++ b/common/ast_rewrite_test.cc @@ -535,9 +535,8 @@ TEST(AstRewrite, SelectRewriteExample) { std::unique_ptr ast, CreateAstFromParsedExpr( google::api::expr::parser::Parse("com.google.Identifier").value())); - AstImpl& ast_impl = AstImpl::CastFromPublicAst(*ast); RewriterExample example; - ASSERT_TRUE(AstRewrite(ast_impl.mutable_root_expr(), example)); + ASSERT_TRUE(AstRewrite(ast->mutable_root_expr(), example)); cel::expr::Expr expected_expr; google::protobuf::TextFormat::ParseFromString( @@ -550,7 +549,7 @@ TEST(AstRewrite, SelectRewriteExample) { cel::Expr expected_native; ASSERT_THAT(ExprFromProto(expected_expr, expected_native), IsOk()); - EXPECT_EQ(ast_impl.root_expr(), expected_native); + EXPECT_EQ(ast->root_expr(), expected_native); } // Rewrites x -> y -> z to demonstrate traversal when a node is rewritten on @@ -591,8 +590,7 @@ TEST(AstRewrite, PreAndPostVisitExpample) { std::unique_ptr ast, CreateAstFromParsedExpr(google::api::expr::parser::Parse("x").value())); PreRewriterExample visitor; - AstImpl& ast_impl = AstImpl::CastFromPublicAst(*ast); - ASSERT_TRUE(AstRewrite(ast_impl.mutable_root_expr(), visitor)); + ASSERT_TRUE(AstRewrite(ast->mutable_root_expr(), visitor)); cel::expr::Expr expected_expr; google::protobuf::TextFormat::ParseFromString( @@ -604,7 +602,7 @@ TEST(AstRewrite, PreAndPostVisitExpample) { cel::Expr expected_native; ASSERT_THAT(ExprFromProto(expected_expr, expected_native), IsOk()); - EXPECT_EQ(ast_impl.root_expr(), expected_native); + EXPECT_EQ(ast->root_expr(), expected_native); EXPECT_THAT(visitor.visited_idents(), ElementsAre("y")); } diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index 68c6dbfca..118311455 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -326,11 +326,10 @@ cc_library( ":resolver", "//base:builtins", "//base:data", + "//common:ast", "//common:constant", "//common:expr", - "//common:kind", "//common:value", - "//common/ast:ast_impl", "//eval/eval:const_value_step", "//eval/eval:evaluator_core", "//internal:status_macros", @@ -340,7 +339,6 @@ cc_library( "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:variant", "@com_google_protobuf//:protobuf", ], ) @@ -396,10 +394,10 @@ cc_library( ":resolver", "//base:ast", "//base:builtins", + "//common:ast", "//common:ast_rewrite", "//common:expr", "//common:kind", - "//common/ast:ast_impl", "//common/ast:expr", "//runtime:runtime_issue", "//runtime/internal:issue_collector", @@ -517,6 +515,7 @@ cc_library( deps = [ ":flat_expr_builder_extensions", "//base:builtins", + "//common:ast", "//common:casting", "//common:expr", "//common:native_type", @@ -548,7 +547,7 @@ cc_test( ":flat_expr_builder_extensions", ":regex_precompilation_optimization", ":resolver", - "//common/ast:ast_impl", + "//common:ast", "//eval/eval:evaluator_core", "//eval/public:activation", "//eval/public:builtin_func_registrar", @@ -579,9 +578,9 @@ cc_library( deps = [ ":flat_expr_builder_extensions", "//base:builtins", + "//common:ast", "//common:constant", "//common:expr", - "//common/ast:ast_impl", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:variant", @@ -594,6 +593,7 @@ cc_library( hdrs = ["instrumentation.h"], deps = [ ":flat_expr_builder_extensions", + "//common:ast", "//common:expr", "//common:value", "//common/ast:ast_impl", @@ -613,8 +613,8 @@ cc_test( ":flat_expr_builder", ":instrumentation", ":regex_precompilation_optimization", + "//common:ast", "//common:value", - "//common/ast:ast_impl", "//eval/eval:evaluator_core", "//extensions/protobuf:ast_converters", "//internal:testing", diff --git a/eval/compiler/comprehension_vulnerability_check.cc b/eval/compiler/comprehension_vulnerability_check.cc index 6085c27b4..ca3905024 100644 --- a/eval/compiler/comprehension_vulnerability_check.cc +++ b/eval/compiler/comprehension_vulnerability_check.cc @@ -22,7 +22,7 @@ #include "absl/strings/string_view.h" #include "absl/types/variant.h" #include "base/builtins.h" -#include "common/ast/ast_impl.h" +#include "common/ast.h" #include "common/constant.h" #include "common/expr.h" #include "eval/compiler/flat_expr_builder_extensions.h" @@ -267,7 +267,7 @@ class ComprehensionVulnerabilityCheck : public ProgramOptimizer { } // namespace ProgramOptimizerFactory CreateComprehensionVulnerabilityCheck() { - return [](PlannerContext&, const cel::ast_internal::AstImpl&) { + return [](PlannerContext&, const cel::Ast& ast) { return std::make_unique(); }; } diff --git a/eval/compiler/constant_folding.cc b/eval/compiler/constant_folding.cc index 22dacd81c..ff04379d2 100644 --- a/eval/compiler/constant_folding.cc +++ b/eval/compiler/constant_folding.cc @@ -23,13 +23,11 @@ #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/types/variant.h" #include "base/builtins.h" #include "base/type_provider.h" -#include "common/ast/ast_impl.h" +#include "common/ast.h" #include "common/constant.h" #include "common/expr.h" -#include "common/kind.h" #include "common/value.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/compiler/resolver.h" @@ -46,15 +44,7 @@ namespace cel::runtime_internal { namespace { -using ::cel::CallExpr; -using ::cel::ComprehensionExpr; -using ::cel::Constant; using ::cel::Expr; -using ::cel::IdentExpr; -using ::cel::ListExpr; -using ::cel::SelectExpr; -using ::cel::StructExpr; -using ::cel::ast_internal::AstImpl; using ::cel::builtin::kAnd; using ::cel::builtin::kOr; using ::cel::builtin::kTernary; @@ -257,7 +247,7 @@ ProgramOptimizerFactory CreateConstantFoldingOptimizer( [shared_arena = std::move(arena), shared_message_factory = std::move(message_factory)]( PlannerContext& context, - const AstImpl&) -> absl::StatusOr> { + const Ast&) -> absl::StatusOr> { // If one was explicitly provided during planning or none was explicitly // provided during configuration, request one from the planning context. // Otherwise use the one provided during configuration. diff --git a/eval/compiler/constant_folding_test.cc b/eval/compiler/constant_folding_test.cc index 07d2a8c7d..d1c0c31e0 100644 --- a/eval/compiler/constant_folding_test.cc +++ b/eval/compiler/constant_folding_test.cc @@ -25,7 +25,6 @@ #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "base/ast.h" -#include "common/ast/ast_impl.h" #include "common/expr.h" #include "common/value.h" #include "eval/compiler/flat_expr_builder_extensions.h" @@ -55,7 +54,6 @@ using ::absl_testing::IsOk; using ::absl_testing::StatusIs; using ::cel::Expr; using ::cel::RuntimeIssue; -using ::cel::ast_internal::AstImpl; using ::cel::runtime_internal::IssueCollector; using ::cel::runtime_internal::NewTestingRuntimeEnv; using ::cel::expr::ParsedExpr; @@ -107,9 +105,8 @@ TEST_F(UpdatedConstantFoldingTest, SkipsTernary) { // Arrange ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, ParseFromCel("true ? true : false")); - AstImpl& ast_impl = AstImpl::CastFromPublicAst(*ast); - const Expr& call = ast_impl.root_expr(); + const Expr& call = ast->root_expr(); const Expr& condition = call.call_expr().args()[0]; const Expr& true_branch = call.call_expr().args()[1]; const Expr& false_branch = call.call_expr().args()[2]; @@ -151,7 +148,7 @@ TEST_F(UpdatedConstantFoldingTest, SkipsTernary) { // Act // Issue the visitation calls. ASSERT_OK_AND_ASSIGN(std::unique_ptr constant_folder, - constant_folder_factory(context, ast_impl)); + constant_folder_factory(context, *ast)); ASSERT_OK(constant_folder->OnPreVisit(context, call)); ASSERT_OK(constant_folder->OnPreVisit(context, condition)); ASSERT_OK(constant_folder->OnPostVisit(context, condition)); @@ -171,9 +168,8 @@ TEST_F(UpdatedConstantFoldingTest, SkipsOr) { // Arrange ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, ParseFromCel("false || true")); - AstImpl& ast_impl = AstImpl::CastFromPublicAst(*ast); - const Expr& call = ast_impl.root_expr(); + const Expr& call = ast->root_expr(); const Expr& left_condition = call.call_expr().args()[0]; const Expr& right_condition = call.call_expr().args()[1]; @@ -211,7 +207,7 @@ TEST_F(UpdatedConstantFoldingTest, SkipsOr) { // Act // Issue the visitation calls. ASSERT_OK_AND_ASSIGN(std::unique_ptr constant_folder, - constant_folder_factory(context, ast_impl)); + constant_folder_factory(context, *ast)); ASSERT_OK(constant_folder->OnPreVisit(context, call)); ASSERT_OK(constant_folder->OnPreVisit(context, left_condition)); ASSERT_OK(constant_folder->OnPostVisit(context, left_condition)); @@ -229,9 +225,8 @@ TEST_F(UpdatedConstantFoldingTest, SkipsAnd) { // Arrange ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, ParseFromCel("true && false")); - AstImpl& ast_impl = AstImpl::CastFromPublicAst(*ast); - const Expr& call = ast_impl.root_expr(); + const Expr& call = ast->root_expr(); const Expr& left_condition = call.call_expr().args()[0]; const Expr& right_condition = call.call_expr().args()[1]; @@ -268,7 +263,7 @@ TEST_F(UpdatedConstantFoldingTest, SkipsAnd) { // Act // Issue the visitation calls. ASSERT_OK_AND_ASSIGN(std::unique_ptr constant_folder, - constant_folder_factory(context, ast_impl)); + constant_folder_factory(context, *ast)); ASSERT_OK(constant_folder->OnPreVisit(context, call)); ASSERT_OK(constant_folder->OnPreVisit(context, left_condition)); ASSERT_OK(constant_folder->OnPostVisit(context, left_condition)); @@ -285,9 +280,8 @@ TEST_F(UpdatedConstantFoldingTest, SkipsAnd) { TEST_F(UpdatedConstantFoldingTest, CreatesList) { // Arrange ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, ParseFromCel("[1, 2]")); - AstImpl& ast_impl = AstImpl::CastFromPublicAst(*ast); - const Expr& create_list = ast_impl.root_expr(); + const Expr& create_list = ast->root_expr(); const Expr& elem_one = create_list.list_expr().elements()[0].expr(); const Expr& elem_two = create_list.list_expr().elements()[1].expr(); @@ -323,7 +317,7 @@ TEST_F(UpdatedConstantFoldingTest, CreatesList) { // Act // Issue the visitation calls. ASSERT_OK_AND_ASSIGN(std::unique_ptr constant_folder, - constant_folder_factory(context, ast_impl)); + constant_folder_factory(context, *ast)); ASSERT_OK(constant_folder->OnPreVisit(context, create_list)); ASSERT_OK(constant_folder->OnPreVisit(context, elem_one)); ASSERT_OK(constant_folder->OnPostVisit(context, elem_one)); @@ -341,9 +335,8 @@ TEST_F(UpdatedConstantFoldingTest, CreatesLargeList) { // Arrange ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, ParseFromCel("[1, 2, 3, 4, 5]")); - AstImpl& ast_impl = AstImpl::CastFromPublicAst(*ast); - const Expr& create_list = ast_impl.root_expr(); + const Expr& create_list = ast->root_expr(); const Expr& elem0 = create_list.list_expr().elements()[0].expr(); const Expr& elem1 = create_list.list_expr().elements()[1].expr(); const Expr& elem2 = create_list.list_expr().elements()[2].expr(); @@ -400,7 +393,7 @@ TEST_F(UpdatedConstantFoldingTest, CreatesLargeList) { // Act // Issue the visitation calls. ASSERT_OK_AND_ASSIGN(std::unique_ptr constant_folder, - constant_folder_factory(context, ast_impl)); + constant_folder_factory(context, *ast)); ASSERT_THAT(constant_folder->OnPreVisit(context, create_list), IsOk()); ASSERT_THAT(constant_folder->OnPreVisit(context, elem0), IsOk()); ASSERT_THAT(constant_folder->OnPostVisit(context, elem0), IsOk()); @@ -423,9 +416,8 @@ TEST_F(UpdatedConstantFoldingTest, CreatesLargeList) { TEST_F(UpdatedConstantFoldingTest, CreatesMap) { // Arrange ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, ParseFromCel("{1: 2}")); - AstImpl& ast_impl = AstImpl::CastFromPublicAst(*ast); - const Expr& create_map = ast_impl.root_expr(); + const Expr& create_map = ast->root_expr(); const Expr& key = create_map.map_expr().entries()[0].key(); const Expr& value = create_map.map_expr().entries()[0].value(); @@ -462,7 +454,7 @@ TEST_F(UpdatedConstantFoldingTest, CreatesMap) { // Act // Issue the visitation calls. ASSERT_OK_AND_ASSIGN(std::unique_ptr constant_folder, - constant_folder_factory(context, ast_impl)); + constant_folder_factory(context, *ast)); ASSERT_OK(constant_folder->OnPreVisit(context, create_map)); ASSERT_OK(constant_folder->OnPreVisit(context, key)); ASSERT_OK(constant_folder->OnPostVisit(context, key)); @@ -479,9 +471,8 @@ TEST_F(UpdatedConstantFoldingTest, CreatesMap) { TEST_F(UpdatedConstantFoldingTest, CreatesInvalidMap) { // Arrange ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, ParseFromCel("{1.0: 2}")); - AstImpl& ast_impl = AstImpl::CastFromPublicAst(*ast); - const Expr& create_map = ast_impl.root_expr(); + const Expr& create_map = ast->root_expr(); const Expr& key = create_map.map_expr().entries()[0].key(); const Expr& value = create_map.map_expr().entries()[0].value(); @@ -519,7 +510,7 @@ TEST_F(UpdatedConstantFoldingTest, CreatesInvalidMap) { // Act // Issue the visitation calls. ASSERT_OK_AND_ASSIGN(std::unique_ptr constant_folder, - constant_folder_factory(context, ast_impl)); + constant_folder_factory(context, *ast)); ASSERT_OK(constant_folder->OnPreVisit(context, create_map)); ASSERT_OK(constant_folder->OnPreVisit(context, key)); ASSERT_OK(constant_folder->OnPostVisit(context, key)); @@ -535,9 +526,8 @@ TEST_F(UpdatedConstantFoldingTest, ErrorsOnUnexpectedOrder) { // Arrange ASSERT_OK_AND_ASSIGN(std::unique_ptr ast, ParseFromCel("true && false")); - AstImpl& ast_impl = AstImpl::CastFromPublicAst(*ast); - const Expr& call = ast_impl.root_expr(); + const Expr& call = ast->root_expr(); const Expr& left_condition = call.call_expr().args()[0]; const Expr& right_condition = call.call_expr().args()[1]; @@ -573,7 +563,7 @@ TEST_F(UpdatedConstantFoldingTest, ErrorsOnUnexpectedOrder) { // Act / Assert ASSERT_OK_AND_ASSIGN(std::unique_ptr constant_folder, - constant_folder_factory(context, ast_impl)); + constant_folder_factory(context, *ast)); EXPECT_THAT(constant_folder->OnPostVisit(context, left_condition), StatusIs(absl::StatusCode::kInternal)); } diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index ff03bd287..dcf407a52 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -2560,16 +2560,14 @@ absl::StatusOr FlatExprBuilder::CreateExpressionImpl( PlannerContext extension_context(env_, resolver, options_, GetTypeProvider(), issue_collector, program_builder, arena); - auto& ast_impl = AstImpl::CastFromPublicAst(*ast); - for (const std::unique_ptr& transform : ast_transforms_) { - CEL_RETURN_IF_ERROR(transform->UpdateAst(extension_context, ast_impl)); + CEL_RETURN_IF_ERROR(transform->UpdateAst(extension_context, *ast)); } std::vector> optimizers; for (const ProgramOptimizerFactory& optimizer_factory : program_optimizers_) { CEL_ASSIGN_OR_RETURN(auto optimizer, - optimizer_factory(extension_context, ast_impl)); + optimizer_factory(extension_context, *ast)); if (optimizer != nullptr) { optimizers.push_back(std::move(optimizer)); } @@ -2578,13 +2576,13 @@ absl::StatusOr FlatExprBuilder::CreateExpressionImpl( // These objects are expected to remain scoped to one build call -- references // to them shouldn't be persisted in any part of the result expression. FlatExprVisitor visitor(resolver, options_, std::move(optimizers), - ast_impl.reference_map(), GetTypeProvider(), + ast->reference_map(), GetTypeProvider(), issue_collector, program_builder, extension_context, enable_optional_types_); cel::TraversalOptions opts; opts.use_comprehension_callbacks = true; - AstTraverse(ast_impl.root_expr(), visitor, opts); + AstTraverse(ast->root_expr(), visitor, opts); if (!visitor.progress_status().ok()) { return visitor.progress_status(); diff --git a/eval/compiler/flat_expr_builder_extensions.h b/eval/compiler/flat_expr_builder_extensions.h index b88cf54d0..21e37b2a8 100644 --- a/eval/compiler/flat_expr_builder_extensions.h +++ b/eval/compiler/flat_expr_builder_extensions.h @@ -38,7 +38,6 @@ #include "absl/types/variant.h" #include "base/ast.h" #include "base/type_provider.h" -#include "common/ast/ast_impl.h" #include "common/expr.h" #include "common/native_type.h" #include "common/type_reflector.h" @@ -441,7 +440,7 @@ class AstTransform { virtual ~AstTransform() = default; virtual absl::Status UpdateAst(PlannerContext& context, - cel::ast_internal::AstImpl& ast) const = 0; + cel::Ast& ast) const = 0; }; // Interface for program optimizers. @@ -475,7 +474,7 @@ class ProgramOptimizer { // it is called from a synchronous context. using ProgramOptimizerFactory = absl::AnyInvocable>( - PlannerContext&, const cel::ast_internal::AstImpl&) const>; + PlannerContext&, const cel::Ast&) const>; } // namespace google::api::expr::runtime diff --git a/eval/compiler/instrumentation.cc b/eval/compiler/instrumentation.cc index 3ee672e4a..3e37bdb45 100644 --- a/eval/compiler/instrumentation.cc +++ b/eval/compiler/instrumentation.cc @@ -20,7 +20,7 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "common/ast/ast_impl.h" +#include "common/ast.h" #include "common/expr.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/eval/evaluator_core.h" @@ -80,8 +80,7 @@ class InstrumentOptimizer : public ProgramOptimizer { ProgramOptimizerFactory CreateInstrumentationExtension( InstrumentationFactory factory) { - return [fac = std::move(factory)](PlannerContext&, - const cel::ast_internal::AstImpl& ast) + return [fac = std::move(factory)](PlannerContext&, const cel::Ast& ast) -> absl::StatusOr> { Instrumentation ins = fac(ast); if (ins) { diff --git a/eval/compiler/instrumentation.h b/eval/compiler/instrumentation.h index badcde360..9096830a0 100644 --- a/eval/compiler/instrumentation.h +++ b/eval/compiler/instrumentation.h @@ -23,7 +23,7 @@ #include "absl/functional/any_invocable.h" #include "absl/status/status.h" -#include "common/ast/ast_impl.h" +#include "common/ast.h" #include "common/value.h" #include "eval/compiler/flat_expr_builder_extensions.h" @@ -45,8 +45,8 @@ using Instrumentation = // // An empty function object may be returned to skip instrumenting the given // expression. -using InstrumentationFactory = absl::AnyInvocable; +using InstrumentationFactory = + absl::AnyInvocable; // Create a new Instrumentation extension. // diff --git a/eval/compiler/instrumentation_test.cc b/eval/compiler/instrumentation_test.cc index 69b78a3ba..cf0527fc9 100644 --- a/eval/compiler/instrumentation_test.cc +++ b/eval/compiler/instrumentation_test.cc @@ -23,7 +23,7 @@ #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" -#include "common/ast/ast_impl.h" +#include "common/ast.h" #include "common/value.h" #include "eval/compiler/constant_folding.h" #include "eval/compiler/flat_expr_builder.h" @@ -89,9 +89,7 @@ TEST_F(InstrumentationTest, Basic) { }; builder.AddProgramOptimizer(CreateInstrumentationExtension( - [=](const cel::ast_internal::AstImpl&) -> Instrumentation { - return expr_id_recorder; - })); + [=](const cel::Ast&) -> Instrumentation { return expr_id_recorder; })); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("1 + 2 + 3")); ASSERT_OK_AND_ASSIGN(auto ast, @@ -130,9 +128,7 @@ TEST_F(InstrumentationTest, BasicWithConstFolding) { builder.AddProgramOptimizer( cel::runtime_internal::CreateConstantFoldingOptimizer()); builder.AddProgramOptimizer(CreateInstrumentationExtension( - [=](const cel::ast_internal::AstImpl&) -> Instrumentation { - return expr_id_recorder; - })); + [=](const cel::Ast&) -> Instrumentation { return expr_id_recorder; })); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("1 + 2 + 3")); ASSERT_OK_AND_ASSIGN(auto ast, @@ -175,9 +171,7 @@ TEST_F(InstrumentationTest, AndShortCircuit) { }; builder.AddProgramOptimizer(CreateInstrumentationExtension( - [=](const cel::ast_internal::AstImpl&) -> Instrumentation { - return expr_id_recorder; - })); + [=](const cel::Ast&) -> Instrumentation { return expr_id_recorder; })); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("a && b")); ASSERT_OK_AND_ASSIGN(auto ast, @@ -218,9 +212,7 @@ TEST_F(InstrumentationTest, OrShortCircuit) { }; builder.AddProgramOptimizer(CreateInstrumentationExtension( - [=](const cel::ast_internal::AstImpl&) -> Instrumentation { - return expr_id_recorder; - })); + [=](const cel::Ast&) -> Instrumentation { return expr_id_recorder; })); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("a || b")); ASSERT_OK_AND_ASSIGN(auto ast, @@ -261,9 +253,7 @@ TEST_F(InstrumentationTest, Ternary) { }; builder.AddProgramOptimizer(CreateInstrumentationExtension( - [=](const cel::ast_internal::AstImpl&) -> Instrumentation { - return expr_id_recorder; - })); + [=](const cel::Ast&) -> Instrumentation { return expr_id_recorder; })); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("(c)? a : b")); ASSERT_OK_AND_ASSIGN(auto ast, @@ -313,9 +303,7 @@ TEST_F(InstrumentationTest, OptimizedStepsNotEvaluated) { }; builder.AddProgramOptimizer(CreateInstrumentationExtension( - [=](const cel::ast_internal::AstImpl&) -> Instrumentation { - return expr_id_recorder; - })); + [=](const cel::Ast&) -> Instrumentation { return expr_id_recorder; })); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("r'test_string'.matches(r'[a-z_]+')")); @@ -341,9 +329,7 @@ TEST_F(InstrumentationTest, NoopSkipped) { FlatExprBuilder builder(env_, options_); builder.AddProgramOptimizer(CreateInstrumentationExtension( - [=](const cel::ast_internal::AstImpl&) -> Instrumentation { - return Instrumentation(); - })); + [=](const cel::Ast&) -> Instrumentation { return Instrumentation(); })); ASSERT_OK_AND_ASSIGN(ParsedExpr expr, Parse("(c)? a : b")); ASSERT_OK_AND_ASSIGN(auto ast, diff --git a/eval/compiler/qualified_reference_resolver.cc b/eval/compiler/qualified_reference_resolver.cc index 40b88e341..cd653da3c 100644 --- a/eval/compiler/qualified_reference_resolver.cc +++ b/eval/compiler/qualified_reference_resolver.cc @@ -29,7 +29,6 @@ #include "absl/types/optional.h" #include "base/ast.h" #include "base/builtins.h" -#include "common/ast/ast_impl.h" #include "common/ast/expr.h" #include "common/ast_rewrite.h" #include "common/expr.h" @@ -315,7 +314,7 @@ class ReferenceResolverExtension : public AstTransform { explicit ReferenceResolverExtension(ReferenceResolverOption opt) : opt_(opt) {} absl::Status UpdateAst(PlannerContext& context, - cel::ast_internal::AstImpl& ast) const override { + cel::Ast& ast) const override { if (opt_ == ReferenceResolverOption::kCheckedOnly && ast.reference_map().empty()) { return absl::OkStatus(); @@ -331,8 +330,7 @@ class ReferenceResolverExtension : public AstTransform { } // namespace absl::StatusOr ResolveReferences(const Resolver& resolver, - IssueCollector& issues, - cel::ast_internal::AstImpl& ast) { + IssueCollector& issues, cel::Ast& ast) { ReferenceResolver ref_resolver(ast.reference_map(), resolver, issues); // Rewriting interface doesn't support failing mid traverse propagate first diff --git a/eval/compiler/qualified_reference_resolver.h b/eval/compiler/qualified_reference_resolver.h index 4bca1d532..673273084 100644 --- a/eval/compiler/qualified_reference_resolver.h +++ b/eval/compiler/qualified_reference_resolver.h @@ -18,8 +18,7 @@ #include #include "absl/status/statusor.h" -#include "base/ast.h" -#include "common/ast/ast_impl.h" +#include "common/ast.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/compiler/resolver.h" #include "runtime/internal/issue_collector.h" @@ -37,7 +36,7 @@ namespace google::api::expr::runtime { // points to an expr node that isn't a reference). absl::StatusOr ResolveReferences( const Resolver& resolver, cel::runtime_internal::IssueCollector& issues, - cel::ast_internal::AstImpl& ast); + cel::Ast& ast); enum class ReferenceResolverOption { // Always attempt to resolve references based on runtime types and functions. diff --git a/eval/compiler/regex_precompilation_optimization.cc b/eval/compiler/regex_precompilation_optimization.cc index 39f0715db..e139492f1 100644 --- a/eval/compiler/regex_precompilation_optimization.cc +++ b/eval/compiler/regex_precompilation_optimization.cc @@ -27,6 +27,7 @@ #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "base/builtins.h" +#include "common/ast.h" #include "common/ast/ast_impl.h" #include "common/ast/expr.h" #include "common/casting.h" @@ -45,6 +46,7 @@ namespace google::api::expr::runtime { namespace { +using ::cel::Ast; using ::cel::CallExpr; using ::cel::Cast; using ::cel::Expr; @@ -52,7 +54,6 @@ using ::cel::InstanceOf; using ::cel::NativeTypeId; using ::cel::StringValue; using ::cel::Value; -using ::cel::ast_internal::AstImpl; using ::cel::ast_internal::Reference; using ::cel::internal::down_cast; @@ -270,7 +271,7 @@ class RegexPrecompilationOptimization : public ProgramOptimizer { ProgramOptimizerFactory CreateRegexPrecompilationExtension( int regex_max_program_size) { - return [=](PlannerContext& context, const AstImpl& ast) { + return [=](PlannerContext& context, const Ast& ast) { return std::make_unique( ast.reference_map(), regex_max_program_size); }; diff --git a/eval/compiler/regex_precompilation_optimization_test.cc b/eval/compiler/regex_precompilation_optimization_test.cc index dbeb77364..9666144b2 100644 --- a/eval/compiler/regex_precompilation_optimization_test.cc +++ b/eval/compiler/regex_precompilation_optimization_test.cc @@ -23,7 +23,7 @@ #include "cel/expr/syntax.pb.h" #include "absl/base/nullability.h" #include "absl/status/status.h" -#include "common/ast/ast_impl.h" +#include "common/ast.h" #include "eval/compiler/cel_expression_builder_flat_impl.h" #include "eval/compiler/constant_folding.h" #include "eval/compiler/flat_expr_builder.h" @@ -111,7 +111,7 @@ TEST_P(RegexPrecompilationExtensionTest, SmokeTest) { CreateRegexPrecompilationExtension(options_.regex_max_program_size); ExecutionPath path; ProgramBuilder program_builder; - cel::ast_internal::AstImpl ast_impl; + cel::Ast ast_impl; ast_impl.set_is_checked(true); std::shared_ptr arena; PlannerContext context(env_, resolver_, runtime_options_, diff --git a/extensions/BUILD b/extensions/BUILD index 7e03e314c..a6b9a0990 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -312,6 +312,7 @@ cc_library( deps = [ "//base:attributes", "//base:builtins", + "//common:ast", "//common:ast_rewrite", "//common:casting", "//common:constant", @@ -322,7 +323,6 @@ cc_library( "//common:type", "//common:value", "//common/ast:ast_impl", - "//common/ast:expr", "//eval/compiler:flat_expr_builder", "//eval/compiler:flat_expr_builder_extensions", "//eval/eval:attribute_trail", diff --git a/extensions/select_optimization.cc b/extensions/select_optimization.cc index bebc25cf8..30fe40355 100644 --- a/extensions/select_optimization.cc +++ b/extensions/select_optimization.cc @@ -36,8 +36,8 @@ #include "absl/types/variant.h" #include "base/attribute.h" #include "base/builtins.h" +#include "common/ast.h" #include "common/ast/ast_impl.h" -#include "common/ast/expr.h" #include "common/ast_rewrite.h" #include "common/casting.h" #include "common/constant.h" @@ -66,13 +66,13 @@ namespace cel::extensions { namespace { +using ::cel::Ast; using ::cel::AstRewriterBase; using ::cel::CallExpr; using ::cel::ConstantKind; using ::cel::Expr; using ::cel::ExprKind; using ::cel::SelectExpr; -using ::cel::ast_internal::AstImpl; using ::google::api::expr::runtime::AttributeTrail; using ::google::api::expr::runtime::DirectExpressionStep; using ::google::api::expr::runtime::ExecutionFrame; @@ -384,7 +384,7 @@ absl::StatusOr> SelectInstructionsFromCall( class RewriterImpl : public AstRewriterBase { public: - RewriterImpl(const AstImpl& ast, PlannerContext& planner_context) + RewriterImpl(const Ast& ast, PlannerContext& planner_context) : ast_(ast), planner_context_(planner_context) {} void PreVisitExpr(const Expr& expr) override { path_.push_back(&expr); } @@ -537,7 +537,7 @@ class RewriterImpl : public AstRewriterBase { } } - const AstImpl& ast_; + const Ast& ast_; PlannerContext& planner_context_; // ids of potentially optimizeable expr nodes. absl::flat_hash_map candidates_; @@ -905,7 +905,7 @@ google::api::expr::runtime::FlatExprBuilder* GetFlatExprBuilder( } // namespace absl::Status SelectOptimizationAstUpdater::UpdateAst(PlannerContext& context, - AstImpl& ast) const { + Ast& ast) const { RewriterImpl rewriter(ast, context); AstRewrite(ast.mutable_root_expr(), rewriter); return rewriter.GetProgressStatus(); @@ -914,7 +914,7 @@ absl::Status SelectOptimizationAstUpdater::UpdateAst(PlannerContext& context, google::api::expr::runtime::ProgramOptimizerFactory CreateSelectOptimizationProgramOptimizer( const SelectOptimizationOptions& options) { - return [=](PlannerContext& context, const cel::ast_internal::AstImpl& ast) { + return [=](PlannerContext& context, const Ast& ast) { return std::make_unique(options); }; } diff --git a/extensions/select_optimization.h b/extensions/select_optimization.h index d5b6799b3..344de11c9 100644 --- a/extensions/select_optimization.h +++ b/extensions/select_optimization.h @@ -16,7 +16,7 @@ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_SELECT_OPTIMIZATION_H_ #include "absl/status/status.h" -#include "common/ast/ast_impl.h" +#include "common/ast.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "runtime/runtime_builder.h" @@ -79,7 +79,7 @@ class SelectOptimizationAstUpdater SelectOptimizationAstUpdater() = default; absl::Status UpdateAst(google::api::expr::runtime::PlannerContext& context, - cel::ast_internal::AstImpl& ast) const override; + cel::Ast& ast) const override; }; google::api::expr::runtime::ProgramOptimizerFactory diff --git a/parser/parser_test.cc b/parser/parser_test.cc index 036d4f64c..e28030cca 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -1914,8 +1914,7 @@ TEST(NewParserBuilderTest, CustomMacros) { EXPECT_FALSE(ast->IsChecked()); KindAndIdAdorner kind_and_id_adorner; ExprPrinter w(kind_and_id_adorner); - const auto& ast_impl = cel::ast_internal::AstImpl::CastFromPublicAst(*ast); - EXPECT_EQ(w.Print(ast_impl.root_expr()), + EXPECT_EQ(w.Print(ast->root_expr()), "_&&_(\n" " a^#2:Expr.Ident#.b~test-only~^#4:Expr.Select#,\n" " []^#5:Expr.CreateList#.map(\n" @@ -1945,8 +1944,7 @@ TEST(NewParserBuilderTest, StandardMacrosNotAddedWithStdlib) { EXPECT_FALSE(ast->IsChecked()); KindAndIdAdorner kind_and_id_adorner; ExprPrinter w(kind_and_id_adorner); - const auto& ast_impl = cel::ast_internal::AstImpl::CastFromPublicAst(*ast); - EXPECT_EQ(w.Print(ast_impl.root_expr()), + EXPECT_EQ(w.Print(ast->root_expr()), "_&&_(\n" " a^#2:Expr.Ident#.b~test-only~^#4:Expr.Select#,\n" " []^#5:Expr.CreateList#.map(\n" diff --git a/testutil/BUILD b/testutil/BUILD index 566fc26fd..6fd1a1a9c 100644 --- a/testutil/BUILD +++ b/testutil/BUILD @@ -25,10 +25,9 @@ cc_library( hdrs = ["expr_printer.h"], deps = [ "//common:ast", + "//common:ast_proto", "//common:constant", "//common:expr", - "//common/ast:ast_impl", - "//extensions/protobuf:ast_converters", "//internal:strings", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/log:absl_log", diff --git a/testutil/baseline_tests.cc b/testutil/baseline_tests.cc index c5001ed81..f5f725819 100644 --- a/testutil/baseline_tests.cc +++ b/testutil/baseline_tests.cc @@ -21,7 +21,6 @@ #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "common/ast.h" -#include "common/ast/ast_impl.h" #include "common/ast/expr.h" #include "common/expr.h" #include "extensions/protobuf/ast_converters.h" @@ -30,30 +29,26 @@ namespace cel::test { namespace { -using ::cel::ast_internal::AstImpl; - -using AstType = ast_internal::Type; - -std::string FormatPrimitive(ast_internal::PrimitiveType t) { +std::string FormatPrimitive(PrimitiveType t) { switch (t) { - case ast_internal::PrimitiveType::kBool: + case PrimitiveType::kBool: return "bool"; - case ast_internal::PrimitiveType::kInt64: + case PrimitiveType::kInt64: return "int"; - case ast_internal::PrimitiveType::kUint64: + case PrimitiveType::kUint64: return "uint"; - case ast_internal::PrimitiveType::kDouble: + case PrimitiveType::kDouble: return "double"; - case ast_internal::PrimitiveType::kString: + case PrimitiveType::kString: return "string"; - case ast_internal::PrimitiveType::kBytes: + case PrimitiveType::kBytes: return "bytes"; default: return ""; } } -std::string FormatType(const AstType& t) { +std::string FormatType(const TypeSpec& t) { if (t.has_dyn()) { return "dyn"; } else if (t.has_null()) { @@ -86,7 +81,7 @@ std::string FormatType(const AstType& t) { } return s; } else if (t.has_type()) { - if (t.type() == AstType()) { + if (t.type() == TypeSpec()) { return "type"; } return absl::StrCat("type(", FormatType(t.type()), ")"); @@ -112,7 +107,7 @@ std::string FormatReference(const cel::ast_internal::Reference& r) { class TypeAdorner : public ExpressionAdorner { public: - explicit TypeAdorner(const AstImpl& ast) : ast_(ast) {} + explicit TypeAdorner(const Ast& ast) : ast_(ast) {} std::string Adorn(const Expr& e) const override { std::string s; @@ -135,16 +130,15 @@ class TypeAdorner : public ExpressionAdorner { std::string AdornMapEntry(const MapExprEntry& e) const override { return ""; } private: - const AstImpl& ast_; + const Ast& ast_; }; } // namespace std::string FormatBaselineAst(const Ast& ast) { - const auto& ast_impl = ast_internal::AstImpl::CastFromPublicAst(ast); - TypeAdorner adorner(ast_impl); + TypeAdorner adorner(ast); ExprPrinter printer(adorner); - return printer.Print(ast_impl.root_expr()); + return printer.Print(ast.root_expr()); } std::string FormatBaselineCheckedExpr( diff --git a/testutil/expr_printer.cc b/testutil/expr_printer.cc index 7a0fb016a..40dea3c33 100644 --- a/testutil/expr_printer.cc +++ b/testutil/expr_printer.cc @@ -24,17 +24,14 @@ #include "absl/strings/match.h" #include "absl/strings/str_format.h" #include "common/ast.h" -#include "common/ast/ast_impl.h" +#include "common/ast_proto.h" #include "common/constant.h" #include "common/expr.h" -#include "extensions/protobuf/ast_converters.h" #include "internal/strings.h" namespace cel::test { namespace { -using ::cel::extensions::CreateAstFromParsedExpr; - class EmptyAdornerImpl : public ExpressionAdorner { public: std::string Adorn(const Expr& e) const override { return ""; } @@ -323,9 +320,7 @@ std::string ExprPrinter::PrintProto(const cel::expr::Expr& expr) const { if (!ast.ok()) { return std::string(ast.status().message()); } - const ast_internal::AstImpl& ast_impl = - ast_internal::AstImpl::CastFromPublicAst(*ast.value()); - return w.Print(ast_impl.root_expr()); + return w.Print(ast.value()->root_expr()); } std::string ExprPrinter::Print(const Expr& expr) const {