Skip to content

Commit

Permalink
Merge 4fb07db into a6f6a72
Browse files Browse the repository at this point in the history
  • Loading branch information
jakirkham committed Aug 8, 2018
2 parents a6f6a72 + 4fb07db commit 6f80f40
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 2 deletions.
2 changes: 2 additions & 0 deletions setup.py
Expand Up @@ -70,6 +70,7 @@
extra_compile_args = []
cython_directives = {}
cython_line_directives = {}
cython_compile_time_env = {"PY_VERSION_HEX": sys.hexversion}


if "test" in sys.argv:
Expand Down Expand Up @@ -99,6 +100,7 @@
for em in ext_modules:
em.cython_directives = dict(cython_directives)
em.cython_line_directives = dict(cython_line_directives)
em.cython_compile_time_env = dict(cython_compile_time_env)


setup(
Expand Down
91 changes: 91 additions & 0 deletions src/pysharedmem.pyx
@@ -1,3 +1,94 @@
# cython: boundscheck = False
# cython: initializedcheck = False
# cython: nonecheck = False
# cython: wraparound = False


cimport cpython
cimport cpython.buffer
cimport cpython.bytes
cimport cpython.object
cimport cpython.tuple

from cpython.buffer cimport PyBuffer_ToContiguous
from cpython.bytes cimport PyBytes_FromStringAndSize
from cpython.object cimport PyObject
from cpython.tuple cimport PyTuple_Pack

cimport pysharedmem

include "version.pxi"


cdef extern from "Python.h":
IF PY_VERSION_HEX >= 0x03040000:
void* PyMem_RawMalloc(size_t n) nogil
void PyMem_RawFree(void *p) nogil
ELSE:
void* PyMem_RawMalloc "PyMem_Malloc" (size_t n) nogil
void PyMem_RawFree "PyMem_Free" (void *p) nogil

cdef extern from "Python.h":
object PyMemoryView_FromObject(object)
Py_buffer* PyMemoryView_GET_BUFFER(object)


cdef class cbuffer:
cdef char[::1] buf

def __cinit__(self, Py_ssize_t length):
self.buf = None

if length <= 0:
raise ValueError("length must be positive definite")

cdef char* ptr = <char*>PyMem_RawMalloc(length * sizeof(char))
if not ptr:
raise MemoryError("unable to allocate buffer")

self.buf = <char[:length]>ptr

def __getbuffer__(self, Py_buffer *buffer, int flags):
buffer.buf = &self.buf[0]
buffer.obj = self
buffer.len = len(self.buf)
buffer.readonly = 0
buffer.itemsize = self.buf.itemsize
buffer.format = "c"
buffer.ndim = self.buf.ndim
buffer.shape = self.buf.shape
buffer.strides = self.buf.strides
buffer.suboffsets = NULL
buffer.internal = NULL

def __releasebuffer__(self, Py_buffer *buffer):
pass

def __reduce__(self):
cdef bytes buf_bytes = PyBytes_FromStringAndSize(
&self.buf[0], len(self.buf)
)

cdef tuple result = PyTuple_Pack(1, <PyObject*>buf_bytes)
result = PyTuple_Pack(2, <PyObject*>frombuffer, <PyObject*>result)

return result

def __dealloc__(self):
cdef char* ptr
if self.buf is not None:
ptr = &self.buf[0]
self.buf = None
PyMem_RawFree(<void*>ptr)


def frombuffer(src):
src = PyMemoryView_FromObject(src)
cdef Py_buffer* src_buf = PyMemoryView_GET_BUFFER(src)

cdef cbuffer dest = cbuffer(src_buf.len)
cdef char* dest_ptr = &dest.buf[0]

PyBuffer_ToContiguous(<void*>dest_ptr, src_buf, src_buf.len, 'C')

return dest
117 changes: 115 additions & 2 deletions tests/test_pysharedmem.py
Expand Up @@ -3,12 +3,125 @@

from __future__ import absolute_import

import pickle
import sys
import unittest


class TestImportTopLevel(unittest.TestCase):
def runTest(self):
try:
unichr
except NameError:
unichr = chr


class TestPySharedMem(unittest.TestCase):
def test_import(self):
try:
import pysharedmem
except ImportError:
self.fail("Unable to import `pysharedmem`.")


def test_create_empty_mem(self):
import pysharedmem
with self.assertRaises(ValueError):
pysharedmem.cbuffer(0)


def test_create_negative_mem(self):
import pysharedmem
with self.assertRaises(ValueError):
pysharedmem.cbuffer(-1)


def test_create_singleton_mem(self):
import pysharedmem
pysharedmem.cbuffer(1)


def test_memoryview(self):
import pysharedmem

l = 3
b = pysharedmem.cbuffer(l)

m = memoryview(b)

self.assertEqual(m.readonly, False)

self.assertEqual(m.format, 'c')
self.assertEqual(m.itemsize, 1)

self.assertEqual(m.ndim, 1)
self.assertEqual(m.shape, (l,))
self.assertEqual(m.strides, (1,))

if sys.version_info[0] >= 3:
self.assertEqual(m.suboffsets, ())
else:
self.assertEqual(m.suboffsets, None)


@unittest.skipIf(sys.version_info[0] < 3,
"This test requires Python 3.")
def test_memoryview_py3(self):
import pysharedmem

l = 3
b = pysharedmem.cbuffer(l)

m = memoryview(b)

self.assertIs(m.obj, b)

self.assertTrue(m.c_contiguous)
self.assertTrue(m.contiguous)
self.assertTrue(m.f_contiguous)

self.assertEqual(m.readonly, False)

self.assertEqual(m.format, 'c')
self.assertEqual(m.itemsize, 1)

self.assertEqual(m.ndim, 1)
self.assertEqual(m.nbytes, l)
self.assertEqual(m.shape, (l,))
self.assertEqual(m.strides, (1,))
self.assertEqual(m.suboffsets, ())


def test_assignment(self):
import pysharedmem

l = 3
b = pysharedmem.cbuffer(l)

m = memoryview(b)
for i in range(len(m)):
m[i] = unichr(i).encode("utf-8")

a1 = bytearray(m)
a2 = bytearray(b)

self.assertIsNot(a1, a2)
self.assertEqual(a1, a2)


def test_pickle(self):
import pysharedmem

l = 3
b1 = pysharedmem.cbuffer(l)

m = memoryview(b1)
for i in range(len(m)):
m[i] = unichr(i).encode("utf-8")

b2 = pickle.loads(pickle.dumps(b1))

a1 = bytes(memoryview(b1))
a2 = bytes(memoryview(b2))

self.assertIsNot(b1, b2)
self.assertIsNot(a1, a2)
self.assertEqual(a1, a2)

0 comments on commit 6f80f40

Please sign in to comment.