diff --git a/crates/noirc_frontend/src/ast/statement.rs b/crates/noirc_frontend/src/ast/statement.rs index 2792d51c41..7f77716c5e 100644 --- a/crates/noirc_frontend/src/ast/statement.rs +++ b/crates/noirc_frontend/src/ast/statement.rs @@ -238,6 +238,61 @@ pub enum PathKind { Plain, } +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct UseTree { + pub prefix: Path, + pub kind: UseTreeKind, +} + +impl Display for UseTree { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.prefix)?; + + match &self.kind { + UseTreeKind::Path(name, alias) => { + write!(f, "{name}")?; + + while let Some(alias) = alias { + write!(f, " as {}", alias)?; + } + + Ok(()) + } + UseTreeKind::List(trees) => { + write!(f, "::{{")?; + let tree = vecmap(trees, ToString::to_string).join(", "); + write!(f, "{tree}}}") + } + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum UseTreeKind { + Path(Ident, Option), + List(Vec), +} + +impl UseTree { + pub fn desugar(self, root: Option) -> Vec { + let prefix = if let Some(mut root) = root { + root.segments.extend(self.prefix.segments); + root + } else { + self.prefix + }; + + match self.kind { + UseTreeKind::Path(name, alias) => { + vec![ImportStatement { path: prefix.join(name), alias }] + } + UseTreeKind::List(trees) => { + trees.into_iter().flat_map(|tree| tree.desugar(Some(prefix.clone()))).collect() + } + } + } +} + // Note: Path deliberately doesn't implement Recoverable. // No matter which default value we could give in Recoverable::error, // it would most likely cause further errors during name resolution @@ -248,6 +303,15 @@ pub struct Path { } impl Path { + pub fn pop(&mut self) -> Ident { + self.segments.pop().unwrap() + } + + fn join(mut self, ident: Ident) -> Path { + self.segments.push(ident); + self + } + /// Construct a PathKind::Plain from this single pub fn from_single(name: String, span: Span) -> Path { let segment = Ident::from(Spanned::from(span, name)); diff --git a/crates/noirc_frontend/src/parser/mod.rs b/crates/noirc_frontend/src/parser/mod.rs index a8b7f43fa5..e1aec84d56 100644 --- a/crates/noirc_frontend/src/parser/mod.rs +++ b/crates/noirc_frontend/src/parser/mod.rs @@ -18,7 +18,7 @@ use crate::{ast::ImportStatement, Expression, NoirStruct}; use crate::{ BlockExpression, ExpressionKind, ForExpression, Ident, IndexExpression, LetStatement, MethodCallExpression, NoirFunction, NoirImpl, Path, PathKind, Pattern, Recoverable, Statement, - UnresolvedType, + UnresolvedType, UseTree, }; use acvm::FieldElement; @@ -38,7 +38,7 @@ static UNIQUE_NAME_COUNTER: AtomicU32 = AtomicU32::new(0); pub(crate) enum TopLevelStatement { Function(NoirFunction), Module(Ident), - Import(ImportStatement), + Import(UseTree), Struct(NoirStruct), Impl(NoirImpl), SubModule(SubModule), @@ -252,8 +252,8 @@ impl ParsedModule { self.impls.push(r#impl); } - fn push_import(&mut self, import_stmt: ImportStatement) { - self.imports.push(import_stmt); + fn push_import(&mut self, import_stmt: UseTree) { + self.imports.extend(import_stmt.desugar(None)); } fn push_module_decl(&mut self, mod_name: Ident) { @@ -446,7 +446,7 @@ impl std::fmt::Display for TopLevelStatement { match self { TopLevelStatement::Function(fun) => fun.fmt(f), TopLevelStatement::Module(m) => write!(f, "mod {m}"), - TopLevelStatement::Import(i) => i.fmt(f), + TopLevelStatement::Import(tree) => write!(f, "use {tree}"), TopLevelStatement::Struct(s) => s.fmt(f), TopLevelStatement::Impl(i) => i.fmt(f), TopLevelStatement::SubModule(s) => s.fmt(f), diff --git a/crates/noirc_frontend/src/parser/parser.rs b/crates/noirc_frontend/src/parser/parser.rs index 2044a02c68..deaa045ccf 100644 --- a/crates/noirc_frontend/src/parser/parser.rs +++ b/crates/noirc_frontend/src/parser/parser.rs @@ -35,8 +35,8 @@ use crate::parser::{force, ignore_then_commit, statement_recovery}; use crate::token::{Attribute, Keyword, Token, TokenKind}; use crate::{ BinaryOp, BinaryOpKind, BlockExpression, CompTime, ConstrainStatement, FunctionDefinition, - Ident, IfExpression, ImportStatement, InfixExpression, LValue, Lambda, NoirFunction, NoirImpl, - NoirStruct, Path, PathKind, Pattern, Recoverable, UnaryOp, UnresolvedTypeExpression, + Ident, IfExpression, InfixExpression, LValue, Lambda, NoirFunction, NoirImpl, NoirStruct, Path, + PathKind, Pattern, Recoverable, UnaryOp, UnresolvedTypeExpression, UseTree, UseTreeKind, }; use chumsky::prelude::*; @@ -396,12 +396,7 @@ fn module_declaration() -> impl NoirParser { } fn use_statement() -> impl NoirParser { - let rename = ignore_then_commit(keyword(Keyword::As), ident()).or_not(); - - keyword(Keyword::Use) - .ignore_then(path()) - .then(rename) - .map(|(path, alias)| TopLevelStatement::Import(ImportStatement { path, alias })) + keyword(Keyword::Use).ignore_then(use_tree()).map(TopLevelStatement::Import) } fn keyword(keyword: Keyword) -> impl NoirParser { @@ -436,6 +431,39 @@ fn path() -> impl NoirParser { )) } +fn empty_path() -> impl NoirParser { + let make_path = |kind| move |_| Path { segments: Vec::new(), kind }; + let path_kind = |key, kind| keyword(key).map(make_path(kind)); + + choice((path_kind(Keyword::Crate, PathKind::Crate), path_kind(Keyword::Dep, PathKind::Dep))) +} + +fn rename() -> impl NoirParser> { + ignore_then_commit(keyword(Keyword::As), ident()).or_not() +} + +fn use_tree() -> impl NoirParser { + recursive(|use_tree| { + let simple = path().then(rename()).map(|(mut prefix, alias)| { + let ident = prefix.pop(); + UseTree { prefix, kind: UseTreeKind::Path(ident, alias) } + }); + + let list = { + let prefix = path().or(empty_path()).then_ignore(just(Token::DoubleColon)); + let tree = use_tree + .separated_by(just(Token::Comma)) + .allow_trailing() + .delimited_by(just(Token::LeftBrace), just(Token::RightBrace)) + .map(UseTreeKind::List); + + prefix.then(tree).map(|(prefix, kind)| UseTree { prefix, kind }) + }; + + choice((list, simple)) + }) +} + fn ident() -> impl NoirParser { token_kind(TokenKind::Ident).map_with_span(Ident::from_token) } @@ -1512,12 +1540,30 @@ mod test { fn parse_use() { parse_all( use_statement(), - vec!["use std::hash", "use std", "use foo::bar as hello", "use bar as bar"], + vec![ + "use std::hash", + "use std", + "use foo::bar as hello", + "use bar as bar", + "use foo::{}", + "use foo::{bar,}", + "use foo::{bar, hello}", + "use foo::{bar as bar2, hello}", + "use foo::{bar as bar2, hello::{foo}, nested::{foo, bar}}", + "use dep::{std::println, bar::baz}", + ], ); parse_all_failing( use_statement(), - vec!["use std as ;", "use foobar as as;", "use hello:: as foo;"], + vec![ + "use std as ;", + "use foobar as as;", + "use hello:: as foo;", + "use foo bar::baz", + "use foo bar::{baz}", + "use foo::{,}", + ], ); }