Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use __array_function__ on functions outside numpy #13872

Open
rgommers opened this issue Jun 30, 2019 · 7 comments
Open

use __array_function__ on functions outside numpy #13872

rgommers opened this issue Jun 30, 2019 · 7 comments

Comments

@rgommers
Copy link
Member

rgommers commented Jun 30, 2019

I think this is a bug, but not 100% sure. The ndarray.__array_function__ implementation seems special, it's not recognized when applying array_function_dispatch to a function outside of NumPy (which NEP 18 suggests is possible).

To try, I add the following lines to PyTorch at the end of torch.__init__.py:

def _sum_dispatcher(input, dtype=None):
    return (input, dtype)

_sum = sum
@_np.core.overrides.array_function_dispatch(_sum_dispatcher)
def sum(input, dtype=None):
    return _sum(input)  # don't worry about the missing `dtype` here, that's a torch issue

Then, I run the following (on 1.16.4 with the envvar enabled; need to rebuild to try master - EDIT: same for current master):

import numpy as np
import torch
import sparse
import dask.array

t = torch.Tensor([1, 2])
x = t.numpy()
s = sparse.as_coo(x)
d = dask.array.from_array(x)

print("Sum of tensor t: ", torch.sum(t))
print("Sum of dask array d: ", torch.sum(d))
# Okay, let's add a compute()
print("Sum of dask array d (evaluated): ", torch.sum(d).compute())
print("Sum of sparse array s: ", torch.sum(s))
print("Sum of ndarray x: ", torch.sum(x))

This gives:

Sum of tensor t:  tensor(3.)
Sum of dask array d:  dask.array<sum-aggregate, shape=(), dtype=float32, chunksize=()>
Sum of dask array d (evaluated):  3.0
Sum of sparse array s:  3.0
Traceback (most recent call last):
  File "try_torch_array_function.py", line 18, in <module>
    print("Sum of ndarray x: ", torch.sum(x))
  File "/Users/rgommers/anaconda3/envs/pytorch/lib/python3.7/site-packages/numpy/core/overrides.py", line 165, in public_api
    implementation, public_api, relevant_args, args, kwargs)
  File "/Users/rgommers/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/__init__.py", line 326, in sum
    return _sum(input)
TypeError: sum(): argument 'input' (position 1) must be Tensor, not numpy.ndarray

So it works fine with Dask and pydata/sparse, but fails with NumPy - the traceback indicates that the dispatch to numpy.sum is not happening at all. Not expected I think?

@shoyer
Copy link
Member

shoyer commented Jun 30, 2019

I think this is working as expected, though there is indeed something surprising going on here.

By using array_function_dispatch, you define a new function torch.sum() that checks for overrides using __array_function__ -- there is no relationship with numpy.sum.

How then could dask and sparse know how to handle torch.sum? It's because dask's __array_function__ method (and presumably sparse's, too) ignores the top-level module name, assuming that it is always 'numpy'.

I suppose this is arguably a bug in dask's __array_function__ implementation: it should be returning NotImplemented if the top level module is not NumPy. Though at this point, this would just be about future proofing: we haven't exposed array_function_dispatch publicly yet.

@rgommers
Copy link
Member Author

I suppose this is arguably a bug in dask's __array_function__ implementation

that implementation would make dask.fft work with both numpy.fft and scipy.fft if SciPy decided to use __array_function__ for that module (just an example, not very likely at this point). Also it would allow multiple compatible top-level namespaces to implement the same mechanism and interoperate. So there is something to say for the approach Dask and Sparse take right now. On the one hand, it could fragment the API landscape again. On the other hand, if PyTorch for example would like to do this and numpy behavior is as it is now, it would need to create its own __torch_function__ protocol (that is under discussion). I'm not yet sure which one would be preferable. What do you think?

By using array_function_dispatch, you define a new function torch.sum() that checks for overrides using __array_function__ -- there is no relationship with numpy.sum.

Indeed. That will always be the case when using array_function_dispatch outside NumPy. Your previous comments in NEP 18 and on the SciPy roadmap made me think this was intended. Or were you thinking that it's intended only for functions that don't overlap in name with anything in NumPy?

@shoyer
Copy link
Member

shoyer commented Jun 30, 2019

On the other hand, if PyTorch for example would like to do this and numpy behavior is as it is now, it would need to create its own __torch_function__ protocol (that is under discussion). I'm not yet sure which one would be preferable. What do you think?

I don't think it particularly matters in most cases. As long as the design is consistent, there's not much overhead in adding another special method.

That said, we did go to a lot of trouble in NumPy to optimize this (e.g., all the dispatching logic written in C) and it seems needless to duplicate that. Particularly for SciPy, which is so closely tied to NumPy anyways. So I would lean towards encouraging using __array_function__ without assuming that everything necessarily comes from NumPy.

With regards to libraries implementing arrays, torch.sum() and numpy.sum() are likely mostly but not entirely interchangeable. For example, they might differ in optional keyword arguments. So it seems pretty error prone to always use the same implementation for both. But obviously it's up to the implementing library (e.g., dask or sparse).

Also, clearly NumPy is not the ultimate array API -- it works, but if we were starting from scratch we might make different decisions. I would rather not use de-facto standards to enforce NumPy's API choices in __array_function__.

@shoyer
Copy link
Member

shoyer commented Jun 30, 2019

By using array_function_dispatch, you define a new function torch.sum() that checks for overrides using __array_function__ -- there is no relationship with numpy.sum.

Indeed. That will always be the case when using array_function_dispatch outside NumPy. Your previous comments in NEP 18 and on the SciPy roadmap made me think this was intended. Or were you thinking that it's intended only for functions that don't overlap in name with anything in NumPy?

Yes, this is something I was thinking about in general, but we hadn't really thought about what it would look like in practice. It was tricky enough to figure things out for internal use in NumPy, so we didn't expose array_function_dispatch as a public API yet.

That said, we're not going to change its internal location for 1.17, so if we do make it public in 1.18 (e.g., as np.array_function_dispatch), you would also be safe to use it as np.core.overrides.array_function_dispatch for the 1.17 series.

@rgommers
Copy link
Member Author

That said, we're not going to change its internal location for 1.17, so if we do make it public in 1.18 (e.g., as np.array_function_dispatch), you would also be safe to use it as np.core.overrides.array_function_dispatch for the 1.17 series.

Given that any library interested in using it is going to want to support at least a few NumPy versions, they'd need to vendor it anyway until NumPy 1.20 or so. So saying np.core.overrides is private is perfectly fine. Would be useful to change the name to _overrides by the way, although at this point we've got so many mostly-private-but-not-underscored submodules that it doesn't really matter.

So I would lean towards encouraging using __array_function__ without assuming that everything necessarily comes from NumPy.

Yes, I think I agree. That would then imply removing the domain check from ndarray.__array_function__.

With regards to libraries implementing arrays, torch.sum() and numpy.sum() are likely mostly but not entirely interchangeable.

True. I think for almost any function (and also ufuncs), the numpy versions will have more keywords than other libraries. As long as the positional arguments match, and keywords that are common mean the same thing, I think this is fine.

For example, they might differ in optional keyword arguments. So it seems pretty error prone to always use the same implementation for both. But obviously it's up to the implementing library (e.g., dask or sparse).

This is always the case right, independent of this discussion? The API used and the executing library need to match for all keywords used. If today we do np.abs(dask_array, where=some_mask), then dask doesn't support where and should raise an error.

@shoyer
Copy link
Member

shoyer commented Jul 6, 2019

Yes, I think I agree. That would then imply removing the domain check from ndarray.__array_function__.

I'm not sure I follow -- what domain check are you referring to?

For example, they might differ in optional keyword arguments. So it seems pretty error prone to always use the same implementation for both. But obviously it's up to the implementing library (e.g., dask or sparse).

This is always the case right, independent of this discussion? The API used and the executing library need to match for all keywords used. If today we do np.abs(dask_array, where=some_mask), then dask doesn't support where and should raise an error.

I was contemplating new optional non-NumPy arguments, or especially cases where optional arguments have different semantics in NumPy vs another library. The later is probably not very common, but is not inconceivable.

@rgommers by the way, I saw your recent PyData Amsterdam talk on Youtube! Very nice, but one minor correction: you actually can override np.concatenate and np.linspace with __array_function__. Only array creation functions that don't take array-likes as arguments cannot be overridden.

@rgommers
Copy link
Member Author

rgommers commented Jul 6, 2019

I'm not sure I follow -- what domain check are you referring to?

IIRC the one that ensures that torch.sum(some_numpy_ndarray) prevents dispatching to np.sum when torch.sum is decorated with array_function_dispatch. But I could be wrong - won't have time before SciPy to dig into this again. If this doesn't make sense, then never mind for now.

@rgommers by the way, I saw your recent PyData Amsterdam talk on Youtube! Very nice, but one minor correction: you actually can override np.concatenate and np.linspace with array_function. Only array creation functions that don't take array-likes as arguments cannot be overridden.

Thanks! Yes, I see you're right. It's a bit artificial though. np.linspace(0, 10) is the normal form and that certainly won't work. np.linspace(np.array(0), np.array(10)) works but is quite odd.

I was contemplating new optional non-NumPy arguments, or especially cases where optional arguments have different semantics in NumPy vs another library. The later is probably not very common, but is not inconceivable.

ah yes. that would indeed require some contemplation. at that point, using the package that has those non-numpy keywords directly is probably the less confusing way to go.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants