Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for __array_function__ protocol #4567

Merged
merged 15 commits into from Mar 15, 2019
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion .travis.yml
Expand Up @@ -11,6 +11,7 @@ _base_envs:
- &no_optimize XTRATESTARGS=
- &imports TEST_IMPORTS='true'
- &no_imports TEST_IMPORTS='false'
- &array_function NUMPY_EXPERIMENTAL_ARRAY_FUNCTION='1'

jobs:
fast_finish: true
Expand Down Expand Up @@ -44,7 +45,7 @@ jobs:

- env: &py37_env
- PYTHON=3.7
- NUMPY=1.15.0
- NUMPY=1.16.2
- PANDAS>=0.24.1
- *test_and_lint
- *no_coverage
Expand Down
4 changes: 4 additions & 0 deletions continuous_integration/travis/run_tests.sh
Expand Up @@ -13,3 +13,7 @@ else
echo "py.test dask --runslow $XTRATESTARGS"
py.test dask --runslow $XTRATESTARGS
fi

# This needs to be enabled to test __array_function__ protocol with
# NumPy v1.16.x, enabled by default starting in v1.17
export NUMPY_EXPERIMENTAL_ARRAY_FUNCTION=1
pentschev marked this conversation as resolved.
Show resolved Hide resolved
16 changes: 16 additions & 0 deletions dask/array/core.py
Expand Up @@ -1002,6 +1002,22 @@ def __array__(self, dtype=None, **kwargs):
x = np.array(x)
return x

def __array_function__(self, func, types, args, kwargs):
import dask.array as module
for submodule in func.__module__.split('.')[1:]:
try:
module = getattr(module, submodule)
except AttributeError:
return NotImplemented
if not hasattr(module, func.__name__):
return NotImplemented
da_func = getattr(module, func.__name__)
if da_func is func:
return NotImplemented
if not any(issubclass(t, (Array, np.ndarray)) for t in types):
return NotImplemented
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So in many cases a dask array function will work with numpy objects (or presumably anything on which it can effectively call asarray. I suggest a test like np.tensordot(dask_array, numpy_array) and replacing the all above with an any.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The __array_function__ protocol guarantees that it will only get called if dask.array.Array is in types. So if that's all you wanted to check, you could drop this entirely.

The problem is that you want to support operations involving some but not all other array types. For example, xarray.DataArray wraps dask arrays, not the other way around. So dask.Array.__array_function__ should return NotImplemented in that case, leaving it up xarray.DataArray.__array_function__ to define the operation.

We basically need some protocol or registration system to mark arrays as "can be coerced into dask array", e.g., either

sparse.Array.__dask_chunk_compatibile__ = True

or

dask.array.register_chunk_type(sparse.Array)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hrm, that's an interesting problem. Maybe we try to turn all of the args into dask arrays with da.asarray and, if that works, proceed?

I'm not sure that this would produce the correct behavior for xarray.DataArray though. I genuinely don't know what the right behavior is there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually think my idea for mixins can help with this. Specifically, we could check the specific mixins that Dask absolutely requires, I’m assuming that’d be the indexing mixin.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hrm, that's an interesting problem. Maybe we try to turn all of the args into dask arrays with da.asarray and, if that works, proceed?

I actually like this idea for its simplicity.

If we want to be completely general, we could use hasattr(t, ‘__array_function__’). But this would also include Pandas dataframes at some point I’m assuming...

I also like this idea in general, but I think then it would be best before to agree on all mixins we expect to have/support. Plus, if I understood correctly, it feels like this will require much more type checking, which may become complicated to handle.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that that would only help to address the second point above. I'm not sure that it would solve it though. As a worst case scenario I suspect that a project like TensorFlow would not inherit from those classes. We could get NumPy, Dask, Sparse, and Xarray on board but I'm not sure about any of the others out there.

I personally think that duck typing is more likely to work across projects than explicit typing.

Mixins may still be a good idea for the reason you mention, you can easily implement a lot of NumPy-like functionality without too much effort, but I don't think that they're a good choice to test for Numpy-ness.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do agree that it's hard to get traction for mixins or add dependencies, but adding a protocol isn't hard.

While you're right that TensorFlow doesn't plan on implementing any of these protocols, it has always historically done its own thing, and it has its own distributed system. I suspect that with Google's historical "rewrite the world and do it better" philosophy, it'll be hard to get them to play along anyway.

In contrast, anyone specifically aiming to be a NumPy duck array won't have a problem implementing protocols, and that is what I'm planning to do, to allow mixins to implement methods and so forth based on the implementation of protocols.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the time being, I've used @shoyer and @mrocklin suggestions replacing only that last condition in __array_function__ to:

if not all(issubclass(t, Array) for t in types):
    return NotImplemented

It seems to me that the mixin discussion has grown enough that we should move it to a Dask issue, discussed there and later review the current code. Anybody disagrees with my proposal?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since nobody opposed, I moved the discussion to #4583.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that with @shoyer 's statements below we can drop this conditional entirely. It should always pass

The array_function protocol guarantees that it will only get called if dask.array.Array is in types. So if that's all you wanted to check, you could drop this entirely.

pentschev marked this conversation as resolved.
Show resolved Hide resolved
return da_func(*args, **kwargs)

@property
def _elemwise(self):
return elemwise
Expand Down
95 changes: 95 additions & 0 deletions dask/array/tests/test_array_function.py
@@ -0,0 +1,95 @@
import pytest
np = pytest.importorskip('numpy', minversion='1.16')

import os

import dask.array as da
from dask.array.utils import assert_eq


env_name = "NUMPY_EXPERIMENTAL_ARRAY_FUNCTION"
missing_arrfunc_cond = env_name not in os.environ or os.environ[env_name] != "1"
missing_arrfunc_reason = env_name + " undefined or disabled"


@pytest.mark.skipif(missing_arrfunc_cond, reason=missing_arrfunc_reason)
@pytest.mark.parametrize('func', [
lambda x: np.concatenate([x, x, x]),
lambda x: np.cov(x, x),
lambda x: np.dot(x, x),
lambda x: np.dstack(x),
lambda x: np.flip(x, axis=0),
lambda x: np.hstack(x),
lambda x: np.matmul(x, x),
lambda x: np.mean(x),
lambda x: np.stack([x, x]),
lambda x: np.sum(x),
lambda x: np.var(x),
lambda x: np.vstack(x),
lambda x: np.fft.fft(x.rechunk(x.shape) if isinstance(x, da.Array) else x),
lambda x: np.fft.fft2(x.rechunk(x.shape) if isinstance(x, da.Array) else x),
lambda x: np.linalg.norm(x)])
def test_array_function_dask(func):
x = np.random.random((100, 100))
y = da.from_array(x, chunks=(50, 50))
res_x = func(x)
res_y = func(y)

assert isinstance(res_y, da.Array)
assert_eq(res_y, res_x)
pentschev marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.skipif(missing_arrfunc_cond, reason=missing_arrfunc_reason)
@pytest.mark.parametrize('func', [
lambda x: np.min_scalar_type(x),
lambda x: np.linalg.det(x),
lambda x: np.linalg.eigvals(x)])
def test_array_notimpl_function_dask(func):
x = np.random.random((100, 100))
y = da.from_array(x, chunks=(50, 50))

with pytest.raises(TypeError):
func(y)


@pytest.mark.skipif(missing_arrfunc_cond, reason=missing_arrfunc_reason)
def test_array_function_sparse_transpose():
sparse = pytest.importorskip('sparse')
x = da.random.random((500, 500), chunks=(100, 100))
x[x < 0.9] = 0

y = x.map_blocks(sparse.COO)

assert_eq(np.transpose(x), np.transpose(y))


@pytest.mark.skipif(missing_arrfunc_cond, reason=missing_arrfunc_reason)
@pytest.mark.xfail(reason="requires sparse support for __array_function__",
strict=False)
def test_array_function_sparse_tensordot():
sparse = pytest.importorskip('sparse')
x = np.random.random((2, 3, 4))
x[x < 0.9] = 0
y = np.random.random((4, 3, 2))
y[y < 0.9] = 0

xx = sparse.COO(x)
yy = sparse.COO(y)

assert_eq(np.tensordot(x, y, axes=(2, 0)),
np.tensordot(xx, yy, axes=(2, 0)).todense())


@pytest.mark.skipif(missing_arrfunc_cond, reason=missing_arrfunc_reason)
def test_array_function_cupy_svd():
cupy = pytest.importorskip('cupy')
x = cupy.random.random((500, 100))

y = da.from_array(x, chunks=(100, 100), asarray=False)

u_base, s_base, v_base = da.linalg.svd(y)
u, s, v = np.linalg.svd(y)

assert_eq(u, u_base)
assert_eq(s, s_base)
assert_eq(v, v_base)
20 changes: 10 additions & 10 deletions dask/array/tests/test_reductions.py
Expand Up @@ -473,30 +473,30 @@ def test_topk_argtopk1(npfunc, daskfunc, split_every):

# 1-dimensional arrays
# top 5 elements, sorted descending
assert_eq(npfunc(a)[-k:][::-1],
assert_eq(npfunc(a.compute())[-k:][::-1],
pentschev marked this conversation as resolved.
Show resolved Hide resolved
daskfunc(a, k, split_every=split_every))
# bottom 5 elements, sorted ascending
assert_eq(npfunc(a)[:k],
assert_eq(npfunc(a.compute())[:k],
daskfunc(a, -k, split_every=split_every))

# n-dimensional arrays
# also testing when k > chunk
# top 5 elements, sorted descending
assert_eq(npfunc(b, axis=0)[-k:, :, :][::-1, :, :],
assert_eq(npfunc(b.compute(), axis=0)[-k:, :, :][::-1, :, :],
daskfunc(b, k, axis=0, split_every=split_every))
assert_eq(npfunc(b, axis=1)[:, -k:, :][:, ::-1, :],
assert_eq(npfunc(b.compute(), axis=1)[:, -k:, :][:, ::-1, :],
daskfunc(b, k, axis=1, split_every=split_every))
assert_eq(npfunc(b, axis=-1)[:, :, -k:][:, :, ::-1],
assert_eq(npfunc(b.compute(), axis=-1)[:, :, -k:][:, :, ::-1],
daskfunc(b, k, axis=-1, split_every=split_every))
with pytest.raises(ValueError):
daskfunc(b, k, axis=3, split_every=split_every)

# bottom 5 elements, sorted ascending
assert_eq(npfunc(b, axis=0)[:k, :, :],
assert_eq(npfunc(b.compute(), axis=0)[:k, :, :],
daskfunc(b, -k, axis=0, split_every=split_every))
assert_eq(npfunc(b, axis=1)[:, :k, :],
assert_eq(npfunc(b.compute(), axis=1)[:, :k, :],
daskfunc(b, -k, axis=1, split_every=split_every))
assert_eq(npfunc(b, axis=-1)[:, :, :k],
assert_eq(npfunc(b.compute(), axis=-1)[:, :, :k],
daskfunc(b, -k, axis=-1, split_every=split_every))
with pytest.raises(ValueError):
daskfunc(b, -k, axis=3, split_every=split_every)
Expand All @@ -514,10 +514,10 @@ def test_topk_argtopk2(npfunc, daskfunc, split_every, chunksize):
k = 5

# top 5 elements, sorted descending
assert_eq(npfunc(a)[-k:][::-1],
assert_eq(npfunc(a.compute())[-k:][::-1],
daskfunc(a, k, split_every=split_every))
# bottom 5 elements, sorted ascending
assert_eq(npfunc(a)[:k],
assert_eq(npfunc(a.compute())[:k],
daskfunc(a, -k, split_every=split_every))


Expand Down
2 changes: 1 addition & 1 deletion dask/array/tests/test_routines.py
Expand Up @@ -502,7 +502,7 @@ def test_bincount_with_weights():

dweights = da.from_array(weights, chunks=2)
e = da.bincount(d, weights=dweights, minlength=6)
assert_eq(e, np.bincount(x, weights=dweights, minlength=6))
assert_eq(e, np.bincount(x, weights=dweights.compute(), minlength=6))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this change needed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because in this, the right argument of assert_eq() is used as baseline, and x being a NumPy array, it fails here:

    @wraps(np.bincount)
    def bincount(x, weights=None, minlength=None):
        if minlength is None:
            raise TypeError("Must specify minlength argument in da.bincount")
        assert x.ndim == 1
        if weights is not None:
>           assert weights.chunks == x.chunks
E           AttributeError: 'numpy.ndarray' object has no attribute 'chunks'

Therefore, to me it makes sense to make both arguments NumPy arrays and let it fallback to NumPy's implementation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the reverse should be done, with weights being coerced to da.Array with the same chunks as x.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. I think I agree with @pentschev here. The left argument, e, is testing the dask/dask situation while the right argument is testing the numpy/numpy situation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the reverse should be done, with weights being coerced to da.Array with the same chunks as x.

That would cause da.bincount() to be dispatched, then we would be testing against e which computes exactly that.

assert same_keys(da.bincount(d, weights=dweights, minlength=6), e)


Expand Down
17 changes: 14 additions & 3 deletions dask/array/tests/test_ufunc.py
Expand Up @@ -56,10 +56,10 @@ def test_ufunc():
unary_ufuncs = ['absolute', 'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan',
'arctanh', 'bitwise_not', 'cbrt', 'ceil', 'conj', 'cos',
'cosh', 'deg2rad', 'degrees', 'exp', 'exp2', 'expm1', 'fabs',
'fix', 'floor', 'i0', 'invert','isfinite', 'isinf', 'isnan', 'log',
'log10', 'log1p', 'log2', 'logical_not', 'nan_to_num',
'fix', 'floor', 'invert','isfinite', 'isinf', 'isnan', 'log',
'log10', 'log1p', 'log2', 'logical_not',
'negative', 'rad2deg', 'radians', 'reciprocal', 'rint', 'sign',
'signbit', 'sin', 'sinc', 'sinh', 'spacing', 'sqrt', 'square',
'signbit', 'sin', 'sinh', 'spacing', 'sqrt', 'square',
'tan', 'tanh', 'trunc']


Expand Down Expand Up @@ -276,6 +276,17 @@ def test_issignedinf():
assert_eq(np.isposinf(arr), da.isposinf(darr))


@pytest.mark.parametrize('func', ['i0', 'sinc', 'nan_to_num'])
def test_non_ufunc_others(func):
arr = np.random.randint(1, 100, size=(20, 20))
darr = da.from_array(arr, 3)

dafunc = getattr(da, func)
npfunc = getattr(np, func)

assert_eq(dafunc(darr), npfunc(arr), equal_nan=True)


def test_frompyfunc():
myadd = da.frompyfunc(add, 2, 1)
np_myadd = np.frompyfunc(add, 2, 1)
Expand Down