Skip to content

Commit

Permalink
[train] group consecutive workers by IP (ray-project#38490)
Browse files Browse the repository at this point in the history
Signed-off-by: Matthew Deng <matt@anyscale.com>
Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
  • Loading branch information
matthewdeng authored and arvind-chandra committed Aug 31, 2023
1 parent ea318a4 commit 64335e9
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 51 deletions.
7 changes: 4 additions & 3 deletions python/ray/train/_internal/backend_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,10 @@ def start(
# trainable, thus allowing for lazy checkpoint transfer to be used.
# See https://github.com/ray-project/ray/issues/33073
# for more context.
# TODO remove
if self._trial_info and self._trial_info.driver_ip:
self.worker_group._move_workers_with_ip_to_front(self._trial_info.driver_ip)
# TODO remove passing in trial_driver_ip.

trial_driver_ip = self._trial_info.driver_ip if self._trial_info else None
self.worker_group.group_workers_by_ip(trial_driver_ip)

worker_locs = [
f"{w.metadata.pid} ({w.metadata.node_ip})"
Expand Down
49 changes: 29 additions & 20 deletions python/ray/train/_internal/worker_group.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
import socket
from collections import defaultdict
from dataclasses import dataclass
from typing import Callable, List, TypeVar, Optional, Dict, Type, Tuple, Union

Expand Down Expand Up @@ -360,26 +361,34 @@ def add_workers(self, num_workers: int):
for i in range(len(new_actors)):
self.workers.append(Worker(actor=new_actors[i], metadata=metadata[i]))

def _move_workers_with_ip_to_front(self, ip):
# Hack to avoid OOMs.
# This is just a temporary solution for Train loading entire checkpoints
# into memory by ensuring that the rank 0 worker is on the same node as
# trainable, thus allowing for lazy checkpoint transfer to be used.
# See https://github.com/ray-project/ray/issues/33073
# for more context.
# TODO remove
workers_with_ip = []
indices_to_remove = set()
for i, worker in enumerate(self.workers):
if worker.metadata.node_ip == ip:
workers_with_ip.append(worker)
indices_to_remove.add(i)
if workers_with_ip:
self.workers = workers_with_ip + [
worker
for i, worker in enumerate(self.workers)
if i not in indices_to_remove
]
def group_workers_by_ip(self, _first_ip: Optional[str] = None):
"""Groups workers by IP.
This is useful for collocating workers on the same node.
Args:
_first_ip: The first IP to group by.
Hack to avoid OOMs.
This is just a temporary solution for Train loading entire checkpoints
into memory by ensuring that the rank 0 worker is on the same node as
trainable, thus allowing for lazy checkpoint transfer to be used.
See https://github.com/ray-project/ray/issues/33073
for more context.
TODO remove this argument.
"""
ip_to_workers = defaultdict(list)

if _first_ip is not None:
ip_to_workers[_first_ip] = []

for worker in self.workers:
ip_to_workers[worker.metadata.node_ip].append(worker)

sorted_workers = []
for workers in ip_to_workers.values():
sorted_workers.extend(workers)

self.workers = sorted_workers

def __len__(self):
return len(self.workers)
4 changes: 2 additions & 2 deletions python/ray/train/tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def train_func():
return train.get_context().get_local_world_size()

e.start_training(train_func, datasets={}, data_config=DataConfig())
assert list(e.finish_training()) == [2, 1, 2]
assert list(e.finish_training()) == [2, 2, 1]


def test_node_ranks(ray_2_node_2_cpu):
Expand All @@ -163,7 +163,7 @@ def train_func():
return train.get_context().get_node_rank()

e.start_training(train_func, datasets={}, data_config=DataConfig())
assert list(e.finish_training()) == [0, 1, 0]
assert list(e.finish_training()) == [0, 0, 1]


def test_train_failure(ray_start_2_cpus):
Expand Down
59 changes: 33 additions & 26 deletions python/ray/train/tests/test_worker_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

import ray
from ray.train._internal.worker_group import WorkerGroup, Worker, WorkerMetadata
from copy import deepcopy
from random import seed, shuffle


@pytest.fixture
Expand Down Expand Up @@ -83,32 +81,41 @@ def test_execute_args(ray_start_2_cpus):
assert all(o == 1 for o in outputs)


def test_move_workers_with_ip_to_front(ray_start_2_cpus):
wg = WorkerGroup(num_workers=2)
wg.workers = [
Worker(
actor=None,
metadata=WorkerMetadata(
node_id="dummy",
node_ip=f"10.1.10.{i}",
hostname="dummy",
gpu_ids=None,
pid=0,
),
)
for i in range(1, 17)
]
wg.workers += deepcopy(wg.workers)
workers_pre_move = deepcopy(wg.workers)
seed(1)
shuffle(wg.workers)
wg._move_workers_with_ip_to_front("10.1.10.1")
assert wg.workers[0].metadata.node_ip == "10.1.10.1"
assert wg.workers[1].metadata.node_ip == "10.1.10.1"
assert sorted([w.metadata.node_ip for w in workers_pre_move]) == sorted(
[w.metadata.node_ip for w in wg.workers]
def test_group_workers_by_ip(ray_start_2_cpus):
def create_worker_group(ips):
wg = WorkerGroup(num_workers=2)
wg.workers = [
Worker(
actor=None,
metadata=WorkerMetadata(
node_id="dummy",
node_ip=ip,
hostname="dummy",
gpu_ids=None,
pid=0,
),
)
for ip in ips
]
return wg

wg = create_worker_group(["2", "3", "1", "4", "2", "1", "3", "3", "4", "2"])
wg.group_workers_by_ip()
expected = ["2", "2", "2", "3", "3", "3", "1", "1", "4", "4"]
ips = [w.metadata.node_ip for w in wg.workers]
assert ips == expected, (
"Workers should be grouped by IP "
"and follow the same original order of IPs encountered (2, 3, 1, 4)."
)

wg = create_worker_group(["2", "3", "1", "4", "2", "1", "3", "3", "4", "2"])
wg.group_workers_by_ip(_first_ip="1")
expected = ["1", "1", "2", "2", "2", "3", "3", "3", "4", "4"]
ips = [w.metadata.node_ip for w in wg.workers]
assert (
ips == expected
), "Workers should be grouped by IP, with the first IP being 1."


def test_execute_single(ray_start_2_cpus):
wg = WorkerGroup(num_workers=2)
Expand Down

0 comments on commit 64335e9

Please sign in to comment.