Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions checker/internal/type_check_env.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ class TypeCheckEnv {
return variables_.insert({decl.name(), std::move(decl)}).second;
}

// Inserts a variable declaration into the environment of the current scope.
// Parent scopes are not searched.
void InsertOrReplaceVariable(VariableDecl decl) {
variables_[decl.name()] = std::move(decl);
}

const absl::flat_hash_map<std::string, FunctionDecl>& functions() const {
return functions_;
}
Expand Down
34 changes: 28 additions & 6 deletions checker/internal/type_checker_builder_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ absl::Status TypeCheckerBuilderImpl::ApplyConfig(
}
break;
}
case AddSemantic::kTryMerge:
case AddSemantic::kTryMerge: {
const FunctionDecl* existing_decl = env.LookupFunction(decl.name());
FunctionDecl to_add = std::move(decl);
if (existing_decl != nullptr) {
Expand All @@ -190,17 +190,33 @@ absl::Status TypeCheckerBuilderImpl::ApplyConfig(
}
env.InsertOrReplaceFunction(std::move(to_add));
break;
}
default:
return absl::InternalError(absl::StrCat(
"unsupported function add semantic: ", fn.add_semantic));
}
}

for (const google::protobuf::Descriptor* context_type : config.context_types) {
CEL_RETURN_IF_ERROR(AddContextDeclarationVariables(context_type, env));
}

for (VariableDecl& var : config.variables) {
if (!env.InsertVariableIfAbsent(var)) {
return absl::AlreadyExistsError(
absl::StrCat("variable '", var.name(), "' declared multiple times"));
for (VariableDeclRecord& var : config.variables) {
switch (var.add_semantic) {
case AddSemantic::kInsertIfAbsent: {
if (!env.InsertVariableIfAbsent(var.decl)) {
return absl::AlreadyExistsError(absl::StrCat(
"variable '", var.decl.name(), "' declared multiple times"));
}
break;
}
case AddSemantic::kInsertOrReplace: {
env.InsertOrReplaceVariable(var.decl);
break;
}
default:
return absl::InternalError(absl::StrCat(
"unsupported variable add semantic: ", var.add_semantic));
}
}

Expand Down Expand Up @@ -274,7 +290,13 @@ absl::Status TypeCheckerBuilderImpl::AddLibrarySubset(
}

absl::Status TypeCheckerBuilderImpl::AddVariable(const VariableDecl& decl) {
target_config_->variables.push_back(std::move(decl));
target_config_->variables.push_back({decl, AddSemantic::kInsertIfAbsent});
return absl::OkStatus();
}

absl::Status TypeCheckerBuilderImpl::AddOrReplaceVariable(
const VariableDecl& decl) {
target_config_->variables.push_back({decl, AddSemantic::kInsertOrReplace});
return absl::OkStatus();
}

Expand Down
13 changes: 10 additions & 3 deletions checker/internal/type_checker_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,14 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder {
absl::Status AddLibrarySubset(TypeCheckerSubset subset) override;

absl::Status AddVariable(const VariableDecl& decl) override;
absl::Status AddOrReplaceVariable(const VariableDecl& decl) override;
absl::Status AddContextDeclaration(absl::string_view type) override;

absl::Status AddFunction(const FunctionDecl& decl) override;
absl::Status MergeFunction(const FunctionDecl& decl) override;

void SetExpectedType(const Type& type) override;

absl::Status MergeFunction(const FunctionDecl& decl) override;

void AddTypeProvider(std::unique_ptr<TypeIntrospector> provider) override;

void set_container(absl::string_view container) override;
Expand All @@ -92,11 +93,17 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder {
// Sematic for adding a possibly duplicated declaration.
enum class AddSemantic {
kInsertIfAbsent,
kInsertOrReplace,
// Attempts to merge with any existing overloads for the same function.
// Will fail if any of the IDs or signatures collide.
kTryMerge,
};

struct VariableDeclRecord {
VariableDecl decl;
AddSemantic add_semantic;
};

struct FunctionDeclRecord {
FunctionDecl decl;
AddSemantic add_semantic;
Expand All @@ -106,7 +113,7 @@ class TypeCheckerBuilderImpl : public TypeCheckerBuilder {
// Used to replay the configuration in calls to Build().
struct ConfigRecord {
std::string id = "";
std::vector<VariableDecl> variables;
std::vector<VariableDeclRecord> variables;
std::vector<FunctionDeclRecord> functions;
std::vector<std::shared_ptr<const TypeIntrospector>> type_providers;
std::vector<const google::protobuf::Descriptor*> context_types;
Expand Down
24 changes: 24 additions & 0 deletions checker/internal/type_checker_builder_impl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,5 +205,29 @@ TEST(ContextDeclsTest, ErrorOnOverlappingVariableDeclaration) {
"variable 'single_int64' declared multiple times"));
}

TEST(ContextDeclsTest, ReplaceVariable) {
TypeCheckerBuilderImpl builder(internal::GetSharedTestingDescriptorPool(),
{});
ASSERT_THAT(
builder.AddContextDeclaration("cel.expr.conformance.proto3.TestAllTypes"),
IsOk());
ASSERT_THAT(builder.AddOrReplaceVariable(
MakeVariableDecl("single_int64", StringType())),
IsOk());

ASSERT_OK_AND_ASSIGN(std::unique_ptr<TypeChecker> type_checker,
builder.Build());
ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("single_int64"));
ASSERT_OK_AND_ASSIGN(ValidationResult result,
type_checker->Check(std::move(ast)));

ASSERT_TRUE(result.IsValid());

const auto& ast_impl = AstImpl::CastFromPublicAst(*result.GetAst());

EXPECT_EQ(ast_impl.GetReturnType(),
AstType(ast_internal::PrimitiveType::kString));
}

} // namespace
} // namespace cel::checker_internal
20 changes: 13 additions & 7 deletions checker/type_checker_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ class TypeCheckerBuilder {
// with the resulting type checker.
virtual absl::Status AddVariable(const VariableDecl& decl) = 0;

// Adds a variable declaration that may be referenced in expressions checked
// with the resulting type checker.
//
// This version replaces any existing variable declaration with the same name.
virtual absl::Status AddOrReplaceVariable(const VariableDecl& decl) = 0;

// Declares struct type by fully qualified name as a context declaration.
//
// Context declarations are a way to declare a group of variables based on the
Expand All @@ -100,6 +106,13 @@ class TypeCheckerBuilder {
// with the resulting TypeChecker.
virtual absl::Status AddFunction(const FunctionDecl& decl) = 0;

// Adds function declaration overloads to the TypeChecker being built.
//
// Attempts to merge with any existing overloads for a function decl with the
// same name. If the overloads are not compatible, an error is returned and
// no change is made.
virtual absl::Status MergeFunction(const FunctionDecl& decl) = 0;

// Sets the expected type for checked expressions.
//
// Validation will fail with an ERROR level issue if the deduced type of the
Expand All @@ -108,13 +121,6 @@ class TypeCheckerBuilder {
// Note: if set multiple times, the last value is used.
virtual void SetExpectedType(const Type& type) = 0;

// Adds function declaration overloads to the TypeChecker being built.
//
// Attempts to merge with any existing overloads for a function decl with the
// same name. If the overloads are not compatible, an error is returned and
// no change is made.
virtual absl::Status MergeFunction(const FunctionDecl& decl) = 0;

// Adds a type provider to the TypeChecker being built.
//
// Type providers are used to describe custom types with typed field
Expand Down