diff --git a/Cargo.toml b/Cargo.toml index f8ffe16..1398e61 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,5 @@ repository = "https://github.com/mozilla/rust-cascade" [dependencies] byteorder="1.3.1" -digest="0.8.0" murmurhash3 = "0.0.5" -sha2="^0.8" +sha2="^0.10.2" diff --git a/src/lib.rs b/src/lib.rs index 0c96b36..78f21fc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,8 @@ extern crate byteorder; -extern crate digest; extern crate murmurhash3; extern crate sha2; -use byteorder::ReadBytesExt; +use byteorder::{ByteOrder, LittleEndian, ReadBytesExt}; use murmurhash3::murmurhash3_x86_32; use sha2::{Digest, Sha256}; use std::convert::{TryFrom, TryInto}; @@ -88,7 +87,8 @@ struct Bloom { /// https://github.com/mozilla/filter-cascade/blob/v0.3.0/filtercascade/fileformats.py enum HashAlgorithm { MurmurHash3 = 1, - Sha256 = 2, + Sha256l32 = 2, // low 32 bits of sha256 + Sha256 = 3, // all 256 bits of sha256 } impl fmt::Display for HashAlgorithm { @@ -103,12 +103,147 @@ impl TryFrom for HashAlgorithm { match value { // Naturally, these need to match the enum declaration 1 => Ok(Self::MurmurHash3), - 2 => Ok(Self::Sha256), + 2 => Ok(Self::Sha256l32), + 3 => Ok(Self::Sha256), _ => Err(()), } } } +/// A CascadeIndexGenerator provides read-once access to a table +/// of numbers H_ij with 0 <= H_ij < r_i. +/// +/// A call to next_layer(r) increments i and sets r_i = r. +/// A call to next_index() increments j and outputs H_ij. +/// +trait CascadeIndexGenerator { + fn next_layer(&mut self, size: u32); + fn next_index(&mut self) -> usize; +} + +struct MurmurHash3IndexGenerator<'a> { + key: &'a [u8], + counter: u32, + depth: u32, + range: u32, +} + +impl<'a> MurmurHash3IndexGenerator<'a> { + fn new(key: &'a [u8], top_layer_size: u32) -> Self { + MurmurHash3IndexGenerator { + key, + counter: 0, + depth: 1, + range: top_layer_size, + } + } +} + +impl<'a> CascadeIndexGenerator for MurmurHash3IndexGenerator<'a> { + fn next_index(&mut self) -> usize { + let hash_seed = (self.counter << 16) + self.depth; + self.counter += 1; + let index = murmurhash3_x86_32(self.key, hash_seed); + (index % self.range) as usize + } + fn next_layer(&mut self, size: u32) { + self.counter = 0; + self.depth += 1; + self.range = size; + } +} + +struct SHA256l32IndexGenerator<'a> { + salt: &'a [u8], + key: &'a [u8], + counter: u32, + depth: u8, + range: u32, +} + +impl<'a> SHA256l32IndexGenerator<'a> { + fn new(salt: &'a [u8], key: &'a [u8], top_layer_size: u32) -> Self { + SHA256l32IndexGenerator { + salt, + key, + counter: 0, + depth: 1, + range: top_layer_size, + } + } +} + +impl<'a> CascadeIndexGenerator for SHA256l32IndexGenerator<'a> { + fn next_index(&mut self) -> usize { + let mut hasher = Sha256::new(); + hasher.update(self.salt); + hasher.update(self.counter.to_le_bytes()); + hasher.update(self.depth.to_le_bytes()); + hasher.update(self.key); + self.counter += 1; + let index = u32::from_le_bytes( + hasher.finalize()[0..4] + .try_into() + .expect("sha256 should have given enough bytes"), + ); + (index % self.range) as usize + } + fn next_layer(&mut self, size: u32) { + self.counter = 0; + self.depth += 1; + self.range = size; + } +} + +struct SHA256CtrIndexGenerator<'a> { + salt: &'a [u8], + key: &'a [u8], + counter: u32, + range: u32, + state: [u8; 32], + state_available: usize, +} + +impl<'a> SHA256CtrIndexGenerator<'a> { + fn new(salt: &'a [u8], key: &'a [u8], top_layer_size: u32) -> Self { + SHA256CtrIndexGenerator { + salt, + key, + counter: 0, + range: top_layer_size, + state: [0; 32], + state_available: 0, + } + } +} + +impl<'a> CascadeIndexGenerator for SHA256CtrIndexGenerator<'a> { + fn next_index(&mut self) -> usize { + // |bytes_needed| is the minimum number of bytes needed to represent a value in [0, range). + let bytes_needed = ((self.range.next_power_of_two().trailing_zeros() + 7) / 8) as usize; + let mut index_arr = [0u8; 4]; + for byte in index_arr.iter_mut().take(bytes_needed) { + if self.state_available == 0 { + let mut hasher = Sha256::new(); + hasher.update(self.counter.to_le_bytes()); + hasher.update(self.salt); + hasher.update(self.key); + let digest = &hasher.finalize()[..]; + self.state.copy_from_slice(digest); + self.state_available = 32; + self.counter += 1; + } + *byte = self.state[32 - self.state_available]; + self.state_available -= 1; + } + let index = LittleEndian::read_u32(&index_arr); + (index % self.range) as usize + } + fn next_layer(&mut self, size: u32) { + self.range = size; + } +} + impl Bloom { /// Attempts to decode the Bloom filter represented by the bytes in the given reader. /// @@ -145,12 +280,7 @@ impl Bloom { let n_hash_funcs = reader.read_u32::()?; let level = reader.read_u8()?; - let shifted_size = size.wrapping_shr(3) as usize; - let byte_count = if size % 8 != 0 { - shifted_size + 1 - } else { - shifted_size - }; + let byte_count = ((size + 7) / 8) as usize; let mut bits_bytes = vec![0; byte_count]; reader.read_exact(&mut bits_bytes)?; let bloom = Bloom { @@ -163,40 +293,13 @@ impl Bloom { Ok(Some(bloom)) } - fn hash(&self, n_fn: u32, key: &[u8], salt: Option<&Vec>) -> u32 { - match self.hash_algorithm { - HashAlgorithm::MurmurHash3 => { - if salt.is_some() { - panic!("murmur does not support salts") - } - let hash_seed = (n_fn << 16) + self.level as u32; - murmurhash3_x86_32(key, hash_seed) % self.size - } - HashAlgorithm::Sha256 => { - let mut hasher = Sha256::new(); - if let Some(salt_bytes) = salt { - hasher.input(salt_bytes) - } - hasher.input(n_fn.to_le_bytes()); - hasher.input(self.level.to_le_bytes()); - hasher.input(key); - - u32::from_le_bytes( - hasher.result()[0..4] - .try_into() - .expect("sha256 should have given enough bytes"), - ) % self.size - } - } - } - /// Test for the presence of a given sequence of bytes in this Bloom filter. /// /// # Arguments /// `item` - The slice of bytes to test for - fn has(&self, item: &[u8], salt: Option<&Vec>) -> bool { - for i in 0..self.n_hash_funcs { - if !self.bit_vector.get(self.hash(i, item, salt) as usize) { + fn has(&self, generator: &mut T) -> bool { + for _ in 0..self.n_hash_funcs { + if !self.bit_vector.get(generator.next_index()) { return false; } } @@ -221,7 +324,7 @@ pub struct Cascade { /// The next (lower) level in the cascade child_layer: Option>, /// The salt in use, if any - salt: Option>, + salt: Vec, /// Whether the logic should be inverted inverted: bool, } @@ -242,7 +345,7 @@ impl Cascade { } let mut reader = bytes.as_slice(); let version = reader.read_u16::()?; - let mut salt = None; + let mut salt = vec![]; let mut inverted = false; if version >= 2 { @@ -251,7 +354,7 @@ impl Cascade { if salt_len > 0 { let mut salt_bytes = vec![0; salt_len]; reader.read_exact(&mut salt_bytes)?; - salt = Some(salt_bytes); + salt.extend_from_slice(&salt_bytes); } } @@ -262,23 +365,43 @@ impl Cascade { )); } - Cascade::child_layer_from_bytes(reader, salt, inverted) + let top_layer = Cascade::child_layer_from_bytes(reader, &salt, inverted)?; + if let Some(ref c) = top_layer { + if c.filter.level != 1 { + return Err(Error::new( + ErrorKind::InvalidData, + format!("Top layer index {} != 1", c.filter.level), + )); + } + } + Ok(top_layer) } fn child_layer_from_bytes( mut reader: R, - salt: Option>, + salt: &[u8], inverted: bool, ) -> Result>, Error> { let filter = match Bloom::read(&mut reader)? { Some(filter) => filter, None => return Ok(None), }; - let our_salt = salt.as_ref().cloned(); + let child_layer = Cascade::child_layer_from_bytes(reader, salt, inverted)?; + if let Some(ref c) = child_layer { + if c.filter.level != filter.level + 1 { + return Err(Error::new( + ErrorKind::InvalidData, + format!( + "Irregular layer numbering: {} followed by {}", + filter.level, c.filter.level + ), + )); + } + } Ok(Some(Box::new(Cascade { filter, - child_layer: Cascade::child_layer_from_bytes(reader, salt, inverted)?, - salt: our_salt, + child_layer, + salt: salt.to_vec(), inverted, }))) } @@ -288,18 +411,34 @@ impl Cascade { /// # Arguments /// `entry` - The slice of bytes to test for pub fn has(&self, entry: &[u8]) -> bool { - let result = self.has_internal(entry); + let result = match self.filter.hash_algorithm { + HashAlgorithm::MurmurHash3 => { + assert!(self.salt.is_empty()); + self.has_internal(&mut MurmurHash3IndexGenerator::new(entry, self.filter.size)) + } + HashAlgorithm::Sha256l32 => self.has_internal(&mut SHA256l32IndexGenerator::new( + &self.salt, + entry, + self.filter.size, + )), + HashAlgorithm::Sha256 => self.has_internal(&mut SHA256CtrIndexGenerator::new( + &self.salt, + entry, + self.filter.size, + )), + }; if self.inverted { return !result; } result } - fn has_internal(&self, entry: &[u8]) -> bool { - if self.filter.has(entry, self.salt.as_ref()) { + fn has_internal(&self, generator: &mut T) -> bool { + if self.filter.has(generator) { match self.child_layer { Some(ref child) => { - let child_value = !child.has_internal(entry); + generator.next_layer(child.filter.size); + let child_value = !child.has_internal(generator); return child_value; } None => { @@ -323,7 +462,7 @@ impl Cascade { .child_layer .as_ref() .map_or(0, |child_layer| child_layer.approximate_size_of()) - + self.salt.as_ref().map_or(0, |salt| salt.len()) + + self.salt.len() } } @@ -345,6 +484,7 @@ impl fmt::Display for Cascade { mod tests { use Bloom; use Cascade; + use MurmurHash3IndexGenerator; #[test] fn bloom_v1_test_from_bytes() { @@ -355,9 +495,18 @@ mod tests { match Bloom::read(&mut reader) { Ok(Some(bloom)) => { - assert!(bloom.has(b"this", None) == true); - assert!(bloom.has(b"that", None) == true); - assert!(bloom.has(b"other", None) == false); + assert!(bloom.has(&mut MurmurHash3IndexGenerator::new( + b"this".as_ref(), + bloom.size + ))); + assert!(bloom.has(&mut MurmurHash3IndexGenerator::new( + b"that".as_ref(), + bloom.size + ))); + assert!(!bloom.has(&mut MurmurHash3IndexGenerator::new( + b"other".as_ref(), + bloom.size + ))); } Ok(None) => panic!("Parsing failed"), Err(_) => panic!("Parsing failed"), @@ -421,33 +570,33 @@ mod tests { } #[test] - fn cascade_v2_sha256_from_file_bytes_test() { - let v = include_bytes!("../test_data/test_v2_sha256_mlbf").to_vec(); + fn cascade_v2_sha256l32_from_file_bytes_test() { + let v = include_bytes!("../test_data/test_v2_sha256l32_mlbf").to_vec(); let cascade = Cascade::from_bytes(v) .expect("parsing Cascade should succeed") .expect("Cascade should be Some"); - assert!(cascade.salt == None); + assert!(cascade.salt.len() == 0); assert!(cascade.inverted == false); assert!(cascade.has(b"this") == true); assert!(cascade.has(b"that") == true); assert!(cascade.has(b"other") == false); - assert_eq!(cascade.approximate_size_of(), 10247); + assert_eq!(cascade.approximate_size_of(), 128314); } #[test] - fn cascade_v2_sha256_with_salt_from_file_bytes_test() { - let v = include_bytes!("../test_data/test_v2_sha256_salt_mlbf").to_vec(); + fn cascade_v2_sha256l32_with_salt_from_file_bytes_test() { + let v = include_bytes!("../test_data/test_v2_sha256l32_salt_mlbf").to_vec(); let cascade = Cascade::from_bytes(v) .expect("parsing Cascade should succeed") .expect("Cascade should be Some"); - assert!(cascade.salt == Some(b"nacl".to_vec())); + assert!(cascade.salt == b"nacl".to_vec()); assert!(cascade.inverted == false); assert!(cascade.has(b"this") == true); assert!(cascade.has(b"that") == true); assert!(cascade.has(b"other") == false); - assert_eq!(cascade.approximate_size_of(), 10251); + assert_eq!(cascade.approximate_size_of(), 128321); } #[test] @@ -457,12 +606,12 @@ mod tests { .expect("parsing Cascade should succeed") .expect("Cascade should be Some"); - assert!(cascade.salt == None); + assert!(cascade.salt.len() == 0); assert!(cascade.inverted == false); assert!(cascade.has(b"this") == true); assert!(cascade.has(b"that") == true); assert!(cascade.has(b"other") == false); - assert_eq!(cascade.approximate_size_of(), 10247); + assert_eq!(cascade.approximate_size_of(), 127914); } #[test] @@ -472,27 +621,42 @@ mod tests { .expect("parsing Cascade should succeed") .expect("Cascade should be Some"); - assert!(cascade.salt == None); + assert!(cascade.salt.len() == 0); assert!(cascade.inverted == true); assert!(cascade.has(b"this") == true); assert!(cascade.has(b"that") == true); assert!(cascade.has(b"other") == false); - assert_eq!(cascade.approximate_size_of(), 10247); + assert_eq!(cascade.approximate_size_of(), 128113); } #[test] - fn cascade_v2_sha256_inverted_from_file_bytes_test() { - let v = include_bytes!("../test_data/test_v2_sha256_inverted_mlbf").to_vec(); + fn cascade_v2_sha256l32_inverted_from_file_bytes_test() { + let v = include_bytes!("../test_data/test_v2_sha256l32_inverted_mlbf").to_vec(); let cascade = Cascade::from_bytes(v) .expect("parsing Cascade should succeed") .expect("Cascade should be Some"); - assert!(cascade.salt == None); + assert!(cascade.salt.len() == 0); assert!(cascade.inverted == true); assert!(cascade.has(b"this") == true); assert!(cascade.has(b"that") == true); assert!(cascade.has(b"other") == false); - assert_eq!(cascade.approximate_size_of(), 10247); + assert_eq!(cascade.approximate_size_of(), 128165); + } + + #[test] + fn cascade_v2_sha256ctr_from_file_bytes_test() { + let v = include_bytes!("../test_data/test_v2_sha256ctr_salt_mlbf").to_vec(); + let cascade = Cascade::from_bytes(v) + .expect("parsing Cascade should succeed") + .expect("Cascade should be Some"); + + assert!(cascade.salt == b"nacl".to_vec()); + assert!(cascade.inverted == false); + assert!(cascade.has(b"this") == true); + assert!(cascade.has(b"that") == true); + assert!(cascade.has(b"other") == false); + assert_eq!(cascade.approximate_size_of(), 128510); } #[test] @@ -500,4 +664,30 @@ mod tests { let cascade = Cascade::from_bytes(Vec::new()).expect("parsing Cascade should succeed"); assert!(cascade.is_none()); } + + #[test] + fn cascade_test_from_bytes() { + let unknown_version: Vec = vec![0xff, 0xff, 0x00, 0x00]; + match Cascade::from_bytes(unknown_version) { + Ok(_) => panic!("Cascade::from_bytes allows unknown version."), + Err(_) => (), + } + + let first_layer_is_zero: Vec = vec![ + 0x01, 0x00, 0x01, 0x08, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, + ]; + match Cascade::from_bytes(first_layer_is_zero) { + Ok(_) => panic!("Cascade::from_bytes allows zero indexed layers."), + Err(_) => (), + } + + let second_layer_is_three: Vec = vec![ + 0x01, 0x00, 0x01, 0x08, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, + 0x08, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, + ]; + match Cascade::from_bytes(second_layer_is_three) { + Ok(_) => panic!("Cascade::from_bytes allows non-sequential layers."), + Err(_) => (), + } + } } diff --git a/test_data/make-sample-data.py b/test_data/make-sample-data.py index bbb73ec..bfd8987 100644 --- a/test_data/make-sample-data.py +++ b/test_data/make-sample-data.py @@ -2,9 +2,14 @@ import hashlib from pathlib import Path +import sys +import logging -def predictable_serial_gen(end): - counter = 0 +logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) + + +def predictable_serial_gen(start, end): + counter = start while counter < end: counter += 1 m = hashlib.sha256() @@ -19,41 +24,83 @@ def store(fc, path): fc.tofile(f) -large_set = set(predictable_serial_gen(100_000)) +small_set = list(set(predictable_serial_gen(0, 100_000))) +large_set = set(predictable_serial_gen(100_000, 1_000_000)) -v2_sha256_with_salt = filtercascade.FilterCascade( - [], defaultHashAlg=filtercascade.fileformats.HashAlgorithm.SHA256, salt=b"nacl" +# filter parameters +growth_factor = 1.0 +min_filter_length = 177 # 177 * 1.44 ~ 256, so smallest filter will have 256 bits + +print("--- v2_sha256l32_with_salt ---") +v2_sha256l32_with_salt = filtercascade.FilterCascade( + [], + defaultHashAlg=filtercascade.fileformats.HashAlgorithm.SHA256, + salt=b"nacl", + growth_factor=growth_factor, + min_filter_length=min_filter_length, ) -v2_sha256_with_salt.initialize( - include=[b"this", b"that"], exclude=large_set | set([b"other"]) +v2_sha256l32_with_salt.initialize( + include=[b"this", b"that"] + small_set, exclude=large_set | set([b"other"]) ) -store(v2_sha256_with_salt, Path("test_v2_sha256_salt_mlbf")) +store(v2_sha256l32_with_salt, Path("test_v2_sha256l32_salt_mlbf")) -v2_sha256 = filtercascade.FilterCascade( - [], defaultHashAlg=filtercascade.fileformats.HashAlgorithm.SHA256 +print("--- v2_sha256l32 ---") +v2_sha256l32 = filtercascade.FilterCascade( + [], + defaultHashAlg=filtercascade.fileformats.HashAlgorithm.SHA256, + growth_factor=growth_factor, + min_filter_length=min_filter_length, +) +v2_sha256l32.initialize( + include=[b"this", b"that"] + small_set, exclude=large_set | set([b"other"]) ) -v2_sha256.initialize(include=[b"this", b"that"], exclude=large_set | set([b"other"])) -store(v2_sha256, Path("test_v2_sha256_mlbf")) +store(v2_sha256l32, Path("test_v2_sha256l32_mlbf")) +print("--- v2_murmur ---") v2_murmur = filtercascade.FilterCascade( - [], defaultHashAlg=filtercascade.fileformats.HashAlgorithm.MURMUR3 + [], + defaultHashAlg=filtercascade.fileformats.HashAlgorithm.MURMUR3, + growth_factor=growth_factor, + min_filter_length=min_filter_length, +) +v2_murmur.initialize( + include=[b"this", b"that"] + small_set, exclude=large_set | set([b"other"]) ) -v2_murmur.initialize(include=[b"this", b"that"], exclude=large_set | set([b"other"])) store(v2_murmur, Path("test_v2_murmur_mlbf")) +print("--- v2_murmur_inverted ---") v2_murmur_inverted = filtercascade.FilterCascade( - [], defaultHashAlg=filtercascade.fileformats.HashAlgorithm.MURMUR3 + [], + defaultHashAlg=filtercascade.fileformats.HashAlgorithm.MURMUR3, + growth_factor=growth_factor, + min_filter_length=min_filter_length, ) v2_murmur_inverted.initialize( - include=large_set | set([b"this", b"that"]), exclude=[b"other"] + include=large_set | set([b"this", b"that"]), exclude=[b"other"] + small_set ) store(v2_murmur_inverted, Path("test_v2_murmur_inverted_mlbf")) +print("--- v2_sha256l32_inverted ---") +v2_sha256l32_inverted = filtercascade.FilterCascade( + [], + defaultHashAlg=filtercascade.fileformats.HashAlgorithm.SHA256, + growth_factor=growth_factor, + min_filter_length=min_filter_length, +) +v2_sha256l32_inverted.initialize( + include=large_set | set([b"this", b"that"]), exclude=[b"other"] + small_set +) +store(v2_sha256l32_inverted, Path("test_v2_sha256l32_inverted_mlbf")) -v2_sha256_inverted = filtercascade.FilterCascade( - [], defaultHashAlg=filtercascade.fileformats.HashAlgorithm.SHA256 +print("--- v2_sha256ctr_with_salt ---") +v2_sha256ctr_with_salt = filtercascade.FilterCascade( + [], + defaultHashAlg=filtercascade.fileformats.HashAlgorithm.SHA256CTR, + salt=b"nacl", + growth_factor=growth_factor, + min_filter_length=min_filter_length, ) -v2_sha256_inverted.initialize( - include=large_set | set([b"this", b"that"]), exclude=[b"other"] +v2_sha256ctr_with_salt.initialize( + include=[b"this", b"that"] + small_set, exclude=large_set | set([b"other"]) ) -store(v2_sha256_inverted, Path("test_v2_sha256_inverted_mlbf")) +store(v2_sha256ctr_with_salt, Path("test_v2_sha256ctr_salt_mlbf")) diff --git a/test_data/test_v2_murmur_inverted_mlbf b/test_data/test_v2_murmur_inverted_mlbf index 0c0aecd..3f66409 100644 Binary files a/test_data/test_v2_murmur_inverted_mlbf and b/test_data/test_v2_murmur_inverted_mlbf differ diff --git a/test_data/test_v2_murmur_mlbf b/test_data/test_v2_murmur_mlbf index f994ac7..91aa602 100644 Binary files a/test_data/test_v2_murmur_mlbf and b/test_data/test_v2_murmur_mlbf differ diff --git a/test_data/test_v2_sha256_inverted_mlbf b/test_data/test_v2_sha256_inverted_mlbf deleted file mode 100644 index 3e1e7c1..0000000 Binary files a/test_data/test_v2_sha256_inverted_mlbf and /dev/null differ diff --git a/test_data/test_v2_sha256_mlbf b/test_data/test_v2_sha256_mlbf deleted file mode 100644 index e662a32..0000000 Binary files a/test_data/test_v2_sha256_mlbf and /dev/null differ diff --git a/test_data/test_v2_sha256_salt_mlbf b/test_data/test_v2_sha256_salt_mlbf deleted file mode 100644 index 330c487..0000000 Binary files a/test_data/test_v2_sha256_salt_mlbf and /dev/null differ diff --git a/test_data/test_v2_sha256ctr_salt_mlbf b/test_data/test_v2_sha256ctr_salt_mlbf new file mode 100644 index 0000000..08b425d Binary files /dev/null and b/test_data/test_v2_sha256ctr_salt_mlbf differ diff --git a/test_data/test_v2_sha256l32_inverted_mlbf b/test_data/test_v2_sha256l32_inverted_mlbf new file mode 100644 index 0000000..09a65e0 Binary files /dev/null and b/test_data/test_v2_sha256l32_inverted_mlbf differ diff --git a/test_data/test_v2_sha256l32_mlbf b/test_data/test_v2_sha256l32_mlbf new file mode 100644 index 0000000..6d1cfeb Binary files /dev/null and b/test_data/test_v2_sha256l32_mlbf differ diff --git a/test_data/test_v2_sha256l32_salt_mlbf b/test_data/test_v2_sha256l32_salt_mlbf new file mode 100644 index 0000000..b2785be Binary files /dev/null and b/test_data/test_v2_sha256l32_salt_mlbf differ