Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Don't scatter data inside dask workers #1061

Merged
merged 9 commits into from
Jul 1, 2020
17 changes: 16 additions & 1 deletion joblib/_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
secede,
rejoin
)
from distributed import get_worker
from distributed.utils import thread_state

try:
Expand All @@ -43,6 +44,14 @@ def is_weakrefable(obj):
return False


def _in_dask_worker():
try:
worker = get_worker()
except ValueError:
worker = None
return worker
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably a little bit more robust.

Suggested change
try:
worker = get_worker()
except ValueError:
worker = None
return worker
from distributed.utils import thread_state
return hasattr(thread_state, "execution_state")



class _WeakKeyDictionary:
"""A variant of weakref.WeakKeyDictionary for unhashable objects.

Expand Down Expand Up @@ -252,12 +261,18 @@ async def maybe_to_futures(args):
try:
f = call_data_futures[arg]
except KeyError:
if is_weakrefable(arg) and sizeof(arg) > 1e3:
pass
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, we won't get confusing KeyErrors if something wrong happens during scattering :)

if f is None:
if (not _in_dask_worker() and is_weakrefable(arg) and
sizeof(arg) > 1e3):
# Automatically scatter large objects to some of
# the workers to avoid duplicated data transfers.
# Rely on automated inter-worker data stealing if
# more workers need to reuse this data
# concurrently.
# Because nested scatter call often end up
# cancelling tasks, (distributed/issues/3703), we
# never scatter inside nested Parallel calls.
[f] = await self.client.scatter(
[arg],
asynchronous=True
Expand Down