Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot compile in nopython mode for a "specialized" function #5817

Closed
flothesof opened this issue Jun 4, 2020 · 6 comments
Closed

Cannot compile in nopython mode for a "specialized" function #5817

flothesof opened this issue Jun 4, 2020 · 6 comments
Labels
more info needed This issue needs more information question Notes an issue as a question

Comments

@flothesof
Copy link

flothesof commented Jun 4, 2020

Hi,

I am working on a physics problem where I assemble a matrix using an analytic function which has parameters of two kinds:

  • ones available ahead of time (predef_param below)
  • ones that change during the assembling of the matrix (runtime_param below)

To make my assembling code more general, I have tried to separate both concerns (ahead of time vs. runtime) so that I can reuse it.

This results in the fact that I need to have a function simpler_func with less parameters calling the original complex_func.

I have code that looks approximately like this:

PARAM = 5.

def complex_func(runtime_param, predef_param):
    return np.cos(predef_param) ** runtime_param

@nb.jit(forceobj=True)
def simpler_func(runtime_param):
    return complex_func(runtime_param, PARAM)

@nb.njit
def assemble_matrix(N):
    mat = np.zeros((N, N))
    mapping = np.arange(N)
    for i in range(N):
        I = mapping[i]
        for j in range(N):
            J = mapping[j]
            mat[i, j] = simpler_func(I + J)
    return mat

def main():
    assemble_matrix(5)

Running the main function results in the following error with numba 0.49.1:

TypingError                               Traceback (most recent call last)
<ipython-input-25-263240bbee7e> in <module>
----> 1 main()

<ipython-input-24-22ddcf18b8e0> in main()
     20 
     21 def main():
---> 22     assemble_matrix(5)

C:\Anaconda3\envs\pyrus_dev\lib\site-packages\numba\core\dispatcher.py in _compile_for_args(self, *args, **kws)
    399                 e.patch_message(msg)
    400 
--> 401             error_rewrite(e, 'typing')
    402         except errors.UnsupportedError as e:
    403             # Something unsupported is present in the user code, add help info

C:\Anaconda3\envs\pyrus_dev\lib\site-packages\numba\core\dispatcher.py in error_rewrite(e, issue_type)
    342                 raise e
    343             else:
--> 344                 reraise(type(e), e, None)
    345 
    346         argtypes = []

C:\Anaconda3\envs\pyrus_dev\lib\site-packages\numba\core\utils.py in reraise(tp, value, tb)
     78         value = tp()
     79     if value.__traceback__ is not tb:
---> 80         raise value.with_traceback(tb)
     81     raise value
     82 

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of type(CPUDispatcher(<function simpler_func at 0x0000024D428E8EE8>)) with parameters (int64)
 * parameterized
[1] During: resolving callee type: type(CPUDispatcher(<function simpler_func at 0x0000024D428E8EE8>))
[2] During: typing of call at <ipython-input-24-22ddcf18b8e0> (18)


File "<ipython-input-24-22ddcf18b8e0>", line 18:
def assemble_matrix(N):
    <source elided>
            J = mapping[j]
            mat[i, j] = simpler_func(I + J)

Is there a way to make this work? What am I doing wrong? Does this have something to do with the fact that functions can't be passed around well?

Thank you in advance for your help.

Florian

@flothesof
Copy link
Author

I have also tried to rewrite the above in a closure style code (inspired by the docs example https://numba.pydata.org/numba-doc/latest/user/faq.html?highlight=function#can-i-pass-a-function-as-an-argument-to-a-jitted-function, but this fails too:

@nb.jit
def complex_func(runtime_param, predef_param):
    return np.cos(predef_param) ** runtime_param

def make_simpler_func(complex_func):
    PARAM = 5.

    @nb.jit(forceobj=True)
    def simpler_func(runtime_param):
        return complex_func(runtime_param, PARAM)
    
    return simpler_func

@nb.njit
def assemble_matrix(N, simpler_func):
    
    
    mat = np.zeros((N, N))
    mapping = np.arange(N)
    for i in range(N):
        I = mapping[i]
        for j in range(N):
            J = mapping[j]
            mat[i, j] = simpler_func(I + J)
    return mat

def main():
    simpler_func = make_simpler_func(complex_func)
    assemble_matrix(5, simpler_func)

Traceback:

---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
<ipython-input-39-263240bbee7e> in <module>
----> 1 main()

<ipython-input-34-bfc00d884e85> in main()
     27 def main():
     28     simpler_func = make_simpler_func(complex_func)
---> 29     assemble_matrix(5, simpler_func)

C:\Anaconda3\envs\pyrus_dev\lib\site-packages\numba\core\dispatcher.py in _compile_for_args(self, *args, **kws)
    399                 e.patch_message(msg)
    400 
--> 401             error_rewrite(e, 'typing')
    402         except errors.UnsupportedError as e:
    403             # Something unsupported is present in the user code, add help info

C:\Anaconda3\envs\pyrus_dev\lib\site-packages\numba\core\dispatcher.py in error_rewrite(e, issue_type)
    342                 raise e
    343             else:
--> 344                 reraise(type(e), e, None)
    345 
    346         argtypes = []

C:\Anaconda3\envs\pyrus_dev\lib\site-packages\numba\core\utils.py in reraise(tp, value, tb)
     78         value = tp()
     79     if value.__traceback__ is not tb:
---> 80         raise value.with_traceback(tb)
     81     raise value
     82 

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of type(CPUDispatcher(<function make_simpler_func.<locals>.simpler_func at 0x00000278B1E25318>)) with parameters (int64)
 * parameterized
[1] During: resolving callee type: type(CPUDispatcher(<function make_simpler_func.<locals>.simpler_func at 0x00000278B1E25318>))
[2] During: typing of call at <ipython-input-34-bfc00d884e85> (24)


File "<ipython-input-34-bfc00d884e85>", line 24:
def assemble_matrix(N, simpler_func):
    <source elided>
            J = mapping[j]
            mat[i, j] = simpler_func(I + J)
            ^

@gmarkall
Copy link
Member

gmarkall commented Jun 4, 2020

Can you explain a bit about what prevents you from jitting complex_func, i.e. doing something like:

from numba import jit, njit, objmode
import numpy as np

PARAM = 5.


@njit
def complex_func(runtime_param, predef_param):
    return np.cos(predef_param) ** runtime_param


@njit
def simpler_func(runtime_param):
    return complex_func(runtime_param, PARAM)


@njit
def assemble_matrix(N):
    mat = np.zeros((N, N))
    mapping = np.arange(N)
    for i in range(N):
        I = mapping[i]
        for j in range(N):
            J = mapping[j]
            mat[i, j] = simpler_func(I + J)
    return mat


def main():
    assemble_matrix(5)


if __name__ == '__main__':
    main()

?

@gmarkall gmarkall added more info needed This issue needs more information question Notes an issue as a question labels Jun 4, 2020
@HPLegion
Copy link
Contributor

HPLegion commented Jun 4, 2020

Not a solution, but a clarification as to what may be going wrong here:
You cannot call a function that was compiled in object mode (forceobj=True) from a function that was called in nopython mode, as is illustrated by this example:

from numba import jit, njit

@jit(forceobj=True)
def f(x):
    return x

@njit
def g(x):
    return f(x)

g(1)
Invalid use of type(CPUDispatcher(<function f at 0x7f7c5b4ae4d0>)) with parameters (int64)
 * parameterized
[1] During: resolving callee type: type(CPUDispatcher(<function f at 0x7f7c5b4ae4d0>))
[2] During: typing of call at njit_calls_jit.py (9)


File "njit_calls_jit.py", line 9:
def g(x):
    return f(x)
    ^

Maybe there would be a way to do this by explicitly using the object mode context manager but this may bring with it a hefty performance penalty.

@flothesof
Copy link
Author

Hi @gmarkall and @HPLegion
Thank you very much for your replies, which are very helpful.
So it turns out that I had indeed a reason for using the forceobj=True flag: I was using a function from scipy:

from numba import jit, njit, objmode
from scipy.special import factorial2
import numpy as np

PARAM = 5.


@njit
def complex_func(runtime_param, predef_param):
    return np.cos(predef_param) ** runtime_param * factorial2(predef_param)


@njit
def simpler_func(runtime_param):
    return complex_func(runtime_param, PARAM)


@njit
def assemble_matrix(N):
    mat = np.zeros((N, N))
    mapping = np.arange(N)
    for i in range(N):
        I = mapping[i]
        for j in range(N):
            J = mapping[j]
            mat[i, j] = simpler_func(I + J)
    return mat


def main():
    return assemble_matrix(5)

if __name__ == '__main__':
    main()

On my machine this does not work, while @gmarkall's original example indeed works fine.

Based on your inputs, I think I can now reframe my question as: is it possible to have an njitted function call functions that cannot be njitted?

@HPLegion you said it might be possible through the object mode context manager. I'll try looking into it.

Thanks for your help!

@stuartarchibald
Copy link
Contributor

@flothesof perhaps take a look at http://numba.pydata.org/numba-doc/latest/user/withobjmode.html#the-objmode-context-manager seem like it would fit your use case?

@flothesof
Copy link
Author

Indeed, this might actually allow me to do it "as is".

In the case I was working on when I submitted the issue, the blocker really was the scipy.special.factorial2 function. Reimplementing it with numba-accelerated python let me do what I wanted.
I will probably be confronted with some more elaborate cases in the near future, where I won't have the knowledge to reimplement the function in pure-python in which case I'll try to use obj mode.

Given that my problem is solved and a workaround provided for further cases, I'll close the issue now. Feel free to reopen if you feel it's a better choice.

Thanks to all for your help.

Regards
Florian

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
more info needed This issue needs more information question Notes an issue as a question
Projects
None yet
Development

No branches or pull requests

4 participants