diff --git a/CHANGELOG.md b/CHANGELOG.md index ff4afc4..5daf2b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,8 @@ * Support integration with [`arbitrary`](https://crates.io/crates/arbitrary) crate (see `arbitrary` feature). * Support `Arbitrary` for integer types -* Ability to specify boundaries (`greater`, `greater_or_equal`, `less`, `less_or_equal`, `len_char_min`, `len_char_max`) with expressions or constants. + * Support `Arbitrary` for float types +* Ability to specify boundaries (`greater`, `greater_or_equal`, `less`, `less_or_equal`, `len_char_min`, `len_char_max`) with expressions or named constants. ### v0.4.0 - 2023-11-21 * Support of arbitrary inner types with custom sanitizers and validators. diff --git a/Cargo.lock b/Cargo.lock index 139a40e..f74e3a3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -103,6 +103,15 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68b0cf012f1230e43cd00ebb729c6bb58707ecfa8ad08b52ef3a4ccd2697fc30" +[[package]] +name = "float_arbitrary" +version = "0.1.0" +dependencies = [ + "arbitrary", + "arbtest", + "nutype", +] + [[package]] name = "float_sortable" version = "0.1.0" diff --git a/examples/float_arbitrary/Cargo.toml b/examples/float_arbitrary/Cargo.toml new file mode 100644 index 0000000..a74ae45 --- /dev/null +++ b/examples/float_arbitrary/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "float_arbitrary" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +arbitrary = "1.3.2" +arbtest = "0.2.0" +nutype = { path = "../../nutype", features = ["arbitrary"] } diff --git a/examples/float_arbitrary/src/main.rs b/examples/float_arbitrary/src/main.rs new file mode 100644 index 0000000..eb44b0d --- /dev/null +++ b/examples/float_arbitrary/src/main.rs @@ -0,0 +1,114 @@ +use arbitrary::Arbitrary; +use nutype::nutype; + +#[nutype( + derive(Debug, Arbitrary), + sanitize(with = |x| x), +)] +struct UnrestrictedFloatNumber(f64); + +#[nutype(derive(Debug, Arbitrary), validate(finite))] +struct FiniteF64(f64); + +#[nutype(derive(Debug, Arbitrary), validate(finite))] +struct FiniteF32(f32); + +#[nutype(derive(Debug, Arbitrary), validate(greater_or_equal = -64.4))] +struct GreaterOrEqualF64(f64); + +#[nutype(derive(Debug, Arbitrary), validate(greater_or_equal = 32.2))] +struct GreaterOrEqualF32(f32); + +#[nutype(derive(Debug, Arbitrary), validate(greater = -64.0))] +struct GreaterF64(f64); + +#[nutype(derive(Debug, Arbitrary), validate(greater = 32.0))] +struct GreaterF32(f32); + +#[nutype(derive(Debug, Arbitrary), validate(greater = -1.0, less = 1.0))] +struct GreaterAndLessF64(f64); + +#[nutype(derive(Debug, Arbitrary), validate(greater = -10.0, less = 10.0))] +struct GreaterAndLessF32(f32); + +#[nutype(derive(Debug, Arbitrary), validate(greater_or_equal = -10.0, less = 10.0))] +struct GreaterOrEqualAndLessF32(f32); + +#[nutype(derive(Debug, Arbitrary), validate(greater = -10.0, less_or_equal = 10.0))] +struct GreaterAndLessOrEqualF32(f32); + +#[nutype(derive(Debug, Arbitrary), validate(greater = -1.0, less_or_equal = -0.5))] +struct GreaterOrEqualAndLessOrEqualF64(f64); + +fn main() { + arbtest::builder().run(|u| { + let _num = UnrestrictedFloatNumber::arbitrary(u)?.into_inner(); + Ok(()) + }); + + arbtest::builder().run(|u| { + let value: f64 = FiniteF64::arbitrary(u)?.into_inner(); + assert!(value.is_finite()); + Ok(()) + }); + + arbtest::builder().run(|u| { + let value: f32 = FiniteF32::arbitrary(u)?.into_inner(); + assert!(value.is_finite()); + Ok(()) + }); + + arbtest::builder().run(|u| { + let value: f64 = GreaterOrEqualF64::arbitrary(u)?.into_inner(); + assert!(value >= -64.4); + Ok(()) + }); + + arbtest::builder().run(|u| { + let value: f32 = GreaterOrEqualF32::arbitrary(u)?.into_inner(); + assert!(value >= 32.2); + Ok(()) + }); + + arbtest::builder().run(|u| { + let value: f64 = GreaterF64::arbitrary(u)?.into_inner(); + assert!(value > -64.0); + Ok(()) + }); + + arbtest::builder().run(|u| { + let value: f32 = GreaterF32::arbitrary(u)?.into_inner(); + assert!(value > 32.0); + Ok(()) + }); + + arbtest::builder().run(|u| { + let value: f64 = GreaterAndLessF64::arbitrary(u)?.into_inner(); + assert!(value > -1.0 && value < 1.0); + Ok(()) + }); + + arbtest::builder().run(|u| { + let value: f32 = GreaterAndLessF32::arbitrary(u)?.into_inner(); + assert!(value > -10.0 && value < 10.0); + Ok(()) + }); + + arbtest::builder().run(|u| { + let value: f32 = GreaterOrEqualAndLessF32::arbitrary(u)?.into_inner(); + assert!((-10.0..10.0).contains(&value)); + Ok(()) + }); + + arbtest::builder().run(|u| { + let value: f32 = GreaterAndLessOrEqualF32::arbitrary(u)?.into_inner(); + assert!(value > -10.0 && value <= 10.0); + Ok(()) + }); + + arbtest::builder().run(|u| { + let value: f64 = GreaterOrEqualAndLessOrEqualF64::arbitrary(u)?.into_inner(); + assert!((-1.0..=-0.5).contains(&value)); + Ok(()) + }); +} diff --git a/nutype_macros/src/common/validate.rs b/nutype_macros/src/common/validate.rs index 427b897..477a4be 100644 --- a/nutype_macros/src/common/validate.rs +++ b/nutype_macros/src/common/validate.rs @@ -78,6 +78,15 @@ where return Err(err); } + // less VS greater + if let (Some(lower), Some(upper)) = (maybe_greater.clone(), maybe_less.clone()) { + if lower.item >= upper.item { + let msg = "The lower bound (`greater`) cannot be equal or higher than the upper bound (`less`)."; + let err = syn::Error::new(upper.span(), msg); + return Err(err); + } + } + let maybe_lower_bound = maybe_greater.or(maybe_greater_or_equal); let maybe_upper_bound = maybe_less.or(maybe_less_or_equal); diff --git a/nutype_macros/src/float/gen/mod.rs b/nutype_macros/src/float/gen/mod.rs index 3ca79a2..4dda353 100644 --- a/nutype_macros/src/float/gen/mod.rs +++ b/nutype_macros/src/float/gen/mod.rs @@ -140,15 +140,16 @@ where maybe_error_type_name: Option, traits: HashSet, maybe_default_value: Option, - _guard: &FloatGuard, + guard: &FloatGuard, ) -> Result { - Ok(gen_traits( + gen_traits( type_name, inner_type, maybe_error_type_name, maybe_default_value, traits, - )) + guard, + ) } fn gen_tests( diff --git a/nutype_macros/src/float/gen/traits/arbitrary.rs b/nutype_macros/src/float/gen/traits/arbitrary.rs new file mode 100644 index 0000000..4050311 --- /dev/null +++ b/nutype_macros/src/float/gen/traits/arbitrary.rs @@ -0,0 +1,394 @@ +use proc_macro2::{Span, TokenStream}; +use quote::{quote, ToTokens}; + +use crate::{ + common::models::TypeName, + float::models::{ + FloatGuard, FloatInnerType, FloatSanitizer, FloatSanitizerKind, FloatValidator, + FloatValidatorKind, + }, + utils::issue_reporter::{build_github_link_with_issue, Issue}, +}; + +pub fn gen_impl_trait_arbitrary( + type_name: &TypeName, + inner_type: &FloatInnerType, + guard: &FloatGuard, +) -> Result { + let construct_value = if guard.has_validation() { + // If by some reason we generate an invalid value, make it very easy for the user to report + let report_issue_msg = + build_github_link_with_issue(&Issue::ArbitraryGeneratedInvalidValue { + inner_type: inner_type.to_string(), + }); + let type_name = type_name.to_string(); + quote!( + Self::new(inner_value).unwrap_or_else(|err| { + // Panic with the maximum details about what went wrong + panic!("\nArbitrary generated an invalid value for {}.\nInvalid inner value: {:?}\nValidation error: {:?}\n\n{}", #type_name, inner_value, err, #report_issue_msg); + }) + ) + } else { + quote!(Self::new(inner_value)) + }; + + let generate_inner_value = gen_generate_valid_inner_value(inner_type, guard)?; + + Ok(quote!( + impl ::arbitrary::Arbitrary<'_> for #type_name { + fn arbitrary(u: &mut ::arbitrary::Unstructured<'_>) -> ::arbitrary::Result { + let inner_value: #inner_type = { + #generate_inner_value + }; + Ok(#construct_value) + } + + #[inline] + fn size_hint(_depth: usize) -> (usize, Option) { + let n = ::core::mem::size_of::<#inner_type>(); + (n, Some(n)) + } + } + )) +} + +/// Generates a code that generates a valid inner value. +fn gen_generate_valid_inner_value( + inner_type: &FloatInnerType, + guard: &FloatGuard, +) -> Result { + match guard { + FloatGuard::WithoutValidation { .. } => { + // When there is no validation, then we can just simply delegate to the arbitrary + // crate, and the job is done. + Ok(quote!(u.arbitrary()?)) + } + FloatGuard::WithValidation { + sanitizers, + validators, + } => { + // When there is validation, then we need to generate a valid value. + gen_generate_valid_inner_value_with_validators(inner_type, sanitizers, validators) + } + } +} + +fn gen_generate_valid_inner_value_with_validators( + inner_type: &FloatInnerType, + sanitizers: &[FloatSanitizer], + validators: &[FloatValidator], +) -> Result { + let validator_kinds: Vec = validators.iter().map(|v| v.kind()).collect(); + let sanitizer_kinds: Vec = sanitizers.iter().map(|s| s.kind()).collect(); + + if validator_kinds.contains(&FloatValidatorKind::Predicate) { + let span = Span::call_site(); + let msg = "It's not possible to derive `Arbitrary` trait for a type with `predicate` validator.\nYou have to implement `Arbitrary` trait on you own."; + return Err(syn::Error::new(span, msg)); + } + if sanitizer_kinds.contains(&FloatSanitizerKind::With) { + let span = Span::call_site(); + let msg = "It's not possible to derive `Arbitrary` trait for a type with `with` sanitizer and validations.\nYou have to implement `Arbitrary` trait on you own."; + return Err(syn::Error::new(span, msg)); + } + + let basic_value_kind = compute_basic_value_kind(&validator_kinds); + let basic_value = generate_basic_value(inner_type, basic_value_kind); + let boundaries = compute_boundaries(validators); + + Ok(normalize_basic_value_for_boundaries( + inner_type, + basic_value, + boundaries, + )) +} + +fn normalize_basic_value_for_boundaries( + inner_type: &FloatInnerType, + basic_value: TokenStream, + boundaries: Boundaries, +) -> TokenStream { + match (boundaries.lower, boundaries.upper) { + (Some(lower), Some(upper)) => { + // In this case we don't use `basic_value` we generate a new value that lays in between + // 0.0 and 1.0 and then scale it to the range of the boundaries. + let arbitrary_in_01_range = gen_in_01_range(inner_type); + + let lower_value = &lower.value; + let upper_value = &upper.value; + let adjust_x_lower = gen_adjust_x_for_lower_boundary(inner_type, &lower); + let adjust_x_upper = gen_adjust_x_for_upper_boundary(inner_type, &lower); + quote! { + let from0to1 = #arbitrary_in_01_range; + + // Scale range [0; 1] to the range of the boundaries + let range = (#upper_value - #lower_value).abs(); + let x = #lower_value + from0to1 * range; + + // Make sure we satisfy the exclusive boundaries + let x = #adjust_x_lower; + let x = #adjust_x_upper; + x + } + } + (Some(lower), None) => { + let lower_value = &lower.value; + let adjust_x = gen_adjust_x_for_lower_boundary(inner_type, &lower); + quote! { + // Compute initial basic value + let basic_value = #basic_value; + let positive_basic_value = basic_value.abs(); + let x = positive_basic_value + #lower_value; + #adjust_x + } + } + (None, Some(upper)) => { + let upper_value = &upper.value; + let adjust_x = gen_adjust_x_for_upper_boundary(inner_type, &upper); + quote! { + // Compute initial basic value + let basic_value = #basic_value; + let negative_basic_value = -basic_value.abs(); + let x = negative_basic_value + #upper_value; + #adjust_x + } + } + (None, None) => basic_value, + } +} + +fn gen_adjust_x_for_upper_boundary( + float_type: &FloatInnerType, + upper_boundary: &Boundary, +) -> TokenStream { + if upper_boundary.is_inclusive { + quote! { x } + } else { + let upper_value = &upper_boundary.value; + let correction_delta = correction_delta_for_float_type(float_type); + quote! { + if x >= #upper_value { + x - #correction_delta + } else { + x + } + } + } +} + +fn gen_adjust_x_for_lower_boundary( + float_type: &FloatInnerType, + lower_boundary: &Boundary, +) -> TokenStream { + if lower_boundary.is_inclusive { + quote! { x } + } else { + let lower_value = &lower_boundary.value; + let correction_delta = correction_delta_for_float_type(float_type); + quote! { + if x <= #lower_value { + // Since there is no upper boundary, we are free to add any positive value here + // to adjust so we can satisfy the exclusive lower boundary. + x + #correction_delta + } else { + x + } + } + } +} + +/// A tiny value that is used to correct the value to satisfy the exclusive boundaries if +/// necessary. +/// For example, if the constraint is `greater = 0.0`, then and we obtain exactly `0.0` when +/// generating a pseudo-random value, then we need to add a tiny value to it to make it +/// satisfy `x > 0.0` check. +/// +/// Unfortunately things like `f32::EPSILON` or `f64::EPSILON` are not suitable for this purpose. +/// The constants are found experimentally. +fn correction_delta_for_float_type(float_type: &FloatInnerType) -> TokenStream { + match float_type { + FloatInnerType::F32 => quote!(0.000_002), + FloatInnerType::F64 => quote!(0.000_000_000_000_004), + } +} + +/// Generate a code snippet that generates a random float in range [0; 1]. +/// Assumptions +/// * There is variable `u` in the given context which is a value of `arbitrary::Unstructured`. +fn gen_in_01_range(float_type: &FloatInnerType) -> TokenStream { + let int_type = match float_type { + FloatInnerType::F32 => quote!(u32), + FloatInnerType::F64 => quote!(u64), + }; + + quote! ( + { // { + let random_int: #int_type = u.arbitrary()?; // let random_int: u32 = u.arbitrary()?; + (random_int as #float_type / #int_type::MAX as #float_type) as #float_type // (random_int as f32 / u32::MAX as f32) as f32 + } // } + ) +} + +struct Boundaries { + lower: Option, + upper: Option, +} + +struct Boundary { + value: TokenStream, + is_inclusive: bool, +} + +/// Describes a type of initial basic value that has to be generated. +enum BasicValueKind { + /// All float values including NaN and Infinity. + All, + + /// All float values except NaN + NotNaN, + + /// All float values except NaN and infinities + Finite, +} + +fn compute_basic_value_kind(validators: &[FloatValidatorKind]) -> BasicValueKind { + let has_boundaries = || { + validators.contains(&FloatValidatorKind::Greater) + || validators.contains(&FloatValidatorKind::GreaterOrEqual) + || validators.contains(&FloatValidatorKind::Less) + || validators.contains(&FloatValidatorKind::LessOrEqual) + }; + + if validators.contains(&FloatValidatorKind::Finite) { + BasicValueKind::Finite + } else if has_boundaries() { + BasicValueKind::NotNaN + } else { + BasicValueKind::All + } +} + +fn compute_boundaries(validators: &[FloatValidator]) -> Boundaries { + let mut lower = None; + let mut upper = None; + + // NOTE: It's guaranteed that either Greater or GreaterOrEqual present, but not both, + // Same for Less and LessOrEqual. + // This handled by prior validation. + for validator in validators { + match validator { + FloatValidator::Greater(expr) => { + let value = quote!(#expr); + let is_inclusive = false; + lower = Some(Boundary { + value, + is_inclusive, + }); + } + FloatValidator::GreaterOrEqual(expr) => { + let value = quote!(#expr); + let is_inclusive = true; + lower = Some(Boundary { + value, + is_inclusive, + }); + } + FloatValidator::Less(expr) => { + let value = quote!(#expr); + let is_inclusive = false; + upper = Some(Boundary { + value, + is_inclusive, + }); + } + FloatValidator::LessOrEqual(expr) => { + let value = quote!(#expr); + let is_inclusive = true; + upper = Some(Boundary { + value, + is_inclusive, + }); + } + FloatValidator::Finite | FloatValidator::Predicate(..) => { + // We don't care about these validators here. + } + } + } + + Boundaries { lower, upper } +} + +fn generate_basic_value(inner_type: &FloatInnerType, kind: BasicValueKind) -> TokenStream { + match kind { + BasicValueKind::All => quote!(u.arbitrary()?), + BasicValueKind::NotNaN => generate_not_nan_float(inner_type), + BasicValueKind::Finite => generate_finite_float(inner_type), + } +} + +// Generate a code that generates a float which is not infinite and not NaN. +fn generate_finite_float(inner_type: &FloatInnerType) -> TokenStream { + let condition = quote!(value.is_finite()); + generate_float_with_condition(inner_type, condition) +} + +// Generate a code that generates a float which is not NaN +fn generate_not_nan_float(inner_type: &FloatInnerType) -> TokenStream { + let condition = quote!(!value.is_nan()); + generate_float_with_condition(inner_type, condition) +} + +// Generates a block of code that uses arbitrary and deterministic mutations to find a value that +// matches the given condition. +// IMPORTANT: The condition must be something like check against NaN or infinity and should not be +// a check against a range of values (otherwise it may loop forever). +// +// The generated code takes the following assumptions: +// * There is variable `u` in the given context which is a value of `arbitrary::Unstructured`. +// * `condition` is a closure that does a check against variable `value` and returns a bool. +fn generate_float_with_condition( + inner_type: &FloatInnerType, + condition: TokenStream, +) -> TokenStream { + quote! ( + { + let condition = |value: #inner_type| #condition; + + 'outer: loop { + let original_value: #inner_type = u.arbitrary()?; + + if condition(original_value) { + break original_value; + } else { + // If the original value obtained from arbitrary does not match the condition, + // we try to mangle/randomize it deterministically in a loop 1000 times, until we + // reach out for another value from arbitrary. + // Generally it must be more than enough, cause what we typically need is to avoid + // NaN and infinity. + + // This returns + // * [u8; 4] for f32 + // * [u8; 8] for f64 + let mut bytes = original_value.to_be_bytes(); + for i in 0..1000 { + // With every iteration we modify next single byte by adding `i` value to + // it. + let index = i % std::mem::size_of::<#inner_type>(); + bytes[index] = bytes[index].wrapping_add((i % 256) as u8); + + // Try to convert the bytes back to float in both BE and NE formats and see + // if we get something what we need + let new_float_be = #inner_type::from_be_bytes(bytes); + if condition(new_float_be) { + break 'outer new_float_be; + } + let new_float_ne = #inner_type::from_ne_bytes(bytes); + if condition(new_float_ne) { + break 'outer new_float_ne; + } + } + } + } + } + ) +} diff --git a/nutype_macros/src/float/gen/traits.rs b/nutype_macros/src/float/gen/traits/mod.rs similarity index 76% rename from nutype_macros/src/float/gen/traits.rs rename to nutype_macros/src/float/gen/traits/mod.rs index 68fb9ab..324f1b4 100644 --- a/nutype_macros/src/float/gen/traits.rs +++ b/nutype_macros/src/float/gen/traits/mod.rs @@ -1,3 +1,4 @@ +pub mod arbitrary; use std::collections::HashSet; use proc_macro2::TokenStream; @@ -14,7 +15,7 @@ use crate::{ }, models::{ErrorTypeName, TypeName}, }, - float::models::{FloatDeriveTrait, FloatInnerType}, + float::models::{FloatDeriveTrait, FloatGuard, FloatInnerType}, }; type FloatGeneratableTrait = GeneratableTrait; @@ -47,6 +48,7 @@ enum FloatIrregularTrait { Default, SerdeSerialize, SerdeDeserialize, + ArbitraryArbitrary, } impl From for FloatGeneratableTrait { @@ -94,6 +96,9 @@ impl From for FloatGeneratableTrait { FloatDeriveTrait::SerdeDeserialize => { FloatGeneratableTrait::Irregular(FloatIrregularTrait::SerdeDeserialize) } + FloatDeriveTrait::ArbitraryArbitrary => { + FloatGeneratableTrait::Irregular(FloatIrregularTrait::ArbitraryArbitrary) + } FloatDeriveTrait::SchemarsJsonSchema => { FloatGeneratableTrait::Transparent(FloatTransparentTrait::SchemarsJsonSchema) } @@ -115,13 +120,14 @@ impl ToTokens for FloatTransparentTrait { } } -pub fn gen_traits( +pub fn gen_traits( type_name: &TypeName, inner_type: &FloatInnerType, maybe_error_type_name: Option, maybe_default_value: Option, traits: HashSet, -) -> GeneratedTraits { + guard: &FloatGuard, +) -> Result { let GeneratableTraits { transparent_traits, irregular_traits, @@ -139,55 +145,60 @@ pub fn gen_traits( maybe_error_type_name, maybe_default_value, irregular_traits, - ); + guard, + )?; - GeneratedTraits { + Ok(GeneratedTraits { derive_transparent_traits, implement_traits, - } + }) } -fn gen_implemented_traits( +fn gen_implemented_traits( type_name: &TypeName, inner_type: &FloatInnerType, maybe_error_type_name: Option, maybe_default_value: Option, impl_traits: Vec, -) -> TokenStream { + guard: &FloatGuard, +) -> Result { impl_traits .iter() .map(|t| match t { - FloatIrregularTrait::AsRef => gen_impl_trait_as_ref(type_name, inner_type), - FloatIrregularTrait::Deref => gen_impl_trait_deref(type_name, inner_type), + FloatIrregularTrait::AsRef => Ok(gen_impl_trait_as_ref(type_name, inner_type)), + FloatIrregularTrait::Deref => Ok(gen_impl_trait_deref(type_name, inner_type)), FloatIrregularTrait::FromStr => { - gen_impl_trait_from_str(type_name, inner_type, maybe_error_type_name.as_ref()) + Ok(gen_impl_trait_from_str(type_name, inner_type, maybe_error_type_name.as_ref())) } - FloatIrregularTrait::From => gen_impl_trait_from(type_name, inner_type), - FloatIrregularTrait::Into => gen_impl_trait_into(type_name, inner_type), + FloatIrregularTrait::From => Ok(gen_impl_trait_from(type_name, inner_type)), + FloatIrregularTrait::Into => Ok(gen_impl_trait_into(type_name, inner_type)), FloatIrregularTrait::TryFrom => { - gen_impl_trait_try_from(type_name, inner_type, maybe_error_type_name.as_ref()) + Ok(gen_impl_trait_try_from(type_name, inner_type, maybe_error_type_name.as_ref())) } - FloatIrregularTrait::Borrow => gen_impl_trait_borrow(type_name, inner_type), - FloatIrregularTrait::Display => gen_impl_trait_display(type_name), + FloatIrregularTrait::Borrow => Ok(gen_impl_trait_borrow(type_name, inner_type)), + FloatIrregularTrait::Display => Ok(gen_impl_trait_display(type_name)), FloatIrregularTrait::Default => match maybe_default_value { Some(ref default_value) => { let has_validation = maybe_error_type_name.is_some(); - gen_impl_trait_default(type_name, default_value, has_validation) + Ok(gen_impl_trait_default(type_name, default_value, has_validation)) } None => { - panic!( - "Default trait is derived for type {type_name}, but `default = ` is missing" - ); + let span = proc_macro2::Span::call_site(); + let msg = format!("Trait `Default` is derived for type {type_name}, but `default = ` parameter is missing in #[nutype] macro"); + Err(syn::Error::new(span, msg)) } }, - FloatIrregularTrait::SerdeSerialize => gen_impl_trait_serde_serialize(type_name), - FloatIrregularTrait::SerdeDeserialize => gen_impl_trait_serde_deserialize( + FloatIrregularTrait::SerdeSerialize => Ok(gen_impl_trait_serde_serialize(type_name)), + FloatIrregularTrait::SerdeDeserialize => Ok(gen_impl_trait_serde_deserialize( type_name, inner_type, maybe_error_type_name.as_ref(), - ), - FloatIrregularTrait::Eq => gen_impl_trait_eq(type_name), - FloatIrregularTrait::Ord => gen_impl_trait_ord(type_name), + )), + FloatIrregularTrait::Eq => Ok(gen_impl_trait_eq(type_name)), + FloatIrregularTrait::Ord => Ok(gen_impl_trait_ord(type_name)), + FloatIrregularTrait::ArbitraryArbitrary => { + arbitrary::gen_impl_trait_arbitrary(type_name, inner_type, guard) + } }) .collect() } diff --git a/nutype_macros/src/float/models.rs b/nutype_macros/src/float/models.rs index b6650c3..28bc074 100644 --- a/nutype_macros/src/float/models.rs +++ b/nutype_macros/src/float/models.rs @@ -1,6 +1,5 @@ use kinded::Kinded; use proc_macro2::TokenStream; -use std::fmt::Debug; use crate::common::models::{ impl_numeric_bound_on_vec_of, impl_numeric_bound_validator, Guard, RawGuard, SpannedItem, @@ -11,7 +10,7 @@ use crate::common::models::{ // #[derive(Debug, Kinded)] -#[kinded(display = "snake_case")] +#[kinded(display = "snake_case", derive(Hash))] pub enum FloatSanitizer { With(TypedCustomFunction), _Phantom(std::marker::PhantomData), @@ -23,7 +22,7 @@ pub type SpannedFloatSanitizer = SpannedItem>; // #[derive(Debug, Kinded)] -#[kinded(display = "snake_case")] +#[kinded(display = "snake_case", derive(Hash))] pub enum FloatValidator { Greater(ValueOrExpr), GreaterOrEqual(ValueOrExpr), @@ -64,7 +63,7 @@ pub enum FloatDeriveTrait { SerdeSerialize, SerdeDeserialize, SchemarsJsonSchema, - // Arbitrary, + ArbitraryArbitrary, } impl TypeTrait for FloatDeriveTrait { @@ -105,6 +104,16 @@ macro_rules! define_float_inner_type { type_stream.to_tokens(token_stream); } } + + impl ::core::fmt::Display for FloatInnerType { + fn fmt(&self, f: &mut ::core::fmt::Formatter) -> Result<(), ::core::fmt::Error> { + match self { + $( + Self::$variant => stringify!($tp).fmt(f), + )* + } + } + } } } diff --git a/nutype_macros/src/float/validate.rs b/nutype_macros/src/float/validate.rs index 58d4e40..c058baf 100644 --- a/nutype_macros/src/float/validate.rs +++ b/nutype_macros/src/float/validate.rs @@ -196,10 +196,6 @@ fn to_float_derive_trait( DeriveTrait::SerdeSerialize => Ok(FloatDeriveTrait::SerdeSerialize), DeriveTrait::SerdeDeserialize => Ok(FloatDeriveTrait::SerdeDeserialize), DeriveTrait::SchemarsJsonSchema => Ok(FloatDeriveTrait::SchemarsJsonSchema), - DeriveTrait::ArbitraryArbitrary => { - // TODO: Implement deriving Arbitrary - let msg = "Deriving Arbitrary trait for float types is not yet implemented"; - Err(syn::Error::new(span, msg)) - } + DeriveTrait::ArbitraryArbitrary => Ok(FloatDeriveTrait::ArbitraryArbitrary), } } diff --git a/nutype_macros/src/integer/gen/traits/arbitrary.rs b/nutype_macros/src/integer/gen/traits/arbitrary.rs index 4b9504b..ff646ff 100644 --- a/nutype_macros/src/integer/gen/traits/arbitrary.rs +++ b/nutype_macros/src/integer/gen/traits/arbitrary.rs @@ -36,6 +36,12 @@ pub fn gen_impl_trait_arbitrary( Ok(#construct_value) } } + + #[inline] + fn size_hint(_depth: usize) -> (usize, Option) { + let n = ::core::mem::size_of::<#inner_type>(); + (n, Some(n)) + } )) } diff --git a/nutype_macros/src/integer/models.rs b/nutype_macros/src/integer/models.rs index cbfdd75..5505c8b 100644 --- a/nutype_macros/src/integer/models.rs +++ b/nutype_macros/src/integer/models.rs @@ -59,7 +59,7 @@ pub enum IntegerDeriveTrait { Default, Deref, - // // External crates + // External crates SerdeSerialize, SerdeDeserialize, SchemarsJsonSchema,