Skip to content

Commit

Permalink
Merge pull request ethereum#1695 from hwwhww/fix_get_block_root
Browse files Browse the repository at this point in the history
Sync `get_block_root` and `get_shard_committees_at_slot ` and fix `test_demo`
  • Loading branch information
hwwhww committed Jan 8, 2019
2 parents b6da0b4 + d29578a commit 000333a
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 113 deletions.
98 changes: 54 additions & 44 deletions eth/beacon/helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import (
Any,
Iterable,
Sequence,
Tuple,
Expand Down Expand Up @@ -63,47 +62,50 @@
from eth.beacon.types.validator_records import ValidatorRecord # noqa: F401


def _get_element_from_recent_list(
target_list: Sequence[Any],
target_slot: SlotNumber,
slot_relative_position: SlotNumber) -> Any:
#
# Get block root
#
def _get_block_root(
latest_block_roots: Sequence[Hash32],
state_slot: SlotNumber,
slot: SlotNumber,
latest_block_roots_length: int) -> Hash32:
"""
Return the element from ``target_list`` by the ``target_slot`` number,
where the the element should be at ``target_slot - slot_relative_position``th
element of the given ``target_list``.
Return the block root at a recent ``slot``.
"""
target_list_length = len(target_list)

if target_slot < slot_relative_position:
raise ValueError(
"target_slot (%s) should be greater than or equal to slot_relative_position (%s)" %
(target_slot, slot_relative_position)
if state_slot > slot + latest_block_roots_length:
raise ValidationError(
"state.slot ({}) should be less than or equal to "
"(slot + latest_block_roots_length) ({}), "
"where slot={}, latest_block_roots_length={}".format(
state_slot,
slot + latest_block_roots_length,
slot,
latest_block_roots_length,
)
)

if target_slot >= slot_relative_position + target_list_length:
raise ValueError(
"target_slot (%s) should be less than "
"slot_relative_position (%s) + target_list_length (%s)" %
(target_slot, slot_relative_position, target_list_length)
if slot >= state_slot:
raise ValidationError(
"slot ({}) should be less than state.slot ({})".format(
slot,
state_slot,
)
)
return target_list[target_slot - slot_relative_position]
return latest_block_roots[slot % latest_block_roots_length]


#
# Get block root
#
def get_block_root(
latest_block_roots: Sequence[Hash32],
current_slot: SlotNumber,
slot: SlotNumber) -> Hash32:
state: 'BeaconState',
slot: SlotNumber,
latest_block_roots_length: int) -> Hash32:
"""
Returns the block root at a recent ``slot``.
Return the block root at a recent ``slot``.
"""
slot_relative_position = SlotNumber(current_slot - len(latest_block_roots))
return _get_element_from_recent_list(
latest_block_roots,
return _get_block_root(
state.latest_block_roots,
state.slot,
slot,
slot_relative_position,
latest_block_roots_length,
)


Expand All @@ -116,21 +118,29 @@ def _get_shard_committees_at_slot(
shard_committees_at_slots: Sequence[Sequence[ShardCommittee]],
slot: SlotNumber,
epoch_length: int) -> Iterable[ShardCommittee]:
if len(shard_committees_at_slots) != epoch_length * 2:
raise ValueError(
"Length of shard_committees_at_slots != epoch_length * 2"
"\texpected: %s, found: %s" % (
epoch_length * 2, len(shard_committees_at_slots)

earliest_slot_in_array = state_slot - (state_slot % epoch_length) - epoch_length

if earliest_slot_in_array > slot:
raise ValidationError(
"earliest_slot_in_array ({}) should be less than or equal to slot ({})".format(
earliest_slot_in_array,
slot,
)
)
if slot >= earliest_slot_in_array + epoch_length * 2:
raise ValidationError(
"slot ({}) should be less than "
"(earliest_slot_in_array + epoch_length * 2) ({}), "
"where earliest_slot_in_array={}, epoch_length={}".format(
slot,
earliest_slot_in_array + epoch_length * 2,
earliest_slot_in_array,
epoch_length,
)
)

slot_relative_position = SlotNumber(state_slot - epoch_length)

yield from _get_element_from_recent_list(
shard_committees_at_slots,
slot,
slot_relative_position,
)
return shard_committees_at_slots[slot - earliest_slot_in_array]


def get_shard_committees_at_slot(state: 'BeaconState',
Expand Down
1 change: 1 addition & 0 deletions eth/beacon/state_machines/forks/serenity/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def process_attestations(state: BeaconState,
attestation,
config.EPOCH_LENGTH,
config.MIN_ATTESTATION_INCLUSION_DELAY,
config.LATEST_BLOCK_ROOTS_LENGTH,
)

# update_latest_attestations
Expand Down
12 changes: 6 additions & 6 deletions eth/beacon/state_machines/forks/serenity/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
ValidationError,
)
import rlp
from typing import Type

from eth.constants import (
ZERO_HASH32,
Expand Down Expand Up @@ -69,10 +68,11 @@ def validate_serenity_proposer_signature(state: BeaconState,
#
# Attestation validation
#
def validate_serenity_attestation(state: Type[BeaconState],
def validate_serenity_attestation(state: BeaconState,
attestation: Attestation,
epoch_length: int,
min_attestation_inclusion_delay: int) -> None:
min_attestation_inclusion_delay: int,
latest_block_roots_length: int) -> None:
"""
Validate the given ``attestation``.
Raise ``ValidationError`` if it's invalid.
Expand All @@ -96,9 +96,9 @@ def validate_serenity_attestation(state: Type[BeaconState],
validate_serenity_attestation_justified_block_root(
attestation.data,
justified_block_root=get_block_root(
state.latest_block_roots,
current_slot=state.slot,
state=state,
slot=attestation.data.justified_slot,
latest_block_roots_length=latest_block_roots_length,
),
)

Expand Down Expand Up @@ -237,7 +237,7 @@ def validate_serenity_attestation_shard_block_root(attestation_data: Attestation
)


def validate_serenity_attestation_aggregate_signature(state: Type[BeaconState],
def validate_serenity_attestation_aggregate_signature(state: BeaconState,
attestation: Attestation,
epoch_length: int) -> None:
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/beacon/state_machines/forks/test_serenity_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def create_mock_signed_attestations_at_slot(state,
shard=shard_committee.shard,
justified_slot=state.previous_justified_slot,
justified_block_root=get_block_root(
state.latest_block_roots,
state.slot,
state,
state.previous_justified_slot,
config.LATEST_BLOCK_ROOTS_LENGTH,
),
latest_crosslink_root=latest_crosslink_root,
shard_block_root=ZERO_HASH32,
Expand Down
7 changes: 2 additions & 5 deletions tests/beacon/state_machines/test_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
'shard_count'
),
[
(10, 2, 1, 2, 2)
(10, 10, 1, 2, 2)
]
)
def test_demo(base_db,
Expand All @@ -45,10 +45,7 @@ def test_demo(base_db,

# Sign block
beacon_proposer_index = get_beacon_proposer_index(
# TODO: use `state` when the bug of `get_shard_committees_at_slot` is fixed.
state.copy(
slot=state.slot + 2,
),
state,
block.slot,
config.EPOCH_LENGTH,
)
Expand Down
87 changes: 31 additions & 56 deletions tests/beacon/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,16 @@
from eth.beacon.types.states import BeaconState
from eth.beacon.types.validator_records import ValidatorRecord
from eth.beacon.helpers import (
_get_element_from_recent_list,
_get_block_root,
_get_shard_committees_at_slot,
get_active_validator_indices,
get_attestation_participants,
get_beacon_proposer_index,
get_block_root,
get_effective_balance,
get_domain,
get_fork_version,
get_new_shuffling,
get_new_validator_registry_delta_chain_tip,
_get_shard_committees_at_slot,
get_block_committees_info,
get_pubkey_for_indices,
generate_aggregate_pubkeys,
Expand Down Expand Up @@ -88,56 +87,29 @@ def get_sample_shard_committees_at_slots(num_slot,

def generate_mock_latest_block_roots(
genesis_block,
current_block_number,
epoch_length):
chain_length = (current_block_number // epoch_length + 1) * epoch_length
current_slot,
epoch_length,
latest_block_roots_length):
assert current_slot < latest_block_roots_length

chain_length = (current_slot // epoch_length + 1) * epoch_length
blocks = get_pseudo_chain(chain_length, genesis_block)
latest_block_roots = [
b'\x00' * 32
for i
in range(epoch_length * 2 - current_block_number)
] + [block.root for block in blocks[:current_block_number]]
block.hash
for block in blocks[:current_slot]
] + [
ZERO_HASH32
for _ in range(latest_block_roots_length - current_slot)
]
return blocks, latest_block_roots


@pytest.mark.parametrize(
(
'target_list,target_slot,slot_relative_position,result'
),
[
([i for i in range(5)], 10, 7, 3),
([], 1, 1, ValueError()),
# target_slot < slot_relative_position
([i for i in range(5)], 1, 2, ValueError()),
# target_slot >= slot_relative_position + target_list_length
([i for i in range(5)], 6, 1, ValueError()),
],
)
def test_get_element_from_recent_list(target_list,
target_slot,
slot_relative_position,
result):
if isinstance(result, Exception):
with pytest.raises(ValueError):
_get_element_from_recent_list(
target_list,
target_slot,
slot_relative_position,
)
else:
assert result == _get_element_from_recent_list(
target_list,
target_slot,
slot_relative_position,
)


#
# Get block rootes
#
@pytest.mark.parametrize(
(
'current_block_number,target_slot,success'
'current_slot,target_slot,success'
),
[
(10, 0, True),
Expand All @@ -148,30 +120,34 @@ def test_get_element_from_recent_list(target_list,
(128, 128, False),
],
)
def test_get_block_root(current_block_number,
def test_get_block_root(current_slot,
target_slot,
success,
epoch_length,
latest_block_roots_length,
sample_block):
blocks, latest_block_roots = generate_mock_latest_block_roots(
sample_block,
current_block_number,
current_slot,
epoch_length,
latest_block_roots_length,
)

if success:
block_root = get_block_root(
block_root = _get_block_root(
latest_block_roots,
current_block_number,
current_slot,
target_slot,
latest_block_roots_length,
)
assert block_root == blocks[target_slot].root
else:
with pytest.raises(ValueError):
get_block_root(
with pytest.raises(ValidationError):
_get_block_root(
latest_block_roots,
current_block_number,
current_slot,
target_slot,
latest_block_roots_length,
)


Expand Down Expand Up @@ -207,15 +183,14 @@ def test_get_block_root(current_block_number,
64,
True,
),
# The length of shard_committees_at_slots != epoch_length * 2
(
100,
64,
64,
127,
1,
128,
10,
0,
False,
1,
True,
),
# slot is too small
(
Expand Down Expand Up @@ -265,7 +240,7 @@ def test_get_shard_committees_at_slot(
assert len(shard_committees) > 0
assert len(shard_committees[0].committee) > 0
else:
with pytest.raises(ValueError):
with pytest.raises(ValidationError):
_get_shard_committees_at_slot(
state_slot=state_slot,
shard_committees_at_slots=shard_committees_at_slots,
Expand Down

0 comments on commit 000333a

Please sign in to comment.