Skip to content

Commit

Permalink
feat: add de-sugaring for impl Trait in function parameters (#4919)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Resolves #4540

## Summary\*

In the resolver, when `x: impl Trait` is encountered as a function
parameter,
it is desugared to:
```rust
fn function_name<T_impl_trait>(x: T_impl_trait) where T_impl_trait: Trait
```

## Additional Context

In the current iteration of this PR, `T[func_id]_impl_[trait_name]` is
checked for collisions and
incremented (`T[func_id]_impl_[trait_name]_N`) until there's no
collision.

Note that this PR adds a few ~unrelated cases where `TraitAsType` could
have arguments that were previously skipped:

https://github.com/noir-lang/noir/pull/4919/files#diff-52c7ae61478bab09a4d23320128eead3a88cde6a7be16fb8b070b5512d690bbdR1203

## Documentation\*

Check one:
- [ ] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.
  • Loading branch information
michaeljklein committed Apr 29, 2024
1 parent 8a3c7f1 commit 8aad2e4
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 17 deletions.
8 changes: 8 additions & 0 deletions compiler/noirc_arena/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#![warn(unreachable_pub)]
#![warn(clippy::semicolon_if_nothing_returned)]

use std::fmt;

#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd, Hash)]
pub struct Index(usize);

Expand All @@ -25,6 +27,12 @@ impl Index {
}
}

impl fmt::Display for Index {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.0.fmt(f)
}
}

#[derive(Clone, Debug)]
pub struct Arena<T> {
pub vec: Vec<T>,
Expand Down
60 changes: 56 additions & 4 deletions compiler/noirc_frontend/src/hir/resolution/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use crate::ast::{
ArrayLiteral, BinaryOpKind, BlockExpression, Distinctness, Expression, ExpressionKind,
ForRange, FunctionDefinition, FunctionKind, FunctionReturnType, Ident, ItemVisibility, LValue,
LetStatement, Literal, NoirFunction, NoirStruct, NoirTypeAlias, Param, Path, PathKind, Pattern,
Statement, StatementKind, UnaryOp, UnresolvedGenerics, UnresolvedTraitConstraint,
Statement, StatementKind, TraitBound, UnaryOp, UnresolvedGenerics, UnresolvedTraitConstraint,
UnresolvedType, UnresolvedTypeData, UnresolvedTypeExpression, Visibility, ERROR_IDENT,
};
use crate::graph::CrateId;
Expand Down Expand Up @@ -202,23 +202,70 @@ impl<'a> Resolver<'a> {
self.errors.push(err);
}

/// This turns function parameters of the form:
/// fn foo(x: impl Bar)
///
/// into
/// fn foo<T0_impl_Bar>(x: T0_impl_Bar) where T0_impl_Bar: Bar
fn desugar_impl_trait_args(&mut self, func: &mut NoirFunction, func_id: FuncId) {
let mut impl_trait_generics = HashSet::new();
let mut counter: usize = 0;
for parameter in func.def.parameters.iter_mut() {
if let UnresolvedTypeData::TraitAsType(path, args) = &parameter.typ.typ {
let mut new_generic_ident: Ident =
format!("T{}_impl_{}", func_id, path.as_string()).into();
let mut new_generic_path = Path::from_ident(new_generic_ident.clone());
while impl_trait_generics.contains(&new_generic_ident)
|| self.lookup_generic_or_global_type(&new_generic_path).is_some()
{
new_generic_ident =
format!("T{}_impl_{}_{}", func_id, path.as_string(), counter).into();
new_generic_path = Path::from_ident(new_generic_ident.clone());
counter += 1;
}
impl_trait_generics.insert(new_generic_ident.clone());

let is_synthesized = true;
let new_generic_type_data =
UnresolvedTypeData::Named(new_generic_path, vec![], is_synthesized);
let new_generic_type =
UnresolvedType { typ: new_generic_type_data.clone(), span: None };
let new_trait_bound = TraitBound {
trait_path: path.clone(),
trait_id: None,
trait_generics: args.to_vec(),
};
let new_trait_constraint = UnresolvedTraitConstraint {
typ: new_generic_type,
trait_bound: new_trait_bound,
};

parameter.typ.typ = new_generic_type_data;
func.def.generics.push(new_generic_ident);
func.def.where_clause.push(new_trait_constraint);
}
}
self.add_generics(&impl_trait_generics.into_iter().collect());
}

/// Resolving a function involves interning the metadata
/// interning any statements inside of the function
/// and interning the function itself
/// We resolve and lower the function at the same time
/// Since lowering would require scope data, unless we add an extra resolution field to the AST
pub fn resolve_function(
mut self,
func: NoirFunction,
mut func: NoirFunction,
func_id: FuncId,
) -> (HirFunction, FuncMeta, Vec<ResolverError>) {
self.scopes.start_function();
self.current_item = Some(DependencyId::Function(func_id));

// Check whether the function has globals in the local module and add them to the scope
self.resolve_local_globals();

self.add_generics(&func.def.generics);

self.desugar_impl_trait_args(&mut func, func_id);
self.trait_bounds = func.def.where_clause.clone();

let is_low_level_or_oracle = func
Expand Down Expand Up @@ -1150,10 +1197,15 @@ impl<'a> Resolver<'a> {
| Type::TypeVariable(_, _)
| Type::Constant(_)
| Type::NamedGeneric(_, _)
| Type::TraitAsType(..)
| Type::Code
| Type::Forall(_, _) => (),

Type::TraitAsType(_, _, args) => {
for arg in args {
Self::find_numeric_generics_in_type(arg, found);
}
}

Type::Array(length, element_type) => {
if let Type::NamedGeneric(type_variable, name) = length.as_ref() {
found.insert(name.to_string(), type_variable.clone());
Expand Down
33 changes: 20 additions & 13 deletions compiler/noirc_frontend/src/hir_def/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -643,9 +643,11 @@ impl Type {
| Type::Constant(_)
| Type::NamedGeneric(_, _)
| Type::Forall(_, _)
| Type::Code
| Type::TraitAsType(..) => false,
| Type::Code => false,

Type::TraitAsType(_, _, args) => {
args.iter().any(|generic| generic.contains_numeric_typevar(target_id))
}
Type::Array(length, elem) => {
elem.contains_numeric_typevar(target_id) || named_generic_id_matches_target(length)
}
Expand Down Expand Up @@ -1591,11 +1593,17 @@ impl Type {
element.substitute_helper(type_bindings, substitute_bound_typevars),
)),

Type::TraitAsType(s, name, args) => {
let args = vecmap(args, |arg| {
arg.substitute_helper(type_bindings, substitute_bound_typevars)
});
Type::TraitAsType(*s, name.clone(), args)
}

Type::FieldElement
| Type::Integer(_, _)
| Type::Bool
| Type::Constant(_)
| Type::TraitAsType(..)
| Type::Error
| Type::Code
| Type::Unit => self.clone(),
Expand All @@ -1613,7 +1621,9 @@ impl Type {
let field_occurs = fields.occurs(target_id);
len_occurs || field_occurs
}
Type::Struct(_, generic_args) | Type::Alias(_, generic_args) => {
Type::Struct(_, generic_args)
| Type::Alias(_, generic_args)
| Type::TraitAsType(_, _, generic_args) => {
generic_args.iter().any(|arg| arg.occurs(target_id))
}
Type::Tuple(fields) => fields.iter().any(|field| field.occurs(target_id)),
Expand All @@ -1637,7 +1647,6 @@ impl Type {
| Type::Integer(_, _)
| Type::Bool
| Type::Constant(_)
| Type::TraitAsType(..)
| Type::Error
| Type::Code
| Type::Unit => false,
Expand Down Expand Up @@ -1689,16 +1698,14 @@ impl Type {

MutableReference(element) => MutableReference(Box::new(element.follow_bindings())),

TraitAsType(s, name, args) => {
let args = vecmap(args, |arg| arg.follow_bindings());
TraitAsType(*s, name.clone(), args)
}

// Expect that this function should only be called on instantiated types
Forall(..) => unreachable!(),
TraitAsType(..)
| FieldElement
| Integer(_, _)
| Bool
| Constant(_)
| Unit
| Code
| Error => self.clone(),
FieldElement | Integer(_, _) | Bool | Constant(_) | Unit | Code | Error => self.clone(),
}
}

Expand Down
7 changes: 7 additions & 0 deletions compiler/noirc_frontend/src/node_interner.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::borrow::Cow;
use std::collections::HashMap;
use std::fmt;
use std::ops::Deref;

use fm::FileId;
Expand Down Expand Up @@ -314,6 +315,12 @@ impl FuncId {
}
}

impl fmt::Display for FuncId {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.0.fmt(f)
}
}

#[derive(Debug, Eq, PartialEq, Hash, Copy, Clone, PartialOrd, Ord)]
pub struct StructId(ModuleId);

Expand Down
5 changes: 5 additions & 0 deletions compiler/noirc_frontend/src/parser/parser/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ mod test {
"fn func_name<T>(f: Field, y : T) where T: SomeTrait + {}",
// The following should produce compile error on later stage. From the parser's perspective it's fine
"fn func_name<A>(f: Field, y : Field, z : Field) where T: SomeTrait {}",
// TODO: this fails with known EOF != EOF error
// https://github.com/noir-lang/noir/issues/4763
// fn func_name(x: impl Eq) {} with error Expected an end of input but found end of input
// "fn func_name(x: impl Eq) {}",
"fn func_name<T>(x: impl Eq, y : T) where T: SomeTrait + Eq {}",
],
);

Expand Down
64 changes: 64 additions & 0 deletions compiler/noirc_frontend/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,70 @@ mod test {
}
}

#[test]
fn check_trait_as_type_as_fn_parameter() {
let src = "
trait Eq {
fn eq(self, other: Self) -> bool;
}
struct Foo {
a: u64,
}
impl Eq for Foo {
fn eq(self, other: Foo) -> bool { self.a == other.a }
}
fn test_eq(x: impl Eq) -> bool {
x.eq(x)
}
fn main(a: Foo) -> pub bool {
test_eq(a)
}";

let errors = get_program_errors(src);
errors.iter().for_each(|err| println!("{:?}", err));
assert!(errors.is_empty());
}

#[test]
fn check_trait_as_type_as_two_fn_parameters() {
let src = "
trait Eq {
fn eq(self, other: Self) -> bool;
}
trait Test {
fn test(self) -> bool;
}
struct Foo {
a: u64,
}
impl Eq for Foo {
fn eq(self, other: Foo) -> bool { self.a == other.a }
}
impl Test for u64 {
fn test(self) -> bool { self == self }
}
fn test_eq(x: impl Eq, y: impl Test) -> bool {
x.eq(x) == y.test()
}
fn main(a: Foo, b: u64) -> pub bool {
test_eq(a, b)
}";

let errors = get_program_errors(src);
errors.iter().for_each(|err| println!("{:?}", err));
assert!(errors.is_empty());
}

fn get_program_captures(src: &str) -> Vec<Vec<String>> {
let (program, context, _errors) = get_program(src);
let interner = context.def_interner;
Expand Down
3 changes: 3 additions & 0 deletions tooling/nargo_fmt/tests/expected/impl_trait_fn_parameter.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
fn func_name(x: impl Eq) {}

fn func_name<T>(x: impl Eq, y: T) where T: SomeTrait + Eq {}
3 changes: 3 additions & 0 deletions tooling/nargo_fmt/tests/input/impl_trait_fn_parameter.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
fn func_name(x: impl Eq) {}

fn func_name<T>(x: impl Eq, y: T) where T: SomeTrait + Eq {}

0 comments on commit 8aad2e4

Please sign in to comment.