diff --git a/src/mpi4py/MPI.src/ascaibuf.pxi b/src/mpi4py/MPI.src/ascaibuf.pxi index 9fc9408c..794717cd 100644 --- a/src/mpi4py/MPI.src/ascaibuf.pxi +++ b/src/mpi4py/MPI.src/ascaibuf.pxi @@ -73,6 +73,7 @@ cdef int Py_GetCAIBuffer(object obj, Py_buffer *view, int flags) except -1: cdef bint readonly = 0 cdef Py_ssize_t s, size = 1 cdef Py_ssize_t itemsize = 1 + cdef char *format = BYTE_FMT cdef char byteorder = c'|' cdef char typekind = c'u' @@ -98,6 +99,7 @@ cdef int Py_GetCAIBuffer(object obj, Py_buffer *view, int flags) except -1: byteorder = ord(typestr[0:1]) typekind = ord(typestr[1:2]) itemsize = int(typestr[2:]) + format = cuda_get_format(typekind, itemsize) if (flags & PyBUF_FORMAT) == PyBUF_FORMAT: if byteorder == c'<': # little-endian @@ -151,7 +153,7 @@ cdef int Py_GetCAIBuffer(object obj, Py_buffer *view, int flags) except -1: PyBuffer_FillInfo(view, obj, buf, size*itemsize, readonly, flags) if (flags & PyBUF_FORMAT) == PyBUF_FORMAT: - view.format = cuda_get_format(typekind, itemsize) + view.format = format if view.format != BYTE_FMT: view.itemsize = itemsize diff --git a/src/mpi4py/MPI.src/asdlpack.pxi b/src/mpi4py/MPI.src/asdlpack.pxi index f8c7e132..bc9db9f2 100644 --- a/src/mpi4py/MPI.src/asdlpack.pxi +++ b/src/mpi4py/MPI.src/asdlpack.pxi @@ -5,10 +5,15 @@ cdef extern from * nogil: ctypedef unsigned char uint8_t ctypedef unsigned short uint16_t - ctypedef int int32_t + ctypedef signed int int32_t + ctypedef unsigned int uint32_t ctypedef signed long long int64_t ctypedef unsigned long long uint64_t +ctypedef struct DLPackVersion: + uint32_t major + uint32_t minor + ctypedef enum DLDeviceType: kDLCPU = 1 kDLCUDA = 2 @@ -58,6 +63,17 @@ ctypedef struct DLManagedTensor: void *manager_ctx void (*deleter)(DLManagedTensor *) +ctypedef enum: + DLPACK_FLAG_BITMASK_READ_ONLY = (1UL << 0UL) + DLPACK_FLAG_BITMASK_IS_COPIED = (1UL << 1UL) + +ctypedef struct DLManagedTensorVersioned: + DLPackVersion version + void *manager_ctx + void (*deleter)(DLManagedTensorVersioned *) + uint64_t flags + DLTensor dl_tensor + # ----------------------------------------------------------------------------- cdef extern from "Python.h": @@ -67,6 +83,14 @@ cdef extern from "Python.h": # ----------------------------------------------------------------------------- +cdef inline int dlpack_check_version( + const DLPackVersion *version, + unsigned version_major, +) except -1: + if version == NULL: return 0 + if version.major >= version_major: return 0 + raise BufferError("dlpack: unexpected version") + cdef inline int dlpack_is_contig( const DLTensor *dltensor, char order, @@ -189,15 +213,25 @@ cdef int Py_CheckDLPackBuffer(object obj) noexcept: except: return 0 # ~> uncovered # noqa cdef int Py_GetDLPackBuffer(object obj, Py_buffer *view, int flags) except -1: + cdef unsigned version_major = 1 + cdef const char *capsulename = b"dltensor_versioned" + cdef const char *usedcapsulename = b"used_dltensor_versioned" + cdef uint64_t READONLY = DLPACK_FLAG_BITMASK_READ_ONLY cdef object dlpack cdef object dlpack_device + cdef tuple max_version cdef unsigned device_type cdef int device_id cdef object capsule - cdef DLManagedTensor *managed + cdef void *pointer + cdef DLManagedTensorVersioned *managed1 = NULL + cdef DLManagedTensor *managed0 = NULL + cdef const DLPackVersion *dlversion cdef const DLTensor *dltensor cdef void *buf cdef Py_ssize_t size + cdef Py_ssize_t itemsize + cdef char *format cdef bint readonly try: @@ -207,37 +241,65 @@ cdef int Py_GetDLPackBuffer(object obj, Py_buffer *view, int flags) except -1: raise NotImplementedError("dlpack: missing support") device_type, device_id = dlpack_device() - if device_type == kDLCPU: - capsule = dlpack() - else: - capsule = dlpack(stream=-1) - device_id # unused - if not PyCapsule_IsValid(capsule, b"dltensor"): - raise BufferError("dlpack: invalid capsule object") + device_id # unused - managed = PyCapsule_GetPointer(capsule, b"dltensor") - dltensor = &managed.dl_tensor + try: # DLPack v1.0+ + max_version = (version_major, 0) + if device_type == kDLCPU: + capsule = dlpack(max_version=max_version, copy=False) + else: + capsule = dlpack(stream=-1, max_version=max_version, copy=False) + except TypeError: # DLPack v0.x + version_major = 0 + capsulename = b"dltensor" + usedcapsulename = b"used_dltensor" + if device_type == kDLCPU: + capsule = dlpack() + else: + capsule = dlpack(stream=-1) + + if not PyCapsule_IsValid(capsule, capsulename): + raise BufferError("dlpack: invalid capsule object") + pointer = PyCapsule_GetPointer(capsule, capsulename) + if version_major >= 1: + managed1 = pointer + dlversion = &managed1.version + dltensor = &managed1.dl_tensor + readonly = (managed1.flags & READONLY) == READONLY + else: + managed0 = pointer + dlversion = NULL + dltensor = &managed0.dl_tensor + readonly = 0 try: + dlpack_check_version(dlversion, version_major) dlpack_check_shape(dltensor) dlpack_check_contig(dltensor) - buf = dlpack_get_data(dltensor) size = dlpack_get_size(dltensor) - readonly = 0 - - PyBuffer_FillInfo(view, obj, buf, size, readonly, flags) - - if (flags & PyBUF_FORMAT) == PyBUF_FORMAT: - view.format = dlpack_get_format(dltensor) - if view.format != BYTE_FMT: - view.itemsize = dlpack_get_itemsize(dltensor) + itemsize = dlpack_get_itemsize(dltensor) + format = dlpack_get_format(dltensor) finally: - if managed.deleter != NULL: - managed.deleter(managed) - PyCapsule_SetName(capsule, b"used_dltensor") + if managed1 != NULL: + if managed1.deleter != NULL: + managed1.deleter(managed1) + if managed0 != NULL: + if managed0.deleter != NULL: + managed0.deleter(managed0) + PyCapsule_SetName(capsule, usedcapsulename) del capsule + if PYPY and readonly and ((flags & PyBUF_WRITABLE) == PyBUF_WRITABLE): + raise BufferError("Object is not writable") # ~> pypy + + PyBuffer_FillInfo(view, obj, buf, size, readonly, flags) + + if (flags & PyBUF_FORMAT) == PyBUF_FORMAT: + view.format = format + if view.format != BYTE_FMT: + view.itemsize = itemsize + return device_type # ----------------------------------------------------------------------------- diff --git a/test/dlpackimpl.py b/test/dlpackimpl.py index 29284faa..4c45b699 100644 --- a/test/dlpackimpl.py +++ b/test/dlpackimpl.py @@ -5,6 +5,12 @@ if hasattr(sys, 'pypy_version_info'): raise ImportError("unsupported on PyPy") +class DLPackVersion(ctypes.Structure): + _fields_ = [ + ("major", ctypes.c_uint32), + ("minor", ctypes.c_uint32), + ] + class DLDeviceType(IntEnum): kDLCPU = 1 kDLCUDA = 2 @@ -63,9 +69,24 @@ class DLManagedTensor(ctypes.Structure): ("deleter", DLManagedTensorDeleter), ] +DLPACK_FLAG_BITMASK_READ_ONLY = 1 << 0 +DLPACK_FLAG_BITMASK_IS_COPIED = 1 << 1 + +DLManagedTensorVersionedDeleter = ctypes.CFUNCTYPE(None, ctypes.c_void_p) + +class DLManagedTensorVersioned(ctypes.Structure): + _fields_ = [ + ("version", DLPackVersion), + ("manager_ctx", ctypes.c_void_p), + ("deleter", DLManagedTensorDeleter), + ("flags", ctypes.c_uint64), + ("dl_tensor", DLTensor), +] + pyapi = ctypes.pythonapi DLManagedTensor_p = ctypes.POINTER(DLManagedTensor) +DLManagedTensorVersioned_p = ctypes.POINTER(DLManagedTensorVersioned) Py_IncRef = pyapi.Py_IncRef Py_IncRef.restype = None @@ -98,6 +119,13 @@ class DLManagedTensor(ctypes.Structure): PyCapsule_GetContext.argtypes = [ctypes.py_object] +def make_dl_version(major, minor): + version = DLPackVersion() + version.major = major + version.minor = minor + return version + + def make_dl_datatype(typecode, itemsize): code = None bits = itemsize * 8 @@ -195,11 +223,27 @@ def dl_managed_tensor_deleter(void_p): if False: Py_DecRef(py_obj) -def make_dl_managed_tensor(obj): - managed = DLManagedTensor() - managed.dl_tensor = make_dl_tensor(obj) - managed.manager_ctx = make_dl_manager_ctx(obj) - managed.deleter = dl_managed_tensor_deleter +@DLManagedTensorVersionedDeleter +def dl_managed_tensor_versioned_deleter(void_p): + managed = ctypes.cast(void_p, DLManagedTensorVersioned_p) + manager_ctx = managed.contents.manager_ctx + py_obj = ctypes.cast(manager_ctx, ctypes.py_object) + if False: Py_DecRef(py_obj) + + +def make_dl_managed_tensor(obj, versioned=False): + if versioned: + managed = DLManagedTensorVersioned() + managed.version = make_dl_version(1, 0) + managed.manager_ctx = make_dl_manager_ctx(obj) + managed.deleter = dl_managed_tensor_versioned_deleter + managed.flags = 0 + managed.dl_tensor = make_dl_tensor(obj) + else: + managed = DLManagedTensor() + managed.dl_tensor = make_dl_tensor(obj) + managed.manager_ctx = make_dl_manager_ctx(obj) + managed.deleter = dl_managed_tensor_deleter return managed @@ -213,22 +257,33 @@ def make_py_context(context): @PyCapsule_Destructor def py_capsule_destructor(void_p): capsule = ctypes.cast(void_p, ctypes.py_object) - if PyCapsule_IsValid(capsule, b"dltensor"): - pointer = PyCapsule_GetPointer(capsule, b"dltensor") - managed = ctypes.cast(pointer, DLManagedTensor_p) - deleter = managed.contents.deleter - if deleter: - deleter(managed) + for py_capsule_name, dl_managed_tensor_type_p in ( + (b"dltensor_versioned", DLManagedTensorVersioned_p), + (b"dltensor", DLManagedTensor_p), + ): + if PyCapsule_IsValid(capsule, py_capsule_name): + pointer = PyCapsule_GetPointer(capsule, py_capsule_name) + managed = ctypes.cast(pointer, dl_managed_tensor_type_p) + deleter = managed.contents.deleter + if deleter: + deleter(managed) + break context = PyCapsule_GetContext(capsule) managed = ctypes.cast(context, ctypes.py_object) Py_DecRef(managed) -def make_py_capsule(managed): - if not isinstance(managed, DLManagedTensor): - managed = make_dl_managed_tensor(managed) +def make_py_capsule(managed, versioned=False): + if versioned >= 1: + py_capsule_name = b"dltensor_versioned" + if not isinstance(managed, DLManagedTensorVersioned): + managed = make_dl_managed_tensor_versioned(managed) + else: + py_capsule_name = b"dltensor" + if not isinstance(managed, DLManagedTensor): + managed = make_dl_managed_tensor(managed) pointer = ctypes.pointer(managed) - capsule = PyCapsule_New(pointer, b"dltensor", py_capsule_destructor) + capsule = PyCapsule_New(pointer, py_capsule_name, py_capsule_destructor) context = make_py_context(managed) PyCapsule_SetContext(capsule, context) return capsule diff --git a/test/test_msgspec.py b/test/test_msgspec.py index f93ff64d..fc666ba4 100644 --- a/test/test_msgspec.py +++ b/test/test_msgspec.py @@ -41,11 +41,14 @@ def __setitem__(self, item, value): except ImportError: dlpack = None + class DLPackCPUBuf(BaseBuf): + versioned = True + def __init__(self, typecode, initializer): super().__init__(typecode, initializer) - self.managed = dlpack.make_dl_managed_tensor(self._buf) + self.managed = dlpack.make_dl_managed_tensor(self._buf, self.versioned) def __del__(self): self.managed = None @@ -57,7 +60,13 @@ def __dlpack_device__(self): device = self.managed.dl_tensor.device return (device.device_type, device.device_id) - def __dlpack__(self, stream=None): + def __dlpack__( + self, + stream=None, + max_version=None, + dl_device=None, + copy=None, + ): kDLCPU = dlpack.DLDeviceType.kDLCPU managed = self.managed device = managed.dl_tensor.device @@ -65,10 +74,18 @@ def __dlpack__(self, stream=None): assert stream is None else: assert stream == -1 - capsule = dlpack.make_py_capsule(managed) + capsule = dlpack.make_py_capsule(managed, self.versioned) return capsule +class DLPackCPUBufV0(DLPackCPUBuf): + + versioned = False + + def __dlpack__(self, stream=None): + return super().__dlpack__(stream=stream) + + if cupy is not None: class DLPackGPUBuf(BaseBuf): @@ -96,6 +113,25 @@ def __dlpack_device__(self): else: return (self.dev_type, self._buf.device.id) + if False: # TODO: wait until CuPy supports DLPack v1.0 + + def __dlpack__(self, stream=None, **kwargs): + assert self.has_dlpack + cupy.cuda.get_current_stream().synchronize() + return self._buf.__dlpack__(stream=-1, **kwargs) + + else: + + def __dlpack__(self, stream=None): + cupy.cuda.get_current_stream().synchronize() + if self.has_dlpack: + return self._buf.__dlpack__(stream=-1) + else: + return self._buf.toDlpack() + + + class DLPackGPUBufV0(DLPackGPUBuf): + def __dlpack__(self, stream=None): cupy.cuda.get_current_stream().synchronize() if self.has_dlpack: @@ -105,7 +141,7 @@ def __dlpack__(self, stream=None): else: - class DLPackGPUBuf(DLPackCPUBuf): + class DLPackGPUBufInitMixin: def __init__(self, *args): super().__init__(*args) @@ -113,6 +149,12 @@ def __init__(self, *args): device = self.managed.dl_tensor.device device.device_type = kDLCUDA + class DLPackGPUBuf(DLPackGPUBufInitMixin, DLPackCPUBuf): + pass + + class DLPackGPUBufV0(DLPackGPUBufInitMixin, DLPackCPUBufV0): + pass + # --- class CAIBuf(BaseBuf): @@ -473,6 +515,11 @@ class TestMessageSimpleDLPackCPUBuf(unittest.TestCase, def array(self, typecode, initializer): return DLPackCPUBuf(typecode, initializer) +class TestMessageSimpleDLPackCPUBufV0(TestMessageSimpleDLPackCPUBuf): + + def array(self, typecode, initializer): + return DLPackCPUBufV0(typecode, initializer) + @unittest.skipIf(cupy is None and (array is None or dlpack is None), 'cupy') class TestMessageSimpleDLPackGPUBuf(unittest.TestCase, BaseTestMessageSimpleArray): @@ -480,6 +527,11 @@ class TestMessageSimpleDLPackGPUBuf(unittest.TestCase, def array(self, typecode, initializer): return DLPackGPUBuf(typecode, initializer) +class TestMessageSimpleDLPackGPUBufV0(TestMessageSimpleDLPackGPUBuf): + + def array(self, typecode, initializer): + return DLPackGPUBufV0(typecode, initializer) + @unittest.skipIf(array is None, 'array') class TestMessageSimpleCAIBuf(unittest.TestCase, BaseTestMessageSimpleArray): @@ -568,6 +620,19 @@ def testNotContiguous(self): @unittest.skipIf(dlpack is None, 'dlpack') class TestMessageDLPackCPUBuf(unittest.TestCase): + def testVersion(self): + buf = DLPackCPUBuf('i', [0,1,2,3]) + buf.managed.version.major = 0 + self.assertRaises(BufferError, MPI.Get_address, buf) + + def testReadonly(self): + smsg = DLPackCPUBuf('i', [0,1,2,3]) + rmsg = DLPackCPUBuf('i', [0,0,0,0]) + smsg.managed.flags |= dlpack.DLPACK_FLAG_BITMASK_READ_ONLY + rmsg.managed.flags |= dlpack.DLPACK_FLAG_BITMASK_READ_ONLY + MPI.Get_address(smsg) + self.assertRaises(BufferError, Sendrecv, smsg, rmsg) + def testDevice(self): buf = DLPackCPUBuf('i', [0,1,2,3]) buf.__dlpack_device__ = None @@ -601,7 +666,7 @@ def testCapsule(self): del buf.__dlpack__ del capsule # - buf.__dlpack__ = lambda *args, **kwargs: None + buf.__dlpack__ = lambda *args, **kwargs: None self.assertRaises(BufferError, MPI.Get_address, buf) del buf.__dlpack__ @@ -720,9 +785,10 @@ def testByteOffset(self): @unittest.skipIf(array is None, 'array') class TestMessageCAIBuf(unittest.TestCase): - def testNonReadonly(self): + def testReadonly(self): smsg = CAIBuf('i', [1,2,3], readonly=True) rmsg = CAIBuf('i', [0,0,0], readonly=True) + MPI.Get_address(smsg) self.assertRaises(BufferError, Sendrecv, smsg, rmsg) def testNonContiguous(self):