diff --git a/p2p/peer.py b/p2p/peer.py index abdb679058..85a34e9a7a 100644 --- a/p2p/peer.py +++ b/p2p/peer.py @@ -2,6 +2,7 @@ import collections import contextlib import datetime +import functools import logging import operator import random @@ -20,6 +21,7 @@ Dict, Iterator, List, + Set, TYPE_CHECKING, Tuple, Type, @@ -381,14 +383,23 @@ def handle_p2p_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) - raise UnexpectedMessage("Unexpected msg: {} ({})".format(cmd, msg)) def handle_sub_proto_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None: + cmd_type = type(cmd) + if self._subscribers: - for subscriber in self._subscribers: + was_added = tuple( subscriber.add_msg((self, cmd, msg)) + for subscriber + in self._subscribers + ) + if not any(was_added): + self.logger.warn( + "Peer %s has no subscribers for msg type %s", + self, + cmd_type.__name__, + ) else: self.logger.warn("Peer %s has no subscribers, discarding %s msg", self, cmd) - cmd_type = type(cmd) - if cmd_type in self.pending_requests: request, future = self.pending_requests[cmd_type] try: @@ -547,10 +558,28 @@ def __hash__(self) -> int: class PeerSubscriber(ABC): _msg_queue: 'asyncio.Queue[PEER_MSG_TYPE]' = None + @property + @abstractmethod + def subscription_msg_types(self) -> Set[Type[protocol.Command]]: + """ + The `p2p.protocol.Command` types that this class subscribes to. Any + command which is not in this set will not be passed to this subscriber. + + The base command class `p2p.protocol.Command` can be used to enable + **all** command types. + """ + pass + + @functools.lru_cache(maxsize=64) + def is_subscription_command(self, cmd_type: Type[protocol.Command]) -> bool: + return bool(self.subscription_msg_types.intersection( + {cmd_type, protocol.Command} + )) + @property @abstractmethod def msg_queue_maxsize(self) -> int: - raise NotImplementedError("Must be implemented by subclasses") + pass def register_peer(self, peer: BasePeer) -> None: """ @@ -577,16 +606,30 @@ def msg_queue(self) -> 'asyncio.Queue[PEER_MSG_TYPE]': def queue_size(self) -> int: return self.msg_queue.qsize() - def add_msg(self, msg: 'PEER_MSG_TYPE') -> None: + def add_msg(self, msg: 'PEER_MSG_TYPE') -> bool: peer, cmd, _ = msg + + if not self.is_subscription_command(type(cmd)): + if hasattr(self, 'logger'): + self.logger.trace( # type: ignore + "Discarding %s msg from %s; not subscribed to msg type; " + "subscriptions: %s", + cmd, peer, self.subscription_msg_types, + ) + return False + try: - self.logger.trace( # type: ignore - "Adding %s msg from %s to queue; queue_size=%d", cmd, peer, self.queue_size) + if hasattr(self, 'logger'): + self.logger.trace( # type: ignore + "Adding %s msg from %s to queue; queue_size=%d", cmd, peer, self.queue_size) self.msg_queue.put_nowait(msg) + return True except asyncio.queues.QueueFull: - self.logger.warn( # type: ignore - "%s msg queue is full; discarding %s msg from %s", - self.__class__.__name__, cmd, peer) + if hasattr(self, 'logger'): + self.logger.warn( # type: ignore + "%s msg queue is full; discarding %s msg from %s", + self.__class__.__name__, cmd, peer) + return False @contextlib.contextmanager def subscribe(self, peer_pool: 'PeerPool') -> Iterator[None]: diff --git a/tests/p2p/test_peer_subscriber.py b/tests/p2p/test_peer_subscriber.py new file mode 100644 index 0000000000..49d2d62da9 --- /dev/null +++ b/tests/p2p/test_peer_subscriber.py @@ -0,0 +1,51 @@ +import asyncio + +import pytest + +from p2p.peer import PeerSubscriber +from p2p.protocol import Command + +from trinity.protocol.eth.peer import ETHPeer +from trinity.protocol.eth.commands import GetBlockHeaders + +from tests.trinity.core.peer_helpers import ( + get_directly_linked_peers, +) + + +class HeadersSubscriber(PeerSubscriber): + msg_queue_maxsize = 10 + subscription_msg_types = {GetBlockHeaders} + + +class AllSubscriber(PeerSubscriber): + msg_queue_maxsize = 10 + subscription_msg_types = {Command} + + +@pytest.mark.asyncio +async def test_peer_subscriber_filters_messages(request, event_loop): + peer, remote = await get_directly_linked_peers( + request, + event_loop, + peer1_class=ETHPeer, + peer2_class=ETHPeer, + ) + + header_subscriber = HeadersSubscriber() + all_subscriber = AllSubscriber() + + peer.add_subscriber(header_subscriber) + peer.add_subscriber(all_subscriber) + + remote.sub_proto.send_get_node_data([b'\x00' * 32]) + remote.sub_proto.send_get_block_headers(0, 1, 0, False) + remote.sub_proto.send_get_node_data([b'\x00' * 32]) + remote.sub_proto.send_get_block_headers(1, 1, 0, False) + remote.sub_proto.send_get_node_data([b'\x00' * 32]) + + # yeild to let remote and peer transmit. + await asyncio.sleep(0.01) + + assert header_subscriber.queue_size == 2 + assert all_subscriber.queue_size == 5 diff --git a/tests/trinity/core/peer_helpers.py b/tests/trinity/core/peer_helpers.py index d5d07b07a6..008b461ca2 100644 --- a/tests/trinity/core/peer_helpers.py +++ b/tests/trinity/core/peer_helpers.py @@ -1,6 +1,8 @@ import asyncio import os -from typing import List +from typing import ( + List, +) from eth_hash.auto import keccak @@ -16,6 +18,7 @@ from p2p import kademlia from p2p.auth import decode_authentication from p2p.peer import BasePeer, PeerPool, PeerSubscriber +from p2p.protocol import Command from trinity.protocol.les.peer import LESPeer @@ -174,6 +177,8 @@ async def _run(self) -> None: class SamplePeerSubscriber(PeerSubscriber): logger = TraceLogger("") + subscription_msg_types = {Command} + @property def msg_queue_maxsize(self) -> int: return 100 diff --git a/trinity/plugins/builtin/tx_pool/pool.py b/trinity/plugins/builtin/tx_pool/pool.py index 9ceb2347c7..759368d973 100644 --- a/trinity/plugins/builtin/tx_pool/pool.py +++ b/trinity/plugins/builtin/tx_pool/pool.py @@ -2,7 +2,9 @@ cast, Callable, Iterable, - List + List, + Set, + Type, ) import uuid @@ -21,6 +23,7 @@ PeerPool, PeerSubscriber, ) +from p2p.protocol import Command from p2p.service import ( BaseService ) @@ -59,12 +62,12 @@ def __init__(self, self._bloom = BloomFilter(max_elements=1000000) self._bloom_salt = str(uuid.uuid4()) - @property - def msg_queue_maxsize(self) -> int: - # This is a rather arbitrary value, but when the sync is operating normally we never see - # the msg queue grow past a few hundred items, so this should be a reasonable limit for - # now. - return 2000 + subscription_msg_types: Set[Type[Command]] = {Transactions} + + # This is a rather arbitrary value, but when the sync is operating normally we never see + # the msg queue grow past a few hundred items, so this should be a reasonable limit for + # now. + msg_queue_maxsize: int = 2000 async def _run(self) -> None: self.logger.info("Running Tx Pool") @@ -74,8 +77,8 @@ async def _run(self) -> None: peer, cmd, msg = await self.wait( self.msg_queue.get(), token=self.cancel_token) peer = cast(ETHPeer, peer) - msg = cast(List[BaseTransactionFields], msg) if isinstance(cmd, Transactions): + msg = cast(List[BaseTransactionFields], msg) await self._handle_tx(peer, msg) async def _handle_tx(self, peer: ETHPeer, txs: List[BaseTransactionFields]) -> None: diff --git a/trinity/sync/full/chain.py b/trinity/sync/full/chain.py index 7a9549f24e..8beeca70dc 100644 --- a/trinity/sync/full/chain.py +++ b/trinity/sync/full/chain.py @@ -6,7 +6,9 @@ Dict, List, NamedTuple, + Set, Tuple, + Type, Union, cast, ) @@ -32,9 +34,15 @@ from p2p.exceptions import NoEligiblePeers from p2p.p2p_proto import DisconnectReason from p2p.peer import PeerPool +from p2p.protocol import Command from trinity.db.chain import AsyncChainDB +from trinity.protocol.eth import commands +from trinity.protocol.eth import ( + constants as eth_constants, +) from trinity.protocol.eth.peer import ETHPeer +from trinity.protocol.eth.requests import HeaderRequest from trinity.protocol.les.peer import LESPeer from trinity.rlp.block_body import BlockBody from trinity.sync.base_chain_syncer import BaseHeaderChainSyncer @@ -66,6 +74,19 @@ def __init__(self, self._downloaded_receipts: asyncio.Queue[Tuple[ETHPeer, List[DownloadedBlockPart]]] = asyncio.Queue() # noqa: E501 self._downloaded_bodies: asyncio.Queue[Tuple[ETHPeer, List[DownloadedBlockPart]]] = asyncio.Queue() # noqa: E501 + subscription_msg_types: Set[Type[Command]] = { + commands.BlockBodies, + commands.Receipts, + commands.NewBlock, + commands.GetBlockHeaders, + commands.BlockHeaders, + commands.GetBlockBodies, + commands.GetReceipts, + commands.GetNodeData, + commands.Transactions, + commands.NodeData, + } + async def _calculate_td(self, headers: Tuple[BlockHeader, ...]) -> int: """Return the score (total difficulty) of the last header in the given list. @@ -191,7 +212,6 @@ def _request_block_parts( target_td: int, headers: List[BlockHeader], request_func: Callable[[ETHPeer, List[BlockHeader]], None]) -> int: - from trinity.protocol.eth.peer import ETHPeer # noqa: F811 peers = self.peer_pool.get_peers(target_td) if not peers: raise NoEligiblePeers() @@ -235,12 +255,6 @@ def request_receipts(self, target_td: int, headers: List[BlockHeader]) -> int: async def _handle_msg(self, peer: HeaderRequestingPeer, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None: - from trinity.protocol.eth.peer import ETHPeer # noqa: F811 - from trinity.protocol.eth import commands - from trinity.protocol.eth import ( - constants as eth_constants, - ) - peer = cast(ETHPeer, peer) if isinstance(cmd, commands.BlockBodies): @@ -318,8 +332,6 @@ async def _handle_get_block_headers( self, peer: ETHPeer, query: Dict[str, Any]) -> None: - from trinity.protocol.eth.requests import HeaderRequest # noqa: F811 - self.logger.debug("Peer %s made header request: %s", peer, query) request = HeaderRequest( query['block_number_or_hash'], diff --git a/trinity/sync/full/state.py b/trinity/sync/full/state.py index 5d8aa16f4a..284cf7c280 100644 --- a/trinity/sync/full/state.py +++ b/trinity/sync/full/state.py @@ -11,6 +11,7 @@ List, Set, Tuple, + Type, Union, ) @@ -88,12 +89,21 @@ def __init__(self, self._peer_missing_nodes: Dict[ETHPeer, Set[Hash32]] = collections.defaultdict(set) self._executor = get_asyncio_executor() - @property - def msg_queue_maxsize(self) -> int: - # This is a rather arbitrary value, but when the sync is operating normally we never see - # the msg queue grow past a few hundred items, so this should be a reasonable limit for - # now. - return 2000 + # Throughout the whole state sync our chain head is fixed, so it makes sense to ignore + # messages related to new blocks/transactions, but we must handle requests for data from + # other peers or else they will disconnect from us. + subscription_msg_types: Set[Type[Command]] = { + commands.NodeData, + commands.GetBlockHeaders, + commands.GetBlockBodies, + commands.GetReceipts, + commands.GetNodeData, + } + + # This is a rather arbitrary value, but when the sync is operating normally we never see + # the msg queue grow past a few hundred items, so this should be a reasonable limit for + # now. + msg_queue_maxsize: int = 2000 def deregister_peer(self, peer: BasePeer) -> None: # Use .pop() with a default value as it's possible we never requested anything to this @@ -154,13 +164,7 @@ async def _process_nodes(self, nodes: Iterable[Tuple[Hash32, bytes]]) -> None: async def _handle_msg( self, peer: ETHPeer, cmd: Command, msg: _DecodedMsgType) -> None: - # Throughout the whole state sync our chain head is fixed, so it makes sense to ignore - # messages related to new blocks/transactions, but we must handle requests for data from - # other peers or else they will disconnect from us. - ignored_commands = (commands.Transactions, commands.NewBlock, commands.NewBlockHashes) - if isinstance(cmd, ignored_commands): - pass - elif isinstance(cmd, commands.NodeData): + if isinstance(cmd, commands.NodeData): msg = cast(List[bytes], msg) if peer not in self.request_tracker.active_requests: # This is probably a batch that we retried after a timeout and ended up receiving diff --git a/trinity/sync/light/chain.py b/trinity/sync/light/chain.py index e66db77e73..eaf3812e95 100644 --- a/trinity/sync/light/chain.py +++ b/trinity/sync/light/chain.py @@ -2,7 +2,9 @@ Any, cast, Dict, + Set, Tuple, + Type, Union, ) @@ -14,7 +16,9 @@ ) from trinity.protocol.eth.peer import ETHPeer +from trinity.protocol.les import commands from trinity.protocol.les.peer import LESPeer +from trinity.protocol.les.requests import HeaderRequest from trinity.sync.base_chain_syncer import BaseHeaderChainSyncer from trinity.utils.timer import Timer @@ -25,10 +29,14 @@ class LightChainSyncer(BaseHeaderChainSyncer): _exit_on_sync_complete = False + subscription_msg_types: Set[Type[Command]] = { + commands.Announce, + commands.GetBlockHeaders, + commands.BlockHeaders, + } + async def _handle_msg(self, peer: HeaderRequestingPeer, cmd: Command, msg: _DecodedMsgType) -> None: - from trinity.protocol.les import commands - from trinity.protocol.les.peer import LESPeer # noqa: F811 if isinstance(cmd, commands.Announce): self._sync_requests.put_nowait(peer) elif isinstance(cmd, commands.GetBlockHeaders): @@ -41,7 +49,6 @@ async def _handle_msg(self, peer: HeaderRequestingPeer, cmd: Command, self.logger.debug("Ignoring %s message from %s", cmd, peer) async def _handle_get_block_headers(self, peer: LESPeer, msg: Dict[str, Any]) -> None: - from trinity.protocol.les.requests import HeaderRequest self.logger.debug("Peer %s made header request: %s", peer, msg) request = HeaderRequest( msg['query'].block_number_or_hash, diff --git a/trinity/sync/light/service.py b/trinity/sync/light/service.py index 58cc7bca6f..1c362ee62d 100644 --- a/trinity/sync/light/service.py +++ b/trinity/sync/light/service.py @@ -8,6 +8,8 @@ cast, Dict, List, + Set, + Type, ) from async_lru import alru_cache @@ -57,6 +59,7 @@ PeerPool, PeerSubscriber, ) +from p2p.protocol import Command from p2p.service import ( BaseService, service_timeout, @@ -83,11 +86,12 @@ def __init__( self.peer_pool = peer_pool self._pending_replies: Dict[int, Callable[[protocol._DecodedMsgType], None]] = {} - @property - def msg_queue_maxsize(self) -> int: - # Here we only care about replies to our requests, ignoring most msgs (which are supposed - # to be handled by the chain syncer), so our queue should never grow too much. - return 500 + # TODO: be more specific about what messages we want. + subscription_msg_types: Set[Type[Command]] = {Command} + + # Here we only care about replies to our requests, ignoring most msgs (which are supposed + # to be handled by the chain syncer), so our queue should never grow too much. + msg_queue_maxsize = 500 async def _run(self) -> None: with self.subscribe(self.peer_pool): diff --git a/trinity/sync/sharding/service.py b/trinity/sync/sharding/service.py index 1e3902a6f8..71c63742aa 100644 --- a/trinity/sync/sharding/service.py +++ b/trinity/sync/sharding/service.py @@ -7,6 +7,7 @@ cast, Dict, Set, + Type, ) from cytoolz import ( @@ -42,6 +43,7 @@ CollationBodyNotFound, ) +from p2p.protocol import Command from p2p.service import BaseService from p2p.peer import ( PeerPool, @@ -77,12 +79,12 @@ def __init__(self, shard: Shard, peer_pool: PeerPool, token: CancelToken=None) - self.start_time = time.time() - @property - def msg_queue_maxsize(self) -> int: - # This is a rather arbitrary value, but when the sync is operating normally we never see - # the msg queue grow past a few hundred items, so this should be a reasonable limit for - # now. - return 2000 + subscription_msg_types: Set[Type[Command]] = {Collations, GetCollations, NewCollationHashes} + + # This is a rather arbitrary value, but when the sync is operating normally we never see + # the msg queue grow past a few hundred items, so this should be a reasonable limit for + # now. + msg_queue_maxsize = 2000 async def _cleanup(self) -> None: pass