Skip to content

Commit

Permalink
feat p2p_daemon: add API to call peer handle (#181)
Browse files Browse the repository at this point in the history
* Extend P2P api

* Add tests for new api

* Add p2pclient dependencies

* Test P2P from different processes

* Fix typo in tests

* Add default initialization

* Fix daemon ports assignment

* Replace del with __del__ in tests

* Read from input stream with receive_exactly

Co-authored-by: Ilya Kobelev <ilya.kobellev@gmail.com>
  • Loading branch information
2 people authored and justheuristic committed Apr 13, 2021
1 parent 0535efe commit 3595c94
Show file tree
Hide file tree
Showing 2 changed files with 292 additions and 40 deletions.
190 changes: 171 additions & 19 deletions hivemind/p2p/p2p_daemon.py
@@ -1,45 +1,197 @@
import asyncio
import contextlib
import copy
from pathlib import Path
import pickle
import socket
import subprocess
import typing as tp
import warnings

from multiaddr import Multiaddr
import p2pclient
from libp2p.peer.id import ID


class P2P(object):
"""
Forks a child process and executes p2pd command with given arguments.
Sends SIGKILL to the child in destructor and on exit from contextmanager.
Can be used for peer to peer communication and procedure calls.
Sends SIGKILL to the child in destructor.
"""

LIBP2P_CMD = 'p2pd'
P2PD_RELATIVE_PATH = 'hivemind_cli/p2pd'
NUM_RETRIES = 3
RETRY_DELAY = 0.4
HEADER_LEN = 8
BYTEORDER = 'big'

def __init__(self, *args, **kwargs):
self._child = subprocess.Popen(args=self._make_process_args(args, kwargs))
try:
stdout, stderr = self._child.communicate(timeout=0.2)
except subprocess.TimeoutExpired:
pass
else:
raise RuntimeError(f'p2p daemon exited with stderr: {stderr}')
def __init__(self):
self._child = None
self._listen_task = None
self._server_stopped = asyncio.Event()
self._buffer = bytearray()

def __enter__(self):
return self._child
@classmethod
async def create(cls, *args, quic=1, tls=1, conn_manager=1, dht_client=1,
nat_port_map=True, auto_nat=True, bootstrap=True,
host_port: int = None, daemon_listen_port: int = None, **kwargs):
self = cls()
p2pd_path = Path(__file__).resolve().parents[1] / P2P.P2PD_RELATIVE_PATH
proc_args = self._make_process_args(
str(p2pd_path), *args,
quic=quic, tls=tls, connManager=conn_manager,
dhtClient=dht_client, natPortMap=nat_port_map,
autonat=auto_nat, b=bootstrap, **kwargs)
self._assign_daemon_ports(host_port, daemon_listen_port)
for try_count in range(self.NUM_RETRIES):
try:
self._initialize(proc_args)
await self._identify_client(P2P.RETRY_DELAY * (2 ** try_count))
except Exception as exc:
warnings.warn("Failed to initialize p2p daemon: " + str(exc), RuntimeWarning)
self._kill_child()
if try_count == P2P.NUM_RETRIES - 1:
raise
self._assign_daemon_ports()
continue
break
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self._kill_child()
def _initialize(self, proc_args: tp.List[str]) -> None:
proc_args = copy.deepcopy(proc_args)
proc_args.extend(self._make_process_args(
hostAddrs=f'/ip4/0.0.0.0/tcp/{self._host_port},/ip4/0.0.0.0/udp/{self._host_port}/quic',
listen=f'/ip4/127.0.0.1/tcp/{self._daemon_listen_port}'
))
self._child = subprocess.Popen(
args=proc_args,
stdin=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, encoding="utf8"
)
self._client_listen_port = find_open_port()
self._client = p2pclient.Client(
Multiaddr(f'/ip4/127.0.0.1/tcp/{self._daemon_listen_port}'),
Multiaddr(f'/ip4/127.0.0.1/tcp/{self._client_listen_port}'))

async def _identify_client(self, delay):
await asyncio.sleep(delay)
encoded = await self._client.identify()
self.id = encoded[0].to_base58()

def _assign_daemon_ports(self, host_port=None, daemon_listen_port=None):
self._host_port, self._daemon_listen_port = host_port, daemon_listen_port
if host_port is None:
self._host_port = find_open_port()
if daemon_listen_port is None:
self._daemon_listen_port = find_open_port()
while self._daemon_listen_port == self._host_port:
self._daemon_listen_port = find_open_port()

@staticmethod
async def send_data(data, stream):
byte_str = pickle.dumps(data)
request = len(byte_str).to_bytes(P2P.HEADER_LEN, P2P.BYTEORDER) + byte_str
await stream.send_all(request)

class IncompleteRead(Exception):
pass

async def _receive_exactly(self, stream, n_bytes, max_bytes=1 << 16):
while len(self._buffer) < n_bytes:
data = await stream.receive_some(max_bytes)
if len(data) == 0:
raise P2P.IncompleteRead()
self._buffer.extend(data)

result = self._buffer[:n_bytes]
self._buffer = self._buffer[n_bytes:]
return bytes(result)

async def receive_data(self, stream, max_bytes=(1 < 16)):
header = await self._receive_exactly(stream, P2P.HEADER_LEN)
content_length = int.from_bytes(header, P2P.BYTEORDER)
data = await self._receive_exactly(stream, content_length)
return pickle.loads(data)

def _handle_stream(self, handle):
async def do_handle_stream(stream_info, stream):
try:
request = await self.receive_data(stream)
except P2P.IncompleteRead:
warnings.warn("Incomplete read while receiving request from peer", RuntimeWarning)
return
finally:
stream.close()
try:
result = handle(request)
await self.send_data(result, stream)
except Exception as exc:
await self.send_data(exc, stream)
finally:
await stream.close()

return do_handle_stream

def start_listening(self):
async def listen():
async with self._client.listen():
await self._server_stopped.wait()

self._listen_task = asyncio.create_task(listen())

async def stop_listening(self):
if self._listen_task is not None:
self._server_stopped.set()
self._listen_task.cancel()
try:
await self._listen_task
except asyncio.CancelledError:
self._listen_task = None
self._server_stopped.clear()

async def add_stream_handler(self, name, handle):
if self._listen_task is None:
self.start_listening()

await self._client.stream_handler(name, self._handle_stream(handle))

async def call_peer_handler(self, peer_id, handler_name, input_data):
libp2p_peer_id = ID.from_base58(peer_id)
stream_info, stream = await self._client.stream_open(libp2p_peer_id, (handler_name,))
try:
await self.send_data(input_data, stream)
return await self.receive_data(stream)
finally:
await stream.close()

def __del__(self):
self._kill_child()

def _kill_child(self):
if self._child.poll() is None:
if self._child is not None and self._child.poll() is None:
self._child.kill()
self._child.wait()

def _make_process_args(self, args: tp.Tuple[tp.Any],
kwargs: tp.Dict[str, tp.Any]) -> tp.List[str]:
proc_args = [self.LIBP2P_CMD]
def _make_process_args(self, *args, **kwargs) -> tp.List[str]:
proc_args = []
proc_args.extend(
str(entry) for entry in args
)
proc_args.extend(
f'-{key}={str(value)}' for key, value in kwargs.items()
f'-{key}={value}' if value is not None else f'-{key}'
for key, value in kwargs.items()
)
return proc_args


def find_open_port(params=(socket.AF_INET, socket.SOCK_STREAM),
opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
""" Finds a tcp port that can be occupied with a socket with *params and use *opt options """
try:
with contextlib.closing(socket.socket(*params)) as sock:
sock.bind(('', 0))
sock.setsockopt(*opt)
return sock.getsockname()[1]
except Exception:
raise
142 changes: 121 additions & 21 deletions tests/test_p2p_daemon.py
@@ -1,6 +1,8 @@
import asyncio
import multiprocessing as mp
import subprocess
from time import perf_counter

import numpy as np
import pytest

import hivemind.p2p
Expand All @@ -23,33 +25,131 @@ def is_process_running(pid: int) -> bool:
return subprocess.check_output(cmd, shell=True).decode('utf-8').strip() == RUNNING


@pytest.fixture()
def mock_p2p_class():
P2P.LIBP2P_CMD = "sleep"


def test_daemon_killed_on_del(mock_p2p_class):
start = perf_counter()
p2p_daemon = P2P('10s')
@pytest.mark.asyncio
async def test_daemon_killed_on_del():
p2p_daemon = await P2P.create()

child_pid = p2p_daemon._child.pid
assert is_process_running(child_pid)

del p2p_daemon
p2p_daemon.__del__()
assert not is_process_running(child_pid)
assert perf_counter() - start < 1


def test_daemon_killed_on_exit(mock_p2p_class):
start = perf_counter()
with P2P('10s') as daemon:
child_pid = daemon.pid
assert is_process_running(child_pid)
def handle_square(x):
return x ** 2

assert not is_process_running(child_pid)
assert perf_counter() - start < 1

def handle_add(args):
result = args[0]
for i in range(1, len(args)):
result = result + args[i]
return result


@pytest.mark.parametrize(
"test_input,handle",
[
pytest.param(10, handle_square, id="square_integer"),
pytest.param((1, 2), handle_add, id="add_integers"),
pytest.param(([1, 2, 3], [12, 13]), handle_add, id="add_lists"),
pytest.param(2, lambda x: x ** 3, id="lambda")
]
)
@pytest.mark.asyncio
async def test_call_peer_single_process(test_input, handle, handler_name="handle"):
server = await P2P.create()
server_pid = server._child.pid
await server.add_stream_handler(handler_name, handle)
assert is_process_running(server_pid)

client = await P2P.create()
client_pid = client._child.pid
assert is_process_running(client_pid)

await asyncio.sleep(1)
result = await client.call_peer_handler(server.id, handler_name, test_input)
assert result == handle(test_input)

await server.stop_listening()
server.__del__()
assert not is_process_running(server_pid)

client.__del__()
assert not is_process_running(client_pid)


@pytest.mark.asyncio
async def test_call_peer_different_processes():
handler_name = "square"
test_input = np.random.randn(2, 3)

server_side, client_side = mp.Pipe()
response_received = mp.Value(np.ctypeslib.as_ctypes_type(np.int32))
response_received.value = 0

async def run_server():
server = await P2P.create()
server_pid = server._child.pid
await server.add_stream_handler(handler_name, handle_square)
assert is_process_running(server_pid)

server_side.send(server.id)
while response_received.value == 0:
await asyncio.sleep(0.5)

await server.stop_listening()
server.__del__()
assert not is_process_running(server_pid)

def server_target():
asyncio.run(run_server())

proc = mp.Process(target=server_target)
proc.start()

client = await P2P.create()
client_pid = client._child.pid
assert is_process_running(client_pid)

await asyncio.sleep(1)
peer_id = client_side.recv()

result = await client.call_peer_handler(peer_id, handler_name, test_input)
assert np.allclose(result, handle_square(test_input))
response_received.value = 1

client.__del__()
assert not is_process_running(client_pid)

proc.join()


def test_daemon_raises_on_faulty_args():
with pytest.raises(RuntimeError):
P2P(faulty='argument')
@pytest.mark.parametrize(
"test_input,handle",
[
pytest.param(np.random.randn(2, 3), handle_square, id="square"),
pytest.param([np.random.randn(2, 3), np.random.randn(2, 3)], handle_add, id="add"),
]
)
@pytest.mark.asyncio
async def test_call_peer_numpy(test_input, handle, handler_name="handle"):
server = await P2P.create()
await server.add_stream_handler(handler_name, handle)
client = await P2P.create()

await asyncio.sleep(1)
result = await client.call_peer_handler(server.id, handler_name, test_input)
assert np.allclose(result, handle(test_input))


@pytest.mark.asyncio
async def test_call_peer_error(handler_name="handle"):
server = await P2P.create()
await server.add_stream_handler(handler_name, handle_add)
client = await P2P.create()

await asyncio.sleep(1)
result = await client.call_peer_handler(server.id, handler_name,
[np.zeros((2, 3)), np.zeros((3, 2))])
assert type(result) == ValueError

0 comments on commit 3595c94

Please sign in to comment.