From 5cb45ab237e21a248cbb5237a4a2fbaff61f7fbe Mon Sep 17 00:00:00 2001 From: Pierre Glaser Date: Mon, 13 Apr 2020 19:29:54 +0200 Subject: [PATCH 1/8] clear python exception state before scattering In order to stop being polluted by unrelated key-errors coming from call_data_futures lookup errors --- joblib/_dask.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/joblib/_dask.py b/joblib/_dask.py index a3af0d1a8..cbf3cdb49 100644 --- a/joblib/_dask.py +++ b/joblib/_dask.py @@ -252,6 +252,8 @@ async def maybe_to_futures(args): try: f = call_data_futures[arg] except KeyError: + pass + if f is None: if is_weakrefable(arg) and sizeof(arg) > 1e3: # Automatically scatter large objects to some of # the workers to avoid duplicated data transfers. From 40ef478154197ed545f53cab6545605ecd803fc8 Mon Sep 17 00:00:00 2001 From: Pierre Glaser Date: Mon, 27 Apr 2020 16:50:28 +0200 Subject: [PATCH 2/8] don't scatter in nested dask calls --- joblib/_dask.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/joblib/_dask.py b/joblib/_dask.py index cbf3cdb49..3cd04b423 100644 --- a/joblib/_dask.py +++ b/joblib/_dask.py @@ -25,6 +25,7 @@ secede, rejoin ) + from distributed import get_worker from distributed.utils import thread_state try: @@ -43,6 +44,14 @@ def is_weakrefable(obj): return False +def _in_dask_worker(): + try: + worker = get_worker() + except ValueError: + worker = None + return worker + + class _WeakKeyDictionary: """A variant of weakref.WeakKeyDictionary for unhashable objects. @@ -254,12 +263,16 @@ async def maybe_to_futures(args): except KeyError: pass if f is None: - if is_weakrefable(arg) and sizeof(arg) > 1e3: + if (not _in_dask_worker() and is_weakrefable(arg) and + sizeof(arg) > 1e3): # Automatically scatter large objects to some of # the workers to avoid duplicated data transfers. # Rely on automated inter-worker data stealing if # more workers need to reuse this data # concurrently. + # Because nested scatter call often end up + # cancelling tasks, (distributed/issues/3703), we + # never scatter inside nested Parallel calls. [f] = await self.client.scatter( [arg], asynchronous=True From ab1e6b18391897d530243305c5a76678cd4c757d Mon Sep 17 00:00:00 2001 From: Pierre Glaser Date: Tue, 30 Jun 2020 22:36:09 +0200 Subject: [PATCH 3/8] apply matthew's suggestion about using hash=False --- joblib/_dask.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/joblib/_dask.py b/joblib/_dask.py index 3cd04b423..a9fbac72d 100644 --- a/joblib/_dask.py +++ b/joblib/_dask.py @@ -44,14 +44,6 @@ def is_weakrefable(obj): return False -def _in_dask_worker(): - try: - worker = get_worker() - except ValueError: - worker = None - return worker - - class _WeakKeyDictionary: """A variant of weakref.WeakKeyDictionary for unhashable objects. @@ -263,19 +255,20 @@ async def maybe_to_futures(args): except KeyError: pass if f is None: - if (not _in_dask_worker() and is_weakrefable(arg) and - sizeof(arg) > 1e3): + if is_weakrefable(arg) and sizeof(arg) > 1e3: # Automatically scatter large objects to some of # the workers to avoid duplicated data transfers. # Rely on automated inter-worker data stealing if # more workers need to reuse this data # concurrently. - # Because nested scatter call often end up - # cancelling tasks, (distributed/issues/3703), we - # never scatter inside nested Parallel calls. + # set hash=False - nested scatter calls (i.e + # calling client.scatter inside a dask worker) + # using hash=True often raises CancelledError, + # see dask/distributed#3703 [f] = await self.client.scatter( [arg], - asynchronous=True + asynchronous=True, + hash=False ) call_data_futures[arg] = f From f8a315832d9bc4e110774520ff91a92dfd34b8fd Mon Sep 17 00:00:00 2001 From: Pierre Glaser Date: Tue, 30 Jun 2020 23:08:55 +0200 Subject: [PATCH 4/8] test that nested scatter calls run successfully --- joblib/test/test_dask.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/joblib/test/test_dask.py b/joblib/test/test_dask.py index 2b5b7dd92..c45f320cf 100644 --- a/joblib/test/test_dask.py +++ b/joblib/test/test_dask.py @@ -10,7 +10,7 @@ from .._dask import DaskDistributedBackend distributed = pytest.importorskip('distributed') -from distributed import Client, LocalCluster +from distributed import Client, LocalCluster, get_client from distributed.metrics import time from distributed.utils_test import cluster, inc @@ -152,6 +152,38 @@ def count_events(event_name, client): assert counts[b['address']] == 0 +@pytest.mark.parametrize("retry_no", list(range(2))) +def test_nested_scatter(loop, retry_no): + + np = pytest.importorskip('numpy') + + NUM_INNER_TASKS = 10 + NUM_OUTER_TASKS = 10 + + def my_sum(x, i, j): + # print(f"running inner task {j} of outer task {i}") + return np.sum(x) + + def outer_function_joblib(array, i): + # print(f"running outer task {i}") + client = get_client() # noqa + with parallel_backend("dask"): + results = Parallel()( + delayed(my_sum)(array[j:], i, j) for j in range( + NUM_INNER_TASKS) + ) + return sum(results) + + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as _: + with parallel_backend("dask"): + my_array = np.ones(10000) + _ = Parallel()( + delayed(outer_function_joblib)( + my_array[i:], i) for i in range(NUM_OUTER_TASKS) + ) + + def test_nested_backend_context_manager(loop): def get_nested_pids(): pids = set(Parallel(n_jobs=2)(delayed(os.getpid)() for _ in range(2))) From b337e76625387a438f2ec5f61a4b6eb68403f549 Mon Sep 17 00:00:00 2001 From: Pierre Glaser Date: Tue, 30 Jun 2020 23:36:13 +0200 Subject: [PATCH 5/8] remove unused import --- joblib/_dask.py | 1 - 1 file changed, 1 deletion(-) diff --git a/joblib/_dask.py b/joblib/_dask.py index a9fbac72d..0bd395225 100644 --- a/joblib/_dask.py +++ b/joblib/_dask.py @@ -25,7 +25,6 @@ secede, rejoin ) - from distributed import get_worker from distributed.utils import thread_state try: From c56cdfb7a79483e6d2d4e3ab604e225dea1e7ee2 Mon Sep 17 00:00:00 2001 From: Pierre Glaser Date: Wed, 1 Jul 2020 11:48:51 +0200 Subject: [PATCH 6/8] update CHANGES.rst --- CHANGES.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGES.rst b/CHANGES.rst index 7bedf3729..47eca0005 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -14,6 +14,10 @@ In development results or errors. https://github.com/joblib/joblib/pull/1055 +- Prevent a dask.distributed bug from surfacing in joblib's dask backend + during nested Parallel calls (due to joblib's auto-scattering feature) + https://github.com/joblib/joblib/pull/1061 + Release 0.15.1 -------------- From eeabb67b0fd4c2c902fd277f7f5d014b8f848a43 Mon Sep 17 00:00:00 2001 From: Pierre Glaser Date: Wed, 1 Jul 2020 11:52:48 +0200 Subject: [PATCH 7/8] Update joblib/_dask.py --- joblib/_dask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/joblib/_dask.py b/joblib/_dask.py index 69ba54594..e69b15b92 100644 --- a/joblib/_dask.py +++ b/joblib/_dask.py @@ -279,7 +279,7 @@ async def maybe_to_futures(args): # concurrently. # set hash=False - nested scatter calls (i.e # calling client.scatter inside a dask worker) - # using hash=True often raises CancelledError, + # using hash=True often raise CancelledError, # see dask/distributed#3703 [f] = await self.client.scatter( [arg], From 04d8188860c2a24a54c3ec2a74000b941dbbd04c Mon Sep 17 00:00:00 2001 From: Pierre Glaser Date: Wed, 1 Jul 2020 11:53:21 +0200 Subject: [PATCH 8/8] remove leftover print function calls --- joblib/test/test_dask.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/joblib/test/test_dask.py b/joblib/test/test_dask.py index 1c5559c42..6af0a0a4e 100644 --- a/joblib/test/test_dask.py +++ b/joblib/test/test_dask.py @@ -260,11 +260,9 @@ def test_nested_scatter(loop, retry_no): NUM_OUTER_TASKS = 10 def my_sum(x, i, j): - # print(f"running inner task {j} of outer task {i}") return np.sum(x) def outer_function_joblib(array, i): - # print(f"running outer task {i}") client = get_client() # noqa with parallel_backend("dask"): results = Parallel()(