From 6fbe921640c755ab59c3ba7f31a932945c202c8a Mon Sep 17 00:00:00 2001 From: Marshall Pierce <575695+marshallpierce@users.noreply.github.com> Date: Fri, 6 Aug 2021 13:53:57 -0600 Subject: [PATCH] Make constructing an Alphabet from a str `const`. Not useful yet since unwrap() and friends aren't const, but some future rust version can make use of it. --- src/alphabet.rs | 234 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 160 insertions(+), 74 deletions(-) diff --git a/src/alphabet.rs b/src/alphabet.rs index 650de43..834f402 100644 --- a/src/alphabet.rs +++ b/src/alphabet.rs @@ -1,164 +1,250 @@ //! Provides [Alphabet] and constants for alphabets commonly used in the wild. +#[cfg(any(feature = "std", test))] +use std::{ + convert, + fmt, + error, +}; + +const ALPHABET_SIZE: usize = 64; + /// An alphabet defines the 64 ASCII characters (symbols) used for base64. /// /// Common alphabets are provided as constants, and custom alphabets -/// can be made via the [From](#impl-From) implementation. +/// can be made via `from_str` or the `TryFrom` implementation. /// /// ``` -/// let custom = base64::alphabet::Alphabet::from("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"); +/// let custom = base64::alphabet::Alphabet::from_str("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/").unwrap(); /// /// let engine = base64::engine::fast_portable::FastPortable::from( /// &custom, /// base64::engine::fast_portable::PAD); /// ``` -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, Eq, PartialEq)] pub struct Alphabet { - pub(crate) symbols: [u8; 64], + pub(crate) symbols: [u8; ALPHABET_SIZE], } impl Alphabet { /// Performs no checks so that it can be const. /// Used only for known-valid strings. - const fn from_unchecked(alphabet: &str) -> Alphabet { - let mut symbols = [0_u8; 64]; + const fn from_str_unchecked(alphabet: &str) -> Alphabet { + let mut symbols = [0_u8; ALPHABET_SIZE]; let source_bytes = alphabet.as_bytes(); // a way to copy that's allowed in const fn let mut index = 0; - while index < 64 { + while index < ALPHABET_SIZE { symbols[index] = source_bytes[index]; index += 1; } Alphabet { symbols } } -} -impl> From for Alphabet { - /// Create a `CharacterSet` from a string of 64 ASCII bytes. Each byte must be - /// unique, and the `=` byte is not allowed as it is used for padding. + /// Create a `CharacterSet` from a string of 64 unique printable ASCII bytes. /// - /// # Errors + /// The `=` byte is not allowed as it is used for padding. /// - /// Panics if the text is an invalid base64 alphabet since the alphabet is - /// likely to be hardcoded, and therefore errors are generally unrecoverable - /// programmer errors. - fn from(string: T) -> Self { - let alphabet = string.as_ref(); - assert_eq!( - 64, - alphabet.as_bytes().len(), - "Base64 char set length must be 64" - ); + /// The `const`-ness of this function isn't useful as of rust 1.54.0 since `const` `unwrap()`, + /// etc, haven't shipped yet, but that's [on the roadmap](https://github.com/rust-lang/rust/issues/85194). + pub const fn from_str(alphabet: &str) -> Result { + let bytes = alphabet.as_bytes(); + if bytes.len() != ALPHABET_SIZE { + return Err(ParseAlphabetError::InvalidLength); + } - // scope just to ensure not accidentally using the sorted copy { - // Check uniqueness without allocating since this must be no_std. - // Could pull in heapless and use IndexSet, but this seems simple enough. - let mut bytes = [0_u8; 64]; - alphabet - .as_bytes() - .iter() - .enumerate() - .for_each(|(index, &byte)| bytes[index] = byte); - - bytes.sort_unstable(); - - // iterate over the sorted bytes, offset by one - bytes.iter().zip(bytes[1..].iter()).for_each(|(b1, b2)| { - // if any byte is the same as the next byte, there's a duplicate - assert_ne!(b1, b2, "Duplicate bytes"); - }); + let mut index = 0; + while index < ALPHABET_SIZE { + let byte = bytes[index]; + + // must be ascii printable. 127 (DEL) is commonly considered printable + // for some reason but clearly unsuitable for base64. + if !(byte >= 32_u8 && byte <= 126_u8) { + return Err(ParseAlphabetError::UnprintableByte(byte)); + } + // = is assumed to be padding, so cannot be used as a symbol + if b'=' == byte { + return Err(ParseAlphabetError::ReservedByte(byte)); + } + + // Check for duplicates while staying within what const allows. + // It's n^2, but only over 64 hot bytes, and only once, so it's likely in the single digit + // microsecond range. + + let mut probe_index = 0; + while probe_index < ALPHABET_SIZE { + if probe_index == index { + probe_index += 1; + continue; + } + + let probe_byte = bytes[probe_index]; + + if byte == probe_byte { + return Err(ParseAlphabetError::DuplicatedByte(byte)); + } + + probe_index += 1; + } + + index += 1; + } } - for &byte in alphabet.as_bytes() { - // must be ascii printable. 127 (DEL) is commonly considered printable - // for some reason but clearly unsuitable for base64. - assert!(byte >= 32_u8 && byte < 127_u8, "Bytes must be printable"); - // = is assumed to be padding, so cannot be used as a symbol - assert_ne!(b'=', byte, "Padding byte '=' is reserved"); - } + Ok(Self::from_str_unchecked(alphabet)) + } +} + +#[cfg(any(feature = "std", test))] +impl convert::TryFrom<&str> for Alphabet { + type Error = ParseAlphabetError; + + fn try_from(value: &str) -> Result { + Alphabet::from_str(value) + } +} + +/// Possible errors when constructing an [Alphabet] from a `str`. +#[derive(Debug, Eq, PartialEq)] +pub enum ParseAlphabetError { + /// Alphabets must be 64 ASCII bytes + InvalidLength, + /// All bytes must be unique + DuplicatedByte(u8), + /// All bytes must be printable (in the range `[32, 126]`). + UnprintableByte(u8), + /// `=` cannot be used + ReservedByte(u8), +} - Self::from_unchecked(alphabet) +#[cfg(any(feature = "std", test))] +impl fmt::Display for ParseAlphabetError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ParseAlphabetError::InvalidLength => write!(f, "Invalid length - must be 64 bytes"), + ParseAlphabetError::DuplicatedByte(b) => write!(f, "Duplicated byte: {}", b), + ParseAlphabetError::UnprintableByte(b) => write!(f, "Unprintable byte: {}", b), + ParseAlphabetError::ReservedByte(b) => write!(f, "Reserved byte: {}", b) + } } } +#[cfg(any(feature = "std", test))] +impl error::Error for ParseAlphabetError {} + /// The standard alphabet (uses `+` and `/`). /// /// See [RFC 3548](https://tools.ietf.org/html/rfc3548#section-3). -pub const STANDARD: Alphabet = - Alphabet::from_unchecked("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"); +pub const STANDARD: Alphabet = Alphabet::from_str_unchecked( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/", +); /// The URL safe alphabet (uses `-` and `_`). /// /// See [RFC 3548](https://tools.ietf.org/html/rfc3548#section-4). -pub const URL_SAFE: Alphabet = - Alphabet::from_unchecked("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"); +pub const URL_SAFE: Alphabet = Alphabet::from_str_unchecked( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_", +); /// The `crypt(3)` alphabet (uses `.` and `/` as the first two values). /// /// Not standardized, but folk wisdom on the net asserts that this alphabet is what crypt uses. -pub const CRYPT: Alphabet = - Alphabet::from_unchecked("./0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"); +pub const CRYPT: Alphabet = Alphabet::from_str_unchecked( + "./0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz", +); /// The bcrypt alphabet. -pub const BCRYPT: Alphabet = - Alphabet::from_unchecked("./ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"); +pub const BCRYPT: Alphabet = Alphabet::from_str_unchecked( + "./ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789", +); /// The alphabet used in IMAP-modified UTF-7 (uses `+` and `,`). /// /// See [RFC 3501](https://tools.ietf.org/html/rfc3501#section-5.1.3) -pub const IMAP_MUTF7: Alphabet = - Alphabet::from_unchecked("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+,"); +pub const IMAP_MUTF7: Alphabet = Alphabet::from_str_unchecked( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+,", +); /// The alphabet used in BinHex 4.0 files. /// /// See [BinHex 4.0 Definition](http://files.stairways.com/other/binhex-40-specs-info.txt) -pub const BIN_HEX: Alphabet = - Alphabet::from_unchecked("!\"#$%&'()*+,-0123456789@ABCDEFGHIJKLMNPQRSTUVXYZ[`abcdehijklmpqr"); +pub const BIN_HEX: Alphabet = Alphabet::from_str_unchecked( + "!\"#$%&'()*+,-0123456789@ABCDEFGHIJKLMNPQRSTUVXYZ[`abcdehijklmpqr", +); #[cfg(test)] mod tests { - use crate::alphabet::Alphabet; + use crate::alphabet::*; + use std::convert::TryFrom as _; - #[should_panic(expected = "Duplicate bytes")] #[test] fn detects_duplicate_start() { - let _ = Alphabet::from("AACDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"); + assert_eq!( + ParseAlphabetError::DuplicatedByte(b'A'), + Alphabet::from_str("AACDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/") + .unwrap_err() + ); } - #[should_panic(expected = "Duplicate bytes")] #[test] fn detects_duplicate_end() { - let _ = Alphabet::from("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789//"); + assert_eq!( + ParseAlphabetError::DuplicatedByte(b'/'), + Alphabet::from_str("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789//") + .unwrap_err() + ); } - #[should_panic(expected = "Duplicate bytes")] #[test] fn detects_duplicate_middle() { - let _ = Alphabet::from("ABCDEFGHIJKLMNOPQRSTUVWXYZZbcdefghijklmnopqrstuvwxyz0123456789+/"); + assert_eq!( + ParseAlphabetError::DuplicatedByte(b'Z'), + Alphabet::from_str("ABCDEFGHIJKLMNOPQRSTUVWXYZZbcdefghijklmnopqrstuvwxyz0123456789+/") + .unwrap_err() + ); } - #[should_panic(expected = "Base64 char set length must be 64")] #[test] fn detects_length() { - let _ = Alphabet::from( - "xxxxxxxxxABCDEFGHIJKLMNOPQRSTUVWXYZZbcdefghijklmnopqrstuvwxyz0123456789+/", + assert_eq!( + ParseAlphabetError::InvalidLength, + Alphabet::from_str( + "xxxxxxxxxABCDEFGHIJKLMNOPQRSTUVWXYZZbcdefghijklmnopqrstuvwxyz0123456789+/", + ) + .unwrap_err() ); } - #[should_panic(expected = "Padding byte '=' is reserved")] #[test] fn detects_padding() { - let _ = Alphabet::from("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+="); + assert_eq!( + ParseAlphabetError::ReservedByte(b'='), + Alphabet::from_str("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+=") + .unwrap_err() + ); } - #[should_panic(expected = "Bytes must be printable")] #[test] fn detects_unprintable() { // form feed - let _ = - Alphabet::from("\x0cBCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"); + assert_eq!( + ParseAlphabetError::UnprintableByte(0xc), + Alphabet::from_str( + "\x0cBCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" + ) + .unwrap_err() + ); + } + + #[test] + fn same_as_unchecked() { + assert_eq!( + STANDARD, + Alphabet::try_from("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/") + .unwrap() + ) } }