Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tuple of functions is not iterable / cannot be accessed by index #7923

Closed
sakshamkumar-byt opened this issue Mar 19, 2022 · 4 comments
Closed
Labels
question Notes an issue as a question

Comments

@sakshamkumar-byt
Copy link

I am using numba 0.52.0 and intel sdc for Pandas support
This is what my method looks like

@overload_method(SeriesType, 'transform')
def pd_series_overload_single_func_args(self, func, *args):
    func_args = [self.dtype]
    if isinstance(self, SeriesType) and isinstance(func, types.Callable):

        func_args.extend(args)
        sig = func.get_call_type(cpu_target.typing_context, func_args, {})
        output_type = sig.return_type

        # find if final arg of function is *args
        def impl(self, func, *args):
            input_arr = self._data
            length = len(input_arr)

            output_arr = numpy.empty(length, dtype=output_type)

            for i in prange(length):
                output_arr[i] = func(input_arr[i], *args)

            return pandas.Series(output_arr, index=self._index, name=self._name)

        return impl
    elif isinstance(self, SeriesType) and isinstance(func, types.Tuple):
        output_types = []
        output_cols = []
        n_series = len(func)

        for i in prange(n_series):
            sig = func[i].get_call_type(cpu_target.typing_context, func_args, {})
            output_types.append(sig.return_type)
            output_cols.append(func[i].dispatcher.py_func.__name__)
            # print(sig.return_type)
            # print(func[i].dispatcher.py_func.__name__)

        col_names = tuple(output_cols)
        col_dtypes = tuple(output_types)

        def impl(self, func, *args):
            for i in func:
                print(i)

        return impl

when i try to run

@njit
def square(x, *args):
    return x ** 2

@njit()
def sum(x, *args):
    return x + 1

@njit
def series_apply():
    s = pd.Series([20.12, 21.2, 12.3],
                  index=['London', 'New York', 'Helsinki'])
    func = (sum, square)
    return s.transform(func)

this fails with error

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/saksham.kumar/miniconda3/envs/sdc-env/lib/python3.7/site-packages/numba/core/dispatcher.py", line 414, in _compile_for_args
    error_rewrite(e, 'typing')
  File "/home/saksham.kumar/miniconda3/envs/sdc-env/lib/python3.7/site-packages/numba/core/dispatcher.py", line 357, in error_rewrite
    raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
- Resolution failure for literal arguments:
No implementation of function Function(<function pd_series_overload_single_func_args at 0x7fe287830dd0>) found for signature:

 >>> pd_series_overload_single_func_args(series(float64, array(float64, 1d, C), StringArrayType(), False), Tuple(type(CPUDispatcher(<function sum at 0x7fe28efeee60>)), type(CPUDispatcher(<function square at 0x7fe2976663b0>))))

There are 2 candidate implementations:
  - Of which 2 did not match due to:
  Overload in function 'pd_series_overload_single_func_args': File: ../../../../../../../../data00/home/saksham.kumar/sdc_seconds/sdc/datatypes/pandas_series_functions/transform.py: Line 11.
    With argument(s): '(series(float64, array(float64, 1d, C), StringArrayType(), False), Tuple(type(CPUDispatcher(<function sum at 0x7fe28efeee60>)), type(CPUDispatcher(<function square at 0x7fe2976663b0>))))':
   Rejected as the implementation raised a specific error:
     TypingError: Failed in nopython mode pipeline (step: nopython frontend)
   Invalid use of getiter with parameters (Tuple(type(CPUDispatcher(<function sum at 0x7fe28efeee60>)), type(CPUDispatcher(<function square at 0x7fe2976663b0>))))
   
   During: typing of intrinsic-call at /data00/home/saksham.kumar/sdc_seconds/sdc/datatypes/pandas_series_functions/transform.py (75)
   
   File "sdc/datatypes/pandas_series_functions/transform.py", line 75:
           def impl(self, func, *args):
               <source elided>
               length = len(input_arr)
               for i in func:
               ^

  raised from /home/saksham.kumar/miniconda3/envs/sdc-env/lib/python3.7/site-packages/numba/core/typeinfer.py:1071

- Resolution failure for non-literal arguments:
None

During: resolving callee type: BoundFunction((<class 'sdc.hiframes.pd_series_type.SeriesType'>, 'transform') for series(float64, array(float64, 1d, C), StringArrayType(), False))
During: typing of call at <stdin> (1)


File "<stdin>", line 1:
<source missing, REPL/exec in use?>

The same also happens when trying to run with

@overload_method(SeriesType, 'transform')
def pd_series_overload_single_func_args(self, func, *args):
    func_args = [self.dtype]
    if isinstance(self, SeriesType) and isinstance(func, types.Callable):

        func_args.extend(args)
        sig = func.get_call_type(cpu_target.typing_context, func_args, {})
        output_type = sig.return_type

        # find if final arg of function is *args
        def impl(self, func, *args):
            input_arr = self._data
            length = len(input_arr)

            output_arr = numpy.empty(length, dtype=output_type)

            for i in prange(length):
                output_arr[i] = func(input_arr[i], *args)

            return pandas.Series(output_arr, index=self._index, name=self._name)

        return impl
    elif isinstance(self, SeriesType) and isinstance(func, types.Tuple):
        output_types = []
        output_cols = []
        n_series = len(func)

        for i in prange(n_series):
            sig = func[i].get_call_type(cpu_target.typing_context, func_args, {})
            output_types.append(sig.return_type)
            output_cols.append(func[i].dispatcher.py_func.__name__)
            # print(sig.return_type)
            # print(func[i].dispatcher.py_func.__name__)

        col_names = tuple(output_cols)
        col_dtypes = tuple(output_types)

        def impl(self, func, *args):
            print(func[0])

        return impl
@stuartarchibald
Copy link
Contributor

Thanks for the report. From the above, I think a minimal working reproducer is:

from numba import njit

@njit
def square(x):
    return x ** 2

@njit
def sum(x):
    return x + 1

@njit
def foo():
    funcs = (sum, square)
    for fn in funcs:
        print(fn(10.))

foo()

and that the error is about iterating over a tuple of functions.

The underlying issue is that Numba considers the type of a tuple like funcs = (sum, square) as a heterogeneous tuple of dispatcher types (dispatchers are the objects returned by @njit). As this tuple is heterogeneous it's not possible to do a getitem on it with a non-constant index as Numba won't know which dispatcher type will be returned.

There's two ways around this:

  1. Use numba.literal_unroll, this function will unroll the tuple and generate a version of the loop body for each type and then wire up the iteration space so that it works like iterating over a tuple. Example:

    from numba import njit, literal_unroll
    
    @njit
    def square(x):
        return x ** 2
    
    @njit
    def sum(x):
        return x + 1
    
    @njit
    def foo():
        funcs = (sum, square)
        for fn in literal_unroll(funcs):
            print(fn(10.))
    
    foo()
  2. When declaring the functions that go into the tuple, specify one or more signatures. This will let Numba consider the functions as a First-class function type and the tuple will be considered as a homogeneous tuple of first-class functions with a certain signature. This lets getitem/iteration work as the type of the variable in each iteration is the same. Example:

    from numba import njit
    
    @njit('float64(float64)')
    def square(x):
        return x ** 2
    
    @njit('float64(float64)')
    def sum(x):
        return x + 1
    
    @njit
    def foo():
        funcs = (sum, square)
        for fn in funcs:
            print(fn(10.))
    
    foo()

Hope this helps.

@stuartarchibald stuartarchibald added the question Notes an issue as a question label Mar 21, 2022
@sakshamkumar-byt
Copy link
Author

@stuartarchibald I see.
However i was able to directly access an element using func[0] or func[1]. The problem arose only when using a loop it seems.

this is why my impl looks something like this

    elif isinstance(self, SeriesType) and isinstance(func, types.Tuple):
        output_types = []
        n_series = len(func)

        for i in prange(n_series):
            sig = func[i].get_call_type(cpu_target.typing_context, func_args, {})
            output_types.append(sig.return_type)

        func_lines = [f"def impl(self, func, *args):",
                      f"  input_arr = self._data",
                      f"  length = len(input_arr)"]

        results = []
        for i in range(n_series):
            result_c = f"s_{i}"
            func_lines += [f"  output_arr_{i} = numpy.empty(length, dtype=types.{output_types[i]})",
                           f"  for i in prange(length):",
                           f"    output_arr_{i}[i] = func[{i}](input_arr[i], *args)",
                           f"  print(output_arr_{i})]                           
        func_text = '\n'.join(func_lines)
        loc_vars = {}
        exec(func_text, global_vars, loc_vars)
        _impl = loc_vars['impl']

        return _impl

This works fine as the func tuple is not accessed by a loop. But this takes away any parallelism that prange may provide

@stuartarchibald
Copy link
Contributor

@stuartarchibald I see. However i was able to directly access an element using func[0] or func[1]. The problem arose only when using a loop it seems.

If the index is a compile time constant (like literal 0 or 1) it's possible to work out the type of the dispatcher at that index and so it's not a problem. It's only in the case of the loop difficulties arise, for example, in:

@njit
def foo():
    funcs = (sum, square)
    for fn in funcs:
        print(fn(10.))

the type of fn changes throughout execution.

this is why my impl looks something like this

    elif isinstance(self, SeriesType) and isinstance(func, types.Tuple):
        output_types = []
        n_series = len(func)

        for i in prange(n_series):
            sig = func[i].get_call_type(cpu_target.typing_context, func_args, {})
            output_types.append(sig.return_type)

        func_lines = [f"def impl(self, func, *args):",
                      f"  input_arr = self._data",
                      f"  length = len(input_arr)"]

        results = []
        for i in range(n_series):
            result_c = f"s_{i}"
            func_lines += [f"  output_arr_{i} = numpy.empty(length, dtype=types.{output_types[i]})",
                           f"  for i in prange(length):",
                           f"    output_arr_{i}[i] = func[{i}](input_arr[i], *args)",
                           f"  print(output_arr_{i})]                           
        func_text = '\n'.join(func_lines)
        loc_vars = {}
        exec(func_text, global_vars, loc_vars)
        _impl = loc_vars['impl']

        return _impl

This works fine as the func tuple is not accessed by a loop. But this takes away any parallelism that prange may provide

You can still use the parallelism from prange if signatures are provided for the functions in the tuple such that they are considered first-class function types, for example:

from numba import njit, prange

@njit('float64(float64)')
def square(x):
    return x ** 2

@njit('float64(float64)')
def sum(x):
    return x + 1

@njit(parallel=True)
def foo():
    funcs = (sum, square)
    for idx in prange(len(funcs)):
        fn = funcs[idx]
        print(fn(10.))

foo()

foo.parallel_diagnostics()

@stuartarchibald
Copy link
Contributor

Closing this question as it seems to be resolved. Numba now has a discourse forum https://numba.discourse.group/ which is great for questions like this, please do consider posting there in future :) Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Notes an issue as a question
Projects
None yet
Development

No branches or pull requests

2 participants