diff --git a/borsh-derive-internal/src/enum_de.rs b/borsh-derive-internal/src/enum_de.rs index 97cfde852..37f2d34cb 100644 --- a/borsh-derive-internal/src/enum_de.rs +++ b/borsh-derive-internal/src/enum_de.rs @@ -40,7 +40,7 @@ pub fn enum_de(input: &ItemEnum, cratename: Ident) -> syn::Result ); variant_header.extend(quote! { - #field_name: #cratename::BorshDeserialize::deserialize(buf)?, + #field_name: #cratename::BorshDeserialize::deserialize_reader(reader)?, }); } } @@ -59,8 +59,9 @@ pub fn enum_de(input: &ItemEnum, cratename: Ident) -> syn::Result .unwrap(), ); - variant_header - .extend(quote! { #cratename::BorshDeserialize::deserialize(buf)?, }); + variant_header.extend( + quote! { #cratename::BorshDeserialize::deserialize_reader(reader)?, }, + ); } } variant_header = quote! { ( #variant_header )}; @@ -72,12 +73,12 @@ pub fn enum_de(input: &ItemEnum, cratename: Ident) -> syn::Result }); } let variant_idx = quote! { - let variant_idx: u8 = #cratename::BorshDeserialize::deserialize(buf)?; + let variant_idx: u8 = #cratename::BorshDeserialize::deserialize_reader(reader)?; }; if let Some(method_ident) = init_method { Ok(quote! { impl #impl_generics #cratename::de::BorshDeserialize for #name #ty_generics #where_clause { - fn deserialize(buf: &mut &[u8]) -> ::core::result::Result { + fn deserialize_reader(reader: &mut R) -> ::core::result::Result { #variant_idx let mut return_value = match variant_idx { #variant_arms @@ -98,7 +99,7 @@ pub fn enum_de(input: &ItemEnum, cratename: Ident) -> syn::Result } else { Ok(quote! { impl #impl_generics #cratename::de::BorshDeserialize for #name #ty_generics #where_clause { - fn deserialize(buf: &mut &[u8]) -> ::core::result::Result { + fn deserialize_reader(reader: &mut R) -> ::core::result::Result { #variant_idx let return_value = match variant_idx { #variant_arms diff --git a/borsh-derive-internal/src/struct_de.rs b/borsh-derive-internal/src/struct_de.rs index d26192d2f..78a2460fd 100644 --- a/borsh-derive-internal/src/struct_de.rs +++ b/borsh-derive-internal/src/struct_de.rs @@ -34,7 +34,7 @@ pub fn struct_de(input: &ItemStruct, cratename: Ident) -> syn::Result syn::Result syn::Result ::core::result::Result { + fn deserialize_reader(reader: &mut R) -> ::core::result::Result { let mut return_value = #return_value; return_value.#method_ident(); Ok(return_value) @@ -74,7 +74,7 @@ pub fn struct_de(input: &ItemStruct, cratename: Ident) -> syn::Result ::core::result::Result { + fn deserialize_reader(reader: &mut R) -> ::core::result::Result { Ok(#return_value) } } diff --git a/borsh/src/de/mod.rs b/borsh/src/de/mod.rs index b6d089b54..0c6299944 100644 --- a/borsh/src/de/mod.rs +++ b/borsh/src/de/mod.rs @@ -11,7 +11,7 @@ use crate::maybestd::{ boxed::Box, collections::{BTreeMap, BTreeSet, BinaryHeap, HashMap, HashSet, LinkedList, VecDeque}, format, - io::{Error, ErrorKind, Result}, + io::{Error, ErrorKind, Read, Result}, string::{String, ToString}, vec, vec::Vec, @@ -32,7 +32,11 @@ const ERROR_INVALID_ZERO_VALUE: &str = "Expected a non-zero value"; pub trait BorshDeserialize: Sized { /// Deserializes this instance from a given slice of bytes. /// Updates the buffer to point at the remaining bytes. - fn deserialize(buf: &mut &[u8]) -> Result; + fn deserialize(buf: &mut &[u8]) -> Result { + Self::deserialize_reader(&mut *buf) + } + + fn deserialize_reader(reader: &mut R) -> Result; /// Deserialize this instance from a slice of bytes. fn try_from_slice(v: &[u8]) -> Result { @@ -44,64 +48,68 @@ pub trait BorshDeserialize: Sized { Ok(result) } + fn try_from_reader(reader: &mut R) -> Result { + let result = Self::deserialize_reader(reader)?; + let mut buf = [0u8; 1]; + match reader.read_exact(&mut buf) { + Err(f) if f.kind() == ErrorKind::UnexpectedEof => Ok(result), + _ => Err(Error::new(ErrorKind::InvalidData, ERROR_NOT_ALL_BYTES_READ)), + } + } + #[inline] #[doc(hidden)] - fn vec_from_bytes(len: u32, buf: &mut &[u8]) -> Result>> { + fn vec_from_reader(len: u32, reader: &mut R) -> Result>> { let _ = len; - let _ = buf; + let _ = reader; Ok(None) } #[inline] #[doc(hidden)] - fn array_from_bytes(buf: &mut &[u8]) -> Result> { - let _ = buf; + fn array_from_reader(reader: &mut R) -> Result> { + let _ = reader; Ok(None) } } +fn unexpected_eof_to_unexpected_length_of_input(e: Error) -> Error { + if e.kind() == ErrorKind::UnexpectedEof { + Error::new(ErrorKind::InvalidInput, ERROR_UNEXPECTED_LENGTH_OF_INPUT) + } else { + e + } +} + impl BorshDeserialize for u8 { #[inline] - fn deserialize(buf: &mut &[u8]) -> Result { - if buf.is_empty() { - return Err(Error::new( - ErrorKind::InvalidInput, - ERROR_UNEXPECTED_LENGTH_OF_INPUT, - )); - } - let res = buf[0]; - *buf = &buf[1..]; - Ok(res) + fn deserialize_reader(reader: &mut R) -> Result { + let mut buf = [0u8; 1]; + reader + .read_exact(&mut buf) + .map_err(unexpected_eof_to_unexpected_length_of_input)?; + Ok(buf[0]) } #[inline] #[doc(hidden)] - fn vec_from_bytes(len: u32, buf: &mut &[u8]) -> Result>> { - let len = len.try_into().map_err(|_| ErrorKind::InvalidInput)?; - if buf.len() < len { - return Err(Error::new( - ErrorKind::InvalidInput, - ERROR_UNEXPECTED_LENGTH_OF_INPUT, - )); - } - let (front, rest) = buf.split_at(len); - *buf = rest; - Ok(Some(front.to_vec())) + fn vec_from_reader(len: u32, reader: &mut R) -> Result>> { + let len: usize = len.try_into().map_err(|_| ErrorKind::InvalidInput)?; + let mut vec = vec![0u8; len]; + reader + .read_exact(vec.as_mut_slice()) + .map_err(unexpected_eof_to_unexpected_length_of_input)?; + Ok(Some(vec)) } #[inline] #[doc(hidden)] - fn array_from_bytes(buf: &mut &[u8]) -> Result> { - if buf.len() < N { - return Err(Error::new( - ErrorKind::InvalidInput, - ERROR_UNEXPECTED_LENGTH_OF_INPUT, - )); - } - let (front, rest) = buf.split_at(N); - *buf = rest; - let front: [u8; N] = front.try_into().unwrap(); - Ok(Some(front)) + fn array_from_reader(reader: &mut R) -> Result> { + let mut arr = [0u8; N]; + reader + .read_exact(&mut arr) + .map_err(unexpected_eof_to_unexpected_length_of_input)?; + Ok(Some(arr)) } } @@ -109,15 +117,12 @@ macro_rules! impl_for_integer { ($type: ident) => { impl BorshDeserialize for $type { #[inline] - fn deserialize(buf: &mut &[u8]) -> Result { - if buf.len() < size_of::<$type>() { - return Err(Error::new( - ErrorKind::InvalidInput, - ERROR_UNEXPECTED_LENGTH_OF_INPUT, - )); - } - let res = $type::from_le_bytes(buf[..size_of::<$type>()].try_into().unwrap()); - *buf = &buf[size_of::<$type>()..]; + fn deserialize_reader(reader: &mut R) -> Result { + let mut buf = [0u8; size_of::<$type>()]; + reader + .read_exact(&mut buf) + .map_err(unexpected_eof_to_unexpected_length_of_input)?; + let res = $type::from_le_bytes(buf.try_into().unwrap()); Ok(res) } } @@ -138,8 +143,8 @@ macro_rules! impl_for_nonzero_integer { ($type: ty) => { impl BorshDeserialize for $type { #[inline] - fn deserialize(buf: &mut &[u8]) -> Result { - <$type>::new(BorshDeserialize::deserialize(buf)?) + fn deserialize_reader(reader: &mut R) -> Result { + <$type>::new(BorshDeserialize::deserialize_reader(reader)?) .ok_or_else(|| Error::new(ErrorKind::InvalidData, ERROR_INVALID_ZERO_VALUE)) } } @@ -159,8 +164,8 @@ impl_for_nonzero_integer!(core::num::NonZeroU128); impl_for_nonzero_integer!(core::num::NonZeroUsize); impl BorshDeserialize for isize { - fn deserialize(buf: &mut &[u8]) -> Result { - let i: i64 = BorshDeserialize::deserialize(buf)?; + fn deserialize_reader(reader: &mut R) -> Result { + let i: i64 = BorshDeserialize::deserialize_reader(reader)?; let i = isize::try_from(i).map_err(|_| { Error::new( ErrorKind::InvalidInput, @@ -172,8 +177,8 @@ impl BorshDeserialize for isize { } impl BorshDeserialize for usize { - fn deserialize(buf: &mut &[u8]) -> Result { - let u: u64 = BorshDeserialize::deserialize(buf)?; + fn deserialize_reader(reader: &mut R) -> Result { + let u: u64 = BorshDeserialize::deserialize_reader(reader)?; let u = usize::try_from(u).map_err(|_| { Error::new( ErrorKind::InvalidInput, @@ -190,17 +195,12 @@ macro_rules! impl_for_float { ($type: ident, $int_type: ident) => { impl BorshDeserialize for $type { #[inline] - fn deserialize(buf: &mut &[u8]) -> Result { - if buf.len() < size_of::<$type>() { - return Err(Error::new( - ErrorKind::InvalidInput, - ERROR_UNEXPECTED_LENGTH_OF_INPUT, - )); - } - let res = $type::from_bits($int_type::from_le_bytes( - buf[..size_of::<$int_type>()].try_into().unwrap(), - )); - *buf = &buf[size_of::<$int_type>()..]; + fn deserialize_reader(reader: &mut R) -> Result { + let mut buf = [0u8; size_of::<$type>()]; + reader + .read_exact(&mut buf) + .map_err(unexpected_eof_to_unexpected_length_of_input)?; + let res = $type::from_bits($int_type::from_le_bytes(buf.try_into().unwrap())); if res.is_nan() { return Err(Error::new( ErrorKind::InvalidInput, @@ -218,15 +218,8 @@ impl_for_float!(f64, u64); impl BorshDeserialize for bool { #[inline] - fn deserialize(buf: &mut &[u8]) -> Result { - if buf.is_empty() { - return Err(Error::new( - ErrorKind::InvalidInput, - ERROR_UNEXPECTED_LENGTH_OF_INPUT, - )); - } - let b = buf[0]; - *buf = &buf[1..]; + fn deserialize_reader(reader: &mut R) -> Result { + let b: u8 = BorshDeserialize::deserialize_reader(reader)?; if b == 0 { Ok(false) } else if b == 1 { @@ -244,19 +237,12 @@ where T: BorshDeserialize, { #[inline] - fn deserialize(buf: &mut &[u8]) -> Result { - if buf.is_empty() { - return Err(Error::new( - ErrorKind::InvalidInput, - ERROR_UNEXPECTED_LENGTH_OF_INPUT, - )); - } - let flag = buf[0]; - *buf = &buf[1..]; + fn deserialize_reader(reader: &mut R) -> Result { + let flag: u8 = BorshDeserialize::deserialize_reader(reader)?; if flag == 0 { Ok(None) } else if flag == 1 { - Ok(Some(T::deserialize(buf)?)) + Ok(Some(T::deserialize_reader(reader)?)) } else { let msg = format!( "Invalid Option representation: {}. The first byte must be 0 or 1", @@ -274,19 +260,12 @@ where E: BorshDeserialize, { #[inline] - fn deserialize(buf: &mut &[u8]) -> Result { - if buf.is_empty() { - return Err(Error::new( - ErrorKind::InvalidInput, - ERROR_UNEXPECTED_LENGTH_OF_INPUT, - )); - } - let flag = buf[0]; - *buf = &buf[1..]; + fn deserialize_reader(reader: &mut R) -> Result { + let flag: u8 = BorshDeserialize::deserialize_reader(reader)?; if flag == 0 { - Ok(Err(E::deserialize(buf)?)) + Ok(Err(E::deserialize_reader(reader)?)) } else if flag == 1 { - Ok(Ok(T::deserialize(buf)?)) + Ok(Ok(T::deserialize_reader(reader)?)) } else { let msg = format!( "Invalid Result representation: {}. The first byte must be 0 or 1", @@ -300,8 +279,8 @@ where impl BorshDeserialize for String { #[inline] - fn deserialize(buf: &mut &[u8]) -> Result { - String::from_utf8(Vec::::deserialize(buf)?).map_err(|err| { + fn deserialize_reader(reader: &mut R) -> Result { + String::from_utf8(Vec::::deserialize_reader(reader)?).map_err(|err| { let msg = err.to_string(); Error::new(ErrorKind::InvalidData, msg) }) @@ -313,14 +292,14 @@ where T: BorshDeserialize, { #[inline] - fn deserialize(buf: &mut &[u8]) -> Result { - let len = u32::deserialize(buf)?; + fn deserialize_reader(reader: &mut R) -> Result { + let len = u32::deserialize_reader(reader)?; if len == 0 { Ok(Vec::new()) - } else if let Some(vec_bytes) = T::vec_from_bytes(len, buf)? { + } else if let Some(vec_bytes) = T::vec_from_reader(len, reader)? { Ok(vec_bytes) } else if size_of::() == 0 { - let mut result = vec![T::deserialize(buf)?]; + let mut result = vec![T::deserialize_reader(reader)?]; let p = result.as_mut_ptr(); unsafe { @@ -333,7 +312,7 @@ where // TODO(16): return capacity allocation when we can safely do that. let mut result = Vec::with_capacity(hint::cautious::(len)); for _ in 0..len { - result.push(T::deserialize(buf)?); + result.push(T::deserialize_reader(reader)?); } Ok(result) } @@ -346,8 +325,8 @@ where T::Owned: BorshDeserialize, { #[inline] - fn deserialize(buf: &mut &[u8]) -> Result { - Ok(Cow::Owned(BorshDeserialize::deserialize(buf)?)) + fn deserialize_reader(reader: &mut R) -> Result { + Ok(Cow::Owned(BorshDeserialize::deserialize_reader(reader)?)) } } @@ -356,8 +335,8 @@ where T: BorshDeserialize, { #[inline] - fn deserialize(buf: &mut &[u8]) -> Result { - let vec = >::deserialize(buf)?; + fn deserialize_reader(reader: &mut R) -> Result { + let vec = >::deserialize_reader(reader)?; Ok(vec.into()) } } @@ -367,8 +346,8 @@ where T: BorshDeserialize, { #[inline] - fn deserialize(buf: &mut &[u8]) -> Result { - let vec = >::deserialize(buf)?; + fn deserialize_reader(reader: &mut R) -> Result { + let vec = >::deserialize_reader(reader)?; Ok(vec.into_iter().collect::>()) } } @@ -378,8 +357,8 @@ where T: BorshDeserialize + Ord, { #[inline] - fn deserialize(buf: &mut &[u8]) -> Result { - let vec = >::deserialize(buf)?; + fn deserialize_reader(reader: &mut R) -> Result { + let vec = >::deserialize_reader(reader)?; Ok(vec.into_iter().collect::>()) } } @@ -390,8 +369,8 @@ where H: BuildHasher + Default, { #[inline] - fn deserialize(buf: &mut &[u8]) -> Result { - let vec = >::deserialize(buf)?; + fn deserialize_reader(reader: &mut R) -> Result { + let vec = >::deserialize_reader(reader)?; Ok(vec.into_iter().collect::>()) } } @@ -403,13 +382,13 @@ where H: BuildHasher + Default, { #[inline] - fn deserialize(buf: &mut &[u8]) -> Result { - let len = u32::deserialize(buf)?; + fn deserialize_reader(reader: &mut R) -> Result { + let len = u32::deserialize_reader(reader)?; // TODO(16): return capacity allocation when we can safely do that. let mut result = HashMap::with_hasher(H::default()); for _ in 0..len { - let key = K::deserialize(buf)?; - let value = V::deserialize(buf)?; + let key = K::deserialize_reader(reader)?; + let value = V::deserialize_reader(reader)?; result.insert(key, value); } Ok(result) @@ -421,8 +400,8 @@ where T: BorshDeserialize + Ord, { #[inline] - fn deserialize(buf: &mut &[u8]) -> Result { - let vec = >::deserialize(buf)?; + fn deserialize_reader(reader: &mut R) -> Result { + let vec = >::deserialize_reader(reader)?; Ok(vec.into_iter().collect::>()) } } @@ -433,12 +412,12 @@ where V: BorshDeserialize, { #[inline] - fn deserialize(buf: &mut &[u8]) -> Result { - let len = u32::deserialize(buf)?; + fn deserialize_reader(reader: &mut R) -> Result { + let len = u32::deserialize_reader(reader)?; let mut result = BTreeMap::new(); for _ in 0..len { - let key = K::deserialize(buf)?; - let value = V::deserialize(buf)?; + let key = K::deserialize_reader(reader)?; + let value = V::deserialize_reader(reader)?; result.insert(key, value); } Ok(result) @@ -448,11 +427,11 @@ where #[cfg(feature = "std")] impl BorshDeserialize for std::net::SocketAddr { #[inline] - fn deserialize(buf: &mut &[u8]) -> Result { - let kind = u8::deserialize(buf)?; + fn deserialize_reader(reader: &mut R) -> Result { + let kind = u8::deserialize_reader(reader)?; match kind { - 0 => std::net::SocketAddrV4::deserialize(buf).map(std::net::SocketAddr::V4), - 1 => std::net::SocketAddrV6::deserialize(buf).map(std::net::SocketAddr::V6), + 0 => std::net::SocketAddrV4::deserialize_reader(reader).map(std::net::SocketAddr::V4), + 1 => std::net::SocketAddrV6::deserialize_reader(reader).map(std::net::SocketAddr::V6), value => Err(Error::new( ErrorKind::InvalidInput, format!("Invalid SocketAddr variant: {}", value), @@ -464,9 +443,9 @@ impl BorshDeserialize for std::net::SocketAddr { #[cfg(feature = "std")] impl BorshDeserialize for std::net::SocketAddrV4 { #[inline] - fn deserialize(buf: &mut &[u8]) -> Result { - let ip = std::net::Ipv4Addr::deserialize(buf)?; - let port = u16::deserialize(buf)?; + fn deserialize_reader(reader: &mut R) -> Result { + let ip = std::net::Ipv4Addr::deserialize_reader(reader)?; + let port = u16::deserialize_reader(reader)?; Ok(std::net::SocketAddrV4::new(ip, port)) } } @@ -474,9 +453,9 @@ impl BorshDeserialize for std::net::SocketAddrV4 { #[cfg(feature = "std")] impl BorshDeserialize for std::net::SocketAddrV6 { #[inline] - fn deserialize(buf: &mut &[u8]) -> Result { - let ip = std::net::Ipv6Addr::deserialize(buf)?; - let port = u16::deserialize(buf)?; + fn deserialize_reader(reader: &mut R) -> Result { + let ip = std::net::Ipv6Addr::deserialize_reader(reader)?; + let port = u16::deserialize_reader(reader)?; Ok(std::net::SocketAddrV6::new(ip, port, 0, 0)) } } @@ -484,34 +463,24 @@ impl BorshDeserialize for std::net::SocketAddrV6 { #[cfg(feature = "std")] impl BorshDeserialize for std::net::Ipv4Addr { #[inline] - fn deserialize(buf: &mut &[u8]) -> Result { - if buf.len() < 4 { - return Err(Error::new( - ErrorKind::InvalidInput, - ERROR_UNEXPECTED_LENGTH_OF_INPUT, - )); - } - let bytes: [u8; 4] = buf[..4].try_into().unwrap(); - let res = std::net::Ipv4Addr::from(bytes); - *buf = &buf[4..]; - Ok(res) + fn deserialize_reader(reader: &mut R) -> Result { + let mut buf = [0u8; 4]; + reader + .read_exact(&mut buf) + .map_err(unexpected_eof_to_unexpected_length_of_input)?; + Ok(std::net::Ipv4Addr::from(buf)) } } #[cfg(feature = "std")] impl BorshDeserialize for std::net::Ipv6Addr { #[inline] - fn deserialize(buf: &mut &[u8]) -> Result { - if buf.len() < 16 { - return Err(Error::new( - ErrorKind::InvalidInput, - ERROR_UNEXPECTED_LENGTH_OF_INPUT, - )); - } - let bytes: [u8; 16] = buf[..16].try_into().unwrap(); - let res = std::net::Ipv6Addr::from(bytes); - *buf = &buf[16..]; - Ok(res) + fn deserialize_reader(reader: &mut R) -> Result { + let mut buf = [0u8; 16]; + reader + .read_exact(&mut buf) + .map_err(unexpected_eof_to_unexpected_length_of_input)?; + Ok(std::net::Ipv6Addr::from(buf)) } } @@ -521,8 +490,8 @@ where T: ToOwned + ?Sized, T::Owned: BorshDeserialize, { - fn deserialize(buf: &mut &[u8]) -> Result { - Ok(T::Owned::deserialize(buf)?.into()) + fn deserialize_reader(reader: &mut R) -> Result { + Ok(T::Owned::deserialize_reader(reader)?.into()) } } @@ -531,7 +500,7 @@ where T: BorshDeserialize, { #[inline] - fn deserialize(buf: &mut &[u8]) -> Result { + fn deserialize_reader(reader: &mut R) -> Result { struct ArrayDropGuard { buffer: [MaybeUninit; N], init_count: usize, @@ -568,7 +537,7 @@ where } } - if let Some(arr) = T::array_from_bytes(buf)? { + if let Some(arr) = T::array_from_reader(reader)? { Ok(arr) } else { let mut result = ArrayDropGuard { @@ -576,7 +545,7 @@ where init_count: 0, }; - result.fill_buffer(|| T::deserialize(buf))?; + result.fill_buffer(|| T::deserialize_reader(reader))?; // SAFETY: The elements up to `i` have been initialized in `fill_buffer`. Ok(unsafe { result.transmute_to_array() }) @@ -593,8 +562,8 @@ fn array_deserialization_doesnt_leak() { struct MyType(u8); impl BorshDeserialize for MyType { - fn deserialize(buf: &mut &[u8]) -> Result { - let val = u8::deserialize(buf)?; + fn deserialize_reader(reader: &mut R) -> Result { + let val = u8::deserialize_reader(reader)?; let v = DESERIALIZE_COUNT.fetch_add(1, Ordering::SeqCst); if v >= 7 { panic!("panic in deserialize"); @@ -629,7 +598,7 @@ fn array_deserialization_doesnt_leak() { } impl BorshDeserialize for () { - fn deserialize(_buf: &mut &[u8]) -> Result { + fn deserialize_reader(_reader: &mut R) -> Result { Ok(()) } } @@ -640,9 +609,9 @@ macro_rules! impl_tuple { where $($name: BorshDeserialize,)+ { #[inline] - fn deserialize(buf: &mut &[u8]) -> Result { + fn deserialize_reader(reader: &mut R) -> Result { - Ok(($($name::deserialize(buf)?,)+)) + Ok(($($name::deserialize_reader(reader)?,)+)) } } }; @@ -676,8 +645,8 @@ where T: ToOwned + ?Sized, T::Owned: BorshDeserialize, { - fn deserialize(buf: &mut &[u8]) -> Result { - Ok(T::Owned::deserialize(buf)?.into()) + fn deserialize_reader(reader: &mut R) -> Result { + Ok(T::Owned::deserialize_reader(reader)?.into()) } } @@ -688,13 +657,13 @@ where T: ToOwned + ?Sized, T::Owned: BorshDeserialize, { - fn deserialize(buf: &mut &[u8]) -> Result { - Ok(T::Owned::deserialize(buf)?.into()) + fn deserialize_reader(reader: &mut R) -> Result { + Ok(T::Owned::deserialize_reader(reader)?.into()) } } impl BorshDeserialize for PhantomData { - fn deserialize(_: &mut &[u8]) -> Result { + fn deserialize_reader(_: &mut R) -> Result { Ok(Self::default()) } } diff --git a/borsh/src/nostd_io.rs b/borsh/src/nostd_io.rs index 08c66fa56..bb0bed9fc 100644 --- a/borsh/src/nostd_io.rs +++ b/borsh/src/nostd_io.rs @@ -688,3 +688,327 @@ impl Write for alloc::vec::Vec { Ok(()) } } + +/// The `Read` trait allows for reading bytes from a source. +/// +/// Implementors of the `Read` trait are called 'readers'. +/// +/// Readers are defined by one required method, [`read()`]. Each call to [`read()`] +/// will attempt to pull bytes from this source into a provided buffer. A +/// number of other methods are implemented in terms of [`read()`], giving +/// implementors a number of ways to read bytes while only needing to implement +/// a single method. +/// +/// Readers are intended to be composable with one another. Many implementors +/// throughout [`std::io`] take and provide types which implement the `Read` +/// trait. +/// +/// Please note that each call to [`read()`] may involve a system call, and +/// therefore, using something that implements [`BufRead`], such as +/// [`BufReader`], will be more efficient. +/// +/// # Examples +/// +/// [`File`]s implement `Read`: +/// +/// ```no_run +/// use std::io; +/// use std::io::prelude::*; +/// use std::fs::File; +/// +/// fn main() -> io::Result<()> { +/// let mut f = File::open("foo.txt")?; +/// let mut buffer = [0; 10]; +/// +/// // read up to 10 bytes +/// f.read(&mut buffer)?; +/// +/// let mut buffer = Vec::new(); +/// // read the whole file +/// f.read_to_end(&mut buffer)?; +/// +/// // read into a String, so that you don't need to do the conversion. +/// let mut buffer = String::new(); +/// f.read_to_string(&mut buffer)?; +/// +/// // and more! See the other methods for more details. +/// Ok(()) +/// } +/// ``` +/// +/// Read from [`&str`] because [`&[u8]`][prim@slice] implements `Read`: +/// +/// ```no_run +/// # use std::io; +/// use std::io::prelude::*; +/// +/// fn main() -> io::Result<()> { +/// let mut b = "This string will be read".as_bytes(); +/// let mut buffer = [0; 10]; +/// +/// // read up to 10 bytes +/// b.read(&mut buffer)?; +/// +/// // etc... it works exactly as a File does! +/// Ok(()) +/// } +/// ``` +/// +/// [`read()`]: Read::read +/// [`&str`]: prim@str +/// [`std::io`]: self +/// [`File`]: crate::fs::File +pub trait Read { + /// Pull some bytes from this source into the specified buffer, returning + /// how many bytes were read. + /// + /// This function does not provide any guarantees about whether it blocks + /// waiting for data, but if an object needs to block for a read and cannot, + /// it will typically signal this via an [`Err`] return value. + /// + /// If the return value of this method is [`Ok(n)`], then implementations must + /// guarantee that `0 <= n <= buf.len()`. A nonzero `n` value indicates + /// that the buffer `buf` has been filled in with `n` bytes of data from this + /// source. If `n` is `0`, then it can indicate one of two scenarios: + /// + /// 1. This reader has reached its "end of file" and will likely no longer + /// be able to produce bytes. Note that this does not mean that the + /// reader will *always* no longer be able to produce bytes. As an example, + /// on Linux, this method will call the `recv` syscall for a [`TcpStream`], + /// where returning zero indicates the connection was shut down correctly. While + /// for [`File`], it is possible to reach the end of file and get zero as result, + /// but if more data is appended to the file, future calls to `read` will return + /// more data. + /// 2. The buffer specified was 0 bytes in length. + /// + /// It is not an error if the returned value `n` is smaller than the buffer size, + /// even when the reader is not at the end of the stream yet. + /// This may happen for example because fewer bytes are actually available right now + /// (e. g. being close to end-of-file) or because read() was interrupted by a signal. + /// + /// As this trait is safe to implement, callers cannot rely on `n <= buf.len()` for safety. + /// Extra care needs to be taken when `unsafe` functions are used to access the read bytes. + /// Callers have to ensure that no unchecked out-of-bounds accesses are possible even if + /// `n > buf.len()`. + /// + /// No guarantees are provided about the contents of `buf` when this + /// function is called, implementations cannot rely on any property of the + /// contents of `buf` being true. It is recommended that *implementations* + /// only write data to `buf` instead of reading its contents. + /// + /// Correspondingly, however, *callers* of this method must not assume any guarantees + /// about how the implementation uses `buf`. The trait is safe to implement, + /// so it is possible that the code that's supposed to write to the buffer might also read + /// from it. It is your responsibility to make sure that `buf` is initialized + /// before calling `read`. Calling `read` with an uninitialized `buf` (of the kind one + /// obtains via [`MaybeUninit`]) is not safe, and can lead to undefined behavior. + /// + /// [`MaybeUninit`]: crate::mem::MaybeUninit + /// + /// # Errors + /// + /// If this function encounters any form of I/O or other error, an error + /// variant will be returned. If an error is returned then it must be + /// guaranteed that no bytes were read. + /// + /// An error of the [`ErrorKind::Interrupted`] kind is non-fatal and the read + /// operation should be retried if there is nothing else to do. + /// + /// # Examples + /// + /// [`File`]s implement `Read`: + /// + /// [`Ok(n)`]: Ok + /// [`File`]: crate::fs::File + /// [`TcpStream`]: crate::net::TcpStream + /// + /// ```no_run + /// use std::io; + /// use std::io::prelude::*; + /// use std::fs::File; + /// + /// fn main() -> io::Result<()> { + /// let mut f = File::open("foo.txt")?; + /// let mut buffer = [0; 10]; + /// + /// // read up to 10 bytes + /// let n = f.read(&mut buffer[..])?; + /// + /// println!("The bytes: {:?}", &buffer[..n]); + /// Ok(()) + /// } + /// ``` + fn read(&mut self, buf: &mut [u8]) -> Result; + + /// Read the exact number of bytes required to fill `buf`. + /// + /// This function reads as many bytes as necessary to completely fill the + /// specified buffer `buf`. + /// + /// No guarantees are provided about the contents of `buf` when this + /// function is called, implementations cannot rely on any property of the + /// contents of `buf` being true. It is recommended that implementations + /// only write data to `buf` instead of reading its contents. The + /// documentation on [`read`] has a more detailed explanation on this + /// subject. + /// + /// # Errors + /// + /// If this function encounters an error of the kind + /// [`ErrorKind::Interrupted`] then the error is ignored and the operation + /// will continue. + /// + /// If this function encounters an "end of file" before completely filling + /// the buffer, it returns an error of the kind [`ErrorKind::UnexpectedEof`]. + /// The contents of `buf` are unspecified in this case. + /// + /// If any other read error is encountered then this function immediately + /// returns. The contents of `buf` are unspecified in this case. + /// + /// If this function returns an error, it is unspecified how many bytes it + /// has read, but it will never read more than would be necessary to + /// completely fill the buffer. + /// + /// # Examples + /// + /// [`File`]s implement `Read`: + /// + /// [`read`]: Read::read + /// [`File`]: crate::fs::File + /// + /// ```no_run + /// use std::io; + /// use std::io::prelude::*; + /// use std::fs::File; + /// + /// fn main() -> io::Result<()> { + /// let mut f = File::open("foo.txt")?; + /// let mut buffer = [0; 10]; + /// + /// // read exactly 10 bytes + /// f.read_exact(&mut buffer)?; + /// Ok(()) + /// } + /// ``` + fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> { + default_read_exact(self, buf) + } + + /// Creates a "by reference" adaptor for this instance of `Read`. + /// + /// The returned adapter also implements `Read` and will simply borrow this + /// current reader. + /// + /// # Examples + /// + /// [`File`]s implement `Read`: + /// + /// [`File`]: crate::fs::File + /// + /// ```no_run + /// use std::io; + /// use std::io::Read; + /// use std::fs::File; + /// + /// fn main() -> io::Result<()> { + /// let mut f = File::open("foo.txt")?; + /// let mut buffer = Vec::new(); + /// let mut other_buffer = Vec::new(); + /// + /// { + /// let reference = f.by_ref(); + /// + /// // read at most 5 bytes + /// reference.take(5).read_to_end(&mut buffer)?; + /// + /// } // drop our &mut reference so we can use f again + /// + /// // original file still usable, read the rest + /// f.read_to_end(&mut other_buffer)?; + /// Ok(()) + /// } + /// ``` + fn by_ref(&mut self) -> &mut Self + where + Self: Sized, + { + self + } +} + +fn default_read_exact(this: &mut R, mut buf: &mut [u8]) -> Result<()> { + while !buf.is_empty() { + match this.read(buf) { + Ok(0) => break, + Ok(n) => { + let tmp = buf; + buf = &mut tmp[n..]; + } + Err(ref e) if e.kind() == ErrorKind::Interrupted => {} + Err(e) => return Err(e), + } + } + if !buf.is_empty() { + Err(Error::new( + ErrorKind::UnexpectedEof, + "failed to fill whole buffer", + )) + } else { + Ok(()) + } +} + +impl Read for &mut R { + #[inline] + fn read(&mut self, buf: &mut [u8]) -> Result { + (**self).read(buf) + } + + #[inline] + fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> { + (**self).read_exact(buf) + } +} + +impl Read for &[u8] { + #[inline] + fn read(&mut self, buf: &mut [u8]) -> Result { + let amt = core::cmp::min(buf.len(), self.len()); + let (a, b) = self.split_at(amt); + + // First check if the amount of bytes we want to read is small: + // `copy_from_slice` will generally expand to a call to `memcpy`, and + // for a single byte the overhead is significant. + if amt == 1 { + buf[0] = a[0]; + } else { + buf[..amt].copy_from_slice(a); + } + + *self = b; + Ok(amt) + } + + #[inline] + fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> { + if buf.len() > self.len() { + return Err(Error::new( + ErrorKind::UnexpectedEof, + "failed to fill whole buffer", + )); + } + let (a, b) = self.split_at(buf.len()); + + // First check if the amount of bytes we want to read is small: + // `copy_from_slice` will generally expand to a call to `memcpy`, and + // for a single byte the overhead is significant. + if buf.len() == 1 { + buf[0] = a[0]; + } else { + buf.copy_from_slice(a); + } + + *self = b; + Ok(()) + } +} diff --git a/borsh/tests/test_custom_reader.rs b/borsh/tests/test_custom_reader.rs new file mode 100644 index 000000000..4fb63cc70 --- /dev/null +++ b/borsh/tests/test_custom_reader.rs @@ -0,0 +1,139 @@ +use borsh::{BorshDeserialize, BorshSerialize}; + +const ERROR_UNEXPECTED_LENGTH_OF_INPUT: &str = "Unexpected length of input"; + +#[derive(BorshSerialize, BorshDeserialize, Debug)] +struct Serializable { + item1: i32, + item2: String, + item3: f64, +} + +#[test] +fn test_custom_reader() { + let s = Serializable { + item1: 100, + item2: "foo".into(), + item3: 1.2345, + }; + let bytes = s.try_to_vec().unwrap(); + let mut reader = CustomReader { + data: bytes, + read_index: 0, + }; + let de: Serializable = BorshDeserialize::deserialize_reader(&mut reader).unwrap(); + assert_eq!(de.item1, s.item1); + assert_eq!(de.item2, s.item2); + assert_eq!(de.item3, s.item3); +} + +#[test] +fn test_custom_reader_with_insufficient_data() { + let s = Serializable { + item1: 100, + item2: "foo".into(), + item3: 1.2345, + }; + let mut bytes = s.try_to_vec().unwrap(); + bytes.pop().unwrap(); + let mut reader = CustomReader { + data: bytes, + read_index: 0, + }; + assert_eq!( + ::deserialize_reader(&mut reader) + .unwrap_err() + .to_string(), + ERROR_UNEXPECTED_LENGTH_OF_INPUT + ); +} + +#[test] +fn test_custom_reader_with_too_much_data() { + let s = Serializable { + item1: 100, + item2: "foo".into(), + item3: 1.2345, + }; + let mut bytes = s.try_to_vec().unwrap(); + bytes.pop().unwrap(); + let mut reader = CustomReader { + data: bytes, + read_index: 0, + }; + assert_eq!( + ::try_from_reader(&mut reader) + .unwrap_err() + .to_string(), + ERROR_UNEXPECTED_LENGTH_OF_INPUT + ); +} + +struct CustomReader { + data: Vec, + read_index: usize, +} + +impl borsh::maybestd::io::Read for CustomReader { + fn read(&mut self, buf: &mut [u8]) -> borsh::maybestd::io::Result { + let len = buf.len().min(self.data.len() - self.read_index); + buf[0..len].copy_from_slice(&self.data[self.read_index..self.read_index + len]); + self.read_index += len; + Ok(len) + } +} + +#[test] +fn test_custom_reader_that_doesnt_fill_slices() { + let s = Serializable { + item1: 100, + item2: "foo".into(), + item3: 1.2345, + }; + let bytes = s.try_to_vec().unwrap(); + let mut reader = CustomReaderThatDoesntFillSlices { + data: bytes, + read_index: 0, + }; + let de: Serializable = BorshDeserialize::deserialize_reader(&mut reader).unwrap(); + assert_eq!(de.item1, s.item1); + assert_eq!(de.item2, s.item2); + assert_eq!(de.item3, s.item3); +} + +struct CustomReaderThatDoesntFillSlices { + data: Vec, + read_index: usize, +} + +impl borsh::maybestd::io::Read for CustomReaderThatDoesntFillSlices { + fn read(&mut self, buf: &mut [u8]) -> borsh::maybestd::io::Result { + let len = buf.len().min(self.data.len() - self.read_index); + let len = if len <= 1 { len } else { len / 2 }; + buf[0..len].copy_from_slice(&self.data[self.read_index..self.read_index + len]); + self.read_index += len; + Ok(len) + } +} + +#[test] +fn test_custom_reader_that_fails_preserves_error_information() { + let mut reader = CustomReaderThatFails; + let err = ::try_from_reader(&mut reader).unwrap_err(); + assert_eq!(err.to_string(), "I don't like to run"); + assert_eq!( + err.kind(), + borsh::maybestd::io::ErrorKind::ConnectionAborted + ); +} + +struct CustomReaderThatFails; + +impl borsh::maybestd::io::Read for CustomReaderThatFails { + fn read(&mut self, _buf: &mut [u8]) -> borsh::maybestd::io::Result { + Err(borsh::maybestd::io::Error::new( + borsh::maybestd::io::ErrorKind::ConnectionAborted, + "I don't like to run", + )) + } +}