Skip to content

Commit

Permalink
Add test for dask distributed optimization loop
Browse files Browse the repository at this point in the history
Signed-off-by: Falkner Stefan (CR/PJ-AI-R32) <Stefan.Falkner@de.bosch.com>
  • Loading branch information
sfalkner committed Feb 15, 2024
1 parent e43e0cc commit 1e694ed
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 0 deletions.
2 changes: 2 additions & 0 deletions blackboxopt/optimization_loops/dask_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ def run_optimization_loop(

while dask_scheduler.has_running_jobs():
new_evaluations = dask_scheduler.check_for_results(timeout_s=20)
if post_evaluation_callback:
list(map(post_evaluation_callback, new_evaluations))
optimizer.report(new_evaluations)
evaluations.extend(new_evaluations)

Expand Down
50 changes: 50 additions & 0 deletions tests/optimization_loops/dask_distributed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)

from blackboxopt import Evaluation, EvaluationSpecification, Objective
from blackboxopt.optimization_loops import testing
from blackboxopt.optimization_loops.dask_distributed import (
MinimalDaskScheduler,
run_optimization_loop,
Expand Down Expand Up @@ -84,3 +85,52 @@ def test_restarting_workers(tmpdir):
scheduler.shutdown()
del dd_client
del cluster


def delayed_evaluation_function(eval_spec: EvaluationSpecification) -> Evaluation:
"""Add a small delay in the evaluation to avoid racing condition"""
time.sleep(0.1)
return eval_spec.create_evaluation(objectives={"loss": 0.0})


def test_post_evaluation_callback(tmpdir):
cluster = dd.LocalCluster(
n_workers=1, threads_per_worker=1, local_directory=tmpdir, processes=True
)
dd_client = dd.Client(cluster)
evaluations_from_callback = []

def callback(e: Evaluation):
evaluations_from_callback.append(e)

evaluations = run_optimization_loop(
RandomSearch(testing.SPACE, [Objective("loss", False)], max_steps=10),
delayed_evaluation_function,
post_evaluation_callback=callback,
dask_client=dd_client,
)

assert len(evaluations) == len(evaluations_from_callback)
assert evaluations == evaluations_from_callback


def test_pre_evaluation_callback(tmpdir):
cluster = dd.LocalCluster(
n_workers=1, threads_per_worker=1, local_directory=tmpdir, processes=True
)
dd_client = dd.Client(cluster)

eval_specs_from_callback = []

def callback(e: Evaluation):
eval_specs_from_callback.append(e)

evaluations = run_optimization_loop(
RandomSearch(testing.SPACE, [Objective("loss", False)], max_steps=10),
delayed_evaluation_function,
pre_evaluation_callback=callback,
dask_client=dd_client,
)

assert len(evaluations) == len(eval_specs_from_callback)
assert [e.get_specification() for e in evaluations] == eval_specs_from_callback

0 comments on commit 1e694ed

Please sign in to comment.