You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Sorry if this was reported elsewhere or dismissed as something that numba will not support, but I tried searching and couldn't find anything general on this (only some related issues such as #846, #1825).
Basically, numpy will cast scalars to zero-dimensional arrays, while numba will throw a typing error for a few of its functions. Here are a few examples of how this bug manifests:
@njitdeff():
returnlist(np.ndenumerate(1))
f()
@njitdeff():
returnnp.reshape(1, (1,))
f()
@njitdeff():
returnnp.transpose(1)
f()
@njitdeff():
returnnp.min(1)
f()
@njitdeff():
returnnp.max(1)
f()
There are likely a few more instances (sum, prod, argmin, argmax, median), although I did check the list of supported methods somewhat thoroughly. Each case produces the same error:
Invalid usage of Function(<class '...'>) with parameters (int64)
and, in all cases, removing the @njit decorator gives a reasonable result, as does replacing the scalar with a 1-d array.
To motivate why this can matter, a user might call np.clip(a, a_min=0, a_max=1) to clip a scalar value for a. And a numba implementation of np.clip (#3468) might use np.ndenumerate(a) which would then run into this corner case. It seems unusual to handle this case within the np.clip implementation. But more to the point this is inconsistent with numpy, when I presume one of the goals is for existing numpy code to work with numba off-the-shelf.
The text was updated successfully, but these errors were encountered:
Thanks for the report. This is an unsupported feature, noted in #3175, but seems to be more prevalent. I'll have a think about whether there's a quick way to fix this!
In [1]: importnumpyasnpIn [2]: importnumbaasnbIn [3]: @nb.njit(fastmath=True)
...: deftest_nbclip(a):
...: returnnp.clip(np.trace(a) -1.0, -1.0, 1.0)
...:
In [4]: a=np.random.rand(3, 3)
In [5]: np.trace(a)
Out[5]: 1.496066502191042In [6]: foriinrange(10):
...: print(test_nbclip(a))
...:
Trace mentions that it tries to do ndenumerate. Is there a way to work around this? I'm trying to pass some float to arccos. Should I just handle NaNs afterward?
OK for now I've just done manual if clauses, it works and still offers a reasonable speedup. I'm curious why ndenumerate is being called, but I'll have to find time to look into that.
Reporting a bug
the change log (https://github.com/numba/numba/blob/master/CHANGE_LOG).
to write one see http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports).
Sorry if this was reported elsewhere or dismissed as something that numba will not support, but I tried searching and couldn't find anything general on this (only some related issues such as #846, #1825).
Basically, numpy will cast scalars to zero-dimensional arrays, while numba will throw a typing error for a few of its functions. Here are a few examples of how this bug manifests:
There are likely a few more instances (
sum
,prod
,argmin
,argmax
,median
), although I did check the list of supported methods somewhat thoroughly. Each case produces the same error:and, in all cases, removing the
@njit
decorator gives a reasonable result, as does replacing the scalar with a 1-d array.To motivate why this can matter, a user might call
np.clip(a, a_min=0, a_max=1)
to clip a scalar value fora
. And a numba implementation ofnp.clip
(#3468) might usenp.ndenumerate(a)
which would then run into this corner case. It seems unusual to handle this case within thenp.clip
implementation. But more to the point this is inconsistent with numpy, when I presume one of the goals is for existingnumpy
code to work withnumba
off-the-shelf.The text was updated successfully, but these errors were encountered: