From 1e694ede10ba9b325fb41ce437b1317a1278044e Mon Sep 17 00:00:00 2001 From: "Falkner Stefan (CR/PJ-AI-R32)" Date: Thu, 15 Feb 2024 14:10:37 +0100 Subject: [PATCH] Add test for dask distributed optimization loop Signed-off-by: Falkner Stefan (CR/PJ-AI-R32) --- .../optimization_loops/dask_distributed.py | 2 + .../dask_distributed_test.py | 50 +++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/blackboxopt/optimization_loops/dask_distributed.py b/blackboxopt/optimization_loops/dask_distributed.py index 0a002f5b..05727904 100644 --- a/blackboxopt/optimization_loops/dask_distributed.py +++ b/blackboxopt/optimization_loops/dask_distributed.py @@ -183,6 +183,8 @@ def run_optimization_loop( while dask_scheduler.has_running_jobs(): new_evaluations = dask_scheduler.check_for_results(timeout_s=20) + if post_evaluation_callback: + list(map(post_evaluation_callback, new_evaluations)) optimizer.report(new_evaluations) evaluations.extend(new_evaluations) diff --git a/tests/optimization_loops/dask_distributed_test.py b/tests/optimization_loops/dask_distributed_test.py index 0bb83913..d07a9b00 100644 --- a/tests/optimization_loops/dask_distributed_test.py +++ b/tests/optimization_loops/dask_distributed_test.py @@ -19,6 +19,7 @@ ) from blackboxopt import Evaluation, EvaluationSpecification, Objective +from blackboxopt.optimization_loops import testing from blackboxopt.optimization_loops.dask_distributed import ( MinimalDaskScheduler, run_optimization_loop, @@ -84,3 +85,52 @@ def test_restarting_workers(tmpdir): scheduler.shutdown() del dd_client del cluster + + +def delayed_evaluation_function(eval_spec: EvaluationSpecification) -> Evaluation: + """Add a small delay in the evaluation to avoid racing condition""" + time.sleep(0.1) + return eval_spec.create_evaluation(objectives={"loss": 0.0}) + + +def test_post_evaluation_callback(tmpdir): + cluster = dd.LocalCluster( + n_workers=1, threads_per_worker=1, local_directory=tmpdir, processes=True + ) + dd_client = dd.Client(cluster) + evaluations_from_callback = [] + + def callback(e: Evaluation): + evaluations_from_callback.append(e) + + evaluations = run_optimization_loop( + RandomSearch(testing.SPACE, [Objective("loss", False)], max_steps=10), + delayed_evaluation_function, + post_evaluation_callback=callback, + dask_client=dd_client, + ) + + assert len(evaluations) == len(evaluations_from_callback) + assert evaluations == evaluations_from_callback + + +def test_pre_evaluation_callback(tmpdir): + cluster = dd.LocalCluster( + n_workers=1, threads_per_worker=1, local_directory=tmpdir, processes=True + ) + dd_client = dd.Client(cluster) + + eval_specs_from_callback = [] + + def callback(e: Evaluation): + eval_specs_from_callback.append(e) + + evaluations = run_optimization_loop( + RandomSearch(testing.SPACE, [Objective("loss", False)], max_steps=10), + delayed_evaluation_function, + pre_evaluation_callback=callback, + dask_client=dd_client, + ) + + assert len(evaluations) == len(eval_specs_from_callback) + assert [e.get_specification() for e in evaluations] == eval_specs_from_callback