diff --git a/portal-bridge/src/bridge/state.rs b/portal-bridge/src/bridge/state.rs index 3fd5dae8c..824e67be2 100644 --- a/portal-bridge/src/bridge/state.rs +++ b/portal-bridge/src/bridge/state.rs @@ -159,13 +159,12 @@ impl StateBridge { .header .hash(); - let walk_diff = TrieWalker::new(root_with_trie_diff.root, root_with_trie_diff.trie_diff); + let walk_diff = + TrieWalker::new_partial_trie(root_with_trie_diff.root, root_with_trie_diff.trie_diff)?; // gossip block's new state transitions let mut content_idx = 0; - for node in walk_diff.nodes.keys() { - let account_proof = walk_diff.get_proof(*node); - + for account_proof in walk_diff { // gossip the account self.gossip_account(&account_proof, block_hash, content_idx) .await?; @@ -213,10 +212,10 @@ impl StateBridge { // gossip contract storage let storage_changed_nodes = trin_execution.database.get_storage_trie_diff(address_hash); - let storage_walk_diff = TrieWalker::new(account.storage_root, storage_changed_nodes); + let storage_walk_diff = + TrieWalker::new_partial_trie(account.storage_root, storage_changed_nodes)?; - for storage_node in storage_walk_diff.nodes.keys() { - let storage_proof = storage_walk_diff.get_proof(*storage_node); + for storage_proof in storage_walk_diff { self.gossip_storage( &account_proof, &storage_proof, diff --git a/trin-execution/src/trie_walker.rs b/trin-execution/src/trie_walker.rs deleted file mode 100644 index 4c05359e0..000000000 --- a/trin-execution/src/trie_walker.rs +++ /dev/null @@ -1,199 +0,0 @@ -use std::collections::VecDeque; - -use alloy::{consensus::EMPTY_ROOT_HASH, primitives::B256}; -use eth_trie::{decode_node, node::Node}; -use hashbrown::HashMap as BrownHashMap; -use serde::{Deserialize, Serialize}; - -use super::types::trie_proof::TrieProof; - -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] -pub struct TrieWalkerNode { - /// The encoded version of the trie node. - pub encoded_node: Vec, - /// The hash of the parent node. It is `None` only for root of the trie. - pub parent_hash: Option, - /// Path from parent node to this node. - pub path_nibbles: Vec, -} - -impl TrieWalkerNode { - pub fn new(encoded_node: Vec, parent_hash: Option, path_nibbles: Vec) -> Self { - Self { - encoded_node, - parent_hash, - path_nibbles, - } - } -} - -/// This struct takes in a root hash and a hashmap of changed nodes, then you can call an iterator -/// which will return every proof to gossip -pub struct TrieWalker { - pub nodes: BrownHashMap, -} - -impl TrieWalker { - pub fn new(root_hash: B256, nodes: BrownHashMap>) -> Self { - // if the storage root is empty then there is no storage to gossip - if root_hash == EMPTY_ROOT_HASH { - return Self { - nodes: BrownHashMap::new(), - }; - } - - if nodes.is_empty() { - return Self { - nodes: BrownHashMap::new(), - }; - } - - let processed_nodes = Self::process_trie(root_hash, &nodes) - .expect("This shouldn't fail as we only pass valid tries"); - Self { - nodes: processed_nodes, - } - } - - fn process_trie( - root_hash: B256, - nodes: &BrownHashMap>, - ) -> anyhow::Result> { - let mut trie_walker_nodes: BrownHashMap = BrownHashMap::new(); - let mut stack = vec![root_hash]; - - trie_walker_nodes.insert( - root_hash, - TrieWalkerNode::new( - nodes - .get(&root_hash) - .expect("Failed to get encoded node for root node. This should never happen.") - .clone(), - None, - vec![], - ), - ); - while let Some(node_key) = stack.pop() { - let encoded_node = nodes - .get(&node_key) - .expect("The stack should only contain nodes that are in the changed nodes"); - - let decoded_node = decode_node(&mut encoded_node.as_slice()) - .expect("Should should only be passing valid encoded nodes"); - - match decoded_node { - Node::Extension(extension) => { - let extension = extension.read().expect("Reading an extension should work"); - // We look for hash nodes in order to connect them to the root. If this node is - // not a hash node, then neither is any of its children. - // We know this because any node that has a hash node as it's descendant would - // also become hash node, because its encoding would be longer than 32 bytes. - if let Node::Hash(hash_node) = &extension.node { - // Only process provided nodes (they belong to the partial trie that we care - // about) - if let Some(encoded_node) = nodes.get(&hash_node.hash) { - stack.push(hash_node.hash); - trie_walker_nodes.insert( - hash_node.hash, - TrieWalkerNode::new( - encoded_node.clone(), - Some(node_key), - extension.prefix.get_data().to_vec(), - ), - ); - } - } - } - Node::Branch(branch) => { - let branch = branch.read().expect("Reading a branch should work"); - for (i, child) in branch.children.iter().enumerate() { - // We look for hash nodes in order to connect them to the root. If this node - // is not a hash node, then neither is any of its children. - // We know this because any node that has a hash node as it's descendant - // would also become hash node, because its encoding would be longer than 32 - // bytes. - if let Node::Hash(hash_node) = child { - //Only process provided nodes (they belong to the partial trie that we - // care about) - if let Some(encoded_node) = nodes.get(&hash_node.hash) { - stack.push(hash_node.hash); - trie_walker_nodes.insert( - hash_node.hash, - TrieWalkerNode::new( - encoded_node.clone(), - Some(node_key), - vec![i as u8], - ), - ); - } - } - } - } - _ => {} - } - } - - Ok(trie_walker_nodes) - } - - pub fn get_proof(&self, node_hash: B256) -> TrieProof { - let mut path_parts = VecDeque::new(); - let mut proof = VecDeque::new(); - let mut next_node: Option = Some(node_hash); - while let Some(current_node) = next_node { - let Some(node) = self.nodes.get(¤t_node) else { - panic!("Node not found in trie walker nodes. This should never happen."); - }; - path_parts.push_front(node.path_nibbles.clone()); - proof.push_front(node.encoded_node.clone().into()); - next_node = node.parent_hash; - } - - TrieProof { - path: Vec::from(path_parts).concat(), - proof: Vec::from(proof), - } - } -} - -#[cfg(test)] -mod tests { - use std::str::FromStr; - - use alloy::primitives::{keccak256, Address, Bytes}; - use eth_trie::{RootWithTrieDiff, Trie}; - use trin_utils::dir::create_temp_test_dir; - - use crate::{config::StateConfig, execution::TrinExecution, trie_walker::TrieWalker}; - - #[tokio::test] - #[ignore = "This test downloads data from a remote server"] - async fn test_trie_walker_builds_valid_proof() { - let temp_directory = create_temp_test_dir().unwrap(); - let mut trin_execution = TrinExecution::new(temp_directory.path(), StateConfig::default()) - .await - .unwrap(); - let RootWithTrieDiff { trie_diff, .. } = trin_execution.process_next_block().await.unwrap(); - let root_hash = trin_execution.get_root().unwrap(); - let walk_diff = TrieWalker::new(root_hash, trie_diff); - - let address = Address::from_str("0x001d14804b399c6ef80e64576f657660804fec0b").unwrap(); - let valid_proof = trin_execution - .database - .trie - .lock() - .get_proof(keccak256(address).as_slice()) - .unwrap() - .into_iter() - .map(Bytes::from) - .collect::>(); - let last_node = valid_proof.last().expect("Missing proof!"); - - let account_proof = walk_diff.get_proof(keccak256(last_node)); - - assert_eq!(account_proof.path, [5, 9, 2, 13]); - assert_eq!(account_proof.proof, valid_proof); - - temp_directory.close().unwrap(); - } -} diff --git a/trin-execution/src/trie_walker/db.rs b/trin-execution/src/trie_walker/db.rs new file mode 100644 index 000000000..dbd4fc184 --- /dev/null +++ b/trin-execution/src/trie_walker/db.rs @@ -0,0 +1,24 @@ +use alloy::primitives::{Bytes, B256}; +use anyhow::anyhow; +use eth_trie::DB; +use hashbrown::HashMap; + +use crate::storage::trie_db::TrieRocksDB; + +pub trait TrieWalkerDb { + fn get(&self, key: &[u8]) -> anyhow::Result>; +} + +impl TrieWalkerDb for HashMap> { + fn get(&self, key: &[u8]) -> anyhow::Result> { + Ok(self.get(key).map(|vec| Bytes::copy_from_slice(vec))) + } +} + +impl TrieWalkerDb for TrieRocksDB { + fn get(&self, key: &[u8]) -> anyhow::Result> { + DB::get(self, key) + .map(|result| result.map(Bytes::from)) + .map_err(|err| anyhow!("Failed to read key value from TrieRocksDB {err}")) + } +} diff --git a/trin-execution/src/trie_walker/mod.rs b/trin-execution/src/trie_walker/mod.rs new file mode 100644 index 000000000..e8e0dad57 --- /dev/null +++ b/trin-execution/src/trie_walker/mod.rs @@ -0,0 +1,274 @@ +pub mod db; + +use std::sync::Arc; + +use alloy::primitives::{Bytes, B256}; +use anyhow::{anyhow, Ok}; +use db::TrieWalkerDb; +use eth_trie::{decode_node, node::Node}; + +use crate::types::trie_proof::TrieProof; + +/// Iterates over trie nodes from the whole or partial state trie +/// +/// Use cases are: +/// 1. Gossiping the whole state trie +/// 2. Gossiping the forward state diffs (partial state trie) +/// 3. Getting stats about the state trie +/// +/// Panics if the trie is corrupted +pub struct TrieWalker { + is_partial_trie: bool, + trie: Arc, + stack: Vec, +} + +impl TrieWalker { + pub fn new(root_hash: B256, trie: Arc) -> anyhow::Result { + let root_node_trie = match trie.get(root_hash.as_slice())? { + Some(root_node_trie) => root_node_trie, + None => return Err(anyhow!("Root node not found in the database")), + }; + let root_proof = TrieProof { + path: vec![], + proof: vec![root_node_trie], + }; + + Ok(Self { + is_partial_trie: false, + trie, + stack: vec![root_proof], + }) + } + + pub fn new_partial_trie(root_hash: B256, trie: DB) -> anyhow::Result { + let root_node_trie = match trie.get(root_hash.as_slice())? { + Some(root_node_trie) => root_node_trie, + None => { + // We are handling 2 potential cases here + // - If the storage root is empty then there is no storage to gossip + // - The trie db is empty so we can't walk it return an empty iterator + return Ok(Self { + is_partial_trie: true, + trie: Arc::new(trie), + stack: vec![], + }); + } + }; + + let root_proof = TrieProof { + path: vec![], + proof: vec![root_node_trie], + }; + + Ok(Self { + is_partial_trie: true, + trie: Arc::new(trie), + stack: vec![root_proof], + }) + } + + fn process_node( + &mut self, + node: Node, + partial_proof: Vec, + path: Vec, + ) -> anyhow::Result<()> { + // We only need to process hash nodes, because if the node isn't a hash node then none of + // its children is + if let Node::Hash(hash) = node { + let encoded_trie_node = match self.trie.get(hash.hash.as_slice())? { + Some(encoded_trie_node) => encoded_trie_node, + None => { + // If we are walking a partial trie, some nodes won't be available in the + // database + if self.is_partial_trie { + return Ok(()); + } + return Err(anyhow::anyhow!("Node not found in the database")); + } + }; + + // check that node decodes correctly and to correct variant + if matches!( + decode_node(&mut encoded_trie_node.as_ref())?, + Node::Empty | Node::Hash(_) + ) { + return Err(anyhow::anyhow!( + "A node hash should never lead to an empty node or a hash node" + )); + } + + let mut proof = partial_proof; + proof.push(encoded_trie_node); + self.stack.push(TrieProof { path, proof }); + } + Ok(()) + } +} + +impl Iterator for TrieWalker { + type Item = TrieProof; + + fn next(&mut self) -> Option { + let next_proof = match self.stack.pop() { + Some(next_proof) => next_proof, + None => return None, + }; + + let TrieProof { path, proof } = &next_proof; + let last_node = proof.last().expect("Proof is empty"); + let decoded_last_node = + decode_node(&mut last_node.as_ref()).expect("Failed to decode node"); + + // Process any children of the node + match decoded_last_node { + Node::Extension(extension) => { + let extension = extension.read().expect("Extension node must be readable"); + self.process_node( + extension.node.clone(), + proof.clone(), + [ + path.as_slice(), + extension.prefix.get_data().to_vec().as_slice(), + ] + .concat(), + ) + .expect("Failed to process node"); + } + Node::Branch(branch) => { + let branch = branch.read().expect("Branch node must be readable"); + + // We want to iterate over the children in reverse order so that we can push them to + // the stack in order + for (i, child) in branch.children.iter().enumerate().rev() { + self.process_node( + child.clone(), + proof.clone(), + [path.as_slice(), &[i as u8]].concat(), + ) + .expect("Failed to process node"); + } + } + // If the node is a leaf node, we don't need to go deeper + _ => {} + } + + Some(next_proof) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use alloy::primitives::{keccak256, Address, B256, U256}; + use eth_trie::{EthTrie, RootWithTrieDiff, Trie}; + use std::{str::FromStr, sync::Arc}; + use tracing_test::traced_test; + use trin_utils::dir::create_temp_test_dir; + + use crate::{ + config::StateConfig, + execution::TrinExecution, + storage::{trie_db::TrieRocksDB, utils::setup_rocksdb}, + utils::full_nibble_path_to_address_hash, + }; + + #[tokio::test] + #[traced_test] + async fn test_state_walker() { + let temp_directory = create_temp_test_dir().unwrap(); + let db = Arc::new(setup_rocksdb(temp_directory.path()).unwrap()); + let mut trie = EthTrie::new(Arc::new(TrieRocksDB::new(false, db.clone()))); + + for i in 1..=18 { + trie.insert( + B256::from(U256::from(i)).as_slice(), + B256::from(U256::from(i)).as_slice(), + ) + .unwrap(); + } + + let root_hash = trie.root_hash().unwrap(); + let walker = TrieWalker::new(root_hash, trie.db.clone()).unwrap(); + let mut count = 0; + let mut leaf_count = 0; + for proof in walker { + count += 1; + + let Some(encoded_last_node) = proof.proof.last() else { + panic!("Account proof is empty"); + }; + + let Node::Leaf(leaf) = + decode_node(&mut encoded_last_node.as_ref()).expect("Failed to decode node") + else { + continue; + }; + leaf_count += 1; + + // reconstruct the address hash from the path so we can call `get_proof` on the trie + let mut partial_key_path = leaf.key.get_data().to_vec(); + partial_key_path.pop(); + let full_key_path = [&proof.path.clone(), partial_key_path.as_slice()].concat(); + let key = full_nibble_path_to_address_hash(&full_key_path); + let valid_proof = trie.get_proof(key.as_slice()).expect("Proof not found"); + assert_eq!(valid_proof, proof.proof); + } + assert_eq!(leaf_count, 18); + assert_eq!(count, 22); + } + + #[tokio::test] + #[ignore = "This test downloads data from a remote server"] + async fn test_trie_walker_builds_valid_proof() { + let temp_directory = create_temp_test_dir().unwrap(); + let mut trin_execution = TrinExecution::new(temp_directory.path(), StateConfig::default()) + .await + .unwrap(); + let RootWithTrieDiff { trie_diff, .. } = trin_execution.process_next_block().await.unwrap(); + let root_hash = trin_execution.get_root().unwrap(); + let walk_diff = TrieWalker::new_partial_trie(root_hash, trie_diff).unwrap(); + + let address = Address::from_str("0x001d14804b399c6ef80e64576f657660804fec0b").unwrap(); + let address_hash = keccak256(address); + let valid_proof = trin_execution + .database + .trie + .lock() + .get_proof(address_hash.as_slice()) + .unwrap() + .into_iter() + .map(Bytes::from) + .collect::>(); + + let mut trie_iter = walk_diff.into_iter(); + let account_proof = loop { + let proof = trie_iter.next().expect("Proof not found"); + let Some(encoded_last_node) = proof.proof.last() else { + panic!("Account proof is empty"); + }; + + let Node::Leaf(leaf) = + decode_node(&mut encoded_last_node.as_ref()).expect("Failed to decode node") + else { + continue; + }; + + // reconstruct the address hash from the path so we can call `get_proof` on the trie + let mut partial_key_path = leaf.key.get_data().to_vec(); + partial_key_path.pop(); + let full_key_path = [&proof.path.clone(), partial_key_path.as_slice()].concat(); + let key = full_nibble_path_to_address_hash(&full_key_path); + if key == address_hash { + break proof; + } + }; + + assert_eq!(account_proof.path, [5, 9, 2, 13]); + assert_eq!(account_proof.proof, valid_proof); + + temp_directory.close().unwrap(); + } +} diff --git a/trin-execution/tests/content_generation.rs b/trin-execution/tests/content_generation.rs index ccee2eba5..ae9ad7804 100644 --- a/trin-execution/tests/content_generation.rs +++ b/trin-execution/tests/content_generation.rs @@ -83,7 +83,7 @@ impl Stats { /// Following command should be used for running: /// /// ``` -/// BLOCKS=1000000 cargo test -p trin-execution --test content_generation -- --include-ignored --nocapture +/// BLOCKS=1000000 cargo test --release -p trin-execution --test content_generation -- --include-ignored --nocapture /// ``` #[tokio::test] #[traced_test] @@ -126,10 +126,9 @@ async fn test_we_can_generate_content_key_values_up_to_x() -> Result<()> { "State root doesn't match" ); - let walk_diff = TrieWalker::new(root_hash, changed_nodes); - for node in walk_diff.nodes.keys() { + let walk_diff = TrieWalker::new_partial_trie(root_hash, changed_nodes)?; + for account_proof in walk_diff { let block_hash = block.header.hash(); - let account_proof = walk_diff.get_proof(*node); // check account content key/value let content_key = @@ -173,10 +172,9 @@ async fn test_we_can_generate_content_key_values_up_to_x() -> Result<()> { // check contract storage content key/value let storage_changed_nodes = trin_execution.database.get_storage_trie_diff(address_hash); - let storage_walk_diff = TrieWalker::new(account.storage_root, storage_changed_nodes); - for storage_node in storage_walk_diff.nodes.keys() { - let storage_proof = storage_walk_diff.get_proof(*storage_node); - + let storage_walk_diff = + TrieWalker::new_partial_trie(account.storage_root, storage_changed_nodes)?; + for storage_proof in storage_walk_diff { let content_key = create_storage_content_key(&storage_proof, address_hash) .expect("Content key should be present"); let content_value =