Skip to content

Commit

Permalink
update ray evaluator to compute as many workers as the number of CPUs
Browse files Browse the repository at this point in the history
  • Loading branch information
Deathn0t committed Jun 22, 2020
1 parent 2861010 commit 266bb2b
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,56 +15,57 @@
def compute_objective(func, x):
return func(x)


class RayFuture:
FAIL_RETURN_VALUE = Evaluator.FAIL_RETURN_VALUE

def __init__(self, func, x):
self.id_res = compute_objective.remote(func, x)
self._state = 'active'
self._state = "active"
self._result = None

def _poll(self):
if not self._state == 'active':
if not self._state == "active":
return

id_done, _ = ray.wait([self.id_res], num_returns=1, timeout=0.001)

if len(id_done) == 1:
try:
self._result = ray.get(id_done[0])
self._state = 'done'
self._state = "done"
except Exception:
self._state = 'failed'
self._state = "failed"
else:
self._state = 'active'
self._state = "active"

def result(self):
if not self.done:
self._result = self.FAIL_RETURN_VALUE
return self._result

def cancel(self):
pass # NOT AVAILABLE YET
pass # NOT AVAILABLE YET

@property
def active(self):
self._poll()
return self._state == 'active'
return self._state == "active"

@property
def done(self):
self._poll()
return self._state == 'done'
return self._state == "done"

@property
def failed(self):
self._poll()
return self._state == 'failed'
return self._state == "failed"

@property
def cancelled(self):
self._poll()
return self._state == 'cancelled'
return self._state == "cancelled"


class RayEvaluator(Evaluator):
Expand All @@ -73,22 +74,26 @@ class RayEvaluator(Evaluator):
Args:
redis_address (str, optional): The "IP:PORT" redis address for the RAY-driver to connect on the RAY-head.
"""
WaitResult = namedtuple(
'WaitResult', ['active', 'done', 'failed', 'cancelled'])

WaitResult = namedtuple("WaitResult", ["active", "done", "failed", "cancelled"])

def __init__(self, run_function, cache_key=None, redis_address=None, **kwargs):
super().__init__(run_function, cache_key, **kwargs)

logger.info(f'RAY Evaluator init: redis-address={redis_address}')
logger.info(f"RAY Evaluator init: redis-address={redis_address}")

if not redis_address is None:
proc_info = ray.init(redis_address=redis_address)
else:
proc_info = ray.init()

self.num_workers = len(ray.nodes())
# self.num_workers = len(ray.nodes())
self.num_cpus = int(sum([node["Resources"]["CPU"] for node in ray.nodes()]))
self.num_workers = self.num_cpus

logger.info(f"RAY Evaluator will execute: '{self._run_function}', proc_info: {proc_info}")
logger.info(
f"RAY Evaluator will execute: '{self._run_function}', proc_info: {proc_info}"
)

def _eval_exec(self, x: dict):
assert isinstance(x, dict)
Expand All @@ -102,20 +107,25 @@ def _timer(timeout):
else:
timeout = max(float(timeout), 0.01)
start = time.time()
return lambda: (time.time()-start) < timeout
return lambda: (time.time() - start) < timeout

def wait(self, futures, timeout=None, return_when='ANY_COMPLETED'):
assert return_when.strip() in ['ANY_COMPLETED', 'ALL_COMPLETED']
waitall = bool(return_when.strip() == 'ALL_COMPLETED')
def wait(self, futures, timeout=None, return_when="ANY_COMPLETED"):
assert return_when.strip() in ["ANY_COMPLETED", "ALL_COMPLETED"]
waitall = bool(return_when.strip() == "ALL_COMPLETED")

num_futures = len(futures)
active_futures = [f for f in futures if f.active]
time_isLeft = self._timer(timeout)

if waitall:
def can_exit(): return len(active_futures) == 0

def can_exit():
return len(active_futures) == 0

else:
def can_exit(): return len(active_futures) < num_futures

def can_exit():
return len(active_futures) < num_futures

while time_isLeft():
if can_exit():
Expand All @@ -125,15 +135,17 @@ def can_exit(): return len(active_futures) < num_futures
time.sleep(0.04)

if not can_exit():
raise TimeoutError(f'{timeout} sec timeout expired while '
f'waiting on {len(futures)} tasks until {return_when}')
raise TimeoutError(
f"{timeout} sec timeout expired while "
f"waiting on {len(futures)} tasks until {return_when}"
)

results = defaultdict(list)
for f in futures:
results[f._state].append(f)
return self.WaitResult(
active=results['active'],
done=results['done'],
failed=results['failed'],
cancelled=results['cancelled']
active=results["active"],
done=results["done"],
failed=results["failed"],
cancelled=results["cancelled"],
)
23 changes: 19 additions & 4 deletions deephyper/evaluator/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,13 @@ class Evaluator:
assert os.path.isfile(PYTHON_EXE)

def __init__(
self, run_function, cache_key=None, encoder=Encoder, seed=None, **kwargs
self,
run_function,
cache_key=None,
encoder=Encoder,
seed=None,
num_workers=None,
**kwargs,
):
self.encoder = encoder # dict --> uuid
self.pending_evals = {} # uid --> Future
Expand All @@ -77,7 +83,7 @@ def __init__(
self.elapsed_times = {}

self._run_function = run_function
self.num_workers = 0
self.num_workers = num_workers

if (cache_key is not None) and (cache_key != "to_dict"):
if callable(cache_key):
Expand All @@ -99,7 +105,12 @@ def __init__(

@staticmethod
def create(
run_function, cache_key=None, method="subprocess", redis_address=None, **kwargs
run_function,
cache_key=None,
method="subprocess",
redis_address=None,
num_workers=None,
**kwargs,
):
available_methods = [
"balsam",
Expand Down Expand Up @@ -136,12 +147,16 @@ def create(

Eval = MPIWorkerPool(run_function, cache_key=cache_key, **kwargs)
elif method == "ray":
from deephyper.evaluator.ray_evaluator import RayEvaluator
from deephyper.evaluator._ray_evaluator import RayEvaluator

Eval = RayEvaluator(
run_function, cache_key=cache_key, redis_address=redis_address, **kwargs
)

# Override the number of workers if passed as an argument
if not (num_workers is None) and type(num_workers) is int:
Eval.num_workers = num_workers

return Eval

def encode(self, x):
Expand Down
22 changes: 18 additions & 4 deletions deephyper/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def __init__(
evaluator: str,
max_evals: int = 1000000,
seed: int = None,
num_nodes_master: int=1,
num_nodes_master: int = 1,
num_workers: int = None,
**kwargs,
):
kwargs["problem"] = problem
Expand All @@ -60,7 +61,13 @@ def __init__(
logger.info(notice)
util.banner(notice)

self.evaluator = Evaluator.create(self.run_func, method=evaluator, num_nodes_master=num_nodes_master, **kwargs)
self.evaluator = Evaluator.create(
self.run_func,
method=evaluator,
num_nodes_master=num_nodes_master,
num_workers=num_workers,
**kwargs,
)
self.num_workers = self.evaluator.num_workers
self.max_evals = max_evals

Expand Down Expand Up @@ -152,7 +159,7 @@ def _base_parser(parser=None) -> argparse.ArgumentParser:
"--num-evals-per-node",
default=1,
type=int,
help="Number of evaluations performed on each node. Only valid if evaluator==balsam and balsam job-mode is 'serial'."
help="Number of evaluations performed on each node. Only valid if evaluator==balsam and balsam job-mode is 'serial'.",
)
parser.add_argument(
"--num-nodes-per-eval",
Expand All @@ -164,5 +171,12 @@ def _base_parser(parser=None) -> argparse.ArgumentParser:
"--num-threads-per-rank",
default=64,
type=int,
help="Number of threads per MPI rank. Only valid if evaluator==balsam and balsam job-mode is 'mpi'.")
help="Number of threads per MPI rank. Only valid if evaluator==balsam and balsam job-mode is 'mpi'.",
)
parser.add_argument(
"--num-workers",
default=None,
type=int,
help="Number of parallel workers for the search. By default, it is be automatically computed depending on the chosen evaluator. If fixed then the default number of workers is override by this value.",
)
return parser
2 changes: 1 addition & 1 deletion docs/software/evaluators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ BalsamEvaluator
RayEvaluator
=============

.. autoclass:: deephyper.evaluator.ray_evaluator.RayEvaluator
.. autoclass:: deephyper.evaluator._ray_evaluator.RayEvaluator


.. _subprocess-evaluator:
Expand Down
2 changes: 1 addition & 1 deletion docs/source/deephyper.evaluator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ deephyper.evaluator.evaluate module
deephyper.evaluator.ray\_evaluator module
-----------------------------------------

.. automodule:: deephyper.evaluator.ray_evaluator
.. automodule:: deephyper.evaluator._ray_evaluator
:members:
:undoc-members:
:show-inheritance:
Expand Down

0 comments on commit 266bb2b

Please sign in to comment.