Skip to content

Commit

Permalink
improve auto bound
Browse files Browse the repository at this point in the history
  • Loading branch information
magiclen committed Dec 9, 2023
1 parent 10eb844 commit 8c261ff
Show file tree
Hide file tree
Showing 24 changed files with 99 additions and 99 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "educe"
version = "0.5.1"
version = "0.5.2"
authors = ["Magic Len <len@magiclen.org>"]
edition = "2021"
rust-version = "1.56"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ enum Enum<T: A> {

###### Generic Parameters Bound to the `Clone` Trait or Others

Generic parameters will be automatically bound to the `Clone` trait if necessary. If the `#[educe(Copy)]` attribute exists, all generic parameters will bound to the `Copy` trait.
Generic parameters will be automatically bound to the `Clone` trait if necessary. If the `#[educe(Copy)]` attribute exists, they will be bound to the `Copy` trait.

```rust
use educe::Educe;
Expand Down
2 changes: 1 addition & 1 deletion src/common/bound.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ impl Bound {
params: &Punctuated<GenericParam, Comma>,
bound_trait: &Path,
types: &[&Type],
recursive: Option<(bool, bool)>,
recursive: Option<(bool, bool, bool)>,
) -> Punctuated<WherePredicate, Comma> {
match self {
Self::Disabled => Punctuated::new(),
Expand Down
130 changes: 53 additions & 77 deletions src/common/type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::collections::HashSet;
use syn::{
parse::{Parse, ParseStream},
punctuated::Punctuated,
Ident, Meta, Token, Type, TypeParamBound,
GenericArgument, Ident, Meta, Path, PathArguments, Token, Type, TypeParamBound,
};

pub(crate) struct TypeWithPunctuatedMeta {
Expand Down Expand Up @@ -34,12 +34,48 @@ impl Parse for TypeWithPunctuatedMeta {
}
}

/// recursive (dereference, de_ptr)
/// recursive (dereference, de_ptr, de_param)
#[inline]
pub(crate) fn find_idents_in_path<'a>(
set: &mut HashSet<&'a Ident>,
path: &'a Path,
recursive: Option<(bool, bool, bool)>,
) {
if let Some((_, _, de_param)) = recursive {
if de_param {
if let Some(segment) = path.segments.iter().last() {
if let PathArguments::AngleBracketed(a) = &segment.arguments {
// the ident is definitely not a generic parameter, so we don't insert it

for arg in a.args.iter() {
match arg {
GenericArgument::Type(ty) => {
find_idents_in_type(set, ty, recursive);
},
GenericArgument::AssocType(ty) => {
find_idents_in_type(set, &ty.ty, recursive);
},
_ => (),
}
}

return;
}
}
}
}

if let Some(ty) = path.get_ident() {
set.insert(ty);
}
}

/// recursive (dereference, de_ptr, de_param)
#[inline]
pub(crate) fn find_idents_in_type<'a>(
set: &mut HashSet<&'a Ident>,
ty: &'a Type,
recursive: Option<(bool, bool)>,
recursive: Option<(bool, bool, bool)>,
) {
match ty {
Type::Array(ty) => {
Expand All @@ -53,21 +89,16 @@ pub(crate) fn find_idents_in_type<'a>(
}
},
Type::ImplTrait(ty) => {
if recursive.is_some() {
for b in &ty.bounds {
if let TypeParamBound::Trait(ty) = b {
if let Some(ty) = ty.path.get_ident() {
set.insert(ty);
}
}
// always recursive
for b in &ty.bounds {
if let TypeParamBound::Trait(ty) = b {
find_idents_in_path(set, &ty.path, recursive);
}
}
},
Type::Macro(ty) => {
if recursive.is_some() {
if let Some(ty) = ty.mac.path.get_ident() {
set.insert(ty);
}
find_idents_in_path(set, &ty.mac.path, recursive);
}
},
Type::Paren(ty) => {
Expand All @@ -76,22 +107,16 @@ pub(crate) fn find_idents_in_type<'a>(
}
},
Type::Path(ty) => {
if let Some(ty) = ty.path.get_ident() {
set.insert(ty);
}
find_idents_in_path(set, &ty.path, recursive);
},
Type::Ptr(ty) => {
if let Some((_, de_ptr)) = recursive {
if de_ptr {
find_idents_in_type(set, ty.elem.as_ref(), recursive);
}
if let Some((_, true, _)) = recursive {
find_idents_in_type(set, ty.elem.as_ref(), recursive);
}
},
Type::Reference(ty) => {
if let Some((dereference, _)) = recursive {
if dereference {
find_idents_in_type(set, ty.elem.as_ref(), recursive);
}
if let Some((true, ..)) = recursive {
find_idents_in_type(set, ty.elem.as_ref(), recursive);
}
},
Type::Slice(ty) => {
Expand All @@ -100,13 +125,10 @@ pub(crate) fn find_idents_in_type<'a>(
}
},
Type::TraitObject(ty) => {
if recursive.is_some() {
for b in &ty.bounds {
if let TypeParamBound::Trait(ty) = b {
if let Some(ty) = ty.path.get_ident() {
set.insert(ty);
}
}
// always recursive
for b in &ty.bounds {
if let TypeParamBound::Trait(ty) = b {
find_idents_in_path(set, &ty.path, recursive);
}
}
},
Expand All @@ -121,52 +143,6 @@ pub(crate) fn find_idents_in_type<'a>(
}
}

#[inline]
pub(crate) fn find_derivable_idents_in_type<'a>(set: &mut HashSet<&'a Ident>, ty: &'a Type) {
match ty {
Type::Array(ty) => find_derivable_idents_in_type(set, ty.elem.as_ref()),
Type::Group(ty) => find_derivable_idents_in_type(set, ty.elem.as_ref()),
Type::ImplTrait(ty) => {
for b in &ty.bounds {
if let TypeParamBound::Trait(ty) = b {
if let Some(ty) = ty.path.get_ident() {
set.insert(ty);
}
}
}
},
Type::Macro(ty) => {
if let Some(ty) = ty.mac.path.get_ident() {
set.insert(ty);
}
},
Type::Paren(ty) => find_derivable_idents_in_type(set, ty.elem.as_ref()),
Type::Path(ty) => {
if let Some(ty) = ty.path.get_ident() {
set.insert(ty);
}
},
Type::Ptr(_) => (),
Type::Reference(_) => (),
Type::Slice(ty) => find_derivable_idents_in_type(set, ty.elem.as_ref()),
Type::TraitObject(ty) => {
for b in &ty.bounds {
if let TypeParamBound::Trait(ty) = b {
if let Some(ty) = ty.path.get_ident() {
set.insert(ty);
}
}
}
},
Type::Tuple(ty) => {
for ty in &ty.elems {
find_derivable_idents_in_type(set, ty)
}
},
_ => (),
}
}

#[inline]
pub(crate) fn dereference(ty: &Type) -> &Type {
if let Type::Reference(ty) = ty {
Expand Down
4 changes: 2 additions & 2 deletions src/common/where_predicates_bool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,14 @@ pub(crate) fn create_where_predicates_from_generic_parameters_check_types(
params: &Punctuated<GenericParam, Comma>,
bound_trait: &Path,
types: &[&Type],
resursive: Option<(bool, bool)>,
recursive: Option<(bool, bool, bool)>,
) -> WherePredicates {
let mut where_predicates = Punctuated::new();

let mut set = HashSet::new();

for t in types {
find_idents_in_type(&mut set, t, resursive);
find_idents_in_type(&mut set, t, recursive);
}

for param in params {
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ enum Enum<T: A> {
###### Generic Parameters Bound to the `Clone` Trait or Others
Generic parameters will be automatically bound to the `Clone` trait if necessary. If the `#[educe(Copy)]` attribute exists, all generic parameters will bound to the `Copy` trait.
Generic parameters will be automatically bound to the `Clone` trait if necessary. If the `#[educe(Copy)]` attribute exists, they will be bound to the `Copy` trait.
```rust
# #[cfg(feature = "Clone")]
Expand Down
2 changes: 1 addition & 1 deletion src/trait_handlers/clone/clone_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ impl TraitHandler for CloneEnumHandler {
})
.unwrap(),
&clone_types,
Some((false, false)),
Some((false, false, false)),
);
}

Expand Down
2 changes: 1 addition & 1 deletion src/trait_handlers/clone/clone_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ impl TraitHandler for CloneStructHandler {
})
.unwrap(),
&clone_types,
Some((false, false)),
Some((false, false, false)),
);
}

Expand Down
12 changes: 12 additions & 0 deletions src/trait_handlers/copy/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,18 @@ impl TraitHandler for CopyHandler {

let ident = &ast.ident;

/*
#[derive(Clone)]
struct B<T> {
f1: PhantomData<T>,
}
impl<T> Copy for B<T> {
}
// The above code will throw a compile error because T have to be bound to `Copy`. However, it seems not to be necessary logically.
*/
let bound = type_attribute.bound.into_where_predicates_by_generic_parameters(
&ast.generics.params,
&syn::parse2(quote!(::core::marker::Copy)).unwrap(),
Expand Down
2 changes: 1 addition & 1 deletion src/trait_handlers/debug/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub(crate) fn create_format_arg(
let ty = dereference(ty);

let mut idents = HashSet::new();
find_idents_in_type(&mut idents, ty, Some((true, false)));
find_idents_in_type(&mut idents, ty, Some((true, true, false)));

// simply support one level generics (without considering bounds that use other generics)
let mut filtered_params: Punctuated<GenericParam, Comma> = Punctuated::new();
Expand Down
2 changes: 1 addition & 1 deletion src/trait_handlers/debug/debug_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ impl TraitHandler for DebugEnumHandler {
&ast.generics.params,
&syn::parse2(quote!(::core::fmt::Debug)).unwrap(),
&debug_types,
Some((true, false)),
Some((true, false, false)),
);

let where_clause = ast.generics.make_where_clause();
Expand Down
2 changes: 1 addition & 1 deletion src/trait_handlers/debug/debug_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ impl TraitHandler for DebugStructHandler {
&ast.generics.params,
&syn::parse2(quote!(::core::fmt::Debug)).unwrap(),
&debug_types,
Some((true, false)),
Some((true, false, false)),
);

let where_clause = ast.generics.make_where_clause();
Expand Down
2 changes: 1 addition & 1 deletion src/trait_handlers/default/default_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ impl TraitHandler for DefaultEnumHandler {
&ast.generics.params,
&syn::parse2(quote!(::core::default::Default)).unwrap(),
&default_types,
Some((false, false)),
Some((false, false, false)),
);

let where_clause = ast.generics.make_where_clause();
Expand Down
2 changes: 1 addition & 1 deletion src/trait_handlers/default/default_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ impl TraitHandler for DefaultStructHandler {
&ast.generics.params,
&syn::parse2(quote!(::core::default::Default)).unwrap(),
&default_types,
Some((false, false)),
Some((false, false, false)),
);

let where_clause = ast.generics.make_where_clause();
Expand Down
2 changes: 1 addition & 1 deletion src/trait_handlers/default/default_union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ impl TraitHandler for DefaultUnionHandler {
&ast.generics.params,
&syn::parse2(quote!(::core::default::Default)).unwrap(),
&default_types,
Some((false, false)),
Some((false, false, false)),
);

let where_clause = ast.generics.make_where_clause();
Expand Down
12 changes: 12 additions & 0 deletions src/trait_handlers/eq/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ impl TraitHandler for EqHandler {

let ident = &ast.ident;

/*
#[derive(PartialEq)]
struct B<T> {
f1: PhantomData<T>,
}
impl<T> Eq for B<T> {
}
// The above code will throw a compile error because T have to be bound to `PartialEq`. However, it seems not to be necessary logically.
*/
let bound = type_attribute.bound.into_where_predicates_by_generic_parameters(
&ast.generics.params,
&syn::parse2(quote!(::core::cmp::PartialEq)).unwrap(),
Expand Down
2 changes: 1 addition & 1 deletion src/trait_handlers/hash/hash_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ impl TraitHandler for HashEnumHandler {
&ast.generics.params,
&syn::parse2(quote!(::core::hash::Hash)).unwrap(),
&hash_types,
Some((true, false)),
Some((true, false, false)),
);

let where_clause = ast.generics.make_where_clause();
Expand Down
2 changes: 1 addition & 1 deletion src/trait_handlers/hash/hash_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl TraitHandler for HashStructHandler {
&ast.generics.params,
&syn::parse2(quote!(::core::hash::Hash)).unwrap(),
&hash_types,
Some((true, false)),
Some((true, false, false)),
);

let where_clause = ast.generics.make_where_clause();
Expand Down
2 changes: 1 addition & 1 deletion src/trait_handlers/ord/ord_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ impl TraitHandler for OrdEnumHandler {
&ast.generics.params,
&syn::parse2(quote!(::core::cmp::Ord)).unwrap(),
&ord_types,
Some((true, false)),
Some((true, false, false)),
);

let where_clause = ast.generics.make_where_clause();
Expand Down
2 changes: 1 addition & 1 deletion src/trait_handlers/ord/ord_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ impl TraitHandler for OrdStructHandler {
&ast.generics.params,
&syn::parse2(quote!(::core::cmp::Ord)).unwrap(),
&ord_types,
Some((true, false)),
Some((true, false, false)),
);

let where_clause = ast.generics.make_where_clause();
Expand Down
2 changes: 1 addition & 1 deletion src/trait_handlers/partial_eq/partial_eq_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ impl TraitHandler for PartialEqEnumHandler {
&ast.generics.params,
&syn::parse2(quote!(::core::cmp::PartialEq)).unwrap(),
&partial_eq_types,
Some((true, false)),
Some((true, false, false)),
);

let where_clause = ast.generics.make_where_clause();
Expand Down
Loading

0 comments on commit 8c261ff

Please sign in to comment.