Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scalars raise typing errors for ndenumerate, reshape, transpose, min, etc. #3469

Open
2 tasks done
arvoelke opened this issue Nov 3, 2018 · 2 comments
Open
2 tasks done
Labels

Comments

@arvoelke
Copy link
Contributor

arvoelke commented Nov 3, 2018

Reporting a bug

Sorry if this was reported elsewhere or dismissed as something that numba will not support, but I tried searching and couldn't find anything general on this (only some related issues such as #846, #1825).

Basically, numpy will cast scalars to zero-dimensional arrays, while numba will throw a typing error for a few of its functions. Here are a few examples of how this bug manifests:

@njit
def f():
    return list(np.ndenumerate(1))
f()
@njit
def f():
    return np.reshape(1, (1,))
f()
@njit
def f():
    return np.transpose(1)
f()
@njit
def f():
    return np.min(1)
f()
@njit
def f():
    return np.max(1)
f()

There are likely a few more instances (sum, prod, argmin, argmax, median), although I did check the list of supported methods somewhat thoroughly. Each case produces the same error:

Invalid usage of Function(<class '...'>) with parameters (int64)

and, in all cases, removing the @njit decorator gives a reasonable result, as does replacing the scalar with a 1-d array.

To motivate why this can matter, a user might call np.clip(a, a_min=0, a_max=1) to clip a scalar value for a. And a numba implementation of np.clip (#3468) might use np.ndenumerate(a) which would then run into this corner case. It seems unusual to handle this case within the np.clip implementation. But more to the point this is inconsistent with numpy, when I presume one of the goals is for existing numpy code to work with numba off-the-shelf.

@stuartarchibald
Copy link
Contributor

Thanks for the report. This is an unsupported feature, noted in #3175, but seems to be more prevalent. I'll have a think about whether there's a quick way to fix this!

@adigitoleo
Copy link

adigitoleo commented Jul 11, 2022

Probably the same issue?

In [1]: import numpy as np

In [2]: import numba as nb

In [3]: @nb.njit(fastmath=True)
   ...: def test_nbclip(a):
   ...:     return np.clip(np.trace(a) - 1.0, -1.0, 1.0)
   ...:

In [4]: a = np.random.rand(3, 3)

In [5]: np.trace(a)
Out[5]: 1.496066502191042

In [6]: for i in range(10):
   ...:     print(test_nbclip(a))
   ...:

Trace mentions that it tries to do ndenumerate. Is there a way to work around this? I'm trying to pass some float to arccos. Should I just handle NaNs afterward?

OK for now I've just done manual if clauses, it works and still offers a reasonable speedup. I'm curious why ndenumerate is being called, but I'll have to find time to look into that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants