From 5bc7acf2857284d12f6349cae8165bf085183421 Mon Sep 17 00:00:00 2001 From: Jason Carver Date: Thu, 4 Oct 2018 14:55:56 -0700 Subject: [PATCH] Use BaseChainPeer instead of HeaderRequestingPeer --- trinity/protocol/common/handlers.py | 16 ++++++++++++++++ trinity/protocol/common/peer.py | 12 ++++++++++++ trinity/protocol/eth/handlers.py | 5 ++--- trinity/protocol/les/handlers.py | 6 ++---- trinity/sync/common/chain.py | 28 ++++++++++++---------------- trinity/sync/full/chain.py | 6 ++---- trinity/sync/light/chain.py | 7 ++----- trinity/sync/light/service.py | 7 ++++++- 8 files changed, 54 insertions(+), 33 deletions(-) diff --git a/trinity/protocol/common/handlers.py b/trinity/protocol/common/handlers.py index 200f3e2bac..df11d72238 100644 --- a/trinity/protocol/common/handlers.py +++ b/trinity/protocol/common/handlers.py @@ -1,11 +1,16 @@ from abc import ABC, abstractmethod from typing import ( Any, + Awaitable, + Callable, Dict, Iterator, + Tuple, Type, ) +from mypy_extensions import DefaultArg +from eth.rlp.headers import BlockHeader from p2p.peer import BasePeer from trinity.protocol.common.exchanges import ( @@ -46,3 +51,14 @@ def get_stats(self) -> Dict[str, str]: for exchange in self } + + +# This class is only needed to please mypy for type checking +BlockHeadersCallable = Callable[ + [BaseExchangeHandler, int, int, DefaultArg(int, 'skip'), DefaultArg(int, 'reverse')], + Awaitable[Tuple[BlockHeader, ...]] +] + + +class BaseChainExchangeHandler(BaseExchangeHandler): + get_block_headers: BlockHeadersCallable diff --git a/trinity/protocol/common/peer.py b/trinity/protocol/common/peer.py index fcccacba2d..45429ed27f 100644 --- a/trinity/protocol/common/peer.py +++ b/trinity/protocol/common/peer.py @@ -1,3 +1,4 @@ +from abc import abstractmethod import operator import random from typing import ( @@ -28,6 +29,7 @@ ) from trinity.db.header import BaseAsyncHeaderDB +from trinity.protocol.common.handlers import BaseChainExchangeHandler from .boot import DAOCheckBootManager from .context import ChainContext @@ -47,6 +49,16 @@ class BaseChainPeer(BasePeer): head_td: int = None head_hash: Hash32 = None + @property + @abstractmethod + def requests(self) -> BaseChainExchangeHandler: + pass + + @property + @abstractmethod + def max_headers_fetch(self) -> int: + pass + @property def headerdb(self) -> BaseAsyncHeaderDB: return self.context.headerdb diff --git a/trinity/protocol/eth/handlers.py b/trinity/protocol/eth/handlers.py index d5c45b200c..0f67c78359 100644 --- a/trinity/protocol/eth/handlers.py +++ b/trinity/protocol/eth/handlers.py @@ -1,5 +1,5 @@ from trinity.protocol.common.handlers import ( - BaseExchangeHandler, + BaseChainExchangeHandler, ) from .exchanges import ( @@ -10,7 +10,7 @@ ) -class ETHExchangeHandler(BaseExchangeHandler): +class ETHExchangeHandler(BaseChainExchangeHandler): _exchange_config = { 'get_block_bodies': GetBlockBodiesExchange, 'get_block_headers': GetBlockHeadersExchange, @@ -20,6 +20,5 @@ class ETHExchangeHandler(BaseExchangeHandler): # These are needed only to please mypy. get_block_bodies: GetBlockBodiesExchange - get_block_headers: GetBlockHeadersExchange get_node_data: GetNodeDataExchange get_receipts: GetReceiptsExchange diff --git a/trinity/protocol/les/handlers.py b/trinity/protocol/les/handlers.py index b1555c5159..ce12300016 100644 --- a/trinity/protocol/les/handlers.py +++ b/trinity/protocol/les/handlers.py @@ -1,13 +1,11 @@ from trinity.protocol.common.handlers import ( - BaseExchangeHandler, + BaseChainExchangeHandler, ) from .exchanges import GetBlockHeadersExchange -class LESExchangeHandler(BaseExchangeHandler): +class LESExchangeHandler(BaseChainExchangeHandler): _exchange_config = { 'get_block_headers': GetBlockHeadersExchange, } - - get_block_headers: GetBlockHeadersExchange diff --git a/trinity/sync/common/chain.py b/trinity/sync/common/chain.py index 4bc0ef51fc..bb6a01b144 100644 --- a/trinity/sync/common/chain.py +++ b/trinity/sync/common/chain.py @@ -6,7 +6,6 @@ AsyncIterator, Iterator, Tuple, - Union, cast, ) @@ -34,13 +33,10 @@ from trinity.db.header import AsyncHeaderDB from trinity.p2p.handlers import PeerRequestHandler -from trinity.protocol.eth.peer import ETHPeer, ETHPeerPool -from trinity.protocol.les.peer import LESPeer, LESPeerPool +from trinity.protocol.common.peer import BaseChainPeer, BaseChainPeerPool +from trinity.protocol.eth.peer import ETHPeer from trinity.utils.datastructures import TaskQueue -HeaderRequestingPeer = Union[ETHPeer, LESPeer] -AnyPeerPool = Union[ETHPeerPool, LESPeerPool] - class BaseHeaderChainSyncer(BaseService, PeerSubscriber): """ @@ -57,14 +53,14 @@ class BaseHeaderChainSyncer(BaseService, PeerSubscriber): def __init__(self, chain: AsyncChain, db: AsyncHeaderDB, - peer_pool: AnyPeerPool, + peer_pool: BaseChainPeerPool, token: CancelToken = None) -> None: super().__init__(token) self.chain = chain self.db = db self.peer_pool = peer_pool self._handler = PeerRequestHandler(self.db, self.logger, self.cancel_token) - self._sync_requests: asyncio.Queue[HeaderRequestingPeer] = asyncio.Queue() + self._sync_requests: asyncio.Queue[BaseChainPeer] = asyncio.Queue() self._peer_header_syncer: 'PeerHeaderSyncer' = None self._last_target_header_hash = None @@ -89,7 +85,7 @@ def get_target_header_hash(self) -> Hash32: return self._last_target_header_hash def register_peer(self, peer: BasePeer) -> None: - self._sync_requests.put_nowait(cast(HeaderRequestingPeer, self.peer_pool.highest_td_peer)) + self._sync_requests.put_nowait(cast(BaseChainPeer, self.peer_pool.highest_td_peer)) async def _handle_msg_loop(self) -> None: while self.is_operational: @@ -97,9 +93,9 @@ async def _handle_msg_loop(self) -> None: # Our handle_msg() method runs cpu-intensive tasks in sub-processes so that the main # loop can keep processing msgs, and that's why we use self.run_task() instead of # awaiting for it to finish here. - self.run_task(self.handle_msg(cast(HeaderRequestingPeer, peer), cmd, msg)) + self.run_task(self.handle_msg(cast(BaseChainPeer, peer), cmd, msg)) - async def handle_msg(self, peer: HeaderRequestingPeer, cmd: protocol.Command, + async def handle_msg(self, peer: BaseChainPeer, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None: try: await self._handle_msg(peer, cmd, msg) @@ -128,7 +124,7 @@ def _syncing(self) -> bool: return self._peer_header_syncer is not None @contextmanager - def _get_peer_header_syncer(self, peer: HeaderRequestingPeer) -> Iterator['PeerHeaderSyncer']: + def _get_peer_header_syncer(self, peer: BaseChainPeer) -> Iterator['PeerHeaderSyncer']: if self._syncing: raise ValidationError("Cannot sync headers from two peers at the same time") @@ -150,7 +146,7 @@ def _get_peer_header_syncer(self, peer: HeaderRequestingPeer) -> Iterator['PeerH self._last_target_header_hash = self._peer_header_syncer.get_target_header_hash() self._peer_header_syncer = None - async def sync(self, peer: HeaderRequestingPeer) -> None: + async def sync(self, peer: BaseChainPeer) -> None: if self._syncing: self.logger.debug( "Got a NewBlock or a new peer, but already syncing so doing nothing") @@ -167,7 +163,7 @@ async def sync(self, peer: HeaderRequestingPeer) -> None: await self.wait(self.header_queue.add(new_headers)) @abstractmethod - async def _handle_msg(self, peer: HeaderRequestingPeer, cmd: protocol.Command, + async def _handle_msg(self, peer: BaseChainPeer, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None: raise NotImplementedError("Must be implemented by subclasses") @@ -184,7 +180,7 @@ class PeerHeaderSyncer(BaseService): def __init__(self, chain: AsyncChain, db: AsyncHeaderDB, - peer: HeaderRequestingPeer, + peer: BaseChainPeer, token: CancelToken = None) -> None: super().__init__(token) self.chain = chain @@ -346,7 +342,7 @@ async def next_header_batch(self) -> AsyncIterator[Tuple[BlockHeader, ...]]: start_at = last_received_header.block_number + 1 async def _request_headers( - self, peer: HeaderRequestingPeer, start_at: int) -> Tuple[BlockHeader, ...]: + self, peer: BaseChainPeer, start_at: int) -> Tuple[BlockHeader, ...]: """Fetch a batch of headers starting at start_at and return the ones we're missing.""" self.logger.debug("Requsting chain of headers from %s starting at #%d", peer, start_at) diff --git a/trinity/sync/full/chain.py b/trinity/sync/full/chain.py index c85d785d1a..2d8e4f0654 100644 --- a/trinity/sync/full/chain.py +++ b/trinity/sync/full/chain.py @@ -15,7 +15,6 @@ Set, Tuple, Type, - Union, cast, ) @@ -49,6 +48,7 @@ from trinity.db.chain import AsyncChainDB from trinity.db.header import AsyncHeaderDB +from trinity.protocol.common.peer import BaseChainPeer from trinity.protocol.eth import commands from trinity.protocol.eth.constants import ( MAX_BODIES_FETCH, @@ -57,7 +57,6 @@ ) from trinity.protocol.eth.peer import ETHPeer, ETHPeerPool from trinity.protocol.eth.requests import HeaderRequest -from trinity.protocol.les.peer import LESPeer from trinity.rlp.block_body import BlockBody from trinity.sync.common.chain import BaseHeaderChainSyncer from trinity.utils.datastructures import ( @@ -69,7 +68,6 @@ ) from trinity.utils.timer import Timer -HeaderRequestingPeer = Union[LESPeer, ETHPeer] # (ReceiptBundle, (Receipt, (root_hash, receipt_trie_data)) ReceiptBundle = Tuple[Tuple[Receipt, ...], Tuple[Hash32, Dict[Hash32, bytes]]] # (BlockBody, (txn_root, txn_trie_data), uncles_hash) @@ -334,7 +332,7 @@ async def _request_block_bodies( return block_body_bundles - async def _handle_msg(self, peer: HeaderRequestingPeer, cmd: protocol.Command, + async def _handle_msg(self, peer: BaseChainPeer, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None: peer = cast(ETHPeer, peer) diff --git a/trinity/sync/light/chain.py b/trinity/sync/light/chain.py index 2f5cb1f0c9..7076e40135 100644 --- a/trinity/sync/light/chain.py +++ b/trinity/sync/light/chain.py @@ -4,7 +4,6 @@ Dict, Set, Type, - Union, ) from p2p.protocol import ( @@ -12,15 +11,13 @@ _DecodedMsgType, ) -from trinity.protocol.eth.peer import ETHPeer +from trinity.protocol.common.peer import BaseChainPeer from trinity.protocol.les import commands from trinity.protocol.les.peer import LESPeer from trinity.protocol.les.requests import HeaderRequest from trinity.sync.common.chain import BaseHeaderChainSyncer from trinity.utils.timer import Timer -HeaderRequestingPeer = Union[ETHPeer, LESPeer] - class LightChainSyncer(BaseHeaderChainSyncer): _exit_on_sync_complete = False @@ -35,7 +32,7 @@ async def _run(self) -> None: self.run_task(self._persist_headers()) await super()._run() - async def _handle_msg(self, peer: HeaderRequestingPeer, cmd: Command, + async def _handle_msg(self, peer: BaseChainPeer, cmd: Command, msg: _DecodedMsgType) -> None: if isinstance(cmd, commands.Announce): self._sync_requests.put_nowait(peer) diff --git a/trinity/sync/light/service.py b/trinity/sync/light/service.py index ce62a2901f..ba801c8bd7 100644 --- a/trinity/sync/light/service.py +++ b/trinity/sync/light/service.py @@ -346,7 +346,12 @@ async def _get_block_header_by_hash(self, block_hash: Hash32, peer: LESPeer) -> """ self.logger.debug("Fetching header %s from %s", encode_hex(block_hash), peer) max_headers = 1 - headers = await peer.requests.get_block_headers(block_hash, max_headers, 0, False) + headers = await peer.requests.get_block_headers( + block_hash, + max_headers, + skip=0, + reverse=False, + ) if not headers: raise HeaderNotFound("Peer {} has no block with hash {}".format(peer, block_hash)) header = headers[0]