Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: core: make unpickling with encoding='bytes' work #4888

Merged
merged 2 commits into from Jul 22, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
181 changes: 131 additions & 50 deletions numpy/core/src/multiarray/descriptor.c
Expand Up @@ -2369,11 +2369,8 @@ arraydescr_setstate(PyArray_Descr *self, PyObject *args)
{
int elsize = -1, alignment = -1;
int version = 4;
#if defined(NPY_PY3K)
int endian;
#else
char endian;
#endif
PyObject *endian_obj;
PyObject *subarray, *fields, *names = NULL, *metadata=NULL;
int incref_names = 1;
int int_dtypeflags = 0;
Expand All @@ -2390,68 +2387,39 @@ arraydescr_setstate(PyArray_Descr *self, PyObject *args)
}
switch (PyTuple_GET_SIZE(PyTuple_GET_ITEM(args,0))) {
case 9:
#if defined(NPY_PY3K)
#define _ARGSTR_ "(iCOOOiiiO)"
#else
#define _ARGSTR_ "(icOOOiiiO)"
#endif
if (!PyArg_ParseTuple(args, _ARGSTR_, &version, &endian,
if (!PyArg_ParseTuple(args, "(iOOOOiiiO)", &version, &endian_obj,
&subarray, &names, &fields, &elsize,
&alignment, &int_dtypeflags, &metadata)) {
PyErr_Clear();
return NULL;
#undef _ARGSTR_
}
break;
case 8:
#if defined(NPY_PY3K)
#define _ARGSTR_ "(iCOOOiii)"
#else
#define _ARGSTR_ "(icOOOiii)"
#endif
if (!PyArg_ParseTuple(args, _ARGSTR_, &version, &endian,
if (!PyArg_ParseTuple(args, "(iOOOOiii)", &version, &endian_obj,
&subarray, &names, &fields, &elsize,
&alignment, &int_dtypeflags)) {
return NULL;
#undef _ARGSTR_
}
break;
case 7:
#if defined(NPY_PY3K)
#define _ARGSTR_ "(iCOOOii)"
#else
#define _ARGSTR_ "(icOOOii)"
#endif
if (!PyArg_ParseTuple(args, _ARGSTR_, &version, &endian,
if (!PyArg_ParseTuple(args, "(iOOOOii)", &version, &endian_obj,
&subarray, &names, &fields, &elsize,
&alignment)) {
return NULL;
#undef _ARGSTR_
}
break;
case 6:
#if defined(NPY_PY3K)
#define _ARGSTR_ "(iCOOii)"
#else
#define _ARGSTR_ "(icOOii)"
#endif
if (!PyArg_ParseTuple(args, _ARGSTR_, &version,
&endian, &subarray, &fields,
if (!PyArg_ParseTuple(args, "(iOOOii)", &version,
&endian_obj, &subarray, &fields,
&elsize, &alignment)) {
PyErr_Clear();
#undef _ARGSTR_
return NULL;
}
break;
case 5:
version = 0;
#if defined(NPY_PY3K)
#define _ARGSTR_ "(COOii)"
#else
#define _ARGSTR_ "(cOOii)"
#endif
if (!PyArg_ParseTuple(args, _ARGSTR_,
&endian, &subarray, &fields, &elsize,
if (!PyArg_ParseTuple(args, "(OOOii)",
&endian_obj, &subarray, &fields, &elsize,
&alignment)) {
#undef _ARGSTR_
return NULL;
}
break;
Expand Down Expand Up @@ -2494,11 +2462,55 @@ arraydescr_setstate(PyArray_Descr *self, PyObject *args)
}
}

/* Parse endian */
if (PyUnicode_Check(endian_obj) || PyBytes_Check(endian_obj)) {
PyObject *tmp = NULL;
char *str;
Py_ssize_t len;

if (PyUnicode_Check(endian_obj)) {
tmp = PyUnicode_AsASCIIString(endian_obj);
if (tmp == NULL) {
return NULL;
}
endian_obj = tmp;
}

if (PyBytes_AsStringAndSize(endian_obj, &str, &len) == -1) {
Py_XDECREF(tmp);
return NULL;
}
if (len != 1) {
PyErr_SetString(PyExc_ValueError,
"endian is not 1-char string in Numpy dtype unpickling");
Py_XDECREF(tmp);
return NULL;
}
endian = str[0];
Py_XDECREF(tmp);
}
else {
PyErr_SetString(PyExc_ValueError,
"endian is not a string in Numpy dtype unpickling");
return NULL;
}

if ((fields == Py_None && names != Py_None) ||
(names == Py_None && fields != Py_None)) {
PyErr_Format(PyExc_ValueError,
"inconsistent fields and names");
"inconsistent fields and names in Numpy dtype unpickling");
return NULL;
}

if (names != Py_None && !PyTuple_Check(names)) {
PyErr_Format(PyExc_ValueError,
"non-tuple names in Numpy dtype unpickling");
return NULL;
}

if (fields != Py_None && !PyDict_Check(fields)) {
PyErr_Format(PyExc_ValueError,
"non-dict fields in Numpy dtype unpickling");
return NULL;
}

Expand Down Expand Up @@ -2563,13 +2575,82 @@ arraydescr_setstate(PyArray_Descr *self, PyObject *args)
}

if (fields != Py_None) {
Py_XDECREF(self->fields);
self->fields = fields;
Py_INCREF(fields);
Py_XDECREF(self->names);
self->names = names;
if (incref_names) {
Py_INCREF(names);
/*
* Ensure names are of appropriate string type
*/
Py_ssize_t i;
int names_ok = 1;
PyObject *name;

for (i = 0; i < PyTuple_GET_SIZE(names); ++i) {
name = PyTuple_GET_ITEM(names, i);
if (!PyUString_Check(name)) {
names_ok = 0;
break;
}
}

if (names_ok) {
Py_XDECREF(self->fields);
self->fields = fields;
Py_INCREF(fields);
Py_XDECREF(self->names);
self->names = names;
if (incref_names) {
Py_INCREF(names);
}
}
else {
#if defined(NPY_PY3K)
/*
* To support pickle.load(f, encoding='bytes') for loading Py2
* generated pickles on Py3, we need to be more lenient and convert
* field names from byte strings to unicode.
*/
PyObject *tmp, *new_name, *field;

tmp = PyDict_New();
if (tmp == NULL) {
return NULL;
}
Py_XDECREF(self->fields);
self->fields = tmp;

tmp = PyTuple_New(PyTuple_GET_SIZE(names));
if (tmp == NULL) {
return NULL;
}
Py_XDECREF(self->names);
self->names = tmp;

for (i = 0; i < PyTuple_GET_SIZE(names); ++i) {
name = PyTuple_GET_ITEM(names, i);
field = PyDict_GetItem(fields, name);
if (!field) {
return NULL;
}

if (PyUnicode_Check(name)) {
new_name = name;
Py_INCREF(new_name);
}
else {
new_name = PyUnicode_FromEncodedObject(name, "ASCII", "strict");
if (new_name == NULL) {
return NULL;
}
}

PyTuple_SET_ITEM(self->names, i, new_name);
if (PyDict_SetItem(self->fields, new_name, field) != 0) {
return NULL;
}
}
#else
PyErr_Format(PyExc_ValueError,
"non-string names in Numpy dtype unpickling");
return NULL;
#endif
}
}

Expand Down
35 changes: 35 additions & 0 deletions numpy/core/tests/test_regression.py
Expand Up @@ -398,6 +398,41 @@ def __getitem__(self, key):

assert_raises(KeyError, np.lexsort, BuggySequence())

def test_pickle_py2_bytes_encoding(self):
# Check that arrays and scalars pickled on Py2 are
# unpickleable on Py3 using encoding='bytes'

test_data = [
# (original, py2_pickle)
(np.unicode_('\u6f2c'),
asbytes("cnumpy.core.multiarray\nscalar\np0\n(cnumpy\ndtype\np1\n"
"(S'U1'\np2\nI0\nI1\ntp3\nRp4\n(I3\nS'<'\np5\nNNNI4\nI4\n"
"I0\ntp6\nbS',o\\x00\\x00'\np7\ntp8\nRp9\n.")),

(np.array([9e123], dtype=np.float64),
asbytes("cnumpy.core.multiarray\n_reconstruct\np0\n(cnumpy\nndarray\n"
"p1\n(I0\ntp2\nS'b'\np3\ntp4\nRp5\n(I1\n(I1\ntp6\ncnumpy\ndtype\n"
"p7\n(S'f8'\np8\nI0\nI1\ntp9\nRp10\n(I3\nS'<'\np11\nNNNI-1\nI-1\n"
"I0\ntp12\nbI00\nS'O\\x81\\xb7Z\\xaa:\\xabY'\np13\ntp14\nb.")),

(np.array([(9e123,)], dtype=[('name', float)]),
asbytes("cnumpy.core.multiarray\n_reconstruct\np0\n(cnumpy\nndarray\np1\n"
"(I0\ntp2\nS'b'\np3\ntp4\nRp5\n(I1\n(I1\ntp6\ncnumpy\ndtype\np7\n"
"(S'V8'\np8\nI0\nI1\ntp9\nRp10\n(I3\nS'|'\np11\nN(S'name'\np12\ntp13\n"
"(dp14\ng12\n(g7\n(S'f8'\np15\nI0\nI1\ntp16\nRp17\n(I3\nS'<'\np18\nNNNI-1\n"
"I-1\nI0\ntp19\nbI0\ntp20\nsI8\nI1\nI0\ntp21\n"
"bI00\nS'O\\x81\\xb7Z\\xaa:\\xabY'\np22\ntp23\nb.")),
]

if sys.version_info[:2] >= (3, 4):
# encoding='bytes' was added in Py3.4
for original, data in test_data:
result = pickle.loads(data, encoding='bytes')
assert_equal(result, original)

if isinstance(result, np.ndarray) and result.dtype.names:
for name in result.dtype.names:
assert_(isinstance(name, str))

def test_pickle_dtype(self,level=rlevel):
"""Ticket #251"""
Expand Down