Skip to content

Commit

Permalink
BUG: Provide correct format in Py_buffer for scalars (#10564)
Browse files Browse the repository at this point in the history
Unifies scalar and ndarray pep3118 format string generation
  • Loading branch information
vanossj authored and eric-wieser committed Mar 25, 2018
1 parent b296ad3 commit e4d678a
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 46 deletions.
143 changes: 113 additions & 30 deletions numpy/core/src/multiarray/buffer.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "common.h"
#include "numpyos.h"
#include "arrayobject.h"
#include "scalartypes.h"

/*************************************************************************
**************** Implement Buffer Protocol ****************************
Expand Down Expand Up @@ -176,7 +177,7 @@ _is_natively_aligned_at(PyArray_Descr *descr,

static int
_buffer_format_string(PyArray_Descr *descr, _tmp_string_t *str,
PyArrayObject* arr, Py_ssize_t *offset,
PyObject* obj, Py_ssize_t *offset,
char *active_byteorder)
{
int k;
Expand Down Expand Up @@ -223,7 +224,7 @@ _buffer_format_string(PyArray_Descr *descr, _tmp_string_t *str,
Py_DECREF(subarray_tuple);

old_offset = *offset;
ret = _buffer_format_string(descr->subarray->base, str, arr, offset,
ret = _buffer_format_string(descr->subarray->base, str, obj, offset,
active_byteorder);
*offset = old_offset + (*offset - old_offset) * total_count;
return ret;
Expand Down Expand Up @@ -265,7 +266,7 @@ _buffer_format_string(PyArray_Descr *descr, _tmp_string_t *str,
}

/* Insert child item */
_buffer_format_string(child, str, arr, offset,
_buffer_format_string(child, str, obj, offset,
active_byteorder);

/* Insert field name */
Expand Down Expand Up @@ -302,6 +303,7 @@ _buffer_format_string(PyArray_Descr *descr, _tmp_string_t *str,
}
else {
int is_standard_size = 1;
int is_natively_aligned;
int is_native_only_type = (descr->type_num == NPY_LONGDOUBLE ||
descr->type_num == NPY_CLONGDOUBLE);
if (sizeof(npy_longlong) != 8) {
Expand All @@ -312,8 +314,16 @@ _buffer_format_string(PyArray_Descr *descr, _tmp_string_t *str,

*offset += descr->elsize;

if (descr->byteorder == '=' &&
_is_natively_aligned_at(descr, arr, *offset)) {
if (PyArray_IsScalar(obj, Generic)) {
/* scalars are always natively aligned */
is_natively_aligned = 1;
}
else {
is_natively_aligned = _is_natively_aligned_at(descr,
(PyArrayObject*)obj, *offset);
}

if (descr->byteorder == '=' && is_natively_aligned) {
/* Prefer native types, to cater for Cython */
is_standard_size = 0;
if (*active_byteorder != '@') {
Expand Down Expand Up @@ -448,43 +458,61 @@ static PyObject *_buffer_info_cache = NULL;

/* Fill in the info structure */
static _buffer_info_t*
_buffer_info_new(PyArrayObject *arr)
_buffer_info_new(PyObject *obj)
{
_buffer_info_t *info;
_tmp_string_t fmt = {NULL, 0, 0};
int k;
PyArray_Descr *descr = NULL;
int err = 0;

info = malloc(sizeof(_buffer_info_t));
if (info == NULL) {
goto fail;
}

/* Fill in format */
if (_buffer_format_string(PyArray_DESCR(arr), &fmt, arr, NULL, NULL) != 0) {
free(fmt.s);
goto fail;
}
_append_char(&fmt, '\0');
info->format = fmt.s;

/* Fill in shape and strides */
info->ndim = PyArray_NDIM(arr);

if (info->ndim == 0) {
if (PyArray_IsScalar(obj, Generic)) {
descr = PyArray_DescrFromScalar(obj);
if (descr == NULL) {
goto fail;
}
info->ndim = 0;
info->shape = NULL;
info->strides = NULL;
}
else {
info->shape = malloc(sizeof(Py_ssize_t) * PyArray_NDIM(arr) * 2 + 1);
if (info->shape == NULL) {
goto fail;
PyArrayObject * arr = (PyArrayObject *)obj;
descr = PyArray_DESCR(arr);
/* Fill in shape and strides */
info->ndim = PyArray_NDIM(arr);

if (info->ndim == 0) {
info->shape = NULL;
info->strides = NULL;
}
info->strides = info->shape + PyArray_NDIM(arr);
for (k = 0; k < PyArray_NDIM(arr); ++k) {
info->shape[k] = PyArray_DIMS(arr)[k];
info->strides[k] = PyArray_STRIDES(arr)[k];
else {
info->shape = malloc(sizeof(Py_ssize_t) * PyArray_NDIM(arr) * 2 + 1);
if (info->shape == NULL) {
goto fail;
}
info->strides = info->shape + PyArray_NDIM(arr);
for (k = 0; k < PyArray_NDIM(arr); ++k) {
info->shape[k] = PyArray_DIMS(arr)[k];
info->strides[k] = PyArray_STRIDES(arr)[k];
}
}
Py_INCREF(descr);
}

/* Fill in format */
err = _buffer_format_string(descr, &fmt, obj, NULL, NULL);
Py_DECREF(descr);
if (err != 0) {
free(fmt.s);
goto fail;
}
_append_char(&fmt, '\0');
info->format = fmt.s;

return info;

Expand Down Expand Up @@ -530,7 +558,7 @@ _buffer_info_free(_buffer_info_t *info)

/* Get buffer info from the global dictionary */
static _buffer_info_t*
_buffer_get_info(PyObject *arr)
_buffer_get_info(PyObject *obj)
{
PyObject *key = NULL, *item_list = NULL, *item = NULL;
_buffer_info_t *info = NULL, *old_info = NULL;
Expand All @@ -543,13 +571,13 @@ _buffer_get_info(PyObject *arr)
}

/* Compute information */
info = _buffer_info_new((PyArrayObject*)arr);
info = _buffer_info_new(obj);
if (info == NULL) {
return NULL;
}

/* Check if it is identical with an old one; reuse old one, if yes */
key = PyLong_FromVoidPtr((void*)arr);
key = PyLong_FromVoidPtr((void*)obj);
if (key == NULL) {
goto fail;
}
Expand Down Expand Up @@ -627,9 +655,8 @@ _buffer_clear_info(PyObject *arr)
}

/*
* Retrieving buffers
* Retrieving buffers for ndarray
*/

static int
array_getbuffer(PyObject *obj, Py_buffer *view, int flags)
{
Expand Down Expand Up @@ -751,6 +778,62 @@ array_getbuffer(PyObject *obj, Py_buffer *view, int flags)
return -1;
}

/*
* Retrieving buffers for scalars
*/
int
gentype_getbuffer(PyObject *self, Py_buffer *view, int flags)
{
_buffer_info_t *info = NULL;
PyArray_Descr *descr = NULL;
int elsize;

if (flags & PyBUF_WRITABLE) {
PyErr_SetString(PyExc_BufferError, "scalar buffer is readonly");
goto fail;
}

/* Fill in information */
info = _buffer_get_info(self);
if (info == NULL) {
PyErr_SetString(PyExc_BufferError,
"could not get scalar buffer information");
goto fail;
}

view->ndim = info->ndim;
view->shape = info->shape;
view->strides = info->strides;

if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
view->format = info->format;
} else {
view->format = NULL;
}

descr = PyArray_DescrFromScalar(self);
view->buf = (void *)scalar_value(self, descr);
elsize = descr->elsize;
#ifndef Py_UNICODE_WIDE
if (descr->type_num == NPY_UNICODE) {
elsize >>= 1;
}
#endif
view->len = elsize;
view->itemsize = elsize;

Py_DECREF(descr);

view->readonly = 1;
view->suboffsets = NULL;
view->obj = self;
Py_INCREF(self);
return 0;

fail:
view->obj = NULL;
return -1;
}

/*
* NOTE: for backward compatibility (esp. with PyArg_ParseTuple("s#", ...))
Expand Down
3 changes: 3 additions & 0 deletions numpy/core/src/multiarray/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,7 @@ _array_dealloc_buffer_info(PyArrayObject *self);
NPY_NO_EXPORT PyArray_Descr*
_descriptor_from_pep3118_format(char *s);

NPY_NO_EXPORT int
gentype_getbuffer(PyObject *obj, Py_buffer *view, int flags);

#endif
17 changes: 1 addition & 16 deletions numpy/core/src/multiarray/scalartypes.c.src
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "npy_import.h"
#include "dragon4.h"
#include "npy_longdouble.h"
#include "buffer.h"

#include <stdlib.h>

Expand Down Expand Up @@ -2526,22 +2527,6 @@ gentype_getcharbuf(PyObject *self, Py_ssize_t segment, constchar **ptrptr)
}
#endif /* !defined(NPY_PY3K) */


static int
gentype_getbuffer(PyObject *self, Py_buffer *view, int flags)
{
Py_ssize_t len;
void *buf;

/* FIXME: XXX: the format is not implemented! -- this needs more work */

len = gentype_getreadbuf(self, 0, &buf);
return PyBuffer_FillInfo(view, self, buf, len, 1, flags);
}

/* releasebuffer is not needed */


static PyBufferProcs gentype_as_buffer = {
#if !defined(NPY_PY3K)
gentype_getreadbuf, /* bf_getreadbuffer*/
Expand Down
84 changes: 84 additions & 0 deletions numpy/core/tests/test_scalarbuffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""
Test scalar buffer interface adheres to PEP 3118
"""
import sys
import numpy as np
from numpy.testing import run_module_suite, assert_, assert_equal, dec

# PEP3118 format strings for native (standard alignment and byteorder) types
scalars_and_codes = [
(np.bool_, '?'),
(np.byte, 'b'),
(np.short, 'h'),
(np.intc, 'i'),
(np.int_, 'l'),
(np.longlong, 'q'),
(np.ubyte, 'B'),
(np.ushort, 'H'),
(np.uintc, 'I'),
(np.uint, 'L'),
(np.ulonglong, 'Q'),
(np.half, 'e'),
(np.single, 'f'),
(np.double, 'd'),
(np.longdouble, 'g'),
(np.csingle, 'Zf'),
(np.cdouble, 'Zd'),
(np.clongdouble, 'Zg'),
]


class TestScalarPEP3118(object):
skip_if_no_buffer_interface = dec.skipif(sys.version_info.major < 3,
"scalars do not implement buffer interface in Python 2")

@skip_if_no_buffer_interface
def test_scalar_match_array(self):
for scalar, _ in scalars_and_codes:
x = scalar()
a = np.array([], dtype=np.dtype(scalar))
mv_x = memoryview(x)
mv_a = memoryview(a)
assert_equal(mv_x.format, mv_a.format)

@skip_if_no_buffer_interface
def test_scalar_dim(self):
for scalar, _ in scalars_and_codes:
x = scalar()
mv_x = memoryview(x)
assert_equal(mv_x.itemsize, np.dtype(scalar).itemsize)
assert_equal(mv_x.ndim, 0)
assert_equal(mv_x.shape, ())
assert_equal(mv_x.strides, ())
assert_equal(mv_x.suboffsets, ())

@skip_if_no_buffer_interface
def test_scalar_known_code(self):
for scalar, code in scalars_and_codes:
x = scalar()
mv_x = memoryview(x)
assert_equal(mv_x.format, code)

@skip_if_no_buffer_interface
def test_void_scalar_structured_data(self):
dt = np.dtype([('name', np.unicode_, 16), ('grades', np.float64, (2,))])
x = np.array(('ndarray_scalar', (1.2, 3.0)), dtype=dt)[()]
assert_(isinstance(x, np.void))
mv_x = memoryview(x)
expected_size = 16 * np.dtype((np.unicode_, 1)).itemsize
expected_size += 2 * np.dtype((np.float64, 1)).itemsize
assert_equal(mv_x.itemsize, expected_size)
assert_equal(mv_x.ndim, 0)
assert_equal(mv_x.shape, ())
assert_equal(mv_x.strides, ())
assert_equal(mv_x.suboffsets, ())

# check scalar format string against ndarray format string
a = np.array([('Sarah', (8.0, 7.0)), ('John', (6.0, 7.0))], dtype=dt)
assert_(isinstance(a, np.ndarray))
mv_a = memoryview(a)
assert_equal(mv_x.itemsize, mv_a.itemsize)
assert_equal(mv_x.format, mv_a.format)

if __name__ == "__main__":
run_module_suite()

0 comments on commit e4d678a

Please sign in to comment.