Skip to content

Commit

Permalink
Dask in pyodide (#9053)
Browse files Browse the repository at this point in the history
This is a small step towards #7764 and dask/distributed#6257. It's basically just defensively importing `threading` and `multiprocessing` and defaulting to the synchronous scheduler if those fail. So this is currently mostly be useful for demos and training around the dask collections API. But it *does* work.

This is distinct from actually getting a `distributed.Client` working and talking to a remote cluster, which will require some actual networking work.
  • Loading branch information
Ian Rose committed Jun 20, 2022
1 parent d33e4a5 commit a62c008
Show file tree
Hide file tree
Showing 12 changed files with 137 additions and 31 deletions.
7 changes: 5 additions & 2 deletions dask/array/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from tlz import accumulate, concat, first, frequencies, groupby, partition
from tlz.curried import pluck

from dask import compute, config, core, threaded
from dask import compute, config, core
from dask.array import chunk
from dask.array.chunk import getitem
from dask.array.chunk_types import is_valid_array_chunk, is_valid_chunk_type
Expand All @@ -49,6 +49,7 @@
compute_as_if_collection,
dont_optimize,
is_dask_collection,
named_schedulers,
persist,
tokenize,
)
Expand Down Expand Up @@ -84,6 +85,8 @@
)
from dask.widgets import get_template

DEFAULT_GET = named_schedulers.get("threads", named_schedulers["sync"])

config.update_defaults({"array": {"chunk-size": "128MiB", "rechunk-threshold": 4}})

unknown_chunk_message = (
Expand Down Expand Up @@ -1406,7 +1409,7 @@ def __dask_tokenize__(self):
__dask_optimize__ = globalmethod(
optimize, key="array_optimize", falsey=dont_optimize
)
__dask_scheduler__ = staticmethod(threaded.get)
__dask_scheduler__ = staticmethod(DEFAULT_GET)

def __dask_postcompute__(self):
return finalize, ()
Expand Down
15 changes: 11 additions & 4 deletions dask/bag/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,18 @@
from dask import config
from dask.bag import chunk
from dask.bag.avro import to_avro
from dask.base import DaskMethodsMixin, dont_optimize, replace_name_in_key, tokenize
from dask.base import (
DaskMethodsMixin,
dont_optimize,
named_schedulers,
replace_name_in_key,
tokenize,
)
from dask.blockwise import blockwise
from dask.context import globalmethod
from dask.core import flatten, get_dependencies, istask, quote, reverse_dict
from dask.delayed import Delayed, unpack_collections
from dask.highlevelgraph import HighLevelGraph
from dask.multiprocessing import get as mpget
from dask.optimization import cull, fuse, inline
from dask.sizeof import sizeof
from dask.utils import (
Expand All @@ -64,6 +69,8 @@
takes_multiple_arguments,
)

DEFAULT_GET = named_schedulers.get("processes", named_schedulers["sync"])

no_default = "__no__default__"
no_result = type(
"no_result", (object,), {"__slots__": (), "__reduce__": lambda self: "no_result"}
Expand Down Expand Up @@ -371,7 +378,7 @@ def __dask_tokenize__(self):
return self.key

__dask_optimize__ = globalmethod(optimize, key="bag_optimize", falsey=dont_optimize)
__dask_scheduler__ = staticmethod(mpget)
__dask_scheduler__ = staticmethod(DEFAULT_GET)

def __dask_postcompute__(self):
return finalize_item, ()
Expand Down Expand Up @@ -481,7 +488,7 @@ def __dask_tokenize__(self):
return self.name

__dask_optimize__ = globalmethod(optimize, key="bag_optimize", falsey=dont_optimize)
__dask_scheduler__ = staticmethod(mpget)
__dask_scheduler__ = staticmethod(DEFAULT_GET)

def __dask_postcompute__(self):
return finalize, ()
Expand Down
24 changes: 15 additions & 9 deletions dask/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
from tlz import curry, groupby, identity, merge
from tlz.functoolz import Compose

from dask import config, local, threaded
from dask.compatibility import _PY_VERSION
from dask import config, local
from dask.compatibility import _EMSCRIPTEN, _PY_VERSION
from dask.context import thread_state
from dask.core import flatten
from dask.core import get as simple_get
from dask.core import literal, quote
from dask.hashing import hash_buffer_hex
from dask.system import CPU_COUNT
from dask.typing import SchedulerGetCallable
from dask.utils import Dispatch, apply, ensure_dict, key_split

__all__ = (
Expand Down Expand Up @@ -1284,19 +1285,24 @@ def _colorize(t):
return "#" + h


named_schedulers = {
named_schedulers: dict[str, SchedulerGetCallable] = {
"sync": local.get_sync,
"synchronous": local.get_sync,
"single-threaded": local.get_sync,
"threads": threaded.get,
"threading": threaded.get,
}

try:
if not _EMSCRIPTEN:
from dask import threaded

named_schedulers.update(
{
"threads": threaded.get,
"threading": threaded.get,
}
)

from dask import multiprocessing as dask_multiprocessing
except ImportError:
pass
else:

named_schedulers.update(
{
"processes": dask_multiprocessing.get,
Expand Down
2 changes: 2 additions & 0 deletions dask/compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
from packaging.version import parse as parse_version

_PY_VERSION = parse_version(".".join(map(str, sys.version_info[:3])))

_EMSCRIPTEN = sys.platform == "emscripten"
16 changes: 12 additions & 4 deletions dask/dataframe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,16 @@
from tlz import first, merge, partition_all, remove, unique

import dask.array as da
from dask import core, threaded
from dask import core
from dask.array.core import Array, normalize_arg
from dask.bag import map_partitions as map_bag_partitions
from dask.base import DaskMethodsMixin, dont_optimize, is_dask_collection, tokenize
from dask.base import (
DaskMethodsMixin,
dont_optimize,
is_dask_collection,
named_schedulers,
tokenize,
)
from dask.blockwise import Blockwise, BlockwiseDep, BlockwiseDepDict, blockwise
from dask.context import globalmethod
from dask.dataframe import methods
Expand Down Expand Up @@ -79,6 +85,8 @@
)
from dask.widgets import get_template

DEFAULT_GET = named_schedulers.get("threads", named_schedulers["sync"])

no_default = "__no_default__"

GROUP_KEYS_DEFAULT = None if PANDAS_GT_150 else True
Expand Down Expand Up @@ -163,7 +171,7 @@ def __dask_layers__(self):
__dask_optimize__ = globalmethod(
optimize, key="dataframe_optimize", falsey=dont_optimize
)
__dask_scheduler__ = staticmethod(threaded.get)
__dask_scheduler__ = staticmethod(DEFAULT_GET)

def __dask_postcompute__(self):
return first, ()
Expand Down Expand Up @@ -345,7 +353,7 @@ def __dask_tokenize__(self):
__dask_optimize__ = globalmethod(
optimize, key="dataframe_optimize", falsey=dont_optimize
)
__dask_scheduler__ = staticmethod(threaded.get)
__dask_scheduler__ = staticmethod(DEFAULT_GET)

def __dask_postcompute__(self):
return finalize, ()
Expand Down
13 changes: 10 additions & 3 deletions dask/dataframe/io/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,21 @@
from fsspec.utils import build_name_function, stringify_path
from tlz import merge

from dask import config, multiprocessing
from dask.base import compute_as_if_collection, get_scheduler, tokenize
from dask import config
from dask.base import (
compute_as_if_collection,
get_scheduler,
named_schedulers,
tokenize,
)
from dask.dataframe.core import DataFrame
from dask.dataframe.io.io import _link, from_map
from dask.dataframe.io.utils import DataFrameIOFunction
from dask.delayed import Delayed, delayed
from dask.utils import get_scheduler_lock

MP_GET = named_schedulers.get("processes", object())


def _pd_to_hdf(pd_to_hdf, lock, args, kwargs=None):
"""A wrapper function around pd_to_hdf that enables locking"""
Expand Down Expand Up @@ -193,7 +200,7 @@ def to_hdf(
if lock is None:
if not single_node:
lock = True
elif not single_file and _actual_get is not multiprocessing.get:
elif not single_file and _actual_get is not MP_GET:
# if we're writing to multiple files with the multiprocessing
# scheduler we don't need to lock
lock = True
Expand Down
8 changes: 6 additions & 2 deletions dask/delayed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@

from tlz import concat, curry, merge, unique

from dask import config, threaded
from dask import config
from dask.base import (
DaskMethodsMixin,
dont_optimize,
is_dask_collection,
named_schedulers,
replace_name_in_key,
)
from dask.base import tokenize as _tokenize
Expand All @@ -23,6 +24,9 @@
__all__ = ["Delayed", "delayed"]


DEFAULT_GET = named_schedulers.get("threads", named_schedulers["sync"])


def unzip(ls, nout):
"""Unzip a list of lists into ``nout`` outputs."""
out = list(zip(*ls))
Expand Down Expand Up @@ -518,7 +522,7 @@ def __dask_layers__(self):
def __dask_tokenize__(self):
return self.key

__dask_scheduler__ = staticmethod(threaded.get)
__dask_scheduler__ = staticmethod(DEFAULT_GET)
__dask_optimize__ = globalmethod(optimize, key="delayed_optimize")

def __dask_postcompute__(self):
Expand Down
5 changes: 4 additions & 1 deletion dask/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,10 @@
See the function ``inline_functions`` for more information.
"""
from __future__ import annotations

import os
from collections.abc import Hashable, Mapping, Sequence
from concurrent.futures import Executor, Future
from functools import partial
from queue import Empty, Queue
Expand Down Expand Up @@ -545,7 +548,7 @@ def submit(self, fn, *args, **kwargs):
synchronous_executor = SynchronousExecutor()


def get_sync(dsk, keys, **kwargs):
def get_sync(dsk: Mapping, keys: Sequence[Hashable] | Hashable, **kwargs):
"""A naive synchronous version of get_async
Can be useful for debugging.
Expand Down
5 changes: 3 additions & 2 deletions dask/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pickle
import sys
import traceback
from collections.abc import Hashable, Mapping, Sequence
from concurrent.futures import ProcessPoolExecutor
from functools import partial
from warnings import warn
Expand Down Expand Up @@ -143,8 +144,8 @@ def get_context():


def get(
dsk,
keys,
dsk: Mapping,
keys: Sequence[Hashable] | Hashable,
num_workers=None,
func_loads=None,
func_dumps=None,
Expand Down
57 changes: 57 additions & 0 deletions dask/tests/test_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dataclasses
import datetime
import inspect
import os
import subprocess
import sys
Expand Down Expand Up @@ -1543,3 +1544,59 @@ def __dask_optimize__(cls, dsk, keys, **kwargs):
)[0]
assert optimized
da.utils.assert_eq(x, result)


# A function designed to be run in a subprocess with dask.compatibility._EMSCRIPTEN
# patched. This allows for checking for different default schedulers depending on the
# platform. One might prefer patching `sys.platform` for a more direct test, but that
# causes problems in other libraries.
def check_default_scheduler(module, collection, expected, emscripten):
from contextlib import nullcontext
from unittest import mock

from dask.local import get_sync

if emscripten:
ctx = mock.patch("dask.base.named_schedulers", {"sync": get_sync})
else:
ctx = nullcontext()
with ctx:
import importlib

if expected == "sync":
from dask.local import get_sync as get
elif expected == "threads":
from dask.threaded import get
elif expected == "processes":
from dask.multiprocessing import get

mod = importlib.import_module(module)

assert getattr(mod, collection).__dask_scheduler__ == get


@pytest.mark.parametrize(
"params",
(
"'dask.dataframe', '_Frame', 'sync', True",
"'dask.dataframe', '_Frame', 'threads', False",
"'dask.array', 'Array', 'sync', True",
"'dask.array', 'Array', 'threads', False",
"'dask.bag', 'Bag', 'sync', True",
"'dask.bag', 'Bag', 'processes', False",
),
)
def test_emscripten_default_scheduler(params):
pytest.importorskip("dask.array")
pytest.importorskip("dask.dataframe")
proc = subprocess.run(
[
sys.executable,
"-c",
(
inspect.getsource(check_default_scheduler)
+ f"check_default_scheduler({params})\n"
),
]
)
proc.check_returncode()
14 changes: 11 additions & 3 deletions dask/threaded.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import sys
import threading
from collections import defaultdict
from collections.abc import Hashable, Mapping, Sequence
from concurrent.futures import Executor, ThreadPoolExecutor
from threading import Lock, current_thread

Expand All @@ -32,15 +33,22 @@ def pack_exception(e, dumps):
return e, sys.exc_info()[2]


def get(dsk, result, cache=None, num_workers=None, pool=None, **kwargs):
def get(
dsk: Mapping,
keys: Sequence[Hashable] | Hashable,
cache=None,
num_workers=None,
pool=None,
**kwargs,
):
"""Threaded cached implementation of dask.get
Parameters
----------
dsk: dict
A dask dictionary specifying a workflow
result: key or list of keys
keys: key or list of keys
Keys corresponding to desired data
num_workers: integer of thread count
The number of threads to use in the ThreadPool that will actually execute tasks
Expand Down Expand Up @@ -82,7 +90,7 @@ def get(dsk, result, cache=None, num_workers=None, pool=None, **kwargs):
pool.submit,
pool._max_workers,
dsk,
result,
keys,
cache=cache,
get_id=_thread_get_id,
pack_exception=pack_exception,
Expand Down
Loading

0 comments on commit a62c008

Please sign in to comment.