Skip to content

Commit

Permalink
Perform reachability check once blocks are loaded to avoid delays
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov committed Jan 9, 2023
1 parent 6d8322e commit 8c69eb5
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 34 deletions.
37 changes: 37 additions & 0 deletions src/petals/server/reachability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import time

import requests
from hivemind.utils.logging import get_logger

logger = get_logger(__file__)


def check_reachability(peer_id, wait_time: float = 600, retry_delay: float = 15) -> None:
for attempt_no in range(math.floor(wait_time / retry_delay) + 1):
try:
r = requests.get(f"http://health.petals.ml/api/v1/is_reachable/{peer_id}", timeout=10)
r.raise_for_status()
response = r.json()

if response["success"]:
logger.info("Server is reachable from the Internet. It will appear at http://health.petals.ml soon")
return

if attempt_no == 0:
# If health.petals.ml didn't manage to connect right away, we need to wait for libp2p to set up relays
logger.info("Detected a NAT or a firewall, connecting to libp2p relays. This takes a few minutes")
time.sleep(retry_delay)
except Exception as e:
logger.warning(f"Skipping reachability check because health.petals.ml is down: {repr(e)}")
return

raise RuntimeError(
f"Server has not become reachable from the Internet:\n\n"
f"{response['message']}\n\n"
f"You need to fix your port forwarding and/or firewall settings. How to do that:\n\n"
f" 1. Choose a specific port for the Petals server, for example, 31337.\n"
f" 2. Ensure that this port is accessible from the Internet and not blocked by your firewall.\n"
f" 3. Add these arguments to explicitly announce your IP address and port to other peers:\n"
f" python -m petals.cli.run_server ... --public_ip {response['your_ip']} --port 31337\n"
f" 4. If it does not help, ask for help in our Discord: https://discord.gg/Wuk8BnrEPH\n"
)
41 changes: 7 additions & 34 deletions src/petals/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import numpy as np
import psutil
import requests
import torch
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
from hivemind.moe.server.layers import add_custom_models_from_file
Expand All @@ -28,6 +27,7 @@
from petals.server.block_utils import get_block_size
from petals.server.handler import TransformerConnectionHandler
from petals.server.memory_cache import MemoryCache
from petals.server.reachability import check_reachability
from petals.server.throughput import get_host_throughput
from petals.utils.convert_block import check_device_balance, convert_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
Expand Down Expand Up @@ -132,6 +132,7 @@ def __init__(
logger.info(f"Connecting to the public swarm, peer_id = {self.dht.peer_id}")
else:
logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
self.need_reachability_check = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS

if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
Expand Down Expand Up @@ -201,41 +202,8 @@ def __init__(
self.mean_balance_check_period = mean_balance_check_period
self.mean_block_selection_delay = mean_block_selection_delay

# We delay the reachability check to the end of init, so the server has time to join libp2p relays
if not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS:
self._check_reachability()

self.stop = threading.Event()

def _check_reachability(self, n_retries=10, retry_delay=30):
for i in range(n_retries):
try:
r = requests.get(f"http://health.petals.ml/api/v1/is_reachable/{self.dht.peer_id}", timeout=10)
r.raise_for_status()
response = r.json()

if response["success"]:
logger.info("Server is reachable from the Internet. It will appear at http://health.petals.ml soon")
return

if i < n_retries - 1:
logger.info(f"Server is not reachable from the Internet yet. Retrying in {retry_delay} sec")
time.sleep(retry_delay)
except Exception as e:
logger.warning(f"Skipping reachability check because health.petals.ml is down: {repr(e)}")
return

raise RuntimeError(
f"Server has not become reachable from the Internet:\n\n"
f"{response['message']}\n\n"
f"You need to fix your port forwarding and/or firewall settings. How to do that:\n\n"
f" 1. Choose a specific port for the Petals server, for example, 31337.\n"
f" 2. Ensure that this port is accessible from the Internet and not blocked by your firewall.\n"
f" 3. Add these arguments to explicitly announce your IP address and port to other peers:\n"
f" python -m petals.cli.run_server ... --public_ip {response['your_ip']} --port 31337\n"
f" 4. If it does not help, ask for help in our Discord: https://discord.gg/Wuk8BnrEPH\n"
)

def _choose_num_blocks(self) -> int:
assert (
self.converted_model_name_or_path == "bigscience/bloom-petals"
Expand Down Expand Up @@ -305,6 +273,7 @@ def run(self):
use_auth_token=self.use_auth_token,
load_in_8bit=self.load_in_8bit,
tensor_parallel_devices=self.tensor_parallel_devices,
need_reachability_check=self.need_reachability_check,
start=True,
)
try:
Expand Down Expand Up @@ -398,6 +367,7 @@ def create(
use_auth_token: Optional[str],
load_in_8bit: bool,
tensor_parallel_devices: Sequence[torch.device],
need_reachability_check: bool,
**kwargs,
) -> ModuleContainer:
module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
Expand Down Expand Up @@ -451,6 +421,9 @@ def create(
min_batch_size=min_batch_size,
max_batch_size=max_batch_size,
)

if need_reachability_check:
check_reachability(dht.peer_id)
except:
logger.debug("Shutting down backends")
for backend in blocks.values():
Expand Down

0 comments on commit 8c69eb5

Please sign in to comment.