Skip to content

Commit

Permalink
Use BaseChainPeer instead of HeaderRequestingPeer
Browse files Browse the repository at this point in the history
  • Loading branch information
carver committed Oct 4, 2018
1 parent 3061a48 commit 5bc7acf
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 33 deletions.
16 changes: 16 additions & 0 deletions trinity/protocol/common/handlers.py
@@ -1,11 +1,16 @@
from abc import ABC, abstractmethod
from typing import (
Any,
Awaitable,
Callable,
Dict,
Iterator,
Tuple,
Type,
)
from mypy_extensions import DefaultArg

from eth.rlp.headers import BlockHeader
from p2p.peer import BasePeer

from trinity.protocol.common.exchanges import (
Expand Down Expand Up @@ -46,3 +51,14 @@ def get_stats(self) -> Dict[str, str]:
for exchange
in self
}


# This class is only needed to please mypy for type checking
BlockHeadersCallable = Callable[
[BaseExchangeHandler, int, int, DefaultArg(int, 'skip'), DefaultArg(int, 'reverse')],
Awaitable[Tuple[BlockHeader, ...]]
]


class BaseChainExchangeHandler(BaseExchangeHandler):
get_block_headers: BlockHeadersCallable
12 changes: 12 additions & 0 deletions trinity/protocol/common/peer.py
@@ -1,3 +1,4 @@
from abc import abstractmethod
import operator
import random
from typing import (
Expand Down Expand Up @@ -28,6 +29,7 @@
)

from trinity.db.header import BaseAsyncHeaderDB
from trinity.protocol.common.handlers import BaseChainExchangeHandler

from .boot import DAOCheckBootManager
from .context import ChainContext
Expand All @@ -47,6 +49,16 @@ class BaseChainPeer(BasePeer):
head_td: int = None
head_hash: Hash32 = None

@property
@abstractmethod
def requests(self) -> BaseChainExchangeHandler:
pass

@property
@abstractmethod
def max_headers_fetch(self) -> int:
pass

@property
def headerdb(self) -> BaseAsyncHeaderDB:
return self.context.headerdb
Expand Down
5 changes: 2 additions & 3 deletions trinity/protocol/eth/handlers.py
@@ -1,5 +1,5 @@
from trinity.protocol.common.handlers import (
BaseExchangeHandler,
BaseChainExchangeHandler,
)

from .exchanges import (
Expand All @@ -10,7 +10,7 @@
)


class ETHExchangeHandler(BaseExchangeHandler):
class ETHExchangeHandler(BaseChainExchangeHandler):
_exchange_config = {
'get_block_bodies': GetBlockBodiesExchange,
'get_block_headers': GetBlockHeadersExchange,
Expand All @@ -20,6 +20,5 @@ class ETHExchangeHandler(BaseExchangeHandler):

# These are needed only to please mypy.
get_block_bodies: GetBlockBodiesExchange
get_block_headers: GetBlockHeadersExchange
get_node_data: GetNodeDataExchange
get_receipts: GetReceiptsExchange
6 changes: 2 additions & 4 deletions trinity/protocol/les/handlers.py
@@ -1,13 +1,11 @@
from trinity.protocol.common.handlers import (
BaseExchangeHandler,
BaseChainExchangeHandler,
)

from .exchanges import GetBlockHeadersExchange


class LESExchangeHandler(BaseExchangeHandler):
class LESExchangeHandler(BaseChainExchangeHandler):
_exchange_config = {
'get_block_headers': GetBlockHeadersExchange,
}

get_block_headers: GetBlockHeadersExchange
28 changes: 12 additions & 16 deletions trinity/sync/common/chain.py
Expand Up @@ -6,7 +6,6 @@
AsyncIterator,
Iterator,
Tuple,
Union,
cast,
)

Expand Down Expand Up @@ -34,13 +33,10 @@

from trinity.db.header import AsyncHeaderDB
from trinity.p2p.handlers import PeerRequestHandler
from trinity.protocol.eth.peer import ETHPeer, ETHPeerPool
from trinity.protocol.les.peer import LESPeer, LESPeerPool
from trinity.protocol.common.peer import BaseChainPeer, BaseChainPeerPool
from trinity.protocol.eth.peer import ETHPeer
from trinity.utils.datastructures import TaskQueue

HeaderRequestingPeer = Union[ETHPeer, LESPeer]
AnyPeerPool = Union[ETHPeerPool, LESPeerPool]


class BaseHeaderChainSyncer(BaseService, PeerSubscriber):
"""
Expand All @@ -57,14 +53,14 @@ class BaseHeaderChainSyncer(BaseService, PeerSubscriber):
def __init__(self,
chain: AsyncChain,
db: AsyncHeaderDB,
peer_pool: AnyPeerPool,
peer_pool: BaseChainPeerPool,
token: CancelToken = None) -> None:
super().__init__(token)
self.chain = chain
self.db = db
self.peer_pool = peer_pool
self._handler = PeerRequestHandler(self.db, self.logger, self.cancel_token)
self._sync_requests: asyncio.Queue[HeaderRequestingPeer] = asyncio.Queue()
self._sync_requests: asyncio.Queue[BaseChainPeer] = asyncio.Queue()
self._peer_header_syncer: 'PeerHeaderSyncer' = None
self._last_target_header_hash = None

Expand All @@ -89,17 +85,17 @@ def get_target_header_hash(self) -> Hash32:
return self._last_target_header_hash

def register_peer(self, peer: BasePeer) -> None:
self._sync_requests.put_nowait(cast(HeaderRequestingPeer, self.peer_pool.highest_td_peer))
self._sync_requests.put_nowait(cast(BaseChainPeer, self.peer_pool.highest_td_peer))

async def _handle_msg_loop(self) -> None:
while self.is_operational:
peer, cmd, msg = await self.wait(self.msg_queue.get())
# Our handle_msg() method runs cpu-intensive tasks in sub-processes so that the main
# loop can keep processing msgs, and that's why we use self.run_task() instead of
# awaiting for it to finish here.
self.run_task(self.handle_msg(cast(HeaderRequestingPeer, peer), cmd, msg))
self.run_task(self.handle_msg(cast(BaseChainPeer, peer), cmd, msg))

async def handle_msg(self, peer: HeaderRequestingPeer, cmd: protocol.Command,
async def handle_msg(self, peer: BaseChainPeer, cmd: protocol.Command,
msg: protocol._DecodedMsgType) -> None:
try:
await self._handle_msg(peer, cmd, msg)
Expand Down Expand Up @@ -128,7 +124,7 @@ def _syncing(self) -> bool:
return self._peer_header_syncer is not None

@contextmanager
def _get_peer_header_syncer(self, peer: HeaderRequestingPeer) -> Iterator['PeerHeaderSyncer']:
def _get_peer_header_syncer(self, peer: BaseChainPeer) -> Iterator['PeerHeaderSyncer']:
if self._syncing:
raise ValidationError("Cannot sync headers from two peers at the same time")

Expand All @@ -150,7 +146,7 @@ def _get_peer_header_syncer(self, peer: HeaderRequestingPeer) -> Iterator['PeerH
self._last_target_header_hash = self._peer_header_syncer.get_target_header_hash()
self._peer_header_syncer = None

async def sync(self, peer: HeaderRequestingPeer) -> None:
async def sync(self, peer: BaseChainPeer) -> None:
if self._syncing:
self.logger.debug(
"Got a NewBlock or a new peer, but already syncing so doing nothing")
Expand All @@ -167,7 +163,7 @@ async def sync(self, peer: HeaderRequestingPeer) -> None:
await self.wait(self.header_queue.add(new_headers))

@abstractmethod
async def _handle_msg(self, peer: HeaderRequestingPeer, cmd: protocol.Command,
async def _handle_msg(self, peer: BaseChainPeer, cmd: protocol.Command,
msg: protocol._DecodedMsgType) -> None:
raise NotImplementedError("Must be implemented by subclasses")

Expand All @@ -184,7 +180,7 @@ class PeerHeaderSyncer(BaseService):
def __init__(self,
chain: AsyncChain,
db: AsyncHeaderDB,
peer: HeaderRequestingPeer,
peer: BaseChainPeer,
token: CancelToken = None) -> None:
super().__init__(token)
self.chain = chain
Expand Down Expand Up @@ -346,7 +342,7 @@ async def next_header_batch(self) -> AsyncIterator[Tuple[BlockHeader, ...]]:
start_at = last_received_header.block_number + 1

async def _request_headers(
self, peer: HeaderRequestingPeer, start_at: int) -> Tuple[BlockHeader, ...]:
self, peer: BaseChainPeer, start_at: int) -> Tuple[BlockHeader, ...]:
"""Fetch a batch of headers starting at start_at and return the ones we're missing."""
self.logger.debug("Requsting chain of headers from %s starting at #%d", peer, start_at)

Expand Down
6 changes: 2 additions & 4 deletions trinity/sync/full/chain.py
Expand Up @@ -15,7 +15,6 @@
Set,
Tuple,
Type,
Union,
cast,
)

Expand Down Expand Up @@ -49,6 +48,7 @@

from trinity.db.chain import AsyncChainDB
from trinity.db.header import AsyncHeaderDB
from trinity.protocol.common.peer import BaseChainPeer
from trinity.protocol.eth import commands
from trinity.protocol.eth.constants import (
MAX_BODIES_FETCH,
Expand All @@ -57,7 +57,6 @@
)
from trinity.protocol.eth.peer import ETHPeer, ETHPeerPool
from trinity.protocol.eth.requests import HeaderRequest
from trinity.protocol.les.peer import LESPeer
from trinity.rlp.block_body import BlockBody
from trinity.sync.common.chain import BaseHeaderChainSyncer
from trinity.utils.datastructures import (
Expand All @@ -69,7 +68,6 @@
)
from trinity.utils.timer import Timer

HeaderRequestingPeer = Union[LESPeer, ETHPeer]
# (ReceiptBundle, (Receipt, (root_hash, receipt_trie_data))
ReceiptBundle = Tuple[Tuple[Receipt, ...], Tuple[Hash32, Dict[Hash32, bytes]]]
# (BlockBody, (txn_root, txn_trie_data), uncles_hash)
Expand Down Expand Up @@ -334,7 +332,7 @@ async def _request_block_bodies(

return block_body_bundles

async def _handle_msg(self, peer: HeaderRequestingPeer, cmd: protocol.Command,
async def _handle_msg(self, peer: BaseChainPeer, cmd: protocol.Command,
msg: protocol._DecodedMsgType) -> None:
peer = cast(ETHPeer, peer)

Expand Down
7 changes: 2 additions & 5 deletions trinity/sync/light/chain.py
Expand Up @@ -4,23 +4,20 @@
Dict,
Set,
Type,
Union,
)

from p2p.protocol import (
Command,
_DecodedMsgType,
)

from trinity.protocol.eth.peer import ETHPeer
from trinity.protocol.common.peer import BaseChainPeer
from trinity.protocol.les import commands
from trinity.protocol.les.peer import LESPeer
from trinity.protocol.les.requests import HeaderRequest
from trinity.sync.common.chain import BaseHeaderChainSyncer
from trinity.utils.timer import Timer

HeaderRequestingPeer = Union[ETHPeer, LESPeer]


class LightChainSyncer(BaseHeaderChainSyncer):
_exit_on_sync_complete = False
Expand All @@ -35,7 +32,7 @@ async def _run(self) -> None:
self.run_task(self._persist_headers())
await super()._run()

async def _handle_msg(self, peer: HeaderRequestingPeer, cmd: Command,
async def _handle_msg(self, peer: BaseChainPeer, cmd: Command,
msg: _DecodedMsgType) -> None:
if isinstance(cmd, commands.Announce):
self._sync_requests.put_nowait(peer)
Expand Down
7 changes: 6 additions & 1 deletion trinity/sync/light/service.py
Expand Up @@ -346,7 +346,12 @@ async def _get_block_header_by_hash(self, block_hash: Hash32, peer: LESPeer) ->
"""
self.logger.debug("Fetching header %s from %s", encode_hex(block_hash), peer)
max_headers = 1
headers = await peer.requests.get_block_headers(block_hash, max_headers, 0, False)
headers = await peer.requests.get_block_headers(
block_hash,
max_headers,
skip=0,
reverse=False,
)
if not headers:
raise HeaderNotFound("Peer {} has no block with hash {}".format(peer, block_hash))
header = headers[0]
Expand Down

0 comments on commit 5bc7acf

Please sign in to comment.