Skip to content

Commit

Permalink
Extend round trip API to remaining LES protocol commands
Browse files Browse the repository at this point in the history
  • Loading branch information
hoi committed Oct 7, 2018
1 parent c413ec9 commit 353357c
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 3 deletions.
29 changes: 29 additions & 0 deletions trinity/protocol/les/exchanges.py
Expand Up @@ -11,20 +11,27 @@
from trinity.protocol.common.exchanges import (
BaseExchange,
)
# Q: What is the order to imports? and why sometimes new line between imports and sometimes not?
from trinity.protocol.common.types import BlockBodyBundles
from trinity.rlp.block_body import BlockBody
from trinity.utils.les import (
gen_request_id,
)

from .normalizers import (
BlockHeadersNormalizer,
GetBlockBodiesNormalizer,
)
from .requests import (
GetBlockHeadersRequest,
GetBlockBodiesRequest,
)
from .trackers import (
GetBlockHeadersTracker,
GetBlockBodiesTracker,
)
from .validators import (
GetBlockBodiesValidator,
GetBlockHeadersValidator,
match_payload_request_id,
)
Expand Down Expand Up @@ -61,3 +68,25 @@ async def __call__( # type: ignore
match_payload_request_id,
timeout,
)

# Q: Where can I find the correct signature for this class?
class GetBlockBodiesExchange(LESExchange[Tuple[BlockBody, ...]]):
_normalizer = GetBlockBodiesNormalizer()
request_class = GetBlockBodiesRequest
tracker_class = GetBlockBodiesTracker

async def __call__(self, # type: ignore
headers: Tuple[BlockHeader, ...],
timeout: float = None) -> BlockBodyBundles:
validator = GetBlockBodiesValidator(headers)

block_hashes = tuple(header.hash for header in headers)
request = self.request_class(block_hashes, gen_request_id())

return await self.get_result(
request,
self._normalizer,
validator,
match_payload_request_id,
timeout,
)
11 changes: 9 additions & 2 deletions trinity/protocol/les/handlers.py
Expand Up @@ -2,10 +2,17 @@
BaseChainExchangeHandler,
)

from .exchanges import GetBlockHeadersExchange
from .exchanges import (
GetBlockBodiesExchange,
GetBlockHeadersExchange,
)


class LESExchangeHandler(BaseChainExchangeHandler):
class ETHExchangeHandler(BaseChainExchangeHandler):
_exchange_config = {
'get_block_bodies': GetBlockBodiesExchange,
'get_block_headers': GetBlockHeadersExchange,
}

# These are needed only to please mypy.
get_block_bodies: GetBlockBodiesExchange
9 changes: 9 additions & 0 deletions trinity/protocol/les/normalizers.py
Expand Up @@ -8,13 +8,22 @@
from eth.rlp.headers import BlockHeader

from trinity.protocol.common.normalizers import BaseNormalizer
from trinity.rlp.block_body import BlockBody

TResult = TypeVar('TResult')
LESNormalizer = BaseNormalizer[Dict[str, Any], TResult]


# Q: Shouldn't this be named GetBlockHeadersNormalizer?
class BlockHeadersNormalizer(LESNormalizer[Tuple[BlockHeader, ...]]):
@staticmethod
def normalize_result(message: Dict[str, Any]) -> Tuple[BlockHeader, ...]:
result = message['headers']
return result


class GetBlockBodiesNormalizer(LESNormalizer[Tuple[BlockBody, ...]]):
@staticmethod
def normalize_result(message: Dict[str, Any]) -> Tuple[BlockBody, ...]:
result = message['bodies']
return result
21 changes: 20 additions & 1 deletion trinity/protocol/les/requests.py
@@ -1,9 +1,13 @@
from typing import (
Any,
Dict,
Tuple,
)

from eth_typing import BlockIdentifier
from eth_typing import (
BlockIdentifier,
Hash32,
)

from p2p.protocol import BaseRequest

Expand All @@ -16,6 +20,8 @@
BlockHeaders,
GetBlockHeaders,
GetBlockHeadersQuery,
BlockBodies,
GetBlockBodies,
)


Expand Down Expand Up @@ -59,3 +65,16 @@ def __init__(self,
reverse,
),
}


class GetBlockBodiesRequest(BaseRequest[Tuple[Hash32, ...]]):
cmd_type = GetBlockBodies
response_type = BlockBodies

def __init__(self,
block_hashes: Tuple[Hash32, ...],
request_id: int) -> None:
self.command_payload = {
'request_id': request_id,
'block_hashes': block_hashes,
}
20 changes: 20 additions & 0 deletions trinity/protocol/les/trackers.py
Expand Up @@ -6,10 +6,12 @@
from eth.rlp.headers import BlockHeader

from trinity.protocol.common.trackers import BasePerformanceTracker
from trinity.rlp.block_body import BlockBody
from trinity.utils.headers import sequence_builder

from .requests import (
GetBlockHeadersRequest,
GetBlockBodiesRequest,
)


Expand Down Expand Up @@ -37,3 +39,21 @@ def _get_result_size(self, result: Tuple[BlockHeader, ...]) -> int:

def _get_result_item_count(self, result: Tuple[BlockHeader, ...]) -> int:
return len(result)


BaseGetBlockBodiesTracker = BasePerformanceTracker[
GetBlockBodiesRequest,
Tuple[BlockBody, ...],
]


# Q: Where can I find the signature for this class?
class GetBlockBodiesTracker(BaseGetBlockBodiesTracker):
def _get_request_size(self, request: GetBlockBodiesRequest) -> Optional[int]:
return len(request.command_payload['block_hashes'])

def _get_result_size(self, result: Tuple[BlockBody, ...]) -> int:
return len(result)

def _get_result_item_count(self, result: Tuple[BlockBody, ...]) -> int:
return len(result)
27 changes: 27 additions & 0 deletions trinity/protocol/les/validators.py
@@ -1,22 +1,49 @@
from typing import (
Any,
Dict,
Tuple,
)

from eth.rlp.headers import BlockHeader

from eth_utils import (
ValidationError,
)

from trinity.protocol.common.validators import (
BaseValidator,
BaseBlockHeadersValidator,
)
from trinity.protocol.common.types import BlockBodyBundles

from . import constants


class GetBlockHeadersValidator(BaseBlockHeadersValidator):
protocol_max_request_size = constants.MAX_HEADERS_FETCH


class GetBlockBodiesValidator(BaseValidator[BlockBodyBundles]):
def __init__(self, headers: Tuple[BlockHeader, ...]) -> None:
self.headers = headers

def validate_result(self, response: BlockBodyBundles) -> None:
expected_keys = {
(header.transaction_root, header.uncles_hash)
for header in self.headers
}
actual_keys = {
(txn_root, uncles_hash)
for body, (txn_root, trie_data), uncles_hash
in response
}
unexpected_keys = actual_keys.difference(expected_keys)
if unexpected_keys:
raise ValidationError(
"Got {0} unexpected block bodies".format(len(unexpected_keys))
)


def match_payload_request_id(request: Dict[str, Any], response: Dict[str, Any]) -> None:
if request['request_id'] != response['request_id']:
raise ValidationError("Request `id` does not match")

0 comments on commit 353357c

Please sign in to comment.