diff --git a/base_layer/mmr/src/sparse_merkle_tree/tree.rs b/base_layer/mmr/src/sparse_merkle_tree/tree.rs index f5bf96dddc..bc3378ac59 100644 --- a/base_layer/mmr/src/sparse_merkle_tree/tree.rs +++ b/base_layer/mmr/src/sparse_merkle_tree/tree.rs @@ -255,7 +255,14 @@ impl> SparseMerkleTree { /// Returns true if the tree contains the key `key`. pub fn contains(&self, key: &NodeKey) -> bool { - self.search_node(key).ok().is_some() + match self.search_node(key) { + // In the case of a malformed tree where the search fails unexpectedly, play it safe + Err(_) => false, + // The node is either empty or is a leaf with an unexpected key + Ok(None) => false, + // The node is a leaf with the expected key + Ok(Some(_)) => true, + } } /// Returns the value at location `key` if it exists, or `None` otherwise. @@ -806,4 +813,38 @@ mod test { let _ = tree.delete(&short_key(65)).unwrap(); assert!(tree.is_empty()); } + + #[test] + fn contains() { + let mut tree = SparseMerkleTree::::default(); + + // An empty tree contains no keys + assert!(!tree.contains(&short_key(0))); + assert!(!tree.contains(&short_key(1))); + + // Add a key, which the tree must then contain + tree.upsert(short_key(1), ValueHash::from([1u8; 32])).unwrap(); + assert!(!tree.contains(&short_key(0))); + assert!(tree.contains(&short_key(1))); + + // Delete the key, which the tree must not contain + tree.delete(&short_key(1)).unwrap(); + assert!(!tree.contains(&short_key(0))); + assert!(!tree.contains(&short_key(1))); + + // Build a more complex tree with two keys, which the tree must then contain + tree.upsert(short_key(0), ValueHash::from([0u8; 32])).unwrap(); + tree.upsert(short_key(1), ValueHash::from([1u8; 32])).unwrap(); + assert!(tree.contains(&short_key(0))); + assert!(tree.contains(&short_key(1))); + + // Delete each key in turn + tree.delete(&short_key(0)).unwrap(); + assert!(!tree.contains(&short_key(0))); + tree.delete(&short_key(1)).unwrap(); + assert!(!tree.contains(&short_key(1))); + + // Sanity check that the tree is now empty + assert!(tree.is_empty()); + } }