Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Function accepting njitted functions as arguments is slow #2952

Open
Juanlu001 opened this Issue May 6, 2018 · 5 comments

Comments

2 participants
@Juanlu001
Copy link
Contributor

Juanlu001 commented May 6, 2018

I was trying numba 0.38 and the new support for jitted functions as arguments with this code snippet:

# coding: utf-8
from scipy.optimize import newton
from numba import njit
@njit
def func(x):
    return x**3 - 1
@njit
def fprime(x):
    return 3 * x**2
@njit
def njit_newton(func, x0, fprime):
    for _ in range(50):
        fder = fprime(x0)
        fval = func(x0)
        newton_step = fval / fder
        x = x0 - newton_step
        if abs(x - x0) < 1.48e-8:
            return x
        x0 = x
            
get_ipython().run_line_magic('timeit', 'newton(func.py_func, 1.5, fprime=fprime.py_func)')
get_ipython().run_line_magic('timeit', 'newton(func, 1.5, fprime=fprime)')
get_ipython().run_line_magic('timeit', 'njit_newton.py_func(func, 1.5, fprime=fprime)')
get_ipython().run_line_magic('timeit', 'njit_newton(func, 1.5, fprime=fprime)')

And I found surprising that njit_newton is the slowest of all, while njit_newton.py_func is the fastest:

$ ipython test_perf.py 
4.76 µs ± 8.52 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
4.14 µs ± 30.8 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
3.58 µs ± 26 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
20 µs ± 85.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

cc @nikita-astronaut

(Inspiration: https://github.com/scipy/scipy/blob/607a21e07dad234f8e63fcf03b7994137a3ccd5b/scipy/optimize/zeros.py#L164-L182)

@stuartarchibald

This comment has been minimized.

Copy link
Contributor

stuartarchibald commented May 7, 2018

Thanks for the report. I can reproduce.

@Juanlu001

This comment has been minimized.

Copy link
Contributor Author

Juanlu001 commented Jul 22, 2018

I would like to kindly ask how is this issue placed in the general roadmap of numba. For me (in particular for the poliastro project, which has numba as a core requirement) is kind of critical to be able to pass functions as arguments, because it's the kind of thing that would enable us to make reusable, fast code (the Newton's method example is a paradigmatic one). However, I understand that it might be a) difficult to implement, b) less important for numba developers than other features, c) low hanging fruit waiting for a champion or d) others. I appreciate that it was acknowledged as a bug, so if numba devs have some more information on when do they think this could be fixed (short term, medium term, not before 2019) it would be great to know. Thanks for this wonderful library ❤️

@Juanlu001

This comment has been minimized.

Copy link
Contributor Author

Juanlu001 commented Jul 22, 2018

By the way, it seems that the problem lies in having functions in the signature, not in calling them:

In [12]: @njit
    ...: def njit_newton3(func, x0, fprime):
    ...:     for _ in range(50):
    ...:         fder = 3 * x0**2
    ...:         fval = x0**3 - 1
    ...:         newton_step = fval / fder
    ...:         x = x0 - newton_step
    ...:         if abs(x - x0) < 1.48e-8:
    ...:             return x
    ...:         x0 = x
    ...:         

In [13]: @njit
    ...: def njit_newton4(x0):
    ...:     for _ in range(50):
    ...:         fder = 3 * x0**2
    ...:         fval = x0**3 - 1
    ...:         newton_step = fval / fder
    ...:         x = x0 - newton_step
    ...:         if abs(x - x0) < 1.48e-8:
    ...:             return x
    ...:         x0 = x
    ...:         

In [14]: %timeit njit_newton3(func, 1.5, fprime=fprime)
26.8 µs ± 277 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

In [15]: %timeit njit_newton4(1.5)
268 ns ± 1.16 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

In [17]: @njit
    ...: def njit_newton5(x0):
    ...:     for _ in range(50):
    ...:         fder = fprime(x0)
    ...:         fval = func(x0)
    ...:         newton_step = fval / fder
    ...:         x = x0 - newton_step
    ...:         if abs(x - x0) < 1.48e-8:
    ...:             return x
    ...:         x0 = x
    ...:         

In [18]: %timeit njit_newton5(1.5)
278 ns ± 4.77 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

(this is with numba 0.39).

@Juanlu001

This comment has been minimized.

Copy link
Contributor Author

Juanlu001 commented Jul 22, 2018

Investigating this further, I discovered that I can workaround the issue by using closures:

In [24]: def newton_generator(func, fprime):
    ...:     @njit
    ...:     def njit_newton_final(x0):
    ...:         for _ in range(50):
    ...:             fder = fprime(x0)
    ...:             fval = func(x0)
    ...:             newton_step = fval / fder
    ...:             x = x0 - newton_step
    ...:             if abs(x - x0) < 1.48e-8:
    ...:                 return x
    ...:             x0 = x
    ...:     return njit_newton_final
    ...: 
    ...: 

In [25]: newton_func = newton_generator(func, fprime)

In [26]: newton_func(1.5)
Out[26]: 1.0

In [27]: %timeit newton_func(1.5)
297 ns ± 10.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

@seibert seibert added this to Triage in Bugs Aug 27, 2018

@Juanlu001

This comment has been minimized.

Copy link
Contributor Author

Juanlu001 commented Jan 8, 2019

I checked that this is still an issue with the latest version. Much simpler example:

In [1]: from numba import njit                                                                                                          

In [2]: @njit 
   ...: def foo(x): 
   ...:     return x 
   ...:                                                                                                                                 

In [3]: @njit 
   ...: def foo_bad(x, func): 
   ...:     return x 
   ...:                                                                                                                                 

In [4]: %timeit foo(1)                                                                                                                  
293 ns ± 2.12 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

In [5]: %timeit foo_bad(1, foo)                                                                                                         
29.3 µs ± 473 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
In [8]: foo.inspect_types()                                                                                                             
foo (int64,)
--------------------------------------------------------------------------------
# File: <ipython-input-2-32059a1fd1a4>
# --- LINE 1 --- 
# label 0

@njit

# --- LINE 2 --- 

def foo(x):

    # --- LINE 3 --- 
    #   x = arg(0, name=x)  :: int64
    #   $0.2 = cast(value=x)  :: int64
    #   del x
    #   return $0.2

    return x


================================================================================

In [9]: foo_bad.inspect_types()                                                                                                         
foo_bad (int64, type(CPUDispatcher(<function foo at 0x7fc4a7b20488>)))
--------------------------------------------------------------------------------
# File: <ipython-input-3-3db07e304d7c>
# --- LINE 1 --- 
# label 0

@njit

# --- LINE 2 --- 

def foo_bad(x, func):

    # --- LINE 3 --- 
    #   x = arg(0, name=x)  :: int64
    #   func = arg(1, name=func)  :: type(CPUDispatcher(<function foo at 0x7fc4a7b20488>))
    #   del func
    #   $0.2 = cast(value=x)  :: int64
    #   del x
    #   return $0.2

    return x


================================================================================
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.