Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion checker/internal/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ cc_test(
"//common:expr",
"//common:source",
"//common:type",
"//common/ast:ast_impl",
"//common/ast:expr",
"//internal:status_macros",
"//internal:testing",
Expand Down
5 changes: 1 addition & 4 deletions checker/internal/test_ast_helpers_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,16 @@
#include "absl/status/status.h"
#include "absl/status/status_matchers.h"
#include "common/ast.h"
#include "common/ast/ast_impl.h"
#include "internal/testing.h"

namespace cel::checker_internal {
namespace {

using ::absl_testing::StatusIs;
using ::cel::ast_internal::AstImpl;

TEST(MakeTestParsedAstTest, Works) {
ASSERT_OK_AND_ASSIGN(std::unique_ptr<Ast> ast, MakeTestParsedAst("123"));
AstImpl& impl = AstImpl::CastFromPublicAst(*ast);
EXPECT_TRUE(impl.root_expr().has_const_expr());
EXPECT_TRUE(ast->root_expr().has_const_expr());
}

TEST(MakeTestParsedAstTest, ForwardsParseError) {
Expand Down
9 changes: 3 additions & 6 deletions checker/internal/type_checker_builder_impl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ namespace {

using ::absl_testing::IsOk;
using ::absl_testing::StatusIs;
using ::cel::ast_internal::AstImpl;

using AstType = cel::ast_internal::Type;

Expand All @@ -67,9 +66,7 @@ TEST_P(ContextDeclsFieldsDefinedTest, ContextDeclsFieldsDefined) {

ASSERT_TRUE(result.IsValid());

const auto& ast_impl = AstImpl::CastFromPublicAst(*result.GetAst());

EXPECT_EQ(ast_impl.GetReturnType(), GetParam().expected_type);
EXPECT_EQ(result.GetAst()->GetReturnType(), GetParam().expected_type);
}

INSTANTIATE_TEST_SUITE_P(
Expand Down Expand Up @@ -324,9 +321,9 @@ TEST(TypeCheckerBuilderImplTest, ReplaceVariable) {

ASSERT_TRUE(result.IsValid());

const auto& ast_impl = AstImpl::CastFromPublicAst(*result.GetAst());
const auto& checked_ast = *result.GetAst();

EXPECT_EQ(ast_impl.GetReturnType(),
EXPECT_EQ(checked_ast.GetReturnType(),
AstType(ast_internal::PrimitiveType::kString));
}

Expand Down
30 changes: 13 additions & 17 deletions checker/internal/type_checker_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@
namespace cel::checker_internal {
namespace {

using cel::ast_internal::AstImpl;

using AstType = cel::ast_internal::Type;
using Severity = TypeCheckIssue::Severity;

Expand All @@ -69,7 +67,7 @@ std::string FormatCandidate(absl::Span<const std::string> qualifiers) {
return absl::StrJoin(qualifiers, ".");
}

SourceLocation ComputeSourceLocation(const AstImpl& ast, int64_t expr_id) {
SourceLocation ComputeSourceLocation(const Ast& ast, int64_t expr_id) {
const auto& source_info = ast.source_info();
auto iter = source_info.positions().find(expr_id);
if (iter == source_info.positions().end()) {
Expand Down Expand Up @@ -248,7 +246,7 @@ class ResolveVisitor : public AstVisitorBase {

ResolveVisitor(absl::string_view container,
NamespaceGenerator namespace_generator,
const TypeCheckEnv& env, const AstImpl& ast,
const TypeCheckEnv& env, const Ast& ast,
TypeInferenceContext& inference_context,
std::vector<TypeCheckIssue>& issues,
google::protobuf::Arena* absl_nonnull arena)
Expand Down Expand Up @@ -468,7 +466,7 @@ class ResolveVisitor : public AstVisitorBase {
const TypeCheckEnv* absl_nonnull env_;
TypeInferenceContext* absl_nonnull inference_context_;
std::vector<TypeCheckIssue>* absl_nonnull issues_;
const ast_internal::AstImpl* absl_nonnull ast_;
const Ast* absl_nonnull ast_;
VariableScope root_scope_;
google::protobuf::Arena* absl_nonnull arena_;

Expand Down Expand Up @@ -1198,8 +1196,7 @@ class ResolveRewriter : public AstRewriterBase {
explicit ResolveRewriter(const ResolveVisitor& visitor,
const TypeInferenceContext& inference_context,
const CheckerOptions& options,
AstImpl::ReferenceMap& references,
AstImpl::TypeMap& types)
Ast::ReferenceMap& references, Ast::TypeMap& types)
: visitor_(visitor),
inference_context_(inference_context),
reference_map_(references),
Expand Down Expand Up @@ -1264,16 +1261,15 @@ class ResolveRewriter : public AstRewriterBase {
absl::Status status_;
const ResolveVisitor& visitor_;
const TypeInferenceContext& inference_context_;
AstImpl::ReferenceMap& reference_map_;
AstImpl::TypeMap& type_map_;
Ast::ReferenceMap& reference_map_;
Ast::TypeMap& type_map_;
const CheckerOptions& options_;
};

} // namespace

absl::StatusOr<ValidationResult> TypeCheckerImpl::Check(
std::unique_ptr<Ast> ast) const {
auto& ast_impl = AstImpl::CastFromPublicAst(*ast);
google::protobuf::Arena type_arena;

std::vector<TypeCheckIssue> issues;
Expand All @@ -1282,13 +1278,13 @@ absl::StatusOr<ValidationResult> TypeCheckerImpl::Check(

TypeInferenceContext type_inference_context(
&type_arena, options_.enable_legacy_null_assignment);
ResolveVisitor visitor(env_.container(), std::move(generator), env_, ast_impl,
ResolveVisitor visitor(env_.container(), std::move(generator), env_, *ast,
type_inference_context, issues, &type_arena);

TraversalOptions opts;
opts.use_comprehension_callbacks = true;
bool error_limit_reached = false;
auto traversal = AstTraversal::Create(ast_impl.root_expr(), opts);
auto traversal = AstTraversal::Create(ast->root_expr(), opts);

for (int step = 0; step < options_.max_expression_node_count * 2; ++step) {
bool has_next = traversal.Step(visitor);
Expand All @@ -1315,7 +1311,7 @@ absl::StatusOr<ValidationResult> TypeCheckerImpl::Check(
{}, absl::StrCat("maximum number of ERROR issues exceeded: ",
options_.max_error_issues)));
} else if (env_.expected_type().has_value()) {
visitor.AssertExpectedType(ast_impl.root_expr(), *env_.expected_type());
visitor.AssertExpectedType(ast->root_expr(), *env_.expected_type());
}

// If any issues are errors, return without an AST.
Expand All @@ -1329,13 +1325,13 @@ absl::StatusOr<ValidationResult> TypeCheckerImpl::Check(
// Happens in a second pass to simplify validating that pointers haven't
// been invalidated by other updates.
ResolveRewriter rewriter(visitor, type_inference_context, options_,
ast_impl.mutable_reference_map(),
ast_impl.mutable_type_map());
AstRewrite(ast_impl.mutable_root_expr(), rewriter);
ast->mutable_reference_map(),
ast->mutable_type_map());
AstRewrite(ast->mutable_root_expr(), rewriter);

CEL_RETURN_IF_ERROR(rewriter.status());

ast_impl.set_is_checked(true);
ast->set_is_checked(true);

return ValidationResult(std::move(ast), std::move(issues));
}
Expand Down
Loading