From 9e1d4152a84f56725dbab10876f235409d5d46d4 Mon Sep 17 00:00:00 2001 From: Michal Nazarewicz Date: Fri, 10 Feb 2023 05:14:13 +0100 Subject: [PATCH] decode: use exact decoded length rather than estimation Fixes: https://github.com/marshallpierce/rust-base64/issues/210 Fixes: https://github.com/marshallpierce/rust-base64/issues/212 --- src/decode.rs | 86 +++++++++++++- src/encode.rs | 25 ++-- src/engine/general_purpose/decode.rs | 72 +----------- src/engine/general_purpose/decode_suffix.rs | 10 +- src/engine/general_purpose/mod.rs | 14 +-- src/engine/mod.rs | 79 ++++--------- src/engine/naive.rs | 119 ++++++-------------- src/engine/tests.rs | 85 +++----------- src/lib.rs | 4 +- src/read/decoder.rs | 11 +- 10 files changed, 186 insertions(+), 319 deletions(-) diff --git a/src/decode.rs b/src/decode.rs index 7d29fdc..c2cf2f3 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -1,4 +1,4 @@ -use crate::engine::{general_purpose::STANDARD, DecodeEstimate, Engine}; +use crate::engine::{general_purpose::STANDARD, Engine}; #[cfg(any(feature = "alloc", feature = "std", test))] use alloc::vec::Vec; use core::fmt; @@ -130,6 +130,73 @@ pub fn decode_engine_slice>( engine.decode_slice(input, output) } +/// Returns the decoded size of the `encoded` input assuming the input is valid +/// base64 string. +/// +/// Assumes input is a valid base64-encoded string. Result is unspecified if it +/// isn’t. +/// +/// If you don’t need a precise length of the decoded string, you can use +/// [`decoded_len_estimate`] function instead. It’s faster and provides an +/// estimate which is only at most two bytes off from the real length. +/// +/// # Examples +/// +/// ``` +/// use base64::decoded_len; +/// +/// assert_eq!(0, decoded_len(b"")); +/// assert_eq!(1, decoded_len(b"AA")); +/// assert_eq!(2, decoded_len(b"AAA")); +/// assert_eq!(3, decoded_len(b"AAAA")); +/// assert_eq!(1, decoded_len(b"AA==")); +/// assert_eq!(2, decoded_len(b"AAA=")); +/// ``` +pub fn decoded_len(encoded: impl AsRef<[u8]>) -> usize { + let encoded = encoded.as_ref(); + if encoded.len() < 2 { + return 0; + } + let is_pad = |idx| (encoded[encoded.len() - idx] == b'=') as usize; + let len = encoded.len() - is_pad(1) - is_pad(2); + match len % 4 { + 0 => len / 4 * 3, + remainder => len / 4 * 3 + remainder - 1, + } +} + +#[test] +fn test_decoded_len() { + for chunks in 0..25 { + let mut input = vec![b'A'; chunks * 4 + 4]; + assert_eq!(chunks * 3 + 0, decoded_len(&input[..chunks * 4])); + assert_eq!(chunks * 3 + 1, decoded_len(&input[..chunks * 4 + 2])); + assert_eq!(chunks * 3 + 2, decoded_len(&input[..chunks * 4 + 3])); + assert_eq!(chunks * 3 + 3, decoded_len(&input[..chunks * 4 + 4])); + + input[chunks * 4 + 3] = b'='; + assert_eq!(chunks * 3 + 1, decoded_len(&input[..chunks * 4 + 2])); + assert_eq!(chunks * 3 + 2, decoded_len(&input[..chunks * 4 + 3])); + assert_eq!(chunks * 3 + 2, decoded_len(&input[..chunks * 4 + 4])); + input[chunks * 4 + 2] = b'='; + assert_eq!(chunks * 3 + 1, decoded_len(&input[..chunks * 4 + 2])); + assert_eq!(chunks * 3 + 1, decoded_len(&input[..chunks * 4 + 3])); + assert_eq!(chunks * 3 + 1, decoded_len(&input[..chunks * 4 + 4])); + } + + // Mustn’t panic or overflow if given bogus input. + for len in 1..100 { + let mut input = vec![b'A'; len]; + let got = decoded_len(&input); + debug_assert!(got <= len); + for padding in 1..=len.min(10) { + input[len - padding] = b'='; + let got = decoded_len(&input); + debug_assert!(got <= len); + } + } +} + /// Returns a conservative estimate of the decoded size of `encoded_len` base64 symbols (rounded up /// to the next group of 3 decoded bytes). /// @@ -141,6 +208,7 @@ pub fn decode_engine_slice>( /// ``` /// use base64::decoded_len_estimate; /// +/// assert_eq!(0, decoded_len_estimate(0)); /// assert_eq!(3, decoded_len_estimate(1)); /// assert_eq!(3, decoded_len_estimate(2)); /// assert_eq!(3, decoded_len_estimate(3)); @@ -149,9 +217,19 @@ pub fn decode_engine_slice>( /// assert_eq!(6, decoded_len_estimate(5)); /// ``` pub fn decoded_len_estimate(encoded_len: usize) -> usize { - STANDARD - .internal_decoded_len_estimate(encoded_len) - .decoded_len_estimate() + (encoded_len / 4 + (encoded_len % 4 > 0) as usize) * 3 +} + +#[test] +fn test_decode_len_estimate() { + for chunks in 0..250 { + assert_eq!(chunks * 3, decoded_len_estimate(chunks * 4)); + assert_eq!(chunks * 3 + 3, decoded_len_estimate(chunks * 4 + 1)); + assert_eq!(chunks * 3 + 3, decoded_len_estimate(chunks * 4 + 2)); + assert_eq!(chunks * 3 + 3, decoded_len_estimate(chunks * 4 + 3)); + } + // Mustn’t panic or overflow. + assert_eq!(usize::MAX / 4 * 3 + 3, decoded_len_estimate(usize::MAX)); } #[cfg(test)] diff --git a/src/encode.rs b/src/encode.rs index 15b903d..5d9016f 100644 --- a/src/encode.rs +++ b/src/encode.rs @@ -96,24 +96,15 @@ pub(crate) fn encode_with_padding( /// input lengths in approximately the top quarter of the range of `usize`. pub fn encoded_len(bytes_len: usize, padding: bool) -> Option { let rem = bytes_len % 3; - - let complete_input_chunks = bytes_len / 3; - let complete_chunk_output = complete_input_chunks.checked_mul(4); - - if rem > 0 { - if padding { - complete_chunk_output.and_then(|c| c.checked_add(4)) - } else { - let encoded_rem = match rem { - 1 => 2, - 2 => 3, - _ => unreachable!("Impossible remainder"), - }; - complete_chunk_output.and_then(|c| c.checked_add(encoded_rem)) - } + let chunks = bytes_len / 3 + (rem > 0 && padding) as usize; + let encoded_len = chunks.checked_mul(4)?; + Some(if !padding && rem > 0 { + // This doesn’t overflow. encoded_len is divisible by four thus it’s at + // most usize::MAX - 3. rem ≤ 2 so we’re adding at most three. + encoded_len + rem + 1 } else { - complete_chunk_output - } + encoded_len + }) } /// Write padding characters. diff --git a/src/engine/general_purpose/decode.rs b/src/engine/general_purpose/decode.rs index 5e30e45..0d9b13e 100644 --- a/src/engine/general_purpose/decode.rs +++ b/src/engine/general_purpose/decode.rs @@ -1,5 +1,5 @@ use crate::{ - engine::{general_purpose::INVALID_VALUE, DecodeEstimate, DecodePaddingMode}, + engine::{general_purpose::INVALID_VALUE, DecodePaddingMode}, DecodeError, PAD_BYTE, }; @@ -21,30 +21,6 @@ const INPUT_BLOCK_LEN: usize = CHUNKS_PER_FAST_LOOP_BLOCK * INPUT_CHUNK_LEN; const DECODED_BLOCK_LEN: usize = CHUNKS_PER_FAST_LOOP_BLOCK * DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX; -#[doc(hidden)] -pub struct GeneralPurposeEstimate { - /// Total number of decode chunks, including a possibly partial last chunk - num_chunks: usize, - decoded_len_estimate: usize, -} - -impl GeneralPurposeEstimate { - pub(crate) fn new(encoded_len: usize) -> Self { - // Formulas that won't overflow - Self { - num_chunks: encoded_len / INPUT_CHUNK_LEN - + (encoded_len % INPUT_CHUNK_LEN > 0) as usize, - decoded_len_estimate: (encoded_len / 4 + (encoded_len % 4 > 0) as usize) * 3, - } - } -} - -impl DecodeEstimate for GeneralPurposeEstimate { - fn decoded_len_estimate(&self) -> usize { - self.decoded_len_estimate - } -} - /// Helper to avoid duplicating num_chunks calculation, which is costly on short inputs. /// Returns the number of bytes written, or an error. // We're on the fragile edge of compiler heuristics here. If this is not inlined, slow. If this is @@ -53,12 +29,11 @@ impl DecodeEstimate for GeneralPurposeEstimate { #[inline] pub(crate) fn decode_helper( input: &[u8], - estimate: GeneralPurposeEstimate, output: &mut [u8], decode_table: &[u8; 256], decode_allow_trailing_bits: bool, padding_mode: DecodePaddingMode, -) -> Result { +) -> Result<(), DecodeError> { let remainder_len = input.len() % INPUT_CHUNK_LEN; // Because the fast decode loop writes in groups of 8 bytes (unrolled to @@ -99,7 +74,8 @@ pub(crate) fn decode_helper( }; // rounded up to include partial chunks - let mut remaining_chunks = estimate.num_chunks; + let mut remaining_chunks = + input.len() / INPUT_CHUNK_LEN + (input.len() % INPUT_CHUNK_LEN > 0) as usize; let mut input_index = 0; let mut output_index = 0; @@ -340,44 +316,4 @@ mod tests { decode_chunk(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 0, 0], &output); } - - #[test] - fn estimate_short_lengths() { - for (range, (num_chunks, decoded_len_estimate)) in [ - (0..=0, (0, 0)), - (1..=4, (1, 3)), - (5..=8, (1, 6)), - (9..=12, (2, 9)), - (13..=16, (2, 12)), - (17..=20, (3, 15)), - ] { - for encoded_len in range { - let estimate = GeneralPurposeEstimate::new(encoded_len); - assert_eq!(num_chunks, estimate.num_chunks); - assert_eq!(decoded_len_estimate, estimate.decoded_len_estimate); - } - } - } - - #[test] - fn estimate_via_u128_inflation() { - // cover both ends of usize - (0..1000) - .chain(usize::MAX - 1000..=usize::MAX) - .for_each(|encoded_len| { - // inflate to 128 bit type to be able to safely use the easy formulas - let len_128 = encoded_len as u128; - - let estimate = GeneralPurposeEstimate::new(encoded_len); - assert_eq!( - ((len_128 + (INPUT_CHUNK_LEN - 1) as u128) / (INPUT_CHUNK_LEN as u128)) - as usize, - estimate.num_chunks - ); - assert_eq!( - ((len_128 + 3) / 4 * 3) as usize, - estimate.decoded_len_estimate - ); - }) - } } diff --git a/src/engine/general_purpose/decode_suffix.rs b/src/engine/general_purpose/decode_suffix.rs index 5652035..28f5f20 100644 --- a/src/engine/general_purpose/decode_suffix.rs +++ b/src/engine/general_purpose/decode_suffix.rs @@ -6,8 +6,9 @@ use crate::{ /// Decode the last 1-8 bytes, checking for trailing set bits and padding per the provided /// parameters. /// -/// Returns the total number of bytes decoded, including the ones indicated as already written by -/// `output_index`. +/// Expects output to be large enough to fit decoded data exactly without any +/// unused space. In debug builds panics if final output length (`output_index` +/// plus any bytes written by this function) doesn’t equal length of the output. pub(crate) fn decode_suffix( input: &[u8], input_index: usize, @@ -16,7 +17,7 @@ pub(crate) fn decode_suffix( decode_table: &[u8; 256], decode_allow_trailing_bits: bool, padding_mode: DecodePaddingMode, -) -> Result { +) -> Result<(), DecodeError> { // Decode any leftovers that aren't a complete input block of 8 bytes. // Use a u64 as a stack-resident 8 byte buffer. let mut leftover_bits: u64 = 0; @@ -157,5 +158,6 @@ pub(crate) fn decode_suffix( leftover_bits_appended_to_buf += 8; } - Ok(output_index) + debug_assert_eq!(output.len(), output_index); + Ok(()) } diff --git a/src/engine/general_purpose/mod.rs b/src/engine/general_purpose/mod.rs index af8897b..e010d9b 100644 --- a/src/engine/general_purpose/mod.rs +++ b/src/engine/general_purpose/mod.rs @@ -9,7 +9,6 @@ use core::convert::TryInto; mod decode; pub(crate) mod decode_suffix; -pub use decode::GeneralPurposeEstimate; pub(crate) const INVALID_VALUE: u8 = 255; @@ -40,7 +39,6 @@ impl GeneralPurpose { impl super::Engine for GeneralPurpose { type Config = GeneralPurposeConfig; - type DecodeEstimate = GeneralPurposeEstimate; fn internal_encode(&self, input: &[u8], output: &mut [u8]) -> usize { let mut input_index: usize = 0; @@ -161,19 +159,9 @@ impl super::Engine for GeneralPurpose { output_index } - fn internal_decoded_len_estimate(&self, input_len: usize) -> Self::DecodeEstimate { - GeneralPurposeEstimate::new(input_len) - } - - fn internal_decode( - &self, - input: &[u8], - output: &mut [u8], - estimate: Self::DecodeEstimate, - ) -> Result { + fn internal_decode(&self, input: &[u8], output: &mut [u8]) -> Result<(), DecodeError> { decode::decode_helper( input, - estimate, output, &self.decode_table, self.config.decode_allow_trailing_bits, diff --git a/src/engine/mod.rs b/src/engine/mod.rs index 7467a91..9388640 100644 --- a/src/engine/mod.rs +++ b/src/engine/mod.rs @@ -39,8 +39,6 @@ pub use general_purpose::{GeneralPurpose, GeneralPurposeConfig}; pub trait Engine: Send + Sync { /// The config type used by this engine type Config: Config; - /// The decode estimate used by this engine - type DecodeEstimate: DecodeEstimate; /// This is not meant to be called directly; it is only for `Engine` implementors. /// See the other `encode*` functions on this trait. @@ -57,23 +55,11 @@ pub trait Engine: Send + Sync { #[doc(hidden)] fn internal_encode(&self, input: &[u8], output: &mut [u8]) -> usize; - /// This is not meant to be called directly; it is only for `Engine` implementors. - /// - /// As an optimization to prevent the decoded length from being calculated twice, it is - /// sometimes helpful to have a conservative estimate of the decoded size before doing the - /// decoding, so this calculation is done separately and passed to [Engine::decode()] as needed. - #[doc(hidden)] - fn internal_decoded_len_estimate(&self, input_len: usize) -> Self::DecodeEstimate; - /// This is not meant to be called directly; it is only for `Engine` implementors. /// See the other `decode*` functions on this trait. /// - /// Decode `input` base64 bytes into the `output` buffer. - /// - /// `decode_estimate` is the result of [Engine::internal_decoded_len_estimate()], which is passed in to avoid - /// calculating it again (expensive on short inputs).` - /// - /// Returns the number of bytes written to `output`. + /// Decode `input` base64 bytes into the `output` buffer. `output` will + /// have the exact amount of space to fit the decoded encoded value. /// /// Each complete 4-byte chunk of encoded data decodes to 3 bytes of decoded data, but this /// function must also handle the final possibly partial chunk. @@ -81,8 +67,6 @@ pub trait Engine: Send + Sync { /// the trailing 2 or 3 bytes must decode to 1 or 2 bytes, respectively, as per the /// [RFC](https://tools.ietf.org/html/rfc4648#section-3.5). /// - /// Decoding must not write any bytes into the output slice other than the decoded data. - /// /// Non-canonical trailing bits in the final tokens or non-canonical padding must be reported as /// errors unless the engine is configured otherwise. /// @@ -90,12 +74,7 @@ pub trait Engine: Send + Sync { /// /// Panics if `output` is too small. #[doc(hidden)] - fn internal_decode( - &self, - input: &[u8], - output: &mut [u8], - decode_estimate: Self::DecodeEstimate, - ) -> Result; + fn internal_decode(&self, input: &[u8], output: &mut [u8]) -> Result<(), DecodeError>; /// Returns the config for this engine. fn config(&self) -> &Self::Config; @@ -224,13 +203,8 @@ pub trait Engine: Send + Sync { #[cfg(any(feature = "alloc", feature = "std", test))] fn decode>(&self, input: T) -> Result, DecodeError> { let input_bytes = input.as_ref(); - - let estimate = self.internal_decoded_len_estimate(input_bytes.len()); - let mut buffer = vec![0; estimate.decoded_len_estimate()]; - - let bytes_written = self.internal_decode(input_bytes, &mut buffer, estimate)?; - buffer.truncate(bytes_written); - + let mut buffer = vec![0; crate::decoded_len(input_bytes)]; + self.internal_decode(input_bytes, &mut buffer)?; Ok(buffer) } @@ -273,28 +247,25 @@ pub trait Engine: Send + Sync { let starting_output_len = buffer.len(); - let estimate = self.internal_decoded_len_estimate(input_bytes.len()); - let total_len_estimate = estimate - .decoded_len_estimate() + let decoded_len = crate::decoded_len(input_bytes); + let total_len = decoded_len .checked_add(starting_output_len) .expect("Overflow when calculating output buffer length"); - buffer.resize(total_len_estimate, 0); + buffer.resize(total_len, 0); let buffer_slice = &mut buffer.as_mut_slice()[starting_output_len..]; - let bytes_written = self.internal_decode(input_bytes, buffer_slice, estimate)?; - - buffer.truncate(starting_output_len + bytes_written); + self.internal_decode(input_bytes, buffer_slice)?; Ok(()) } /// Decode the input into the provided output slice. /// - /// Returns an error if `output` is smaller than the estimated decoded length. + /// Returns an error if `output` is smaller than the decoded length. /// /// This will not write any bytes past exactly what is decoded (no stray garbage bytes at the end). /// - /// See [crate::decoded_len_estimate] for calculating buffer sizes. + /// See [crate::decoded_len] for calculating buffer sizes. /// /// See [Engine::decode_slice_unchecked] for a version that panics instead of returning an error /// if the output buffer is too small. @@ -305,20 +276,20 @@ pub trait Engine: Send + Sync { ) -> Result { let input_bytes = input.as_ref(); - let estimate = self.internal_decoded_len_estimate(input_bytes.len()); - if output.len() < estimate.decoded_len_estimate() { + let decoded_len = crate::decoded_len(input_bytes); + if output.len() < decoded_len { return Err(DecodeSliceError::OutputSliceTooSmall); } - self.internal_decode(input_bytes, output, estimate) - .map_err(|e| e.into()) + self.internal_decode(input_bytes, &mut output[..decoded_len])?; + Ok(decoded_len) } /// Decode the input into the provided output slice. /// /// This will not write any bytes past exactly what is decoded (no stray garbage bytes at the end). /// - /// See [crate::decoded_len_estimate] for calculating buffer sizes. + /// See [crate::decoded_len] for calculating buffer sizes. /// /// See [Engine::decode_slice] for a version that returns an error instead of panicking if the output /// buffer is too small. @@ -333,11 +304,9 @@ pub trait Engine: Send + Sync { ) -> Result { let input_bytes = input.as_ref(); - self.internal_decode( - input_bytes, - output, - self.internal_decoded_len_estimate(input_bytes.len()), - ) + let decoded_len = crate::decoded_len(input_bytes); + self.internal_decode(input_bytes, &mut output[..decoded_len])?; + Ok(decoded_len) } } @@ -359,13 +328,11 @@ pub trait Config { /// /// Implementors may store relevant data here when constructing this to avoid having to calculate /// them again during actual decoding. -pub trait DecodeEstimate { - /// Returns a conservative (err on the side of too big) estimate of the decoded length to use - /// for pre-allocating buffers, etc. +pub trait DecodedLength { + /// Returns the decoded length to use for pre-allocating buffers, etc. /// - /// The estimate must be no larger than the next largest complete triple of decoded bytes. - /// That is, the final quad of tokens to decode may be assumed to be complete with no padding. - fn decoded_len_estimate(&self) -> usize; + /// The value must be exactly equal the length of the decoded value. + fn decoded_len(&self) -> usize; } /// Controls how pad bytes are handled when decoding. diff --git a/src/engine/naive.rs b/src/engine/naive.rs index 6665c5e..eb213c3 100644 --- a/src/engine/naive.rs +++ b/src/engine/naive.rs @@ -2,7 +2,7 @@ use crate::{ alphabet::Alphabet, engine::{ general_purpose::{self, decode_table, encode_table}, - Config, DecodeEstimate, DecodePaddingMode, Engine, + Config, DecodePaddingMode, Engine, }, DecodeError, PAD_BYTE, }; @@ -41,7 +41,6 @@ impl Naive { impl Engine for Naive { type Config = NaiveConfig; - type DecodeEstimate = NaiveEstimate; fn internal_encode(&self, input: &[u8], output: &mut [u8]) -> usize { // complete chunks first @@ -103,70 +102,51 @@ impl Engine for Naive { output_index } - fn internal_decoded_len_estimate(&self, input_len: usize) -> Self::DecodeEstimate { - NaiveEstimate::new(input_len) - } - - fn internal_decode( - &self, - input: &[u8], - output: &mut [u8], - estimate: Self::DecodeEstimate, - ) -> Result { - if estimate.rem == 1 { - // trailing whitespace is so common that it's worth it to check the last byte to - // possibly return a better error message - if let Some(b) = input.last() { - if *b != PAD_BYTE - && self.decode_table[*b as usize] == general_purpose::INVALID_VALUE - { - return Err(DecodeError::InvalidByte(input.len() - 1, *b)); + fn internal_decode(&self, input: &[u8], output: &mut [u8]) -> Result<(), DecodeError> { + let full_chunks = match input.len() % 4 { + 0 => { + if input.is_empty() { + debug_assert!(output.is_empty()); + return Ok(()); + } else { + input.len() / Self::DECODE_INPUT_CHUNK_SIZE - 1 } } - - return Err(DecodeError::InvalidLength); - } - - let mut input_index = 0_usize; + 1 => { + // Trailing whitespace is so common that it's worth it to check + // the last byte to possibly return a better error message + let last = input[input.len() - 1]; + let value = self.decode_table[last as usize]; + if last != PAD_BYTE && value == general_purpose::INVALID_VALUE { + return Err(DecodeError::InvalidByte(input.len() - 1, last)); + } else { + return Err(DecodeError::InvalidLength); + } + } + _ => input.len() / Self::DECODE_INPUT_CHUNK_SIZE, + }; + let full_bytes = full_chunks * Self::DECODE_INPUT_CHUNK_SIZE; let mut output_index = 0_usize; const BOTTOM_BYTE: u32 = 0xFF; - // can only use the main loop on non-trailing chunks - if input.len() > Self::DECODE_INPUT_CHUNK_SIZE { - // skip the last chunk, whether it's partial or full, since it might - // have padding, and start at the beginning of the chunk before that - let last_complete_chunk_start_index = estimate.complete_chunk_len - - if estimate.rem == 0 { - // Trailing chunk is also full chunk, so there must be at least 2 chunks, and - // this won't underflow - Self::DECODE_INPUT_CHUNK_SIZE * 2 - } else { - // Trailing chunk is partial, so it's already excluded in - // complete_chunk_len - Self::DECODE_INPUT_CHUNK_SIZE - }; - - while input_index <= last_complete_chunk_start_index { - let chunk = &input[input_index..input_index + Self::DECODE_INPUT_CHUNK_SIZE]; - let decoded_int: u32 = self.decode_byte_into_u32(input_index, chunk[0])?.shl(18) - | self - .decode_byte_into_u32(input_index + 1, chunk[1])? - .shl(12) - | self.decode_byte_into_u32(input_index + 2, chunk[2])?.shl(6) - | self.decode_byte_into_u32(input_index + 3, chunk[3])?; - - output[output_index] = decoded_int.shr(16_u8).bitand(BOTTOM_BYTE) as u8; - output[output_index + 1] = decoded_int.shr(8_u8).bitand(BOTTOM_BYTE) as u8; - output[output_index + 2] = decoded_int.bitand(BOTTOM_BYTE) as u8; - - input_index += Self::DECODE_INPUT_CHUNK_SIZE; - output_index += 3; - } + for input_index in (0..full_bytes).step_by(Self::DECODE_INPUT_CHUNK_SIZE) { + let chunk = &input[input_index..input_index + Self::DECODE_INPUT_CHUNK_SIZE]; + let decoded_int: u32 = self.decode_byte_into_u32(input_index, chunk[0])?.shl(18) + | self + .decode_byte_into_u32(input_index + 1, chunk[1])? + .shl(12) + | self.decode_byte_into_u32(input_index + 2, chunk[2])?.shl(6) + | self.decode_byte_into_u32(input_index + 3, chunk[3])?; + + output[output_index] = decoded_int.shr(16_u8).bitand(BOTTOM_BYTE) as u8; + output[output_index + 1] = decoded_int.shr(8_u8).bitand(BOTTOM_BYTE) as u8; + output[output_index + 2] = decoded_int.bitand(BOTTOM_BYTE) as u8; + output_index += 3; } general_purpose::decode_suffix::decode_suffix( input, - input_index, + full_bytes, output, output_index, &self.decode_table, @@ -180,31 +160,6 @@ impl Engine for Naive { } } -pub struct NaiveEstimate { - /// remainder from dividing input by `Naive::DECODE_CHUNK_SIZE` - rem: usize, - /// Length of input that is in complete `Naive::DECODE_CHUNK_SIZE`-length chunks - complete_chunk_len: usize, -} - -impl NaiveEstimate { - fn new(input_len: usize) -> Self { - let rem = input_len % Naive::DECODE_INPUT_CHUNK_SIZE; - let complete_chunk_len = input_len - rem; - - Self { - rem, - complete_chunk_len, - } - } -} - -impl DecodeEstimate for NaiveEstimate { - fn decoded_len_estimate(&self) -> usize { - ((self.complete_chunk_len / 4) + ((self.rem > 0) as usize)) * 3 - } -} - #[derive(Clone, Copy, Debug)] pub struct NaiveConfig { pub encode_padding: bool, diff --git a/src/engine/tests.rs b/src/engine/tests.rs index d2851ec..15c8c12 100644 --- a/src/engine/tests.rs +++ b/src/engine/tests.rs @@ -14,7 +14,7 @@ use crate::{ alphabet::{Alphabet, STANDARD}, encode::add_padding, encoded_len, - engine::{general_purpose, naive, Config, DecodeEstimate, DecodePaddingMode, Engine}, + engine::{general_purpose, naive, Config, DecodePaddingMode, Engine}, tests::{assert_encode_sanity, random_alphabet, random_config}, DecodeError, PAD_BYTE, }; @@ -1154,81 +1154,32 @@ fn decode_into_slice_fits_in_precisely_sized_slice(engine_wrap let mut rng = rngs::SmallRng::from_entropy(); for _ in 0..10_000 { - orig_data.clear(); - encoded_data.clear(); - decode_buf.clear(); - - let input_len = input_len_range.sample(&mut rng); - - for _ in 0..input_len { - orig_data.push(rng.gen()); - } + orig_data.resize(input_len_range.sample(&mut rng), 0); + rng.fill(&mut orig_data[..]); let engine = E::random(&mut rng); + encoded_data.clear(); engine.encode_string(&orig_data, &mut encoded_data); - assert_encode_sanity(&encoded_data, engine.config().encode_padding(), input_len); - - decode_buf.resize(input_len, 0); - - // decode into the non-empty buf - let decode_bytes_written = engine - .decode_slice_unchecked(encoded_data.as_bytes(), &mut decode_buf[..]) - .unwrap(); + assert_encode_sanity( + &encoded_data, + engine.config().encode_padding(), + orig_data.len(), + ); - assert_eq!(orig_data.len(), decode_bytes_written); + decode_buf.clear(); + decode_buf.resize(orig_data.len(), 0); + let res = engine.decode_slice(encoded_data.as_bytes(), &mut decode_buf[..]); + assert_eq!(Ok(orig_data.len()), res); assert_eq!(orig_data, decode_buf); - } -} -#[apply(all_engines)] -fn decode_length_estimate_delta(engine_wrapper: E) { - for engine in [E::standard(), E::standard_unpadded()] { - for &padding in &[true, false] { - for orig_len in 0..1000 { - let encoded_len = encoded_len(orig_len, padding).unwrap(); - - let decoded_estimate = engine - .internal_decoded_len_estimate(encoded_len) - .decoded_len_estimate(); - assert!(decoded_estimate >= orig_len); - assert!( - decoded_estimate - orig_len < 3, - "estimate: {}, encoded: {}, orig: {}", - decoded_estimate, - encoded_len, - orig_len - ); - } - } + decode_buf.clear(); + decode_buf.resize(orig_data.len(), 0); + let res = engine.decode_slice_unchecked(encoded_data.as_bytes(), &mut decode_buf[..]); + assert_eq!(Ok(orig_data.len()), res); + assert_eq!(orig_data, decode_buf); } } -#[apply(all_engines)] -fn estimate_via_u128_inflation(engine_wrapper: E) { - // cover both ends of usize - (0..1000) - .chain(usize::MAX - 1000..=usize::MAX) - .for_each(|encoded_len| { - // inflate to 128 bit type to be able to safely use the easy formulas - let len_128 = encoded_len as u128; - - let estimate = E::standard() - .internal_decoded_len_estimate(encoded_len) - .decoded_len_estimate(); - - // This check is a little too strict: it requires using the (len + 3) / 4 * 3 formula - // or equivalent, but until other engines come along that use a different formula - // requiring that we think more carefully about what the allowable criteria are, this - // will do. - assert_eq!( - ((len_128 + 3) / 4 * 3) as usize, - estimate, - "enc len {}", - encoded_len - ); - }) -} - /// Returns a tuple of the original data length, the encoded data length (just data), and the length including padding. /// /// Vecs provided should be empty. diff --git a/src/lib.rs b/src/lib.rs index cc9d628..2e89011 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -169,7 +169,9 @@ mod decode; #[cfg(any(feature = "alloc", feature = "std", test))] pub use crate::decode::{decode, decode_engine, decode_engine_vec}; #[allow(deprecated)] -pub use crate::decode::{decode_engine_slice, decoded_len_estimate, DecodeError, DecodeSliceError}; +pub use crate::decode::{ + decode_engine_slice, decoded_len, decoded_len_estimate, DecodeError, DecodeSliceError, +}; pub mod prelude; diff --git a/src/read/decoder.rs b/src/read/decoder.rs index 4888c9c..091c089 100644 --- a/src/read/decoder.rs +++ b/src/read/decoder.rs @@ -132,13 +132,10 @@ impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> { debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE); debug_assert!(!buf.is_empty()); - let decoded = self - .engine - .internal_decode( - &self.b64_buffer[self.b64_offset..self.b64_offset + num_bytes], - buf, - self.engine.internal_decoded_len_estimate(num_bytes), - ) + let input_bytes = &self.b64_buffer[self.b64_offset..self.b64_offset + num_bytes]; + let decoded = crate::decoded_len(input_bytes); + self.engine + .internal_decode(input_bytes, &mut buf[..decoded]) .map_err(|e| match e { DecodeError::InvalidByte(offset, byte) => { DecodeError::InvalidByte(self.total_b64_decoded + offset, byte)