Skip to content

Commit

Permalink
fix: Allow trait method references from the trait name (#3774)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Resolves #3773
Part of #2568 

## Summary\*

This PR implements the ability to be able to call trait methods such as
`Default::default()` where previously we required
`StructName::default()`. The specific impl is selected via type
inference. Previously, this resulted in a compiler panic.

## Additional Context

When a trait method's type isn't constrained enough or is otherwise
still ambiguous after type inference, the compiler currently just
selects the first trait implementation that matches. E.g. `let _ =
Default::default();`. This is a separate issue, so I'll create a new
issue for this.

## Documentation\*

Check one:
- [ ] No documentation needed. 
- [ ] Documentation included in this PR.
- [x] **[Exceptional Case]** Documentation to be submitted in a separate
PR.
- I think traits are finally stable enough to start writing
documentation for. I'm going to submit another PR to remove some of the
experimental warnings the compiler has for traits, and with it add
documentation for traits as well. This wouldn't stabilize all trait
features (e.g. associated types are still unimplemented), but users will
now be able to use basic traits with functions without warnings. Any
unimplemented trait items (associated types again) will not be included
in the documentation.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.

---------

Co-authored-by: kevaundray <kevtheappdev@gmail.com>
  • Loading branch information
jfecher and kevaundray committed Dec 13, 2023
1 parent ce80f5a commit cfa34d4
Show file tree
Hide file tree
Showing 9 changed files with 209 additions and 136 deletions.
1 change: 1 addition & 0 deletions compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ pub struct UnresolvedTrait {
pub module_id: LocalModuleId,
pub crate_id: CrateId,
pub trait_def: NoirTrait,
pub method_ids: HashMap<String, FuncId>,
pub fns_with_default_impl: UnresolvedFunctions,
}

Expand Down
7 changes: 6 additions & 1 deletion compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::vec;
use std::{collections::HashMap, vec};

use acvm::acir::acir_field::FieldOptions;
use fm::FileId;
Expand Down Expand Up @@ -378,6 +378,8 @@ impl<'a> ModCollector<'a> {
functions: Vec::new(),
trait_id: None,
};

let mut method_ids = HashMap::new();
for trait_item in &trait_definition.items {
match trait_item {
TraitItem::Function {
Expand All @@ -389,6 +391,8 @@ impl<'a> ModCollector<'a> {
body,
} => {
let func_id = context.def_interner.push_empty_fn();
method_ids.insert(name.to_string(), func_id);

let modifiers = FunctionModifiers {
name: name.to_string(),
visibility: crate::FunctionVisibility::Public,
Expand Down Expand Up @@ -473,6 +477,7 @@ impl<'a> ModCollector<'a> {
module_id: self.module_id,
crate_id: krate,
trait_def: trait_definition,
method_ids,
fns_with_default_impl: unresolved_functions,
};
self.def_collector.collected_traits.insert(trait_id, unresolved);
Expand Down
99 changes: 79 additions & 20 deletions compiler/noirc_frontend/src/hir/resolution/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::hir_def::expr::{
};

use crate::hir_def::traits::{Trait, TraitConstraint};
use crate::token::FunctionAttribute;
use crate::token::{Attributes, FunctionAttribute};
use regex::Regex;
use std::collections::{BTreeMap, HashSet};
use std::rc::Rc;
Expand All @@ -37,11 +37,11 @@ use crate::{
StatementKind,
};
use crate::{
ArrayLiteral, ContractFunctionType, Distinctness, ForRange, FunctionVisibility, Generics,
LValue, NoirStruct, NoirTypeAlias, Param, Path, PathKind, Pattern, Shared, StructType, Type,
TypeAliasType, TypeBinding, TypeVariable, UnaryOp, UnresolvedGenerics,
UnresolvedTraitConstraint, UnresolvedType, UnresolvedTypeData, UnresolvedTypeExpression,
Visibility, ERROR_IDENT,
ArrayLiteral, ContractFunctionType, Distinctness, ForRange, FunctionDefinition,
FunctionReturnType, FunctionVisibility, Generics, LValue, NoirStruct, NoirTypeAlias, Param,
Path, PathKind, Pattern, Shared, StructType, Type, TypeAliasType, TypeBinding, TypeVariable,
UnaryOp, UnresolvedGenerics, UnresolvedTraitConstraint, UnresolvedType, UnresolvedTypeData,
UnresolvedTypeExpression, Visibility, ERROR_IDENT,
};
use fm::FileId;
use iter_extended::vecmap;
Expand Down Expand Up @@ -200,6 +200,52 @@ impl<'a> Resolver<'a> {
(hir_func, func_meta, self.errors)
}

pub fn resolve_trait_function(
&mut self,
name: &Ident,
parameters: &[(Ident, UnresolvedType)],
return_type: &FunctionReturnType,
where_clause: &[UnresolvedTraitConstraint],
func_id: FuncId,
) -> (HirFunction, FuncMeta) {
self.scopes.start_function();

// Check whether the function has globals in the local module and add them to the scope
self.resolve_local_globals();

self.trait_bounds = where_clause.to_vec();

let kind = FunctionKind::Normal;
let def = FunctionDefinition {
name: name.clone(),
attributes: Attributes::empty(),
is_open: false,
is_internal: false,
is_unconstrained: false,
visibility: FunctionVisibility::Public, // Trait functions are always public
generics: Vec::new(), // self.generics should already be set
parameters: vecmap(parameters, |(name, typ)| Param {
visibility: Visibility::Private,
pattern: Pattern::Identifier(name.clone()),
typ: typ.clone(),
span: name.span(),
}),
body: BlockExpression(Vec::new()),
span: name.span(),
where_clause: where_clause.to_vec(),
return_type: return_type.clone(),
return_visibility: Visibility::Private,
return_distinctness: Distinctness::DuplicationAllowed,
};

let (hir_func, func_meta) = self.intern_function(NoirFunction { kind, def }, func_id);
let func_scope_tree = self.scopes.end_function();
self.check_for_unused_variables_in_scope_tree(func_scope_tree);

self.trait_bounds.clear();
(hir_func, func_meta)
}

fn check_for_unused_variables_in_scope_tree(&mut self, scope_decls: ScopeTree) {
let mut unused_vars = Vec::new();
for scope in scope_decls.0.into_iter() {
Expand Down Expand Up @@ -1584,27 +1630,39 @@ impl<'a> Resolver<'a> {
&mut self,
path: &Path,
) -> Option<(HirExpression, Type)> {
if let Some(trait_id) = self.trait_id {
if path.kind == PathKind::Plain && path.segments.len() == 2 {
let name = &path.segments[0].0.contents;
let method = &path.segments[1];
let trait_id = self.trait_id?;

if name == SELF_TYPE_NAME {
let the_trait = self.interner.get_trait(trait_id);
if path.kind == PathKind::Plain && path.segments.len() == 2 {
let name = &path.segments[0].0.contents;
let method = &path.segments[1];

if let Some(method) = the_trait.find_method(method.0.contents.as_str()) {
let self_type = Type::TypeVariable(
the_trait.self_type_typevar.clone(),
crate::TypeVariableKind::Normal,
);
return Some((HirExpression::TraitMethodReference(method), self_type));
}
}
if name == SELF_TYPE_NAME {
let the_trait = self.interner.get_trait(trait_id);
let method = the_trait.find_method(method.0.contents.as_str())?;
let self_type = self.self_type.clone()?;
return Some((HirExpression::TraitMethodReference(method), self_type));
}
}
None
}

// this resolves TraitName::some_static_method
fn resolve_trait_static_method(&mut self, path: &Path) -> Option<(HirExpression, Type)> {
if path.kind == PathKind::Plain && path.segments.len() == 2 {
let method = &path.segments[1];

let mut trait_path = path.clone();
trait_path.pop();
let trait_id = self.lookup(trait_path).ok()?;
let the_trait = self.interner.get_trait(trait_id);

let method = the_trait.find_method(method.0.contents.as_str())?;
let self_type = Type::type_variable(the_trait.self_type_typevar_id);
return Some((HirExpression::TraitMethodReference(method), self_type));
}
None
}

// this resolves a static trait method T::trait_method by iterating over the where clause
fn resolve_trait_method_by_named_generic(
&mut self,
Expand Down Expand Up @@ -1641,6 +1699,7 @@ impl<'a> Resolver<'a> {

fn resolve_trait_generic_path(&mut self, path: &Path) -> Option<(HirExpression, Type)> {
self.resolve_trait_static_method_by_self(path)
.or_else(|| self.resolve_trait_static_method(path))
.or_else(|| self.resolve_trait_method_by_named_generic(path))
}

Expand Down
54 changes: 22 additions & 32 deletions compiler/noirc_frontend/src/hir/resolution/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::collections::{BTreeMap, HashSet};

use fm::FileId;
use iter_extended::vecmap;
use noirc_errors::{Location, Span};
use noirc_errors::Location;

use crate::{
graph::CrateId,
Expand All @@ -22,9 +22,7 @@ use crate::{
};

use super::{
errors::ResolverError,
functions, get_module_mut, get_struct_type,
import::PathResolutionError,
path_resolver::{PathResolver, StandardPathResolver},
resolver::Resolver,
take_errors,
Expand Down Expand Up @@ -92,13 +90,14 @@ fn resolve_trait_methods(

let mut functions = vec![];
let mut resolver_errors = vec![];

for item in &unresolved_trait.trait_def.items {
if let TraitItem::Function {
name,
generics,
parameters,
return_type,
where_clause: _,
where_clause,
body: _,
} = item
{
Expand All @@ -110,6 +109,16 @@ fn resolve_trait_methods(
resolver.add_generics(generics);
resolver.set_self_type(Some(self_type));

let func_id = unresolved_trait.method_ids[&name.0.contents];
let (_, func_meta) = resolver.resolve_trait_function(
name,
parameters,
return_type,
where_clause,
func_id,
);
resolver.interner.push_fn_meta(func_meta, func_id);

let arguments = vecmap(parameters, |param| resolver.resolve_type(param.1.clone()));
let return_type = resolver.resolve_type(return_type.get_type().into_owned());

Expand All @@ -124,14 +133,13 @@ fn resolve_trait_methods(
let the_trait = resolver.interner.get_trait(trait_id);
generics.push((the_trait.self_type_typevar_id, the_trait.self_type_typevar.clone()));

let name = name.clone();
let span: Span = name.span();
let default_impl_list: Vec<_> = unresolved_trait
.fns_with_default_impl
.functions
.iter()
.filter(|(_, _, q)| q.name() == name.0.contents)
.collect();

let default_impl = if default_impl_list.len() == 1 {
Some(Box::new(default_impl_list[0].2.clone()))
} else {
Expand All @@ -140,18 +148,18 @@ fn resolve_trait_methods(

let no_environment = Box::new(Type::Unit);
let function_type = Type::Function(arguments, Box::new(return_type), no_environment);
let typ = Type::Forall(generics, Box::new(function_type));

let f = TraitFunction {
name,
typ,
span,
functions.push(TraitFunction {
name: name.clone(),
typ: Type::Forall(generics, Box::new(function_type)),
span: name.span(),
default_impl,
default_impl_file_id: unresolved_trait.file_id,
default_impl_module_id: unresolved_trait.module_id,
};
functions.push(f);
resolver_errors.extend(take_errors_filter_self_not_resolved(file, resolver));
});

let errors = resolver.take_errors().into_iter();
resolver_errors.extend(errors.map(|resolution_error| (resolution_error.into(), file)));
}
}
(functions, resolver_errors)
Expand Down Expand Up @@ -451,21 +459,3 @@ pub(crate) fn resolve_trait_impls(

methods
}

pub(crate) fn take_errors_filter_self_not_resolved(
file_id: FileId,
resolver: Resolver<'_>,
) -> Vec<(CompilationError, FileId)> {
resolver
.take_errors()
.iter()
.filter(|resolution_error| match resolution_error {
ResolverError::PathResolutionError(PathResolutionError::Unresolved(ident)) => {
&ident.0.contents != "Self"
}
_ => true,
})
.cloned()
.map(|resolution_error| (resolution_error.into(), file_id))
.collect()
}
21 changes: 18 additions & 3 deletions compiler/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
},
types::Type,
},
node_interner::{DefinitionKind, ExprId, FuncId, TraitId, TraitMethodId},
node_interner::{DefinitionKind, ExprId, FuncId, TraitId, TraitImplKind, TraitMethodId},
BinaryOpKind, Signedness, TypeBinding, TypeBindings, TypeVariableKind, UnaryOp,
};

Expand Down Expand Up @@ -289,8 +289,23 @@ impl<'interner> TypeChecker<'interner> {
}
HirExpression::TraitMethodReference(method) => {
let the_trait = self.interner.get_trait(method.trait_id);
let typ = &the_trait.methods[method.method_index].typ;
let (typ, bindings) = typ.instantiate(self.interner);
let typ2 = &the_trait.methods[method.method_index].typ;
let (typ, mut bindings) = typ2.instantiate(self.interner);

// We must also remember to apply these substitutions to the object_type
// referenced by the selected trait impl, if one has yet to be selected.
let impl_kind = self.interner.get_selected_impl_for_ident(*expr_id);
if let Some(TraitImplKind::Assumed { object_type }) = impl_kind {
let the_trait = self.interner.get_trait(method.trait_id);
let object_type = object_type.substitute(&bindings);
bindings.insert(
the_trait.self_type_typevar_id,
(the_trait.self_type_typevar.clone(), object_type.clone()),
);
self.interner
.select_impl_for_ident(*expr_id, TraitImplKind::Assumed { object_type });
}

self.interner.store_instantiation_bindings(*expr_id, bindings);
typ
}
Expand Down
31 changes: 8 additions & 23 deletions compiler/noirc_frontend/src/hir_def/traits.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::collections::HashMap;

use crate::{
graph::CrateId,
node_interner::{FuncId, TraitId, TraitMethodId},
Expand Down Expand Up @@ -43,6 +45,12 @@ pub struct Trait {

pub methods: Vec<TraitFunction>,

/// Maps method_name -> method id.
/// This map is separate from methods since TraitFunction ids
/// are created during collection where we don't yet have all
/// the information needed to create the full TraitFunction.
pub method_ids: HashMap<String, FuncId>,

pub constants: Vec<TraitConstant>,
pub types: Vec<TraitType>,

Expand Down Expand Up @@ -98,29 +106,6 @@ impl PartialEq for Trait {
}

impl Trait {
pub fn new(
id: TraitId,
name: Ident,
crate_id: CrateId,
span: Span,
generics: Generics,
self_type_typevar_id: TypeVariableId,
self_type_typevar: TypeVariable,
) -> Trait {
Trait {
id,
name,
crate_id,
span,
methods: Vec::new(),
constants: Vec::new(),
types: Vec::new(),
generics,
self_type_typevar_id,
self_type_typevar,
}
}

pub fn set_methods(&mut self, methods: Vec<TraitFunction>) {
self.methods = methods;
}
Expand Down
Loading

0 comments on commit cfa34d4

Please sign in to comment.