diff --git a/CHANGES.rst b/CHANGES.rst index 7bedf3729..47eca0005 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -14,6 +14,10 @@ In development results or errors. https://github.com/joblib/joblib/pull/1055 +- Prevent a dask.distributed bug from surfacing in joblib's dask backend + during nested Parallel calls (due to joblib's auto-scattering feature) + https://github.com/joblib/joblib/pull/1061 + Release 0.15.1 -------------- diff --git a/joblib/_dask.py b/joblib/_dask.py index f4d91512b..e69b15b92 100644 --- a/joblib/_dask.py +++ b/joblib/_dask.py @@ -269,15 +269,22 @@ async def maybe_to_futures(args): try: f = call_data_futures[arg] except KeyError: + pass + if f is None: if is_weakrefable(arg) and sizeof(arg) > 1e3: # Automatically scatter large objects to some of # the workers to avoid duplicated data transfers. # Rely on automated inter-worker data stealing if # more workers need to reuse this data # concurrently. + # set hash=False - nested scatter calls (i.e + # calling client.scatter inside a dask worker) + # using hash=True often raise CancelledError, + # see dask/distributed#3703 [f] = await self.client.scatter( [arg], - asynchronous=True + asynchronous=True, + hash=False ) call_data_futures[arg] = f diff --git a/joblib/test/test_dask.py b/joblib/test/test_dask.py index b1e51f6a8..6af0a0a4e 100644 --- a/joblib/test/test_dask.py +++ b/joblib/test/test_dask.py @@ -11,7 +11,7 @@ from .._dask import DaskDistributedBackend distributed = pytest.importorskip('distributed') -from distributed import Client, LocalCluster +from distributed import Client, LocalCluster, get_client from distributed.metrics import time from distributed.utils_test import cluster, inc @@ -251,6 +251,36 @@ def test_auto_scatter(loop): assert counts[b['address']] == 0 +@pytest.mark.parametrize("retry_no", list(range(2))) +def test_nested_scatter(loop, retry_no): + + np = pytest.importorskip('numpy') + + NUM_INNER_TASKS = 10 + NUM_OUTER_TASKS = 10 + + def my_sum(x, i, j): + return np.sum(x) + + def outer_function_joblib(array, i): + client = get_client() # noqa + with parallel_backend("dask"): + results = Parallel()( + delayed(my_sum)(array[j:], i, j) for j in range( + NUM_INNER_TASKS) + ) + return sum(results) + + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as _: + with parallel_backend("dask"): + my_array = np.ones(10000) + _ = Parallel()( + delayed(outer_function_joblib)( + my_array[i:], i) for i in range(NUM_OUTER_TASKS) + ) + + def test_nested_backend_context_manager(loop): def get_nested_pids(): pids = set(Parallel(n_jobs=2)(delayed(os.getpid)() for _ in range(2)))