Skip to content

Commit

Permalink
Rollup merge of rust-lang#60437 - davidtwco:issue-60236, r=nikomatsakis
Browse files Browse the repository at this point in the history
Ensure that drop order of `async fn` matches `fn` and that users cannot refer to generated arguments.

Fixes rust-lang#60236 and fixes rust-lang#60438.

This PR modifies the lowering of `async fn` arguments so that the
drop order matches the equivalent `fn`.

Previously, async function arguments were lowered as shown below:

    async fn foo(<pattern>: <ty>) {
      async move {
      }
    } // <-- dropped as you "exit" the fn

    // ...becomes...
    fn foo(__arg0: <ty>) {
      async move {
        let <pattern> = __arg0;
      } // <-- dropped as you "exit" the async block
    }

After this PR, async function arguments will be lowered as:

    async fn foo(<pattern>: <ty>, <pattern>: <ty>, <pattern>: <ty>) {
      async move {
      }
    } // <-- dropped as you "exit" the fn

    // ...becomes...
    fn foo(__arg0: <ty>, __arg1: <ty>, __arg2: <ty>) {
      async move {
        let __arg2 = __arg2;
        let <pattern> = __arg2;
        let __arg1 = __arg1;
        let <pattern> = __arg1;
        let __arg0 = __arg0;
        let <pattern> = __arg0;
      } // <-- dropped as you "exit" the async block
    }

If `<pattern>` is a simple ident, then it is lowered to a single
`let <pattern> = <pattern>;` statement as an optimization.

This PR also stops users from referring to the generated `__argN`
identifiers.

r? @nikomatsakis
  • Loading branch information
Centril committed May 1, 2019
2 parents 12bf981 + 1fedb0a commit 16939a5
Show file tree
Hide file tree
Showing 13 changed files with 505 additions and 232 deletions.
39 changes: 36 additions & 3 deletions src/librustc/hir/lowering.rs
Expand Up @@ -2996,8 +2996,33 @@ impl<'a> LoweringContext<'a> {
if let IsAsync::Async { closure_id, ref arguments, .. } = asyncness {
let mut body = body.clone();

// Async function arguments are lowered into the closure body so that they are
// captured and so that the drop order matches the equivalent non-async functions.
//
// async fn foo(<pattern>: <ty>, <pattern>: <ty>, <pattern>: <ty>) {
// async move {
// }
// }
//
// // ...becomes...
// fn foo(__arg0: <ty>, __arg1: <ty>, __arg2: <ty>) {
// async move {
// let __arg2 = __arg2;
// let <pattern> = __arg2;
// let __arg1 = __arg1;
// let <pattern> = __arg1;
// let __arg0 = __arg0;
// let <pattern> = __arg0;
// }
// }
//
// If `<pattern>` is a simple ident, then it is lowered to a single
// `let <pattern> = <pattern>;` statement as an optimization.
for a in arguments.iter().rev() {
body.stmts.insert(0, a.stmt.clone());
if let Some(pat_stmt) = a.pat_stmt.clone() {
body.stmts.insert(0, pat_stmt);
}
body.stmts.insert(0, a.move_stmt.clone());
}

let async_expr = this.make_async_expr(
Expand Down Expand Up @@ -3093,7 +3118,11 @@ impl<'a> LoweringContext<'a> {
let mut decl = decl.clone();
// Replace the arguments of this async function with the generated
// arguments that will be moved into the closure.
decl.inputs = arguments.clone().drain(..).map(|a| a.arg).collect();
for (i, a) in arguments.clone().drain(..).enumerate() {
if let Some(arg) = a.arg {
decl.inputs[i] = arg;
}
}
lower_fn(&decl)
} else {
lower_fn(decl)
Expand Down Expand Up @@ -3590,7 +3619,11 @@ impl<'a> LoweringContext<'a> {
let mut sig = sig.clone();
// Replace the arguments of this async function with the generated
// arguments that will be moved into the closure.
sig.decl.inputs = arguments.clone().drain(..).map(|a| a.arg).collect();
for (i, a) in arguments.clone().drain(..).enumerate() {
if let Some(arg) = a.arg {
sig.decl.inputs[i] = arg;
}
}
lower_method(&sig)
} else {
lower_method(sig)
Expand Down
13 changes: 9 additions & 4 deletions src/librustc/hir/map/def_collector.rs
Expand Up @@ -94,7 +94,9 @@ impl<'a> DefCollector<'a> {
// Walk the generated arguments for the `async fn`.
for a in arguments {
use visit::Visitor;
this.visit_ty(&a.arg.ty);
if let Some(arg) = &a.arg {
this.visit_ty(&arg.ty);
}
}

// We do not invoke `walk_fn_decl` as this will walk the arguments that are being
Expand All @@ -105,10 +107,13 @@ impl<'a> DefCollector<'a> {
*closure_id, DefPathData::ClosureExpr, REGULAR_SPACE, span,
);
this.with_parent(closure_def, |this| {
use visit::Visitor;
// Walk each of the generated statements before the regular block body.
for a in arguments {
use visit::Visitor;
// Walk each of the generated statements before the regular block body.
this.visit_stmt(&a.stmt);
this.visit_stmt(&a.move_stmt);
if let Some(pat_stmt) = &a.pat_stmt {
this.visit_stmt(&pat_stmt);
}
}

visit::walk_block(this, &body);
Expand Down
15 changes: 10 additions & 5 deletions src/librustc/lint/context.rs
Expand Up @@ -1334,14 +1334,19 @@ impl<'a, T: EarlyLintPass> ast_visit::Visitor<'a> for EarlyContextAndPass<'a, T>
if let ast::IsAsync::Async { ref arguments, .. } = header.asyncness.node {
for a in arguments {
// Visit the argument..
self.visit_pat(&a.arg.pat);
if let ast::ArgSource::AsyncFn(pat) = &a.arg.source {
self.visit_pat(pat);
if let Some(arg) = &a.arg {
self.visit_pat(&arg.pat);
if let ast::ArgSource::AsyncFn(pat) = &arg.source {
self.visit_pat(pat);
}
self.visit_ty(&arg.ty);
}
self.visit_ty(&a.arg.ty);

// ..and the statement.
self.visit_stmt(&a.stmt);
self.visit_stmt(&a.move_stmt);
if let Some(pat_stmt) = &a.pat_stmt {
self.visit_stmt(&pat_stmt);
}
}
}
}
Expand Down
38 changes: 25 additions & 13 deletions src/librustc_resolve/lib.rs
Expand Up @@ -862,7 +862,13 @@ impl<'a, 'tcx> Visitor<'tcx> for Resolver<'a> {
// Walk the generated async arguments if this is an `async fn`, otherwise walk the
// normal arguments.
if let IsAsync::Async { ref arguments, .. } = asyncness {
for a in arguments { add_argument(&a.arg); }
for (i, a) in arguments.iter().enumerate() {
if let Some(arg) = &a.arg {
add_argument(&arg);
} else {
add_argument(&declaration.inputs[i]);
}
}
} else {
for a in &declaration.inputs { add_argument(a); }
}
Expand All @@ -882,8 +888,11 @@ impl<'a, 'tcx> Visitor<'tcx> for Resolver<'a> {
let mut body = body.clone();
// Insert the generated statements into the body before attempting to
// resolve names.
for a in arguments {
body.stmts.insert(0, a.stmt.clone());
for a in arguments.iter().rev() {
if let Some(pat_stmt) = a.pat_stmt.clone() {
body.stmts.insert(0, pat_stmt);
}
body.stmts.insert(0, a.move_stmt.clone());
}
self.visit_block(&body);
} else {
Expand Down Expand Up @@ -4174,7 +4183,7 @@ impl<'a> Resolver<'a> {
let add_module_candidates = |module: Module<'_>, names: &mut Vec<TypoSuggestion>| {
for (&(ident, _), resolution) in module.resolutions.borrow().iter() {
if let Some(binding) = resolution.borrow().binding {
if filter_fn(binding.def()) {
if !ident.name.is_gensymed() && filter_fn(binding.def()) {
names.push(TypoSuggestion {
candidate: ident.name,
article: binding.def().article(),
Expand All @@ -4192,7 +4201,7 @@ impl<'a> Resolver<'a> {
for rib in self.ribs[ns].iter().rev() {
// Locals and type parameters
for (ident, def) in &rib.bindings {
if filter_fn(*def) {
if !ident.name.is_gensymed() && filter_fn(*def) {
names.push(TypoSuggestion {
candidate: ident.name,
article: def.article(),
Expand All @@ -4219,7 +4228,7 @@ impl<'a> Resolver<'a> {
index: CRATE_DEF_INDEX,
});

if filter_fn(crate_mod) {
if !ident.name.is_gensymed() && filter_fn(crate_mod) {
Some(TypoSuggestion {
candidate: ident.name,
article: "a",
Expand All @@ -4242,13 +4251,16 @@ impl<'a> Resolver<'a> {
// Add primitive types to the mix
if filter_fn(Def::PrimTy(Bool)) {
names.extend(
self.primitive_type_table.primitive_types.iter().map(|(name, _)| {
TypoSuggestion {
candidate: *name,
article: "a",
kind: "primitive type",
}
})
self.primitive_type_table.primitive_types
.iter()
.filter(|(name, _)| !name.is_gensymed())
.map(|(name, _)| {
TypoSuggestion {
candidate: *name,
article: "a",
kind: "primitive type",
}
})
)
}
} else {
Expand Down
12 changes: 8 additions & 4 deletions src/libsyntax/ast.rs
Expand Up @@ -1865,10 +1865,14 @@ pub enum Unsafety {
pub struct AsyncArgument {
/// `__arg0`
pub ident: Ident,
/// `__arg0: <ty>` argument to replace existing function argument `<pat>: <ty>`.
pub arg: Arg,
/// `let <pat>: <ty> = __arg0;` statement to be inserted at the start of the block.
pub stmt: Stmt,
/// `__arg0: <ty>` argument to replace existing function argument `<pat>: <ty>`. Only if
/// argument is not a simple binding.
pub arg: Option<Arg>,
/// `let __arg0 = __arg0;` statement to be inserted at the start of the block.
pub move_stmt: Stmt,
/// `let <pat> = __arg0;` statement to be inserted at the start of the block, after matching
/// move statement. Only if argument is not a simple binding.
pub pat_stmt: Option<Stmt>,
}

#[derive(Clone, RustcEncodable, RustcDecodable, Debug)]
Expand Down
5 changes: 4 additions & 1 deletion src/libsyntax/ext/placeholders.rs
Expand Up @@ -199,7 +199,10 @@ impl<'a, 'b> MutVisitor for PlaceholderExpander<'a, 'b> {

if let ast::IsAsync::Async { ref mut arguments, .. } = a {
for argument in arguments.iter_mut() {
self.next_id(&mut argument.stmt.id);
self.next_id(&mut argument.move_stmt.id);
if let Some(ref mut pat_stmt) = &mut argument.pat_stmt {
self.next_id(&mut pat_stmt.id);
}
}
}
}
Expand Down
14 changes: 11 additions & 3 deletions src/libsyntax/mut_visit.rs
Expand Up @@ -694,13 +694,21 @@ pub fn noop_visit_asyncness<T: MutVisitor>(asyncness: &mut IsAsync, vis: &mut T)
IsAsync::Async { closure_id, return_impl_trait_id, ref mut arguments } => {
vis.visit_id(closure_id);
vis.visit_id(return_impl_trait_id);
for AsyncArgument { ident, arg, stmt } in arguments.iter_mut() {
for AsyncArgument { ident, arg, pat_stmt, move_stmt } in arguments.iter_mut() {
vis.visit_ident(ident);
vis.visit_arg(arg);
visit_clobber(stmt, |stmt| {
if let Some(arg) = arg {
vis.visit_arg(arg);
}
visit_clobber(move_stmt, |stmt| {
vis.flat_map_stmt(stmt)
.expect_one("expected visitor to produce exactly one item")
});
visit_opt(pat_stmt, |stmt| {
visit_clobber(stmt, |stmt| {
vis.flat_map_stmt(stmt)
.expect_one("expected visitor to produce exactly one item")
})
});
}
}
IsAsync::NotAsync => {}
Expand Down
64 changes: 49 additions & 15 deletions src/libsyntax/parse/parser.rs
Expand Up @@ -8720,27 +8720,46 @@ impl<'a> Parser<'a> {

// Construct a name for our temporary argument.
let name = format!("__arg{}", index);
let ident = Ident::from_str(&name);
let ident = Ident::from_str(&name).gensym();

// Check if this is a ident pattern, if so, we can optimize and avoid adding a
// `let <pat> = __argN;` statement, instead just adding a `let <pat> = <pat>;`
// statement.
let (ident, is_simple_pattern) = match input.pat.node {
PatKind::Ident(_, ident, _) => (ident, true),
_ => (ident, false),
};

// Construct an argument representing `__argN: <ty>` to replace the argument of the
// async function.
let arg = Arg {
ty: input.ty.clone(),
id,
// async function if it isn't a simple pattern.
let arg = if is_simple_pattern {
None
} else {
Some(Arg {
ty: input.ty.clone(),
id,
pat: P(Pat {
id,
node: PatKind::Ident(
BindingMode::ByValue(Mutability::Immutable), ident, None,
),
span,
}),
source: ArgSource::AsyncFn(input.pat.clone()),
})
};

// Construct a `let __argN = __argN;` statement to insert at the top of the
// async closure. This makes sure that the argument is captured by the closure and
// that the drop order is correct.
let move_local = Local {
pat: P(Pat {
id,
node: PatKind::Ident(
BindingMode::ByValue(Mutability::Immutable), ident, None,
),
span,
}),
source: ArgSource::AsyncFn(input.pat.clone()),
};

// Construct a `let <pat> = __argN;` statement to insert at the top of the
// async closure.
let local = P(Local {
pat: input.pat.clone(),
// We explicitly do not specify the type for this statement. When the user's
// argument type is `impl Trait` then this would require the
// `impl_trait_in_bindings` feature to also be present for that same type to
Expand All @@ -8760,10 +8779,25 @@ impl<'a> Parser<'a> {
span,
attrs: ThinVec::new(),
source: LocalSource::AsyncFn,
});
let stmt = Stmt { id, node: StmtKind::Local(local), span, };
};

// Construct a `let <pat> = __argN;` statement to insert at the top of the
// async closure if this isn't a simple pattern.
let pat_stmt = if is_simple_pattern {
None
} else {
Some(Stmt {
id,
node: StmtKind::Local(P(Local {
pat: input.pat.clone(),
..move_local.clone()
})),
span,
})
};

arguments.push(AsyncArgument { ident, arg, stmt });
let move_stmt = Stmt { id, node: StmtKind::Local(P(move_local)), span };
arguments.push(AsyncArgument { ident, arg, pat_stmt, move_stmt });
}
}
}
Expand Down

0 comments on commit 16939a5

Please sign in to comment.