Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow numba.jit to compose with numba.stencil #3914

Open
mrocklin opened this issue Mar 30, 2019 · 2 comments
Open

Allow numba.jit to compose with numba.stencil #3914

mrocklin opened this issue Mar 30, 2019 · 2 comments

Comments

@mrocklin
Copy link

Today, I think that the way to use stencil functions with numba is to include them in another function that is jitted.

@numba.stencil
def f(x):
    ...

@numba.jit
def g(x):
    return f(x)

I wonder if it might make sense instead to allow numba.jit to be directly composed with numba.stencil

@numba.jit
@numba.stencil
def f(x):
    ...

Or, going further, stencil may just have a jit=True keyword. When I originally tried to use stencil I assumed that it would jit automatically. I suspect that most users coming to this function for the first time have a similar expectation, used to as we are with the simple API of numba.jit.

Example

In [1]: import numba

In [2]: import numpy as np

In [3]: x = np.arange(25).reshape((5, 5))

In [4]: x
Out[4]:
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24]])

In [5]: @numba.stencil
   ...: def f(x):
   ...:     return (x[-1, -1] + x[-1, 0] + x[-1, 1] +
   ...:             x[ 0, -1] + x[ 0, 0] + x[ 0, 1] +
   ...:             x[ 1, -1] + x[ 1, 0] + x[ 1, 1])
   ...:

In [6]: %timeit f(x)
189 ms ± 1.24 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [7]: @numba.njit
   ...: @numba.stencil
   ...: def f(x):
   ...:     return (x[-1, -1] + x[-1, 0] + x[-1, 1] +
   ...:             x[ 0, -1] + x[ 0, 0] + x[ 0, 1] +
   ...:             x[ 1, -1] + x[ 1, 0] + x[ 1, 1])
   ...:
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-7-f2d04f4f5a9e> in <module>
      1 @numba.njit
----> 2 @numba.stencil
      3 def f(x):
      4     return (x[-1, -1] + x[-1, 0] + x[-1, 1] +
      5             x[ 0, -1] + x[ 0, 0] + x[ 0, 1] +

~/miniconda/envs/dev/lib/python3.7/site-packages/numba/decorators.py in njit(*args, **kws)
    234         warnings.warn('forceobj is set for njit and is ignored', RuntimeWarning)
    235     kws.update({'nopython': True})
--> 236     return jit(*args, **kws)
    237
    238

~/miniconda/envs/dev/lib/python3.7/site-packages/numba/decorators.py in jit(signature_or_function, locals, target, cache, pipeline_class, **options)
    171                    targetoptions=options, **dispatcher_args)
    172     if pyfunc is not None:
--> 173         return wrapper(pyfunc)
    174     else:
    175         return wrapper

~/miniconda/envs/dev/lib/python3.7/site-packages/numba/decorators.py in wrapper(func)
    187         disp = dispatcher(py_func=func, locals=locals,
    188                           targetoptions=targetoptions,
--> 189                           **dispatcher_args)
    190         if cache:
    191             disp.enable_caching()

~/miniconda/envs/dev/lib/python3.7/site-packages/numba/dispatcher.py in __init__(self, py_func, locals, targetoptions, impl_kind, pipeline_class)
    542         arg_count = len(pysig.parameters)
    543         can_fallback = not targetoptions.get('nopython', False)
--> 544         _DispatcherBase.__init__(self, arg_count, py_func, pysig, can_fallback)
    545
    546         functools.update_wrapper(self, py_func)

~/miniconda/envs/dev/lib/python3.7/site-packages/numba/dispatcher.py in __init__(self, arg_count, py_func, pysig, can_fallback)
    180
    181         argnames = tuple(pysig.parameters)
--> 182         default_values = self.py_func.__defaults__ or ()
    183         defargs = tuple(OmittedArg(val) for val in default_values)
    184         try:

AttributeError: 'StencilFunc' object has no attribute '__defaults__'

In [8]: @numba.njit
   ...: def g(x):
   ...:     return f(x)
   ...:

In [9]: g(x)
Out[9]:
array([[  0,   0,   0,   0,   0],
       [  0,  54,  63,  72,   0],
       [  0,  99, 108, 117,   0],
       [  0, 144, 153, 162,   0],
       [  0,   0,   0,   0,   0]])

In [10]: %timeit g(x)
569 ns ± 27.5 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
@mrocklin
Copy link
Author

It would be nice also, for this to work with other decorators like guvectorize

@stuartarchibald
Copy link
Contributor

Thanks for the suggestion, I think that this would be a good thing to get working too. I'm fairly sure it's technically possible but may require a fair bit of unpicking of the current implementation as it's quite involved (also might be easier if/when more generic function IR inlining is supported).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants