Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add service checking direct reachability from peers #195

Merged
merged 23 commits into from
Jan 12, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/petals/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
"/dns/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
"/dns6/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
]

# The reachability API is currently used only when connecting to the public swarm
REACHABILITY_API_URL = "http://health.petals.ml"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed the env variable here since this check currently works only for the public swarm, and the env var were making an impression that you could enable it for the private swarm too.

We can update the logic for this in future PR: e.g., make the server run the check if the swarm is public or custom API URL is provided.

119 changes: 115 additions & 4 deletions src/petals/server/reachability.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,28 @@
import asyncio
import math
import threading
import time
from contextlib import asynccontextmanager
from functools import partial
from typing import Optional

import requests
from hivemind.utils.logging import get_logger
from hivemind.dht import DHT, DHTNode
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase
from hivemind.proto import dht_pb2
from hivemind.utils import get_logger

logger = get_logger(__file__)
from petals.constants import REACHABILITY_API_URL

logger = get_logger(__name__)

def check_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float = 15) -> None:

def validate_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float = 15) -> None:
"""verify that your peer is reachable from a (centralized) validator, whether directly or through a relay"""
for attempt_no in range(math.floor(wait_time / retry_delay) + 1):
try:
r = requests.get(f"http://health.petals.ml/api/v1/is_reachable/{peer_id}", timeout=10)
r = requests.get(f"{REACHABILITY_API_URL}/api/v1/is_reachable/{peer_id}", timeout=10)
r.raise_for_status()
response = r.json()

Expand All @@ -37,3 +49,102 @@ def check_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float =
f" python -m petals.cli.run_server ... --public_ip {response['your_ip']} --port 31337\n"
f" 4. If it does not help, ask for help in our Discord: https://discord.gg/Wuk8BnrEPH\n"
)


def check_direct_reachability(max_peers: int = 5, threshold: float = 0.5, **kwargs) -> Optional[bool]:
"""test if your peer is accessible by others in the swarm with the specified network options in **kwargs"""

async def _check_direct_reachability():
target_dht = await DHTNode.create(client_mode=True, **kwargs)
try:
protocol = ReachabilityProtocol(target_dht.protocol.p2p)
async with protocol.serve():
successes = requests = 0
for remote_peer in list(target_dht.protocol.routing_table.peer_id_to_uid.keys()):
probe_available = await protocol.call_check(remote_peer=remote_peer, check_peer=target_dht.peer_id)
if probe_available is None:
continue # remote peer failed to check probe
successes += probe_available
requests += 1
if requests >= max_peers:
break

logger.info(f"Direct reachability: {successes}/{requests}")
return (successes / requests) >= threshold if requests > 0 else None
finally:
await target_dht.shutdown()

return RemoteExpertWorker.run_coroutine(_check_direct_reachability())


PROBE_P2P_ARGS = dict(dht_mode="client", use_relay=False, auto_nat=False, nat_port_map=False, no_listen=True)


class ReachabilityProtocol(ServicerBase):
"""Mini protocol to test if a locally running peer is accessible by other devices in the swarm"""

def __init__(self, p2p: P2P, *, probe: Optional[P2P] = None, wait_timeout: float = 5.0):
probe = probe if probe is not None else p2p
self.p2p, self.probe, self.wait_timeout = p2p, probe, wait_timeout
self._event_loop = self._stop = None

async def call_check(self, remote_peer: PeerID, *, check_peer: PeerID) -> Optional[bool]:
"""Returns True if remote_peer can reach check_peer, False if it cannot, None if it did not respond"""
try:
request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=check_peer.to_bytes()))
timeout = self.wait_timeout if check_peer == remote_peer else self.wait_timeout * 2
response = await self.get_stub(self.probe, remote_peer).rpc_check(request, timeout=timeout)
logger.debug(f"call_check(remote_peer={remote_peer}, check_peer={check_peer}) -> {response.available}")
return response.available
except Exception as e:
logger.debug(f"Requested {remote_peer} to check {check_peer}, but got:", exc_info=True)
return None

async def rpc_check(self, request: dht_pb2.PingRequest, context: P2PContext) -> dht_pb2.PingResponse:
"""Help another peer to check its reachability"""
response = dht_pb2.PingResponse(available=True)
check_peer = PeerID(request.peer.node_id)
if check_peer != context.local_id: # remote peer wants us to check someone other than ourselves
response.available = await self.call_check(check_peer, check_peer=check_peer) is True
logger.debug(f"rpc_check(check_peer={check_peer}) -> {response.available}")
return response

@asynccontextmanager
async def serve(self):
try:
await self.add_p2p_handlers(self.p2p)
yield self
finally:
await self.remove_p2p_handlers(self.p2p)

@classmethod
def attach_to_dht(cls, dht: DHT, **kwargs) -> "ReachabilityProtocol":
protocol = None
ready = threading.Event()

async def _serve_with_probe():
nonlocal protocol
protocol = cls(p2p=await dht.replicate_p2p(), **kwargs)
protocol._event_loop = asyncio.get_event_loop()
protocol._stop = asyncio.Event()
ready.set()

initial_peers = list(map(str, await protocol.p2p.get_visible_maddrs(latest=True)))
for info in await protocol.p2p.list_peers():
initial_peers.extend(f"{addr}/p2p/{info.peer_id}" for addr in info.addrs)
protocol.probe = await P2P.create(initial_peers, **PROBE_P2P_ARGS)

try:
async with protocol.serve():
await protocol._stop.wait()
finally:
await protocol.probe.shutdown()
logger.debug("ReachabilityProtocol shut down")

threading.Thread(target=partial(asyncio.run, _serve_with_probe()), daemon=True).start()
ready.wait()
return protocol

def shutdown(self):
if self._event_loop is not None and self._stop is not None:
self._event_loop.call_soon_threadsafe(self._stop.set)
22 changes: 16 additions & 6 deletions src/petals/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from petals.server.block_utils import get_block_size
from petals.server.handler import TransformerConnectionHandler
from petals.server.memory_cache import MemoryCache
from petals.server.reachability import check_reachability
from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability
from petals.server.throughput import get_dtype_name, get_host_throughput
from petals.utils.convert_block import check_device_balance, convert_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
Expand Down Expand Up @@ -77,6 +77,7 @@ def __init__(
load_in_8bit: Optional[bool] = None,
tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
skip_reachability_check: bool = False,
dht_client_mode: Optional[bool] = None,
use_relay: bool = True,
use_auto_relay: bool = True,
**kwargs,
Expand Down Expand Up @@ -118,20 +119,27 @@ def __init__(
)
self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]

if dht_client_mode is None:
is_reachable = check_direct_reachability(initial_peers=initial_peers, use_relay=False, **kwargs)
dht_client_mode = is_reachable is False # if could not check reachability (returns None), run a full peer
logger.info(f"This server will run DHT in {'client' if dht_client_mode else 'full peer'} mode")
self.dht = DHT(
initial_peers=initial_peers,
start=True,
num_workers=self.block_config.n_layer,
use_relay=use_relay,
use_auto_relay=use_auto_relay,
client_mode=dht_client_mode,
**kwargs,
)
self.reachability_protocol = ReachabilityProtocol.attach_to_dht(self.dht) if not dht_client_mode else None

visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
if initial_peers == PUBLIC_INITIAL_PEERS:
logger.info(f"Connecting to the public swarm, peer_id = {self.dht.peer_id}")
else:
logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
self.need_reachability_check = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS
self.should_validate_reachability = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS

if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
Expand Down Expand Up @@ -277,7 +285,7 @@ def run(self):
use_auth_token=self.use_auth_token,
load_in_8bit=self.load_in_8bit,
tensor_parallel_devices=self.tensor_parallel_devices,
need_reachability_check=self.need_reachability_check,
should_validate_reachability=self.should_validate_reachability,
start=True,
)
try:
Expand Down Expand Up @@ -335,6 +343,8 @@ def _should_choose_other_blocks(self) -> bool:
def shutdown(self):
self.stop.set()

if self.reachability_protocol is not None:
self.reachability_protocol.shutdown()
self.dht.shutdown()
self.dht.join()

Expand Down Expand Up @@ -367,7 +377,7 @@ def create(
use_auth_token: Optional[str],
load_in_8bit: bool,
tensor_parallel_devices: Sequence[torch.device],
need_reachability_check: bool,
should_validate_reachability: bool,
**kwargs,
) -> ModuleContainer:
module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
Expand Down Expand Up @@ -422,8 +432,8 @@ def create(
max_batch_size=max_batch_size,
)

if need_reachability_check:
check_reachability(dht.peer_id)
if should_validate_reachability:
validate_reachability(dht.peer_id)
except:
logger.debug("Shutting down backends")
for backend in blocks.values():
Expand Down