diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs index 0806a8eb75..b9fd51dbb0 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs @@ -410,12 +410,15 @@ pub(crate) fn check_methods_signatures( trait_impl_generic_count: usize, errors: &mut Vec<(CompilationError, FileId)>, ) { - let the_trait = resolver.interner.get_trait(trait_id); - - let self_type = resolver.get_self_type().expect("trait impl must have a Self type"); + let self_type = resolver.get_self_type().expect("trait impl must have a Self type").clone(); // Temporarily bind the trait's Self type to self_type so we can type check - the_trait.self_type_typevar.bind(self_type.clone()); + let the_trait = resolver.interner.get_trait_mut(trait_id); + the_trait.self_type_typevar.bind(self_type); + + // Temporarily take the trait's methods so we can use both them and a mutable reference + // to the interner within the loop. + let trait_methods = std::mem::take(&mut the_trait.methods); for (file_id, func_id) in impl_methods { let impl_method = resolver.interner.function_meta(func_id); @@ -427,7 +430,7 @@ pub(crate) fn check_methods_signatures( // If that's the case, a `MethodNotInTrait` error has already been thrown, and we can ignore // the impl method, since there's nothing in the trait to match its signature against. if let Some(trait_method) = - the_trait.methods.iter().find(|method| method.name.0.contents == func_name) + trait_methods.iter().find(|method| method.name.0.contents == func_name) { let impl_function_type = impl_method.typ.instantiate(resolver.interner); @@ -442,7 +445,7 @@ pub(crate) fn check_methods_signatures( let error = DefCollectorErrorKind::MismatchTraitImplementationNumGenerics { impl_method_generic_count, trait_method_generic_count, - trait_name: the_trait.name.to_string(), + trait_name: resolver.interner.get_trait(trait_id).name.to_string(), method_name: func_name.to_string(), span: impl_method.location.span, }; @@ -472,7 +475,7 @@ pub(crate) fn check_methods_signatures( let error = DefCollectorErrorKind::MismatchTraitImplementationNumParameters { actual_num_parameters: impl_method.parameters.0.len(), expected_num_parameters: trait_method.arguments().len(), - trait_name: the_trait.name.to_string(), + trait_name: resolver.interner.get_trait(trait_id).name.to_string(), method_name: func_name.to_string(), span: impl_method.location.span, }; @@ -498,5 +501,7 @@ pub(crate) fn check_methods_signatures( } } + let the_trait = resolver.interner.get_trait_mut(trait_id); + the_trait.set_methods(trait_methods); the_trait.self_type_typevar.unbind(the_trait.self_type_typevar_id); } diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs index 6cf2d669b9..4cf910221e 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs @@ -349,7 +349,7 @@ impl<'a> ModCollector<'a> { let name = trait_definition.name.clone(); // Create the corresponding module for the trait namespace - let id = match self.push_child_module(&name, self.file_id, false, false) { + let trait_id = match self.push_child_module(&name, self.file_id, false, false) { Ok(local_id) => TraitId(ModuleId { krate, local_id }), Err(error) => { errors.push((error.into(), self.file_id)); @@ -359,7 +359,7 @@ impl<'a> ModCollector<'a> { // Add the trait to scope so its path can be looked up later let result = - self.def_collector.def_map.modules[self.module_id.0].declare_trait(name, id); + self.def_collector.def_map.modules[self.module_id.0].declare_trait(name, trait_id); if let Err((first_def, second_def)) = result { let error = DefCollectorErrorKind::Duplicate { @@ -400,9 +400,9 @@ impl<'a> ModCollector<'a> { let location = Location::new(name.span(), self.file_id); context .def_interner - .push_function_definition(func_id, modifiers, id.0, location); + .push_function_definition(func_id, modifiers, trait_id.0, location); - match self.def_collector.def_map.modules[id.0.local_id.0] + match self.def_collector.def_map.modules[trait_id.0.local_id.0] .declare_function(name.clone(), func_id) { Ok(()) => { @@ -437,7 +437,7 @@ impl<'a> ModCollector<'a> { let stmt_id = context.def_interner.push_empty_global(); if let Err((first_def, second_def)) = self.def_collector.def_map.modules - [id.0.local_id.0] + [trait_id.0.local_id.0] .declare_global(name.clone(), stmt_id) { let error = DefCollectorErrorKind::Duplicate { @@ -451,7 +451,7 @@ impl<'a> ModCollector<'a> { TraitItem::Type { name } => { // TODO(nickysn or alexvitkov): implement context.def_interner.push_empty_type_alias and get an id, instead of using TypeAliasId::dummy_id() if let Err((first_def, second_def)) = self.def_collector.def_map.modules - [id.0.local_id.0] + [trait_id.0.local_id.0] .declare_type_alias(name.clone(), TypeAliasId::dummy_id()) { let error = DefCollectorErrorKind::Duplicate { @@ -473,7 +473,7 @@ impl<'a> ModCollector<'a> { trait_def: trait_definition, fns_with_default_impl: unresolved_functions, }; - self.def_collector.collected_traits.insert(id, unresolved); + self.def_collector.collected_traits.insert(trait_id, unresolved); } errors } diff --git a/compiler/noirc_frontend/src/hir/resolution/resolver.rs b/compiler/noirc_frontend/src/hir/resolution/resolver.rs index 68c33c93c4..9657141931 100644 --- a/compiler/noirc_frontend/src/hir/resolution/resolver.rs +++ b/compiler/noirc_frontend/src/hir/resolution/resolver.rs @@ -523,7 +523,7 @@ impl<'a> Resolver<'a> { _new_variables: &mut Generics, ) -> Type { if let Some(t) = self.lookup_trait_or_error(path) { - Type::TraitAsType(t) + Type::TraitAsType(t.id, Rc::new(t.name.to_string())) } else { Type::Error } @@ -938,7 +938,7 @@ impl<'a> Resolver<'a> { | Type::Constant(_) | Type::NamedGeneric(_, _) | Type::NotConstant - | Type::TraitAsType(_) + | Type::TraitAsType(..) | Type::Forall(_, _) => (), Type::Array(length, element_type) => { @@ -1498,8 +1498,8 @@ impl<'a> Resolver<'a> { self.interner.get_struct(type_id) } - pub fn get_trait(&self, trait_id: TraitId) -> Trait { - self.interner.get_trait(trait_id) + pub fn get_trait_mut(&mut self, trait_id: TraitId) -> &mut Trait { + self.interner.get_trait_mut(trait_id) } fn lookup(&mut self, path: Path) -> Result { @@ -1542,9 +1542,9 @@ impl<'a> Resolver<'a> { } /// Lookup a given trait by name/path. - fn lookup_trait_or_error(&mut self, path: Path) -> Option { + fn lookup_trait_or_error(&mut self, path: Path) -> Option<&mut Trait> { match self.lookup(path) { - Ok(trait_id) => Some(self.get_trait(trait_id)), + Ok(trait_id) => Some(self.get_trait_mut(trait_id)), Err(error) => { self.push_err(error); None @@ -1592,9 +1592,9 @@ impl<'a> Resolver<'a> { if name == SELF_TYPE_NAME { let the_trait = self.interner.get_trait(trait_id); - if let Some(method) = the_trait.find_method(method.clone()) { + if let Some(method) = the_trait.find_method(method.0.contents.as_str()) { let self_type = Type::TypeVariable( - the_trait.self_type_typevar, + the_trait.self_type_typevar.clone(), crate::TypeVariableKind::Normal, ); return Some((HirExpression::TraitMethodReference(method), self_type)); @@ -1628,7 +1628,7 @@ impl<'a> Resolver<'a> { { let the_trait = self.interner.get_trait(trait_id); if let Some(method) = - the_trait.find_method(path.segments.last().unwrap().clone()) + the_trait.find_method(path.segments.last().unwrap().0.contents.as_str()) { let self_type = self.resolve_type(typ.clone()); return Some((HirExpression::TraitMethodReference(method), self_type)); diff --git a/compiler/noirc_frontend/src/hir/resolution/traits.rs b/compiler/noirc_frontend/src/hir/resolution/traits.rs index 84f51f7ff6..d4969c52c7 100644 --- a/compiler/noirc_frontend/src/hir/resolution/traits.rs +++ b/compiler/noirc_frontend/src/hir/resolution/traits.rs @@ -16,7 +16,7 @@ use crate::{ def_map::{CrateDefMap, ModuleDefId, ModuleId}, Context, }, - hir_def::traits::{Trait, TraitConstant, TraitFunction, TraitImpl, TraitType}, + hir_def::traits::{TraitConstant, TraitFunction, TraitImpl, TraitType}, node_interner::{FuncId, NodeInterner, TraitId}, Path, Shared, TraitItem, Type, TypeBinding, TypeVariableKind, }; @@ -90,7 +90,7 @@ fn resolve_trait_methods( }); let file = def_maps[&crate_id].file_id(unresolved_trait.module_id); - let mut res = vec![]; + let mut functions = vec![]; let mut resolver_errors = vec![]; for item in &unresolved_trait.trait_def.items { if let TraitItem::Function { @@ -121,7 +121,8 @@ fn resolve_trait_methods( }); // Ensure the trait is generic over the Self type as well - generics.push((the_trait.self_type_typevar_id, the_trait.self_type_typevar)); + let the_trait = resolver.interner.get_trait(trait_id); + generics.push((the_trait.self_type_typevar_id, the_trait.self_type_typevar.clone())); let name = name.clone(); let span: Span = name.span(); @@ -149,11 +150,11 @@ fn resolve_trait_methods( default_impl_file_id: unresolved_trait.file_id, default_impl_module_id: unresolved_trait.module_id, }; - res.push(f); + functions.push(f); resolver_errors.extend(take_errors_filter_self_not_resolved(file, resolver)); } } - (res, resolver_errors) + (functions, resolver_errors) } fn collect_trait_impl_methods( @@ -167,15 +168,18 @@ fn collect_trait_impl_methods( // for a particular method, the default implementation will be added at that slot. let mut ordered_methods = Vec::new(); - let the_trait = interner.get_trait(trait_id); - // check whether the trait implementation is in the same crate as either the trait or the type let mut errors = - check_trait_impl_crate_coherence(interner, &the_trait, trait_impl, crate_id, def_maps); + check_trait_impl_crate_coherence(interner, trait_id, trait_impl, crate_id, def_maps); // set of function ids that have a corresponding method in the trait let mut func_ids_in_trait = HashSet::new(); - for method in &the_trait.methods { + // Temporarily take ownership of the trait's methods so we can iterate over them + // while also mutating the interner + let the_trait = interner.get_trait_mut(trait_id); + let methods = std::mem::take(&mut the_trait.methods); + + for method in &methods { let overrides: Vec<_> = trait_impl .methods .functions @@ -197,7 +201,7 @@ fn collect_trait_impl_methods( )); } else { let error = DefCollectorErrorKind::TraitMissingMethod { - trait_name: the_trait.name.clone(), + trait_name: interner.get_trait(trait_id).name.clone(), method_name: method.name.clone(), trait_impl_span: trait_impl.object_type.span.expect("type must have a span"), }; @@ -221,6 +225,10 @@ fn collect_trait_impl_methods( } } + // Restore the methods that were taken before the for loop + let the_trait = interner.get_trait_mut(trait_id); + the_trait.set_methods(methods); + // Emit MethodNotInTrait error for methods in the impl block that // don't have a corresponding method signature defined in the trait for (_, func_id, func) in &trait_impl.methods.functions { @@ -299,7 +307,7 @@ pub(crate) fn collect_trait_impls( fn check_trait_impl_crate_coherence( interner: &mut NodeInterner, - the_trait: &Trait, + trait_id: TraitId, trait_impl: &UnresolvedTraitImpl, current_crate: CrateId, def_maps: &BTreeMap, @@ -316,6 +324,7 @@ fn check_trait_impl_crate_coherence( _ => CrateId::Dummy, }; + let the_trait = interner.get_trait(trait_id); if current_crate != the_trait.crate_id && current_crate != object_crate { let error = DefCollectorErrorKind::TraitImplOrphaned { span: trait_impl.object_type.span.expect("object type must have a span"), diff --git a/compiler/noirc_frontend/src/hir/type_check/expr.rs b/compiler/noirc_frontend/src/hir/type_check/expr.rs index 5263434a35..720ed8d5b5 100644 --- a/compiler/noirc_frontend/src/hir/type_check/expr.rs +++ b/compiler/noirc_frontend/src/hir/type_check/expr.rs @@ -891,7 +891,9 @@ impl<'interner> TypeChecker<'interner> { } } } - Type::TraitAsType(_trait) => { + // TODO: We should allow method calls on `impl Trait`s eventually. + // For now it is fine since they are only allowed on return types. + Type::TraitAsType(..) => { self.errors.push(TypeCheckError::UnresolvedMethodCall { method_name: method_name.to_string(), object_type: object_type.clone(), diff --git a/compiler/noirc_frontend/src/hir/type_check/mod.rs b/compiler/noirc_frontend/src/hir/type_check/mod.rs index 6108e74917..9599104709 100644 --- a/compiler/noirc_frontend/src/hir/type_check/mod.rs +++ b/compiler/noirc_frontend/src/hir/type_check/mod.rs @@ -101,8 +101,8 @@ pub fn type_check_func(interner: &mut NodeInterner, func_id: FuncId) -> Vec, + pub constants: Vec, pub types: Vec, @@ -124,9 +125,9 @@ impl Trait { self.methods = methods; } - pub fn find_method(&self, name: Ident) -> Option { + pub fn find_method(&self, name: &str) -> Option { for (idx, method) in self.methods.iter().enumerate() { - if method.name == name { + if &method.name == name { return Some(TraitMethodId { trait_id: self.id, method_index: idx }); } } diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index 2c403d8d19..977293ec67 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -6,7 +6,7 @@ use std::{ use crate::{ hir::type_check::TypeCheckError, - node_interner::{ExprId, NodeInterner, TypeAliasId}, + node_interner::{ExprId, NodeInterner, TraitId, TypeAliasId}, }; use iter_extended::vecmap; use noirc_errors::{Location, Span}; @@ -14,10 +14,7 @@ use noirc_printable_type::PrintableType; use crate::{node_interner::StructId, Ident, Signedness}; -use super::{ - expr::{HirCallExpression, HirExpression, HirIdent}, - traits::Trait, -}; +use super::expr::{HirCallExpression, HirExpression, HirIdent}; #[derive(Debug, PartialEq, Eq, Clone, Hash)] pub enum Type { @@ -65,7 +62,10 @@ pub enum Type { /// different argument types each time. TypeVariable(TypeVariable, TypeVariableKind), - TraitAsType(Trait), + /// `impl Trait` when used in a type position. + /// These are only matched based on the TraitId. The trait name paramer is only + /// used for displaying error messages using the name of the trait. + TraitAsType(TraitId, /*name:*/ Rc), /// NamedGenerics are the 'T' or 'U' in a user-defined generic function /// like `fn foo(...) {}`. Unlike TypeVariables, they cannot be bound over. @@ -132,7 +132,7 @@ impl Type { Type::FmtString(_, _) | Type::Unit | Type::TypeVariable(_, _) - | Type::TraitAsType(_) + | Type::TraitAsType(..) | Type::NamedGeneric(_, _) | Type::Function(_, _, _) | Type::MutableReference(_) @@ -575,7 +575,7 @@ impl Type { | Type::NamedGeneric(_, _) | Type::NotConstant | Type::Forall(_, _) - | Type::TraitAsType(_) => false, + | Type::TraitAsType(..) => false, Type::Array(length, elem) => { elem.contains_numeric_typevar(target_id) || named_generic_id_matches_target(length) @@ -714,8 +714,8 @@ impl std::fmt::Display for Type { write!(f, "{}<{}>", s.borrow(), args.join(", ")) } } - Type::TraitAsType(tr) => { - write!(f, "impl {}", tr.name) + Type::TraitAsType(_id, name) => { + write!(f, "impl {}", name) } Type::Tuple(elements) => { let elements = vecmap(elements, ToString::to_string); @@ -1279,7 +1279,7 @@ impl Type { | Type::Integer(_, _) | Type::Bool | Type::Unit - | Type::TraitAsType(_) + | Type::TraitAsType(..) | Type::Constant(_) | Type::NotConstant | Type::Error => (), @@ -1372,7 +1372,6 @@ impl Type { let fields = vecmap(fields, |field| field.substitute(type_bindings)); Type::Tuple(fields) } - Type::TraitAsType(_) => todo!(), Type::Forall(typevars, typ) => { // Trying to substitute a variable defined within a nested Forall // is usually impossible and indicative of an error in the type checker somewhere. @@ -1396,6 +1395,7 @@ impl Type { | Type::Integer(_, _) | Type::Bool | Type::Constant(_) + | Type::TraitAsType(..) | Type::Error | Type::NotConstant | Type::Unit => self.clone(), @@ -1412,7 +1412,6 @@ impl Type { let field_occurs = fields.occurs(target_id); len_occurs || field_occurs } - Type::TraitAsType(_) => todo!(), Type::Struct(_, generic_args) => generic_args.iter().any(|arg| arg.occurs(target_id)), Type::Tuple(fields) => fields.iter().any(|field| field.occurs(target_id)), Type::NamedGeneric(binding, _) | Type::TypeVariable(binding, _) => { @@ -1435,6 +1434,7 @@ impl Type { | Type::Integer(_, _) | Type::Bool | Type::Constant(_) + | Type::TraitAsType(..) | Type::Error | Type::NotConstant | Type::Unit => false, @@ -1482,7 +1482,7 @@ impl Type { // Expect that this function should only be called on instantiated types Forall(..) => unreachable!(), - TraitAsType(_) + TraitAsType(..) | FieldElement | Integer(_, _) | Bool @@ -1590,7 +1590,7 @@ impl From<&Type> for PrintableType { let fields = vecmap(fields, |(name, typ)| (name, typ.into())); PrintableType::Struct { fields, name: struct_type.name.to_string() } } - Type::TraitAsType(_) => unreachable!(), + Type::TraitAsType(..) => unreachable!(), Type::Tuple(_) => todo!("printing tuple types is not yet implemented"), Type::TypeVariable(_, _) => unreachable!(), Type::NamedGeneric(..) => unreachable!(), diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index 52ed0c746e..78cde11593 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -231,7 +231,7 @@ impl<'interner> Monomorphizer<'interner> { let body_expr_id = *self.interner.function(&f).as_expr(); let body_return_type = self.interner.id_type(body_expr_id); let return_type = self.convert_type(match meta.return_type() { - Type::TraitAsType(_) => &body_return_type, + Type::TraitAsType(..) => &body_return_type, _ => meta.return_type(), }); @@ -720,7 +720,7 @@ impl<'interner> Monomorphizer<'interner> { ast::Type::Slice(element) } } - HirType::TraitAsType(_) => { + HirType::TraitAsType(..) => { unreachable!("All TraitAsType should be replaced before calling convert_type"); } HirType::NamedGeneric(binding, _) => { diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index 8df31a51fc..d49e236c68 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -863,12 +863,16 @@ impl NodeInterner { self.structs[&id].clone() } - pub fn get_trait(&self, id: TraitId) -> Trait { - self.traits[&id].clone() + pub fn get_trait(&self, id: TraitId) -> &Trait { + &self.traits[&id] } - pub fn try_get_trait(&self, id: TraitId) -> Option { - self.traits.get(&id).cloned() + pub fn get_trait_mut(&mut self, id: TraitId) -> &mut Trait { + self.traits.get_mut(&id).expect("get_trait_mut given invalid TraitId") + } + + pub fn try_get_trait(&self, id: TraitId) -> Option<&Trait> { + self.traits.get(&id) } pub fn get_type_alias(&self, id: TypeAliasId) -> &TypeAliasType { @@ -892,7 +896,7 @@ impl NodeInterner { let typ = self.id_type(def_id); if let Type::Function(args, ret, env) = &typ { let def = self.definition(def_id); - if let Type::TraitAsType(_trait) = ret.as_ref() { + if let Type::TraitAsType(..) = ret.as_ref() { if let DefinitionKind::Function(func_id) = def.kind { let f = self.function(&func_id); let func_body = f.as_expr(); @@ -1382,6 +1386,6 @@ fn get_type_method_key(typ: &Type) -> Option { | Type::Error | Type::NotConstant | Type::Struct(_, _) - | Type::TraitAsType(_) => None, + | Type::TraitAsType(..) => None, } } diff --git a/tooling/noirc_abi/src/lib.rs b/tooling/noirc_abi/src/lib.rs index 41ef7cde62..884c49c010 100644 --- a/tooling/noirc_abi/src/lib.rs +++ b/tooling/noirc_abi/src/lib.rs @@ -161,7 +161,7 @@ impl AbiType { Type::Error => unreachable!(), Type::Unit => unreachable!(), Type::Constant(_) => unreachable!(), - Type::TraitAsType(_) => unreachable!(), + Type::TraitAsType(..) => unreachable!(), Type::Struct(def, ref args) => { let struct_type = def.borrow(); let fields = struct_type.get_fields(args);