From 033986e74e54d1aae8d33a8db9fbe418852ee58c Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Thu, 9 Oct 2025 13:10:15 -0700 Subject: [PATCH] [checker] Fix bug where recursive type could be inferred The cycle check during inference had a bug where certain recursive definitions could be inferred. This would lead to the checker eventually crashing when it tried to realize a concrete type from the substitution map. ex. `[optional.none()].map(x, [?x, null, x])` PiperOrigin-RevId: 817307359 --- checker/internal/BUILD | 1 + checker/internal/type_inference_context.cc | 70 ++++++++++++++++------ checker/internal/type_inference_context.h | 14 +---- checker/optional_test.cc | 10 ++++ 4 files changed, 64 insertions(+), 31 deletions(-) diff --git a/checker/internal/BUILD b/checker/internal/BUILD index a264628d8..0f8f28f66 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -218,6 +218,7 @@ cc_library( srcs = ["type_inference_context.cc"], hdrs = ["type_inference_context.h"], deps = [ + ":format_type_name", "//common:decl", "//common:type", "//common:type_kind", diff --git a/checker/internal/type_inference_context.cc b/checker/internal/type_inference_context.cc index dd43be990..96d985071 100644 --- a/checker/internal/type_inference_context.cc +++ b/checker/internal/type_inference_context.cc @@ -23,9 +23,12 @@ #include "absl/log/absl_check.h" #include "absl/log/absl_log.h" #include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" +#include "checker/internal/format_type_name.h" #include "common/decl.h" #include "common/type.h" #include "common/type_kind.h" @@ -267,14 +270,15 @@ bool TypeInferenceContext::IsAssignableInternal( // Checking assignability to a specific type var // that has a prospective type assignment. to.kind() == TypeKind::kTypeParam && - prospective_substitutions.contains(to.AsTypeParam()->name())) { - auto prospective_subs_cpy(prospective_substitutions); + prospective_substitutions.contains(to.GetTypeParam().name())) { + SubstitutionMap prospective_subs_cpy = prospective_substitutions; if (CompareGenerality(from_subs, to_subs, prospective_subs_cpy) == RelativeGenerality::kMoreGeneral) { if (IsAssignableInternal(to_subs, from_subs, prospective_subs_cpy) && - !OccursWithin(to.name(), from_subs, prospective_subs_cpy)) { - prospective_subs_cpy[to.AsTypeParam()->name()] = from_subs; - prospective_substitutions = prospective_subs_cpy; + !OccursWithin(to.GetTypeParam().name(), from_subs, + prospective_subs_cpy)) { + prospective_subs_cpy[to.GetTypeParam().name()] = from_subs; + prospective_substitutions = std::move(prospective_subs_cpy); return true; // otherwise, continue with normal assignability check. } @@ -454,17 +458,35 @@ bool TypeInferenceContext::OccursWithin( // // This check guarantees that we don't introduce a recursive type definition // (a cycle in the substitution map). - if (type.kind() == TypeKind::kTypeParam) { - if (type.AsTypeParam()->name() == var_name) { + // + // We can't reuse Substitute here because it does the pointer chasing and + // might hide a cycle. + // + // E.g. + // T2 in T3 when + // T3 -> T2 -> null_type; + Type substitution = type; + while (substitution.kind() == TypeKind::kTypeParam) { + absl::string_view param_name = substitution.AsTypeParam()->name(); + if (param_name == var_name) { return true; } - auto typeSubs = Substitute(type, substitutions); - if (typeSubs != type && OccursWithin(var_name, typeSubs, substitutions)) { - return true; + + if (auto it = substitutions.find(param_name); it != substitutions.end()) { + substitution = it->second; + continue; + } + if (auto it = type_parameter_bindings_.find(param_name); + it != type_parameter_bindings_.end() && it->second.type.has_value()) { + substitution = it->second.type.value(); + continue; } + + // Type parameter is free. + return false; } - for (const auto& param : type.GetParameters()) { + for (const auto& param : substitution.GetParameters()) { if (OccursWithin(var_name, param, substitutions)) { return true; } @@ -526,11 +548,10 @@ TypeInferenceContext::ResolveOverload(const FunctionDecl& decl, ABSL_DCHECK_EQ(argument_types.size(), call_type_instance.param_types.size()); bool is_match = true; - SubstitutionMap prospective_substitutions; + AssignabilityContext assignability_context = CreateAssignabilityContext(); for (int i = 0; i < argument_types.size(); ++i) { - if (!IsAssignableInternal(argument_types[i], - call_type_instance.param_types[i], - prospective_substitutions)) { + if (!assignability_context.IsAssignable( + argument_types[i], call_type_instance.param_types[i])) { is_match = false; break; } @@ -538,7 +559,7 @@ TypeInferenceContext::ResolveOverload(const FunctionDecl& decl, if (is_match) { matching_overloads.push_back(ovl); - UpdateTypeParameterBindings(prospective_substitutions); + assignability_context.UpdateInferredTypeAssignments(); if (!result_type.has_value()) { result_type = call_type_instance.result_type; } else { @@ -625,10 +646,23 @@ bool TypeInferenceContext::AssignabilityContext::IsAssignable(const Type& from, prospective_substitutions_); } +std::string TypeInferenceContext::DebugString() const { + return absl::StrCat( + "type_parameter_bindings: ", + absl::StrJoin( + type_parameter_bindings_, "\n ", + [](std::string* out, const auto& binding) { + absl::StrAppend( + out, binding.first, " (", binding.second.name, ") -> ", + checker_internal::FormatTypeName( + binding.second.type.value_or(Type(TypeParamType("none"))))); + })); +} + void TypeInferenceContext::AssignabilityContext:: UpdateInferredTypeAssignments() { - inference_context_.UpdateTypeParameterBindings( - std::move(prospective_substitutions_)); + inference_context_.UpdateTypeParameterBindings(prospective_substitutions_); + prospective_substitutions_.clear(); } void TypeInferenceContext::AssignabilityContext::Reset() { diff --git a/checker/internal/type_inference_context.h b/checker/internal/type_inference_context.h index 644e87d9a..1a1043047 100644 --- a/checker/internal/type_inference_context.h +++ b/checker/internal/type_inference_context.h @@ -23,7 +23,6 @@ #include "absl/container/node_hash_map.h" #include "absl/log/absl_check.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" @@ -141,18 +140,7 @@ class TypeInferenceContext { // Checks if `from` is assignable to `to`. bool IsAssignable(const Type& from, const Type& to); - std::string DebugString() const { - return absl::StrCat( - "type_parameter_bindings: ", - absl::StrJoin( - type_parameter_bindings_, "\n ", - [](std::string* out, const auto& binding) { - absl::StrAppend( - out, binding.first, " (", binding.second.name, ") -> ", - binding.second.type.value_or(Type(TypeParamType("none"))) - .DebugString()); - })); - } + std::string DebugString() const; private: struct TypeVar { diff --git a/checker/optional_test.cc b/checker/optional_test.cc index be05eccd8..8285e51df 100644 --- a/checker/optional_test.cc +++ b/checker/optional_test.cc @@ -158,6 +158,16 @@ INSTANTIATE_TEST_SUITE_P( "optional.none()", IsOptionalType(TypeSpec(DynTypeSpec())), }, + // Odd case -- the correct result might be a bespoke recursively-defined + // type but CEL doesn't support that. Null is used because it is + // implicitly assignable to optional types. This allows for a recursive + // type to be non-trivial and verify the checker is actually avoiding + // introducing a cyclic type. + TestCase{ + "[optional.none()].map(x, [?x, null, x])", + Eq(TypeSpec(ListTypeSpec(std::make_unique( + ListTypeSpec(std::make_unique(NullTypeSpec())))))), + }, TestCase{ "optional.of('abc').hasValue()", Eq(TypeSpec(PrimitiveType::kBool)),