New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Inserting global njit, numpy function into an njit function fails #4602
Comments
Here's the response when
|
@mroeschke thank you for asking this question about Numba. Could you possibly expand a bit on what larger problem you are attempting to solve? I.e. is anything stopping you from using the following construct:
|
As hinted at by @esc, the issue is that you are trying to jit the |
I'm currently trying to implement rolling apply in pandas, in which a user can pass an arbitrary function that can be applied to a window of data. My exact implementation can be found in https://github.com/twosigma/pandas/pull/29/files#diff-0de5c5d9abfcdd141e83701eaaec4358R1145, but posting the relevant part:
If I change
this will fail in nopython mode (which I would really like to maintain here for performance)
|
@mroeschke Thanks for the update. This might help? It's unfortunately not particularly elegant. import numba
import numpy as np
def make_rolling_apply(func, args=(), kwargs=()):
@numba.generated_jit(nopython=True)
def numba_func(window, *_args):
if getattr(np, func.__name__, False):
def impl(window, *_args):
return func(window, *_args)
return impl
else:
jf = numba.njit(func)
def impl(window, *_args):
return jf(window, *_args)
return impl
# I'd like this function signature to remain fixed
# (consistent with other roll functions)
@numba.njit
def roll_apply(
values: np.ndarray,
begin: np.ndarray,
end: np.ndarray,
minimum_periods: int,
):
result = np.empty(len(begin))
for i, (start, stop) in enumerate(zip(begin, end)):
window = values[start:stop]
count_nan = np.sum(np.isnan(window))
if len(window) - count_nan >= minimum_periods:
result[i] = numba_func(window, *args)
else:
result[i] = np.nan
return result
return roll_apply
def _apply(the_func, args, kwargs):
# this stuff would come from self?
N = 10
values = np.ones((N),)
begin = np.arange(N)
end = begin + 1
minimum_periods = 1
impl = make_rolling_apply(the_func, args=args, kwargs=kwargs)
return impl(values, begin, end, minimum_periods)
print(_apply(np.sum, (), {}))
def foo(window, *args):
arg1, arg2 = args
return (window[0] + arg1) / arg2
print(_apply(foo, args=(10, 20), kwargs={}))
def bar(window, *args):
arg1, arg2 = args
return (window[0] + arg1) / arg2[1]
print(_apply(bar, args=(10, np.ones(5)), kwargs={})) |
Thanks @stuartarchibald! I'll try this solution tonight. Special handling of numpy function may the right work around here. |
Your solution worked @stuartarchibald, thanks! So should I assume |
@mroeschke great! It might also be worth adding a check in the NumPy function identification part to make sure it is indeed the NumPy function, i.e. this breaks: In [4]: import numpy as np
In [5]: def sum(x):
...: pass
...:
In [6]: func = sum
In [7]: getattr(np, func.__name__, False)
Out[7]: <function numpy.sum(a, axis=None, dtype=None, out=None, keepdims=<no value>, initial=<no value>)> which probably needs something like: In [11]: possible_np_func = getattr(np, func.__name__, False)
In [12]: if func is possible_np_func:
...: print("is Numpy func")
...: else:
...: print("is not NumPy func")
...:
is not NumPy func to make sure only NumPy functions are treated this way.
This is up for discussion at the Numba core developer meeting today, with this #4599 in the works use cases similar to that presented are more likely. We briefly discussed at an issue triage session yesterday how such a feature could be implemented too, it seems like it's technically possible. |
Outcome of discussion was #4608 |
Thanks. I'll close this issue in favor of #4608 |
Reporting a bug
the change log (https://github.com/numba/numba/blob/master/CHANGE_LOG).
to write one see http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports).
The text was updated successfully, but these errors were encountered: