Skip to content

Commit

Permalink
Create a trivial char array object
Browse files Browse the repository at this point in the history
This builds up a basic char array object in C with some Python trappings
for handling construction and cleanup of the object. Otherwise provides
no interface for accessing this object in Python. Since it's purpose is
merely to manage a block of memory requested by the user to be passed
off and used by other objects that implement the buffer protocol, there
is no need for it to have other functionality. It simply allocates
memory from Python's memory allocator and frees it on cleanup.
Implements the buffer protocol in Cython for this object. Thus allowing
the memory allocated to be reused by NumPy arrays or other Python
objects that support the buffer protocol.
  • Loading branch information
jakirkham committed Aug 8, 2018
1 parent a6f6a72 commit bbb5fe0
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 2 deletions.
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
extra_compile_args = []
cython_directives = {}
cython_line_directives = {}
cython_compile_time_env = {"PY_VERSION_HEX": sys.hexversion}


if "test" in sys.argv:
Expand Down Expand Up @@ -99,6 +100,7 @@
for em in ext_modules:
em.cython_directives = dict(cython_directives)
em.cython_line_directives = dict(cython_line_directives)
em.cython_compile_time_env = dict(cython_compile_time_env)


setup(
Expand Down
90 changes: 90 additions & 0 deletions src/pysharedmem.pyx
Original file line number Diff line number Diff line change
@@ -1,3 +1,93 @@
# cython: boundscheck = False
# cython: initializedcheck = False
# cython: nonecheck = False
# cython: wraparound = False


cimport cpython
cimport cpython.buffer
cimport cpython.bytes
cimport cpython.object
cimport cpython.tuple

cimport pysharedmem

include "version.pxi"

from cpython.buffer cimport PyBuffer_ToContiguous
from cpython.bytes cimport PyBytes_FromStringAndSize
from cpython.object cimport PyObject
from cpython.tuple cimport PyTuple_Pack

cdef extern from "Python.h":
IF PY_VERSION_HEX >= 0x03040000:
void* PyMem_RawMalloc(size_t n) nogil
void PyMem_RawFree(void *p) nogil
ELSE:
void* PyMem_RawMalloc "PyMem_Malloc" (size_t n) nogil
void PyMem_RawFree "PyMem_Free" (void *p) nogil

cdef extern from "Python.h":
object PyMemoryView_FromObject(object)
Py_buffer* PyMemoryView_GET_BUFFER(object)


cdef class cbuffer:
cdef char[::1] buf

def __cinit__(self, Py_ssize_t length):
self.buf = None

if length <= 0:
raise ValueError("length must be positive definite")

cdef char* ptr = <char*>PyMem_RawMalloc(length * sizeof(char))
if not ptr:
raise MemoryError("unable to allocate buffer")

self.buf = <char[:length]>ptr

def __getbuffer__(self, Py_buffer *buffer, int flags):
buffer.buf = &self.buf[0]
buffer.obj = self
buffer.len = len(self.buf)
buffer.readonly = 0
buffer.itemsize = self.buf.itemsize
buffer.format = "c"
buffer.ndim = self.buf.ndim
buffer.shape = self.buf.shape
buffer.strides = self.buf.strides
buffer.suboffsets = NULL
buffer.internal = NULL

def __releasebuffer__(self, Py_buffer *buffer):
pass

def __reduce__(self):
cdef bytes buf_bytes = PyBytes_FromStringAndSize(
&self.buf[0], len(self.buf)
)

cdef tuple result = PyTuple_Pack(1, <PyObject*>buf_bytes)
result = PyTuple_Pack(2, <PyObject*>frombuffer, <PyObject*>result)

return result

def __dealloc__(self):
cdef char* ptr
if self.buf is not None:
ptr = &self.buf[0]
self.buf = None
PyMem_RawFree(<void*>ptr)


def frombuffer(src):
src = PyMemoryView_FromObject(src)
cdef Py_buffer* src_buf = PyMemoryView_GET_BUFFER(src)

cdef cbuffer dest = cbuffer(src_buf.len)
cdef char* dest_ptr = &dest.buf[0]

PyBuffer_ToContiguous(<void*>dest_ptr, src_buf, src_buf.len, 'C')

return dest
117 changes: 115 additions & 2 deletions tests/test_pysharedmem.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,125 @@

from __future__ import absolute_import

import pickle
import sys
import unittest


class TestImportTopLevel(unittest.TestCase):
def runTest(self):
try:
unichr
except NameError:
unichr = chr


class TestPySharedMem(unittest.TestCase):
def test_import(self):
try:
import pysharedmem
except ImportError:
self.fail("Unable to import `pysharedmem`.")


def test_create_empty_mem(self):
import pysharedmem
with self.assertRaises(ValueError):
pysharedmem.cbuffer(0)


def test_create_negative_mem(self):
import pysharedmem
with self.assertRaises(ValueError):
pysharedmem.cbuffer(-1)


def test_create_singleton_mem(self):
import pysharedmem
pysharedmem.cbuffer(1)


def test_memoryview(self):
import pysharedmem

l = 3
b = pysharedmem.cbuffer(l)

m = memoryview(b)

self.assertEqual(m.readonly, False)

self.assertEqual(m.format, 'c')
self.assertEqual(m.itemsize, 1)

self.assertEqual(m.ndim, 1)
self.assertEqual(m.shape, (l,))
self.assertEqual(m.strides, (1,))

if sys.version_info[0] >= 3:
self.assertEqual(m.suboffsets, ())
else:
self.assertEqual(m.suboffsets, None)


@unittest.skipIf(sys.version_info[0] < 3,
"This test requires Python 3.")
def test_memoryview_py3(self):
import pysharedmem

l = 3
b = pysharedmem.cbuffer(l)

m = memoryview(b)

self.assertIs(m.obj, b)

self.assertTrue(m.c_contiguous)
self.assertTrue(m.contiguous)
self.assertTrue(m.f_contiguous)

self.assertEqual(m.readonly, False)

self.assertEqual(m.format, 'c')
self.assertEqual(m.itemsize, 1)

self.assertEqual(m.ndim, 1)
self.assertEqual(m.nbytes, l)
self.assertEqual(m.shape, (l,))
self.assertEqual(m.strides, (1,))
self.assertEqual(m.suboffsets, ())


def test_assignment(self):
import pysharedmem

l = 3
b = pysharedmem.cbuffer(l)

m = memoryview(b)
for i in range(len(m)):
m[i] = unichr(i).encode("utf-8")

a1 = bytearray(m)
a2 = bytearray(b)

self.assertIsNot(a1, a2)
self.assertEqual(a1, a2)


def test_pickle(self):
import pysharedmem

l = 3
b1 = pysharedmem.cbuffer(l)

m = memoryview(b1)
for i in range(len(m)):
m[i] = unichr(i).encode("utf-8")

b2 = pickle.loads(pickle.dumps(b1))

a1 = bytes(memoryview(b1))
a2 = bytes(memoryview(b2))

self.assertIsNot(b1, b2)
self.assertIsNot(a1, a2)
self.assertEqual(a1, a2)

0 comments on commit bbb5fe0

Please sign in to comment.