Skip to content

Commit

Permalink
add DLPackCuPy arrayimpl
Browse files Browse the repository at this point in the history
  • Loading branch information
leofang committed Jul 1, 2021
1 parent b4608a1 commit ed1f45c
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 1 deletion.
40 changes: 40 additions & 0 deletions test/arrayimpl.py
Expand Up @@ -389,6 +389,46 @@ def size(self):
return self.array.size


if dlpack is not None and cupy is not None:

# Note: we do not create a BaseDLPackGPU class because each GPU library
# has its own way to get device ID etc, so we have to reimplement the
# DLPack support anyway

@add_backend
class DLPackCuPy(GPUArrayCuPy):

backend = 'dlpack-cupy'
has_dlpack = None
dev_type = None

def __init__(self, arg, typecode, shape=None):
super().__init__(arg, typecode, shape)
self._has_dlpack = hasattr(self.array, '__dlpack_device__')
# TODO(leofang): test CUDA managed memory?
if cupy.cuda.runtime.is_hip:
self.dev_type = dlpack.DLDeviceType.kDLROCM
else:
self.dev_type = dlpack.DLDeviceType.kDLCUDA

@property
def address(self):
return self.array.data.ptr

def __dlpack_device__(self):
if self.has_dlpack:
return self.array.__dlpack_device__()
else:
return (self.dev_type, self.array.device.id)

def __dlpack__(self, stream=None):
cupy.cuda.get_current_stream().synchronize()
if self.has_dlpack:
return self.array.__dlpack__(stream)
else:
return self.array.toDlpack()


if numba is not None:

@add_backend
Expand Down
2 changes: 1 addition & 1 deletion test/mpiunittest.py
Expand Up @@ -92,7 +92,7 @@ def key(s):
return None

def is_mpi_gpu(predicate, array):
if array.backend in ('cupy', 'numba'):
if array.backend in ('cupy', 'numba', 'dlpack-cupy'):
if mpi_predicate(predicate):
return True
return False
Expand Down

0 comments on commit ed1f45c

Please sign in to comment.