Skip to content

Commit

Permalink
Vendor indices based on our Dask contribution
Browse files Browse the repository at this point in the history
This is just our code that we contributed to Dask. So there is no issue
with us vendoring it here. That said, we use a BSD 3-Clause license just
like Dask. So if there were any issue, we are basically using the same
license as well. This is just borrowed from the same vendoring in
`dask-ndmeasure`.

ref: dask-image/dask-ndmeasure@3bfa650
  • Loading branch information
jakirkham committed Sep 28, 2017
1 parent 71a6ecb commit 6840551
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 0 deletions.
63 changes: 63 additions & 0 deletions dask_distance/_compat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# -*- coding: utf-8 -*-


import itertools

import numpy

import dask
Expand All @@ -26,3 +28,64 @@ def _asarray(a):
a = dask.array.from_array(a, a.shape)

return a


def _indices(dimensions, dtype=int, chunks=None):
"""
Implements NumPy's ``indices`` for Dask Arrays.
Generates a grid of indices covering the dimensions provided.
The final array has the shape ``(len(dimensions), *dimensions)``. The
chunks are used to specify the chunking for axis 1 up to
``len(dimensions)``. The 0th axis always has chunks of length 1.
Parameters
----------
dimensions : sequence of ints
The shape of the index grid.
dtype : dtype, optional
Type to use for the array. Default is ``int``.
chunks : sequence of ints
The number of samples on each block. Note that the last block will
have fewer samples if ``len(array) % chunks != 0``.
Returns
-------
grid : dask array
Notes
-----
Borrowed from my Dask Array contribution.
"""
if chunks is None:
raise ValueError("Must supply a chunks= keyword argument")

dimensions = tuple(dimensions)
dtype = numpy.dtype(dtype)
chunks = tuple(chunks)

if len(dimensions) != len(chunks):
raise ValueError("Need one more chunk than dimensions.")

grid = []
if numpy.prod(dimensions):
for i in range(len(dimensions)):
s = len(dimensions) * [None]
s[i] = slice(None)
s = tuple(s)

r = dask.array.arange(dimensions[i], dtype=dtype, chunks=chunks[i])
r = r[s]

for j in itertools.chain(range(i), range(i + 1, len(dimensions))):
r = r.repeat(dimensions[j], axis=j)

grid.append(r)

if grid:
grid = dask.array.stack(grid)
else:
grid = dask.array.empty(
(len(dimensions),) + dimensions, dtype=dtype, chunks=(1,) + chunks
)

return grid
61 changes: 61 additions & 0 deletions tests/test__compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,22 @@
# -*- coding: utf-8 -*-


import distutils.version as ver

import pytest

import numpy as np

import dask
import dask.array as da
import dask.array.utils as dau

import dask_distance._compat


old_dask = ver.LooseVersion(dask.__version__) <= ver.LooseVersion("0.13.0")


@pytest.mark.parametrize("x", [
list(range(5)),
np.random.randint(10, size=(15, 16)),
Expand All @@ -26,3 +32,58 @@ def test_asarray(x):
x = np.asarray(x)

dau.assert_eq(d, x)


def test_indices_no_chunks():
with pytest.raises(ValueError):
dask_distance._compat._indices((1,))


def test_indices_wrong_chunks():
with pytest.raises(ValueError):
dask_distance._compat._indices((1,), chunks=tuple())


@pytest.mark.parametrize(
"dimensions, dtype, chunks",
[
(tuple(), int, tuple()),
(tuple(), float, tuple()),
((0,), float, (1,)),
((0, 1, 2), float, (1, 1, 2)),
]
)
def test_empty_indicies(dimensions, dtype, chunks):
darr = dask_distance._compat._indices(dimensions, dtype, chunks=chunks)
nparr = np.indices(dimensions, dtype)

assert darr.shape == nparr.shape
assert darr.dtype == nparr.dtype

try:
dau.assert_eq(darr, nparr)
except IndexError:
if len(dimensions) and old_dask:
pytest.skip(
"Dask pre-0.14.0 is unable to compute this empty array."
)
else:
raise


def test_indicies():
darr = dask_distance._compat._indices((1,), chunks=(1,))
nparr = np.indices((1,))
dau.assert_eq(darr, nparr)

darr = dask_distance._compat._indices((1,), float, chunks=(1,))
nparr = np.indices((1,), float)
dau.assert_eq(darr, nparr)

darr = dask_distance._compat._indices((2, 1), chunks=(2, 1))
nparr = np.indices((2, 1))
dau.assert_eq(darr, nparr)

darr = dask_distance._compat._indices((2, 3), chunks=(1, 2))
nparr = np.indices((2, 3))
dau.assert_eq(darr, nparr)

0 comments on commit 6840551

Please sign in to comment.