Skip to content

Commit

Permalink
Make type checking return types sound
Browse files Browse the repository at this point in the history
When a method returns a type parameter or a type containing a type
parameter, without an explicit ownership, we no longer consider
references to the type parameter to be compatible with it. This means
code such as this is no longer valid:

    class Box[V] {
      let @value: V

      fn foo -> V {
        @value
      }
    }

Without these changes the above code is unsound, as for `Box[User]` the
return type of `Box.foo` is inferred as `User`, when in reality a
`ref User` is returned.

Changelog: fixed
  • Loading branch information
yorickpeterse committed Dec 13, 2023
1 parent 7138301 commit 6f2341e
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 36 deletions.
8 changes: 4 additions & 4 deletions compiler/src/type_check/expressions.rs
Expand Up @@ -1324,7 +1324,7 @@ impl<'a> CheckMethodBody<'a> {
return;
}

if !TypeChecker::check(self.db(), typ, returns) {
if !TypeChecker::check_return(self.db(), typ, returns) {
let loc =
nodes.last().map(|n| n.location()).unwrap_or(fallback_location);

Expand Down Expand Up @@ -3071,7 +3071,7 @@ impl<'a> CheckMethodBody<'a> {

let expected = scope.return_type;

if !TypeChecker::check(self.db(), returned, expected) {
if !TypeChecker::check_return(self.db(), returned, expected) {
self.state.diagnostics.type_error(
format_type(self.db(), returned),
format_type(self.db(), expected),
Expand Down Expand Up @@ -3117,7 +3117,7 @@ impl<'a> CheckMethodBody<'a> {
pid.assign(self.db(), typ);
}
ThrowKind::Result(ret_ok, ret_err) => {
if !TypeChecker::check(self.db(), throw_type, ret_err) {
if !TypeChecker::check_return(self.db(), throw_type, ret_err) {
self.state.diagnostics.invalid_throw(
ThrowKind::Result(ret_ok, expr)
.throw_type_name(self.db(), ret_ok),
Expand Down Expand Up @@ -4044,7 +4044,7 @@ impl<'a> CheckMethodBody<'a> {
ThrowKind::Result(ok, expr_err),
ThrowKind::Result(ret_ok, ret_err),
) => {
if TypeChecker::check(self.db(), expr_err, ret_err) {
if TypeChecker::check_return(self.db(), expr_err, ret_err) {
return ok;
}

Expand Down
2 changes: 1 addition & 1 deletion std/test/diagnostics/assign_async_field_recover_owned.inko
Expand Up @@ -2,7 +2,7 @@ class async Thing[A, B] {
let @a: A
let @b: B

fn static new(a: A, b: uni B) -> Thing[A, B] {
fn static new(a: A, b: uni B) -> Thing[A, move B] {
Thing { @a = a, @b = recover b }
}
}
Expand Down
20 changes: 20 additions & 0 deletions std/test/diagnostics/return_when_any_is_expected.inko
@@ -0,0 +1,20 @@
class Box[V] {
let @value: V

fn foo -> V {
@value
}

fn move bar -> V {
@value
}
}

impl Box if V: mut {
fn mut baz -> V {
@value
}
}

# return_when_any_is_expected.inko:5:5 error(invalid-type): expected a value of type 'V', found 'ref V'
# return_when_any_is_expected.inko:15:5 error(invalid-type): expected a value of type 'V: mut', found 'mut V: mut'
137 changes: 106 additions & 31 deletions types/src/check.rs
Expand Up @@ -12,6 +12,28 @@ enum Subtyping {
Once,
}

#[derive(Copy, Clone)]
enum Kind {
/// A regular type check.
Regular,

/// A type check as part of a type cast.
Cast,

/// A type check for a return value.
Return,
}

impl Kind {
fn is_return(self) -> bool {
matches!(self, Kind::Return)
}

fn is_cast(self) -> bool {
matches!(self, Kind::Cast)
}
}

#[derive(Copy, Clone)]
struct Rules {
/// The rules to apply when performing sub-typing checks.
Expand All @@ -28,10 +50,8 @@ struct Rules {
/// contexts (e.g. when comparing trait implementations).
rigid_parameters: bool,

/// If we're performing a type-check as part of a type cast.
///
/// When enabled, certain type-checking rules may be relaxed.
type_cast: bool,
/// What kind of type check we're performing.
kind: Kind,
}

impl Rules {
Expand All @@ -41,7 +61,7 @@ impl Rules {
implicit_root_ref: false,
uni_compatible_with_owned: true,
rigid_parameters: false,
type_cast: false,
kind: Kind::Regular,
}
}

Expand All @@ -63,8 +83,8 @@ impl Rules {
self
}

fn with_type_cast(mut self) -> Rules {
self.type_cast = true;
fn with_kind(mut self, kind: Kind) -> Rules {
self.kind = kind;
self
}

Expand Down Expand Up @@ -143,7 +163,20 @@ impl<'a> TypeChecker<'a> {
let mut env =
Environment::new(left.type_arguments(db), right.type_arguments(db));

let rules = Rules::new().with_type_cast().with_one_time_subtyping();
let rules =
Rules::new().with_kind(Kind::Cast).with_one_time_subtyping();

TypeChecker::new(db).check_type_ref(left, right, &mut env, rules)
}

pub fn check_return(
db: &'a Database,
left: TypeRef,
right: TypeRef,
) -> bool {
let rules = Rules::new().with_kind(Kind::Return);
let mut env =
Environment::new(left.type_arguments(db), right.type_arguments(db));

TypeChecker::new(db).check_type_ref(left, right, &mut env, rules)
}
Expand Down Expand Up @@ -317,7 +350,10 @@ impl<'a> TypeChecker<'a> {
_ => true,
},
TypeRef::Owned(left_id) => match right {
TypeRef::Owned(right_id) | TypeRef::Any(right_id) => {
TypeRef::Any(right_id) if !rules.kind.is_return() => {
self.check_type_id(left_id, right_id, env, rules)
}
TypeRef::Owned(right_id) => {
self.check_type_id(left_id, right_id, env, rules)
}
TypeRef::Ref(right_id) | TypeRef::Mut(right_id)
Expand All @@ -343,7 +379,7 @@ impl<'a> TypeChecker<'a> {
left, left_id, orig_right, id, env, rules,
)
}
TypeRef::Pointer(_) if rules.type_cast => match left_id {
TypeRef::Pointer(_) if rules.kind.is_cast() => match left_id {
TypeId::ClassInstance(ins) => ins.instance_of().0 == INT_ID,
TypeId::Foreign(ForeignType::Int(_, _)) => true,
_ => false,
Expand All @@ -357,7 +393,10 @@ impl<'a> TypeChecker<'a> {
{
self.check_type_id(left_id, right_id, env, rules)
}
TypeRef::Any(right_id) | TypeRef::Uni(right_id) => {
TypeRef::Any(right_id) if !rules.kind.is_return() => {
self.check_type_id(left_id, right_id, env, rules)
}
TypeRef::Uni(right_id) => {
self.check_type_id(left_id, right_id, env, rules)
}
TypeRef::Ref(right_id) | TypeRef::Mut(right_id) if is_val => {
Expand Down Expand Up @@ -403,7 +442,10 @@ impl<'a> TypeChecker<'a> {
{
false
}
TypeRef::Ref(right_id) | TypeRef::Any(right_id) => {
TypeRef::Any(right_id) if !rules.kind.is_return() => {
self.check_type_id(left_id, right_id, env, rules)
}
TypeRef::Ref(right_id) => {
self.check_type_id(left_id, right_id, env, rules)
}
TypeRef::Owned(right_id)
Expand Down Expand Up @@ -436,7 +478,10 @@ impl<'a> TypeChecker<'a> {
_ => false,
},
TypeRef::Mut(left_id) => match right {
TypeRef::Ref(right_id) | TypeRef::Any(right_id) => {
TypeRef::Any(right_id) if !rules.kind.is_return() => {
self.check_type_id(left_id, right_id, env, rules)
}
TypeRef::Ref(right_id) => {
self.check_type_id(left_id, right_id, env, rules)
}
TypeRef::Mut(right_id) => self.check_type_id(
Expand Down Expand Up @@ -509,7 +554,8 @@ impl<'a> TypeChecker<'a> {
}
}
(Any, _) => true,
(Owned, TypeRef::Owned(_) | TypeRef::Any(_)) => true,
(Owned, TypeRef::Any(_)) => !rules.kind.is_return(),
(Owned, TypeRef::Owned(_)) => true,
(Owned, TypeRef::Ref(_) | TypeRef::Mut(_)) => {
allow_ref || rval
}
Expand All @@ -521,15 +567,14 @@ impl<'a> TypeChecker<'a> {
(Ref, TypeRef::Any(TypeId::TypeParameter(pid))) => {
!pid.is_mutable(self.db) || rval
}
(Ref, TypeRef::Ref(_) | TypeRef::Any(_)) => true,
(Ref, TypeRef::Any(_)) => !rules.kind.is_return(),
(Ref, TypeRef::Ref(_)) => true,
(
Ref,
TypeRef::Owned(_) | TypeRef::Uni(_) | TypeRef::Mut(_),
) => rval,
(
Mut,
TypeRef::Any(_) | TypeRef::Ref(_) | TypeRef::Mut(_),
) => true,
(Mut, TypeRef::Any(_)) => !rules.kind.is_return(),
(Mut, TypeRef::Ref(_) | TypeRef::Mut(_)) => true,
(Mut, TypeRef::Owned(_) | TypeRef::Uni(_)) => rval,
_ => false,
};
Expand All @@ -542,14 +587,14 @@ impl<'a> TypeChecker<'a> {
}
TypeRef::Pointer(left_id) => match right {
TypeRef::Pointer(right_id) => {
rules.type_cast
rules.kind.is_cast()
|| self.check_type_id(left_id, right_id, env, rules)
}
TypeRef::Owned(TypeId::Foreign(ForeignType::Int(_, _))) => {
rules.type_cast
rules.kind.is_cast()
}
TypeRef::Owned(TypeId::ClassInstance(ins)) => {
rules.type_cast && ins.instance_of().0 == INT_ID
rules.kind.is_cast() && ins.instance_of().0 == INT_ID
}
TypeRef::Placeholder(right_id) => {
match right_id.ownership {
Expand Down Expand Up @@ -591,7 +636,7 @@ impl<'a> TypeChecker<'a> {
TypeId::ClassInstance(lhs) => match right_id {
TypeId::ClassInstance(rhs) => {
if lhs.instance_of != rhs.instance_of {
if rules.type_cast
if rules.kind.is_cast()
&& lhs.instance_of.is_numeric()
&& rhs.instance_of.is_numeric()
{
Expand Down Expand Up @@ -622,14 +667,15 @@ impl<'a> TypeChecker<'a> {
TypeId::TraitInstance(rhs)
if !lhs.instance_of().kind(self.db).is_extern() =>
{
if rules.type_cast && !lhs.instance_of().allow_cast(self.db)
if rules.kind.is_cast()
&& !lhs.instance_of().allow_cast(self.db)
{
return false;
}

self.check_class_with_trait(lhs, rhs, env, trait_rules)
}
TypeId::TypeParameter(_) if rules.type_cast => false,
TypeId::TypeParameter(_) if rules.kind.is_cast() => false,
TypeId::TypeParameter(rhs)
if !lhs.instance_of().kind(self.db).is_extern() =>
{
Expand All @@ -645,14 +691,14 @@ impl<'a> TypeChecker<'a> {
)
})
}
TypeId::Foreign(_) => rules.type_cast,
TypeId::Foreign(_) => rules.kind.is_cast(),
_ => false,
},
TypeId::TraitInstance(lhs) => match right_id {
TypeId::TraitInstance(rhs) => {
self.check_traits(lhs, rhs, env, rules)
}
TypeId::TypeParameter(_) if rules.type_cast => false,
TypeId::TypeParameter(_) if rules.kind.is_cast() => false,
TypeId::TypeParameter(rhs) => rhs
.requirements(self.db)
.into_iter()
Expand All @@ -663,7 +709,7 @@ impl<'a> TypeChecker<'a> {
TypeId::TypeParameter(rhs) => {
self.check_parameters(lhs, rhs, env, rules)
}
TypeId::Foreign(_) => rules.type_cast,
TypeId::Foreign(_) => rules.kind.is_cast(),
_ => false,
},
TypeId::RigidTypeParameter(lhs)
Expand Down Expand Up @@ -699,7 +745,7 @@ impl<'a> TypeChecker<'a> {
_ => false,
},
TypeId::Foreign(ForeignType::Int(lsize, lsigned)) => {
if rules.type_cast {
if rules.kind.is_cast() {
match right_id {
TypeId::Foreign(_) => true,
TypeId::ClassInstance(ins) => {
Expand All @@ -720,7 +766,7 @@ impl<'a> TypeChecker<'a> {
}
}
TypeId::Foreign(ForeignType::Float(lsize)) => {
if rules.type_cast {
if rules.kind.is_cast() {
match right_id {
TypeId::Foreign(_) => true,
TypeId::ClassInstance(ins) => {
Expand Down Expand Up @@ -758,7 +804,7 @@ impl<'a> TypeChecker<'a> {
self.check_parameter_with_trait(left, req, env, rules)
})
}
TypeId::Foreign(_) => rules.type_cast,
TypeId::Foreign(_) => rules.kind.is_cast(),
_ => false,
}
}
Expand Down Expand Up @@ -1206,6 +1252,16 @@ mod tests {
);
}

#[track_caller]
fn check_err_return(db: &Database, left: TypeRef, right: TypeRef) {
assert!(
!TypeChecker::check_return(db, left, right),
"Expected {} to not be compatible with {}",
format_type(db, left),
format_type(db, right)
);
}

#[test]
fn test_never() {
let mut db = Database::new();
Expand Down Expand Up @@ -2714,4 +2770,23 @@ mod tests {
&mut env
));
}

#[test]
fn test_check_return() {
let mut db = Database::new();
let thing = new_class(&mut db, "Thing");
let owned_var = TypePlaceholder::alloc(&mut db, None).as_owned();
let uni_var = TypePlaceholder::alloc(&mut db, None).as_uni();
let ref_var = TypePlaceholder::alloc(&mut db, None).as_ref();
let mut_var = TypePlaceholder::alloc(&mut db, None).as_mut();

check_err_return(&db, owned(instance(thing)), any(instance(thing)));
check_err_return(&db, uni(instance(thing)), any(instance(thing)));
check_err_return(&db, immutable(instance(thing)), any(instance(thing)));
check_err_return(&db, mutable(instance(thing)), any(instance(thing)));
check_err_return(&db, placeholder(owned_var), any(instance(thing)));
check_err_return(&db, placeholder(uni_var), any(instance(thing)));
check_err_return(&db, placeholder(ref_var), any(instance(thing)));
check_err_return(&db, placeholder(mut_var), any(instance(thing)));
}
}

0 comments on commit 6f2341e

Please sign in to comment.