From bbb5fe0c4788956f50256175745c5de636f69799 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 8 Aug 2018 02:00:44 -0400 Subject: [PATCH] Create a trivial char array object This builds up a basic char array object in C with some Python trappings for handling construction and cleanup of the object. Otherwise provides no interface for accessing this object in Python. Since it's purpose is merely to manage a block of memory requested by the user to be passed off and used by other objects that implement the buffer protocol, there is no need for it to have other functionality. It simply allocates memory from Python's memory allocator and frees it on cleanup. Implements the buffer protocol in Cython for this object. Thus allowing the memory allocated to be reused by NumPy arrays or other Python objects that support the buffer protocol. --- setup.py | 2 + src/pysharedmem.pyx | 90 +++++++++++++++++++++++++++++ tests/test_pysharedmem.py | 117 +++++++++++++++++++++++++++++++++++++- 3 files changed, 207 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index e5dc3d0..b054300 100644 --- a/setup.py +++ b/setup.py @@ -70,6 +70,7 @@ extra_compile_args = [] cython_directives = {} cython_line_directives = {} +cython_compile_time_env = {"PY_VERSION_HEX": sys.hexversion} if "test" in sys.argv: @@ -99,6 +100,7 @@ for em in ext_modules: em.cython_directives = dict(cython_directives) em.cython_line_directives = dict(cython_line_directives) + em.cython_compile_time_env = dict(cython_compile_time_env) setup( diff --git a/src/pysharedmem.pyx b/src/pysharedmem.pyx index 564b9a6..8a8c195 100644 --- a/src/pysharedmem.pyx +++ b/src/pysharedmem.pyx @@ -1,3 +1,93 @@ +# cython: boundscheck = False +# cython: initializedcheck = False +# cython: nonecheck = False +# cython: wraparound = False + + +cimport cpython +cimport cpython.buffer +cimport cpython.bytes +cimport cpython.object +cimport cpython.tuple + cimport pysharedmem include "version.pxi" + +from cpython.buffer cimport PyBuffer_ToContiguous +from cpython.bytes cimport PyBytes_FromStringAndSize +from cpython.object cimport PyObject +from cpython.tuple cimport PyTuple_Pack + +cdef extern from "Python.h": + IF PY_VERSION_HEX >= 0x03040000: + void* PyMem_RawMalloc(size_t n) nogil + void PyMem_RawFree(void *p) nogil + ELSE: + void* PyMem_RawMalloc "PyMem_Malloc" (size_t n) nogil + void PyMem_RawFree "PyMem_Free" (void *p) nogil + +cdef extern from "Python.h": + object PyMemoryView_FromObject(object) + Py_buffer* PyMemoryView_GET_BUFFER(object) + + +cdef class cbuffer: + cdef char[::1] buf + + def __cinit__(self, Py_ssize_t length): + self.buf = None + + if length <= 0: + raise ValueError("length must be positive definite") + + cdef char* ptr = PyMem_RawMalloc(length * sizeof(char)) + if not ptr: + raise MemoryError("unable to allocate buffer") + + self.buf = ptr + + def __getbuffer__(self, Py_buffer *buffer, int flags): + buffer.buf = &self.buf[0] + buffer.obj = self + buffer.len = len(self.buf) + buffer.readonly = 0 + buffer.itemsize = self.buf.itemsize + buffer.format = "c" + buffer.ndim = self.buf.ndim + buffer.shape = self.buf.shape + buffer.strides = self.buf.strides + buffer.suboffsets = NULL + buffer.internal = NULL + + def __releasebuffer__(self, Py_buffer *buffer): + pass + + def __reduce__(self): + cdef bytes buf_bytes = PyBytes_FromStringAndSize( + &self.buf[0], len(self.buf) + ) + + cdef tuple result = PyTuple_Pack(1, buf_bytes) + result = PyTuple_Pack(2, frombuffer, result) + + return result + + def __dealloc__(self): + cdef char* ptr + if self.buf is not None: + ptr = &self.buf[0] + self.buf = None + PyMem_RawFree(ptr) + + +def frombuffer(src): + src = PyMemoryView_FromObject(src) + cdef Py_buffer* src_buf = PyMemoryView_GET_BUFFER(src) + + cdef cbuffer dest = cbuffer(src_buf.len) + cdef char* dest_ptr = &dest.buf[0] + + PyBuffer_ToContiguous(dest_ptr, src_buf, src_buf.len, 'C') + + return dest diff --git a/tests/test_pysharedmem.py b/tests/test_pysharedmem.py index 6d75ad5..593caae 100644 --- a/tests/test_pysharedmem.py +++ b/tests/test_pysharedmem.py @@ -3,12 +3,125 @@ from __future__ import absolute_import +import pickle +import sys import unittest -class TestImportTopLevel(unittest.TestCase): - def runTest(self): +try: + unichr +except NameError: + unichr = chr + + +class TestPySharedMem(unittest.TestCase): + def test_import(self): try: import pysharedmem except ImportError: self.fail("Unable to import `pysharedmem`.") + + + def test_create_empty_mem(self): + import pysharedmem + with self.assertRaises(ValueError): + pysharedmem.cbuffer(0) + + + def test_create_negative_mem(self): + import pysharedmem + with self.assertRaises(ValueError): + pysharedmem.cbuffer(-1) + + + def test_create_singleton_mem(self): + import pysharedmem + pysharedmem.cbuffer(1) + + + def test_memoryview(self): + import pysharedmem + + l = 3 + b = pysharedmem.cbuffer(l) + + m = memoryview(b) + + self.assertEqual(m.readonly, False) + + self.assertEqual(m.format, 'c') + self.assertEqual(m.itemsize, 1) + + self.assertEqual(m.ndim, 1) + self.assertEqual(m.shape, (l,)) + self.assertEqual(m.strides, (1,)) + + if sys.version_info[0] >= 3: + self.assertEqual(m.suboffsets, ()) + else: + self.assertEqual(m.suboffsets, None) + + + @unittest.skipIf(sys.version_info[0] < 3, + "This test requires Python 3.") + def test_memoryview_py3(self): + import pysharedmem + + l = 3 + b = pysharedmem.cbuffer(l) + + m = memoryview(b) + + self.assertIs(m.obj, b) + + self.assertTrue(m.c_contiguous) + self.assertTrue(m.contiguous) + self.assertTrue(m.f_contiguous) + + self.assertEqual(m.readonly, False) + + self.assertEqual(m.format, 'c') + self.assertEqual(m.itemsize, 1) + + self.assertEqual(m.ndim, 1) + self.assertEqual(m.nbytes, l) + self.assertEqual(m.shape, (l,)) + self.assertEqual(m.strides, (1,)) + self.assertEqual(m.suboffsets, ()) + + + def test_assignment(self): + import pysharedmem + + l = 3 + b = pysharedmem.cbuffer(l) + + m = memoryview(b) + for i in range(len(m)): + m[i] = unichr(i).encode("utf-8") + + a1 = bytearray(m) + a2 = bytearray(b) + + self.assertIsNot(a1, a2) + self.assertEqual(a1, a2) + + + def test_pickle(self): + import pysharedmem + + l = 3 + b1 = pysharedmem.cbuffer(l) + + m = memoryview(b1) + for i in range(len(m)): + m[i] = unichr(i).encode("utf-8") + + b2 = pickle.loads(pickle.dumps(b1)) + + a1 = bytes(memoryview(b1)) + a2 = bytes(memoryview(b2)) + + self.assertIsNot(b1, b2) + self.assertIsNot(a1, a2) + self.assertEqual(a1, a2)