From b4608a1339dafd09b339f282d76957503cbfdc6d Mon Sep 17 00:00:00 2001 From: Leo Fang Date: Thu, 1 Jul 2021 02:32:22 -0400 Subject: [PATCH] simplify a bit DLPack support --- src/mpi4py/MPI/asdlpack.pxi | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/src/mpi4py/MPI/asdlpack.pxi b/src/mpi4py/MPI/asdlpack.pxi index e89994768..1cc21b36b 100644 --- a/src/mpi4py/MPI/asdlpack.pxi +++ b/src/mpi4py/MPI/asdlpack.pxi @@ -1,4 +1,5 @@ #------------------------------------------------------------------------------ +# Below is dlpack.h (as of commit 9b6176fd, to be released as v0.6) cdef extern from * nogil: ctypedef unsigned char uint8_t @@ -20,8 +21,8 @@ ctypedef enum DLDeviceType: kDLCUDAManaged = 13 ctypedef struct DLDevice: - DLDeviceType device_type - int device_id + DLDeviceType device_type + int device_id ctypedef enum DLDataTypeCode: kDLInt = 0 @@ -151,7 +152,8 @@ cdef inline Py_ssize_t dlpack_get_itemsize(const DLTensor *dltensor) nogil: #------------------------------------------------------------------------------ cdef int Py_CheckDLPackBuffer(object obj): - try: return hasattr(obj, '__dlpack__') + # we check __dlpack_device__ to avoid potential side effects + try: return hasattr(obj, '__dlpack_device__') except: return 0 cdef int Py_GetDLPackBuffer(object obj, Py_buffer *view, int flags) except -1: @@ -168,18 +170,14 @@ cdef int Py_GetDLPackBuffer(object obj, Py_buffer *view, int flags) except -1: cdef bint fixnull try: - dlpack = obj.__dlpack__ + # we check __dlpack_device__ first instead of __dlpack__ to avoid + # potential side effects + device_type, devide_id = obj.__dlpack_device__() except AttributeError: - raise NotImplementedError("dlpack: missing __dlpack__ method") + raise NotImplementedError("dlpack: missing support") - try: - dlpack_device = obj.__dlpack_device__ - except AttributeError: - dlpack_device = None - if dlpack_device is not None: - device_type, device_id = dlpack_device() - else: - device_type, devide_id = kDLCPU, 0 + # at this point, __dlpack__ should be there + dlpack = obj.__dlpack__ if device_type == kDLCPU: capsule = dlpack() else: @@ -200,7 +198,7 @@ cdef int Py_GetDLPackBuffer(object obj, Py_buffer *view, int flags) except -1: fixnull = (buf == NULL and size == 0) if fixnull: buf = &fixnull - PyBuffer_FillInfo(view, obj, buf, size, readonly, flags) + PyBuffer_FillInfo(view, capsule, buf, size, readonly, flags) if fixnull: view.buf = NULL if (flags & PyBUF_FORMAT) == PyBUF_FORMAT: @@ -211,7 +209,6 @@ cdef int Py_GetDLPackBuffer(object obj, Py_buffer *view, int flags) except -1: if managed.deleter != NULL: managed.deleter(managed) PyCapsule_SetName(capsule, b"used_dltensor") - del capsule return 0 #------------------------------------------------------------------------------