Skip to content

Commit

Permalink
Use named scheduler dict to get default
Browse files Browse the repository at this point in the history
  • Loading branch information
Ian Rose committed Jun 15, 2022
1 parent 1472d71 commit 78766ad
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 39 deletions.
11 changes: 2 additions & 9 deletions dask/array/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@
compute_as_if_collection,
dont_optimize,
is_dask_collection,
named_schedulers,
persist,
tokenize,
)
from dask.blockwise import blockwise as core_blockwise
from dask.blockwise import broadcast_dimensions
from dask.compatibility import _EMSCRIPTEN
from dask.context import globalmethod
from dask.core import quote
from dask.delayed import Delayed, delayed
Expand Down Expand Up @@ -85,14 +85,7 @@
)
from dask.widgets import get_template

if _EMSCRIPTEN:
from dask import local

DEFAULT_GET = local.get_sync
else:
from dask import threaded

DEFAULT_GET = threaded.get
DEFAULT_GET = named_schedulers.get("threads", named_schedulers["sync"])

config.update_defaults({"array": {"chunk-size": "128MiB", "rechunk-threshold": 4}})

Expand Down
18 changes: 8 additions & 10 deletions dask/bag/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,14 @@
from dask import config
from dask.bag import chunk
from dask.bag.avro import to_avro
from dask.base import DaskMethodsMixin, dont_optimize, replace_name_in_key, tokenize
from dask.base import (
DaskMethodsMixin,
dont_optimize,
named_schedulers,
replace_name_in_key,
tokenize,
)
from dask.blockwise import blockwise
from dask.compatibility import _EMSCRIPTEN
from dask.context import globalmethod
from dask.core import flatten, get_dependencies, istask, quote, reverse_dict
from dask.delayed import Delayed, unpack_collections
Expand All @@ -64,14 +69,7 @@
takes_multiple_arguments,
)

if _EMSCRIPTEN:
from dask import local

DEFAULT_GET = local.get_sync
else:
from dask import multiprocessing

DEFAULT_GET = multiprocessing.get
DEFAULT_GET = named_schedulers.get("processes", named_schedulers["sync"])

no_default = "__no__default__"
no_result = type(
Expand Down
18 changes: 8 additions & 10 deletions dask/dataframe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,14 @@
from dask import core
from dask.array.core import Array, normalize_arg
from dask.bag import map_partitions as map_bag_partitions
from dask.base import DaskMethodsMixin, dont_optimize, is_dask_collection, tokenize
from dask.base import (
DaskMethodsMixin,
dont_optimize,
is_dask_collection,
named_schedulers,
tokenize,
)
from dask.blockwise import Blockwise, BlockwiseDep, BlockwiseDepDict, blockwise
from dask.compatibility import _EMSCRIPTEN
from dask.context import globalmethod
from dask.dataframe import methods
from dask.dataframe._compat import PANDAS_GT_140, PANDAS_GT_150
Expand Down Expand Up @@ -80,14 +85,7 @@
)
from dask.widgets import get_template

if _EMSCRIPTEN:
from dask import local

DEFAULT_GET = local.get_sync
else:
from dask import threaded

DEFAULT_GET = threaded.get
DEFAULT_GET = named_schedulers.get("threads", named_schedulers["sync"])

no_default = "__no_default__"

Expand Down
11 changes: 2 additions & 9 deletions dask/delayed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
DaskMethodsMixin,
dont_optimize,
is_dask_collection,
named_schedulers,
replace_name_in_key,
)
from dask.base import tokenize as _tokenize
from dask.compatibility import _EMSCRIPTEN
from dask.context import globalmethod
from dask.core import flatten, quote
from dask.highlevelgraph import HighLevelGraph
Expand All @@ -24,14 +24,7 @@
__all__ = ["Delayed", "delayed"]


if _EMSCRIPTEN:
from dask import local

DEFAULT_GET = local.get_sync
else:
from dask import threaded

DEFAULT_GET = threaded.get
DEFAULT_GET = named_schedulers.get("threads", named_schedulers["sync"])


def unzip(ls, nout):
Expand Down
9 changes: 8 additions & 1 deletion dask/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1547,9 +1547,16 @@ def __dask_optimize__(cls, dsk, keys, **kwargs):
# platform. One might prefer patching `sys.platform` for a more direct test, but that
# causes problems in other libraries.
def check_default_scheduler(module, collection, expected, emscripten):
from contextlib import nullcontext
from unittest import mock

with mock.patch("dask.compatibility._EMSCRIPTEN", emscripten):
from dask.local import get_sync

if emscripten:
ctx = mock.patch("dask.base.named_schedulers", {"sync": get_sync})
else:
ctx = nullcontext()
with ctx:
import importlib

if expected == "sync":
Expand Down

0 comments on commit 78766ad

Please sign in to comment.