Skip to content

Commit

Permalink
Fix runaway recursion when inferring cyclic types
Browse files Browse the repository at this point in the history
Instead of one-off hacks, we now simply limit the recursion depth when
trying to infer/resolve a type into it's final type. There are various
ways (due to the use of type placeholders and type parameters) one can
introduce cyclic types, and simply limiting the recursion depth is the
only consistent way of handling this (short of not supporting type
inference).

Changelog: fixed
  • Loading branch information
yorickpeterse committed Sep 29, 2022
1 parent 923cb3a commit a620693
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 66 deletions.
4 changes: 2 additions & 2 deletions compiler/src/mir/pattern_matching.rs
Expand Up @@ -799,7 +799,7 @@ impl<'a> Compiler<'a> {
return types.into_iter().map(|t| self.new_variable(t)).collect();
}

let ctx = TypeContext::with_arguments(
let mut ctx = TypeContext::with_arguments(
self.self_type,
instance.type_arguments(self.db()).clone(),
);
Expand All @@ -808,7 +808,7 @@ impl<'a> Compiler<'a> {
.into_iter()
.map(|raw_type| {
let inferred = raw_type
.inferred(self.db_mut(), &ctx, false)
.inferred(self.db_mut(), &mut ctx, false)
.cast_according_to(source_variable_type, self.db());

self.new_variable(inferred)
Expand Down
32 changes: 16 additions & 16 deletions compiler/src/type_check/expressions.rs
Expand Up @@ -487,7 +487,7 @@ impl MethodCall {
) -> TypeRef {
let typ = self.method.throw_type(&state.db).inferred(
&mut state.db,
&self.context,
&mut self.context,
false,
);

Expand All @@ -512,7 +512,7 @@ impl MethodCall {
) -> TypeRef {
let typ = self.method.return_type(&state.db).inferred(
&mut state.db,
&self.context,
&mut self.context,
false,
);

Expand Down Expand Up @@ -2111,7 +2111,7 @@ impl<'a> CheckMethodBody<'a> {
}

let immutable = value_type.is_ref(self.db());
let ctx = TypeContext::for_class_instance(self.db(), ins);
let mut ctx = TypeContext::for_class_instance(self.db(), ins);

for node in &mut node.values {
let name = &node.field.name;
Expand Down Expand Up @@ -2139,7 +2139,7 @@ impl<'a> CheckMethodBody<'a> {

let field_type = field
.value_type(self.db())
.inferred(self.db_mut(), &ctx, immutable)
.inferred(self.db_mut(), &mut ctx, immutable)
.cast_according_to(value_type, self.db());

node.field_id = Some(field);
Expand Down Expand Up @@ -2283,7 +2283,7 @@ impl<'a> CheckMethodBody<'a> {

for (patt, member) in node.values.iter_mut().zip(members.into_iter()) {
let typ = member
.inferred(self.db_mut(), &ctx, immutable)
.inferred(self.db_mut(), &mut ctx, immutable)
.cast_according_to(value_type, self.db());

self.pattern(patt, typ, pattern);
Expand Down Expand Up @@ -2456,7 +2456,7 @@ impl<'a> CheckMethodBody<'a> {
fn closure(
&mut self,
node: &mut hir::Closure,
expected: Option<(ClosureId, TypeRef, &mut TypeContext)>,
mut expected: Option<(ClosureId, TypeRef, &mut TypeContext)>,
scope: &mut LexicalScope,
) -> TypeRef {
let self_type = self.self_type;
Expand All @@ -2472,9 +2472,9 @@ impl<'a> CheckMethodBody<'a> {
let db = &mut self.state.db;

expected
.as_ref()
.as_mut()
.map(|(id, _, context)| {
id.throw_type(db).inferred(db, context, false)
id.throw_type(db).inferred(db, *context, false)
})
.unwrap_or_else(|| TypeRef::placeholder(self.db_mut()))
};
Expand All @@ -2485,9 +2485,9 @@ impl<'a> CheckMethodBody<'a> {
let db = &mut self.state.db;

expected
.as_ref()
.as_mut()
.map(|(id, _, context)| {
id.return_type(db).inferred(db, context, false)
id.return_type(db).inferred(db, *context, false)
})
.unwrap_or_else(|| TypeRef::placeholder(self.db_mut()))
};
Expand Down Expand Up @@ -2520,7 +2520,7 @@ impl<'a> CheckMethodBody<'a> {
let db = &mut self.state.db;

expected
.as_ref()
.as_mut()
.and_then(|(id, _, context)| {
id.positional_argument_input_type(db, index)
.map(|t| t.inferred(db, context, false))
Expand Down Expand Up @@ -3337,7 +3337,7 @@ impl<'a> CheckMethodBody<'a> {
);

let var_type =
var_type.inferred(self.db_mut(), &ctx, false);
var_type.inferred(self.db_mut(), &mut ctx, false);

node.kind = CallKind::SetField(FieldInfo {
id: field,
Expand Down Expand Up @@ -3482,12 +3482,12 @@ impl<'a> CheckMethodBody<'a> {
let throws = closure
.throw_type(self.db())
.as_rigid_type(&mut self.state.db, self.bounds)
.inferred(self.db_mut(), &ctx, false);
.inferred(self.db_mut(), &mut ctx, false);

let returns = closure
.return_type(self.db())
.as_rigid_type(&mut self.state.db, self.bounds)
.inferred(self.db_mut(), &ctx, false);
.inferred(self.db_mut(), &mut ctx, false);

if let Some(block) = node.else_block.as_mut() {
if throws.is_never(self.db()) {
Expand Down Expand Up @@ -3576,15 +3576,15 @@ impl<'a> CheckMethodBody<'a> {
.copy_into(&mut ctx.type_arguments);

let returns = if raw_typ.is_owned_or_uni(db) {
let typ = raw_typ.inferred(db, &ctx, false);
let typ = raw_typ.inferred(db, &mut ctx, false);

if receiver.is_ref(db) {
typ.as_ref(db)
} else {
typ.as_mut(db)
}
} else {
raw_typ.inferred(db, &ctx, receiver.is_ref(db))
raw_typ.inferred(db, &mut ctx, receiver.is_ref(db))
};

if receiver.require_sendable_arguments(self.db())
Expand Down
115 changes: 67 additions & 48 deletions types/src/lib.rs
Expand Up @@ -85,6 +85,17 @@ pub const FIELDS_LIMIT: usize = u8::MAX as usize;

const MAX_FORMATTING_DEPTH: usize = 8;

/// The maximum recursion/depth to restrict ourselves to when inferring types or
/// checking if they are inferred.
///
/// In certain cases we may end up with cyclic types, where the cycles are
/// non-trivial (e.g. `A -> B -> C -> D -> A`). To prevent runaway recursion we
/// limit such operations to a certain depth.
///
/// The depth here is sufficiently large that no sane program should run into
/// it, but we also won't blow the stack.
const MAX_TYPE_DEPTH: usize = 64;

pub fn format_type<T: FormatType>(db: &Database, typ: T) -> String {
TypeFormatter::new(db, None, None).format(typ)
}
Expand Down Expand Up @@ -324,6 +335,14 @@ impl TypePlaceholderId {
}

fn assign(self, db: &Database, value: TypeRef) {
// Assigning placeholders to themselves creates cycles that aren't
// useful, so we ignore those.
if let TypeRef::Placeholder(id) = value {
if id.0 == self.0 {
return;
}
}

self.get(db).value.set(value);

for &mut id in self.depending(db) {
Expand Down Expand Up @@ -371,11 +390,21 @@ pub struct TypeContext {
/// When type-checking a method call, this table contains the type
/// parameters and values of both the receiver and the method itself.
pub type_arguments: TypeArguments,

/// The nesting/recursion depth when e.g. inferring a type.
///
/// This value is used to prevent runaway recursion that can occur when
/// dealing with (complex) cyclic types.
depth: usize,
}

impl TypeContext {
pub fn new(self_type_id: TypeId) -> Self {
Self { self_type: self_type_id, type_arguments: TypeArguments::new() }
Self {
self_type: self_type_id,
type_arguments: TypeArguments::new(),
depth: 0,
}
}

pub fn for_class_instance(db: &Database, instance: ClassInstance) -> Self {
Expand All @@ -385,14 +414,18 @@ impl TypeContext {
TypeArguments::new()
};

Self { self_type: TypeId::ClassInstance(instance), type_arguments }
Self {
self_type: TypeId::ClassInstance(instance),
type_arguments,
depth: 0,
}
}

pub fn with_arguments(
self_type_id: TypeId,
type_arguments: TypeArguments,
) -> Self {
Self { self_type: self_type_id, type_arguments }
Self { self_type: self_type_id, type_arguments, depth: 0 }
}
}

Expand Down Expand Up @@ -1081,7 +1114,7 @@ impl TraitInstance {
fn inferred(
self,
db: &mut Database,
context: &TypeContext,
context: &mut TypeContext,
immutable: bool,
) -> Self {
if !self.instance_of.is_generic(db) {
Expand All @@ -1091,25 +1124,7 @@ impl TraitInstance {
let mut new_args = TypeArguments::new();

for (arg, val) in self.type_arguments(db).pairs() {
let new_val = match val {
TypeRef::Placeholder(id) => match id.value(db) {
Some(TypeRef::Owned(TypeId::TraitInstance(ins)))
| Some(TypeRef::Mut(TypeId::TraitInstance(ins)))
| Some(TypeRef::Ref(TypeId::TraitInstance(ins)))
| Some(TypeRef::Uni(TypeId::TraitInstance(ins)))
| Some(TypeRef::RefUni(TypeId::TraitInstance(ins)))
| Some(TypeRef::MutUni(TypeId::TraitInstance(ins)))
| Some(TypeRef::Infer(TypeId::TraitInstance(ins)))
if ins == self =>
{
TypeRef::Unknown
}
_ => val.inferred(db, context, immutable),
},
_ => val.inferred(db, context, immutable),
};

new_args.assign(arg, new_val);
new_args.assign(arg, val.inferred(db, context, immutable));
}

Self::generic(db, self.instance_of, new_args)
Expand Down Expand Up @@ -1952,7 +1967,7 @@ impl ClassInstance {
fn inferred(
self,
db: &mut Database,
context: &TypeContext,
context: &mut TypeContext,
immutable: bool,
) -> Self {
if !self.instance_of.is_generic(db) {
Expand All @@ -1962,25 +1977,7 @@ impl ClassInstance {
let mut new_args = TypeArguments::new();

for (param, val) in self.type_arguments(db).pairs() {
let new_val = match val {
TypeRef::Placeholder(id) => match id.value(db) {
Some(TypeRef::Owned(TypeId::ClassInstance(ins)))
| Some(TypeRef::Mut(TypeId::ClassInstance(ins)))
| Some(TypeRef::Ref(TypeId::ClassInstance(ins)))
| Some(TypeRef::Uni(TypeId::ClassInstance(ins)))
| Some(TypeRef::RefUni(TypeId::ClassInstance(ins)))
| Some(TypeRef::MutUni(TypeId::ClassInstance(ins)))
| Some(TypeRef::Infer(TypeId::ClassInstance(ins)))
if ins == self =>
{
TypeRef::Unknown
}
_ => val.inferred(db, context, immutable),
},
_ => val.inferred(db, context, immutable),
};

new_args.assign(param, new_val);
new_args.assign(param, val.inferred(db, context, immutable));
}

Self::generic(db, self.instance_of, new_args)
Expand Down Expand Up @@ -3137,7 +3134,7 @@ impl ClosureId {
fn inferred(
self,
db: &mut Database,
context: &TypeContext,
context: &mut TypeContext,
immutable: bool,
) -> Self {
let mut new_func = self.get(db).clone();
Expand Down Expand Up @@ -3790,10 +3787,16 @@ impl TypeRef {
pub fn inferred(
self,
db: &mut Database,
context: &TypeContext,
context: &mut TypeContext,
immutable: bool,
) -> TypeRef {
match self {
if context.depth == MAX_TYPE_DEPTH {
return TypeRef::Unknown;
}

context.depth += 1;

let result = match self {
TypeRef::OwnedSelf => TypeRef::owned_or_ref(
self.infer_self_type_id(db, context),
immutable,
Expand Down Expand Up @@ -3955,7 +3958,10 @@ impl TypeRef {
|| if immutable { self.as_ref(db) } else { self },
),
_ => self,
}
};

context.depth -= 1;
result
}

pub fn as_enum_instance(
Expand Down Expand Up @@ -4546,7 +4552,7 @@ impl TypeRef {
self,
type_parameter: TypeParameterId,
db: &mut Database,
context: &TypeContext,
context: &mut TypeContext,
immutable: bool,
) -> TypeRef {
if let Some(arg) = context.type_arguments.get(type_parameter) {
Expand Down Expand Up @@ -8484,4 +8490,17 @@ mod tests {

db.module("foo");
}

#[test]
fn test_type_placeholder_id_assign() {
let mut db = Database::new();
let p1 = TypePlaceholder::alloc(&mut db);
let p2 = TypePlaceholder::alloc(&mut db);

p1.assign(&db, TypeRef::Any);
p2.assign(&db, TypeRef::Placeholder(p2));

assert_eq!(p1.value(&db), Some(TypeRef::Any));
assert!(p2.value(&db).is_none());
}
}

0 comments on commit a620693

Please sign in to comment.