Skip to content

Commit

Permalink
Abort speedtest if it runs too long
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov committed May 9, 2023
1 parent 6eb306a commit 7ed2a1d
Showing 1 changed file with 22 additions and 11 deletions.
33 changes: 22 additions & 11 deletions src/petals/server/throughput.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import fcntl
import json
import math
import multiprocessing as mp
import os
import time
from collections import Counter
Expand Down Expand Up @@ -120,24 +121,25 @@ def measure_throughput_info(
}
try:
throughput_info["network_rps"] = measure_network_rps(config)
except Exception:
logger.warning("Failed to measure network throughput:", exc_info=True)
logger.warning("Proceeding with the compute throughput only")
except Exception as e:
logger.warning(f"Failed to measure network throughput: {repr(e)}. Proceeding with compute throughput only")
return throughput_info


def measure_network_rps(config: BloomConfig) -> Optional[float]:
s = speedtest.Speedtest()
s.get_servers()
s.get_best_server()
s.download()
s.upload()
network_info = s.results.dict()
def measure_network_rps(config: BloomConfig, *, timeout: float = 60) -> Optional[float]:
pipe_recv, pipe_send = mp.Pipe(duplex=False)
process = mp.Process(target=_measure_bits_per_second, args=(pipe_send,))
process.start()

if not pipe_recv.poll(timeout):
process.terminate()
raise RuntimeError(f"speedtest did not finish in {timeout} seconds")
network_info = pipe_recv.recv()

bits_per_request = config.hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward
network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
if network_rps == 0:
raise ValueError("speedtest has returned network_rps == 0")
raise RuntimeError("speedtest has returned network_rps == 0")

logger.info(
f"Network throughput: {network_rps:.1f} RPS "
Expand All @@ -147,6 +149,15 @@ def measure_network_rps(config: BloomConfig) -> Optional[float]:
return network_rps


def _measure_bits_per_second(pipe_send: mp.Pipe):
s = speedtest.Speedtest()
s.get_servers()
s.get_best_server()
s.download()
s.upload()
pipe_send.send(s.results.dict())


def measure_compute_rps(
config: BloomConfig,
device: torch.device,
Expand Down

0 comments on commit 7ed2a1d

Please sign in to comment.