Skip to content

Commit

Permalink
Merge pull request #1666 from dtolnay/foldhelper
Browse files Browse the repository at this point in the history
Generate fewer monomorphizations in Fold
  • Loading branch information
dtolnay committed May 19, 2024
2 parents b2ee932 + 8124c0e commit 7273aa7
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 159 deletions.
29 changes: 16 additions & 13 deletions codegen/src/fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ use syn_codegen::{Data, Definitions, Features, Node, Type};

const FOLD_SRC: &str = "src/gen/fold.rs";

fn simple_visit(item: &str, name: &TokenStream) -> TokenStream {
fn method_name(item: &str) -> Ident {
let ident = gen::under_name(item);
let method = format_ident!("fold_{}", ident);
quote! {
f.#method(#name)
}
format_ident!("fold_{}", ident)
}

fn simple_fold(item: &str, name: &TokenStream) -> TokenStream {
let method = method_name(item);
quote! { f.#method(#name) }
}

fn visit(
Expand All @@ -29,17 +31,18 @@ fn visit(
})
}
Type::Vec(t) => {
let operand = quote!(it);
let val = visit(t, features, defs, &operand)?;
let Type::Syn(t) = &**t else { unimplemented!() };
let method = method_name(t);
Some(quote! {
FoldHelper::lift(#name, |it| #val)
FoldHelper::lift(#name, f, F::#method)
})
}
Type::Punctuated(p) => {
let operand = quote!(it);
let val = visit(&p.element, features, defs, &operand)?;
let t = &*p.element;
let Type::Syn(t) = t else { unimplemented!() };
let method = method_name(t);
Some(quote! {
FoldHelper::lift(#name, |it| #val)
FoldHelper::lift(#name, f, F::#method)
})
}
Type::Option(t) => {
Expand All @@ -66,14 +69,14 @@ fn visit(
fn requires_full(features: &Features) -> bool {
features.any.contains("full") && features.any.len() == 1
}
let mut res = simple_visit(t, name);
let mut res = simple_fold(t, name);
let target = defs.types.iter().find(|ty| ty.ident == *t).unwrap();
if requires_full(&target.features) && !requires_full(features) {
res = quote!(full!(#res));
}
Some(res)
}
Type::Ext(t) if gen::TERMINAL_TYPES.contains(&&t[..]) => Some(simple_visit(t, name)),
Type::Ext(t) if gen::TERMINAL_TYPES.contains(&&t[..]) => Some(simple_fold(t, name)),
Type::Ext(_) | Type::Std(_) | Type::Token(_) | Type::Group(_) => None,
}
}
Expand Down

0 comments on commit 7273aa7

Please sign in to comment.