From 36e7c5364da6baa0bb0bcae9307902d4d6e7881b Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 11 Aug 2022 19:17:58 +0800 Subject: [PATCH] [dask] Deterministic rank assignment. (#8018) --- python-package/xgboost/dask.py | 18 ++++++++-- python-package/xgboost/tracker.py | 59 ++++++++++++++++++++++--------- tests/python/test_tracker.py | 32 +++++++++++++++++ 3 files changed, 90 insertions(+), 19 deletions(-) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 22d284f798b1..951676a81757 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -170,9 +170,11 @@ def _try_start_tracker( use_logger=False, ) else: - assert isinstance(addrs[0], str) or addrs[0] is None + addr = addrs[0] + assert isinstance(addr, str) or addr is None + host_ip = get_host_ip(addr) rabit_context = RabitTracker( - host_ip=get_host_ip(addrs[0]), n_workers=n_workers, use_logger=False + host_ip=host_ip, n_workers=n_workers, use_logger=False, sortby="task" ) env.update(rabit_context.worker_envs()) rabit_context.start(n_workers) @@ -222,8 +224,16 @@ class RabitContext(rabit.RabitContext): def __init__(self, args: List[bytes]) -> None: super().__init__(args) worker = distributed.get_worker() + with distributed.worker_client() as client: + info = client.scheduler_info() + w = info["workers"][worker.address] + wid = w["id"] + # We use task ID for rank assignment which makes the RABIT rank consistent (but + # not the same as task ID is string and "10" is sorted before "2") with dask + # worker ID. This outsources the rank assignment to dask and prevents + # non-deterministic issue. self.args.append( - ("DMLC_TASK_ID=[xgboost.dask]:" + str(worker.address)).encode() + (f"DMLC_TASK_ID=[xgboost.dask-{wid}]:" + str(worker.address)).encode() ) @@ -841,6 +851,8 @@ async def _get_rabit_args( except Exception: # pylint: disable=broad-except sched_addr = None + # make sure all workers are online so that we can obtain reliable scheduler_info + client.wait_for_workers(n_workers) env = await client.run_on_scheduler( _start_tracker, n_workers, sched_addr, user_addr ) diff --git a/python-package/xgboost/tracker.py b/python-package/xgboost/tracker.py index e19181bf4aea..6dc6167d9517 100644 --- a/python-package/xgboost/tracker.py +++ b/python-package/xgboost/tracker.py @@ -32,15 +32,15 @@ def recvall(self, nbytes: int) -> bytes: chunk = self.sock.recv(min(nbytes - nread, 1024)) nread += len(chunk) res.append(chunk) - return b''.join(res) + return b"".join(res) def recvint(self) -> int: """Receive an integer of 32 bytes""" - return struct.unpack('@i', self.recvall(4))[0] + return struct.unpack("@i", self.recvall(4))[0] def sendint(self, value: int) -> None: """Send an integer of 32 bytes""" - self.sock.sendall(struct.pack('@i', value)) + self.sock.sendall(struct.pack("@i", value)) def sendstr(self, value: str) -> None: """Send a Python string""" @@ -69,6 +69,7 @@ def get_family(addr: str) -> int: class WorkerEntry: """Hanlder to each worker.""" + def __init__(self, sock: socket.socket, s_addr: Tuple[str, int]): worker = ExSocket(sock) self.sock = worker @@ -78,7 +79,7 @@ def __init__(self, sock: socket.socket, s_addr: Tuple[str, int]): worker.sendint(MAGIC_NUM) self.rank = worker.recvint() self.world_size = worker.recvint() - self.jobid = worker.recvstr() + self.task_id = worker.recvstr() self.cmd = worker.recvstr() self.wait_accept = 0 self.port: Optional[int] = None @@ -96,8 +97,8 @@ def decide_rank(self, job_map: Dict[str, int]) -> int: """Get the rank of current entry.""" if self.rank >= 0: return self.rank - if self.jobid != 'NULL' and self.jobid in job_map: - return job_map[self.jobid] + if self.task_id != "NULL" and self.task_id in job_map: + return job_map[self.task_id] return -1 def assign_rank( @@ -180,7 +181,12 @@ class RabitTracker: """ def __init__( - self, host_ip: str, n_workers: int, port: int = 0, use_logger: bool = False + self, + host_ip: str, + n_workers: int, + port: int = 0, + use_logger: bool = False, + sortby: str = "host", ) -> None: """A Python implementation of RABIT tracker. @@ -190,6 +196,13 @@ def __init__( Use logging.info for tracker print command. When set to False, Python print function is used instead. + sortby: + How to sort the workers for rank assignment. The default is host, but users + can set the `DMLC_TASK_ID` via RABIT initialization arguments and obtain + deterministic rank assignment. Available options are: + - host + - task + """ sock = socket.socket(get_family(host_ip), socket.SOCK_STREAM) sock.bind((host_ip, port)) @@ -200,6 +213,7 @@ def __init__( self.thread: Optional[Thread] = None self.n_workers = n_workers self._use_logger = use_logger + self._sortby = sortby logging.info("start listen on %s:%d", host_ip, self.port) def __del__(self) -> None: @@ -223,7 +237,7 @@ def worker_envs(self) -> Dict[str, Union[str, int]]: get environment variables for workers can be passed in as args or envs """ - return {'DMLC_TRACKER_URI': self.host_ip, 'DMLC_TRACKER_PORT': self.port} + return {"DMLC_TRACKER_URI": self.host_ip, "DMLC_TRACKER_PORT": self.port} def _get_tree(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int]]: tree_map: _TreeMap = {} @@ -296,8 +310,16 @@ def get_link_map(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int], _RingM parent_map_[rmap[k]] = -1 return tree_map_, parent_map_, ring_map_ + def _sort_pending(self, pending: List[WorkerEntry]) -> List[WorkerEntry]: + if self._sortby == "host": + pending.sort(key=lambda s: s.host) + elif self._sortby == "task": + pending.sort(key=lambda s: s.task_id) + return pending + def accept_workers(self, n_workers: int) -> None: """Wait for all workers to connect to the tracker.""" + # set of nodes that finishes the job shutdown: Dict[int, WorkerEntry] = {} # set of nodes that is waiting for connections @@ -341,27 +363,32 @@ def accept_workers(self, n_workers: int) -> None: assert todo_nodes pending.append(s) if len(pending) == len(todo_nodes): - pending.sort(key=lambda x: x.host) + pending = self._sort_pending(pending) for s in pending: rank = todo_nodes.pop(0) - if s.jobid != 'NULL': - job_map[s.jobid] = rank + if s.task_id != "NULL": + job_map[s.task_id] = rank s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map) if s.wait_accept > 0: wait_conn[rank] = s - logging.debug('Received %s signal from %s; assign rank %d', - s.cmd, s.host, s.rank) + logging.debug( + "Received %s signal from %s; assign rank %d", + s.cmd, + s.host, + s.rank, + ) if not todo_nodes: - logging.info('@tracker All of %d nodes getting started', n_workers) + logging.info("@tracker All of %d nodes getting started", n_workers) else: s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map) - logging.debug('Received %s signal from %d', s.cmd, s.rank) + logging.debug("Received %s signal from %d", s.cmd, s.rank) if s.wait_accept > 0: wait_conn[rank] = s - logging.info('@tracker All nodes finishes job') + logging.info("@tracker All nodes finishes job") def start(self, n_workers: int) -> None: """Strat the tracker, it will wait for `n_workers` to connect.""" + def run() -> None: self.accept_workers(n_workers) diff --git a/tests/python/test_tracker.py b/tests/python/test_tracker.py index 885221aae4ae..b9ae17531790 100644 --- a/tests/python/test_tracker.py +++ b/tests/python/test_tracker.py @@ -4,6 +4,7 @@ import testing as tm import numpy as np import sys +import re if sys.platform.startswith("win"): pytest.skip("Skipping dask tests on Windows", allow_module_level=True) @@ -58,3 +59,34 @@ def test_rabit_ops(): with LocalCluster(n_workers=n_workers) as cluster: with Client(cluster) as client: run_rabit_ops(client, n_workers) + + +def test_rank_assignment() -> None: + from distributed import Client, LocalCluster + from test_with_dask import _get_client_workers + + def local_test(worker_id): + with xgb.dask.RabitContext(args): + for val in args: + sval = val.decode("utf-8") + if sval.startswith("DMLC_TASK_ID"): + task_id = sval + break + matched = re.search(".*-([0-9]).*", task_id) + rank = xgb.rabit.get_rank() + # As long as the number of workers is lesser than 10, rank and worker id + # should be the same + assert rank == int(matched.group(1)) + + with LocalCluster(n_workers=8) as cluster: + with Client(cluster) as client: + workers = _get_client_workers(client) + args = client.sync( + xgb.dask._get_rabit_args, + len(workers), + None, + client, + ) + + futures = client.map(local_test, range(len(workers)), workers=workers) + client.gather(futures)