Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Part 1: dynamic override for overload decorator #9578

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

dlee992
Copy link
Contributor

@dlee992 dlee992 commented May 21, 2024

A potential way to
fix #5043

Notes for reviewers:

  • add override option for overload decorator
  • handle this override request by modifying the behavior of Function type's augment method
  • add override() method for Dispatcher class
  • add callers to record the call graph
    • utilize the CallStack to populate callers
    • after callee is cached, it won't appear in CallStack, need to complete callers at another location
    • dispatcher call dispatcher directly, handle this case in typeof_global in Typeinferer
  • add disp_map to record the mapping between compiled py_func name and its dispatcher instance
  • when found override=True, call cputarget.typing_context.refresh() to flush dispatcher instances of callers
  • how to handle disp_map when the dispatcher is defined via exec: skip for now

TODOs:

  • shouldn't explicitly call recompile(), make the recompilation detection automatically
  • test if it works with callers decorated by @njit(cache=True)
  • test if it works with call chains whose length is > 1, i.e., test with indirect callers
  • test if it works with the overloaded function having a different signature with the first time
  • how about adding sigs, or (args, kws) into the disp_map and callers?

Future plan for next PR:

  • add same support for overload_attr and overload_method.

@dlee992 dlee992 marked this pull request as draft May 21, 2024 22:07
@dlee992
Copy link
Contributor Author

dlee992 commented May 23, 2024

Test code:

from numba import njit
from numba.extending import overload

def foo():
    raise NotImplementedError

@overload(foo, override=False)
def ol_foo0():
    def impl0():
        return 0
    return impl0

@njit
def goo():
    return foo()

print(goo())

@overload(foo, override=True)
def ol_foo1():
    def impl1():
        return 1
    return impl1

print(goo())

@njit
def hoo():
    return foo()

@njit
def joo():
    return hoo()

print(hoo())
print(joo())

@overload(foo, override=True)
def ol_foo2():
    def impl2():
        return 2
    return impl2

print(goo())
print(hoo())
print(joo())

With the PR, output:

# 1st definition of foo
0
# 2nd definition of foo
1
1
1
# 3rd definition of foo
2
2
2

Without the PR, output is 0 0 0 0 0 0 0.

@dlee992 dlee992 changed the title dynamic override of user-defined overload functions/methods Part 1: dynamic override for overload decorator May 24, 2024
@sklam
Copy link
Member

sklam commented May 28, 2024

@dlee992, after discussing in today's triage, we think that target-extension API can fulfill the same needs. See

@overload(operator.add, target="dpu")
def ol_add(x, y):
if isinstance(x, types.Integer) and isinstance(y, types.Integer):
def impl(x, y):
return intrin_add(x, y)
return impl
for an example. The other problem with proposed API of @overload(.., override=True) is that multiple packages can override the same API, leading to conflicts or behavior that depends on import order.

@dlee992
Copy link
Contributor Author

dlee992 commented May 28, 2024

we think that target-extension API can fulfill the same needs

I think I didn't get it. Do you imply if we change the decorator usage from @overload(..., override=True) to @overload(..., target='cpu'), the dynamic override will work? I tested this locally in #9578 (comment), but it didn't work.

AH, do you mean I have to create a new Target for myself? But can my new Target class also inherit all the overloads from cpu target? I mean I just want to dynamically change one overload definition for one operation, while keeping using other overloads from cpu target.

In fact, my final desire is to use the dynamic override for overload_method, I just try find a way to support overload, then extend this feature to overload_method as Part 2.

The other problem with proposed API of @overload(.., override=True) is that multiple packages can override the same API, leading to conflicts or behavior that depends on import order.

Yeah, the concern is reasonable. But I think we can suggest numba-extension library developers DONOT use this flag if possible, while suggesting numba end-users to use this flag for debugging, just as described in the original issue #5043.

after discussing in today's triage

BTW, can I join the triage meeting? Thanks!

@guilhermeleobas
Copy link
Contributor

guilhermeleobas commented Jun 25, 2024

AH, do you mean I have to create a new Target for myself? But can my new Target class also inherit all the overloads from cpu target? I mean I just want to dynamically change one overload definition for one operation, while keeping using other overloads from cpu target.

Yes, you can! Just inherit from the CPU target when creating your own target:

from numba.core.registry import CPUDispatcher, cpu_target
from numba.core.target_extension import (
    dispatcher_registry,
    target_registry,
    CPU,
)

class MyTarget(CPU): ...


class MyTargetDispatcher(CPUDispatcher):
    targetdescr = cpu_target

target_registry["my_target"] = MyTarget
dispatcher_registry[target_registry["my_target"]] = MyTargetDispatcher

@dlee992
Copy link
Contributor Author

dlee992 commented Jun 25, 2024

Thanks! Let me figure out how to apply this feature into my use case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

@overload cannot replace previously defined or built-in implementations
3 participants