diff --git a/setup.py b/setup.py index e5dc3d0..b054300 100644 --- a/setup.py +++ b/setup.py @@ -70,6 +70,7 @@ extra_compile_args = [] cython_directives = {} cython_line_directives = {} +cython_compile_time_env = {"PY_VERSION_HEX": sys.hexversion} if "test" in sys.argv: @@ -99,6 +100,7 @@ for em in ext_modules: em.cython_directives = dict(cython_directives) em.cython_line_directives = dict(cython_line_directives) + em.cython_compile_time_env = dict(cython_compile_time_env) setup( diff --git a/src/pysharedmem.pyx b/src/pysharedmem.pyx index 564b9a6..3d00753 100644 --- a/src/pysharedmem.pyx +++ b/src/pysharedmem.pyx @@ -1,3 +1,93 @@ +cimport cpython +cimport cpython.buffer +cimport cpython.bytes + +from cpython.buffer cimport ( + PyObject_GetBuffer, + PyBuffer_Release, + PyBuffer_ToContiguous, +) +from cpython.bytes cimport ( + PyBytes_FromStringAndSize, +) + +cimport cython + cimport pysharedmem include "version.pxi" + + +cdef extern from "Python.h": + IF PY_VERSION_HEX >= 0x03040000: + void* PyMem_RawMalloc(size_t n) nogil + void PyMem_RawFree(void *p) nogil + ELSE: + void* PyMem_RawMalloc "PyMem_Malloc" (size_t n) nogil + void PyMem_RawFree "PyMem_Free" (void *p) nogil + + + +cdef class cbuffer: + cdef char[::1] buf + + + def __cinit__(self, Py_ssize_t length): + self.buf = None + + if length <= 0: + raise ValueError("length must be positive definite") + + cdef char* ptr = PyMem_RawMalloc(length * sizeof(char)) + if not ptr: + raise MemoryError("unable to allocate buffer") + + self.buf = ptr + + + def __getbuffer__(self, Py_buffer *buffer, int flags): + buffer.buf = &self.buf[0] + buffer.obj = self + buffer.len = len(self.buf) + buffer.readonly = 0 + buffer.itemsize = self.buf.itemsize + buffer.format = "c" + buffer.ndim = self.buf.ndim + buffer.shape = self.buf.shape + buffer.strides = self.buf.strides + buffer.suboffsets = NULL + buffer.internal = NULL + + + def __releasebuffer__(self, Py_buffer *buffer): + pass + + + def __reduce__(self): + cdef bytes buf_bytes = PyBytes_FromStringAndSize( + &self.buf[0], len(self.buf) + ) + return (frombuffer, (buf_bytes,)) + + + @cython.nonecheck(False) + def __dealloc__(self): + cdef char* ptr + if self.buf is not None: + ptr = &self.buf[0] + self.buf = None + PyMem_RawFree(ptr) + + +def frombuffer(object src): + cdef Py_buffer src_buf + cdef cbuffer dest + + PyObject_GetBuffer(src, &src_buf, 0) + + dest = cbuffer(src_buf.len) + PyBuffer_ToContiguous(&dest.buf[0], &src_buf, src_buf.len, 'C') + + PyBuffer_Release(&src_buf) + + return dest diff --git a/tests/test_pysharedmem.py b/tests/test_pysharedmem.py index 6d75ad5..593caae 100644 --- a/tests/test_pysharedmem.py +++ b/tests/test_pysharedmem.py @@ -3,12 +3,125 @@ from __future__ import absolute_import +import pickle +import sys import unittest -class TestImportTopLevel(unittest.TestCase): - def runTest(self): +try: + unichr +except NameError: + unichr = chr + + +class TestPySharedMem(unittest.TestCase): + def test_import(self): try: import pysharedmem except ImportError: self.fail("Unable to import `pysharedmem`.") + + + def test_create_empty_mem(self): + import pysharedmem + with self.assertRaises(ValueError): + pysharedmem.cbuffer(0) + + + def test_create_negative_mem(self): + import pysharedmem + with self.assertRaises(ValueError): + pysharedmem.cbuffer(-1) + + + def test_create_singleton_mem(self): + import pysharedmem + pysharedmem.cbuffer(1) + + + def test_memoryview(self): + import pysharedmem + + l = 3 + b = pysharedmem.cbuffer(l) + + m = memoryview(b) + + self.assertEqual(m.readonly, False) + + self.assertEqual(m.format, 'c') + self.assertEqual(m.itemsize, 1) + + self.assertEqual(m.ndim, 1) + self.assertEqual(m.shape, (l,)) + self.assertEqual(m.strides, (1,)) + + if sys.version_info[0] >= 3: + self.assertEqual(m.suboffsets, ()) + else: + self.assertEqual(m.suboffsets, None) + + + @unittest.skipIf(sys.version_info[0] < 3, + "This test requires Python 3.") + def test_memoryview_py3(self): + import pysharedmem + + l = 3 + b = pysharedmem.cbuffer(l) + + m = memoryview(b) + + self.assertIs(m.obj, b) + + self.assertTrue(m.c_contiguous) + self.assertTrue(m.contiguous) + self.assertTrue(m.f_contiguous) + + self.assertEqual(m.readonly, False) + + self.assertEqual(m.format, 'c') + self.assertEqual(m.itemsize, 1) + + self.assertEqual(m.ndim, 1) + self.assertEqual(m.nbytes, l) + self.assertEqual(m.shape, (l,)) + self.assertEqual(m.strides, (1,)) + self.assertEqual(m.suboffsets, ()) + + + def test_assignment(self): + import pysharedmem + + l = 3 + b = pysharedmem.cbuffer(l) + + m = memoryview(b) + for i in range(len(m)): + m[i] = unichr(i).encode("utf-8") + + a1 = bytearray(m) + a2 = bytearray(b) + + self.assertIsNot(a1, a2) + self.assertEqual(a1, a2) + + + def test_pickle(self): + import pysharedmem + + l = 3 + b1 = pysharedmem.cbuffer(l) + + m = memoryview(b1) + for i in range(len(m)): + m[i] = unichr(i).encode("utf-8") + + b2 = pickle.loads(pickle.dumps(b1)) + + a1 = bytes(memoryview(b1)) + a2 = bytes(memoryview(b2)) + + self.assertIsNot(b1, b2) + self.assertIsNot(a1, a2) + self.assertEqual(a1, a2)