diff --git a/trinity/protocol/les/exchanges.py b/trinity/protocol/les/exchanges.py index ba0b249c2e..f32b5c5b0b 100644 --- a/trinity/protocol/les/exchanges.py +++ b/trinity/protocol/les/exchanges.py @@ -11,20 +11,27 @@ from trinity.protocol.common.exchanges import ( BaseExchange, ) +# Q: What is the order to imports? and why sometimes new line between imports and sometimes not? +from trinity.protocol.common.types import BlockBodyBundles +from trinity.rlp.block_body import BlockBody from trinity.utils.les import ( gen_request_id, ) from .normalizers import ( BlockHeadersNormalizer, + GetBlockBodiesNormalizer, ) from .requests import ( GetBlockHeadersRequest, + GetBlockBodiesRequest, ) from .trackers import ( GetBlockHeadersTracker, + GetBlockBodiesTracker, ) from .validators import ( + GetBlockBodiesValidator, GetBlockHeadersValidator, match_payload_request_id, ) @@ -61,3 +68,25 @@ async def __call__( # type: ignore match_payload_request_id, timeout, ) + +# Q: Where can I find the correct signature for this class? +class GetBlockBodiesExchange(LESExchange[Tuple[BlockBody, ...]]): + _normalizer = GetBlockBodiesNormalizer() + request_class = GetBlockBodiesRequest + tracker_class = GetBlockBodiesTracker + + async def __call__(self, # type: ignore + headers: Tuple[BlockHeader, ...], + timeout: float = None) -> BlockBodyBundles: + validator = GetBlockBodiesValidator(headers) + + block_hashes = tuple(header.hash for header in headers) + request = self.request_class(block_hashes, gen_request_id()) + + return await self.get_result( + request, + self._normalizer, + validator, + match_payload_request_id, + timeout, + ) diff --git a/trinity/protocol/les/handlers.py b/trinity/protocol/les/handlers.py index ce12300016..36abb01a70 100644 --- a/trinity/protocol/les/handlers.py +++ b/trinity/protocol/les/handlers.py @@ -2,10 +2,17 @@ BaseChainExchangeHandler, ) -from .exchanges import GetBlockHeadersExchange +from .exchanges import ( + GetBlockBodiesExchange, + GetBlockHeadersExchange, +) -class LESExchangeHandler(BaseChainExchangeHandler): +class ETHExchangeHandler(BaseChainExchangeHandler): _exchange_config = { + 'get_block_bodies': GetBlockBodiesExchange, 'get_block_headers': GetBlockHeadersExchange, } + + # These are needed only to please mypy. + get_block_bodies: GetBlockBodiesExchange diff --git a/trinity/protocol/les/normalizers.py b/trinity/protocol/les/normalizers.py index 6618b059aa..4a7920641a 100644 --- a/trinity/protocol/les/normalizers.py +++ b/trinity/protocol/les/normalizers.py @@ -8,13 +8,22 @@ from eth.rlp.headers import BlockHeader from trinity.protocol.common.normalizers import BaseNormalizer +from trinity.rlp.block_body import BlockBody TResult = TypeVar('TResult') LESNormalizer = BaseNormalizer[Dict[str, Any], TResult] +# Q: Shouldn't this be named GetBlockHeadersNormalizer? class BlockHeadersNormalizer(LESNormalizer[Tuple[BlockHeader, ...]]): @staticmethod def normalize_result(message: Dict[str, Any]) -> Tuple[BlockHeader, ...]: result = message['headers'] return result + + +class GetBlockBodiesNormalizer(LESNormalizer[Tuple[BlockBody, ...]]): + @staticmethod + def normalize_result(message: Dict[str, Any]) -> Tuple[BlockBody, ...]: + result = message['bodies'] + return result diff --git a/trinity/protocol/les/requests.py b/trinity/protocol/les/requests.py index ced779d016..950d25dd16 100644 --- a/trinity/protocol/les/requests.py +++ b/trinity/protocol/les/requests.py @@ -1,9 +1,13 @@ from typing import ( Any, Dict, + Tuple, ) -from eth_typing import BlockIdentifier +from eth_typing import ( + BlockIdentifier, + Hash32, +) from p2p.protocol import BaseRequest @@ -16,6 +20,8 @@ BlockHeaders, GetBlockHeaders, GetBlockHeadersQuery, + BlockBodies, + GetBlockBodies, ) @@ -59,3 +65,16 @@ def __init__(self, reverse, ), } + + +class GetBlockBodiesRequest(BaseRequest[Tuple[Hash32, ...]]): + cmd_type = GetBlockBodies + response_type = BlockBodies + + def __init__(self, + block_hashes: Tuple[Hash32, ...], + request_id: int) -> None: + self.command_payload = { + 'request_id': request_id, + 'block_hashes': block_hashes, + } diff --git a/trinity/protocol/les/trackers.py b/trinity/protocol/les/trackers.py index e054b7d2a3..31729aaf3d 100644 --- a/trinity/protocol/les/trackers.py +++ b/trinity/protocol/les/trackers.py @@ -6,10 +6,12 @@ from eth.rlp.headers import BlockHeader from trinity.protocol.common.trackers import BasePerformanceTracker +from trinity.rlp.block_body import BlockBody from trinity.utils.headers import sequence_builder from .requests import ( GetBlockHeadersRequest, + GetBlockBodiesRequest, ) @@ -37,3 +39,21 @@ def _get_result_size(self, result: Tuple[BlockHeader, ...]) -> int: def _get_result_item_count(self, result: Tuple[BlockHeader, ...]) -> int: return len(result) + + +BaseGetBlockBodiesTracker = BasePerformanceTracker[ + GetBlockBodiesRequest, + Tuple[BlockBody, ...], +] + + +# Q: Where can I find the signature for this class? +class GetBlockBodiesTracker(BaseGetBlockBodiesTracker): + def _get_request_size(self, request: GetBlockBodiesRequest) -> Optional[int]: + return len(request.command_payload['block_hashes']) + + def _get_result_size(self, result: Tuple[BlockBody, ...]) -> int: + return len(result) + + def _get_result_item_count(self, result: Tuple[BlockBody, ...]) -> int: + return len(result) diff --git a/trinity/protocol/les/validators.py b/trinity/protocol/les/validators.py index 38e492b921..a61e3f560f 100644 --- a/trinity/protocol/les/validators.py +++ b/trinity/protocol/les/validators.py @@ -1,15 +1,21 @@ from typing import ( Any, Dict, + Tuple, ) +from eth.rlp.headers import BlockHeader + from eth_utils import ( ValidationError, ) from trinity.protocol.common.validators import ( + BaseValidator, BaseBlockHeadersValidator, ) +from trinity.protocol.common.types import BlockBodyBundles + from . import constants @@ -17,6 +23,27 @@ class GetBlockHeadersValidator(BaseBlockHeadersValidator): protocol_max_request_size = constants.MAX_HEADERS_FETCH +class GetBlockBodiesValidator(BaseValidator[BlockBodyBundles]): + def __init__(self, headers: Tuple[BlockHeader, ...]) -> None: + self.headers = headers + + def validate_result(self, response: BlockBodyBundles) -> None: + expected_keys = { + (header.transaction_root, header.uncles_hash) + for header in self.headers + } + actual_keys = { + (txn_root, uncles_hash) + for body, (txn_root, trie_data), uncles_hash + in response + } + unexpected_keys = actual_keys.difference(expected_keys) + if unexpected_keys: + raise ValidationError( + "Got {0} unexpected block bodies".format(len(unexpected_keys)) + ) + + def match_payload_request_id(request: Dict[str, Any], response: Dict[str, Any]) -> None: if request['request_id'] != response['request_id']: raise ValidationError("Request `id` does not match")