diff --git a/subtle-encoding/src/base64.rs b/subtle-encoding/src/base64.rs index 0c7ad4f0..69fc07b2 100644 --- a/subtle-encoding/src/base64.rs +++ b/subtle-encoding/src/base64.rs @@ -9,7 +9,7 @@ use super::{ Encoding, - Error::{self, EncodingInvalid}, + Error::{self, *}, }; #[cfg(feature = "alloc")] use crate::prelude::*; @@ -41,6 +41,10 @@ pub fn decode>(encoded_bytes: B) -> Result, Error> { impl Encoding for Base64 { fn encode_to_slice(&self, src: &[u8], dst: &mut [u8]) -> Result { + if self.encoded_len(src) > dst.len() { + return Err(LengthInvalid); + } + let mut src_offset: usize = 0; let mut dst_offset: usize = 0; let mut src_length: usize = src.len(); @@ -79,6 +83,13 @@ impl Encoding for Base64 { } fn decode_to_slice(&self, src: &[u8], dst: &mut [u8]) -> Result { + // TODO: constant-time whitespace tolerance + if !src.is_empty() && char::from(src[src.len() - 1]).is_whitespace() { + return Err(TrailingWhitespace); + } + + ensure!(self.decoded_len(src)? <= dst.len(), LengthInvalid); + let mut src_offset: usize = 0; let mut dst_offset: usize = 0; let mut src_length: usize = src.len(); @@ -250,29 +261,28 @@ mod tests { raw: b"\xFF\xFF\xFF\xFF\xFF", base64: b"//////8=", }, + Base64Vector { + raw: b"\x40\xC1\x3F\xBD\x05\x4C\x72\x2A\xA3\xC2\xF2\x11\x73\xC0\x69\xEA\ + \x49\x7D\x35\x29\x6B\xCC\x24\x65\xF6\xF9\xD0\x41\x08\x7B\xD7\xA9", + base64: b"QME/vQVMciqjwvIRc8Bp6kl9NSlrzCRl9vnQQQh716k=", + }, ]; #[test] fn encode_test_vectors() { for vector in BASE64_TEST_VECTORS { - // 8 is the size of the largest encoded test vector - let mut out = [0u8; 8]; - let out_len = encoder().encode_to_slice(vector.raw, &mut out).unwrap(); - - assert_eq!(encoder().encoded_len(vector.raw), out_len); - assert_eq!(vector.base64, &out[..out_len]); + let out = encoder().encode(vector.raw); + assert_eq!(encoder().encoded_len(vector.raw), out.len()); + assert_eq!(vector.base64, &out[..]); } } #[test] fn decode_test_vectors() { for vector in BASE64_TEST_VECTORS { - // 5 is the size of the largest decoded test vector - let mut out = [0u8; 5]; - let out_len = encoder().decode_to_slice(vector.base64, &mut out).unwrap(); - - assert_eq!(encoder().decoded_len(vector.base64).unwrap(), out_len); - assert_eq!(vector.raw, &out[..out_len]); + let out = encoder().decode(vector.base64).unwrap(); + assert_eq!(encoder().decoded_len(vector.base64).unwrap(), out.len()); + assert_eq!(vector.raw, &out[..]); } } @@ -289,4 +299,12 @@ mod tests { assert_eq!(decoded.as_slice(), &data[..i]); } } + + #[test] + fn trailing_whitespace() { + assert_eq!( + encoder().decode(&b"QME/vQVMciqjwvIRc8Bp6kl9NSlrzCRl9vnQQQh716k=\n"[..]), + Err(TrailingWhitespace) + ); + } } diff --git a/subtle-encoding/src/bech32/base32.rs b/subtle-encoding/src/bech32/base32.rs index 6c7b1443..3f8ffd9e 100644 --- a/subtle-encoding/src/bech32/base32.rs +++ b/subtle-encoding/src/bech32/base32.rs @@ -1,6 +1,7 @@ -use crate::prelude::*; - -use super::Error; +use crate::{ + error::Error::{self, *}, + prelude::*, +}; /// Encode binary data as base32 pub fn encode(data: &[u8]) -> Vec { @@ -20,10 +21,7 @@ fn convert(data: &[u8], src_base: u32, dst_base: u32) -> Result, Error> for value in data { let v = u32::from(*value); - - if (v >> src_base) != 0 { - return Err(Error::EncodingInvalid); - } + ensure!(v >> src_base == 0, EncodingInvalid); acc = (acc << src_base) | v; bits += src_base; @@ -39,7 +37,7 @@ fn convert(data: &[u8], src_base: u32, dst_base: u32) -> Result, Error> result.push(((acc << (dst_base - bits)) & max) as u8); } } else if bits >= src_base || ((acc << (dst_base - bits)) & max) != 0 { - return Err(Error::PaddingInvalid); + return Err(PaddingInvalid); } Ok(result) diff --git a/subtle-encoding/src/bech32/mod.rs b/subtle-encoding/src/bech32/mod.rs index a1ee48ad..90aaceae 100644 --- a/subtle-encoding/src/bech32/mod.rs +++ b/subtle-encoding/src/bech32/mod.rs @@ -12,8 +12,10 @@ mod base32; mod checksum; use self::checksum::{Checksum, CHECKSUM_SIZE}; -use crate::error::Error; -use crate::prelude::*; +use crate::{ + error::Error::{self, *}, + prelude::*, +}; /// Default separator character pub const DEFAULT_SEPARATOR: char = '1'; @@ -168,37 +170,45 @@ impl Bech32 { let encoded_str = encoded.as_ref(); let encoded_len: usize = encoded_str.len(); - // TODO: support for longer strings - if encoded_len > MAX_LENGTH { - return Err(Error::LengthInvalid); + // TODO: constant-time whitespace tolerance + if encoded_str + .chars() + .last() + .map(|c| c.is_whitespace()) + .unwrap_or(false) + { + return Err(TrailingWhitespace); } + // TODO: support for longer strings + ensure!(encoded_len <= MAX_LENGTH, LengthInvalid); + let pos = encoded_str .rfind(self.separator) - .ok_or_else(|| Error::EncodingInvalid)?; + .ok_or_else(|| EncodingInvalid)?; if pos == encoded_str.len() { - return Err(Error::EncodingInvalid); + return Err(EncodingInvalid); } let hrp = encoded_str[..pos].to_lowercase(); if hrp.is_empty() { - return Err(Error::EncodingInvalid); + return Err(EncodingInvalid); } // Ensure all characters in the human readable part are in a valid range for c in hrp.chars() { match c { '!'...'@' | 'A'...'Z' | '['...'`' | 'a'...'z' | '{'...'~' => (), - _ => return Err(Error::EncodingInvalid), + _ => return Err(EncodingInvalid), } } let encoded_data = &encoded_str[(pos + 1)..]; if encoded_data.len() < CHECKSUM_SIZE { - return Err(Error::LengthInvalid); + return Err(LengthInvalid); } let mut base32_data = Vec::with_capacity(encoded_data.len()); @@ -208,7 +218,7 @@ impl Bech32 { .charset_inverse .get(encoded_byte as usize) .and_then(|byte| *byte) - .ok_or_else(|| Error::EncodingInvalid)?; + .ok_or_else(|| EncodingInvalid)?; base32_data.push(decoded_byte); } @@ -323,24 +333,21 @@ mod tests { #[test] fn hrp_character_out_of_range() { let bech32 = Bech32::default(); - assert_eq!(bech32.decode("\x201nwldj5"), Err(Error::EncodingInvalid)); - assert_eq!(bech32.decode("\x7F1axkwrx"), Err(Error::EncodingInvalid)); + assert_eq!(bech32.decode("\x201nwldj5"), Err(EncodingInvalid)); + assert_eq!(bech32.decode("\x7F1axkwrx"), Err(EncodingInvalid)); } #[test] fn overall_max_length_exceeded() { let too_long: &str = "an84characterslonghumanreadablepartthatcontainsthenumber1andtheexcludedcharactersbio1569pvx"; - assert_eq!( - Bech32::default().decode(too_long), - Err(Error::LengthInvalid) - ); + assert_eq!(Bech32::default().decode(too_long), Err(LengthInvalid)); } #[test] fn no_separator_character() { assert_eq!( Bech32::default().decode("pzry9x0s0muk"), - Err(Error::EncodingInvalid) + Err(EncodingInvalid) ); } @@ -349,32 +356,26 @@ mod tests { for empty_hrp_str in &["1pzry9x0s0muk", "10a06t8", "1qzzfhee"] { assert_eq!( Bech32::default().decode(empty_hrp_str), - Err(Error::EncodingInvalid) + Err(EncodingInvalid) ); } } #[test] fn invalid_data_character() { - assert_eq!( - Bech32::default().decode("x1b4n0q5v"), - Err(Error::EncodingInvalid) - ); + assert_eq!(Bech32::default().decode("x1b4n0q5v"), Err(EncodingInvalid)); } #[test] fn checksum_too_short() { - assert_eq!( - Bech32::default().decode("li1dgmt3"), - Err(Error::LengthInvalid) - ); + assert_eq!(Bech32::default().decode("li1dgmt3"), Err(LengthInvalid)); } #[test] fn invalid_character_in_checksum() { assert_eq!( Bech32::default().decode("de1lg7wt\x7F"), - Err(Error::EncodingInvalid) + Err(EncodingInvalid) ); } @@ -382,16 +383,13 @@ mod tests { fn checksum_calculated_with_uppercase_hrp() { assert_eq!( Bech32::upper_case().decode("A1G7SGD8"), - Err(Error::ChecksumInvalid) + Err(ChecksumInvalid) ); } // NOTE: not in test vectors but worth testing for anyway #[test] fn invalid_mixed_case() { - assert_eq!( - Bech32::default().decode("a12UEL5L"), - Err(Error::EncodingInvalid) - ); + assert_eq!(Bech32::default().decode("a12UEL5L"), Err(EncodingInvalid)); } } diff --git a/subtle-encoding/src/error.rs b/subtle-encoding/src/error.rs index 1a73236b..2ba56594 100644 --- a/subtle-encoding/src/error.rs +++ b/subtle-encoding/src/error.rs @@ -24,6 +24,11 @@ pub enum Error { /// Padding missing/invalid #[fail(display = "padding invalid")] PaddingInvalid, + + /// Trailing whitespace detected + // TODO: handle trailing whitespace? + #[fail(display = "trailing whitespace")] + TrailingWhitespace, } /// Assert that the provided condition is true, or else return the given error diff --git a/subtle-encoding/src/hex.rs b/subtle-encoding/src/hex.rs index 58db216c..e1b69012 100644 --- a/subtle-encoding/src/hex.rs +++ b/subtle-encoding/src/hex.rs @@ -41,7 +41,7 @@ use super::{ Encoding, - Error::{self, EncodingInvalid, LengthInvalid}, + Error::{self, *}, }; #[cfg(feature = "alloc")] use crate::prelude::*; @@ -93,11 +93,16 @@ impl Hex { impl Encoding for Hex { fn encode_to_slice(&self, src: &[u8], dst: &mut [u8]) -> Result { + if self.encoded_len(src) > dst.len() { + return Err(LengthInvalid); + } + for (i, src_byte) in src.iter().enumerate() { let offset = i * 2; dst[offset] = self.case.encode_nibble(src_byte >> 4); dst[offset + 1] = self.case.encode_nibble(src_byte & 0x0f); } + Ok(src.len() * 2) } @@ -106,6 +111,11 @@ impl Encoding for Hex { } fn decode_to_slice(&self, src: &[u8], dst: &mut [u8]) -> Result { + // TODO: constant-time whitespace tolerance + if !src.is_empty() && char::from(src[src.len() - 1]).is_whitespace() { + return Err(TrailingWhitespace); + } + let dst_length = self.decoded_len(src)?; ensure!(dst_length <= dst.len(), LengthInvalid);