Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support builtins in @overload #3030

Open
ehsantn opened this issue Jun 12, 2018 · 2 comments
Open

Support builtins in @overload #3030

ehsantn opened this issue Jun 12, 2018 · 2 comments

Comments

@ehsantn
Copy link
Contributor

ehsantn commented Jun 12, 2018

Feature request

High-level extension APIs such as @overload should support builtin operators (e.g. setitem, in). Feature parity with low-level extension APIs in general is desirable.

@asodeur
Copy link
Contributor

asodeur commented Jun 25, 2018

#2297 is a step in that direction. This would allow to @overload anything that has a corresponding function in the operator module by overloading that function.

This PR is an attempt at the implementation. However, I am not sure the approach taken is the best choice. Required more invasive changes than I initially thought. Maybe starting at the other end by modifying numba.templates.builtin_registry and/or @overload such that @overload could act on builtin operators would have been smarter.

Also, the PR does not cover getitem, setitem etc, yet (see numba.utils.OPERATORS_TO_BUILTINS from the PR for what is moved to operator module functions already).

@saulshanabrook
Copy link
Contributor

I hacked together a new version of overload that works for builtins:

def overload_any(func):
    """
    Like `numba.extending.overload` but works for things like `getitem`, etc.

    Used likes `generated_jit`:

        @overload_any("getitem")
        def getitem_const(val, i):
            if val.value == "hi":
                return lambda val, i: i
            elif val.value == "there":
                return lambda val, i: -i
    """

    def inner(overload_func):
        # lower dispatcher based on `numba.typing.templates._OverloadMethodTemplate.do_class_init`
        dispatcher = numba.generated_jit(nopython=True)(overload_func)
        disp_type = numba.types.Dispatcher(dispatcher)

        def impl(context, builder, sig, args):
            call = context.get_function(disp_type, sig)
            return call(builder, args)

        @numba.extending.type_callable(func)
        def type_inner(context):
            # need to pass in `dispatcher` or get "underlying object has vanished"
            def typer(*args, dispatcher=dispatcher):
                try:
                    sig = disp_type.get_call_type(context, args, {})
                except TypeError:  # None returned by overloaded function
                    return
                if sig:
                    # ideally, instead of adding a lowering for this specific type, we would just return the `impl`
                    # with the typing so it doesn't have to look it up. I am not sure how to do this in `type_callable`, though.
                    numba.targets.imputils.lower_builtin(func, *sig.args)(impl)
                    return sig.return_type

            return typer

    return inner

You use it like this:

@overload_any("getitem")
def getitem_const(val, i):
    if val.value == "hi":
        return lambda val, i: i
    elif val.value == "there":
        return lambda val, i: -i


@numba.njit
def hi(i):
    return "there"[i], "hi"[i]


assert hi(10) == (-10, 10)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants