Skip to content

Commit

Permalink
add dask.array.compress
Browse files Browse the repository at this point in the history
  • Loading branch information
mrocklin committed Nov 30, 2015
1 parent c4df983 commit 8e06b9e
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 1 deletion.
2 changes: 1 addition & 1 deletion dask/array/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from_array, from_imperative, choose, where, coarsen, insert,
broadcast_to, ravel, reshape, fromfunction, unique, store, squeeze,
topk, bincount, histogram, map_blocks, atop, to_hdf5, dot, cov, array,
to_npy_stack, from_npy_stack)
to_npy_stack, from_npy_stack, compress)
from .core import (logaddexp, logaddexp2, conj, exp, log, log2, log10, log1p,
expm1, sqrt, square, sin, cos, tan, arcsin, arccos, arctan, arctan2,
hypot, sinh, cosh, tanh, arcsinh, arccosh, arctanh, deg2rad, rad2deg,
Expand Down
22 changes: 22 additions & 0 deletions dask/array/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1652,6 +1652,28 @@ def take(a, indices, axis=0):
return a[(slice(None),) * axis + (indices,)]


@wraps(np.compress)
def compress(condition, a, axis=None):
if axis is None:
raise NotImplementedError("Must select axis for compression")
if not -a.ndim <= axis < a.ndim:
raise ValueError('axis=(%s) out of bounds' % axis)
if axis < 0:
axis += a.ndim

condition = np.array(condition, dtype=bool)
if condition.ndim != 1:
raise ValueError("Condition must be one dimensional")
if len(condition) < a.shape[axis]:
condition = condition.copy()
condition.resize(a.shape[axis])

slc = ((slice(None),) * axis
+ (condition,)
+ (slice(None),) * (a.ndim - axis - 1))
return a[slc]


def _take_dask_array_from_numpy(a, indices, axis):
assert isinstance(a, np.ndarray)
assert isinstance(indices, Array)
Expand Down
19 changes: 19 additions & 0 deletions dask/array/tests/test_array_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,25 @@ def test_take():
assert same_keys(take(a, [3, 4, 5], axis=-1), take(a, [3, 4, 5], axis=-1))


def test_compress():
x = np.arange(25).reshape((5, 5))
a = from_array(x, chunks=(2, 2))

assert eq(np.compress([True, False, True, False, True], x, axis=0),
da.compress([True, False, True, False, True], a, axis=0))
assert eq(np.compress([True, False, True, False, True], x, axis=1),
da.compress([True, False, True, False, True], a, axis=1))
assert eq(np.compress([True, False], x, axis=1),
da.compress([True, False], a, axis=1))

with pytest.raises(NotImplementedError):
da.compress([True, False], a)
with pytest.raises(ValueError):
da.compress([True, False], a, axis=100)
with pytest.raises(ValueError):
da.compress([[True], [False]], a, axis=100)


def test_binops():
a = Array(dict((('a', i), np.array([''])) for i in range(3)),
'a', chunks=((1, 1, 1),))
Expand Down

0 comments on commit 8e06b9e

Please sign in to comment.