diff --git a/src/pysharedmem.pyx b/src/pysharedmem.pyx index 564b9a6..248dbcf 100644 --- a/src/pysharedmem.pyx +++ b/src/pysharedmem.pyx @@ -1,3 +1,67 @@ +cimport cpython +cimport cpython.mem +from cpython.mem cimport PyMem_Malloc, PyMem_Free + +cimport cython + cimport pysharedmem include "version.pxi" + + + +cdef class cbuffer: + cdef char[::1] buf + + + def __cinit__(self, Py_ssize_t length): + self.buf = None + + if length <= 0: + raise ValueError("length must be positive definite") + + cdef void* ptr = PyMem_Malloc(length * sizeof(char)) + if not ptr: + raise MemoryError("unable to allocate buffer") + + self.buf = ptr + + + def __getbuffer__(self, Py_buffer *buffer, int flags): + buffer.buf = &self.buf[0] + buffer.obj = self + buffer.len = len(self.buf) + buffer.readonly = 0 + buffer.itemsize = self.buf.ndim + buffer.format = "c" + buffer.ndim = self.buf.ndim + buffer.shape = self.buf.shape + buffer.strides = self.buf.strides + buffer.suboffsets = NULL + buffer.internal = NULL + + + def __releasebuffer__(self, Py_buffer *buffer): + pass + + + def __reduce__(self): + return (__create_cbuffer_data, (bytes(self.buf),)) + + + @cython.nonecheck(False) + def __dealloc__(self): + if self.buf is None: + PyMem_Free(&self.buf[0]) + self.buf = None + + +def __create_cbuffer_data(bytes data): + cdef Py_ssize_t length = len(data) + cdef cbuffer new = cbuffer(length) + + cdef Py_ssize_t i + for i in range(length): + new.buf[i] = data[i] + + return new diff --git a/tests/test_pysharedmem.py b/tests/test_pysharedmem.py index 6d75ad5..509f7b9 100644 --- a/tests/test_pysharedmem.py +++ b/tests/test_pysharedmem.py @@ -3,12 +3,125 @@ from __future__ import absolute_import +import pickle +import sys import unittest -class TestImportTopLevel(unittest.TestCase): - def runTest(self): +try: + unichr +except NameError: + unichr = chr + + +class TestPySharedMem(unittest.TestCase): + def test_import(self): try: import pysharedmem except ImportError: self.fail("Unable to import `pysharedmem`.") + + + def test_create_empty_mem(self): + import pysharedmem + with self.assertRaises(ValueError): + pysharedmem.cbuffer(0) + + + def test_create_negative_mem(self): + import pysharedmem + with self.assertRaises(ValueError): + pysharedmem.cbuffer(-1) + + + def test_create_singleton_mem(self): + import pysharedmem + pysharedmem.cbuffer(1) + + + def test_memoryview(self): + import pysharedmem + + l = 3 + b = pysharedmem.cbuffer(l) + + m = memoryview(b) + + self.assertEqual(m.readonly, False) + + self.assertEqual(m.format, 'c') + self.assertEqual(m.itemsize, 1) + + self.assertEqual(m.ndim, 1) + self.assertEqual(m.shape, (l,)) + self.assertEqual(m.strides, (1,)) + + if sys.version_info[0] >= 3: + self.assertEqual(m.suboffsets, ()) + else: + self.assertEqual(m.suboffsets, None) + + + @unittest.skipIf(sys.version_info[0] < 3, + "This test requires Python 3.") + def test_memoryview_py3(self): + import pysharedmem + + l = 3 + b = pysharedmem.cbuffer(l) + + m = memoryview(b) + + self.assertIs(m.obj, b) + + self.assertTrue(m.c_contiguous) + self.assertTrue(m.contiguous) + self.assertTrue(m.f_contiguous) + + self.assertEqual(m.readonly, False) + + self.assertEqual(m.format, 'c') + self.assertEqual(m.itemsize, 1) + + self.assertEqual(m.ndim, 1) + self.assertEqual(m.nbytes, l) + self.assertEqual(m.shape, (l,)) + self.assertEqual(m.strides, (1,)) + self.assertEqual(m.suboffsets, ()) + + + def test_assignment(self): + import pysharedmem + + l = 3 + b = pysharedmem.cbuffer(l) + + m = memoryview(b) + for i in range(len(m)): + m[i] = unichr(i).encode("utf-8") + + a1 = bytearray(m) + a2 = bytearray(b) + + self.assertIsNot(a1, a2) + self.assertEqual(a1, a2) + + + def test_pickle(self): + import pysharedmem + + l = 3 + b1 = pysharedmem.cbuffer(l) + + m = memoryview(b1) + for i in range(len(m)): + m[i] = unichr(i).encode("utf-8") + + b2 = pickle.loads(pickle.dumps(b1)) + + a1 = bytearray(b1) + a2 = bytearray(b2) + + self.assertIsNot(b1, b2) + self.assertIsNot(a1, a2) + self.assertEqual(a1, a2)