Skip to content

Commit

Permalink
Merge branch 'alex/iterative-mht-digest' into 'master'
Browse files Browse the repository at this point in the history
perf(crypto): CRP-2019 iterative MixedHashTree::digest()

`MixedHashTree::digest()` is a recursive function, which, naturally, can produce a stack overflow. Thus we had 2 options to mitigate that: 1) make the function iterative and 2) add a depth check and error handling. This MR implements both and sets 1) to be used in production, whereas 2) turned out to be ~15% less efficient, so let's count this as a performance improvement rather than a fix, since we don't expect to have trees that could produce a stack overflow in production.

The recursive implementation with error handling is moved to test utils and this MR also adds a new proptest that compares that the outputs of both recursive and iterative functions are equal. 

See merge request dfinity-lab/public/ic!13918
  • Loading branch information
altkdf committed Aug 9, 2023
2 parents d48dfc0 + e4c886e commit 865a18f
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 9 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 16 additions & 2 deletions rs/crypto/tree_hash/benches/tree_hash.rs
Expand Up @@ -3,7 +3,9 @@ use ic_crypto_tree_hash::{
flatmap, lookup_path, FlatMap, HashTree, HashTreeBuilder, Label, LabeledTree, LabeledTree::*,
MixedHashTree, WitnessGenerator,
};
use ic_crypto_tree_hash_test_utils::hash_tree_builder_from_labeled_tree;
use ic_crypto_tree_hash_test_utils::{
hash_tree_builder_from_labeled_tree, mixed_hash_tree_digest_recursive,
};
use ic_types_test_utils::ids::message_test_id;

fn new_request_status_tree(num_subtrees: usize) -> LabeledTree<Vec<u8>> {
Expand Down Expand Up @@ -198,9 +200,21 @@ pub fn criterion_benchmark(c: &mut Criterion) {
c.benchmark_group("compute_digest");

g.bench_function(BenchmarkId::new("mixed_hash_tree", num_subtrees), |b| {
b.iter(|| black_box(mixed_hash_tree.digest()));
b.iter(|| {
black_box(
mixed_hash_tree_digest_recursive(&mixed_hash_tree)
.expect("too deep recursion"),
)
});
});

g.bench_function(
BenchmarkId::new("mixed_hash_tree_iterative", num_subtrees),
|b| {
b.iter(|| black_box(mixed_hash_tree.digest()));
},
);

g.bench_function(BenchmarkId::new("witness", num_subtrees), |b| {
b.iter(|| {
black_box(ic_crypto_tree_hash::recompute_digest(
Expand Down
63 changes: 56 additions & 7 deletions rs/crypto/tree_hash/src/lib.rs
Expand Up @@ -447,15 +447,64 @@ impl MixedHashTree {
/// Recomputes root hash of the full tree that this mixed tree was
/// constructed from.
pub fn digest(&self) -> Digest {
match self {
Self::Empty => tree_hash::empty_subtree_hash(),
Self::Fork(lr) => tree_hash::compute_fork_digest(&lr.0.digest(), &lr.1.digest()),
Self::Labeled(label, subtree) => {
tree_hash::compute_node_digest(label, &subtree.digest())
#[derive(Debug)]
enum StackItem<'a> {
Expand(&'a MixedHashTree),
Collect(&'a MixedHashTree),
}

impl<'a> StackItem<'a> {
fn to_collect(&self) -> Self {
match self {
Self::Expand(t) => Self::Collect(t),
Self::Collect(_) => panic!("expected Expand, got Collect"),
}
}
Self::Leaf(buf) => tree_hash::compute_leaf_digest(&buf[..]),
Self::Pruned(digest) => digest.clone(),
}

let mut stack: Vec<StackItem<'_>> = Vec::new();
let mut digests: Vec<Digest> = Vec::new();

stack.push(StackItem::Expand(self));

while let Some(t) = stack.pop() {
match t {
StackItem::Expand(Self::Fork(lr)) => {
stack.push(t.to_collect());
stack.push(StackItem::Expand(&lr.1));
stack.push(StackItem::Expand(&lr.0));
}
StackItem::Expand(Self::Labeled(_, subtree)) => {
stack.push(t.to_collect());
stack.push(StackItem::Expand(subtree));
}
StackItem::Collect(Self::Fork(_)) => {
let right = digests.pop().expect("bug: missing right subtree digest");
let left = digests.pop().expect("bug: missing left subtree digest");
digests.push(tree_hash::compute_fork_digest(&left, &right));
}
StackItem::Collect(Self::Labeled(label, _)) => {
let subtree_digest = digests.pop().expect("bug: missing subtree digest");
let labeled_digest = tree_hash::compute_node_digest(label, &subtree_digest);
digests.push(labeled_digest);
}
StackItem::Collect(Self::Leaf(buf)) => {
digests.push(tree_hash::compute_leaf_digest(&buf[..]))
}
StackItem::Collect(Self::Pruned(digest)) => digests.push(digest.clone()),
StackItem::Collect(Self::Empty) => digests.push(tree_hash::empty_subtree_hash()),
t /* Expand of Leaf, Pruned or Empty */ => stack.push(t.to_collect()),
}
}

assert_eq!(
digests.len(),
1,
"bug: reduced tree to not exactly one digest: {digests:?}"
);
assert!(stack.is_empty(), "bug: stack is not empty: {stack:?}");

digests[0].clone()
}

/// Finds a label in a hash tree.
Expand Down
1 change: 1 addition & 0 deletions rs/crypto/tree_hash/test_utils/BUILD.bazel
Expand Up @@ -7,6 +7,7 @@ DEPENDENCIES = [
"//rs/crypto/tree_hash",
"@crate_index//:proptest",
"@crate_index//:rand_0_8_4",
"@crate_index//:thiserror",
]

DEV_DEPENDENCIES = [
Expand Down
1 change: 1 addition & 0 deletions rs/crypto/tree_hash/test_utils/Cargo.toml
Expand Up @@ -8,6 +8,7 @@ assert_matches = "1.5.0"
ic-crypto-tree-hash = { path = ".." }
proptest = "1.0"
rand = "0.8.4"
thiserror = "1.0"

[dev-dependencies]
ic-crypto-test-utils-reproducible-rng = { path = "../../test_utils/reproducible_rng" }
30 changes: 30 additions & 0 deletions rs/crypto/tree_hash/test_utils/src/lib.rs
Expand Up @@ -455,3 +455,33 @@ pub fn compute_fork_digest(left_digest: &Digest, right_digest: &Digest) -> Diges
pub fn empty_subtree_hash() -> Digest {
Hasher::for_domain(DOMAIN_HASHTREE_EMPTY_SUBTREE).finalize()
}

/// This error indicates that the algorithm exceeded the recursion depth limit.
#[derive(thiserror::Error, Debug, PartialEq)]
#[error("The algorithm failed due to too deep recursion (depth={0})")]
pub struct TooDeepRecursion(pub u32);

/// Recomputes root hash of the full tree that this mixed tree was
/// constructed from.
pub fn mixed_hash_tree_digest_recursive(tree: &MixedHashTree) -> Result<Digest, TooDeepRecursion> {
fn digest_impl(t: &MixedHashTree, depth: u32) -> Result<Digest, TooDeepRecursion> {
if depth as usize > MAX_HASH_TREE_DEPTH {
return Err(TooDeepRecursion(depth));
}
let result = match t {
MixedHashTree::Empty => empty_subtree_hash(),
MixedHashTree::Fork(lr) => compute_fork_digest(
&digest_impl(&lr.0, depth + 1)?,
&digest_impl(&lr.1, depth + 1)?,
),
MixedHashTree::Labeled(label, subtree) => {
compute_node_digest(label, &digest_impl(subtree, depth + 1)?)
}
MixedHashTree::Leaf(buf) => compute_leaf_digest(&buf[..]),
MixedHashTree::Pruned(digest) => digest.clone(),
};
Ok(result)
}

digest_impl(tree, 1)
}
21 changes: 21 additions & 0 deletions rs/crypto/tree_hash/test_utils/tests/mixed_hash_tree.rs
@@ -0,0 +1,21 @@
use assert_matches::assert_matches;
use ic_crypto_tree_hash::MixedHashTree;
use ic_crypto_tree_hash_test_utils::{mixed_hash_tree_digest_recursive, TooDeepRecursion};

const MAX_HASH_TREE_DEPTH: usize = 128;

#[test]
fn mixed_hash_tree_recursive_digest_errors_on_too_deep_trees() {
let mut tree = MixedHashTree::Empty;
for _ in 1..MAX_HASH_TREE_DEPTH {
tree = MixedHashTree::Fork(Box::new((tree.clone(), MixedHashTree::Empty)));
}

assert_matches!(mixed_hash_tree_digest_recursive(&tree), Ok(_));

tree = MixedHashTree::Fork(Box::new((tree.clone(), MixedHashTree::Empty)));
assert_eq!(
mixed_hash_tree_digest_recursive(&tree),
Err(TooDeepRecursion(MAX_HASH_TREE_DEPTH as u32 + 1))
);
}
15 changes: 15 additions & 0 deletions rs/crypto/tree_hash/tests/tree_hash.rs
Expand Up @@ -4,6 +4,7 @@ use ic_crypto_sha2::Sha256;
use ic_crypto_test_utils_reproducible_rng::reproducible_rng;
use ic_crypto_tree_hash::*;
use ic_crypto_tree_hash_test_utils::*;
use proptest::prelude::*;
use rand::Rng;
use rand::{CryptoRng, RngCore};
use std::collections::BTreeMap;
Expand Down Expand Up @@ -1482,6 +1483,20 @@ fn tree_with_three_levels() -> HashTreeBuilderImpl {
builder
}

proptest! {
#[test]
fn recompute_digest_for_mixed_hash_tree_iteratively_and_recursively_produces_same_digest(
tree in arbitrary::arbitrary_well_formed_mixed_hash_tree()
){
let rec_or_error = mixed_hash_tree_digest_recursive(&tree);
// ignore the error case, since the iterative algorithm is infallible
if let Ok(rec) = rec_or_error {
let iter = tree.digest();
assert_eq!(rec, iter);
}
}
}

#[test]
fn witness_for_simple_path_in_a_big_tree() {
// Simple path : label_b -> label_b_5 -> label_b_5_1 -> leaf
Expand Down

0 comments on commit 865a18f

Please sign in to comment.