Skip to content

Commit

Permalink
Fix error in Dask-on-Mars when compute multiple objects (#2348)
Browse files Browse the repository at this point in the history
  • Loading branch information
loopyme committed Aug 18, 2021
1 parent c4a424c commit ec9a854
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 11 deletions.
9 changes: 7 additions & 2 deletions mars/contrib/dask/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# limitations under the License.

from dask import is_dask_collection, optimize
from dask.bag import Bag

from .scheduler import mars_dask_get
from .utils import reduce
from ...remote import spawn


def convert_dask_collection(dc):
Expand Down Expand Up @@ -46,5 +48,8 @@ def convert_dask_collection(dc):
else:
raise ValueError(
f"Dask collection object seems be broken, with unexpected key type:'{type(first_key).__name__}'")

return reduce(mars_dask_get(dsk, [key]))
res = reduce(mars_dask_get(dsk, [key]))
if isinstance(dc, Bag):
return spawn(lambda x: list(x[0][0]), args=(res,))
else:
return res
10 changes: 6 additions & 4 deletions mars/contrib/dask/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Tuple

from dask.core import istask, ishashable

from typing import List, Tuple
from .utils import reduce
from ...remote import spawn

Expand All @@ -39,8 +38,11 @@ def mars_scheduler(dsk: dict, keys: List[List[str]]):
Object
Computed values corresponding to the provided keys.
"""

return [[reduce(mars_dask_get(dsk, keys)).execute().fetch()]]
res = reduce(mars_dask_get(dsk, keys)).execute().fetch()
if not isinstance(res, List):
return [[res]]
else:
return res


def mars_dask_get(dsk: dict, keys: List[List]):
Expand Down
19 changes: 16 additions & 3 deletions mars/contrib/dask/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,25 @@ def test_bag(setup_cluster):

dask_res = result.compute()
assert dask_res == result.compute(scheduler=mars_scheduler)
assert dask_res == list(
convert_dask_collection(result).execute().fetch()
) # TODO: dask-bag computation will return weird tuple, which we don't know why
assert dask_res == convert_dask_collection(result).execute().fetch()


@pytest.mark.skipif(not dask_installed, reason='dask not installed')
def test_dask_errors():
with pytest.raises(TypeError):
convert_dask_collection({"foo": 0, "bar": 1})


@pytest.mark.skipif(not dask_installed, reason='dask not installed')
def test_multiple_objects(setup_cluster):
import dask

def inc(x: int):
return x + 1

test_list = [dask.delayed(inc)(i) for i in range(10)]
test_tuple = tuple(dask.delayed(inc)(i) for i in range(10))
test_dict = {str(i): dask.delayed(inc)(i) for i in range(10)}

for test_obj in (test_list, test_tuple, test_dict):
assert dask.compute(test_obj) == dask.compute(test_obj, scheduler=mars_scheduler)
2 changes: 0 additions & 2 deletions mars/contrib/dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ def concat(objs: List):
res = df_concat(objs)
else:
res = objs
while isinstance(res, List):
res = res[0]
return res.compute() if is_dask_collection(res) else res


Expand Down

0 comments on commit ec9a854

Please sign in to comment.