diff --git a/hivemind/p2p/p2p_daemon.py b/hivemind/p2p/p2p_daemon.py index 3083c70e5..1f441c5d1 100644 --- a/hivemind/p2p/p2p_daemon.py +++ b/hivemind/p2p/p2p_daemon.py @@ -1,45 +1,197 @@ +import asyncio +import contextlib +import copy +from pathlib import Path +import pickle +import socket import subprocess import typing as tp +import warnings + +from multiaddr import Multiaddr +import p2pclient +from libp2p.peer.id import ID class P2P(object): """ Forks a child process and executes p2pd command with given arguments. - Sends SIGKILL to the child in destructor and on exit from contextmanager. + Can be used for peer to peer communication and procedure calls. + Sends SIGKILL to the child in destructor. """ - LIBP2P_CMD = 'p2pd' + P2PD_RELATIVE_PATH = 'hivemind_cli/p2pd' + NUM_RETRIES = 3 + RETRY_DELAY = 0.4 + HEADER_LEN = 8 + BYTEORDER = 'big' - def __init__(self, *args, **kwargs): - self._child = subprocess.Popen(args=self._make_process_args(args, kwargs)) - try: - stdout, stderr = self._child.communicate(timeout=0.2) - except subprocess.TimeoutExpired: - pass - else: - raise RuntimeError(f'p2p daemon exited with stderr: {stderr}') + def __init__(self): + self._child = None + self._listen_task = None + self._server_stopped = asyncio.Event() + self._buffer = bytearray() - def __enter__(self): - return self._child + @classmethod + async def create(cls, *args, quic=1, tls=1, conn_manager=1, dht_client=1, + nat_port_map=True, auto_nat=True, bootstrap=True, + host_port: int = None, daemon_listen_port: int = None, **kwargs): + self = cls() + p2pd_path = Path(__file__).resolve().parents[1] / P2P.P2PD_RELATIVE_PATH + proc_args = self._make_process_args( + str(p2pd_path), *args, + quic=quic, tls=tls, connManager=conn_manager, + dhtClient=dht_client, natPortMap=nat_port_map, + autonat=auto_nat, b=bootstrap, **kwargs) + self._assign_daemon_ports(host_port, daemon_listen_port) + for try_count in range(self.NUM_RETRIES): + try: + self._initialize(proc_args) + await self._identify_client(P2P.RETRY_DELAY * (2 ** try_count)) + except Exception as exc: + warnings.warn("Failed to initialize p2p daemon: " + str(exc), RuntimeWarning) + self._kill_child() + if try_count == P2P.NUM_RETRIES - 1: + raise + self._assign_daemon_ports() + continue + break + return self - def __exit__(self, exc_type, exc_val, exc_tb): - self._kill_child() + def _initialize(self, proc_args: tp.List[str]) -> None: + proc_args = copy.deepcopy(proc_args) + proc_args.extend(self._make_process_args( + hostAddrs=f'/ip4/0.0.0.0/tcp/{self._host_port},/ip4/0.0.0.0/udp/{self._host_port}/quic', + listen=f'/ip4/127.0.0.1/tcp/{self._daemon_listen_port}' + )) + self._child = subprocess.Popen( + args=proc_args, + stdin=subprocess.PIPE, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, encoding="utf8" + ) + self._client_listen_port = find_open_port() + self._client = p2pclient.Client( + Multiaddr(f'/ip4/127.0.0.1/tcp/{self._daemon_listen_port}'), + Multiaddr(f'/ip4/127.0.0.1/tcp/{self._client_listen_port}')) + + async def _identify_client(self, delay): + await asyncio.sleep(delay) + encoded = await self._client.identify() + self.id = encoded[0].to_base58() + + def _assign_daemon_ports(self, host_port=None, daemon_listen_port=None): + self._host_port, self._daemon_listen_port = host_port, daemon_listen_port + if host_port is None: + self._host_port = find_open_port() + if daemon_listen_port is None: + self._daemon_listen_port = find_open_port() + while self._daemon_listen_port == self._host_port: + self._daemon_listen_port = find_open_port() + + @staticmethod + async def send_data(data, stream): + byte_str = pickle.dumps(data) + request = len(byte_str).to_bytes(P2P.HEADER_LEN, P2P.BYTEORDER) + byte_str + await stream.send_all(request) + + class IncompleteRead(Exception): + pass + + async def _receive_exactly(self, stream, n_bytes, max_bytes=1 << 16): + while len(self._buffer) < n_bytes: + data = await stream.receive_some(max_bytes) + if len(data) == 0: + raise P2P.IncompleteRead() + self._buffer.extend(data) + + result = self._buffer[:n_bytes] + self._buffer = self._buffer[n_bytes:] + return bytes(result) + + async def receive_data(self, stream, max_bytes=(1 < 16)): + header = await self._receive_exactly(stream, P2P.HEADER_LEN) + content_length = int.from_bytes(header, P2P.BYTEORDER) + data = await self._receive_exactly(stream, content_length) + return pickle.loads(data) + + def _handle_stream(self, handle): + async def do_handle_stream(stream_info, stream): + try: + request = await self.receive_data(stream) + except P2P.IncompleteRead: + warnings.warn("Incomplete read while receiving request from peer", RuntimeWarning) + return + finally: + stream.close() + try: + result = handle(request) + await self.send_data(result, stream) + except Exception as exc: + await self.send_data(exc, stream) + finally: + await stream.close() + + return do_handle_stream + + def start_listening(self): + async def listen(): + async with self._client.listen(): + await self._server_stopped.wait() + + self._listen_task = asyncio.create_task(listen()) + + async def stop_listening(self): + if self._listen_task is not None: + self._server_stopped.set() + self._listen_task.cancel() + try: + await self._listen_task + except asyncio.CancelledError: + self._listen_task = None + self._server_stopped.clear() + + async def add_stream_handler(self, name, handle): + if self._listen_task is None: + self.start_listening() + + await self._client.stream_handler(name, self._handle_stream(handle)) + + async def call_peer_handler(self, peer_id, handler_name, input_data): + libp2p_peer_id = ID.from_base58(peer_id) + stream_info, stream = await self._client.stream_open(libp2p_peer_id, (handler_name,)) + try: + await self.send_data(input_data, stream) + return await self.receive_data(stream) + finally: + await stream.close() def __del__(self): self._kill_child() def _kill_child(self): - if self._child.poll() is None: + if self._child is not None and self._child.poll() is None: self._child.kill() self._child.wait() - def _make_process_args(self, args: tp.Tuple[tp.Any], - kwargs: tp.Dict[str, tp.Any]) -> tp.List[str]: - proc_args = [self.LIBP2P_CMD] + def _make_process_args(self, *args, **kwargs) -> tp.List[str]: + proc_args = [] proc_args.extend( str(entry) for entry in args ) proc_args.extend( - f'-{key}={str(value)}' for key, value in kwargs.items() + f'-{key}={value}' if value is not None else f'-{key}' + for key, value in kwargs.items() ) return proc_args + + +def find_open_port(params=(socket.AF_INET, socket.SOCK_STREAM), + opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)): + """ Finds a tcp port that can be occupied with a socket with *params and use *opt options """ + try: + with contextlib.closing(socket.socket(*params)) as sock: + sock.bind(('', 0)) + sock.setsockopt(*opt) + return sock.getsockname()[1] + except Exception: + raise diff --git a/tests/test_p2p_daemon.py b/tests/test_p2p_daemon.py index ac57e9e2f..75fd51cdc 100644 --- a/tests/test_p2p_daemon.py +++ b/tests/test_p2p_daemon.py @@ -1,6 +1,8 @@ +import asyncio +import multiprocessing as mp import subprocess -from time import perf_counter +import numpy as np import pytest import hivemind.p2p @@ -23,33 +25,131 @@ def is_process_running(pid: int) -> bool: return subprocess.check_output(cmd, shell=True).decode('utf-8').strip() == RUNNING -@pytest.fixture() -def mock_p2p_class(): - P2P.LIBP2P_CMD = "sleep" - - -def test_daemon_killed_on_del(mock_p2p_class): - start = perf_counter() - p2p_daemon = P2P('10s') +@pytest.mark.asyncio +async def test_daemon_killed_on_del(): + p2p_daemon = await P2P.create() child_pid = p2p_daemon._child.pid assert is_process_running(child_pid) - del p2p_daemon + p2p_daemon.__del__() assert not is_process_running(child_pid) - assert perf_counter() - start < 1 -def test_daemon_killed_on_exit(mock_p2p_class): - start = perf_counter() - with P2P('10s') as daemon: - child_pid = daemon.pid - assert is_process_running(child_pid) +def handle_square(x): + return x ** 2 - assert not is_process_running(child_pid) - assert perf_counter() - start < 1 + +def handle_add(args): + result = args[0] + for i in range(1, len(args)): + result = result + args[i] + return result + + +@pytest.mark.parametrize( + "test_input,handle", + [ + pytest.param(10, handle_square, id="square_integer"), + pytest.param((1, 2), handle_add, id="add_integers"), + pytest.param(([1, 2, 3], [12, 13]), handle_add, id="add_lists"), + pytest.param(2, lambda x: x ** 3, id="lambda") + ] +) +@pytest.mark.asyncio +async def test_call_peer_single_process(test_input, handle, handler_name="handle"): + server = await P2P.create() + server_pid = server._child.pid + await server.add_stream_handler(handler_name, handle) + assert is_process_running(server_pid) + + client = await P2P.create() + client_pid = client._child.pid + assert is_process_running(client_pid) + + await asyncio.sleep(1) + result = await client.call_peer_handler(server.id, handler_name, test_input) + assert result == handle(test_input) + + await server.stop_listening() + server.__del__() + assert not is_process_running(server_pid) + + client.__del__() + assert not is_process_running(client_pid) + + +@pytest.mark.asyncio +async def test_call_peer_different_processes(): + handler_name = "square" + test_input = np.random.randn(2, 3) + + server_side, client_side = mp.Pipe() + response_received = mp.Value(np.ctypeslib.as_ctypes_type(np.int32)) + response_received.value = 0 + + async def run_server(): + server = await P2P.create() + server_pid = server._child.pid + await server.add_stream_handler(handler_name, handle_square) + assert is_process_running(server_pid) + + server_side.send(server.id) + while response_received.value == 0: + await asyncio.sleep(0.5) + + await server.stop_listening() + server.__del__() + assert not is_process_running(server_pid) + + def server_target(): + asyncio.run(run_server()) + + proc = mp.Process(target=server_target) + proc.start() + + client = await P2P.create() + client_pid = client._child.pid + assert is_process_running(client_pid) + + await asyncio.sleep(1) + peer_id = client_side.recv() + + result = await client.call_peer_handler(peer_id, handler_name, test_input) + assert np.allclose(result, handle_square(test_input)) + response_received.value = 1 + + client.__del__() + assert not is_process_running(client_pid) + + proc.join() -def test_daemon_raises_on_faulty_args(): - with pytest.raises(RuntimeError): - P2P(faulty='argument') +@pytest.mark.parametrize( + "test_input,handle", + [ + pytest.param(np.random.randn(2, 3), handle_square, id="square"), + pytest.param([np.random.randn(2, 3), np.random.randn(2, 3)], handle_add, id="add"), + ] +) +@pytest.mark.asyncio +async def test_call_peer_numpy(test_input, handle, handler_name="handle"): + server = await P2P.create() + await server.add_stream_handler(handler_name, handle) + client = await P2P.create() + + await asyncio.sleep(1) + result = await client.call_peer_handler(server.id, handler_name, test_input) + assert np.allclose(result, handle(test_input)) + + +@pytest.mark.asyncio +async def test_call_peer_error(handler_name="handle"): + server = await P2P.create() + await server.add_stream_handler(handler_name, handle_add) + client = await P2P.create() + + await asyncio.sleep(1) + result = await client.call_peer_handler(server.id, handler_name, + [np.zeros((2, 3)), np.zeros((3, 2))]) + assert type(result) == ValueError