Skip to content

Commit

Permalink
Support generic bounds with bound attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
Dessix committed Aug 2, 2021
1 parent 031fea6 commit a95fd83
Show file tree
Hide file tree
Showing 8 changed files with 247 additions and 5 deletions.
28 changes: 27 additions & 1 deletion impl/src/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ use proc_macro2::{Delimiter, Group, Span, TokenStream, TokenTree};
use quote::{format_ident, quote, ToTokens};
use std::iter::FromIterator;
use syn::parse::{Nothing, ParseStream};
use syn::punctuated::Punctuated;
use syn::{
braced, bracketed, parenthesized, token, Attribute, Error, Ident, Index, LitInt, LitStr,
Result, Token,
Result, Token, TypeParamBound,
};

pub struct Attrs<'a> {
Expand All @@ -13,6 +14,7 @@ pub struct Attrs<'a> {
pub backtrace: Option<&'a Attribute>,
pub from: Option<&'a Attribute>,
pub transparent: Option<Transparent<'a>>,
pub bound: Option<Bound<'a>>,
}

#[derive(Clone)]
Expand All @@ -29,13 +31,20 @@ pub struct Transparent<'a> {
pub span: Span,
}

#[derive(Clone)]
pub struct Bound<'a> {
pub original: &'a Attribute,
pub bounds: Punctuated<TypeParamBound, token::Add>,
}

pub fn get(input: &[Attribute]) -> Result<Attrs> {
let mut attrs = Attrs {
display: None,
source: None,
backtrace: None,
from: None,
transparent: None,
bound: None,
};

for attr in input {
Expand Down Expand Up @@ -70,6 +79,7 @@ pub fn get(input: &[Attribute]) -> Result<Attrs> {

fn parse_error_attribute<'a>(attrs: &mut Attrs<'a>, attr: &'a Attribute) -> Result<()> {
syn::custom_keyword!(transparent);
syn::custom_keyword!(bound);

attr.parse_args_with(|input: ParseStream| {
if let Some(kw) = input.parse::<Option<transparent>>()? {
Expand All @@ -84,6 +94,22 @@ fn parse_error_attribute<'a>(attrs: &mut Attrs<'a>, attr: &'a Attribute) -> Resu
span: kw.span,
});
return Ok(());
} else if input.parse::<Option<bound>>()?.is_some() {
if attrs.bound.is_some() {
return Err(Error::new_spanned(
attr,
"duplicate #[error(bound)] attribute",
));
}
input.parse::<token::Eq>().map_err(|_| {
Error::new_spanned(attr, "\"bound\" keyword must be followed by '='")
})?;
let bound = Bound {
original: attr,
bounds: Punctuated::<TypeParamBound, token::Add>::parse_separated_nonempty(input)?,
};
attrs.bound = Some(bound);
return Ok(());
}

let display = Display {
Expand Down
112 changes: 108 additions & 4 deletions impl/src/expand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use crate::ast::{Enum, Field, Input, Struct};
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned, ToTokens};
use syn::spanned::Spanned;
use syn::{Data, DeriveInput, Member, PathArguments, Result, Type, Visibility};
use syn::{
token, Data, DeriveInput, Member, PathArguments, Result, Type, Visibility, WherePredicate,
};

pub fn derive(node: &DeriveInput) -> Result<TokenStream> {
let input = Input::from_syn(node)?;
Expand Down Expand Up @@ -112,9 +114,27 @@ fn impl_struct(input: Struct) -> TokenStream {
None
};
let display_impl = display_body.map(|body| {
let display_impl_generics = {
let mut lifetime_params = input.generics.lifetimes().peekable();
let mut type_params = input.generics.type_params().peekable();
if let Some(bounds) = &input.attrs.bound {
let bounds = std::iter::repeat(bounds).map(|x| &x.bounds);
quote! {
<#(#type_params: #bounds),*>
}
} else if lifetime_params.peek().is_none() || type_params.peek().is_none() {
quote! {
<#(#lifetime_params),* #(#type_params),*>
}
} else {
quote! {
<#(#lifetime_params),* , #(#type_params),*>
}
}
};
quote! {
#[allow(unused_qualifications)]
impl #impl_generics std::fmt::Display for #ty #ty_generics #where_clause {
impl #display_impl_generics std::fmt::Display for #ty #ty_generics #where_clause {
#[allow(
// Clippy bug: https://github.com/rust-lang/rust-clippy/issues/7422
clippy::nonstandard_macro_braces,
Expand Down Expand Up @@ -143,6 +163,17 @@ fn impl_struct(input: Struct) -> TokenStream {
});

let error_trait = spanned_error_trait(input.original);
let bounded_where_predicates: Vec<syn::WherePredicate> = input
.attrs
.bound
.as_ref()
.map(|bound| apply_type_bounds(input.generics.type_params(), bound))
.unwrap_or_default();
let where_clause = extend_where_clause(
where_clause.cloned(),
input.original.span(),
bounded_where_predicates.into_iter(),
);

quote! {
#[allow(unused_qualifications)]
Expand Down Expand Up @@ -302,9 +333,27 @@ fn impl_enum(input: Enum) -> TokenStream {
#ty::#ident #pat => #display
}
});
let display_impl_generics = {
let mut lifetime_params = input.generics.lifetimes().peekable();
let mut type_params = input.generics.type_params().peekable();
if let Some(bounds) = &input.attrs.bound {
let bounds = std::iter::repeat(bounds).map(|x| &x.bounds);
quote! {
<#(#type_params: #bounds),*>
}
} else if lifetime_params.peek().is_none() || type_params.peek().is_none() {
quote! {
<#(#lifetime_params),* #(#type_params),*>
}
} else {
quote! {
<#(#lifetime_params),* , #(#type_params),*>
}
}
};
Some(quote! {
#[allow(unused_qualifications)]
impl #impl_generics std::fmt::Display for #ty #ty_generics #where_clause {
impl #display_impl_generics std::fmt::Display for #ty #ty_generics #where_clause {
fn fmt(&self, __formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
#use_as_display
#[allow(
Expand Down Expand Up @@ -342,7 +391,17 @@ fn impl_enum(input: Enum) -> TokenStream {
});

let error_trait = spanned_error_trait(input.original);

let bounded_where_predicates: Vec<syn::WherePredicate> = input
.attrs
.bound
.as_ref()
.map(|bound| apply_type_bounds(input.generics.type_params(), bound))
.unwrap_or_default();
let where_clause = extend_where_clause(
where_clause.cloned(),
input.original.span(),
bounded_where_predicates.into_iter(),
);
quote! {
#[allow(unused_qualifications)]
impl #impl_generics #error_trait for #ty #ty_generics #where_clause {
Expand Down Expand Up @@ -424,3 +483,48 @@ fn spanned_error_trait(input: &DeriveInput) -> TokenStream {
let error = quote_spanned!(last_span=> Error);
quote!(#path #error)
}

/// Enhance a where clause with the given predicates, or create one with them if needed.
/// When no new predicates are provided, return without alteration.
fn extend_where_clause<TPredicates: std::iter::ExactSizeIterator<Item = syn::WherePredicate>>(
// Clause to be enhanced; created if absent when predicates are provided
where_clause: Option<syn::WhereClause>,
// Used to create span for new where clause if populating
where_span: proc_macro2::Span,
predicates: TPredicates,
) -> Option<syn::WhereClause> {
// If we don't have any predicates to add, it doesn't matter if we
// have a where clause to extend or not; return whatever was given
if predicates.len() == 0 {
return where_clause;
}
Some(match where_clause {
// Extend the existing clause with the new predicates
Some(mut where_clause) => {
where_clause.predicates.extend(predicates);
where_clause
}
// No where clause provided; create a new one with the provided span
None => syn::WhereClause {
where_token: token::Where(where_span),
predicates: predicates.collect(),
},
})
}

fn apply_type_bounds<'a, TTypeParams: std::iter::Iterator<Item = &'a syn::TypeParam>>(
type_params: TTypeParams,
bound_attr: &'a crate::attr::Bound<'_>,
) -> Vec<WherePredicate> {
let bounds = &bound_attr.bounds;
if bounds.is_empty() {
return Vec::new();
}
type_params
.map(move |p| {
let predicate = quote! { #p: #bounds };
syn::parse2::<syn::WherePredicate>(predicate)
.expect("quasiquote must create predicate bounds")
})
.collect()
}
24 changes: 24 additions & 0 deletions impl/src/valid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,18 @@ impl Struct<'_> {
));
}
}
if let Some(crate::attr::Bound {
original: bound_span,
..
}) = self.attrs.bound
{
if self.generics.params.is_empty() {
return Err(Error::new_spanned(
bound_span,
"#[error(bound = ...)] requires generics to apply bounds against",
));
}
}
check_field_attrs(&self.fields)?;
for field in &self.fields {
field.validate()?;
Expand All @@ -52,6 +64,18 @@ impl Enum<'_> {
));
}
}
if let Some(crate::attr::Bound {
original: bound_span,
..
}) = self.attrs.bound
{
if self.generics.params.is_empty() {
return Err(Error::new_spanned(
bound_span,
"#[error(bound = ...)] requires at least one generic type parameter",
));
}
}
let mut from_types = Set::new();
for variant in &self.variants {
if let Some(from_field) = variant.from_field() {
Expand Down
48 changes: 48 additions & 0 deletions tests/test_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@ use std::io;
use thiserror::Error;

macro_rules! unimplemented_display {
(($tp:tt), $ty:ty) => {
impl<$tp> Display for $ty {
fn fmt(&self, _formatter: &mut fmt::Formatter) -> fmt::Result {
unimplemented!()
}
}
};
($ty:ty) => {
impl Display for $ty {
fn fmt(&self, _formatter: &mut fmt::Formatter) -> fmt::Result {
Expand Down Expand Up @@ -49,9 +56,50 @@ enum EnumError {
Unit,
}

#[derive(Error, Debug)]
#[error(bound = std::fmt::Display + std::error::Error + 'static)]
enum WithGeneric<T> {
Variant,
Generic(T),
}

#[derive(Error, Debug)]
#[error(bound = std::fmt::Debug + std::error::Error + 'static)]
enum WithGenericFrom<T> {
Variant,
Generic(#[from] T),
}

#[derive(Error, Debug)]
#[error(bound = std::fmt::Display + std::fmt::Debug + std::error::Error + 'static)]
enum WithGenericTransparent<T> {
#[error("variant")]
Variant,
#[error(transparent)]
Generic(#[from] T),
}

#[derive(Error, Debug)]
#[error(bound = std::error::Error + 'static)]
struct WithGenericStruct<T> {
#[from]
inner: T,
}

#[derive(Error, Debug)]
#[error(bound = std::fmt::Display + std::error::Error + 'static)]
#[error(transparent)]
struct WithGenericStructTransparent<T> {
#[from]
inner: T,
}

unimplemented_display!(BracedError);
unimplemented_display!(TupleError);
unimplemented_display!(UnitError);
unimplemented_display!(WithSource);
unimplemented_display!(WithAnyhow);
unimplemented_display!(EnumError);
unimplemented_display!((T), WithGeneric<T>);
unimplemented_display!((T), WithGenericFrom<T>);
unimplemented_display!((T), WithGenericStruct<T>);
15 changes: 15 additions & 0 deletions tests/ui/bound-enum-without-generic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use thiserror::Error;

#[derive(Error, Debug)]
#[error(bound = std::error::Error + 'static)]
enum BoundsWithoutGeneric {
Variant(u32),
}

impl std::fmt::Display for BoundsWithoutGeneric {
fn fmt(&self, _formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
unimplemented!()
}
}

fn main() {}
5 changes: 5 additions & 0 deletions tests/ui/bound-enum-without-generic.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
error: #[error(bound = ...)] requires at least one generic type parameter
--> $DIR/bound-enum-without-generic.rs:4:1
|
4 | #[error(bound = std::error::Error + 'static)]
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
15 changes: 15 additions & 0 deletions tests/ui/bound-struct-without-generic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use thiserror::Error;

#[derive(Error, Debug)]
#[error(bound = std::error::Error + 'static)]
struct BoundsWithoutGeneric {
inner: u32,
}

impl std::fmt::Display for BoundsWithoutGeneric {
fn fmt(&self, _formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
unimplemented!()
}
}

fn main() {}
5 changes: 5 additions & 0 deletions tests/ui/bound-struct-without-generic.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
error: #[error(bound = ...)] requires generics to apply bounds against
--> $DIR/bound-struct-without-generic.rs:4:1
|
4 | #[error(bound = std::error::Error + 'static)]
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

0 comments on commit a95fd83

Please sign in to comment.