diff --git a/blackboxopt/optimization_loops/dask_distributed.py b/blackboxopt/optimization_loops/dask_distributed.py index 0a002f5..0572790 100644 --- a/blackboxopt/optimization_loops/dask_distributed.py +++ b/blackboxopt/optimization_loops/dask_distributed.py @@ -183,6 +183,8 @@ def run_optimization_loop( while dask_scheduler.has_running_jobs(): new_evaluations = dask_scheduler.check_for_results(timeout_s=20) + if post_evaluation_callback: + list(map(post_evaluation_callback, new_evaluations)) optimizer.report(new_evaluations) evaluations.extend(new_evaluations) diff --git a/tests/optimization_loops/dask_distributed_test.py b/tests/optimization_loops/dask_distributed_test.py index 0bb8391..d07a9b0 100644 --- a/tests/optimization_loops/dask_distributed_test.py +++ b/tests/optimization_loops/dask_distributed_test.py @@ -19,6 +19,7 @@ ) from blackboxopt import Evaluation, EvaluationSpecification, Objective +from blackboxopt.optimization_loops import testing from blackboxopt.optimization_loops.dask_distributed import ( MinimalDaskScheduler, run_optimization_loop, @@ -84,3 +85,52 @@ def test_restarting_workers(tmpdir): scheduler.shutdown() del dd_client del cluster + + +def delayed_evaluation_function(eval_spec: EvaluationSpecification) -> Evaluation: + """Add a small delay in the evaluation to avoid racing condition""" + time.sleep(0.1) + return eval_spec.create_evaluation(objectives={"loss": 0.0}) + + +def test_post_evaluation_callback(tmpdir): + cluster = dd.LocalCluster( + n_workers=1, threads_per_worker=1, local_directory=tmpdir, processes=True + ) + dd_client = dd.Client(cluster) + evaluations_from_callback = [] + + def callback(e: Evaluation): + evaluations_from_callback.append(e) + + evaluations = run_optimization_loop( + RandomSearch(testing.SPACE, [Objective("loss", False)], max_steps=10), + delayed_evaluation_function, + post_evaluation_callback=callback, + dask_client=dd_client, + ) + + assert len(evaluations) == len(evaluations_from_callback) + assert evaluations == evaluations_from_callback + + +def test_pre_evaluation_callback(tmpdir): + cluster = dd.LocalCluster( + n_workers=1, threads_per_worker=1, local_directory=tmpdir, processes=True + ) + dd_client = dd.Client(cluster) + + eval_specs_from_callback = [] + + def callback(e: Evaluation): + eval_specs_from_callback.append(e) + + evaluations = run_optimization_loop( + RandomSearch(testing.SPACE, [Objective("loss", False)], max_steps=10), + delayed_evaluation_function, + pre_evaluation_callback=callback, + dask_client=dd_client, + ) + + assert len(evaluations) == len(eval_specs_from_callback) + assert [e.get_specification() for e in evaluations] == eval_specs_from_callback