Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add service checking direct reachability from peers #195

Merged
merged 23 commits into from
Jan 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions src/petals/cli/run_dht.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""
A copy of run_dht.py from hivemind with the ReachabilityProtocol added:
https://github.com/learning-at-home/hivemind/blob/master/hivemind/hivemind_cli/run_dht.py

This script may be used for launching lightweight CPU machines serving as bootstrap nodes to a Petals swarm.

This may be eventually merged to the hivemind upstream.
"""

import time
from argparse import ArgumentParser
from secrets import token_hex

from hivemind.dht import DHT, DHTNode
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from hivemind.utils.networking import log_visible_maddrs

from petals.server.reachability import ReachabilityProtocol

use_hivemind_log_handler("in_root_logger")
logger = get_logger(__name__)


async def report_status(dht: DHT, node: DHTNode):
logger.info(
f"{len(node.protocol.routing_table.uid_to_peer_id) + 1} DHT nodes (including this one) "
f"are in the local routing table "
)
logger.debug(f"Routing table contents: {node.protocol.routing_table}")
logger.info(f"Local storage contains {len(node.protocol.storage)} keys")
logger.debug(f"Local storage contents: {node.protocol.storage}")

# Contact peers and keep the routing table healthy (remove stale PeerIDs)
await node.get(f"heartbeat_{token_hex(16)}", latest=True)


def main():
parser = ArgumentParser()
parser.add_argument(
"--initial_peers",
nargs="*",
help="Multiaddrs of the peers that will welcome you into the existing DHT. "
"Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/tcp/7777/p2p/YYYY",
)
parser.add_argument(
"--host_maddrs",
nargs="*",
default=["/ip4/0.0.0.0/tcp/0", "/ip6/::/tcp/0"],
help="Multiaddrs to listen for external connections from other DHT instances. "
"Defaults to all IPv4 interfaces and the TCP protocol: /ip4/0.0.0.0/tcp/0",
)
parser.add_argument(
"--announce_maddrs",
nargs="*",
help="Visible multiaddrs the host announces for external connections from other DHT instances",
)
parser.add_argument(
"--use_ipfs",
action="store_true",
help='Use IPFS to find initial_peers. If enabled, you only need to provide the "/p2p/XXXX" '
"part of the multiaddrs for the initial_peers "
"(no need to specify a particular IPv4/IPv6 host and port)",
)
parser.add_argument(
"--identity_path",
help="Path to a private key file. If defined, makes the peer ID deterministic. "
"If the file does not exist, writes a new private key to this file.",
)
parser.add_argument(
"--no_relay",
action="store_false",
dest="use_relay",
help="Disable circuit relay functionality in libp2p (see https://docs.libp2p.io/concepts/nat/circuit-relay/)",
)
parser.add_argument(
"--use_auto_relay", action="store_true", help="Look for libp2p relays to reach peers behind NATs/firewalls"
)
parser.add_argument(
"--refresh_period", type=int, default=30, help="Period (in seconds) for fetching the keys from DHT"
)

args = parser.parse_args()

dht = DHT(
start=True,
initial_peers=args.initial_peers,
host_maddrs=args.host_maddrs,
announce_maddrs=args.announce_maddrs,
use_ipfs=args.use_ipfs,
identity_path=args.identity_path,
use_relay=args.use_relay,
use_auto_relay=args.use_auto_relay,
)
log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=args.use_ipfs)

reachability_protocol = ReachabilityProtocol.attach_to_dht(dht, await_ready=True)

while True:
dht.run_coroutine(report_status, return_future=False)
time.sleep(args.refresh_period)


if __name__ == "__main__":
main()
3 changes: 3 additions & 0 deletions src/petals/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
"/dns/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
"/dns6/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
]

# The reachability API is currently used only when connecting to the public swarm
REACHABILITY_API_URL = "http://health.petals.ml"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed the env variable here since this check currently works only for the public swarm, and the env var were making an impression that you could enable it for the private swarm too.

We can update the logic for this in future PR: e.g., make the server run the check if the swarm is public or custom API URL is provided.

135 changes: 131 additions & 4 deletions src/petals/server/reachability.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,30 @@
import asyncio
import math
import threading
import time
from concurrent.futures import Future
from contextlib import asynccontextmanager
from functools import partial
from secrets import token_hex
from typing import Optional

import requests
from hivemind.utils.logging import get_logger
from hivemind.dht import DHT, DHTNode
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase
from hivemind.proto import dht_pb2
from hivemind.utils import get_logger

logger = get_logger(__file__)
from petals.constants import REACHABILITY_API_URL

logger = get_logger(__name__)

def check_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float = 15) -> None:

def validate_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float = 15) -> None:
"""verify that your peer is reachable from a (centralized) validator, whether directly or through a relay"""
for attempt_no in range(math.floor(wait_time / retry_delay) + 1):
try:
r = requests.get(f"http://health.petals.ml/api/v1/is_reachable/{peer_id}", timeout=10)
r = requests.get(f"{REACHABILITY_API_URL}/api/v1/is_reachable/{peer_id}", timeout=10)
r.raise_for_status()
response = r.json()

Expand All @@ -37,3 +51,116 @@ def check_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float =
f" python -m petals.cli.run_server ... --public_ip {response['your_ip']} --port 31337\n"
f" 4. If it does not help, ask for help in our Discord: https://discord.gg/Wuk8BnrEPH\n"
)


def check_direct_reachability(max_peers: int = 5, threshold: float = 0.5, **kwargs) -> Optional[bool]:
"""test if your peer is accessible by others in the swarm with the specified network options in **kwargs"""

async def _check_direct_reachability():
target_dht = await DHTNode.create(client_mode=True, **kwargs)
try:
protocol = ReachabilityProtocol(probe=target_dht.protocol.p2p)
async with protocol.serve(target_dht.protocol.p2p):
successes = requests = 0
for remote_peer in list(target_dht.protocol.routing_table.peer_id_to_uid.keys()):
probe_available = await protocol.call_check(remote_peer=remote_peer, check_peer=target_dht.peer_id)
if probe_available is None:
continue # remote peer failed to check probe
successes += probe_available
requests += 1
if requests >= max_peers:
break

logger.info(f"Direct reachability: {successes}/{requests}")
return (successes / requests) >= threshold if requests > 0 else None
finally:
await target_dht.shutdown()

return RemoteExpertWorker.run_coroutine(_check_direct_reachability())


STRIPPED_PROBE_ARGS = dict(
dht_mode="client", use_relay=False, auto_nat=False, nat_port_map=False, no_listen=True, startup_timeout=60
)


class ReachabilityProtocol(ServicerBase):
"""Mini protocol to test if a locally running peer is accessible by other devices in the swarm"""

def __init__(self, *, probe: Optional[P2P] = None, wait_timeout: float = 5.0):
self.probe = probe
self.wait_timeout = wait_timeout
self._event_loop = self._stop = None

async def call_check(self, remote_peer: PeerID, *, check_peer: PeerID) -> Optional[bool]:
"""Returns True if remote_peer can reach check_peer, False if it cannot, None if it did not respond"""
try:
request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=check_peer.to_bytes()))
timeout = self.wait_timeout if check_peer == remote_peer else self.wait_timeout * 2
response = await self.get_stub(self.probe, remote_peer).rpc_check(request, timeout=timeout)
logger.debug(f"call_check(remote_peer={remote_peer}, check_peer={check_peer}) -> {response.available}")
return response.available
except Exception as e:
logger.debug(f"Requested {remote_peer} to check {check_peer}, but got:", exc_info=True)
return None

async def rpc_check(self, request: dht_pb2.PingRequest, context: P2PContext) -> dht_pb2.PingResponse:
"""Help another peer to check its reachability"""
response = dht_pb2.PingResponse(available=True)
check_peer = PeerID(request.peer.node_id)
if check_peer != context.local_id: # remote peer wants us to check someone other than ourselves
response.available = await self.call_check(check_peer, check_peer=check_peer) is True
logger.info(
f"reachability.rpc_check(remote_peer=...{str(context.remote_id)[-6:]}, "
f"check_peer=...{str(check_peer)[-6:]}) -> {response.available}"
)
return response

@asynccontextmanager
async def serve(self, p2p: P2P):
try:
await self.add_p2p_handlers(p2p)
yield self
finally:
await self.remove_p2p_handlers(p2p)

@classmethod
def attach_to_dht(cls, dht: DHT, await_ready: bool = False, **kwargs) -> Optional["ReachabilityProtocol"]:
protocol = cls(**kwargs)
ready = Future()

async def _serve_with_probe():
try:
common_p2p = await dht.replicate_p2p()
protocol._event_loop = asyncio.get_event_loop()
protocol._stop = asyncio.Event()

initial_peers = [str(addr) for addr in await common_p2p.get_visible_maddrs(latest=True)]
for info in await common_p2p.list_peers():
initial_peers.extend(f"{addr}/p2p/{info.peer_id}" for addr in info.addrs)
protocol.probe = await P2P.create(initial_peers, **STRIPPED_PROBE_ARGS)

ready.set_result(True)
logger.info("Reachability service started")

async with protocol.serve(common_p2p):
await protocol._stop.wait()
except Exception as e:
logger.warning(f"Reachability service failed: {repr(e)}")
logger.debug("See detailed traceback below:", exc_info=True)

if not ready.done():
ready.set_exception(e)
finally:
if protocol is not None and protocol.probe is not None:
await protocol.probe.shutdown()
logger.debug("Reachability service shut down")

threading.Thread(target=partial(asyncio.run, _serve_with_probe()), daemon=True).start()
if await_ready:
ready.result() # Propagates startup exceptions, if any
return protocol

def shutdown(self):
if self._event_loop is not None and self._stop is not None:
self._event_loop.call_soon_threadsafe(self._stop.set)
22 changes: 16 additions & 6 deletions src/petals/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from petals.server.block_utils import get_block_size
from petals.server.handler import TransformerConnectionHandler
from petals.server.memory_cache import MemoryCache
from petals.server.reachability import check_reachability
from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability
from petals.server.throughput import get_dtype_name, get_host_throughput
from petals.utils.convert_block import check_device_balance, convert_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
Expand Down Expand Up @@ -77,6 +77,7 @@ def __init__(
load_in_8bit: Optional[bool] = None,
tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
skip_reachability_check: bool = False,
dht_client_mode: Optional[bool] = None,
use_relay: bool = True,
use_auto_relay: bool = True,
**kwargs,
Expand Down Expand Up @@ -118,20 +119,27 @@ def __init__(
)
self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]

if dht_client_mode is None:
is_reachable = check_direct_reachability(initial_peers=initial_peers, use_relay=False, **kwargs)
dht_client_mode = is_reachable is False # if could not check reachability (returns None), run a full peer
logger.info(f"This server will run DHT in {'client' if dht_client_mode else 'full peer'} mode")
self.dht = DHT(
initial_peers=initial_peers,
start=True,
num_workers=self.block_config.n_layer,
use_relay=use_relay,
use_auto_relay=use_auto_relay,
client_mode=dht_client_mode,
**kwargs,
)
self.reachability_protocol = ReachabilityProtocol.attach_to_dht(self.dht) if not dht_client_mode else None

visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
if initial_peers == PUBLIC_INITIAL_PEERS:
logger.info(f"Connecting to the public swarm, peer_id = {self.dht.peer_id}")
else:
logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
self.need_reachability_check = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS
self.should_validate_reachability = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS

if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
Expand Down Expand Up @@ -277,7 +285,7 @@ def run(self):
use_auth_token=self.use_auth_token,
load_in_8bit=self.load_in_8bit,
tensor_parallel_devices=self.tensor_parallel_devices,
need_reachability_check=self.need_reachability_check,
should_validate_reachability=self.should_validate_reachability,
start=True,
)
try:
Expand Down Expand Up @@ -335,6 +343,8 @@ def _should_choose_other_blocks(self) -> bool:
def shutdown(self):
self.stop.set()

if self.reachability_protocol is not None:
self.reachability_protocol.shutdown()
self.dht.shutdown()
self.dht.join()

Expand Down Expand Up @@ -367,7 +377,7 @@ def create(
use_auth_token: Optional[str],
load_in_8bit: bool,
tensor_parallel_devices: Sequence[torch.device],
need_reachability_check: bool,
should_validate_reachability: bool,
**kwargs,
) -> ModuleContainer:
module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
Expand Down Expand Up @@ -422,8 +432,8 @@ def create(
max_batch_size=max_batch_size,
)

if need_reachability_check:
check_reachability(dht.peer_id)
if should_validate_reachability:
validate_reachability(dht.peer_id)
except:
logger.debug("Shutting down backends")
for backend in blocks.values():
Expand Down