Skip to content

Commit

Permalink
Implement ShardUId
Browse files Browse the repository at this point in the history
  • Loading branch information
Min Zhang committed Aug 15, 2021
1 parent df2a421 commit b0d9267
Show file tree
Hide file tree
Showing 43 changed files with 744 additions and 391 deletions.
13 changes: 9 additions & 4 deletions chain/chain/src/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1453,8 +1453,11 @@ impl Chain {
root_proofs.push(root_proofs_cur);
}

let state_root_node =
self.runtime_adapter.get_state_root_node(shard_id, &chunk_header.prev_state_root())?;
let state_root_node = self.runtime_adapter.get_state_root_node(
shard_id,
&sync_hash,
&chunk_header.prev_state_root(),
)?;

let shard_state_header = match chunk {
ShardChunk::V1(chunk) => {
Expand Down Expand Up @@ -1527,19 +1530,21 @@ impl Chain {
if shard_id as usize >= sync_prev_block.chunks().len() {
return Err(ErrorKind::InvalidStateRequest("shard_id out of bounds".into()).into());
}
// TODO: why is this prev_state_root here?
let state_root = sync_prev_block.chunks()[shard_id as usize].prev_state_root();
let state_root_node = self
.runtime_adapter
.get_state_root_node(shard_id, &state_root)
.get_state_root_node(shard_id, &sync_hash, &state_root)
.log_storage_error("get_state_root_node fail")?;
let num_parts = get_num_state_parts(state_root_node.memory_usage);

if part_id >= num_parts {
return Err(ErrorKind::InvalidStateRequest("part_id out of bound".to_string()).into());
}
// TODO: this part of logic may change when shards may change?
let state_part = self
.runtime_adapter
.obtain_state_part(shard_id, &state_root, part_id, num_parts)
.obtain_state_part(shard_id, &sync_hash, &state_root, part_id, num_parts)
.log_storage_error("obtain_state_part fail")?;

// Before saving State Part data, we need to make sure we can calculate and save State Header
Expand Down
16 changes: 10 additions & 6 deletions chain/chain/src/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use near_primitives::errors::InvalidTxError;
use near_primitives::hash::CryptoHash;
use near_primitives::merkle::{MerklePath, PartialMerkleTree};
use near_primitives::receipt::{Receipt, ReceiptResult};
use near_primitives::shard_layout::{get_block_shard_uid, ShardUId};
use near_primitives::sharding::{
ChunkHash, EncodedShardChunk, PartialEncodedChunk, ReceiptProof, ShardChunk, ShardChunkHeader,
StateSyncInfo,
Expand Down Expand Up @@ -2007,15 +2008,17 @@ impl<'a> ChainStoreUpdate<'a> {
GCMode::Fork(tries) => {
// If the block is on a fork, we delete the state that's the result of applying this block
for shard_id in 0..header.chunk_mask().len() as ShardId {
// TODO: pass in the actual shard version that this block uses
let shard_uid = ShardUId { version: 0, shard_id: shard_id as u32 };
self.store()
.get_ser(ColTrieChanges, &get_block_shard_id(&block_hash, shard_id))?
.get_ser(ColTrieChanges, &get_block_shard_uid(&block_hash, shard_uid))?
.map(|trie_changes: TrieChanges| {
tries
.revert_insertions(&trie_changes, shard_id, &mut store_update)
.revert_insertions(&trie_changes, shard_uid, &mut store_update)
.map(|_| {
self.gc_col(
ColTrieChanges,
&get_block_shard_id(&block_hash, shard_id),
&get_block_shard_uid(&block_hash, shard_uid),
);
self.inc_gc_col_state();
})
Expand All @@ -2027,15 +2030,16 @@ impl<'a> ChainStoreUpdate<'a> {
GCMode::Canonical(tries) => {
// If the block is on canonical chain, we delete the state that's before applying this block
for shard_id in 0..header.chunk_mask().len() as ShardId {
let shard_uid = ShardUId { version: 0, shard_id: shard_id as u32 };
self.store()
.get_ser(ColTrieChanges, &get_block_shard_id(&block_hash, shard_id))?
.get_ser(ColTrieChanges, &get_block_shard_uid(&block_hash, shard_uid))?
.map(|trie_changes: TrieChanges| {
tries
.apply_deletions(&trie_changes, shard_id, &mut store_update)
.apply_deletions(&trie_changes, shard_uid, &mut store_update)
.map(|_| {
self.gc_col(
ColTrieChanges,
&get_block_shard_id(&block_hash, shard_id),
&get_block_shard_uid(&block_hash, shard_uid),
);
self.inc_gc_col_state();
})
Expand Down
2 changes: 1 addition & 1 deletion chain/chain/src/store_validator/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ pub(crate) fn trie_changes_chunk_extra_exists(
block_hash,
shard_id
);
let trie = sv.runtime_adapter.get_trie_for_shard(*shard_id);
let trie = sv.runtime_adapter.get_trie_for_shard(*shard_id, block.header().prev_hash());
let trie_iterator = unwrap_or_err!(
TrieIterator::new(&trie, &new_root),
"Trie Node Missing for ShardChunk {:?}",
Expand Down
24 changes: 14 additions & 10 deletions chain/chain/src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,23 @@ use chrono::Utc;
use num_rational::Rational;
use tracing::debug;

use near_chain_configs::ProtocolConfig;
use near_chain_primitives::{Error, ErrorKind};
use near_crypto::{KeyType, PublicKey, SecretKey, Signature};
use near_pool::types::PoolIterator;
use near_primitives::account::{AccessKey, Account};
#[cfg(feature = "protocol_feature_block_header_v3")]
use near_primitives::block_header::{Approval, ApprovalInner};
use near_primitives::challenge::ChallengesResult;
use near_primitives::epoch_manager::block_info::BlockInfo;
use near_primitives::epoch_manager::epoch_info::EpochInfo;
use near_primitives::errors::InvalidTxError;
use near_primitives::hash::{hash, CryptoHash};
use near_primitives::receipt::{ActionReceipt, Receipt, ReceiptEnum};
use near_primitives::serialize::to_base;
use near_primitives::shard_layout::ShardUId;
use near_primitives::sharding::ChunkHash;
use near_primitives::state_record::StateRecord;
use near_primitives::transaction::{
Action, ExecutionMetadata, ExecutionOutcome, ExecutionOutcomeWithId, ExecutionStatus,
SignedTransaction, TransferAction,
Expand Down Expand Up @@ -49,10 +54,6 @@ use crate::types::{
#[cfg(feature = "protocol_feature_block_header_v3")]
use crate::Doomslug;
use crate::{BlockHeader, DoomslugThresholdMode, RuntimeAdapter};
use near_chain_configs::ProtocolConfig;
#[cfg(feature = "protocol_feature_block_header_v3")]
use near_primitives::block_header::{Approval, ApprovalInner};
use near_primitives::state_record::StateRecord;

#[derive(BorshSerialize, BorshDeserialize, Hash, PartialEq, Eq, Ord, PartialOrd, Clone, Debug)]
struct AccountNonce(AccountId, Nonce);
Expand Down Expand Up @@ -137,7 +138,7 @@ impl KeyValueRuntime {
epoch_length: u64,
no_gc: bool,
) -> Self {
let tries = ShardTries::new(store.clone(), num_shards);
let tries = ShardTries::new(store.clone(), 0, num_shards);
let mut initial_amounts = HashMap::new();
for (i, validator) in validators.iter().flatten().enumerate() {
initial_amounts.insert(validator.clone(), (1000 + 100 * i) as u128);
Expand Down Expand Up @@ -303,12 +304,12 @@ impl RuntimeAdapter for KeyValueRuntime {
self.tries.clone()
}

fn get_trie_for_shard(&self, shard_id: ShardId) -> Trie {
self.tries.get_trie_for_shard(shard_id)
fn get_trie_for_shard(&self, shard_id: ShardId, _block_hash: &CryptoHash) -> Trie {
self.tries.get_trie_for_shard(ShardUId { version: 0, shard_id: shard_id as u32 })
}

fn get_view_trie_for_shard(&self, shard_id: ShardId) -> Trie {
self.tries.get_view_trie_for_shard(shard_id)
fn get_view_trie_for_shard(&self, shard_id: ShardId, _block_hash: &CryptoHash) -> Trie {
self.tries.get_view_trie_for_shard(ShardUId { version: 0, shard_id: shard_id as u32 })
}

fn verify_block_vrf(
Expand Down Expand Up @@ -555,6 +556,7 @@ impl RuntimeAdapter for KeyValueRuntime {
&self,
_gas_price: Balance,
_gas_limit: Gas,
_epoch_id: &EpochId,
_shard_id: ShardId,
_state_root: StateRoot,
_next_block_height: BlockHeight,
Expand Down Expand Up @@ -759,7 +761,7 @@ impl RuntimeAdapter for KeyValueRuntime {
Ok(ApplyTransactionResult {
trie_changes: WrappedTrieChanges::new(
self.get_tries(),
shard_id,
ShardUId { version: 0, shard_id: shard_id as u32 },
TrieChanges::empty(state_root),
Default::default(),
block_hash.clone(),
Expand Down Expand Up @@ -869,6 +871,7 @@ impl RuntimeAdapter for KeyValueRuntime {
fn obtain_state_part(
&self,
_shard_id: ShardId,
_block_hash: &CryptoHash,
state_root: &StateRoot,
part_id: u64,
num_parts: u64,
Expand Down Expand Up @@ -917,6 +920,7 @@ impl RuntimeAdapter for KeyValueRuntime {
fn get_state_root_node(
&self,
_shard_id: ShardId,
_block_hash: &CryptoHash,
state_root: &StateRoot,
) -> Result<StateRootNode, Error> {
Ok(StateRootNode {
Expand Down
11 changes: 8 additions & 3 deletions chain/chain/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,13 @@ pub trait RuntimeAdapter: Send + Sync {

fn get_tries(&self) -> ShardTries;

/// Returns trie.
fn get_trie_for_shard(&self, shard_id: ShardId) -> Trie;
/// Returns trie. Since shard layout may change from epoch to epoch, `shard_id` itself is
/// not enough to identify the trie. `prev_hash` is used to identify the epoch the given
/// `shard_id` is at.
fn get_trie_for_shard(&self, shard_id: ShardId, prev_hash: &CryptoHash) -> Trie;

/// Returns trie with view cache
fn get_view_trie_for_shard(&self, shard_id: ShardId) -> Trie;
fn get_view_trie_for_shard(&self, shard_id: ShardId, prev_hash: &CryptoHash) -> Trie;

fn verify_block_vrf(
&self,
Expand Down Expand Up @@ -285,6 +287,7 @@ pub trait RuntimeAdapter: Send + Sync {
&self,
gas_price: Balance,
gas_limit: Gas,
epoch_id: &EpochId,
shard_id: ShardId,
state_root: StateRoot,
next_block_height: BlockHeight,
Expand Down Expand Up @@ -609,6 +612,7 @@ pub trait RuntimeAdapter: Send + Sync {
fn obtain_state_part(
&self,
shard_id: ShardId,
block_hash: &CryptoHash,
state_root: &StateRoot,
part_id: u64,
num_parts: u64,
Expand Down Expand Up @@ -641,6 +645,7 @@ pub trait RuntimeAdapter: Send + Sync {
fn get_state_root_node(
&self,
shard_id: ShardId,
block_hash: &CryptoHash,
state_root: &StateRoot,
) -> Result<StateRootNode, Error>;

Expand Down
19 changes: 9 additions & 10 deletions chain/chain/tests/gc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ mod tests {
use near_crypto::KeyType;
use near_primitives::block::Block;
use near_primitives::merkle::PartialMerkleTree;
use near_primitives::shard_layout::ShardUId;
use near_primitives::types::{NumBlocks, NumShards, StateRoot};
use near_primitives::validator_signer::InMemoryValidatorSigner;
use near_store::test_utils::{create_test_store, gen_changes};
Expand Down Expand Up @@ -78,9 +79,10 @@ mod tests {

let mut trie_changes_shards = Vec::new();
for shard_id in 0..num_shards {
let shard_uid = ShardUId { version: 0, shard_id: shard_id as u32 };
let trie_changes_data = gen_changes(&mut rng, max_changes);
let state_root = prev_state_roots[shard_id as usize];
let trie = tries.get_trie_for_shard(shard_id);
let trie = tries.get_trie_for_shard(shard_uid);
let trie_changes =
trie.update(&state_root, trie_changes_data.iter().cloned()).unwrap();
if verbose {
Expand All @@ -90,7 +92,7 @@ mod tests {
let new_root = trie_changes.new_root;
let wrapped_trie_changes = WrappedTrieChanges::new(
tries.clone(),
shard_id,
shard_uid,
trie_changes,
Default::default(),
*block.hash(),
Expand Down Expand Up @@ -127,7 +129,8 @@ mod tests {
let tries1 = chain1.runtime_adapter.get_tries();
let mut rng = rand::thread_rng();
let shard_to_check_trie = rng.gen_range(0, num_shards);
let trie1 = tries1.get_trie_for_shard(shard_to_check_trie);
let shard_uid = ShardUId { version: 0, shard_id: shard_to_check_trie as u32 };
let trie1 = tries1.get_trie_for_shard(shard_uid);
let genesis1 = chain1.get_block_by_height(0).unwrap().clone();
let mut states1 = vec![];
states1.push((
Expand Down Expand Up @@ -159,7 +162,7 @@ mod tests {

let mut chain2 = get_chain(num_shards);
let tries2 = chain2.runtime_adapter.get_tries();
let trie2 = tries2.get_trie_for_shard(shard_to_check_trie);
let trie2 = tries2.get_trie_for_shard(shard_uid);

// Find gc_height
let mut gc_height = simple_chains[0].length - 51;
Expand Down Expand Up @@ -202,18 +205,14 @@ mod tests {
if block1.header().height() > gc_height || i == gc_height {
let mut trie_store_update2 = StoreUpdate::new_with_tries(tries2.clone());
tries2
.apply_insertions(
&trie_changes2,
shard_to_check_trie,
&mut trie_store_update2,
)
.apply_insertions(&trie_changes2, shard_uid, &mut trie_store_update2)
.unwrap();
state_root2 = trie_changes2.new_root;
assert_eq!(state_root1[shard_to_check_trie as usize], state_root2);
store_update2.merge(trie_store_update2);
} else {
let (trie_store_update2, new_root2) =
tries2.apply_all(&trie_changes2, shard_to_check_trie).unwrap();
tries2.apply_all(&trie_changes2, shard_uid).unwrap();
state_root2 = new_root2;
store_update2.merge(trie_store_update2);
}
Expand Down
1 change: 1 addition & 0 deletions chain/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,7 @@ impl Client {
runtime_adapter.prepare_transactions(
prev_block_header.gas_price(),
chunk_extra.gas_limit(),
&next_epoch_id,
shard_id,
*chunk_extra.state_root(),
// while the height of the next block that includes the chunk might not be prev_height + 1,
Expand Down
23 changes: 17 additions & 6 deletions chain/epoch_manager/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub use crate::types::RngSeed;

pub use crate::reward_calculator::NUM_SECONDS_IN_A_YEAR;
use near_chain::types::{BlockHeaderInfo, ValidatorInfoIdentifier};
use near_primitives::shard_layout::ShardLayout;
use near_primitives::shard_layout::{ShardLayout, ShardVersion};
use near_store::db::DBCol::ColEpochValidatorInfo;

mod proposals;
Expand Down Expand Up @@ -1188,10 +1188,17 @@ impl EpochManager {
Ok(EpochId(*first_block_info.prev_hash()))
}

pub fn get_shard_layout(&mut self, epoch_id: &EpochId) -> Result<ShardLayout, EpochError> {
pub fn get_shard_layout(&mut self, epoch_id: &EpochId) -> Result<&ShardLayout, EpochError> {
let protocol_version = self.get_epoch_info(epoch_id)?.protocol_version();
let shard_layout = &self.config.for_protocol_version(protocol_version).shard_layout;
Ok(shard_layout.clone())
Ok(shard_layout)
}

pub fn get_shard_version(&mut self, epoch_id: &EpochId) -> Result<ShardVersion, EpochError> {
let protocol_version = self.get_epoch_info(epoch_id)?.protocol_version();
let shard_version =
self.config.for_protocol_version(protocol_version).shard_layout.version();
Ok(shard_version)
}

pub fn get_epoch_info(&mut self, epoch_id: &EpochId) -> Result<&EpochInfo, EpochError> {
Expand Down Expand Up @@ -3616,6 +3623,7 @@ mod tests {
vec!["aurora".parse().unwrap()],
vec!["hhhh", "oooo"].into_iter().map(|x| x.parse().unwrap()).collect(),
Some(vec![0, 0, 0, 0]),
1,
);
let shard_config = ShardConfig {
num_block_producer_seats_per_shard: get_num_seats_per_shard(4, 2),
Expand Down Expand Up @@ -3661,12 +3669,15 @@ mod tests {
epoch_manager.get_epoch_info(&EpochId(h[2])).unwrap().protocol_version(),
new_protocol_version - 1
);
assert_eq!(epoch_manager.get_shard_layout(&EpochId(h[2])).unwrap(), ShardLayout::v0(1),);
assert_eq!(
*epoch_manager.get_shard_layout(&EpochId(h[2])).unwrap(),
ShardLayout::default(),
);
assert_eq!(
epoch_manager.get_epoch_info(&EpochId(h[4])).unwrap().protocol_version(),
new_protocol_version
);
assert_eq!(epoch_manager.get_shard_layout(&EpochId(h[4])).unwrap(), shard_layout);
assert_eq!(*epoch_manager.get_shard_layout(&EpochId(h[4])).unwrap(), shard_layout);
}

#[test]
Expand All @@ -3686,7 +3697,7 @@ mod tests {
protocol_upgrade_stake_threshold: Rational::new(80, 100),
protocol_upgrade_num_epochs: 2,
minimum_stake_divisor: 1,
shard_layout: ShardLayout::v0(1),
shard_layout: ShardLayout::default(),
};
let config = AllEpochConfig::new(epoch_config, None);
let amount_staked = 1_000_000;
Expand Down
2 changes: 1 addition & 1 deletion chain/epoch_manager/src/proposals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ mod tests {
minimum_stake_divisor: 1,
protocol_upgrade_stake_threshold: Rational::new(80, 100),
protocol_upgrade_num_epochs: 2,
shard_layout: ShardLayout::v0(5),
shard_layout: ShardLayout::v0(5, 0),
},
[0; 32],
&EpochInfo::default(),
Expand Down
2 changes: 1 addition & 1 deletion chain/epoch_manager/src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ pub fn epoch_config(
protocol_upgrade_stake_threshold: Rational::new(80, 100),
protocol_upgrade_num_epochs: 2,
minimum_stake_divisor: 1,
shard_layout: ShardLayout::v0(num_shards),
shard_layout: ShardLayout::v0(num_shards, 0),
};
AllEpochConfig::new(epoch_config, simple_nightshade_shard_config)
}
Expand Down
Loading

0 comments on commit b0d9267

Please sign in to comment.