diff --git a/borsh/src/de/mod.rs b/borsh/src/de/mod.rs index b015778bd..01e42185b 100644 --- a/borsh/src/de/mod.rs +++ b/borsh/src/de/mod.rs @@ -21,6 +21,7 @@ use crate::__private::maybestd::{ #[cfg(feature = "rc")] use crate::__private::maybestd::{rc::Rc, sync::Arc}; +use crate::error::check_zst; mod hint; @@ -388,12 +389,7 @@ where { #[inline] fn deserialize_reader(reader: &mut R) -> Result { - if size_of::() == 0 { - return Err(Error::new( - ErrorKind::InvalidData, - "Vectors of zero-sized types are not allowed due to deny-of-service concerns on deserialization.", - )); - } + check_zst::()?; let len = u32::deserialize_reader(reader)?; if len == 0 { @@ -493,6 +489,7 @@ pub mod hashes { const ERROR_WRONG_ORDER_OF_KEYS: &str = "keys were not serialized in ascending order"; #[cfg(feature = "de_strict_order")] use crate::__private::maybestd::io::{Error, ErrorKind}; + use crate::error::check_zst; #[cfg(feature = "de_strict_order")] use core::cmp::Ordering; @@ -537,6 +534,7 @@ pub mod hashes { { #[inline] fn deserialize_reader(reader: &mut R) -> Result { + check_zst::()?; // NOTE: deserialize-as-you-go approach as once was in HashSet is better in the sense // that it allows to fail early, and not allocate memory for all the entries // which may fail `cmp()` checks @@ -604,6 +602,7 @@ where { #[inline] fn deserialize_reader(reader: &mut R) -> Result { + check_zst::()?; // NOTE: deserialize-as-you-go approach as once was in HashSet is better in the sense // that it allows to fail early, and not allocate memory for all the entries // which may fail `cmp()` checks diff --git a/borsh/src/error.rs b/borsh/src/error.rs new file mode 100644 index 000000000..9da39e3f9 --- /dev/null +++ b/borsh/src/error.rs @@ -0,0 +1,10 @@ +use crate::__private::maybestd::io::{Error, ErrorKind, Result}; +use core::mem::size_of; +pub const ERROR_ZST_FORBIDDEN: &str = "Collections of zero-sized types are not allowed due to deny-of-service concerns on deserialization."; + +pub(crate) fn check_zst() -> Result<()> { + if size_of::() == 0 { + return Err(Error::new(ErrorKind::InvalidData, ERROR_ZST_FORBIDDEN)); + } + Ok(()) +} diff --git a/borsh/src/lib.rs b/borsh/src/lib.rs index b052518d0..70aedaa76 100644 --- a/borsh/src/lib.rs +++ b/borsh/src/lib.rs @@ -87,6 +87,7 @@ pub use schema::BorshSchema; pub use schema_helpers::{try_from_slice_with_schema, try_to_vec_with_schema}; pub use ser::helpers::{to_vec, to_writer}; pub use ser::BorshSerialize; +pub mod error; #[cfg(all(feature = "std", feature = "hashbrown"))] compile_error!("feature \"std\" and feature \"hashbrown\" don't make sense at the same time"); diff --git a/borsh/src/ser/mod.rs b/borsh/src/ser/mod.rs index d3a55b43d..1b8c1ca75 100644 --- a/borsh/src/ser/mod.rs +++ b/borsh/src/ser/mod.rs @@ -1,14 +1,16 @@ use core::convert::TryFrom; use core::marker::PhantomData; -use core::mem::size_of; - -use crate::__private::maybestd::{ - borrow::{Cow, ToOwned}, - boxed::Box, - collections::{BTreeMap, BTreeSet, LinkedList, VecDeque}, - io::{Error, ErrorKind, Result, Write}, - string::String, - vec::Vec, + +use crate::{ + __private::maybestd::{ + borrow::{Cow, ToOwned}, + boxed::Box, + collections::{BTreeMap, BTreeSet, LinkedList, VecDeque}, + io::{ErrorKind, Result, Write}, + string::String, + vec::Vec, + }, + error::check_zst, }; #[cfg(feature = "rc")] @@ -278,12 +280,8 @@ where { #[inline] fn serialize(&self, writer: &mut W) -> Result<()> { - if size_of::() == 0 { - return Err(Error::new( - ErrorKind::InvalidData, - "Vectors of zero-sized types are not allowed due to deny-of-service concerns on deserialization.", - )); - } + check_zst::()?; + self.as_slice().serialize(writer) } } @@ -318,6 +316,8 @@ where { #[inline] fn serialize(&self, writer: &mut W) -> Result<()> { + check_zst::()?; + writer.write_all( &(u32::try_from(self.len()).map_err(|_| ErrorKind::InvalidData)?).to_le_bytes(), )?; @@ -333,6 +333,8 @@ where { #[inline] fn serialize(&self, writer: &mut W) -> Result<()> { + check_zst::()?; + writer.write_all( &(u32::try_from(self.len()).map_err(|_| ErrorKind::InvalidData)?).to_le_bytes(), )?; @@ -350,6 +352,7 @@ pub mod hashes { //! Module defines [BorshSerialize](crate::ser::BorshSerialize) implementation for //! [HashMap](std::collections::HashMap)/[HashSet](std::collections::HashSet). use crate::__private::maybestd::vec::Vec; + use crate::error::check_zst; use crate::{ BorshSerialize, __private::maybestd::collections::{HashMap, HashSet}, @@ -367,6 +370,8 @@ pub mod hashes { { #[inline] fn serialize(&self, writer: &mut W) -> Result<()> { + check_zst::()?; + let mut vec = self.iter().collect::>(); vec.sort_by(|(a, _), (b, _)| a.partial_cmp(b).unwrap()); u32::try_from(vec.len()) @@ -387,6 +392,8 @@ pub mod hashes { { #[inline] fn serialize(&self, writer: &mut W) -> Result<()> { + check_zst::()?; + let mut vec = self.iter().collect::>(); vec.sort_by(|a, b| a.partial_cmp(b).unwrap()); u32::try_from(vec.len()) @@ -407,6 +414,7 @@ where { #[inline] fn serialize(&self, writer: &mut W) -> Result<()> { + check_zst::()?; // NOTE: BTreeMap iterates over the entries that are sorted by key, so the serialization // result will be consistent without a need to sort the entries as we do for HashMap // serialization. @@ -427,6 +435,7 @@ where { #[inline] fn serialize(&self, writer: &mut W) -> Result<()> { + check_zst::()?; // NOTE: BTreeSet iterates over the items that are sorted, so the serialization result will // be consistent without a need to sort the entries as we do for HashSet serialization. u32::try_from(self.len()) diff --git a/borsh/tests/test_zero_size.rs b/borsh/tests/test_zero_size.rs deleted file mode 100644 index 598377b67..000000000 --- a/borsh/tests/test_zero_size.rs +++ /dev/null @@ -1,45 +0,0 @@ -#![cfg_attr(not(feature = "std"), no_std)] -#![cfg(feature = "derive")] - -#[cfg(not(feature = "std"))] -extern crate alloc; -#[cfg(not(feature = "std"))] -use alloc::{vec, vec::Vec}; - -use borsh::from_slice; -use borsh::to_vec; -use borsh::BorshDeserialize; -use borsh::BorshSerialize; - -#[derive(BorshDeserialize, BorshSerialize, PartialEq, Debug)] -struct A(); - -#[test] -fn test_deserialize_zero_size() { - let v = [0u8, 0u8, 0u8, 64u8]; - let res = from_slice::>(&v); - assert!(res.is_err()); -} - -#[test] -fn test_serialize_zero_size() { - let v = vec![A()]; - let res = to_vec(&v); - assert!(res.is_err()); -} - -#[derive(BorshDeserialize, BorshSerialize, PartialEq, Debug)] -struct B(u32); -#[test] -fn test_deserialize_non_zero_size() { - let v = [1, 0, 0, 0, 64, 0, 0, 0]; - let res = Vec::::try_from_slice(&v); - assert!(res.is_ok()); -} - -#[test] -fn test_serialize_non_zero_size() { - let v = vec![B(1)]; - let res = to_vec(&v); - assert!(res.is_ok()); -} diff --git a/borsh/tests/test_zero_sized_types.rs b/borsh/tests/test_zero_sized_types.rs new file mode 100644 index 000000000..4e74d6a8a --- /dev/null +++ b/borsh/tests/test_zero_sized_types.rs @@ -0,0 +1,142 @@ +#![cfg_attr(not(feature = "std"), no_std)] +#![cfg(feature = "derive")] + +#[cfg(not(feature = "std"))] +extern crate alloc; +#[cfg(not(feature = "std"))] +use alloc::{string::ToString, vec, vec::Vec}; + +#[cfg(feature = "std")] +use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, LinkedList, VecDeque}; + +#[cfg(not(feature = "std"))] +use alloc::collections::{BTreeMap, BTreeSet, LinkedList, VecDeque}; +#[cfg(feature = "hashbrown")] +use hashbrown::{HashMap, HashSet}; + +use borsh::from_slice; +use borsh::to_vec; +use borsh::BorshDeserialize; +use borsh::BorshSerialize; + +use borsh::error::ERROR_ZST_FORBIDDEN; +#[derive(BorshDeserialize, BorshSerialize, PartialEq, Debug, Eq, PartialOrd, Ord, Hash)] +struct A(); + +#[test] +fn test_deserialize_vec_of_zst() { + let v = [0u8, 0u8, 0u8, 64u8]; + let res = from_slice::>(&v); + assert_eq!(res.unwrap_err().to_string(), ERROR_ZST_FORBIDDEN); +} + +#[test] +fn test_serialize_vec_of_zst() { + let v = vec![A()]; + let res = to_vec(&v); + assert_eq!(res.unwrap_err().to_string(), ERROR_ZST_FORBIDDEN); +} + +#[test] +fn test_deserialize_vec_deque_of_zst() { + let v = [0u8, 0u8, 0u8, 64u8]; + let res = from_slice::>(&v); + assert_eq!(res.unwrap_err().to_string(), ERROR_ZST_FORBIDDEN); +} + +#[test] +fn test_serialize_vec_deque_of_zst() { + let v: VecDeque = vec![A()].into(); + let res = to_vec(&v); + assert_eq!(res.unwrap_err().to_string(), ERROR_ZST_FORBIDDEN); +} + +#[test] +fn test_deserialize_linked_list_of_zst() { + let v = [0u8, 0u8, 0u8, 64u8]; + let res = from_slice::>(&v); + assert_eq!(res.unwrap_err().to_string(), ERROR_ZST_FORBIDDEN); +} + +#[test] +fn test_serialize_linked_list_of_zst() { + let v: LinkedList = vec![A()].into_iter().collect(); + let res = to_vec(&v); + assert_eq!(res.unwrap_err().to_string(), ERROR_ZST_FORBIDDEN); +} + +#[test] +fn test_deserialize_btreeset_of_zst() { + let v = [0u8, 0u8, 0u8, 64u8]; + let res = from_slice::>(&v); + assert_eq!(res.unwrap_err().to_string(), ERROR_ZST_FORBIDDEN); +} + +#[test] +fn test_serialize_btreeset_of_zst() { + let v: BTreeSet = vec![A()].into_iter().collect(); + let res = to_vec(&v); + assert_eq!(res.unwrap_err().to_string(), ERROR_ZST_FORBIDDEN); +} + +#[cfg(hash_collections)] +#[test] +fn test_deserialize_hashset_of_zst() { + let v = [0u8, 0u8, 0u8, 64u8]; + let res = from_slice::>(&v); + assert_eq!(res.unwrap_err().to_string(), ERROR_ZST_FORBIDDEN); +} + +#[cfg(hash_collections)] +#[test] +fn test_serialize_hashset_of_zst() { + let v: HashSet = vec![A()].into_iter().collect(); + let res = to_vec(&v); + assert_eq!(res.unwrap_err().to_string(), ERROR_ZST_FORBIDDEN); +} + +#[test] +fn test_deserialize_btreemap_of_zst() { + let v = [0u8, 0u8, 0u8, 64u8]; + let res = from_slice::>(&v); + assert_eq!(res.unwrap_err().to_string(), ERROR_ZST_FORBIDDEN); +} + +#[test] +fn test_serialize_btreemap_of_zst() { + let v: BTreeMap = vec![(A(), 42u64)].into_iter().collect(); + let res = to_vec(&v); + assert_eq!(res.unwrap_err().to_string(), ERROR_ZST_FORBIDDEN); +} + +#[cfg(hash_collections)] +#[test] +fn test_deserialize_hashmap_of_zst() { + let v = [0u8, 0u8, 0u8, 64u8]; + let res = from_slice::>(&v); + assert_eq!(res.unwrap_err().to_string(), ERROR_ZST_FORBIDDEN); +} + +#[cfg(hash_collections)] +#[test] +fn test_serialize_hashmap_of_zst() { + let v: HashMap = vec![(A(), 42u64)].into_iter().collect(); + let res = to_vec(&v); + assert_eq!(res.unwrap_err().to_string(), ERROR_ZST_FORBIDDEN); +} + +#[derive(BorshDeserialize, BorshSerialize, PartialEq, Debug)] +struct B(u32); +#[test] +fn test_deserialize_non_zst() { + let v = [1, 0, 0, 0, 64, 0, 0, 0]; + let res = Vec::::try_from_slice(&v); + assert!(res.is_ok()); +} + +#[test] +fn test_serialize_non_zst() { + let v = vec![B(1)]; + let res = to_vec(&v); + assert!(res.is_ok()); +}