-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support builtins in @overload #3030
Comments
#2297 is a step in that direction. This would allow to This PR is an attempt at the implementation. However, I am not sure the approach taken is the best choice. Required more invasive changes than I initially thought. Maybe starting at the other end by modifying Also, the PR does not cover |
I hacked together a new version of overload that works for builtins: def overload_any(func):
"""
Like `numba.extending.overload` but works for things like `getitem`, etc.
Used likes `generated_jit`:
@overload_any("getitem")
def getitem_const(val, i):
if val.value == "hi":
return lambda val, i: i
elif val.value == "there":
return lambda val, i: -i
"""
def inner(overload_func):
# lower dispatcher based on `numba.typing.templates._OverloadMethodTemplate.do_class_init`
dispatcher = numba.generated_jit(nopython=True)(overload_func)
disp_type = numba.types.Dispatcher(dispatcher)
def impl(context, builder, sig, args):
call = context.get_function(disp_type, sig)
return call(builder, args)
@numba.extending.type_callable(func)
def type_inner(context):
# need to pass in `dispatcher` or get "underlying object has vanished"
def typer(*args, dispatcher=dispatcher):
try:
sig = disp_type.get_call_type(context, args, {})
except TypeError: # None returned by overloaded function
return
if sig:
# ideally, instead of adding a lowering for this specific type, we would just return the `impl`
# with the typing so it doesn't have to look it up. I am not sure how to do this in `type_callable`, though.
numba.targets.imputils.lower_builtin(func, *sig.args)(impl)
return sig.return_type
return typer
return inner You use it like this: @overload_any("getitem")
def getitem_const(val, i):
if val.value == "hi":
return lambda val, i: i
elif val.value == "there":
return lambda val, i: -i
@numba.njit
def hi(i):
return "there"[i], "hi"[i]
assert hi(10) == (-10, 10) |
Feature request
High-level extension APIs such as
@overload
should support builtin operators (e.g.setitem
,in
). Feature parity with low-level extension APIs in general is desirable.The text was updated successfully, but these errors were encountered: