Skip to content

Commit

Permalink
Merge pull request #1167 from workingloong/join-rdzv-with-node-id
Browse files Browse the repository at this point in the history
The agent joins the rendzvous with the node id.
  • Loading branch information
BalaBalaYi committed Jun 25, 2024
2 parents 7f05d90 + b221170 commit 0c18cc5
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 81 deletions.
1 change: 1 addition & 0 deletions dlrover/python/common/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ class CommWorldRequest(RendezvousRequest):

@dataclass
class JoinRendezvousRequest(RendezvousRequest):
node_rank: int = -1
node_ip: str = "" # The IP of node where the pod is located.


Expand Down
3 changes: 2 additions & 1 deletion dlrover/python/elastic_agent/master_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,8 @@ def num_nodes_waiting(self, rdzv_name):

def join_rendezvous(self, node_rank, local_world_size, rdzv_name=""):
request = grpc.JoinRendezvousRequest(
node_id=node_rank,
node_id=self._node_id,
node_rank=node_rank,
local_world_size=local_world_size,
rdzv_name=rdzv_name,
node_ip=self._node_ip,
Expand Down
16 changes: 2 additions & 14 deletions dlrover/python/master/elastic_training/net_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple
from typing import Dict, List, Tuple


@dataclass
class NodeTopologyMeta(object):
node_id: int = 0
node_rank: int = 0
process_num: int = 0
node_ip: str = ""
asw: str = ""
psw: str = ""

def __repr__(self) -> str:
d: Dict[str, Any] = {}
d["node_rank"] = self.node_rank
d["process_num"] = self.process_num
if self.node_ip:
d["node_ip"] = self.node_ip
if self.asw:
d["asw"] = self.asw
if self.psw:
d["psw"] = self.psw
return json.dumps(d)


class TopologyQuerier(metaclass=ABCMeta):
@abstractmethod
Expand Down
117 changes: 83 additions & 34 deletions dlrover/python/master/elastic_training/rdzv_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,18 @@ def remove_alive_node(self, node: Node):
"""When a node is exited, the master will remove it from alive list."""
if node.id in self._alive_nodes:
self._alive_nodes.remove(node.id)
self._has_node_failed = True
remove_rank = -1
for rank, meta in self._waiting_nodes.items():
if meta.node_id == node.id:
remove_rank = rank
break
if remove_rank > 0:
self._waiting_nodes.pop(remove_rank, None)
logger.info(
f"Remove exited worker {node.name} from "
f"{self._name} rendezvous."
f"Remove exited worker {node.name} with rank {remove_rank} "
f" from {self._name} rendezvous."
)
self._has_node_failed = True

def update_rdzv_params(
self, min_nodes, max_ndoes, waiting_timeout, node_unit
Expand Down Expand Up @@ -155,26 +162,32 @@ def _check_rdzv_completed(self):
self._lastcall_time = 0
self._log_rendezvous_info()
if self._waiting_nodes:
waiting_node_ranks = sorted(list(self._waiting_nodes.keys()))
waiting_node_ids = []
for node in self._waiting_nodes.values():
waiting_node_ids.append(node.node_id)
logger.warning(
f"Waiting nodes not in {self._rdzv_round} rendezvous "
f"are {waiting_node_ranks}."
f"are {waiting_node_ids}."
)
elif time.time() - self._latest_log_nodes_time > 60:
self._latest_log_nodes_time = time.time()
waiting_node_ranks = sorted(list(self._waiting_nodes.keys()))
logger.info(
f"Waiting nodes in rendezvous are {waiting_node_ranks}"
)
waiting_nodes = {}
for rank, node in self._waiting_nodes.items():
waiting_nodes[node.node_id] = rank
logger.info(f"Waiting nodes in rendezvous are {waiting_nodes}")
return rdzv_completed

def _log_rendezvous_info(self):
node_ranks = sorted(list(self._rdzv_nodes.keys()))
node_ranks = {}
for rank, node in self._rdzv_nodes.items():
node_ranks[node.node_id] = rank
node_ranks = dict(sorted(node_ranks.items()))
node_rdzv_times = self._map_node_rank_to_id(self._node_rdzv_times)
logger.info(
f"Completed {self._rdzv_round} round "
f"rendezvous of {self._name} is {node_ranks} \n"
"The times of nodes to join rendezvous "
f"are {self._node_rdzv_times}."
f"are {node_rdzv_times}."
)
self._node_rdzv_times.clear()
if self._start_rdzv_ts > 0:
Expand All @@ -188,14 +201,18 @@ def _log_rendezvous_info(self):
def not_joined_rdzv_nodes(self):
"""Return workers which do not join a rendezvous."""
nodes = []
join_node_ids = []
for node in self._rdzv_nodes.values():
join_node_ids.append(node.node_id)
if self._rdzv_nodes:
for node_id in self._alive_nodes:
if node_id not in self._rdzv_nodes:
if node_id not in join_node_ids:
nodes.append(node_id)
return nodes

def join_rendezvous(
self,
node_id,
node_rank,
local_world_size,
node_ip="",
Expand All @@ -220,6 +237,7 @@ def join_rendezvous(
return self._rdzv_round
asw, psw = self._topology_querier.query(node_ip)
meta = NodeTopologyMeta(
node_id=node_id,
node_rank=node_rank,
node_ip=node_ip,
process_num=local_world_size,
Expand All @@ -235,6 +253,22 @@ def join_rendezvous(

return self._rdzv_round

def _map_node_rank_to_id(self, rank_dict):
"""
Convert the dict with the node rank as the key to a dict
with the node id. Because, it is more clear to show the node
id in the log than the node rank. If the log shows the node rank,
users need to search which node has the node rank.
"""
id_dict = {}
for node_rank, v in rank_dict.items():
if node_rank not in self._rdzv_nodes:
continue
node_id = self._rdzv_nodes[node_rank].node_id
id_dict[node_id] = v
id_dict = dict(sorted(id_dict.items()))
return id_dict

def num_nodes_waiting(self):
"""The elastic agent will restart training processes if it
find the number of waiting nodes is not zero. The manager
Expand Down Expand Up @@ -286,7 +320,7 @@ def get_comm_world(

@abstractmethod
def report_network_check_result(
self, node_id: int, normal: bool, elapsed_time: float
self, node_rank: int, normal: bool, elapsed_time: float
):
"""The node updates its status"""
pass
Expand Down Expand Up @@ -339,14 +373,17 @@ def get_comm_world(
)
ranks = list(self._rdzv_nodes.keys())
node_ips = []
node_ids = []
for node_rank in ranks:
node_ips.append(self._rdzv_nodes[node_rank].node_ip)
node = self._rdzv_nodes[node_rank]
node_ips.append(node.node_ip)
node_ids.append(node.node_id)
logger.info(
f"Node ranks are {ranks}.\n Node IPs are {node_ips}"
f"Node ids are {node_ids}.\n Node IPs are {node_ips}"
)
return self._rdzv_round, 0, self._rdzv_nodes

def report_network_check_result(self, node_id, normal, elapsed_time):
def report_network_check_result(self, node_rank, normal, elapsed_time):
return


Expand Down Expand Up @@ -388,13 +425,16 @@ def get_comm_world(
self._fault_nodes.clear()
self._straggler_nodes.clear()
self._node_groups = self._group_nodes(self._rdzv_round)
rank_groups = []
node_groups = []
for group in self._node_groups:
ranks = [rank for rank in group.keys()]
rank_groups.append(ranks)
ids = [
self._rdzv_nodes[rank].node_id
for rank in group.keys()
]
node_groups.append(ids)
logger.info(
f"Node groups of round {self._rdzv_round} "
f"are: {rank_groups}."
f"are: {node_groups}."
)
if self._rdzv_round % 2 == 0:
self._clear_check_status()
Expand Down Expand Up @@ -460,7 +500,8 @@ def _group_nodes(self, round):
def _check_abnormal_nodes(self):
abnormal_nodes = []
normal_nodes = []
for node_id, status in self._node_status.items():
for node_rank, status in self._node_status.items():
node_id = self._rdzv_nodes[node_rank].node_id
if status:
normal_nodes.append(node_id)
else:
Expand All @@ -471,27 +512,30 @@ def _check_abnormal_nodes(self):
)

def report_network_check_result(
self, node_id: int, succeed: bool, elapsed_time: float
self, node_rank: int, succeed: bool, elapsed_time: float
):
self._reported_nodes.add(node_id)
self._node_status.setdefault(node_id, succeed)
self._node_times.setdefault(node_id, elapsed_time)
self._node_status[node_id] = self._node_status[node_id] or succeed
self._node_times[node_id] = round(
min(self._node_times[node_id], elapsed_time), 3
self._reported_nodes.add(node_rank)
self._node_status.setdefault(node_rank, succeed)
self._node_times.setdefault(node_rank, elapsed_time)
self._node_status[node_rank] = self._node_status[node_rank] or succeed
self._node_times[node_rank] = round(
min(self._node_times[node_rank], elapsed_time), 3
)
if len(self._reported_nodes) == len(self._rdzv_nodes):
node_status = self._map_node_rank_to_id(self._node_status)
logger.info(
f"Round {self._rdzv_round}: The node status "
f"are {self._node_status}."
f"are {node_status}."
)
node_check_times = self._map_node_rank_to_id(self._node_times)
logger.info(
f"Round {self._rdzv_round}: The node elapsed time "
f"are {self._node_times}"
f"are {node_check_times}"
)

def join_rendezvous(
self,
node_id,
node_rank,
local_world_size,
node_ip="",
Expand All @@ -506,7 +550,9 @@ def join_rendezvous(
int: the number of rendezvous round.
"""
self._node_groups.clear()
return super().join_rendezvous(node_rank, local_world_size, node_ip)
return super().join_rendezvous(
node_id, node_rank, local_world_size, node_ip
)

def check_fault_node(self):
"""Check whether the job has fault nodes. Each task contains 2 rounds
Expand All @@ -518,11 +564,14 @@ def check_fault_node(self):
if not all_joined:
reason = NetworkFailureReason.WAITING_NODE
elif len(self._fault_nodes) == 0:
for node_id, status in self._node_status.items():
for node_rank, status in self._node_status.items():
if not status:
self._fault_nodes.add(node_id)
self._fault_nodes.add(node_rank)
if len(self._fault_nodes) > 0:
logger.warning(f"Fault nodes {self._fault_nodes}")
fault_nodes = {}
for rank in self._fault_nodes:
fault_nodes[rank] = self._rdzv_nodes[rank].node_id
logger.warning(f"Fault nodes are {fault_nodes}")
stragglers = self._detect_stragglers()
if not self._fault_nodes and not stragglers:
self._rdzv_round = (
Expand Down
12 changes: 9 additions & 3 deletions dlrover/python/master/servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,14 @@ def _check_straggler(self):

def _join_rendezvous(self, request: grpc.JoinRendezvousRequest):
rdzv_manager = self._rdzv_managers[request.rdzv_name]
node_rank = request.node_rank
if node_rank == -1: # Back compatibility
node_rank = request.node_id
round = rdzv_manager.join_rendezvous(
request.node_id, request.local_world_size, request.node_ip
request.node_id,
node_rank,
request.local_world_size,
request.node_ip,
)
if request.rdzv_name == RendezvousName.NETWORK_CHECK:
# The waiting node in the training rdzv should clear if
Expand All @@ -266,8 +272,8 @@ def _get_comm_world(self, request: grpc.CommWorldRequest):
res = grpc.RendezvousState(world={})
res.group = group
res.round = rdzv_round
for rank_id, meta in nodes.items():
res.world[rank_id] = meta.process_num
for rank, meta in nodes.items():
res.world[rank] = meta.process_num
if nodes and request.rdzv_name == RendezvousName.ELASTIC_TRAINING:
rdzv_round = rdzv_manager.get_rdzv_round()
metrics = {CustomMetricKeys.RDZV_ROUND: rdzv_round}
Expand Down
Loading

0 comments on commit 0c18cc5

Please sign in to comment.