Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions checker/internal/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ cc_library(
srcs = ["type_inference_context.cc"],
hdrs = ["type_inference_context.h"],
deps = [
":format_type_name",
"//common:decl",
"//common:type",
"//common:type_kind",
Expand Down
70 changes: 52 additions & 18 deletions checker/internal/type_inference_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@
#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "checker/internal/format_type_name.h"
#include "common/decl.h"
#include "common/type.h"
#include "common/type_kind.h"
Expand Down Expand Up @@ -267,14 +270,15 @@ bool TypeInferenceContext::IsAssignableInternal(
// Checking assignability to a specific type var
// that has a prospective type assignment.
to.kind() == TypeKind::kTypeParam &&
prospective_substitutions.contains(to.AsTypeParam()->name())) {
auto prospective_subs_cpy(prospective_substitutions);
prospective_substitutions.contains(to.GetTypeParam().name())) {
SubstitutionMap prospective_subs_cpy = prospective_substitutions;
if (CompareGenerality(from_subs, to_subs, prospective_subs_cpy) ==
RelativeGenerality::kMoreGeneral) {
if (IsAssignableInternal(to_subs, from_subs, prospective_subs_cpy) &&
!OccursWithin(to.name(), from_subs, prospective_subs_cpy)) {
prospective_subs_cpy[to.AsTypeParam()->name()] = from_subs;
prospective_substitutions = prospective_subs_cpy;
!OccursWithin(to.GetTypeParam().name(), from_subs,
prospective_subs_cpy)) {
prospective_subs_cpy[to.GetTypeParam().name()] = from_subs;
prospective_substitutions = std::move(prospective_subs_cpy);
return true;
// otherwise, continue with normal assignability check.
}
Expand Down Expand Up @@ -454,17 +458,35 @@ bool TypeInferenceContext::OccursWithin(
//
// This check guarantees that we don't introduce a recursive type definition
// (a cycle in the substitution map).
if (type.kind() == TypeKind::kTypeParam) {
if (type.AsTypeParam()->name() == var_name) {
//
// We can't reuse Substitute here because it does the pointer chasing and
// might hide a cycle.
//
// E.g.
// T2 in T3 when
// T3 -> T2 -> null_type;
Type substitution = type;
while (substitution.kind() == TypeKind::kTypeParam) {
absl::string_view param_name = substitution.AsTypeParam()->name();
if (param_name == var_name) {
return true;
}
auto typeSubs = Substitute(type, substitutions);
if (typeSubs != type && OccursWithin(var_name, typeSubs, substitutions)) {
return true;

if (auto it = substitutions.find(param_name); it != substitutions.end()) {
substitution = it->second;
continue;
}
if (auto it = type_parameter_bindings_.find(param_name);
it != type_parameter_bindings_.end() && it->second.type.has_value()) {
substitution = it->second.type.value();
continue;
}

// Type parameter is free.
return false;
}

for (const auto& param : type.GetParameters()) {
for (const auto& param : substitution.GetParameters()) {
if (OccursWithin(var_name, param, substitutions)) {
return true;
}
Expand Down Expand Up @@ -526,19 +548,18 @@ TypeInferenceContext::ResolveOverload(const FunctionDecl& decl,
ABSL_DCHECK_EQ(argument_types.size(),
call_type_instance.param_types.size());
bool is_match = true;
SubstitutionMap prospective_substitutions;
AssignabilityContext assignability_context = CreateAssignabilityContext();
for (int i = 0; i < argument_types.size(); ++i) {
if (!IsAssignableInternal(argument_types[i],
call_type_instance.param_types[i],
prospective_substitutions)) {
if (!assignability_context.IsAssignable(
argument_types[i], call_type_instance.param_types[i])) {
is_match = false;
break;
}
}

if (is_match) {
matching_overloads.push_back(ovl);
UpdateTypeParameterBindings(prospective_substitutions);
assignability_context.UpdateInferredTypeAssignments();
if (!result_type.has_value()) {
result_type = call_type_instance.result_type;
} else {
Expand Down Expand Up @@ -625,10 +646,23 @@ bool TypeInferenceContext::AssignabilityContext::IsAssignable(const Type& from,
prospective_substitutions_);
}

std::string TypeInferenceContext::DebugString() const {
return absl::StrCat(
"type_parameter_bindings: ",
absl::StrJoin(
type_parameter_bindings_, "\n ",
[](std::string* out, const auto& binding) {
absl::StrAppend(
out, binding.first, " (", binding.second.name, ") -> ",
checker_internal::FormatTypeName(
binding.second.type.value_or(Type(TypeParamType("none")))));
}));
}

void TypeInferenceContext::AssignabilityContext::
UpdateInferredTypeAssignments() {
inference_context_.UpdateTypeParameterBindings(
std::move(prospective_substitutions_));
inference_context_.UpdateTypeParameterBindings(prospective_substitutions_);
prospective_substitutions_.clear();
}

void TypeInferenceContext::AssignabilityContext::Reset() {
Expand Down
14 changes: 1 addition & 13 deletions checker/internal/type_inference_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include "absl/container/node_hash_map.h"
#include "absl/log/absl_check.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
Expand Down Expand Up @@ -141,18 +140,7 @@ class TypeInferenceContext {
// Checks if `from` is assignable to `to`.
bool IsAssignable(const Type& from, const Type& to);

std::string DebugString() const {
return absl::StrCat(
"type_parameter_bindings: ",
absl::StrJoin(
type_parameter_bindings_, "\n ",
[](std::string* out, const auto& binding) {
absl::StrAppend(
out, binding.first, " (", binding.second.name, ") -> ",
binding.second.type.value_or(Type(TypeParamType("none")))
.DebugString());
}));
}
std::string DebugString() const;

private:
struct TypeVar {
Expand Down
10 changes: 10 additions & 0 deletions checker/optional_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,16 @@ INSTANTIATE_TEST_SUITE_P(
"optional.none()",
IsOptionalType(TypeSpec(DynTypeSpec())),
},
// Odd case -- the correct result might be a bespoke recursively-defined
// type but CEL doesn't support that. Null is used because it is
// implicitly assignable to optional types. This allows for a recursive
// type to be non-trivial and verify the checker is actually avoiding
// introducing a cyclic type.
TestCase{
"[optional.none()].map(x, [?x, null, x])",
Eq(TypeSpec(ListTypeSpec(std::make_unique<TypeSpec>(
ListTypeSpec(std::make_unique<TypeSpec>(NullTypeSpec())))))),
},
TestCase{
"optional.of('abc').hasValue()",
Eq(TypeSpec(PrimitiveType::kBool)),
Expand Down