New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Issue with using overload
on a vectorized function
#4029
Comments
Thanks for the report. Whilst I'm fairly sure that import numba
@numba.vectorize
def f(x):
pass
@numba.njit
def g(x):
return f(x)
g(1.0) reproduces the error you see. If however you change import numba
@numba.vectorize
def f(x):
return 1
@numba.njit
def g(x):
return f(x)
g(1.0) this works fine. Once fixed up, going back to the import numba
@numba.vectorize
def f(x):
return 1
@numba.extending.overload(f)
def f_overload(x):
if x == numba.types.float64:
return lambda x: x
@numba.njit
def g(x):
return f(x)
print(g(10.0)) prints the value |
Thanks for looking into that @stuartarchibald. My original goal was to do basically your fixed up example. |
Breaking this down a bit, it looks like you want to do select between two implementations:
The standard way to do this is to use import numba
import numpy as np
@numba.vectorize
def f(x):
return 1
@numba.generated_jit(nopython=True)
def g(x):
# args are types, not values
if x is numba.types.float64:
def scalar_impl(x):
return 2.0
return scalar_impl
else:
def vector_impl(x):
return f(x) # call the ufunc defined above
return vector_impl which then results in this: >>> print(g(5.0))
2.0
>>> print(g(np.arange(5)))
[1 1 1 1 1]
>>> print(g(np.arange(6).reshape(2,3)))
[[1 1 1]
[1 1 1]] (The only minor improvement I would prefer here is being able to return |
@seibert that would probably work for what I'm trying to do. (I'll try it out later today!) I am (however) in the situation where my https://github.com/person142/spycial/blob/master/spycial/trig.py#L119 |
Hmm, that is an interesting situation. Do you need the top-level user function to be a ufunc? It seems like the best way to straighten this out would be to do all the type-based selection of implementations at the top level function, and call out to different ufuncs or regular functions as needed. |
No matter what mechanism we offer, I think the metaprogramming you are trying to do will be challenging. :) |
Yeah, that's probably the right way to go. There was appeal in having ufuncs because it gave "drop in" replacements (though in my case I'd really prefer to just provide overloaded scalar functions and let users call |
Hm the In [16]: g(np.array([1.0]))
Out[16]: array([1])
In [17]: g([1.0])
NotImplementedError Traceback (most recent call last)
...
LoweringError: Failed in nopython mode pipeline (step: nopython mode backend)
unsupported type for input operand: reflected list(float64)
File "<ipython-input-10-c3444b0c0ec1>", line 9:
def vector_impl(x):
return f(x)
^
[1] During: lowering "$0.3 = call $0.1(x, func=$0.1, args=[Var(x, <ipython-input-10-c3444b0c0ec1> (9))], kws=(), vararg=None)" at <ipython-input-10-c3444b0c0ec1> (9)
-------------------------------------------------------------------------------
This should not have happened, a problem has occurred in Numba's internals.
Please report the error message and traceback, along with a minimal reproducer
at: https://github.com/numba/numba/issues/new
If more help is needed please feel free to speak to the Numba core developers
directly at: https://gitter.im/numba/numba
Thanks in advance for your help in improving Numba! |
the change log (https://github.com/numba/numba/blob/master/CHANGE_LOG).
to write one see http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports).
If I try to overload a function decorated with
numba.vectorize
:then I get this traceback:
Traceback
which ends with a request to report the issue:
If I do the same thing but without the
vectorize
decorator then the overload works as expected:The text was updated successfully, but these errors were encountered: