From a729ba75f9f92ed9583a22ad2e3105d39367fd47 Mon Sep 17 00:00:00 2001 From: Abin Shahab Date: Tue, 16 Nov 2021 23:05:10 -0800 Subject: [PATCH] RayExecutor V2: Dynamic executor for elastic and static jobs (#3230) This resolves #3190 by adding elastic params to the RayExecutor API for horovod: This API now supports both static(non-elastic) and elastic horovod jobs. Example of static job(Identical to current RayExecutor): ```python from horovod.ray import RayExecutor ray.init() hjob = RayExecutor(setting, num_workers=num_workers, use_gpu=True )) executor.start() def simple_fn(): hvd.init() print("hvd rank", hvd.rank()) return hvd.rank() result = executor.run(simple_fn) assert len(set(result)) == hosts * num_slots executor.shutdown() ``` Example of an elastic job: ```python from horovod.ray import RayExecutor import horovod.torch as hvd def training_fn(): hvd.init() model = Model() torch.cuda.set_device(hvd.local_rank()) @hvd.elastic.run def train(state): for state.epoch in range(state.epoch, epochs): ... state.commit() state = hvd.elastic.TorchState(model, optimizer, batch=0, epoch=0) state.register_reset_callbacks([on_state_reset]) train(state) return executor = RayExecutor(settings, min_workers=1, use_gpu=True, cpus_per_worker=2) executor.start() executor.run(training_fn) ``` Signed-off-by: Abin Shahab --- CHANGELOG.md | 4 +- docs/ray.rst | 10 +- horovod/ray/adapter.py | 126 +++++++ horovod/ray/elastic.py | 1 + horovod/ray/elastic_v2.py | 533 +++++++++++++++++++++++++++++ horovod/ray/runner.py | 282 ++++++++++----- test/single/test_ray.py | 28 +- test/single/test_ray_elastic_v2.py | 371 ++++++++++++++++++++ 8 files changed, 1236 insertions(+), 119 deletions(-) create mode 100644 horovod/ray/adapter.py create mode 100644 horovod/ray/elastic_v2.py create mode 100644 test/single/test_ray_elastic_v2.py diff --git a/CHANGELOG.md b/CHANGELOG.md index a0f5a0a1aa..0093a4377e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,13 +7,15 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ## [Unreleased] - YYYY-MM-DD ### Added +- Added Elastic keyword parameters to RayExecutor API: This API supports both static(non-elastic) and elastic horovod jobs. This resolves issue: +[#3190](https://github.com/horovod/horovod/issues/3190). - TensorFlow: Added in-place broadcasting of variables. ([#3128](https://github.com/horovod/horovod/pull/3128)) ### Changed ### Deprecated - +- Deprecated ElasticRayExecutor APIs in favor of the new RayExecutor API for issue: [#3190](https://github.com/horovod/horovod/issues/3190). ### Removed ### Fixed diff --git a/docs/ray.rst b/docs/ray.rst index f5b77992b4..e15e5e9133 100644 --- a/docs/ray.rst +++ b/docs/ray.rst @@ -110,7 +110,7 @@ A unique feature of Ray is its support for `stateful Actors `_ via :ref:`the ElasticRayExecutor `. Similar to default Horovod, the difference between the non-elastic and elastic versions of Ray is that the hosts and number of workers is dynamically determined at runtime. +Ray also supports `elastic execution `_ via :ref:`the RayExecutor `. Similar to default Horovod, the difference between the non-elastic and elastic versions of Ray is that the hosts and number of workers is dynamically determined at runtime. You must first set up `a Ray cluster`_. Ray clusters can support autoscaling for any cloud provider (AWS, GCP, Azure). @@ -153,10 +153,12 @@ You can then attach to the underlying Ray cluster and execute the training funct .. code-block:: python import ray + from horovod.ray import RayExecutor + ray.init(address="auto") # attach to the Ray cluster - settings = ElasticRayExecutor.create_settings(verbose=True) - executor = ElasticRayExecutor( - settings, use_gpu=True, cpus_per_slot=2) + settings = RayExecutor.create_settings(verbose=True) + executor = RayExecutor( + settings, min_workers=1, use_gpu=True, cpus_per_slot=2) executor.start() executor.run(training_fn) diff --git a/horovod/ray/adapter.py b/horovod/ray/adapter.py new file mode 100644 index 0000000000..8cd722ff69 --- /dev/null +++ b/horovod/ray/adapter.py @@ -0,0 +1,126 @@ +from abc import ABC, abstractmethod +from typing import Dict, Callable, Any, Optional, List +from dataclasses import dataclass + +@dataclass +class BaseParams: + cpus_per_worker: int = 1 + use_gpu: bool = False + gpus_per_worker: Optional[int] = None + def __post_init__(self): + if self.gpus_per_worker and not self.use_gpu: + raise ValueError("gpus_per_worker is set, but use_gpu is False. " + "use_gpu must be True if gpus_per_worker is " + "set. ") + if self.use_gpu and isinstance(self.gpus_per_worker, + int) and self.gpus_per_worker < 1: + raise ValueError( + f"gpus_per_worker must be >= 1: Got {self.gpus_per_worker}.") + self.gpus_per_worker = self.gpus_per_worker or int(self.use_gpu) + + +class Adapter(ABC): + """Adapter for executing Ray calls for various types(e.g. static and elastic) + Horovod jobs. + """ + @abstractmethod + def start(self, + executable_cls: type = None, + executable_args: Optional[List] = None, + executable_kwargs: Optional[Dict] = None, + extra_env_vars: Optional[Dict] = None): + """Starts the Adapter + + Args: + executable_cls (type): The class that will be created within + an actor (BaseHorovodWorker). This will allow Horovod + to establish its connections and set env vars. + executable_args (List): Arguments to be passed into the + worker class upon initialization. + executable_kwargs (Dict): Keyword arguments to be passed into the + worker class upon initialization. + extra_env_vars (Dict): Environment variables to be set + on the actors (worker processes) before initialization. + """ + raise NotImplementedError("Method must be implemented in a subclass") + + @abstractmethod + def execute(self, fn: Callable[["executable_cls"], Any], + callbacks: Optional[List[Callable]] = None) -> List[Any]: + """Executes the provided function on all workers. + + Args: + fn: Target function to be invoked on every object. + callbacks: List of callables. Each callback must either + be a callable function or a class that implements __call__. + Every callback will be invoked on every value logged + by the rank 0 worker. + + Returns: + Deserialized return values from the target function. + """ + raise NotImplementedError("Method must be implemented in a subclass") + + @abstractmethod + def run(self, + fn: Callable[[Any], Any], + args: Optional[List] = None, + kwargs: Optional[Dict] = None, + callbacks: Optional[List[Callable]] = None) -> List[Any]: + """Executes the provided function on all workers. + + Args: + fn: Target function that can be executed with arbitrary + args and keyword arguments. + args: List of arguments to be passed into the target function. + kwargs: Dictionary of keyword arguments to be + passed into the target function. + callbacks: List of callables. Each callback must either + be a callable function or a class that implements __call__. + Every callback will be invoked on every value logged + by the rank 0 worker. + + Returns: + Deserialized return values from the target function. + """ + raise NotImplementedError("Method must be implemented in a subclass") + + @abstractmethod + def run_remote(self, + fn: Callable[[Any], Any], + args: Optional[List] = None, + kwargs: Optional[Dict] = None, + callbacks: Optional[List[Callable]] = None): + + """Executes the provided function on all workers. + + Args: + fn: Target function that can be executed with arbitrary + args and keyword arguments. + args: List of arguments to be passed into the target function. + kwargs: Dictionary of keyword arguments to be + passed into the target function. + + Returns: + list: List of ObjectRefs that you can run `ray.get` on to + retrieve values. + """ + raise NotImplementedError("Method must be implemented in a subclass") + + @abstractmethod + def execute_single(self, + fn: Callable[["executable_cls"], Any]) -> List[Any]: + """Executes the provided function on the rank 0 worker (chief). + + Args: + fn: Target function to be invoked on the chief object. + + Returns: + Deserialized return values from the target function. + """ + raise NotImplementedError("Method must be implemented in a subclass") + + @abstractmethod + def shutdown(self): + """Destroys the adapter.""" + raise NotImplementedError("Method must be implemented in a subclass") diff --git a/horovod/ray/elastic.py b/horovod/ray/elastic.py index 2cafadf59c..5395fd5ccc 100644 --- a/horovod/ray/elastic.py +++ b/horovod/ray/elastic.py @@ -179,6 +179,7 @@ class ElasticRayExecutor: settings, use_gpu=True, cpus_per_slot=2) executor.start() executor.run(train_fn) + warning:: .. deprecated:: 0.25.0 """ @staticmethod diff --git a/horovod/ray/elastic_v2.py b/horovod/ray/elastic_v2.py new file mode 100644 index 0000000000..20f92d2b19 --- /dev/null +++ b/horovod/ray/elastic_v2.py @@ -0,0 +1,533 @@ +from typing import Callable, List, Any, Dict, Optional +import logging +import ray.exceptions +import socket + +import time +import os +import random +import math +import threading +from dataclasses import dataclass + +from horovod.ray.adapter import Adapter, BaseParams +from horovod.runner.http.http_server import RendezvousServer +from horovod.ray.utils import detect_nics +from horovod.runner.elastic.rendezvous import create_rendezvous_handler +from horovod.runner.gloo_run import (create_slot_env_vars, create_run_env_vars, + _get_min_start_hosts) +from horovod.ray.worker import BaseHorovodWorker +from horovod.runner.elastic.discovery import HostDiscovery +from horovod.runner.elastic.driver import ElasticDriver + +import ray +import ray.exceptions +from horovod.ray.worker import BaseHorovodWorker +from horovod.ray.utils import detect_nics +logger = logging.getLogger(__name__) + +if hasattr(ray.exceptions, "GetTimeoutError"): + GetTimeoutError = ray.exceptions.GetTimeoutError +elif hasattr(ray.exceptions, "RayTimeoutError"): + GetTimeoutError = ray.exceptions.RayTimeoutError +else: + raise ImportError("Unable to find Ray Timeout Error class " + "(GetTimeoutError, RayTimeoutError). " + "This is likely due to the Ray version not " + "compatible with Horovod-Ray.") + + +class RayHostDiscovery(HostDiscovery): + """Uses Ray global state to obtain host mapping. + + Assumes that the whole global state is available for usage.""" + + def __init__(self, use_gpu=False, cpus_per_worker=1, gpus_per_worker=1): + self.use_gpu = use_gpu + self.cpus_per_worker = cpus_per_worker + self.gpus_per_worker = gpus_per_worker + logger.debug(f"Discovery started with {cpus_per_worker} CPU / " + f"{gpus_per_worker} GPU per slot.") + + def find_available_hosts_and_slots(self) -> Dict[str, int]: + """Returns a dict mapping -> .""" + alive_nodes = [k for k in ray.nodes() if k["alive"]] + host_mapping = {} + for node in alive_nodes: + hostname = node["NodeManagerAddress"] + resources = node["Resources"] + slots = resources.get("CPU", 0) // self.cpus_per_worker + if self.use_gpu: + gpu_slots = resources.get("GPU", 0) // self.gpus_per_worker + slots = min(slots, gpu_slots) + slots = int(math.ceil(slots)) + if slots: + host_mapping[hostname] = slots + + if host_mapping and sum(host_mapping.values()) == 0: + logger.info(f"Detected {len(host_mapping)} hosts, but no hosts " + "have available slots.") + logger.debug(f"Alive nodes: {alive_nodes}") + return host_mapping + + +class TestDiscovery(RayHostDiscovery): + def __init__(self, + min_hosts, + max_hosts, + change_frequency_s, + use_gpu=False, + cpus_per_worker=1, + gpus_per_worker=1, + verbose=True, + _graceful=True): + super().__init__( + use_gpu=use_gpu, + cpus_per_worker=cpus_per_worker, + gpus_per_worker=gpus_per_worker) + self._min_hosts = min_hosts + self._graceful = _graceful + self._max_hosts = max_hosts + self._change_frequency_s = change_frequency_s + self._last_reset_t = None + self.verbose = verbose + self._removed_hosts = set() + + def add_host(self, hosts): + available_hosts = self._removed_hosts & hosts.keys() + if available_hosts: + host = random.choice(list(available_hosts)) + self._removed_hosts.remove(host) + else: + print("No hosts to add.") + + def remove_host(self, hosts): + good_hosts = [k for k in hosts if k not in self._removed_hosts] + + from ray.autoscaler._private.commands import kill_node + if good_hosts: + if self._graceful: + host = random.choice(good_hosts) + else: + host = kill_node( + os.path.expanduser("~/ray_bootstrap_config.yaml"), True, + False, None) + self._removed_hosts.add(host) + + def change_hosts(self, hosts): + for host in self._removed_hosts: + if host not in hosts: + self._removed_hosts.remove(host) + current_hosts = len(hosts) - len(self._removed_hosts) + if current_hosts <= self._min_hosts: + self.add_host(hosts) + elif current_hosts >= self._max_hosts: + self.remove_host(hosts) + else: + if random.random() < 0.5: + self.add_host(hosts) + else: + self.remove_host(hosts) + + def find_available_hosts_and_slots(self): + t = time.time() + if self._last_reset_t is None: + self._last_reset_t = t + hosts = super().find_available_hosts_and_slots() + if t - self._last_reset_t >= self._change_frequency_s: + self.change_hosts(hosts) + self._last_reset_t = t + if self.verbose: + print(f"Total hosts: {len(hosts)}") + remaining = { + k: v + for k, v in hosts.items() if k not in self._removed_hosts + } + if self.verbose: + print(f"Remaining hosts: {len(remaining)} -- {remaining}") + return remaining + +@dataclass +class ElasticParams(BaseParams): + """Parameters for elastic jobs. + + Args: + min_workers (int): Minimum number of processes running for + training to continue. If number of available processes dips + below this threshold, then training will wait for + more instances to become available. + max_workers (int): Maximum number of training processes, + beyond which no additional processes will be created. + If not specified, then will be unbounded. + reset_limit (int): Maximum number of times that the training + job can scale up or down the number of workers after + which the job is terminated. + elastic_timeout (int): Timeout for elastic initialisation after + re-scaling the cluster. The default value is 600 seconds. + Alternatively, the environment variable + HOROVOD_ELASTIC_TIMEOUT can also be used. + cpus_per_worker (int): Number of CPU resources to allocate to + each worker. + use_gpu (bool): Whether to use GPU for allocation. TODO: this + can be removed. + gpus_per_worker (int): Number of GPU resources to allocate to + each worker. + + """ + min_workers: int = 1 + max_workers: int = None + reset_limit: int = None + elastic_timeout: int = 600 + override_discovery: bool = True + + @property + def elastic(self): + return True + + @property + def adapter(self): + return ElasticAdapter + +class ElasticAdapter(Adapter): + """Adapter for executing Ray calls for elastic Horovod jobs. + + Args: + settings (horovod.Settings): Configuration for job setup. You can + use a standard Horovod Settings object or create one directly + from RayExecutor.create_settings. + min_workers (int): Minimum number of processes running for + training to continue. If number of available processes dips + below this threshold, then training will wait for + more instances to become available. + max_workers (int): Maximum number of training processes, + beyond which no additional processes will be created. + If not specified, then will be unbounded. + reset_limit (int): Maximum number of times that the training + job can scale up or down the number of workers after + which the job is terminated. + elastic_timeout (int): Timeout for elastic initialisation after + re-scaling the cluster. The default value is 600 seconds. + Alternatively, the environment variable + HOROVOD_ELASTIC_TIMEOUT can also be used.' + cpus_per_worker (int): Number of CPU resources to allocate to + each worker. + use_gpu (bool): Whether to use GPU for allocation. TODO: this + can be removed. + gpus_per_worker (int): Number of GPU resources to allocate to + each worker. + override_discovery (bool): Whether for the ElasticRayExecutor to + automatically provide a discovery mechanism for ElasticSettings. + + """ + def __init__(self, + settings, + min_workers: int, + max_workers: Optional[int] = None, + use_gpu: bool = False, + cpus_per_worker: int = 1, + gpus_per_worker: Optional[int] = None, + override_discovery: bool=True, + reset_limit: int = None, + elastic_timeout: int = 600): + self.settings = settings + if override_discovery: + settings.discovery = RayHostDiscovery( + use_gpu=use_gpu, + cpus_per_worker=cpus_per_worker, + gpus_per_worker=gpus_per_worker) + self.cpus_per_worker = cpus_per_worker + self.gpus_per_worker = gpus_per_worker + self.use_gpu = use_gpu + # moved from settings + self.min_workers = min_workers + self.max_workers = max_workers + self.num_workers = min_workers + self.reset_limit = reset_limit + self.elastic_timeout = elastic_timeout + self.driver = None + self.rendezvous = None + + def start(self, + executable_cls: type = None, + executable_args: Optional[List] = None, + executable_kwargs: Optional[Dict] = None, + extra_env_vars: Optional[Dict] = None): + """Starts the Horovod driver and services. + + Args: + executable_cls (type): The class that will be created within + an actor (BaseHorovodWorker). This will allow Horovod + to establish its connections and set env vars. + executable_args (List): Arguments to be passed into the + worker class upon initialization. + executable_kwargs (Dict): Keyword arguments to be passed into the + worker class upon initialization. + extra_env_vars (Dict): Environment variables to be set + on the actors (worker processes) before initialization. + + """ + + self.rendezvous = RendezvousServer(self.settings.verbose) + self.driver = ElasticDriver( + rendezvous=self.rendezvous, + discovery=self.settings.discovery, + min_np=self.min_workers, + max_np=self.max_workers, + timeout=self.elastic_timeout, + reset_limit=self.reset_limit, + verbose=self.settings.verbose) + handler = create_rendezvous_handler(self.driver) + logger.debug("[ray] starting rendezvous") + global_rendezv_port = self.rendezvous.start(handler) + + logger.debug(f"[ray] waiting for {self.num_workers} to start.") + self.driver.wait_for_available_slots(self.num_workers) + + # Host-to-host common interface detection + # requires at least 2 hosts in an elastic job. + min_hosts = _get_min_start_hosts(self.settings) + current_hosts = self.driver.wait_for_available_slots( + self.num_workers, min_hosts=min_hosts) + logger.debug("[ray] getting common interfaces") + nics = detect_nics( + self.settings, + all_host_names=current_hosts.host_assignment_order, + ) + logger.debug("[ray] getting driver IP") + server_ip = socket.gethostbyname(socket.gethostname()) + self.run_env_vars = create_run_env_vars( + server_ip, nics, global_rendezv_port, elastic=True) + + self.executable_cls = executable_cls + self.executable_args = executable_args + self.executable_kwargs = executable_kwargs + self.env_vars = extra_env_vars or {} + + + def _create_resources(self, hostname: str): + resources = dict( + num_cpus=self.cpus_per_worker, + num_gpus=int(self.use_gpu) * self.gpus_per_worker, + resources={f"node:{hostname}": 0.01}) + return resources + + def _create_remote_worker(self, slot_info, worker_env_vars): + hostname = slot_info.hostname + loaded_worker_cls = self.remote_worker_cls.options( + **self._create_resources(hostname)) + + worker = loaded_worker_cls.remote() + worker.update_env_vars.remote(worker_env_vars) + worker.update_env_vars.remote(create_slot_env_vars(slot_info)) + if self.use_gpu: + visible_devices = ",".join( + [str(i) for i in range(slot_info.local_size)]) + worker.update_env_vars.remote({ + "CUDA_VISIBLE_DEVICES": + visible_devices + }) + return worker + + def _create_spawn_worker_fn(self, return_results: List, + worker_fn: Callable, + queue: "ray.util.Queue") -> Callable: + self.remote_worker_cls = ray.remote(BaseHorovodWorker) + # event = register_shutdown_event() + worker_env_vars = {} + worker_env_vars.update(self.run_env_vars.copy()) + worker_env_vars.update(self.env_vars.copy()) + worker_env_vars.update({"PYTHONUNBUFFERED": "1"}) + + def worker_loop(slot_info, events): + def ping_worker(worker): + # There is an odd edge case where a node can be removed + # before the remote worker is started, leading to a failure + # in trying to create the horovod mesh. + try: + ping = worker.execute.remote(lambda _: 1) + ray.get(ping, timeout=10) + except Exception as e: + logger.error(f"{slot_info.hostname}: Ping failed - {e}") + return False + return True + + worker = self._create_remote_worker(slot_info, worker_env_vars) + if not ping_worker(worker): + return 1, time.time() + + ray.get(worker.set_queue.remote(queue)) + future = worker.execute.remote(worker_fn) + + result = None + while result is None: + try: + # TODO: make this event driven at some point. + retval = ray.get(future, timeout=0.1) + return_results.append((slot_info.rank, retval)) + # Success + result = 0, time.time() + except GetTimeoutError: + # Timeout + if any(e.is_set() for e in events): + ray.kill(worker) + result = 1, time.time() + except Exception as e: + logger.error(f"{slot_info.hostname}[{slot_info.rank}]:{e}") + ray.kill(worker) + result = 1, time.time() + logger.debug(f"Worker ({slot_info}) routine is done!") + return result + + return worker_loop + + + def run(self, + fn: Callable[[Any], Any], + args: Optional[List] = None, + kwargs: Optional[Dict] = None, + callbacks: Optional[List[Callable]] = None) -> List[Any]: + """Executes the provided function on all workers. + + Args: + fn: Target function that can be executed with arbitrary + args and keyword arguments. + args: List of arguments to be passed into the target function. + kwargs: Dictionary of keyword arguments to be + passed into the target function. + callbacks: List of callables. Each callback must either + be a callable function or a class that implements __call__. + Every callback will be invoked on every value logged + by the rank 0 worker. + + Returns: + Deserialized return values from the target function. + """ + args = args or [] + kwargs = kwargs or {} + f = lambda _: fn(*args, **kwargs) + return self._run_remote(f, callbacks=callbacks) + + def _run_remote(self, + worker_fn: Callable, + callbacks: Optional[List[Callable]] = None) -> List[Any]: + """Executes the provided function on all workers. + + Args: + worker_fn: Target elastic function that can be executed. + callbacks: List of callables. Each callback must either + be a callable function or a class that implements __call__. + Every callback will be invoked on every value logged + by the rank 0 worker. + + Returns: + List of return values from every completed worker. + """ + return_values = [] + from ray.util.queue import Queue + import inspect + args = inspect.getfullargspec(Queue).args + if "actor_options" not in args: + # Ray 1.1 and less + _queue = Queue() + else: + _queue = Queue(actor_options={ + "num_cpus": 0, + "resources": { + ray.state.current_node_id(): 0.001 + } + }) + self.driver.start( + self.num_workers, + self._create_spawn_worker_fn(return_values, worker_fn, _queue)) + + def _process_calls(queue, callbacks, event): + if not callbacks: + return + while queue.actor: + if not queue.empty(): + result = queue.get_nowait() + for c in callbacks: + c(result) + # avoid slamming the CI + elif event.is_set(): + break + time.sleep(0.1) + + try: + event = threading.Event() + _callback_thread = threading.Thread( + target=_process_calls, + args=(_queue, callbacks, event), + daemon=True) + _callback_thread.start() + res = self.driver.get_results() + event.set() + if _callback_thread: + _callback_thread.join(timeout=60) + finally: + if hasattr(_queue, "shutdown"): + _queue.shutdown() + else: + done_ref = _queue.actor.__ray_terminate__.remote() + done, not_done = ray.wait([done_ref], timeout=5) + if not_done: + ray.kill(_queue.actor) + self.driver.stop() + + if res.error_message is not None: + raise RuntimeError(res.error_message) + + for name, value in sorted( + res.worker_results.items(), key=lambda item: item[1][1]): + exit_code, timestamp = value + if exit_code != 0: + raise RuntimeError( + 'Horovod detected that one or more processes ' + 'exited with non-zero ' + 'status, thus causing the job to be terminated. ' + 'The first process ' + 'to do so was:\nProcess name: {name}\nExit code: {code}\n' + .format(name=name, code=exit_code)) + + return_values = [ + value for k, value in sorted(return_values, key=lambda kv: kv[0]) + ] + return return_values + + def run_remote(self, + fn: Callable[[Any], Any]) -> List[Any]: + raise NotImplementedError("ObjectRefs cannot be returned from Elastic runs as the workers are ephemeral") + + def execute(self, fn: Callable[["executable_cls"], Any], + callbacks: Optional[List[Callable]] = None) -> List[Any]: + """Executes the provided function on all workers. + + Args: + fn: Target function to be invoked on every object. + callbacks: List of callables. Each callback must either + be a callable function or a class that implements __call__. + Every callback will be invoked on every value logged + by the rank 0 worker. + Returns: + Deserialized return values from the target function. + """ + return ray.get(self._run_remote(fn, callbacks=callbacks)) + + def execute_single(self, + fn: Callable[["executable_cls"], Any]) -> List[Any]: + """Executes the provided function on the rank 0 worker (chief). + + Args: + fn: Target function to be invoked on the chief object. + + Returns: + Deserialized return values from the target function. + """ + raise NotImplementedError("Elastic mode does not support execute_single. Please use the execute method instead") + + def shutdown(self): + """Destroys the driver.""" + if not self.driver: + return + assert self.driver.finished() + self.driver = None diff --git a/horovod/ray/runner.py b/horovod/ray/runner.py index 595e0361de..ab94d18424 100644 --- a/horovod/ray/runner.py +++ b/horovod/ray/runner.py @@ -1,19 +1,21 @@ import ray from ray.util.placement_group import get_current_placement_group -import warnings from collections import defaultdict -from dataclasses import dataclass +from dataclasses import dataclass, asdict import os from typing import Dict, Callable, Any, Optional, List import logging +import ray.exceptions +from horovod.ray.adapter import Adapter, BaseParams from horovod.runner.common.util import secret, timeout, hosts from horovod.runner.http.http_server import RendezvousServer from horovod.ray.utils import detect_nics, nics_to_env_var, map_blocking from horovod.ray.strategy import ColocatedStrategy, PGStrategy -logger = logging.getLogger(__name__) +from horovod.ray.elastic_v2 import ElasticParams +logger = logging.getLogger(__name__) @dataclass class MiniSettings: @@ -28,6 +30,7 @@ class MiniSettings: ssh_identity_file: str = None timeout_s: int = 300 placement_group_timeout_s: int = 100 + elastic: bool = False @property def start_timeout(self): @@ -126,6 +129,42 @@ def establish_rendezvous(self) -> Dict[str, str]: } +@dataclass +class StaticParams(BaseParams): + """Parameters for non-elastic jobs. + + Args: + num_workers (int): Number of workers to use for training. + cpus_per_worker (int): Number of CPU resources to allocate to + each worker. + use_gpu (bool): Whether to use GPU for allocation. TODO: this + can be removed. + gpus_per_worker (int): Number of GPU resources to allocate to + each worker. + num_hosts (int): Alternative API to ``num_workers``. Number of + machines to execute the job on. Used to enforce equal number of + workers on each machine. + num_workers_per_host (int): Alternative API to + ``num_workers``. Number of workers to be placed on each machine. + Used to enforce equal number of workers on each machine. Only + used in conjunction with `num_hosts`. + use_current_placement_group (bool): Whether to use the current + placement group instead of creating a new one. Defaults to True. + + """ + num_workers: Optional[int] = None + num_hosts: Optional[int] = None + num_workers_per_host: int = 1 + use_current_placement_group: bool = True + + @property + def elastic(self): + return False + + @property + def adapter(self): + return StaticAdapter + class RayExecutor: """Job class for Horovod + Ray integration. @@ -149,15 +188,32 @@ class RayExecutor: used in conjunction with `num_hosts`. use_current_placement_group (bool): Whether to use the current placement group instead of creating a new one. Defaults to True. + min_workers (int): Minimum number of processes running for + training to continue. If number of available processes dips + below this threshold, then training will wait for + more instances to become available. + max_workers (int): Maximum number of training processes, + beyond which no additional processes will be created. + If not specified, then will be unbounded. + reset_limit (int): Maximum number of times that the training + job can scale up or down the number of workers after + which the job is terminated. + elastic_timeout (int): Timeout for elastic initialisation after + re-scaling the cluster. The default value is 600 seconds. + Alternatively, the environment variable + HOROVOD_ELASTIC_TIMEOUT can also be used. + override_discovery (bool): Whether for the ElasticRayExecutor to + automatically provide a discovery mechanism for ElasticSettings. """ @classmethod def create_settings(cls, - timeout_s, + timeout_s=30, ssh_identity_file=None, ssh_str=None, - placement_group_timeout_s=100): + placement_group_timeout_s=100, + nics=None): """Create a mini setting object. Args: @@ -168,6 +224,7 @@ def create_settings(cls, file contents. Writes the private key to ssh_identity_file. placement_group_timeout_s (int): Timeout parameter for Ray Placement Group creation. + nics (set): Network interfaces that can be used for communication. Returns: MiniSettings object. @@ -179,7 +236,8 @@ def create_settings(cls, return MiniSettings( ssh_identity_file=ssh_identity_file, timeout_s=timeout_s, - placement_group_timeout_s=placement_group_timeout_s) + placement_group_timeout_s=placement_group_timeout_s, + nics=nics) def __init__( self, @@ -191,66 +249,43 @@ def __init__( use_gpu: bool = False, gpus_per_worker: Optional[int] = None, use_current_placement_group: bool = True, - # Deprecated Args. - num_slots: Optional[int] = None, - cpus_per_slot: Optional[int] = None, - gpus_per_slot: Optional[int] = None): - - if num_slots: - warnings.warn( - "`num_slots` is now deprecated. Please use the `num_workers` " - "API, or to enforce an equal number of workers on each node, " - "set `num_hosts` and `num_workers_per_host`. " - "This will raise an error in a later release of Horovod. " - "Setting num_workers_per_host = num_slots.", - category=DeprecationWarning, - stacklevel=2) - num_workers_per_host = num_slots - - if cpus_per_slot or gpus_per_slot: - warnings.warn( - "`cpus_per_slot` and `gpus_per_slot` have been deprecated. " - "Use `cpus_per_worker` and `gpus_per_worker` instead. " - "This will raise an error in a later release of Horovod. " - "Setting cpus/gpus_per_slot = cpus/gpus_per_worker.", - category=DeprecationWarning, - stacklevel=2) - cpus_per_worker = cpus_per_slot - gpus_per_worker = gpus_per_slot - - if not (num_workers or num_hosts): - raise ValueError("One of `num_workers` or `num_hosts` must be " - "set.") - - if num_workers and num_hosts: - raise ValueError("Only one of `num_workers` and `num_hosts` must be " - "set.") - - if gpus_per_worker and not use_gpu: - raise ValueError("gpus_per_worker is set, but use_gpu is False. " - "use_gpu must be True if gpus_per_worker is " - "set. ") - if use_gpu and isinstance(gpus_per_worker, - int) and gpus_per_worker < 1: - raise ValueError( - f"gpus_per_worker must be >= 1: Got {gpus_per_worker}.") - - kwargs = dict( - num_workers=num_workers, - num_hosts=num_hosts, - num_workers_per_host=num_workers_per_host, - cpus_per_worker=cpus_per_worker, - use_gpu=use_gpu, - gpus_per_worker=gpus_per_worker, - use_current_placement_group=use_current_placement_group - ) - self._is_remote = False - if ray.util.client.ray.is_connected(): - RemoteDriver = ray.remote(_ExecutorDriver) - self.driver = RemoteDriver.remote(settings, **kwargs) - self._is_remote = True + + min_workers: int = None, + max_workers: int = None, + reset_limit: int = None, + elastic_timeout: int = 600, + override_discovery: bool = True + ): + if max_workers and (not min_workers or min_workers <= 0): + raise ValueError("`max_workers` provided without any positive `min_workers`" + "Elastic workloads require a positive `min_workers`") + if min_workers and num_workers: + raise ValueError("Both `min_workers` and `num_workers` provided." + "Only one of the above is allowed as workloads cannot be elastic and non-elastic.") + + if min_workers is not None: + self.params = ElasticParams( + min_workers=min_workers, + max_workers=max_workers, + reset_limit=reset_limit, + elastic_timeout=elastic_timeout, + override_discovery=override_discovery, + cpus_per_worker=cpus_per_worker, + use_gpu=use_gpu, + gpus_per_worker=gpus_per_worker + ) else: - self.driver = _ExecutorDriver(settings, **kwargs) + self.params = StaticParams( + num_workers=num_workers, + num_hosts=num_hosts, + num_workers_per_host=num_workers_per_host, + cpus_per_worker=cpus_per_worker, + use_gpu=use_gpu, + gpus_per_worker=gpus_per_worker, + use_current_placement_group=use_current_placement_group + ) + self.settings = settings + self.settings.elastic = self.params.elastic def start(self, executable_cls: type = None, @@ -274,31 +309,51 @@ def start(self, worker class upon initialization. extra_env_vars (Dict): Environment variables to be set on the actors (worker processes) before initialization. - """ + self._initialize_adapter() + kwargs_ = dict( executable_cls=executable_cls, executable_args=executable_args, executable_kwargs=executable_kwargs, extra_env_vars=extra_env_vars) - return self._maybe_call_ray(self.driver.start, **kwargs_) + return self._maybe_call_ray(self.adapter.start, **kwargs_) + + def _initialize_adapter(self): + kwargs = asdict(self.params) + logger.debug(f"Kwargs: {kwargs}") + Adapter = self.params.adapter + self._is_remote = False + if ray.util.client.ray.is_connected(): + RemoteAdapter = ray.remote(Adapter) + self.adapter = RemoteAdapter.remote(self.settings, **kwargs) + self._is_remote = True + else: + self.adapter= Adapter(self.settings, **kwargs) - def execute(self, fn: Callable[["executable_cls"], Any]) -> List[Any]: + def execute(self, fn: Callable[["executable_cls"], Any], + callbacks: Optional[List[Callable]] = None) -> List[Any]: """Executes the provided function on all workers. Args: fn: Target function to be invoked on every object. + callbacks: List of callables. Each callback must either + be a callable function or a class that implements __call__. + Every callback will be invoked on every value logged + by the rank 0 worker. Returns: Deserialized return values from the target function. """ - kwargs_ = dict(fn=fn) - return self._maybe_call_ray(self.driver.execute, **kwargs_) + kwargs_ = dict(fn=fn, callbacks=callbacks) + # invoke run_remote + return self._maybe_call_ray(self.adapter.execute, **kwargs_) def run(self, fn: Callable[[Any], Any], args: Optional[List] = None, - kwargs: Optional[Dict] = None) -> List[Any]: + kwargs: Optional[Dict] = None, + callbacks: Optional[List[Callable]] = None) -> List[Any]: """Executes the provided function on all workers. Args: @@ -307,12 +362,16 @@ def run(self, args: List of arguments to be passed into the target function. kwargs: Dictionary of keyword arguments to be passed into the target function. + callbacks: List of callables. Each callback must either + be a callable function or a class that implements __call__. + Every callback will be invoked on every value logged + by the rank 0 worker. Returns: Deserialized return values from the target function. """ - kwargs_ = dict(fn=fn, args=args, kwargs=kwargs) - return self._maybe_call_ray(self.driver.run, **kwargs_) + kwargs_ = dict(fn=fn, args=args, kwargs=kwargs, callbacks=callbacks) + return self._maybe_call_ray(self.adapter.run, **kwargs_) def run_remote(self, fn: Callable[[Any], Any], @@ -332,7 +391,7 @@ def run_remote(self, retrieve values. """ kwargs_ = dict(fn=fn, args=args, kwargs=kwargs) - return self._maybe_call_ray(self.driver.run_remote, **kwargs_) + return self._maybe_call_ray(self.adapter.run_remote, **kwargs_) def execute_single(self, fn: Callable[["executable_cls"], Any]) -> List[Any]: @@ -345,12 +404,12 @@ def execute_single(self, Deserialized return values from the target function. """ kwargs = dict(fn=fn) - return self._maybe_call_ray(self.driver.execute_single, **kwargs) + return self._maybe_call_ray(self.adapter.execute_single, **kwargs) def shutdown(self): """Destroys the provided workers.""" - result = self._maybe_call_ray(self.driver.shutdown) - del self.driver + result = self._maybe_call_ray(self.adapter.shutdown) + del self.adapter return result def _maybe_call_ray(self, driver_func, *args, **kwargs): @@ -360,9 +419,31 @@ def _maybe_call_ray(self, driver_func, *args, **kwargs): return driver_func(**kwargs) -class _ExecutorDriver: - """Base driver for executing Ray calls.""" +class StaticAdapter(Adapter): + """Adapter for executing Ray calls for non-elastic Horovod jobs. + + Args: + settings (horovod.Settings): Configuration for job setup. You can + use a standard Horovod Settings object or create one directly + from RayExecutor.create_settings. + num_workers (int): Number of workers to use for training. + cpus_per_worker (int): Number of CPU resources to allocate to + each worker. + use_gpu (bool): Whether to use GPU for allocation. TODO: this + can be removed. + gpus_per_worker (int): Number of GPU resources to allocate to + each worker. + num_hosts (int): Alternative API to ``num_workers``. Number of + machines to execute the job on. Used to enforce equal number of + workers on each machine. + num_workers_per_host (int): Alternative API to + ``num_workers``. Number of workers to be placed on each machine. + Used to enforce equal number of workers on each machine. Only + used in conjunction with `num_hosts`. + use_current_placement_group (bool): Whether to use the current + placement group instead of creating a new one. Defaults to True. + """ def __init__(self, settings, num_workers: Optional[int] = None, @@ -477,21 +558,27 @@ def start(self, self._start_executables(executable_cls, executable_args, executable_kwargs) - def execute(self, fn: Callable[["executable_cls"], Any]) -> List[Any]: + def execute(self, fn: Callable[["executable_cls"], Any], + callbacks: Optional[List[Callable]] = None) -> List[Any]: """Executes the provided function on all workers. Args: fn: Target function to be invoked on every object. + callbacks: List of callables. Each callback must either + be a callable function or a class that implements __call__. + Every callback will be invoked on every value logged + by the rank 0 worker. Returns: Deserialized return values from the target function. """ - return ray.get([worker.execute.remote(fn) for worker in self.workers]) + return ray.get(self._run_remote(fn)) def run(self, fn: Callable[[Any], Any], args: Optional[List] = None, - kwargs: Optional[Dict] = None) -> List[Any]: + kwargs: Optional[Dict] = None, + callbacks: Optional[List[Callable]] = None) -> List[Any]: """Executes the provided function on all workers. Args: @@ -504,12 +591,17 @@ def run(self, Returns: Deserialized return values from the target function. """ - return ray.get(self.run_remote(fn, args, kwargs)) + args = args or [] + kwargs = kwargs or {} + f = lambda w: fn(*args, **kwargs) + return ray.get(self._run_remote(fn=f)) def run_remote(self, fn: Callable[[Any], Any], args: Optional[List] = None, - kwargs: Optional[Dict] = None) -> List[Any]: + kwargs: Optional[Dict] = None, + callbacks: Optional[List[Callable]] = None): + """Executes the provided function on all workers. Args: @@ -525,9 +617,25 @@ def run_remote(self, """ args = args or [] kwargs = kwargs or {} + f = lambda w: fn(*args, **kwargs) + return self._run_remote(fn=f) + + def _run_remote(self, + fn: Callable[[Any], Any]) -> List[Any]: + """Executes the provided function on all workers. + + Args: + fn: Target function that can be executed with arbitrary + args and keyword arguments. + + Returns: + list: List of ObjectRefs that you can run `ray.get` on to + retrieve values. + """ + # Use run_remote for all calls + # for elastic, start the driver and launch the job return [ - worker.execute.remote(lambda w: fn(*args, **kwargs)) - for worker in self.workers + worker.execute.remote(fn) for worker in self.workers ] def execute_single(self, @@ -543,7 +651,7 @@ def execute_single(self, return ray.get(self.workers[0].execute.remote(fn)) def shutdown(self): - """Destroys the provided workers.""" + """Destroys the workers.""" for worker in self.workers: del worker diff --git a/test/single/test_ray.py b/test/single/test_ray.py index a51396f80c..de71ac4a1c 100644 --- a/test/single/test_ray.py +++ b/test/single/test_ray.py @@ -339,32 +339,6 @@ def rank_epoch(self): hjob.shutdown() -@pytest.mark.skipif( - not gloo_built(), reason='Gloo is required for Ray integration') -def test_ray_deprecation(ray_start_4_cpus): - class Executable: - def __init__(self, epochs): - import horovod.torch as hvd - self.hvd = hvd - self.epochs = epochs - self.hvd.init() - - def rank_epoch(self): - return self.hvd.rank() * self.epochs - - setting = RayExecutor.create_settings(timeout_s=30) - hjob = RayExecutor( - setting, - num_hosts=1, - num_slots=2, - cpus_per_slot=2, - use_gpu=torch.cuda.is_available()) - hjob.start(executable_cls=Executable, executable_args=[2]) - result = hjob.execute(lambda w: w.rank_epoch()) - assert set(result) == {0, 2} - hjob.shutdown() - - def _train(batch_size=32, batch_per_iter=10): import torch.nn.functional as F import torch.optim as optim @@ -445,7 +419,7 @@ def simple_fn(worker): gpus_per_worker=int(torch.cuda.is_available()) or None, use_gpu=torch.cuda.is_available()) hjob.start() - assert not hjob.driver.strategy._created_placement_group + assert not hjob.adapter.strategy._created_placement_group result = hjob.execute(simple_fn) assert set(result) == {0, 1, 2, 3} hjob.shutdown() diff --git a/test/single/test_ray_elastic_v2.py b/test/single/test_ray_elastic_v2.py new file mode 100644 index 0000000000..be478e98b6 --- /dev/null +++ b/test/single/test_ray_elastic_v2.py @@ -0,0 +1,371 @@ +"""Ray-Horovod Elastic training unit tests. + +This is currently not run on the Ray CI. +""" +from contextlib import contextmanager +import psutil +import os +import socket + +import mock +import pytest +import ray + +from horovod.common.util import gloo_built +from horovod.runner.elastic.discovery import HostDiscovery +from horovod.ray.elastic_v2 import RayHostDiscovery +from horovod.ray.runner import RayExecutor + + + +@pytest.fixture +def ray_shutdown(): + yield + # The code after the yield will run as teardown code. + ray.shutdown() + + +@pytest.fixture +def ray_8_cpus(): + ray.init(num_cpus=8, resources={ + f"node:host-{i}": 1 for i in range(10)}) + yield + # The code after the yield will run as teardown code. + ray.shutdown() + + +@pytest.fixture +def ray_8_cpus_gpus(): + if "CUDA_VISIBLE_DEVICES" in os.environ: + if len(os.environ["CUDA_VISIBLE_DEVICES"].split(",")) < 8: + pytest.skip("Avoiding mismatched GPU machine.") + ray.init(num_cpus=8, num_gpus=8, resources={ + f"node:host-{i}": 1 for i in range(10)}) + try: + yield + finally: + # The code after the yield will run as teardown code. + ray.shutdown() + + +class TestRayDiscoverySuite: + @pytest.mark.skipif( + not gloo_built(), reason='Gloo is required for Ray integration') + def test_cpu_discovery(self, ray_shutdown): + ray.init(num_cpus=4, num_gpus=1) + discovery = RayHostDiscovery(cpus_per_worker=1) + mapping = discovery.find_available_hosts_and_slots() + assert len(mapping) == 1 + assert list(mapping.values()) == [4] + + @pytest.mark.skipif( + not gloo_built(), reason='Gloo is required for Ray integration') + def test_gpu_discovery(self, ray_shutdown): + ray.init(num_cpus=4, num_gpus=1) + discovery = RayHostDiscovery(use_gpu=True, cpus_per_worker=1) + mapping = discovery.find_available_hosts_and_slots() + assert len(mapping) == 1 + assert list(mapping.values()) == [1] + + @pytest.mark.skipif( + not gloo_built(), reason='Gloo is required for Ray integration') + def test_gpu_slot_discovery(self, ray_shutdown): + ray.init(num_cpus=4, num_gpus=4) + discovery = RayHostDiscovery( + use_gpu=True, cpus_per_worker=1, gpus_per_worker=2) + mapping = discovery.find_available_hosts_and_slots() + assert len(mapping) == 1 + assert list(mapping.values()) == [2] + + @pytest.mark.skipif( + not gloo_built(), reason='Gloo is required for Ray integration') + def test_multinode(self, monkeypatch): + def create_multi_node_mock(): + host_names = ["host-1", "host-2", "host-3"] + resources = {"GPU": 2, "CPU": 8} + + def create_node_entry(hostname): + return { + "NodeManagerAddress": hostname, + "Resources": resources.copy(), + "alive": True + } + + return map(create_node_entry, host_names) + + monkeypatch.setattr(ray, "nodes", create_multi_node_mock) + discovery = RayHostDiscovery(use_gpu=True, cpus_per_worker=1) + mapping = discovery.find_available_hosts_and_slots() + assert len(mapping) == 3 + assert list(mapping.values()) == [2, 2, 2] + + @pytest.mark.skipif( + not gloo_built(), reason='Gloo is required for Ray integration') + def test_multinode_gpus_per_slot(self, monkeypatch): + def create_multi_node_mock(): + host_names = ["host-1", "host-2", "host-3"] + resources = {"GPU": 2, "CPU": 8} + + def create_node_entry(hostname): + return { + "NodeManagerAddress": hostname, + "Resources": resources.copy(), + "alive": True + } + + return map(create_node_entry, host_names) + + monkeypatch.setattr(ray, "nodes", create_multi_node_mock) + discovery = RayHostDiscovery(use_gpu=True, gpus_per_worker=2) + mapping = discovery.find_available_hosts_and_slots() + assert len(mapping) == 3 + assert list(mapping.values()) == [1, 1, 1] + + @pytest.mark.skipif( + not gloo_built(), reason='Gloo is required for Ray integration') + def test_multinode_mismatch(self, monkeypatch): + def create_multi_node_mock(): + host_names = ["host-1", "host-2", "host-3"] + resources = {"CPU": 8} + + def create_node_entry(hostname): + return { + "NodeManagerAddress": hostname, + "Resources": resources.copy(), + "alive": True + } + + return map(create_node_entry, host_names) + + monkeypatch.setattr(ray, "nodes", create_multi_node_mock) + discovery = RayHostDiscovery(use_gpu=True, cpus_per_worker=1) + mapping = discovery.find_available_hosts_and_slots() + assert sum(mapping.values()) == 0 + + +class SimpleTestDiscovery(HostDiscovery): + def __init__(self, schedule): + self._schedule = schedule + self._generator = self.host_generator() + + def host_generator(self): + for iters, hosts in self._schedule: + iters = iters or 500 # max + for i in range(iters): + yield hosts + + def find_available_hosts_and_slots(self): + hostlist = next(self._generator) + hosts = {} + for item in hostlist: + host, slots = item.split(":") + slots = int(slots) + hosts[host] = slots + return hosts + + +class StatusCallback: + def __init__(self): + self._journal = [] + + def __call__(self, info_dict): + self._journal.append(info_dict) + + def fetch(self): + return self._journal.copy() + + +def _create_training_function(iterations): + def training_fn(): + import time + import torch + import horovod.torch as hvd + from horovod.ray import ray_logger + + hvd.init() + + model = torch.nn.Sequential(torch.nn.Linear(2, 2)) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + ray_logger.log({"started": True, "pid": os.getpid()}) + + @hvd.elastic.run + def train(state): + for state.epoch in range(state.epoch, iterations): + ray_logger.log({"training": True, "pid": os.getpid()}) + time.sleep(0.1) + state.commit() # triggers scale-up, scale-down + ray_logger.log({"finished": True, "pid": os.getpid()}) + + state = hvd.elastic.TorchState( + model, optimizer, batch=0, epoch=0, commits=0, rendezvous=0) + train(state) + return True + + return training_fn + + +@contextmanager +def fault_tolerance_patches(): + with mock.patch( + 'horovod.runner.elastic.driver.DISCOVER_HOSTS_FREQUENCY_SECS', + 0.1): + with mock.patch( + "horovod.runner.util.network.get_driver_ip", + return_value=socket.gethostbyname(socket.gethostname())): + yield + + +@pytest.mark.skipif( + not gloo_built(), reason='Gloo is required for Ray integration') +@pytest.mark.skip(reason='https://github.com/horovod/horovod/issues/3197') +def test_fault_tolerance_hosts_added_and_removed(ray_8_cpus): + with fault_tolerance_patches(): + discovery_schedule = [ + (10, ['host-1:2']), + (30, ['host-1:2', 'host-2:1', 'host-3:1']), + (None, ['host-2:1']), + ] + nics = list(psutil.net_if_addrs().keys())[0] + + settings = RayExecutor.create_settings(nics={nics}) + settings.discovery = SimpleTestDiscovery(discovery_schedule) + executor = RayExecutor( + settings, + min_workers=1, + cpus_per_worker=1, override_discovery=False) + + training_fn = _create_training_function(iterations=50) + executor.start() + trace = StatusCallback() + results = executor.run(training_fn, callbacks=[trace]) + assert len(results) == 1 + + events = trace.fetch() + assert sum(int("started" in e) for e in events) == 4, events + assert sum(int("finished" in e) for e in events) == 1, events + + +@pytest.mark.skipif( + not gloo_built(), reason='Gloo is required for Ray integration') +@pytest.mark.skip(reason='https://github.com/horovod/horovod/issues/3197') +def test_fault_tolerance_hosts_remove_and_add(ray_8_cpus): + with fault_tolerance_patches(): + discovery_schedule = [ + (10, ['host-1:2', 'host-2:1', 'host-3:2']), + (10, ['host-1:2']), + (None, ['host-1:2', 'host-4:1', 'host-5:1']), + ] + nics = list(psutil.net_if_addrs().keys())[0] + + settings = RayExecutor.create_settings(nics={nics}) + settings.discovery = SimpleTestDiscovery(discovery_schedule) + executor = RayExecutor(settings, + min_workers=1, cpus_per_worker=1, override_discovery=False) + + training_fn = _create_training_function(iterations=30) + executor.start() + trace = StatusCallback() + results = executor.run(training_fn, callbacks=[trace]) + assert len(results) == 4 + + events = trace.fetch() + assert sum(int("started" in e) for e in events) == 7, events + assert sum(int("finished" in e) for e in events) == 4, events + + +@pytest.mark.skipif( + not gloo_built(), reason='Gloo is required for Ray integration') +def test_max_np(ray_8_cpus): + with fault_tolerance_patches(): + discovery_schedule = [ + (10, ['host-1:2']), + (None, ['host-1:2', 'host-4:1', 'host-5:1']), + ] + nics = list(psutil.net_if_addrs().keys())[0] + + settings = RayExecutor.create_settings(nics={nics}) + settings.discovery = SimpleTestDiscovery(discovery_schedule) + executor = RayExecutor(settings, + min_workers=1, max_workers=2, cpus_per_worker=1, override_discovery=False) + + training_fn = _create_training_function(iterations=20) + executor.start() + trace = StatusCallback() + results = executor.run(training_fn, callbacks=[trace]) + assert len(results) == 2 + + events = trace.fetch() + assert sum(int("started" in e) for e in events) == 2, events + assert sum(int("finished" in e) for e in events) == 2, events + + +@pytest.mark.skipif( + not gloo_built(), reason='Gloo is required for Ray integration') +def test_min_np(ray_8_cpus): + with fault_tolerance_patches(): + discovery_schedule = [ + (10, ['host-1:1']), + (10, ['host-1:1', 'host-4:1', 'host-5:1']), + (None, ['host-1:1', 'host-4:1', 'host-5:1', 'host-6:1']), + ] + nics = list(psutil.net_if_addrs().keys())[0] + + settings = RayExecutor.create_settings(nics={nics}) + settings.discovery = SimpleTestDiscovery(discovery_schedule) + executor = RayExecutor(settings, + min_workers=4, + max_workers=4, + override_discovery=False + ) + + training_fn = _create_training_function(iterations=30) + executor.start() + trace = StatusCallback() + results = executor.run(training_fn, callbacks=[trace]) + assert len(results) == 4 + + events = trace.fetch() + assert sum(int("started" in e) for e in events) == 4, events + assert sum(int("finished" in e) for e in events) == 4, events + + +@pytest.mark.skipif( + not gloo_built(), reason='Gloo is required for Ray integration') +def test_gpu_e2e(ray_8_cpus_gpus): + with fault_tolerance_patches(): + discovery_schedule = [ + (10, ['host-1:1']), + (10, ['host-1:1', 'host-4:1', 'host-5:1']), + (None, ['host-1:1', 'host-4:1', 'host-5:1', 'host-6:1']), + ] + nics = list(psutil.net_if_addrs().keys())[0] + + settings = RayExecutor.create_settings(nics={nics}) + settings.discovery = SimpleTestDiscovery(discovery_schedule) + executor = RayExecutor(settings, + min_workers=4, max_workers=4, gpus_per_worker=1, use_gpu=True, override_discovery=False) + + training_fn = _create_training_function(iterations=30) + executor.start() + trace = StatusCallback() + results = executor.run(training_fn, callbacks=[trace]) + assert len(results) == 4 + + events = trace.fetch() + assert sum(int("started" in e) for e in events) == 4, events + assert sum(int("finished" in e) for e in events) == 4, events + +@pytest.mark.skipif( + not gloo_built(), reason='Gloo is required for Ray integration') +def test_both_num_workers_min_workers(ray_8_cpus): + settings = RayExecutor.create_settings() + with pytest.raises(ValueError, match=r"Both `min_workers` and `num_workers` provided."): + executor = RayExecutor( + settings, + min_workers=1, + num_workers=1, + cpus_per_worker=1) + +if __name__ == "__main__": + import sys + sys.exit(pytest.main(sys.argv[1:] + ["-v", "-x", __file__]))