Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vector @overload not overwriting scalar one #4324

Open
2 tasks done
person142 opened this issue Jul 14, 2019 · 3 comments
Open
2 tasks done

Vector @overload not overwriting scalar one #4324

person142 opened this issue Jul 14, 2019 · 3 comments
Labels

Comments

@person142
Copy link
Contributor

If I define a scalar overload for a function and then later define a vector overload, the vector overload appears to not be respected. A simple example is:

In [1]: import numba

In [2]: import scipy.special as sc

In [3]: @numba.extending.overload(sc.gamma)
   ...: def gamma_scalar(a):
   ...:     if a is numba.types.float64:
   ...:         return lambda a: a**2
   ...:

In [4]: @numba.extending.overload(sc.gamma)
   ...: def gamma_vector(a):
   ...:     if isinstance(a, numba.types.Array):
   ...:         @numba.vectorize
   ...:         def vector(a):
   ...:             return a
   ...:         return lambda a: vector(a)
   ...:

In [5]: @numba.njit
   ...: def gamma(x):
   ...:     return sc.gamma(x)
   ...:

In [6]: import numpy as np

In [7]: gamma(2.0)
Out[7]: 4.0

In [8]: gamma(np.array([2.0]))
Out[8]: array([4.])

I would have expected to get 2.0 in Out[8].

@stuartarchibald
Copy link
Contributor

Thanks for the report, I can reproduce, this is a bug. Interestingly this works:

import numba

import scipy.special as sc

def zzz(x):
    pass

@numba.extending.overload(zzz)
def gamma_scalar(a):
    print("type of a", type(a))
    print("testing float ty")
    if isinstance(a, numba.types.Float):
        print("Using scalar overload")
        def scalar_impl(a):
            print("In scalar overload")
            return a ** 2
        return scalar_impl

    print("testing array ty")
    if isinstance(a, numba.types.Array):
        print("Using vector overload")
        def foo(a):
            print("In vector overload")
            return a + 10
        return foo

@numba.njit
def gamma(x):
    return zzz(x)

import numpy as np

print(gamma(2.0))
print(gamma(np.array([2.0])))

which suggests it's something to do with binding the scipy module (probably not actually scipy related, just external module related).

@stuartarchibald
Copy link
Contributor

This is something to do with the array expression rewrite pass.

@stuartarchibald
Copy link
Contributor

ok, I think I've got it... what happens is this:

  1. The njit function gamma calls sc.gamma with an array argument.
  2. The typing mechanism correctly spots this and selects the gamma_vector overload.
  3. The "vectorized" function is then typed and appears as a ufunc called "gamma"
  4. The above ufunc is then typed as "gamma" which hits the overload resolution mechanism when the rewriting happens and as the ufunc kernel will dispatch on scalar dtype it resolves "gamma" to the gamma_scalar.

In one line... gamma calls sc.gamma resolves to gamma_vector calls vector resolves to gamma ufunc resolves to gamma_scalar calls lambda a: a** 2, hence problem and unexpected answer.

I think the fix might require making ufuncs care about the module the function came from (assuming that they don't now and that was the cause of the problem).

@seibert seibert mentioned this issue Jul 29, 2019
21 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants