New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA: Support for @overload
#6944
Conversation
Joint work with @stuartarchibald. The implementation consists of: - Addition of `get_call_template`, `nopython_signatures`, and `get_overload` to the class that is now `DeviceDispatcher`. These are needed by the overload infrastructure, and are inspired by the equivalents in `_DispatcherBase`. - Registration of a CUDA `jit` decorator. The registration with `dispatcher_register` is removed and replaced with this new registration, because it did nothing (it was for the `target` kwarg to the `numba.jit` decorator, which is deprecated). The `jit` decorator for overloads is a bit special in that it always adds the `device=True` kwarg to the regular `jit` decorator, because overloads will always be device functions. - Addition of a `no_cpython_wrapper` kwarg to `jitdevice`, which is needed because the overload infrastructure passes it. No action is required regardless of its value, because there is no CPython wrapper in the CUDA target. A refactoring is also made: `DeviceFunctionTemplate` is renamed to `DeviceDispatcher`, as it reflects more closely what it is - it is analagous to the `Dispatcher` class used for kernels, and the `DeviceFunction` class is more analagous to the `_Kernel` class. Eventually `Dispatcher` will subsume these classes as we move towards a more "modern" dispatcher implementation in the CUDA target. Tests are added; these are all skipped on the simulator because overloading doesn't really exist as a concept without compilation.
This failure: https://dev.azure.com/numba/ff1fe4d0-ed73-4f1c-b894-1d50a27e048f/_apis/build/builds/8536/logs/130 AFAICT is nothing to do with this PR, fixed in #6945 |
A quick note on this, mainly for @stuartarchibald: Earlier versions of this had |
Agree, I think the registrations needed to get |
Using prints and capturing stdout to test overloads is both overly complicated and difficult to debug (terminal-based debuggers and debug prints don't work during the tests). The value-based check that replaces it suffers neither of these issues, and it is at least as clear what each test does (more so, in my opinion).
@stuartarchibald As you suggested OOB I've replaced the printing and captured output in the tests with a value-based check - this is much tidier, and easier to deal with issues. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the patch @gmarkall, great to see this working :) Once minor comment to resolve else looks good.
def test_call_hardware_overloaded(self): | ||
def kernel(x): | ||
hardware_overloaded(x) | ||
|
||
expected = CUDA_HARDWARE_OL | ||
self.check_overload(kernel, expected) | ||
|
||
def test_generic_calls_hardware_overloaded(self): | ||
def kernel(x): | ||
generic_calls_hardware_overloaded(x) | ||
|
||
expected = GENERIC_CALLS_HARDWARE_OL * CUDA_HARDWARE_OL | ||
self.check_overload(kernel, expected) | ||
|
||
def test_cuda_calls_hardware_overloaded(self): | ||
def kernel(x): | ||
cuda_calls_hardware_overloaded(x) | ||
|
||
expected = CUDA_CALLS_HARDWARE_OL * CUDA_HARDWARE_OL | ||
self.check_overload(kernel, expected) | ||
|
||
def test_hardware_overloaded_calls_hardware_overloaded(self): | ||
def kernel(x): | ||
hardware_overloaded_calls_hardware_overloaded(x) | ||
|
||
expected = CUDA_HARDWARE_OL_CALLS_HARDWARE_OL * CUDA_HARDWARE_OL | ||
self.check_overload(kernel, expected) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this section, maybe test that the CPU calls would pick up the "generic" variants?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added.
This is checking when there are both generic and CUDA overloads, to ensure the CUDA overloads don't interfere with the generic / CPU target. From PR numba#6944 feedback.
Thanks for adding the extra tests in f1b54d0, this looks good. There's a conflict to resolve against mainline and once done this can be tested on the farm. Thanks again. |
@stuartarchibald Many thanks, conflict resolved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for getting this working, it's great to see the new API in action!
Thanks, patch approved. |
Buildfarm ID: |
Passed. |
Joint work with @stuartarchibald.
The implementation consists of:
get_call_template
,nopython_signatures
, andget_overload
to the class that is nowDeviceDispatcher
. These are needed by the overload infrastructure, and are inspired by the equivalents in_DispatcherBase
.jit
decorator. The registration withdispatcher_register
is removed and replaced with this new registration, because it did nothing (it was for thetarget
kwarg to thenumba.jit
decorator, which is deprecated). Thejit
decorator for overloads is a bit special in that it always adds thedevice=True
kwarg to the regularjit
decorator, because overloads will always be device functions.no_cpython_wrapper
kwarg tojitdevice
, which is needed because the overload infrastructure passes it. No action is required regardless of its value, because there is no CPython wrapper in the CUDA target.A refactoring is also made:
DeviceFunctionTemplate
is renamed toDeviceDispatcher
, as it reflects more closely what it is - it isanalagous to the
Dispatcher
class used for kernels, and theDeviceFunction
class is more analagous to the_Kernel
class. EventuallyDispatcher
will subsume these classes as we move towards a more "modern" dispatcher implementation in the CUDA target.Tests are added; these are all skipped on the simulator because overloading doesn't really exist as a concept without compilation.