diff --git a/num_enum/src/lib.rs b/num_enum/src/lib.rs index 425b706..86a0b73 100644 --- a/num_enum/src/lib.rs +++ b/num_enum/src/lib.rs @@ -21,10 +21,11 @@ pub trait FromPrimitive: Sized { pub trait TryFromPrimitive: Sized { type Primitive: Copy + Eq + fmt::Debug; + type Error; const NAME: &'static str; - fn try_from_primitive(number: Self::Primitive) -> Result>; + fn try_from_primitive(number: Self::Primitive) -> Result; } pub trait UnsafeFromPrimitive: Sized { @@ -56,6 +57,12 @@ pub struct TryFromPrimitiveError { pub number: Enum::Primitive, } +impl TryFromPrimitiveError { + pub fn new(number: Enum::Primitive) -> Self { + Self { number } + } +} + impl fmt::Debug for TryFromPrimitiveError { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { fmt.debug_struct("TryFromPrimitiveError") diff --git a/num_enum/tests/renamed_num_enum.rs b/num_enum/tests/renamed_num_enum.rs index 44784ed..d05fbf6 100644 --- a/num_enum/tests/renamed_num_enum.rs +++ b/num_enum/tests/renamed_num_enum.rs @@ -1,3 +1,5 @@ +use std::process::Stdio; + #[test] fn no_std() { assert!(::std::process::Command::new("cargo") @@ -9,6 +11,8 @@ fn no_std() { "/../renamed_num_enum/Cargo.toml", ), ]) + .stdout(Stdio::inherit()) + .stderr(Stdio::inherit()) .status() .unwrap() .success()) @@ -27,6 +31,8 @@ fn std() { "--features", "std", ]) + .stdout(Stdio::inherit()) + .stderr(Stdio::inherit()) .status() .unwrap() .success()) diff --git a/num_enum/tests/try_build/compile_fail/custom_error_type_parsing.rs b/num_enum/tests/try_build/compile_fail/custom_error_type_parsing.rs new file mode 100644 index 0000000..7abaf63 --- /dev/null +++ b/num_enum/tests/try_build/compile_fail/custom_error_type_parsing.rs @@ -0,0 +1,56 @@ +#[derive(num_enum::TryFromPrimitive)] +#[num_enum(error_type(name = CustomError))] +#[repr(u8)] +enum MissingConstructor { + Zero, + One, + Two, +} + +#[derive(num_enum::TryFromPrimitive)] +#[num_enum(error_type(constructor = CustomError::new))] +#[repr(u8)] +enum MissingName { + Zero, + One, + Two, +} + +#[derive(num_enum::TryFromPrimitive)] +#[num_enum(error_type(name = CustomError, constructor = CustomError::new, extra = something))] +#[repr(u8)] +enum ExtraAttr { + Zero, + One, + Two, +} + +#[derive(num_enum::TryFromPrimitive)] +#[num_enum(error_type(name = CustomError, constructor = CustomError::new), error_type(name = CustomError, constructor = CustomError::new))] +#[repr(u8)] +enum TwoErrorTypes { + Zero, + One, + Two, +} + +#[derive(num_enum::TryFromPrimitive)] +#[num_enum(error_type(name = CustomError, constructor = CustomError::new))] +#[num_enum(error_type(name = CustomError, constructor = CustomError::new))] +#[repr(u8)] +enum TwoAttrs { + Zero, + One, + Two, +} + +struct CustomError {} + +impl CustomError { + fn new(_: u8) -> CustomError { + CustomError{} + } +} + +fn main() { +} diff --git a/num_enum/tests/try_build/compile_fail/custom_error_type_parsing.stderr b/num_enum/tests/try_build/compile_fail/custom_error_type_parsing.stderr new file mode 100644 index 0000000..fef9618 --- /dev/null +++ b/num_enum/tests/try_build/compile_fail/custom_error_type_parsing.stderr @@ -0,0 +1,29 @@ +error: num_enum error_type attribute requires `constructor` value + --> tests/try_build/compile_fail/custom_error_type_parsing.rs:2:12 + | +2 | #[num_enum(error_type(name = CustomError))] + | ^^^^^^^^^^ + +error: num_enum error_type attribute requires `name` value + --> tests/try_build/compile_fail/custom_error_type_parsing.rs:11:12 + | +11 | #[num_enum(error_type(constructor = CustomError::new))] + | ^^^^^^^^^^ + +error: expected `name` or `constructor` + --> tests/try_build/compile_fail/custom_error_type_parsing.rs:20:75 + | +20 | #[num_enum(error_type(name = CustomError, constructor = CustomError::new, extra = something))] + | ^^^^^ + +error: num_enum attribute must have at most one error_type + --> tests/try_build/compile_fail/custom_error_type_parsing.rs:29:76 + | +29 | #[num_enum(error_type(name = CustomError, constructor = CustomError::new), error_type(name = CustomError, constructor = CustomError::new))] + | ^^^^^^^^^^ + +error: At most one num_enum error_type attribute may be specified + --> tests/try_build/compile_fail/custom_error_type_parsing.rs:39:1 + | +39 | #[num_enum(error_type(name = CustomError, constructor = CustomError::new))] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/num_enum/tests/try_from_primitive.rs b/num_enum/tests/try_from_primitive.rs index b6454be..f65b542 100644 --- a/num_enum/tests/try_from_primitive.rs +++ b/num_enum/tests/try_from_primitive.rs @@ -502,6 +502,47 @@ fn try_from_primitive_number() { assert_eq!(try_from, Ok(Enum::Whatever)); } +#[test] +fn custom_error() { + #[derive(Debug, Eq, PartialEq, TryFromPrimitive)] + #[num_enum(error_type(name = CustomError, constructor = CustomError::new))] + #[repr(u8)] + enum FirstNumber { + Zero, + One, + Two, + } + + #[derive(Debug, Eq, PartialEq, TryFromPrimitive)] + #[num_enum(error_type(constructor = CustomError::new, name = CustomError))] + #[repr(u8)] + enum SecondNumber { + Zero, + One, + Two, + } + + #[derive(Debug, PartialEq, Eq)] + struct CustomError { + bad_value: u8, + } + + impl CustomError { + fn new(value: u8) -> CustomError { + CustomError { bad_value: value } + } + } + + let zero: Result = 0u8.try_into(); + assert_eq!(zero, Ok(FirstNumber::Zero)); + + let three: Result = 3u8.try_into(); + assert_eq!(three.unwrap_err(), CustomError { bad_value: 3u8 }); + + let three: Result = 3u8.try_into(); + assert_eq!(three.unwrap_err(), CustomError { bad_value: 3u8 }); +} + // #[derive(FromPrimitive)] generates implementations for the following traits: // // - `FromPrimitive` diff --git a/num_enum_derive/Cargo.toml b/num_enum_derive/Cargo.toml index 2133828..fdb57cb 100644 --- a/num_enum_derive/Cargo.toml +++ b/num_enum_derive/Cargo.toml @@ -36,3 +36,6 @@ proc-macro2 = "1.0.60" proc-macro-crate = { version = "1", optional = true } quote = "1" syn = { version = "2", features = ["parsing"] } + +[dev-dependencies] +syn = { version = "2", features = ["extra-traits", "parsing"] } diff --git a/num_enum_derive/src/enum_attributes.rs b/num_enum_derive/src/enum_attributes.rs new file mode 100644 index 0000000..fd19a31 --- /dev/null +++ b/num_enum_derive/src/enum_attributes.rs @@ -0,0 +1,274 @@ +use crate::utils::die; +use proc_macro2::Span; +use syn::{ + parse::{Parse, ParseStream}, + Error, Result, +}; + +mod kw { + syn::custom_keyword!(constructor); + syn::custom_keyword!(error_type); + syn::custom_keyword!(name); +} + +// Example: error_type(name = Foo, constructor = Foo::new) +#[cfg_attr(test, derive(Debug))] +pub(crate) struct Attributes { + pub(crate) error_type: Option, +} + +// Example: error_type(name = Foo, constructor = Foo::new) +#[cfg_attr(test, derive(Debug))] +pub(crate) enum AttributeItem { + ErrorType(ErrorTypeAttribute), +} + +impl Parse for Attributes { + fn parse(input: ParseStream<'_>) -> Result { + let attribute_items = input.parse_terminated(AttributeItem::parse, syn::Token![,])?; + let mut maybe_error_type = None; + for attribute_item in &attribute_items { + match attribute_item { + AttributeItem::ErrorType(error_type) => { + if maybe_error_type.is_some() { + return Err(Error::new( + error_type.span, + "num_enum attribute must have at most one error_type", + )); + } + maybe_error_type = Some(error_type.clone()); + } + } + } + Ok(Self { + error_type: maybe_error_type, + }) + } +} + +impl Parse for AttributeItem { + fn parse(input: ParseStream<'_>) -> Result { + let lookahead = input.lookahead1(); + if lookahead.peek(kw::error_type) { + input.parse().map(Self::ErrorType) + } else { + Err(lookahead.error()) + } + } +} + +// Example: error_type(name = Foo, constructor = Foo::new) +#[derive(Clone)] +#[cfg_attr(test, derive(Debug))] +pub(crate) struct ErrorTypeAttribute { + pub(crate) name: ErrorTypeNameAttribute, + pub(crate) constructor: ErrorTypeConstructorAttribute, + + span: Span, +} + +impl Parse for ErrorTypeAttribute { + fn parse(input: ParseStream) -> Result { + let keyword: kw::error_type = input.parse()?; + let span = keyword.span; + let content; + syn::parenthesized!(content in input); + let attribute_values = + content.parse_terminated(ErrorTypeAttributeNamedArgument::parse, syn::Token![,])?; + let mut name = None; + let mut constructor = None; + for attribute_value in &attribute_values { + match attribute_value { + ErrorTypeAttributeNamedArgument::Name(name_attr) => { + if name.is_some() { + die!("num_enum error_type attribute must have exactly one `name` value"); + } + name = Some(name_attr.clone()); + } + ErrorTypeAttributeNamedArgument::Constructor(constructor_attr) => { + if constructor.is_some() { + die!("num_enum error_type attribute must have exactly one `constructor` value") + } + constructor = Some(constructor_attr.clone()); + } + } + } + match (name, constructor) { + (None, None) => Err(Error::new( + span, + "num_enum error_type attribute requires `name` and `constructor` values", + )), + (Some(_), None) => Err(Error::new( + span, + "num_enum error_type attribute requires `constructor` value", + )), + (None, Some(_)) => Err(Error::new( + span, + "num_enum error_type attribute requires `name` value", + )), + (Some(name), Some(constructor)) => Ok(Self { + name, + constructor, + span, + }), + } + } +} + +// Examples: +// * name = Foo +// * constructor = Foo::new +pub(crate) enum ErrorTypeAttributeNamedArgument { + Name(ErrorTypeNameAttribute), + Constructor(ErrorTypeConstructorAttribute), +} + +impl Parse for ErrorTypeAttributeNamedArgument { + fn parse(input: ParseStream<'_>) -> Result { + let lookahead = input.lookahead1(); + if lookahead.peek(kw::name) { + input.parse().map(Self::Name) + } else if lookahead.peek(kw::constructor) { + input.parse().map(Self::Constructor) + } else { + Err(lookahead.error()) + } + } +} + +// Example: name = Foo +#[derive(Clone)] +#[cfg_attr(test, derive(Debug))] +pub(crate) struct ErrorTypeNameAttribute { + pub(crate) path: syn::Path, +} + +impl Parse for ErrorTypeNameAttribute { + fn parse(input: ParseStream) -> Result { + input.parse::()?; + input.parse::()?; + let path = input.parse()?; + Ok(Self { path }) + } +} + +// Example: constructor = Foo::new +#[derive(Clone)] +#[cfg_attr(test, derive(Debug))] +pub(crate) struct ErrorTypeConstructorAttribute { + pub(crate) path: syn::Path, +} + +impl Parse for ErrorTypeConstructorAttribute { + fn parse(input: ParseStream) -> Result { + input.parse::()?; + input.parse::()?; + let path = input.parse()?; + Ok(Self { path }) + } +} + +#[cfg(test)] +mod test { + use crate::enum_attributes::Attributes; + use quote::ToTokens; + use syn::{parse_quote, Path}; + + #[test] + fn parse_num_enum_attr() { + let expected_name: Path = parse_quote! { Foo }; + let expected_constructor: Path = parse_quote! { ::foo::Foo::::new }; + + let attributes: Attributes = + syn::parse_str("error_type(name = Foo, constructor = ::foo::Foo::::new)").unwrap(); + assert!(attributes.error_type.is_some()); + let error_type = attributes.error_type.unwrap(); + assert_eq!( + error_type.name.path.to_token_stream().to_string(), + expected_name.to_token_stream().to_string() + ); + assert_eq!( + error_type.constructor.path.to_token_stream().to_string(), + expected_constructor.to_token_stream().to_string() + ); + } + + #[test] + fn parse_num_enum_attr_swapped_order() { + let expected_name: Path = parse_quote! { Foo }; + let expected_constructor: Path = parse_quote! { ::foo::Foo::::new }; + + let attributes: Attributes = + syn::parse_str("error_type(constructor = ::foo::Foo::::new, name = Foo)").unwrap(); + assert!(attributes.error_type.is_some()); + let error_type = attributes.error_type.unwrap(); + assert_eq!( + error_type.name.path.to_token_stream().to_string(), + expected_name.to_token_stream().to_string() + ); + assert_eq!( + error_type.constructor.path.to_token_stream().to_string(), + expected_constructor.to_token_stream().to_string() + ); + } + + #[test] + fn missing_constructor() { + let err = syn::parse_str::("error_type(name = Foo)").unwrap_err(); + assert_eq!( + err.to_string(), + "num_enum error_type attribute requires `constructor` value" + ); + } + + #[test] + fn missing_name() { + let err = syn::parse_str::("error_type(constructor = Foo::new)").unwrap_err(); + assert_eq!( + err.to_string(), + "num_enum error_type attribute requires `name` value" + ); + } + + #[test] + fn missing_both() { + let err = syn::parse_str::("error_type()").unwrap_err(); + assert_eq!( + err.to_string(), + "num_enum error_type attribute requires `name` and `constructor` values" + ); + } + + #[test] + fn extra_attr() { + let err = syn::parse_str::( + "error_type(name = Foo, constructor = Foo::new, extra = unneeded)", + ) + .unwrap_err(); + assert_eq!(err.to_string(), "expected `name` or `constructor`"); + } + + #[test] + fn multiple_names() { + let err = syn::parse_str::( + "error_type(name = Foo, name = Foo, constructor = Foo::new)", + ) + .unwrap_err(); + assert_eq!( + err.to_string(), + "num_enum error_type attribute must have exactly one `name` value" + ); + } + + #[test] + fn multiple_constructors() { + let err = syn::parse_str::( + "error_type(name = Foo, constructor = Foo::new, constructor = Foo::new)", + ) + .unwrap_err(); + assert_eq!( + err.to_string(), + "num_enum error_type attribute must have exactly one `constructor` value" + ); + } +} diff --git a/num_enum_derive/src/lib.rs b/num_enum_derive/src/lib.rs index e183d66..4fc4927 100644 --- a/num_enum_derive/src/lib.rs +++ b/num_enum_derive/src/lib.rs @@ -8,8 +8,10 @@ use proc_macro2::Span; use quote::quote; use syn::{parse_macro_input, Expr, Ident}; +mod enum_attributes; mod parsing; use parsing::{get_crate_name, EnumInfo}; +mod utils; mod variant_attributes; /// Implements `Into` for a `#[repr(Primitive)] enum`. @@ -191,7 +193,10 @@ pub fn derive_try_from_primitive(input: TokenStream) -> TokenStream { let krate = Ident::new(&get_crate_name(), Span::call_site()); let EnumInfo { - ref name, ref repr, .. + ref name, + ref repr, + ref error_type_info, + .. } = enum_info; let variant_idents: Vec = enum_info.variant_idents(); @@ -200,9 +205,13 @@ pub fn derive_try_from_primitive(input: TokenStream) -> TokenStream { debug_assert_eq!(variant_idents.len(), variant_expressions.len()); + let error_type = &error_type_info.name; + let error_constructor = &error_type_info.constructor; + TokenStream::from(quote! { impl ::#krate::TryFromPrimitive for #name { type Primitive = #repr; + type Error = #error_type; const NAME: &'static str = stringify!(#name); @@ -210,7 +219,7 @@ pub fn derive_try_from_primitive(input: TokenStream) -> TokenStream { number: Self::Primitive, ) -> ::core::result::Result< Self, - ::#krate::TryFromPrimitiveError + #error_type > { // Use intermediate const(s) so that enums defined like // `Two = ONE + 1u8` work properly. @@ -228,19 +237,19 @@ pub fn derive_try_from_primitive(input: TokenStream) -> TokenStream { )* #[allow(unreachable_patterns)] _ => ::core::result::Result::Err( - ::#krate::TryFromPrimitiveError { number } + #error_constructor ( number ) ), } } } impl ::core::convert::TryFrom<#repr> for #name { - type Error = ::#krate::TryFromPrimitiveError; + type Error = #error_type; #[inline] fn try_from ( number: #repr, - ) -> ::core::result::Result> + ) -> ::core::result::Result { ::#krate::TryFromPrimitive::try_from_primitive(number) } diff --git a/num_enum_derive/src/parsing.rs b/num_enum_derive/src/parsing.rs index 1dd0523..9d511b5 100644 --- a/num_enum_derive/src/parsing.rs +++ b/num_enum_derive/src/parsing.rs @@ -1,31 +1,20 @@ +use crate::enum_attributes::ErrorTypeAttribute; +use crate::utils::die; use crate::variant_attributes::{NumEnumVariantAttributeItem, NumEnumVariantAttributes}; use proc_macro2::Span; use quote::{format_ident, ToTokens}; use std::collections::BTreeSet; use syn::{ parse::{Parse, ParseStream}, - parse_quote, Attribute, Data, DeriveInput, Error, Expr, ExprLit, ExprUnary, Fields, Ident, Lit, - LitInt, Meta, Result, UnOp, + parse_quote, Attribute, Data, DeriveInput, Expr, ExprLit, ExprUnary, Fields, Ident, Lit, + LitInt, Meta, Path, Result, UnOp, }; -macro_rules! die { - ($spanned:expr=> - $msg:expr - ) => { - return Err(Error::new_spanned($spanned, $msg)) - }; - - ( - $msg:expr - ) => { - return Err(Error::new(Span::call_site(), $msg)) - }; -} - pub(crate) struct EnumInfo { pub(crate) name: Ident, pub(crate) repr: Ident, pub(crate) variants: Vec, + pub(crate) error_type_info: ErrorType, } impl EnumInfo { @@ -95,6 +84,51 @@ impl EnumInfo { .map(|variant| variant.all_values().cloned().collect()) .collect() } + + fn parse_attrs>( + mut attrs: Attrs, + ) -> Result<(Ident, Option)> { + let mut maybe_repr = None; + let mut maybe_error_type = None; + while let Some(attr) = attrs.next() { + if let Meta::List(meta_list) = &attr.meta { + if let Some(ident) = meta_list.path.get_ident() { + if ident == "repr" { + let mut nested = meta_list.tokens.clone().into_iter(); + let repr_tree = match (nested.next(), nested.next()) { + (Some(repr_tree), None) => repr_tree, + _ => die!(attr => + "Expected exactly one `repr` argument" + ), + }; + let repr_ident: Ident = parse_quote! { + #repr_tree + }; + if repr_ident == "C" { + die!(repr_ident => + "repr(C) doesn't have a well defined size" + ); + } else { + maybe_repr = Some(repr_ident); + } + } else if ident == "num_enum" { + let attributes = + attr.parse_args_with(crate::enum_attributes::Attributes::parse)?; + if let Some(error_type) = attributes.error_type { + if maybe_error_type.is_some() { + die!(attr => "At most one num_enum error_type attribute may be specified"); + } + maybe_error_type = Some(error_type.into()); + } + } + } + } + } + if maybe_repr.is_none() { + die!("Missing `#[repr({Integer})]` attribute"); + } + Ok((maybe_repr.unwrap(), maybe_error_type)) + } } impl Parse for EnumInfo { @@ -108,38 +142,7 @@ impl Parse for EnumInfo { Data::Struct(data) => die!(data.struct_token => "Expected enum but found struct"), }; - let repr: Ident = { - let mut attrs = input.attrs.into_iter(); - loop { - if let Some(attr) = attrs.next() { - if let Meta::List(meta_list) = &attr.meta { - if let Some(ident) = meta_list.path.get_ident() { - if ident == "repr" { - let mut nested = meta_list.tokens.clone().into_iter(); - let repr = match (nested.next(), nested.next()) { - (Some(repr), None) => repr, - _ => die!(attr => - "Expected exactly one `repr` argument" - ), - }; - let repr: Ident = parse_quote! { - #repr - }; - if repr == "C" { - die!(repr => - "repr(C) doesn't have a well defined size" - ); - } else { - break repr; - } - } - } - } - } else { - die!("Missing `#[repr({Integer})]` attribute"); - } - } - }; + let (repr, maybe_error_type) = Self::parse_attrs(input.attrs.into_iter())?; let mut variants: Vec = vec![]; let mut has_default_variant: bool = false; @@ -380,10 +383,23 @@ impl Parse for EnumInfo { } } + let error_type_info = maybe_error_type.unwrap_or_else(|| { + let crate_name = Ident::new(&get_crate_name(), Span::call_site()); + ErrorType { + name: parse_quote! { + ::#crate_name::TryFromPrimitiveError + }, + constructor: parse_quote! { + ::#crate_name::TryFromPrimitiveError::::new + }, + } + }); + EnumInfo { name, repr, variants, + error_type_info, } }) } @@ -497,6 +513,20 @@ impl VariantInfo { } } +pub(crate) struct ErrorType { + pub(crate) name: Path, + pub(crate) constructor: Path, +} + +impl From for ErrorType { + fn from(attribute: ErrorTypeAttribute) -> Self { + Self { + name: attribute.name.path, + constructor: attribute.constructor.path, + } + } +} + #[cfg(feature = "proc-macro-crate")] pub(crate) fn get_crate_name() -> String { let found_crate = proc_macro_crate::crate_name("num_enum").unwrap_or_else(|err| { diff --git a/num_enum_derive/src/utils.rs b/num_enum_derive/src/utils.rs new file mode 100644 index 0000000..ed6ab87 --- /dev/null +++ b/num_enum_derive/src/utils.rs @@ -0,0 +1,15 @@ +macro_rules! die { + ($spanned:expr=> + $msg:expr + ) => { + return Err(::syn::Error::new_spanned($spanned, $msg)) + }; + + ( + $msg:expr + ) => { + return Err(::syn::Error::new(::proc_macro2::Span::call_site(), $msg)) + }; +} + +pub(crate) use die;