diff --git a/dlrover/python/common/grpc.py b/dlrover/python/common/grpc.py index 335e50e2f..328b30781 100644 --- a/dlrover/python/common/grpc.py +++ b/dlrover/python/common/grpc.py @@ -333,6 +333,7 @@ class CommWorldRequest(RendezvousRequest): @dataclass class JoinRendezvousRequest(RendezvousRequest): + node_rank: int = -1 node_ip: str = "" # The IP of node where the pod is located. diff --git a/dlrover/python/elastic_agent/master_client.py b/dlrover/python/elastic_agent/master_client.py index 97a1b3124..b78fe044b 100644 --- a/dlrover/python/elastic_agent/master_client.py +++ b/dlrover/python/elastic_agent/master_client.py @@ -310,7 +310,8 @@ def num_nodes_waiting(self, rdzv_name): def join_rendezvous(self, node_rank, local_world_size, rdzv_name=""): request = grpc.JoinRendezvousRequest( - node_id=node_rank, + node_id=self._node_id, + node_rank=node_rank, local_world_size=local_world_size, rdzv_name=rdzv_name, node_ip=self._node_ip, diff --git a/dlrover/python/master/elastic_training/net_topology.py b/dlrover/python/master/elastic_training/net_topology.py index 9c9bd365e..c1d93d1c1 100644 --- a/dlrover/python/master/elastic_training/net_topology.py +++ b/dlrover/python/master/elastic_training/net_topology.py @@ -10,33 +10,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json from abc import ABCMeta, abstractmethod from collections import OrderedDict from dataclasses import dataclass -from typing import Any, Dict, List, Tuple +from typing import Dict, List, Tuple @dataclass class NodeTopologyMeta(object): + node_id: int = 0 node_rank: int = 0 process_num: int = 0 node_ip: str = "" asw: str = "" psw: str = "" - def __repr__(self) -> str: - d: Dict[str, Any] = {} - d["node_rank"] = self.node_rank - d["process_num"] = self.process_num - if self.node_ip: - d["node_ip"] = self.node_ip - if self.asw: - d["asw"] = self.asw - if self.psw: - d["psw"] = self.psw - return json.dumps(d) - class TopologyQuerier(metaclass=ABCMeta): @abstractmethod diff --git a/dlrover/python/master/elastic_training/rdzv_manager.py b/dlrover/python/master/elastic_training/rdzv_manager.py index 60037f2e0..1d705f2d0 100644 --- a/dlrover/python/master/elastic_training/rdzv_manager.py +++ b/dlrover/python/master/elastic_training/rdzv_manager.py @@ -95,11 +95,18 @@ def remove_alive_node(self, node: Node): """When a node is exited, the master will remove it from alive list.""" if node.id in self._alive_nodes: self._alive_nodes.remove(node.id) + self._has_node_failed = True + remove_rank = -1 + for rank, meta in self._waiting_nodes.items(): + if meta.node_id == node.id: + remove_rank = rank + break + if remove_rank > 0: + self._waiting_nodes.pop(remove_rank, None) logger.info( - f"Remove exited worker {node.name} from " - f"{self._name} rendezvous." + f"Remove exited worker {node.name} with rank {remove_rank} " + f" from {self._name} rendezvous." ) - self._has_node_failed = True def update_rdzv_params( self, min_nodes, max_ndoes, waiting_timeout, node_unit @@ -155,26 +162,32 @@ def _check_rdzv_completed(self): self._lastcall_time = 0 self._log_rendezvous_info() if self._waiting_nodes: - waiting_node_ranks = sorted(list(self._waiting_nodes.keys())) + waiting_node_ids = [] + for node in self._waiting_nodes.values(): + waiting_node_ids.append(node.node_id) logger.warning( f"Waiting nodes not in {self._rdzv_round} rendezvous " - f"are {waiting_node_ranks}." + f"are {waiting_node_ids}." ) elif time.time() - self._latest_log_nodes_time > 60: self._latest_log_nodes_time = time.time() - waiting_node_ranks = sorted(list(self._waiting_nodes.keys())) - logger.info( - f"Waiting nodes in rendezvous are {waiting_node_ranks}" - ) + waiting_nodes = {} + for rank, node in self._waiting_nodes.items(): + waiting_nodes[node.node_id] = rank + logger.info(f"Waiting nodes in rendezvous are {waiting_nodes}") return rdzv_completed def _log_rendezvous_info(self): - node_ranks = sorted(list(self._rdzv_nodes.keys())) + node_ranks = {} + for rank, node in self._rdzv_nodes.items(): + node_ranks[node.node_id] = rank + node_ranks = dict(sorted(node_ranks.items())) + node_rdzv_times = self._map_node_rank_to_id(self._node_rdzv_times) logger.info( f"Completed {self._rdzv_round} round " f"rendezvous of {self._name} is {node_ranks} \n" "The times of nodes to join rendezvous " - f"are {self._node_rdzv_times}." + f"are {node_rdzv_times}." ) self._node_rdzv_times.clear() if self._start_rdzv_ts > 0: @@ -188,14 +201,18 @@ def _log_rendezvous_info(self): def not_joined_rdzv_nodes(self): """Return workers which do not join a rendezvous.""" nodes = [] + join_node_ids = [] + for node in self._rdzv_nodes.values(): + join_node_ids.append(node.node_id) if self._rdzv_nodes: for node_id in self._alive_nodes: - if node_id not in self._rdzv_nodes: + if node_id not in join_node_ids: nodes.append(node_id) return nodes def join_rendezvous( self, + node_id, node_rank, local_world_size, node_ip="", @@ -220,6 +237,7 @@ def join_rendezvous( return self._rdzv_round asw, psw = self._topology_querier.query(node_ip) meta = NodeTopologyMeta( + node_id=node_id, node_rank=node_rank, node_ip=node_ip, process_num=local_world_size, @@ -235,6 +253,22 @@ def join_rendezvous( return self._rdzv_round + def _map_node_rank_to_id(self, rank_dict): + """ + Convert the dict with the node rank as the key to a dict + with the node id. Because, it is more clear to show the node + id in the log than the node rank. If the log shows the node rank, + users need to search which node has the node rank. + """ + id_dict = {} + for node_rank, v in rank_dict.items(): + if node_rank not in self._rdzv_nodes: + continue + node_id = self._rdzv_nodes[node_rank].node_id + id_dict[node_id] = v + id_dict = dict(sorted(id_dict.items())) + return id_dict + def num_nodes_waiting(self): """The elastic agent will restart training processes if it find the number of waiting nodes is not zero. The manager @@ -286,7 +320,7 @@ def get_comm_world( @abstractmethod def report_network_check_result( - self, node_id: int, normal: bool, elapsed_time: float + self, node_rank: int, normal: bool, elapsed_time: float ): """The node updates its status""" pass @@ -339,14 +373,17 @@ def get_comm_world( ) ranks = list(self._rdzv_nodes.keys()) node_ips = [] + node_ids = [] for node_rank in ranks: - node_ips.append(self._rdzv_nodes[node_rank].node_ip) + node = self._rdzv_nodes[node_rank] + node_ips.append(node.node_ip) + node_ids.append(node.node_id) logger.info( - f"Node ranks are {ranks}.\n Node IPs are {node_ips}" + f"Node ids are {node_ids}.\n Node IPs are {node_ips}" ) return self._rdzv_round, 0, self._rdzv_nodes - def report_network_check_result(self, node_id, normal, elapsed_time): + def report_network_check_result(self, node_rank, normal, elapsed_time): return @@ -388,13 +425,16 @@ def get_comm_world( self._fault_nodes.clear() self._straggler_nodes.clear() self._node_groups = self._group_nodes(self._rdzv_round) - rank_groups = [] + node_groups = [] for group in self._node_groups: - ranks = [rank for rank in group.keys()] - rank_groups.append(ranks) + ids = [ + self._rdzv_nodes[rank].node_id + for rank in group.keys() + ] + node_groups.append(ids) logger.info( f"Node groups of round {self._rdzv_round} " - f"are: {rank_groups}." + f"are: {node_groups}." ) if self._rdzv_round % 2 == 0: self._clear_check_status() @@ -460,7 +500,8 @@ def _group_nodes(self, round): def _check_abnormal_nodes(self): abnormal_nodes = [] normal_nodes = [] - for node_id, status in self._node_status.items(): + for node_rank, status in self._node_status.items(): + node_id = self._rdzv_nodes[node_rank].node_id if status: normal_nodes.append(node_id) else: @@ -471,27 +512,30 @@ def _check_abnormal_nodes(self): ) def report_network_check_result( - self, node_id: int, succeed: bool, elapsed_time: float + self, node_rank: int, succeed: bool, elapsed_time: float ): - self._reported_nodes.add(node_id) - self._node_status.setdefault(node_id, succeed) - self._node_times.setdefault(node_id, elapsed_time) - self._node_status[node_id] = self._node_status[node_id] or succeed - self._node_times[node_id] = round( - min(self._node_times[node_id], elapsed_time), 3 + self._reported_nodes.add(node_rank) + self._node_status.setdefault(node_rank, succeed) + self._node_times.setdefault(node_rank, elapsed_time) + self._node_status[node_rank] = self._node_status[node_rank] or succeed + self._node_times[node_rank] = round( + min(self._node_times[node_rank], elapsed_time), 3 ) if len(self._reported_nodes) == len(self._rdzv_nodes): + node_status = self._map_node_rank_to_id(self._node_status) logger.info( f"Round {self._rdzv_round}: The node status " - f"are {self._node_status}." + f"are {node_status}." ) + node_check_times = self._map_node_rank_to_id(self._node_times) logger.info( f"Round {self._rdzv_round}: The node elapsed time " - f"are {self._node_times}" + f"are {node_check_times}" ) def join_rendezvous( self, + node_id, node_rank, local_world_size, node_ip="", @@ -506,7 +550,9 @@ def join_rendezvous( int: the number of rendezvous round. """ self._node_groups.clear() - return super().join_rendezvous(node_rank, local_world_size, node_ip) + return super().join_rendezvous( + node_id, node_rank, local_world_size, node_ip + ) def check_fault_node(self): """Check whether the job has fault nodes. Each task contains 2 rounds @@ -518,11 +564,14 @@ def check_fault_node(self): if not all_joined: reason = NetworkFailureReason.WAITING_NODE elif len(self._fault_nodes) == 0: - for node_id, status in self._node_status.items(): + for node_rank, status in self._node_status.items(): if not status: - self._fault_nodes.add(node_id) + self._fault_nodes.add(node_rank) if len(self._fault_nodes) > 0: - logger.warning(f"Fault nodes {self._fault_nodes}") + fault_nodes = {} + for rank in self._fault_nodes: + fault_nodes[rank] = self._rdzv_nodes[rank].node_id + logger.warning(f"Fault nodes are {fault_nodes}") stragglers = self._detect_stragglers() if not self._fault_nodes and not stragglers: self._rdzv_round = ( diff --git a/dlrover/python/master/servicer.py b/dlrover/python/master/servicer.py index 9558e3aa3..048c27b8b 100644 --- a/dlrover/python/master/servicer.py +++ b/dlrover/python/master/servicer.py @@ -242,8 +242,14 @@ def _check_straggler(self): def _join_rendezvous(self, request: grpc.JoinRendezvousRequest): rdzv_manager = self._rdzv_managers[request.rdzv_name] + node_rank = request.node_rank + if node_rank == -1: # Back compatibility + node_rank = request.node_id round = rdzv_manager.join_rendezvous( - request.node_id, request.local_world_size, request.node_ip + request.node_id, + node_rank, + request.local_world_size, + request.node_ip, ) if request.rdzv_name == RendezvousName.NETWORK_CHECK: # The waiting node in the training rdzv should clear if @@ -266,8 +272,8 @@ def _get_comm_world(self, request: grpc.CommWorldRequest): res = grpc.RendezvousState(world={}) res.group = group res.round = rdzv_round - for rank_id, meta in nodes.items(): - res.world[rank_id] = meta.process_num + for rank, meta in nodes.items(): + res.world[rank] = meta.process_num if nodes and request.rdzv_name == RendezvousName.ELASTIC_TRAINING: rdzv_round = rdzv_manager.get_rdzv_round() metrics = {CustomMetricKeys.RDZV_ROUND: rdzv_round} diff --git a/dlrover/python/tests/test_elastic_training_agent.py b/dlrover/python/tests/test_elastic_training_agent.py index 86cc951e7..0bdb49690 100644 --- a/dlrover/python/tests/test_elastic_training_agent.py +++ b/dlrover/python/tests/test_elastic_training_agent.py @@ -75,11 +75,10 @@ def setUp(self) -> None: ) master_addr = "127.0.0.1" - node_id = 0 self.rdzv_handler = MasterRendezvousHandler( RendezvousName.ELASTIC_TRAINING, - node_id, + 0, rdzv_parameters, local_world_size=self.config.nproc_per_node, ) @@ -121,18 +120,21 @@ def test_auto_configure(self): self.assertTrue(config.network_check) def test_rank0_rendzevous(self): - node_id = 0 agent = ElasticTrainingAgent( - node_rank=node_id, + node_rank=0, config=self.config, entrypoint="python", spec=self.spec, start_method=self.config.start_method, log_dir=self.config.log_dir, ) + + # Mock node rank 1 joins the rendezvous. + self.rdzv_handler._client._node_id = 1 self.rdzv_handler._client.join_rendezvous( 1, 8, self.rdzv_handler._name ) + agent._client._node_id = 0 agent._rendezvous(agent._worker_group) worker_group = agent._worker_group self.assertEqual(len(worker_group.workers), 8) @@ -147,22 +149,26 @@ def test_rank0_rendzevous(self): ) def test_rank1_rendzevous(self): - node_id = 1 agent = ElasticTrainingAgent( - node_rank=node_id, + node_rank=1, config=self.config, entrypoint="python", spec=self.spec, start_method=self.config.start_method, log_dir=self.config.log_dir, ) - self.rdzv_handler._node_rank = node_id + # Mock node rank 0 joins the rendezvous. + self.rdzv_handler._client._node_id = 0 self.rdzv_handler._client.join_rendezvous( 0, 8, self.rdzv_handler._name ) store = self.rdzv_handler._get_store(round=1, group=0) store.set("MASTER_ADDR", "127.0.0.1".encode()) store.set("MASTER_PORT", "12345".encode()) + + # Set the node id and rank as 1. + agent._client._node_id = 1 + self.spec.rdzv_handler._node_rank = 1 agent._rendezvous(agent._worker_group) worker_group = agent._worker_group self.assertEqual(len(worker_group.workers), 8) @@ -283,7 +289,7 @@ def test_report_resource_with_step(self): def test_check_network_rdzv_for_elastic_training(self): self._master.rdzv_managers[ RendezvousName.NETWORK_CHECK - ].join_rendezvous(0, 8) + ].join_rendezvous(0, 0, 8) with self.assertRaises(RendezvousOutSyncError): self.rdzv_handler._check_network_rdzv_for_elastic_training() diff --git a/dlrover/python/tests/test_net_topology.py b/dlrover/python/tests/test_net_topology.py index 45d2aca89..593b2cd07 100644 --- a/dlrover/python/tests/test_net_topology.py +++ b/dlrover/python/tests/test_net_topology.py @@ -34,7 +34,12 @@ def test_dp_topology_sorter(self): node_ip = f"192.168.0.{i}" asw, psw = sw_querier.query(node_ip) node = NodeTopologyMeta( - node_rank=i, process_num=8, node_ip=node_ip, asw=asw, psw=psw + node_id=i, + node_rank=i, + process_num=8, + node_ip=node_ip, + asw=asw, + psw=psw, ) nodes[i] = node sorted_nodes = sorter.sort(nodes) @@ -42,7 +47,7 @@ def test_dp_topology_sorter(self): self.assertListEqual(node_ranks, list(range(node_num))) for node in nodes.values(): - asw_index = node.node_rank % 3 + asw_index = node.node_id % 3 node.asw = f"asw-{asw_index}" sorted_nodes = sorter.sort(nodes) diff --git a/dlrover/python/tests/test_rdzv_manager.py b/dlrover/python/tests/test_rdzv_manager.py index 21554876a..89cc5b383 100644 --- a/dlrover/python/tests/test_rdzv_manager.py +++ b/dlrover/python/tests/test_rdzv_manager.py @@ -22,6 +22,9 @@ build_master_client, ) from dlrover.python.elastic_agent.torch.master_kv_store import MasterKVStore +from dlrover.python.master.elastic_training.net_topology import ( + NodeTopologyMeta, +) from dlrover.python.master.elastic_training.rdzv_manager import ( ElasticTrainingRendezvousManager, NetworkCheckRendezvousManager, @@ -62,14 +65,14 @@ def test_max_nodes(self): rdzv_round = rdzv_manager.get_rdzv_round() self.assertEqual(rdzv_round, 0) rdzv_manager._alive_nodes = [0, 1, 2] - rdzv_manager.join_rendezvous(0, 8) - rdzv_manager.join_rendezvous(1, 8) + rdzv_manager.join_rendezvous(0, 0, 8) + rdzv_manager.join_rendezvous(1, 1, 8) round, _, world = rdzv_manager.get_comm_world(0) self.assertEqual(round, 0) self.assertDictEqual(world, {}) self.assertEqual(len(rdzv_manager._waiting_nodes), 2) self.assertEqual(len(rdzv_manager._rdzv_nodes), 0) - rdzv_manager.join_rendezvous(2, 8) + rdzv_manager.join_rendezvous(2, 2, 8) self.assertDictEqual( rdzv_manager._node_rdzv_times, {0: 0.0, 1: 0.0, 2: 0.0} ) @@ -89,8 +92,8 @@ def test_min_nodes(self): rdzv_manager.add_alive_node(node_0) node_2 = Node("worker", 2) rdzv_manager.add_alive_node(node_2) - rdzv_manager.join_rendezvous(0, 8) - rdzv_manager.join_rendezvous(1, 8) + rdzv_manager.join_rendezvous(0, 0, 8) + rdzv_manager.join_rendezvous(1, 1, 8) rdzv_manager.remove_alive_node(node_2) self.assertEqual(len(rdzv_manager._alive_nodes), 2) self.assertEqual(len(rdzv_manager._waiting_nodes), 2) @@ -113,7 +116,7 @@ def test_min_nodes_with_unit(self): for i in range(test_loop): node = Node("worker", i, name=f"worker-{i}") rdzv_manager.add_alive_node(node) - rdzv_manager.join_rendezvous(i, 8) + rdzv_manager.join_rendezvous(i, i, 8) self.assertEqual(len(rdzv_manager._alive_nodes), test_loop) self.assertEqual(len(rdzv_manager._waiting_nodes), test_loop) self.assertEqual(len(rdzv_manager._rdzv_nodes), 0) @@ -131,8 +134,8 @@ def test_min_nodes_with_unit(self): # Test the number of waiting nodes is less than the node unit. self.assertEqual(rdzv_manager.num_nodes_waiting(), 0) - rdzv_manager.join_rendezvous(10, 8) - rdzv_manager.join_rendezvous(11, 8) + rdzv_manager.join_rendezvous(10, 10, 8) + rdzv_manager.join_rendezvous(11, 11, 8) self.assertEqual( len(rdzv_manager._waiting_nodes), rdzv_manager.num_nodes_waiting() ) @@ -147,14 +150,14 @@ def test_min_nodes_with_unit(self): rdzv_manager.add_alive_node(node_11) rdzv_manager.remove_alive_node(node_10) rdzv_manager.remove_alive_node(node_11) - self.assertEqual(len(rdzv_manager._waiting_nodes), 4) + self.assertEqual(len(rdzv_manager._waiting_nodes), 2) # Test the number of waiting nodes is equal or # bigger than the node unit. for i in range(12, 16): - rdzv_manager.join_rendezvous(i, 8) + rdzv_manager.join_rendezvous(i, i, 8) num = rdzv_manager.num_nodes_waiting() - self.assertEqual(num, 8) + self.assertEqual(num, 6) rdzv_manager.clear_waiting_nodes() num = rdzv_manager.num_nodes_waiting() self.assertEqual(num, 0) @@ -166,7 +169,7 @@ def test_network_check_rdzv(self): rdzv_manager.update_rdzv_params(4, 4, 60, 1) rdzv_manager._alive_nodes = [0, 1, 2, 3] for i in range(4): - round = rdzv_manager.join_rendezvous(i, 8) + round = rdzv_manager.join_rendezvous(i, i, 8) self.assertEqual(round, 0) round, group, world = rdzv_manager.get_comm_world(0) self.assertEqual(round, 1) @@ -184,7 +187,7 @@ def test_network_check_rdzv(self): rdzv_manager.report_network_check_result(3, False, 0.0) for i in range(4): - round = rdzv_manager.join_rendezvous(i, 8) + round = rdzv_manager.join_rendezvous(i, i, 8) self.assertEqual(round, 1) round, group, world = rdzv_manager.get_comm_world(0) self.assertEqual(round, 2) @@ -197,7 +200,7 @@ def test_network_check_rdzv(self): self.assertFalse(success) for i in range(4): - round = rdzv_manager.join_rendezvous(i, 8) + round = rdzv_manager.join_rendezvous(i, i, 8) self.assertEqual(round, 2) round, group, world = rdzv_manager.get_comm_world(3) self.assertEqual(round, 3) @@ -214,7 +217,7 @@ def test_network_check_rdzv_with_single_node(self): rdzv_manager = NetworkCheckRendezvousManager() rdzv_manager.update_rdzv_params(1, 1, 60, 1) rdzv_manager._alive_nodes = [0] - round = rdzv_manager.join_rendezvous(0, 8) + round = rdzv_manager.join_rendezvous(0, 0, 8) self.assertEqual(round, 0) round, _, world = rdzv_manager.get_comm_world(0) self.assertEqual(round, 1) @@ -231,7 +234,7 @@ def test_network_check_straggler_even_nodes(self): rdzv_manager.update_rdzv_params(6, 6, 60, 1) rdzv_manager._alive_nodes = [0, 1, 2, 3, 4, 5] for i in range(6): - round = rdzv_manager.join_rendezvous(i, 8) + round = rdzv_manager.join_rendezvous(i, i, 8) self.assertEqual(round, 0) round, group, world = rdzv_manager.get_comm_world(0) self.assertEqual(round, 1) @@ -248,7 +251,7 @@ def test_network_check_straggler_even_nodes(self): self.assertListEqual(stragglers, [4, 5]) for i in range(6): - round = rdzv_manager.join_rendezvous(i, 8) + round = rdzv_manager.join_rendezvous(i, i, 8) self.assertEqual(round, 1) round, group, world = rdzv_manager.get_comm_world(5) self.assertEqual(round, 2) @@ -266,7 +269,7 @@ def test_network_check_straggler_old_nodes(self): rdzv_manager.update_rdzv_params(5, 5, 60, 1) rdzv_manager._alive_nodes = [0, 1, 2, 3, 4] for i in range(5): - round = rdzv_manager.join_rendezvous(i, 8) + round = rdzv_manager.join_rendezvous(i, i, 8) self.assertEqual(round, 0) round, group, world = rdzv_manager.get_comm_world(0) self.assertEqual(round, 1) @@ -283,7 +286,7 @@ def test_network_check_straggler_old_nodes(self): self.assertListEqual(stragglers, [0, 1]) for i in range(5): - round = rdzv_manager.join_rendezvous(i, 8) + round = rdzv_manager.join_rendezvous(i, i, 8) self.assertEqual(round, 1) round, group, world = rdzv_manager.get_comm_world(1) self.assertEqual(round, 2) @@ -305,3 +308,14 @@ def test_sync_ckpt_nodes(self): self.assertTrue(success) success = rdzv_manager.sync_ckpt_nodes(1, 90) self.assertFalse(success) + + def test_map_node_rank_to_id(self): + rdzv_manager = ElasticTrainingRendezvousManager() + rdzv_manager._rdzv_nodes[0] = NodeTopologyMeta( + node_id=1, + node_rank=0, + process_num=8, + ) + rank_d = {0: True} + id_d = rdzv_manager._map_node_rank_to_id(rank_d) + self.assertDictEqual(id_d, {1: True})