Skip to content

Commit

Permalink
add simple message tests via DLPack GPU support
Browse files Browse the repository at this point in the history
  • Loading branch information
leofang committed Jul 1, 2021
1 parent 79f64d5 commit 0780e63
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 2 deletions.
2 changes: 1 addition & 1 deletion test/arrayimpl.py
Expand Up @@ -424,7 +424,7 @@ def __dlpack_device__(self):
def __dlpack__(self, stream=None):
cupy.cuda.get_current_stream().synchronize()
if self.has_dlpack:
return self.array.__dlpack__(stream)
return self.array.__dlpack__(stream=-1)
else:
return self.array.toDlpack()

Expand Down
46 changes: 45 additions & 1 deletion test/test_msgspec.py
Expand Up @@ -91,6 +91,41 @@ def __dlpack__(self, stream=None):
capsule = dlpack.make_py_capsule(managed)
return capsule


if cupy is not None:

class DLPackGPUBuf(BaseBuf):

has_dlpack = None
dev_type = None

def __init__(self, typecode, initializer):
self._buf = cupy.array(initializer, dtype=typecode)
self.has_dlpack = hasattr(self._buf, '__dlpack_device__')
# TODO(leofang): test CUDA managed memory?
if cupy.cuda.runtime.is_hip:
self.dev_type = dlpack.DLDeviceType.kDLROCM
else:
self.dev_type = dlpack.DLDeviceType.kDLCUDA

def __del__(self):
if not pypy and sys.getrefcount(self._buf) > 2:
raise RuntimeError('dlpack: possible reference leak')

def __dlpack_device__(self):
if self.has_dlpack:
return self._buf.__dlpack_device__()
else:
return (self.dev_type, self._buf.device.id)

def __dlpack__(self, stream=None):
cupy.cuda.get_current_stream().synchronize()
if self.has_dlpack:
return self._buf.__dlpack__(stream=-1)
else:
return self._buf.toDlpack()


# ---

class CAIBuf(BaseBuf):
Expand Down Expand Up @@ -426,12 +461,21 @@ def testNotContiguous(self):
@unittest.skipIf(array is None, 'array')
@unittest.skipIf(dlpack is None, 'dlpack')
class TestMessageSimpleDLPackCPUBuf(unittest.TestCase,
BaseTestMessageSimpleArray):
BaseTestMessageSimpleArray):

def array(self, typecode, initializer):
return DLPackCPUBuf(typecode, initializer)


@unittest.skipIf(cupy is None, 'cupy')
@unittest.skipIf(dlpack is None, 'dlpack')
class TestMessageSimpleDLPackGPUBuf(unittest.TestCase,
BaseTestMessageSimpleArray):

def array(self, typecode, initializer):
return DLPackGPUBuf(typecode, initializer)


@unittest.skipIf(array is None, 'array')
class TestMessageSimpleCAIBuf(unittest.TestCase,
BaseTestMessageSimpleArray):
Expand Down

0 comments on commit 0780e63

Please sign in to comment.