Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Don't scatter data inside dask workers #1061

Merged
merged 9 commits into from
Jul 1, 2020
9 changes: 8 additions & 1 deletion joblib/_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,15 +269,22 @@ async def maybe_to_futures(args):
try:
f = call_data_futures[arg]
except KeyError:
pass
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, we won't get confusing KeyErrors if something wrong happens during scattering :)

if f is None:
if is_weakrefable(arg) and sizeof(arg) > 1e3:
# Automatically scatter large objects to some of
# the workers to avoid duplicated data transfers.
# Rely on automated inter-worker data stealing if
# more workers need to reuse this data
# concurrently.
# set hash=False - nested scatter calls (i.e
# calling client.scatter inside a dask worker)
# using hash=True often raises CancelledError,
pierreglaser marked this conversation as resolved.
Show resolved Hide resolved
# see dask/distributed#3703
[f] = await self.client.scatter(
[arg],
asynchronous=True
asynchronous=True,
hash=False
)
call_data_futures[arg] = f

Expand Down
34 changes: 33 additions & 1 deletion joblib/test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .._dask import DaskDistributedBackend

distributed = pytest.importorskip('distributed')
from distributed import Client, LocalCluster
from distributed import Client, LocalCluster, get_client
from distributed.metrics import time
from distributed.utils_test import cluster, inc

Expand Down Expand Up @@ -251,6 +251,38 @@ def test_auto_scatter(loop):
assert counts[b['address']] == 0


@pytest.mark.parametrize("retry_no", list(range(2)))
def test_nested_scatter(loop, retry_no):

np = pytest.importorskip('numpy')

NUM_INNER_TASKS = 10
NUM_OUTER_TASKS = 10

def my_sum(x, i, j):
# print(f"running inner task {j} of outer task {i}")
pierreglaser marked this conversation as resolved.
Show resolved Hide resolved
return np.sum(x)

def outer_function_joblib(array, i):
# print(f"running outer task {i}")
pierreglaser marked this conversation as resolved.
Show resolved Hide resolved
client = get_client() # noqa
with parallel_backend("dask"):
results = Parallel()(
delayed(my_sum)(array[j:], i, j) for j in range(
NUM_INNER_TASKS)
)
return sum(results)

with cluster() as (s, [a, b]):
with Client(s['address'], loop=loop) as _:
with parallel_backend("dask"):
my_array = np.ones(10000)
_ = Parallel()(
delayed(outer_function_joblib)(
my_array[i:], i) for i in range(NUM_OUTER_TASKS)
)


def test_nested_backend_context_manager(loop):
def get_nested_pids():
pids = set(Parallel(n_jobs=2)(delayed(os.getpid)() for _ in range(2)))
Expand Down