Skip to content

Commit

Permalink
Support concatenate for scipy.sparse in dask array
Browse files Browse the repository at this point in the history
  • Loading branch information
mrocklin committed Aug 1, 2018
1 parent 77af8f8 commit 40890d7
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
16 changes: 16 additions & 0 deletions dask/array/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,22 @@ def register_sparse():
tensordot_lookup.register(sparse.COO, sparse.tensordot)


@concatenate_lookup.register_lazy('scipy')
def register_scipy_sparse():
import scipy.sparse

def _concatenate(L, axis=0):
if axis == 0:
return scipy.sparse.vstack(L)
elif axis == 1:
return scipy.sparse.hstack(L)
else:
msg = ("Can only concatenate scipy sparse matrices for axis in "
"{0, 1}. Got %s" % axis)
raise ValueError(msg)
concatenate_lookup.register(scipy.sparse.spmatrix, _concatenate)


class PerformanceWarning(Warning):
""" A warning given when bad chunking may cause poor performance """

Expand Down
17 changes: 14 additions & 3 deletions dask/array/tests/test_array_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3620,11 +3620,22 @@ def test_blocks_indexer():


def test_dask_array_holds_scipy_sparse_containers():
sparse = pytest.importorskip('scipy.sparse')
pytest.importorskip('scipy.sparse')
import scipy.sparse
x = da.random.random((1000, 10), chunks=(100, 10))
x[x < 0.9] = 0
y = x.map_blocks(sparse.csr_matrix)
xx = x.compute()
y = x.map_blocks(scipy.sparse.csr_matrix)

vs = y.to_delayed().flatten().tolist()
values = dask.compute(*vs, scheduler='single-threaded')
assert all(isinstance(v, sparse.csr_matrix) for v in values)
assert all(isinstance(v, scipy.sparse.csr_matrix) for v in values)

yy = y.compute(scheduler='single-threaded')
assert isinstance(yy, scipy.sparse.spmatrix)
assert (yy == xx).all()

z = x.T.map_blocks(scipy.sparse.csr_matrix)
zz = z.compute(scheduler='single-threaded')
assert isinstance(yy, scipy.sparse.spmatrix)
assert (zz == xx.T).all()

0 comments on commit 40890d7

Please sign in to comment.