Skip to content

Commit

Permalink
dlpack: Update to DLPack v1.0 and Python array API v2023.12
Browse files Browse the repository at this point in the history
  • Loading branch information
dalcinl committed May 14, 2024
1 parent 8472cca commit 09bd782
Show file tree
Hide file tree
Showing 4 changed files with 230 additions and 45 deletions.
4 changes: 3 additions & 1 deletion src/mpi4py/MPI.src/ascaibuf.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ cdef int Py_GetCAIBuffer(object obj, Py_buffer *view, int flags) except -1:
cdef bint readonly = 0
cdef Py_ssize_t s, size = 1
cdef Py_ssize_t itemsize = 1
cdef char *format = BYTE_FMT
cdef char byteorder = c'|'
cdef char typekind = c'u'

Expand All @@ -98,6 +99,7 @@ cdef int Py_GetCAIBuffer(object obj, Py_buffer *view, int flags) except -1:
byteorder = <char>ord(typestr[0:1])
typekind = <char>ord(typestr[1:2])
itemsize = <Py_ssize_t>int(typestr[2:])
format = cuda_get_format(typekind, itemsize)

if (flags & PyBUF_FORMAT) == PyBUF_FORMAT:
if byteorder == c'<': # little-endian
Expand Down Expand Up @@ -151,7 +153,7 @@ cdef int Py_GetCAIBuffer(object obj, Py_buffer *view, int flags) except -1:
PyBuffer_FillInfo(view, obj, buf, size*itemsize, readonly, flags)

if (flags & PyBUF_FORMAT) == PyBUF_FORMAT:
view.format = cuda_get_format(typekind, itemsize)
view.format = format
if view.format != BYTE_FMT:
view.itemsize = itemsize

Expand Down
108 changes: 85 additions & 23 deletions src/mpi4py/MPI.src/asdlpack.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@
cdef extern from * nogil:
ctypedef unsigned char uint8_t
ctypedef unsigned short uint16_t
ctypedef int int32_t
ctypedef signed int int32_t
ctypedef unsigned int uint32_t
ctypedef signed long long int64_t
ctypedef unsigned long long uint64_t

ctypedef struct DLPackVersion:
uint32_t major
uint32_t minor

ctypedef enum DLDeviceType:
kDLCPU = 1
kDLCUDA = 2
Expand Down Expand Up @@ -58,6 +63,17 @@ ctypedef struct DLManagedTensor:
void *manager_ctx
void (*deleter)(DLManagedTensor *)

ctypedef enum:
DLPACK_FLAG_BITMASK_READ_ONLY = (1UL << 0UL)
DLPACK_FLAG_BITMASK_IS_COPIED = (1UL << 1UL)

ctypedef struct DLManagedTensorVersioned:
DLPackVersion version
void *manager_ctx
void (*deleter)(DLManagedTensorVersioned *)
uint64_t flags
DLTensor dl_tensor

# -----------------------------------------------------------------------------

cdef extern from "Python.h":
Expand All @@ -67,6 +83,14 @@ cdef extern from "Python.h":

# -----------------------------------------------------------------------------

cdef inline int dlpack_check_version(
const DLPackVersion *version,
unsigned version_major,
) except -1:
if version == NULL: return 0
if version.major >= version_major: return 0
raise BufferError("dlpack: unexpected version")

cdef inline int dlpack_is_contig(
const DLTensor *dltensor,
char order,
Expand Down Expand Up @@ -189,15 +213,25 @@ cdef int Py_CheckDLPackBuffer(object obj) noexcept:
except: return 0 # ~> uncovered # noqa

cdef int Py_GetDLPackBuffer(object obj, Py_buffer *view, int flags) except -1:
cdef unsigned version_major = 1
cdef const char *capsulename = b"dltensor_versioned"
cdef const char *usedcapsulename = b"used_dltensor_versioned"
cdef uint64_t READONLY = DLPACK_FLAG_BITMASK_READ_ONLY
cdef object dlpack
cdef object dlpack_device
cdef tuple max_version
cdef unsigned device_type
cdef int device_id
cdef object capsule
cdef DLManagedTensor *managed
cdef void *pointer
cdef DLManagedTensorVersioned *managed1 = NULL
cdef DLManagedTensor *managed0 = NULL
cdef const DLPackVersion *dlversion
cdef const DLTensor *dltensor
cdef void *buf
cdef Py_ssize_t size
cdef Py_ssize_t itemsize
cdef char *format
cdef bint readonly

try:
Expand All @@ -207,37 +241,65 @@ cdef int Py_GetDLPackBuffer(object obj, Py_buffer *view, int flags) except -1:
raise NotImplementedError("dlpack: missing support")

device_type, device_id = dlpack_device()
if device_type == kDLCPU:
capsule = dlpack()
else:
capsule = dlpack(stream=-1)
<void> device_id # unused
if not PyCapsule_IsValid(capsule, b"dltensor"):
raise BufferError("dlpack: invalid capsule object")
<void> device_id # unused

managed = <DLManagedTensor*> PyCapsule_GetPointer(capsule, b"dltensor")
dltensor = &managed.dl_tensor
try: # DLPack v1.0+
max_version = (version_major, 0)
if device_type == kDLCPU:
capsule = dlpack(max_version=max_version, copy=False)
else:
capsule = dlpack(stream=-1, max_version=max_version, copy=False)
except TypeError: # DLPack v0.x
version_major = 0
capsulename = b"dltensor"
usedcapsulename = b"used_dltensor"
if device_type == kDLCPU:
capsule = dlpack()
else:
capsule = dlpack(stream=-1)

if not PyCapsule_IsValid(capsule, capsulename):
raise BufferError("dlpack: invalid capsule object")
pointer = PyCapsule_GetPointer(capsule, capsulename)
if version_major >= 1:
managed1 = <DLManagedTensorVersioned*> pointer
dlversion = &managed1.version
dltensor = &managed1.dl_tensor
readonly = (managed1.flags & READONLY) == READONLY
else:
managed0 = <DLManagedTensor*> pointer
dlversion = NULL
dltensor = &managed0.dl_tensor
readonly = 0

try:
dlpack_check_version(dlversion, version_major)
dlpack_check_shape(dltensor)
dlpack_check_contig(dltensor)

buf = dlpack_get_data(dltensor)
size = dlpack_get_size(dltensor)
readonly = 0

PyBuffer_FillInfo(view, obj, buf, size, readonly, flags)

if (flags & PyBUF_FORMAT) == PyBUF_FORMAT:
view.format = dlpack_get_format(dltensor)
if view.format != BYTE_FMT:
view.itemsize = dlpack_get_itemsize(dltensor)
itemsize = dlpack_get_itemsize(dltensor)
format = dlpack_get_format(dltensor)
finally:
if managed.deleter != NULL:
managed.deleter(managed)
PyCapsule_SetName(capsule, b"used_dltensor")
if managed1 != NULL:
if managed1.deleter != NULL:
managed1.deleter(managed1)
if managed0 != NULL:
if managed0.deleter != NULL:
managed0.deleter(managed0)
PyCapsule_SetName(capsule, usedcapsulename)
del capsule

if PYPY and readonly and ((flags & PyBUF_WRITABLE) == PyBUF_WRITABLE):
raise BufferError("Object is not writable") # ~> pypy

PyBuffer_FillInfo(view, obj, buf, size, readonly, flags)

if (flags & PyBUF_FORMAT) == PyBUF_FORMAT:
view.format = format
if view.format != BYTE_FMT:
view.itemsize = itemsize

return <int> device_type

# -----------------------------------------------------------------------------
85 changes: 70 additions & 15 deletions test/dlpackimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
if hasattr(sys, 'pypy_version_info'):
raise ImportError("unsupported on PyPy")

class DLPackVersion(ctypes.Structure):
_fields_ = [
("major", ctypes.c_uint32),
("minor", ctypes.c_uint32),
]

class DLDeviceType(IntEnum):
kDLCPU = 1
kDLCUDA = 2
Expand Down Expand Up @@ -63,9 +69,24 @@ class DLManagedTensor(ctypes.Structure):
("deleter", DLManagedTensorDeleter),
]

DLPACK_FLAG_BITMASK_READ_ONLY = 1 << 0
DLPACK_FLAG_BITMASK_IS_COPIED = 1 << 1

DLManagedTensorVersionedDeleter = ctypes.CFUNCTYPE(None, ctypes.c_void_p)

class DLManagedTensorVersioned(ctypes.Structure):
_fields_ = [
("version", DLPackVersion),
("manager_ctx", ctypes.c_void_p),
("deleter", DLManagedTensorDeleter),
("flags", ctypes.c_uint64),
("dl_tensor", DLTensor),
]

pyapi = ctypes.pythonapi

DLManagedTensor_p = ctypes.POINTER(DLManagedTensor)
DLManagedTensorVersioned_p = ctypes.POINTER(DLManagedTensorVersioned)

Py_IncRef = pyapi.Py_IncRef
Py_IncRef.restype = None
Expand Down Expand Up @@ -98,6 +119,13 @@ class DLManagedTensor(ctypes.Structure):
PyCapsule_GetContext.argtypes = [ctypes.py_object]


def make_dl_version(major, minor):
version = DLPackVersion()
version.major = major
version.minor = minor
return version


def make_dl_datatype(typecode, itemsize):
code = None
bits = itemsize * 8
Expand Down Expand Up @@ -195,11 +223,27 @@ def dl_managed_tensor_deleter(void_p):
if False: Py_DecRef(py_obj)


def make_dl_managed_tensor(obj):
managed = DLManagedTensor()
managed.dl_tensor = make_dl_tensor(obj)
managed.manager_ctx = make_dl_manager_ctx(obj)
managed.deleter = dl_managed_tensor_deleter
@DLManagedTensorVersionedDeleter
def dl_managed_tensor_versioned_deleter(void_p):
managed = ctypes.cast(void_p, DLManagedTensorVersioned_p)
manager_ctx = managed.contents.manager_ctx
py_obj = ctypes.cast(manager_ctx, ctypes.py_object)
if False: Py_DecRef(py_obj)


def make_dl_managed_tensor(obj, versioned=False):
if versioned:
managed = DLManagedTensorVersioned()
managed.version = make_dl_version(1, 0)
managed.manager_ctx = make_dl_manager_ctx(obj)
managed.deleter = dl_managed_tensor_versioned_deleter
managed.flags = 0
managed.dl_tensor = make_dl_tensor(obj)
else:
managed = DLManagedTensor()
managed.dl_tensor = make_dl_tensor(obj)
managed.manager_ctx = make_dl_manager_ctx(obj)
managed.deleter = dl_managed_tensor_deleter
return managed


Expand All @@ -213,22 +257,33 @@ def make_py_context(context):
@PyCapsule_Destructor
def py_capsule_destructor(void_p):
capsule = ctypes.cast(void_p, ctypes.py_object)
if PyCapsule_IsValid(capsule, b"dltensor"):
pointer = PyCapsule_GetPointer(capsule, b"dltensor")
managed = ctypes.cast(pointer, DLManagedTensor_p)
deleter = managed.contents.deleter
if deleter:
deleter(managed)
for py_capsule_name, dl_managed_tensor_type_p in (
(b"dltensor_versioned", DLManagedTensorVersioned_p),
(b"dltensor", DLManagedTensor_p),
):
if PyCapsule_IsValid(capsule, py_capsule_name):
pointer = PyCapsule_GetPointer(capsule, py_capsule_name)
managed = ctypes.cast(pointer, dl_managed_tensor_type_p)
deleter = managed.contents.deleter
if deleter:
deleter(managed)
break
context = PyCapsule_GetContext(capsule)
managed = ctypes.cast(context, ctypes.py_object)
Py_DecRef(managed)


def make_py_capsule(managed):
if not isinstance(managed, DLManagedTensor):
managed = make_dl_managed_tensor(managed)
def make_py_capsule(managed, versioned=False):
if versioned >= 1:
py_capsule_name = b"dltensor_versioned"
if not isinstance(managed, DLManagedTensorVersioned):
managed = make_dl_managed_tensor_versioned(managed)
else:
py_capsule_name = b"dltensor"
if not isinstance(managed, DLManagedTensor):
managed = make_dl_managed_tensor(managed)
pointer = ctypes.pointer(managed)
capsule = PyCapsule_New(pointer, b"dltensor", py_capsule_destructor)
capsule = PyCapsule_New(pointer, py_capsule_name, py_capsule_destructor)
context = make_py_context(managed)
PyCapsule_SetContext(capsule, context)
return capsule

0 comments on commit 09bd782

Please sign in to comment.