New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor np.where to use overload #8258
refactor np.where to use overload #8258
Conversation
Thanks for the review, @apmasell |
@guilhermeleobas thanks for the patch. In the interests of making this easier to review, perhaps split the API refactoring and the addition of new functionality into separate patches? |
fec7c97
to
924893e
Compare
…numpy version < 1.20
924893e
to
66eaa41
Compare
@stuartarchibald, this is a new impl. of |
@guilhermeleobas please could you resolve the conflicts against |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code looks functionally good and well-written. There are mostly issues with handling of None
and a few nitpicks but otherwise it LGTM.
numba/np/arraymath.py
Outdated
if x.layout == y.layout == condition.layout: | ||
layout = x.layout | ||
else: | ||
layout = 'C' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be A
format ?
Having a C
would mean the implementation would tend to go though _where_fast_inner_impl
even though arrays might have different layouts (which may not be C
or F
) ?
(The resulting array should be C
ordered but that is automatically handled since the default layout is C
ordered.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just kept the original behavior, which assigns the layout to C
:
numba/numba/core/typing/npydecl.py
Lines 710 to 715 in 3bee7be
if (cond.ndim == x.ndim == y.ndim): | |
if x.layout == y.layout == cond.layout: | |
retty = types.Array(retdty, x.ndim, x.layout) | |
else: | |
retty = types.Array(retdty, x.ndim, 'C') | |
return signature(retty, *args) |
numba/np/arraymath.py
Outdated
raise NumbaTypeError(msg.format(name)) | ||
|
||
if is_nonelike(x) and is_nonelike(y): | ||
return _where_cond_none_none |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The support for None
objects seem to deviate from what's expected:
import numpy as np
import numba
def foo_np(cond, x, y):
return np.where(cond, x, y)
@numba.njit
def foo_nb(cond, x, y):
return np.where(cond, x, y)
cond = np.array([None, 1])
x = np.array([0, 1])
y = np.array([3, 4])
print(foo_np(cond, x, y))
print(foo_nb(cond, x, y)) # Error
cond = None
x = np.array([0, 1])
y = np.array([3, 4])
print(foo_np(cond, x, y))
print(foo_nb(cond, x, y)) # Error
cond = np.array([0, 1])
x = np.array([0, 1])
y = None
print(foo_np(cond, x, y))
print(foo_nb(cond, x, y)) # Error
cond = np.array([0, 1])
x = None
y = None
print(foo_np(cond, x, y))
print(foo_nb(cond, x, y)) # Wrong Results
If this is not supported then it should be caught as a proper error. Otherwise, The behaviour of None
in condition seem to be the same as 0
and within arrays it seems to be treated as a normal element.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Supporting None
inputs will be tricky. I'll try to address this in the following days.
def test_np_where_numpy_ndim(self): | ||
# https://github.com/numpy/numpy/blob/fe2bb380fd9a084b622ff3f00cb6f245e8c1a10e/numpy/core/tests/test_multiarray.py#L8737-L8749 | ||
pyfunc = np_where_3 | ||
cfunc = jit(nopython=True)(pyfunc) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be better to use the njit
and func_name.py_func
API over here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm following the convention used in the test file.
tmpmask = c != 0 | ||
c[c == 0] = 41247212 | ||
c[tmpmask] = 0 | ||
np.testing.assert_equal(cfunc(c, b, a), r) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could add some tests involving None
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the patch, this largely looks good, great to see the NumPy tests passing too. I've left a few comments inline, once resolved should be good to merge.
numba/np/arraymath.py
Outdated
if is_nonelike(x) and is_nonelike(y): | ||
return _where_cond_none_none |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the use of kwargs with default None
is going to potentially cause issues. Consider:
In [5]: np.where([3], None, None)
Out[5]: array([None], dtype=object)
vs.
In [13]: np.array([3]).nonzero()
Out[13]: (array([0]),)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've changed the code to not use None
as default value.
numba/np/arraymath.py
Outdated
# | ||
# >>> np.where([0, 1], None, None) | ||
# array([None, None]) | ||
if x is None and y is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be:
if x is None and y is None: | |
if is_nonelike(x) and is_nonelike(y): |
as this is in the typing domain, however, I think there's further issues, e.g.:
from numba import njit
import numpy as np
@njit
def foo(a, x, y):
return np.where(a, x, y)
args = (1, None, None)
expected = foo.py_func(*args)
got = foo(*args)
print(expected, type(expected))
print(got, type(got))
produces:
None <class 'numpy.ndarray'>
(array([0]),) <class 'tuple'>
setting args = (np.ones(4), None, None)
also does something similarly strange.
for idx, c in np.ndenumerate(cond): | ||
res[idx] = x if c else y[idx] | ||
res[idx] = x[idx] if c else y[idx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fails to unify for res
in the case of 'unusual' inputs like:
cond, x, y = np.arange(-2, 2, 1), np.zeros((4, 4)), np.ones((4, 4), dtype='<U5'))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there anything we can do in this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whilst it's maybe possible to assess the output type based on the inputs and explicitly ban unsupported combinations, I think it's ok to leave it as is, the error message is reasonably informative and working out what's "unsupported" is probably complicated. Do you feel differently?
cond_ = np.broadcast_to(cond1, shape) | ||
x_ = np.broadcast_to(x1, shape) | ||
y_ = np.broadcast_to(y1, shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess there are cases where it's faster the compute then broadcast opposed to broadcast then compute. e.g. where the cond
is smaller dimension than x
and y
. Perhaps leave this opt for now and concentrate on correctness with view of getting this merged!
@kc611 @stuartarchibald, would you folks be ok with not supporting The current approach fails with |
@guilhermeleobas I think that would be fine, the existing implementation in Numba doesn't support it either so it's not a regression. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates @guilhermeleobas, think they address everything in the review. I'm inclined to leave the issue with unifying 'unusual' array types for now unless you feel strongly otherwise. Thanks again!
xref: #8254
np.where
to use@overload