Skip to content

Commit

Permalink
simplify a bit DLPack support
Browse files Browse the repository at this point in the history
  • Loading branch information
leofang committed Jul 1, 2021
1 parent 7338c11 commit b4608a1
Showing 1 changed file with 12 additions and 15 deletions.
27 changes: 12 additions & 15 deletions src/mpi4py/MPI/asdlpack.pxi
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#------------------------------------------------------------------------------
# Below is dlpack.h (as of commit 9b6176fd, to be released as v0.6)

cdef extern from * nogil:
ctypedef unsigned char uint8_t
Expand All @@ -20,8 +21,8 @@ ctypedef enum DLDeviceType:
kDLCUDAManaged = 13

ctypedef struct DLDevice:
DLDeviceType device_type
int device_id
DLDeviceType device_type
int device_id

ctypedef enum DLDataTypeCode:
kDLInt = 0
Expand Down Expand Up @@ -151,7 +152,8 @@ cdef inline Py_ssize_t dlpack_get_itemsize(const DLTensor *dltensor) nogil:
#------------------------------------------------------------------------------

cdef int Py_CheckDLPackBuffer(object obj):
try: return <bint>hasattr(obj, '__dlpack__')
# we check __dlpack_device__ to avoid potential side effects
try: return <bint>hasattr(obj, '__dlpack_device__')
except: return 0

cdef int Py_GetDLPackBuffer(object obj, Py_buffer *view, int flags) except -1:
Expand All @@ -168,18 +170,14 @@ cdef int Py_GetDLPackBuffer(object obj, Py_buffer *view, int flags) except -1:
cdef bint fixnull

try:
dlpack = obj.__dlpack__
# we check __dlpack_device__ first instead of __dlpack__ to avoid
# potential side effects
device_type, devide_id = obj.__dlpack_device__()
except AttributeError:
raise NotImplementedError("dlpack: missing __dlpack__ method")
raise NotImplementedError("dlpack: missing support")

try:
dlpack_device = obj.__dlpack_device__
except AttributeError:
dlpack_device = None
if dlpack_device is not None:
device_type, device_id = dlpack_device()
else:
device_type, devide_id = kDLCPU, 0
# at this point, __dlpack__ should be there
dlpack = obj.__dlpack__
if device_type == kDLCPU:
capsule = dlpack()
else:
Expand All @@ -200,7 +198,7 @@ cdef int Py_GetDLPackBuffer(object obj, Py_buffer *view, int flags) except -1:

fixnull = (buf == NULL and size == 0)
if fixnull: buf = &fixnull
PyBuffer_FillInfo(view, obj, buf, size, readonly, flags)
PyBuffer_FillInfo(view, capsule, buf, size, readonly, flags)
if fixnull: view.buf = NULL

if (flags & PyBUF_FORMAT) == PyBUF_FORMAT:
Expand All @@ -211,7 +209,6 @@ cdef int Py_GetDLPackBuffer(object obj, Py_buffer *view, int flags) except -1:
if managed.deleter != NULL:
managed.deleter(managed)
PyCapsule_SetName(capsule, b"used_dltensor")
del capsule
return 0

#------------------------------------------------------------------------------

0 comments on commit b4608a1

Please sign in to comment.