Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support NVIDIA's CUDA Python bindings #7461

Merged
merged 77 commits into from Nov 24, 2021
Merged
Show file tree
Hide file tree
Changes from 73 commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
ccfce50
CUDA: Start of support for CUDA Python bindings
gmarkall Sep 7, 2021
c51a9d1
CUDA testsuite runs with CUDA Python bindings
gmarkall Sep 7, 2021
8db1185
Implement memory allocation for CUDA Python
gmarkall Sep 7, 2021
1244c4e
Some fixes for views with CUDA Python
gmarkall Sep 7, 2021
2ba0e4b
CUDA: Add framework for two separate linker implementations
gmarkall Sep 7, 2021
cd57e37
CUDA: Implement load_module_image with CUDA Python bindings
gmarkall Sep 7, 2021
cb3d680
Implement modules and functions for CUDA Python
gmarkall Sep 8, 2021
b4e93c6
Correct argument preparation with CUDA Python
gmarkall Sep 8, 2021
e4db9a5
Kernel launches now starting with CUDA Python
gmarkall Sep 8, 2021
bd539e8
Skip record test with CUDA Python
gmarkall Sep 8, 2021
0670f2d
Fix block size and occupancy functions and kernel launch for CUDA Python
gmarkall Sep 8, 2021
5e69273
CUDA: Only handle device pointers in launch_kernel
gmarkall Sep 9, 2021
84ce354
Fix CUDA Python stream creation and skip some tests
gmarkall Sep 9, 2021
d25ee6e
Revert changes to context stack
gmarkall Sep 9, 2021
03dd908
CUDA Python fixes for IPC, streams, and CAI
gmarkall Sep 9, 2021
c033420
CUDA Python IPC and context fixes
gmarkall Sep 9, 2021
d1c1eb9
CUDA Python host allocation fixes
gmarkall Sep 9, 2021
96f617a
CUDA Python host allocation fixes
gmarkall Sep 9, 2021
e1dfe1e
Fix managed allocation with CUDA Python
gmarkall Sep 9, 2021
afa0d1b
Fix views with CUDA Python
gmarkall Sep 9, 2021
0b974c7
Some CUDA Python IPC fixes
gmarkall Sep 9, 2021
991b58f
Fix test_cuda_memory with CUDA Python
gmarkall Sep 9, 2021
92e053d
Fix record argument passing with CUDA Python
gmarkall Sep 9, 2021
24d7b48
Unskip remaining skipped CUDA Python tests
gmarkall Sep 9, 2021
00d568c
Fix CUDA driver tests with CUDA Python
gmarkall Sep 10, 2021
5d757be
Fix a couple of CAI tests with CUDA Python
gmarkall Sep 10, 2021
55161ec
Fix CUDA Array Interface tests with CUDA Python
gmarkall Sep 16, 2021
b1ef00f
Mark PTDS as unsupported with CUDA Python
gmarkall Sep 16, 2021
58fd927
Fix context stack tests with CUDA Python
gmarkall Sep 16, 2021
eb959b2
Add file extension map for CUDA Python
gmarkall Sep 28, 2021
be49f23
Fix async callbacks for CUDA Python
gmarkall Sep 28, 2021
217b658
Fix event recording for CUDA Python
gmarkall Sep 28, 2021
3d07299
Fix device_memory_size for CUDA Python
gmarkall Sep 28, 2021
cca4e4e
Fix test_managed_alloc for CUDA Python
gmarkall Sep 28, 2021
7c9b3c3
Fix a few more CUDA Python fails
gmarkall Sep 28, 2021
93cc0f1
Fix remaining CUDA Python test fails
gmarkall Sep 29, 2021
08182d0
Fix import when CUDA Python not available
gmarkall Sep 29, 2021
25eb71c
Merge remote-tracking branch 'numba/master' into cuda-python
gmarkall Sep 29, 2021
0dd03dd
Small comment and whitespace change undo
gmarkall Sep 29, 2021
7cc1f53
Merge remote-tracking branch 'numba/master' into cuda-python
gmarkall Oct 6, 2021
3b0a363
Reuse alloc_key for allocations key in memhostalloc
gmarkall Oct 6, 2021
6404dfb
Simplify getting pointers for ctypes functions
gmarkall Oct 6, 2021
0c6ed5b
Don't use CUDA Python by default
gmarkall Oct 6, 2021
c2d4d8d
Remove some dead code
gmarkall Oct 6, 2021
12effed
Document CUDA Python environment variable
gmarkall Oct 6, 2021
379ac22
Merge remote-tracking branch 'numba/master' into cuda-python
gmarkall Nov 1, 2021
7847d51
driver.py: rename cuda_driver to binding (PR #7461 feedback
gmarkall Nov 1, 2021
4beb7dd
Rename CUDA_USE_CUDA_PYTHON to CUDA_USE_NV_BINDING
gmarkall Nov 1, 2021
caf34c0
Update docs for CUDA_USE_NVIDIA_BINDING
gmarkall Nov 1, 2021
004d74a
CUDA driver: Use defined values instead of magic numbers for streams
gmarkall Nov 1, 2021
c3e7fdb
CUDA driver error checking: factor out fork detection
gmarkall Nov 1, 2021
2eef758
Use CU_STREAM_DEFAULT in Stream.__repr__
gmarkall Nov 1, 2021
af14cc8
Fix spelling of CU_JIT_INPUT_FATBINARY
gmarkall Nov 1, 2021
d110d1c
Add docstring to add_file_guess_ext
gmarkall Nov 1, 2021
3d964cd
CUDA: Remove a needless del from the Ctypes linker
gmarkall Nov 1, 2021
84d46d4
Some small fixups from PR #7461 feedback
gmarkall Nov 1, 2021
d4d5176
CUDA: Fix simulator by adding missing USE_NV_BINDING to simulator
gmarkall Nov 1, 2021
96776f7
CUDA: Use helper function in test_derived_pointer
gmarkall Nov 1, 2021
0a7a8d8
Re-enable profiler with CUDA Python
gmarkall Nov 1, 2021
0617911
Update documentation for NVIDIA bindings
gmarkall Nov 3, 2021
6eb1924
Merge remote-tracking branch 'numba/master' into cuda-python
gmarkall Nov 4, 2021
43f3ae7
PR #7461 feedback on deprecation wording
gmarkall Nov 8, 2021
57413cf
Merge remote-tracking branch 'numba/master' into cuda-python
gmarkall Nov 22, 2021
37ef39b
CUDA: Add function to get driver version
gmarkall Nov 22, 2021
771bc38
Report CUDA binding availability and use in Numba sysinfo
gmarkall Nov 22, 2021
81809f7
Merge remote-tracking branch 'gmarkall/cuda-python' into cuda-python
gmarkall Nov 22, 2021
5699b91
CUDA: Attempt to test with NVIDIA binding on CUDA 11.4
gmarkall Nov 22, 2021
dd57c0c
CUDA: Add docs for NVIDIA binding support
gmarkall Nov 22, 2021
b86d4fa
Correct spelling of NUMBA_CUDA_USE_NVIDIA_BINDING
gmarkall Nov 22, 2021
29b3ea8
Revert "Re-enable profiler with CUDA Python"
gmarkall Nov 22, 2021
22a4b74
CUDA docs: Note that profiler not supported with NV bindings
gmarkall Nov 22, 2021
1b59892
Correct mis-spelled env var in docs
gmarkall Nov 23, 2021
60321d5
Update CUDA docs based on PR #7461 feedback
gmarkall Nov 23, 2021
1ca9acb
Warn when NVIDIA bindings requested but not found
gmarkall Nov 23, 2021
4a50d0c
Mention env var in NVIDIA bindings warning
gmarkall Nov 23, 2021
8ab1535
Update NV binding env var docs
gmarkall Nov 23, 2021
f93c602
Update numba/core/config.py
gmarkall Nov 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 14 additions & 0 deletions buildscripts/gpuci/build.sh
Expand Up @@ -16,6 +16,14 @@ cd "$WORKSPACE"
# Determine CUDA release version
export CUDA_REL=${CUDA_VERSION%.*}

# Test with NVIDIA Bindings on CUDA 11.4
if [ $CUDA_TOOLKIT_VER == "11.4" ]
then
export NUMBA_CUDA_USE_NVIDIA_BINDING=1;
else
export NUMBA_CUDA_USE_NVIDIA_BINDING=0;
fi;

################################################################################
# SETUP - Check environment
################################################################################
Expand All @@ -41,6 +49,12 @@ gpuci_mamba_retry create -n numba_ci -y \

conda activate numba_ci

if [ $NUMBA_CUDA_USE_NVIDIA_BINDING == "1" ]
then
gpuci_logger "Install NVIDIA CUDA Python bindings";
gpuci_mamba_retry install nvidia::cuda-python;
fi;

gpuci_logger "Install numba"
python setup.py develop

Expand Down
31 changes: 31 additions & 0 deletions docs/source/cuda/bindings.rst
@@ -0,0 +1,31 @@
CUDA Bindings
=============

Numba supports two bindings to the CUDA Driver APIs: its own internal bindings
based on ctypes, and the official `NVIDIA CUDA Python bindings
<https://nvidia.github.io/cuda-python/>`_. Functionality is equivalent between
the two bindings, with two exceptions:

* the NVIDIA bindings presently do not support Per-Thread Default Streams
(PTDS), and an exception will be raised on import if PTDS is enabled along
with the NVIDIA bindings.
* The profiling APIs are not available with the NVIDIA bindings.

The internal bindings are used by default. If the NVIDIA bindings are installed,
then they can be used by setting the environment variable
``NUMBA_CUDA_USE_NVIDIA_BINDING`` to ``1`` prior to the import of Numba. Once
Numba has been imported, the selected binding cannot be changed.


Roadmap
-------

In future versions of Numba:

- The NVIDIA Bindings will be used by default, if they are installed.
- The internal bindings will be deprecated.
- The internal bindings will be removed.

It is expected that the NVIDIA bindings will be the default in Numba 0.56; at
present, no specific release is planned for the deprecation or removal of the
internal bindings.
1 change: 1 addition & 0 deletions docs/source/cuda/index.rst
Expand Up @@ -23,4 +23,5 @@ Numba for CUDA GPUs
ipc.rst
cuda_array_interface.rst
external-memory.rst
bindings.rst
faq.rst
23 changes: 23 additions & 0 deletions docs/source/cuda/overview.rst
Expand Up @@ -57,6 +57,29 @@ If you are not using Conda or if you want to use a different version of CUDA
toolkit, the following describe how Numba searches for a CUDA toolkit
installation.

.. _cuda-bindings:

CUDA Bindings
~~~~~~~~~~~~~

Numba supports interacting with the CUDA Driver API via the `NVIDIA CUDA Python
bindings <https://nvidia.github.io/cuda-python/>`_ and its own ctypes-based
binding. The ctypes-based binding is presently the default as Per-Thread
stuartarchibald marked this conversation as resolved.
Show resolved Hide resolved
Default Streams and the profiler APIs are not supported with the NVIDIA
bindings, but otherwise functionality is equivalent between the two. You can
install the NVIDIA bindings with::

$ conda install nvidia::cuda-python

if you are using Conda, or::

$ pip install cuda-python

if you are using pip.

The use of the NVIDIA bindings is enabled by setting the environment variable
:envvar:`NUMBA_CUDA_USE_NVIDIA_BINDING` to ``"1"``.

.. _cudatoolkit-lookup:

Setting CUDA Installation Path
Expand Down
9 changes: 8 additions & 1 deletion docs/source/reference/envvars.rst
Expand Up @@ -516,11 +516,18 @@ GPU support
heuristic needs to check the number of SMs available on the device in the
current context.

.. envvar:: CUDA_WARN_ON_IMPLICIT_COPY
.. envvar:: NUMBA_CUDA_WARN_ON_IMPLICIT_COPY
stuartarchibald marked this conversation as resolved.
Show resolved Hide resolved

Enable warnings if a kernel is launched with host memory which forces a copy to and
from the device. This option is on by default (default value is 1).

.. envvar:: NUMBA_CUDA_USE_NVIDIA_BINDING

When set to 1, Numba will use the `NVIDIA CUDA Python binding
<https://nvidia.github.io/cuda-python/>`_ to make calls to the driver API
instead of using its own ctypes binding. This defaults to 0 (off), as the
NVIDIA binding is currently missing support for Per-Thread Default
Streams.
stuartarchibald marked this conversation as resolved.
Show resolved Hide resolved

Threading Control
-----------------
Expand Down
2 changes: 2 additions & 0 deletions docs/source/user/installing.rst
Expand Up @@ -246,6 +246,8 @@ vary with target operating system and hardware. The following lists them all
Python 3.7.
* ``typeguard`` - used by ``runtests.py`` for
:ref:`runtime type-checking <type_anno_check>`.
* ``cuda-python`` - The NVIDIA CUDA Python bindings. See :ref:`cuda-bindings`.
Numba is tested with Version 11.5 of the bindings.

* To build the documentation:

Expand Down
18 changes: 18 additions & 0 deletions numba/core/config.py
Expand Up @@ -119,6 +119,20 @@ def update(self, force=False):
# Store a copy
self.old_environ = dict(new_environ)

self.validate()

def validate(self):
if CUDA_USE_NVIDIA_BINDING: # noqa: F821
try:
import cuda # noqa: F401
except ImportError as ie:
msg = ("CUDA Python bindings requested, "
"but they are not importable")
raise RuntimeError(msg) from ie
stuartarchibald marked this conversation as resolved.
Show resolved Hide resolved

if CUDA_PER_THREAD_DEFAULT_STREAM: # noqa: F821
stuartarchibald marked this conversation as resolved.
Show resolved Hide resolved
warnings.warn("PTDS is not supported with CUDA Python")
stuartarchibald marked this conversation as resolved.
Show resolved Hide resolved

def process_environ(self, environ):
def _readenv(name, ctor, default):
value = environ.get(name)
Expand Down Expand Up @@ -170,6 +184,10 @@ def optional_str(x):
CUDA_LOW_OCCUPANCY_WARNINGS = _readenv(
"NUMBA_CUDA_LOW_OCCUPANCY_WARNINGS", int, 1)

# Whether to use the official CUDA Python API Bindings
CUDA_USE_NVIDIA_BINDING = _readenv(
"NUMBA_CUDA_USE_NVIDIA_BINDING", int, 0)

# Debug flag to control compiler debug print
DEBUG = _readenv("NUMBA_DEBUG", int, 0)

Expand Down
8 changes: 6 additions & 2 deletions numba/cuda/api.py
Expand Up @@ -227,9 +227,13 @@ def open_ipc_array(handle, shape, dtype, strides=None, offset=0):
# compute size
size = np.prod(shape) * dtype.itemsize
# manually recreate the IPC mem handle
handle = driver.drvapi.cu_ipc_mem_handle(*handle)
if driver.USE_NV_BINDING:
driver_handle = driver.binding.CUipcMemHandle()
driver_handle.reserved = handle
else:
driver_handle = driver.drvapi.cu_ipc_mem_handle(*handle)
# use *IpcHandle* to open the IPC memory
ipchandle = driver.IpcHandle(None, handle, size, offset=offset)
ipchandle = driver.IpcHandle(None, driver_handle, size, offset=offset)
yield ipchandle.open_array(current_context(), shape=shape,
strides=strides, dtype=dtype)
ipchandle.close()
Expand Down
2 changes: 1 addition & 1 deletion numba/cuda/codegen.py
Expand Up @@ -167,7 +167,7 @@ def get_cubin(self, cc=None):
if cubin:
return cubin

linker = driver.Linker(max_registers=self._max_registers, cc=cc)
linker = driver.Linker.new(max_registers=self._max_registers, cc=cc)

ptxes = self._get_ptxes(cc=cc)
for ptx in ptxes:
Expand Down
21 changes: 18 additions & 3 deletions numba/cuda/compiler.py
Expand Up @@ -563,7 +563,12 @@ def launch(self, args, griddim, blockdim, stream=0, sharedmem=0):
for t, v in zip(self.argument_types, args):
self._prepare_args(t, v, stream, retr, kernelargs)

stream_handle = stream and stream.handle or None
if driver.USE_NV_BINDING:
zero_stream = driver.binding.CUstream(0)
else:
zero_stream = None

stream_handle = stream and stream.handle or zero_stream

# Invoke kernel
driver.launch_kernel(cufunc.handle,
Expand Down Expand Up @@ -634,7 +639,14 @@ def _prepare_args(self, ty, val, stream, retr, kernelargs):
parent = ctypes.c_void_p(0)
nitems = c_intp(devary.size)
itemsize = c_intp(devary.dtype.itemsize)
data = ctypes.c_void_p(driver.device_pointer(devary))

ptr = driver.device_pointer(devary)

if driver.USE_NV_BINDING:
ptr = int(ptr)

data = ctypes.c_void_p(ptr)

kernelargs.append(meminfo)
kernelargs.append(parent)
kernelargs.append(nitems)
Expand Down Expand Up @@ -674,7 +686,10 @@ def _prepare_args(self, ty, val, stream, retr, kernelargs):

elif isinstance(ty, types.Record):
devrec = wrap_arg(val).to_device(retr, stream)
kernelargs.append(devrec)
ptr = devrec.device_ctypes_pointer
if driver.USE_NV_BINDING:
ptr = ctypes.c_void_p(int(ptr))
kernelargs.append(ptr)

elif isinstance(ty, types.BaseTuple):
assert len(ty) == len(val)
Expand Down
23 changes: 18 additions & 5 deletions numba/cuda/cudadrv/devicearray.py
Expand Up @@ -105,19 +105,29 @@ def __init__(self, shape, strides, dtype, stream=0, gpu_data=None):
self.alloc_size = _driver.device_memory_size(gpu_data)
else:
# Make NULL pointer for empty allocation
if _driver.USE_NV_BINDING:
null = _driver.binding.CUdeviceptr(0)
else:
null = c_void_p(0)
gpu_data = _driver.MemoryPointer(context=devices.get_context(),
pointer=c_void_p(0), size=0)
pointer=null, size=0)
self.alloc_size = 0

self.gpu_data = gpu_data
self.stream = stream

@property
def __cuda_array_interface__(self):
if self.device_ctypes_pointer.value is not None:
ptr = self.device_ctypes_pointer.value
if _driver.USE_NV_BINDING:
if self.device_ctypes_pointer is not None:
ptr = int(self.device_ctypes_pointer)
else:
ptr = 0
else:
ptr = 0
if self.device_ctypes_pointer.value is not None:
ptr = self.device_ctypes_pointer.value
else:
ptr = 0

return {
'shape': tuple(self.shape),
Expand Down Expand Up @@ -191,7 +201,10 @@ def device_ctypes_pointer(self):
"""Returns the ctypes pointer to the GPU data buffer
"""
if self.gpu_data is None:
return c_void_p(0)
if _driver.USE_NV_BINDING:
return _driver.binding.CUdeviceptr(0)
else:
return c_void_p(0)
else:
return self.gpu_data.device_ctypes_pointer

Expand Down
14 changes: 11 additions & 3 deletions numba/cuda/cudadrv/devices.py
Expand Up @@ -14,7 +14,7 @@
import threading
from contextlib import contextmanager

from .driver import driver
from .driver import driver, USE_NV_BINDING


class _DeviceList(object):
Expand Down Expand Up @@ -139,6 +139,8 @@ def get_or_create_context(self, devnum):
else:
return attached_ctx
else:
if USE_NV_BINDING:
devnum = int(devnum)
return self._activate_context_for(devnum)

def _get_or_create_context_uncached(self, devnum):
Expand All @@ -155,10 +157,16 @@ def _get_or_create_context_uncached(self, devnum):
# Get primary context for the active device
ctx = self.gpus[ac.devnum].get_primary_context()
# Is active context the primary context?
if ctx.handle.value != ac.context_handle.value:
if USE_NV_BINDING:
ctx_handle = int(ctx.handle)
ac_ctx_handle = int(ac.context_handle)
else:
ctx_handle = ctx.handle.value
ac_ctx_handle = ac.context_handle.value
if ctx_handle != ac_ctx_handle:
msg = ('Numba cannot operate on non-primary'
' CUDA context {:x}')
raise RuntimeError(msg.format(ac.context_handle.value))
raise RuntimeError(msg.format(ac_ctx_handle))
# Ensure the context is ready
ctx.prepare_for_use()
return ctx
Expand Down