Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with using overload on a vectorized function #4029

Open
2 tasks done
person142 opened this issue Apr 29, 2019 · 8 comments
Open
2 tasks done

Issue with using overload on a vectorized function #4029

person142 opened this issue Apr 29, 2019 · 8 comments

Comments

@person142
Copy link
Contributor

If I try to overload a function decorated with numba.vectorize:

In [40]: @numba.vectorize
    ...: def f(x):
    ...:     pass
    ...:

In [41]: @numba.extending.overload(f)
    ...: def f_overload(x):
    ...:     if x == numba.types.float64:
    ...:         return lambda x: x
    ...:

In [42]: @numba.njit
    ...: def g(x):
    ...:     return f(x)
    ...:

In [43]: g(1.0)

then I get this traceback:

Traceback
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
~/Projects/spycial-dev/lib/python3.7/site-packages/numba/errors.py in new_error_context(fmt_, *args, **kwargs)
    626     try:
--> 627         yield
    628     except NumbaError as e:

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/lowering.py in lower_block(self, block)
    257                                    loc=self.loc, errcls_=defaulterrcls):
--> 258                 self.lower_inst(inst)
    259

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/lowering.py in lower_inst(self, inst)
    300             ty = self.typeof(inst.target.name)
--> 301             val = self.lower_assign(ty, inst)
    302             self.storevar(val, inst.target.name)

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/lowering.py in lower_assign(self, ty, inst)
    453         elif isinstance(value, ir.Expr):
--> 454             return self.lower_expr(ty, value)
    455

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/lowering.py in lower_expr(self, resty, expr)
    912         elif expr.op == 'call':
--> 913             res = self.lower_call(resty, expr)
    914             return res

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/lowering.py in lower_call(self, resty, expr)
    705         else:
--> 706             res = self._lower_call_normal(fnty, expr, signature)
    707

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/lowering.py in _lower_call_normal(self, fnty, expr, signature)
    884
--> 885         res = impl(self.builder, argvals, self.loc)
    886         return res

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/targets/base.py in __call__(self, builder, args, loc)
   1131     def __call__(self, builder, args, loc=None):
-> 1132         res = self._imp(self._context, builder, self._sig, args, loc=loc)
   1133         self._context.add_linking_libs(getattr(self, 'libs', ()))

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/targets/base.py in wrapper(*args, **kwargs)
   1156             kwargs.pop('loc')     # drop unused loc
-> 1157             return fn(*args, **kwargs)
   1158

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/npyufunc/dufunc.py in __call__(self, context, builder, sig, args)
     64                                           self.kernel,
---> 65                                           explicit_output=explicit_output)
     66

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/targets/npyimpl.py in numpy_ufunc_kernel(context, builder, sig, args, kernel_class, explicit_output)
    320                 context, builder,
--> 321                 lc.Constant.null(context.get_value_type(ret_ty)), ret_ty)
    322         arguments.append(output)

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/targets/npyimpl.py in _prepare_argument(ctxt, bld, inp, tyinp, where)
    172     else:
--> 173         raise NotImplementedError('unsupported type for {0}: {1}'.format(where, str(tyinp)))
    174

NotImplementedError: unsupported type for input operand: none

During handling of the above exception, another exception occurred:

LoweringError                             Traceback (most recent call last)
<ipython-input-43-b4ad4e189a16> in <module>
----> 1 g(1.0)

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/dispatcher.py in _compile_for_args(self, *args, **kws)
    367                     e.patch_message(''.join(e.args) + help_msg)
    368             # ignore the FULL_TRACEBACKS config, this needs reporting!
--> 369             raise e
    370
    371     def inspect_llvm(self, signature=None):

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/dispatcher.py in _compile_for_args(self, *args, **kws)
    324                 argtypes.append(self.typeof_pyval(a))
    325         try:
--> 326             return self.compile(tuple(argtypes))
    327         except errors.TypingError as e:
    328             # Intercept typing error that may be due to an argument

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/compiler_lock.py in _acquire_compile_lock(*args, **kwargs)
     30         def _acquire_compile_lock(*args, **kwargs):
     31             with self:
---> 32                 return func(*args, **kwargs)
     33         return _acquire_compile_lock
     34

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/dispatcher.py in compile(self, sig)
    656
    657             self._cache_misses[sig] += 1
--> 658             cres = self._compiler.compile(args, return_type)
    659             self.add_overload(cres)
    660             self._cache.save_overload(sig, cres)

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/dispatcher.py in compile(self, args, return_type)
     80                                       args=args, return_type=return_type,
     81                                       flags=flags, locals=self.locals,
---> 82                                       pipeline_class=self.pipeline_class)
     83         # Check typing error if object mode is used
     84         if cres.typing_error is not None and not flags.enable_pyobject:

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/compiler.py in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
    939     pipeline = pipeline_class(typingctx, targetctx, library,
    940                               args, return_type, flags, locals)
--> 941     return pipeline.compile_extra(func)
    942
    943

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/compiler.py in compile_extra(self, func)
    370         self.lifted = ()
    371         self.lifted_from = None
--> 372         return self._compile_bytecode()
    373
    374     def compile_ir(self, func_ir, lifted=(), lifted_from=None):

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/compiler.py in _compile_bytecode(self)
    870         """
    871         assert self.func_ir is None
--> 872         return self._compile_core()
    873
    874     def _compile_ir(self):

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/compiler.py in _compile_core(self)
    857         self.define_pipelines(pm)
    858         pm.finalize()
--> 859         res = pm.run(self.status)
    860         if res is not None:
    861             # Early pipeline completion

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/compiler_lock.py in _acquire_compile_lock(*args, **kwargs)
     30         def _acquire_compile_lock(*args, **kwargs):
     31             with self:
---> 32                 return func(*args, **kwargs)
     33         return _acquire_compile_lock
     34

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/compiler.py in run(self, status)
    251                     # No more fallback pipelines?
    252                     if is_final_pipeline:
--> 253                         raise patched_exception
    254                     # Go to next fallback pipeline
    255                     else:

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/compiler.py in run(self, status)
    242                 try:
    243                     event(stage_name)
--> 244                     stage()
    245                 except _EarlyPipelineCompletion as e:
    246                     return e.result

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/compiler.py in stage_nopython_backend(self)
    729         """
    730         lowerfn = self.backend_nopython_mode
--> 731         self._backend(lowerfn, objectmode=False)
    732
    733     def stage_compile_interp_mode(self):

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/compiler.py in _backend(self, lowerfn, objectmode)
    679             self.library.enable_object_caching()
    680
--> 681         lowered = lowerfn()
    682         signature = typing.signature(self.return_type, *self.args)
    683         self.cr = compile_result(

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/compiler.py in backend_nopython_mode(self)
    666                 self.calltypes,
    667                 self.flags,
--> 668                 self.metadata)
    669
    670     def _backend(self, lowerfn, objectmode):

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/compiler.py in native_lowering_stage(targetctx, library, interp, typemap, restype, calltypes, flags, metadata)
   1061         lower = lowering.Lower(targetctx, library, fndesc, interp,
   1062                                metadata=metadata)
-> 1063         lower.lower()
   1064         if not flags.no_cpython_wrapper:
   1065             lower.create_cpython_wrapper(flags.release_gil)

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/lowering.py in lower(self)
    175         if self.generator_info is None:
    176             self.genlower = None
--> 177             self.lower_normal_function(self.fndesc)
    178         else:
    179             self.genlower = self.GeneratorLower(self)

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/lowering.py in lower_normal_function(self, fndesc)
    216         # Init argument values
    217         self.extract_function_arguments()
--> 218         entry_block_tail = self.lower_function_body()
    219
    220         # Close tail of entry block

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/lowering.py in lower_function_body(self)
    241             bb = self.blkmap[offset]
    242             self.builder.position_at_end(bb)
--> 243             self.lower_block(block)
    244
    245         self.post_lower()

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/lowering.py in lower_block(self, block)
    256             with new_error_context('lowering "{inst}" at {loc}', inst=inst,
    257                                    loc=self.loc, errcls_=defaulterrcls):
--> 258                 self.lower_inst(inst)
    259
    260     def create_cpython_wrapper(self, release_gil=False):

/usr/local/Cellar/python/3.7.1/Frameworks/Python.framework/Versions/3.7/lib/python3.7/contextlib.py in __exit__(self, type, value, traceback)
    128                 value = type()
    129             try:
--> 130                 self.gen.throw(type, value, traceback)
    131             except StopIteration as exc:
    132                 # Suppress StopIteration *unless* it's the same exception that

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/errors.py in new_error_context(fmt_, *args, **kwargs)
    633         from numba import config
    634         tb = sys.exc_info()[2] if config.FULL_TRACEBACKS else None
--> 635         six.reraise(type(newerr), newerr, tb)
    636
    637

~/Projects/spycial-dev/lib/python3.7/site-packages/numba/six.py in reraise(tp, value, tb)
    657         if value.__traceback__ is not tb:
    658             raise value.with_traceback(tb)
--> 659         raise value
    660
    661 else:

LoweringError: Failed in nopython mode pipeline (step: nopython mode backend)
unsupported type for input operand: none

File "<ipython-input-42-b6814e4dfb5d>", line 3:
def g(x):
    return f(x)
    ^

[1] During: lowering "$0.3 = call $0.1(x, func=$0.1, args=[Var(x, <ipython-input-42-b6814e4dfb5d> (3))], kws=(), vararg=None)" at <ipython-input-42-b6814e4dfb5d> (3)
-------------------------------------------------------------------------------
This should not have happened, a problem has occurred in Numba's internals.

Please report the error message and traceback, along with a minimal reproducer
at: https://github.com/numba/numba/issues/new

If more help is needed please feel free to speak to the Numba core developers
directly at: https://gitter.im/numba/numba

Thanks in advance for your help in improving Numba!

which ends with a request to report the issue:

This should not have happened, a problem has occurred in Numba's internals.

Please report the error message and traceback, along with a minimal reproducer
at: https://github.com/numba/numba/issues/new

If more help is needed please feel free to speak to the Numba core developers
directly at: https://gitter.im/numba/numba

Thanks in advance for your help in improving Numba!

If I do the same thing but without the vectorize decorator then the overload works as expected:

In [36]: def f(x):
    ...:     pass
    ...:

In [37]: @numba.extending.overload(f)
    ...: def f_overload(x):
    ...:     if x == numba.types.float64:
    ...:         return lambda x: x
    ...:

In [38]: @numba.njit
    ...: def g(x):
    ...:     return f(x)
    ...:

In [39]: g(1.0)
Out[39]: 1.0
@stuartarchibald
Copy link
Contributor

Thanks for the report. Whilst I'm fairly sure that @overloading a function-decorated-with-different-jit-kind-of-jit-decorator would not work as expected, the issue here is that the ufunc itself is invalid as it doesn't return a scalar value of NumPy type:

import numba

@numba.vectorize
def f(x):
    pass

@numba.njit
def g(x):
    return f(x)

g(1.0)

reproduces the error you see.

If however you change f and make the code look like:

import numba

@numba.vectorize
def f(x):
    return 1

@numba.njit
def g(x):
    return f(x)

g(1.0)

this works fine.

Once fixed up, going back to the @overload, this:

import numba

@numba.vectorize
def f(x):
    return 1

@numba.extending.overload(f)
def f_overload(x):
    if x == numba.types.float64:
        return lambda x: x

@numba.njit
def g(x):
    return f(x)


print(g(10.0))

prints the value 1, I think this is because f is already resolved in the type system so the @overload won't even get called.

@person142
Copy link
Contributor Author

Thanks for looking into that @stuartarchibald. My original goal was to do basically your fixed up example.

@seibert
Copy link
Contributor

seibert commented Apr 30, 2019

Breaking this down a bit, it looks like you want to do select between two implementations:

  • one for scalars that is a regular @jit function
  • one for arrays that is a ufunc @vectorize function

The standard way to do this is to use @generated_jit to select an implementation based on argument types, which will then be compiled as a @jit function. Since @jit functions can call @vectorize compiled functions, you can do the following:

import numba
import numpy as np

@numba.vectorize
def f(x):
    return 1

@numba.generated_jit(nopython=True)
def g(x):
    # args are types, not values
    if x is numba.types.float64:
        def scalar_impl(x):
            return 2.0
        return scalar_impl
    else:
        def vector_impl(x):
            return f(x)  # call the ufunc defined above
        return vector_impl

which then results in this:

>>> print(g(5.0))
2.0
>>> print(g(np.arange(5)))
[1 1 1 1 1]
>>> print(g(np.arange(6).reshape(2,3)))
[[1 1 1]
 [1 1 1]]

(The only minor improvement I would prefer here is being able to return @vectorize function directly from a @generated_jit, rather than having to wrap it in a regular python function.) Note that this doesn't work for @guvectorize functions due to an existing limitation about calling those functions from nopython mode.

@person142
Copy link
Contributor Author

@seibert that would probably work for what I'm trying to do. (I'll try it out later today!)

I am (however) in the situation where my vectorize functions are already using generated_jit to dispatch to different implementations, so I'll need a lot of boilerplate to get everything going (which isn't a huge deal at the end of the day). Concrete example of what I'm doing now here:

https://github.com/person142/spycial/blob/master/spycial/trig.py#L119

@seibert
Copy link
Contributor

seibert commented Apr 30, 2019

Hmm, that is an interesting situation. Do you need the top-level user function to be a ufunc? It seems like the best way to straighten this out would be to do all the type-based selection of implementations at the top level function, and call out to different ufuncs or regular functions as needed.

@seibert
Copy link
Contributor

seibert commented Apr 30, 2019

No matter what mechanism we offer, I think the metaprogramming you are trying to do will be challenging. :)

@person142
Copy link
Contributor Author

person142 commented Apr 30, 2019

Do you need the top-level user function to be a ufunc? It seems like the best way to straighten this out would be to do all the type-based selection of implementations at the top level function, and call out to different ufuncs or regular functions as needed.

Yeah, that's probably the right way to go. There was appeal in having ufuncs because it gave "drop in" replacements (though in my case reduce etc. aren't useful). But it also comes with a bunch of drawbacks-users can't specify target for example.

I'd really prefer to just provide overloaded scalar functions and let users call vectorize as desired, but I worry it's a little too much friction. An alternative is to provide a split API-a bunch of scalar kernels in one module and corresponding ufuncs in another, but then I got to wondering whether I could have my cake and eat it too...

@person142
Copy link
Contributor Author

Hm the generated_jit approach does have a big drawback compared to a ufunc:

In [16]: g(np.array([1.0]))
Out[16]: array([1])

In [17]: g([1.0])
NotImplementedError                       Traceback (most recent call last)
...
LoweringError: Failed in nopython mode pipeline (step: nopython mode backend)
unsupported type for input operand: reflected list(float64)

File "<ipython-input-10-c3444b0c0ec1>", line 9:
        def vector_impl(x):
            return f(x)
            ^

[1] During: lowering "$0.3 = call $0.1(x, func=$0.1, args=[Var(x, <ipython-input-10-c3444b0c0ec1> (9))], kws=(), vararg=None)" at <ipython-input-10-c3444b0c0ec1> (9)
-------------------------------------------------------------------------------
This should not have happened, a problem has occurred in Numba's internals.

Please report the error message and traceback, along with a minimal reproducer
at: https://github.com/numba/numba/issues/new

If more help is needed please feel free to speak to the Numba core developers
directly at: https://gitter.im/numba/numba

Thanks in advance for your help in improving Numba!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants