Skip to content

Commit

Permalink
Merge pull request #3491 from ContinuumIO/astype_fix2
Browse files Browse the repository at this point in the history
Fix creation of string arrays from object types
  • Loading branch information
charris committed Jul 8, 2013
2 parents f1c7766 + 97372db commit 884c403
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 20 deletions.
50 changes: 37 additions & 13 deletions numpy/core/src/multiarray/convert_datatype.c
Expand Up @@ -226,24 +226,36 @@ PyArray_AdaptFlexibleDType(PyObject *data_obj, PyArray_Descr *data_dtype,
break;
case NPY_OBJECT:
size = 64;
/*
* If we're adapting a string dtype for an array of string
* objects, call GetArrayParamsFromObject to figure out
* maximum string size, and use that as new dtype size.
*/
if ((flex_type_num == NPY_STRING ||
flex_type_num == NPY_UNICODE) &&
data_obj != NULL) {
/*
* Convert data array to list of objects since
* GetArrayParamsFromObject won't iterate through
* items in an array.
*/
list = PyArray_ToList(data_obj);
if (list != NULL) {
if (PyArray_CheckScalar(data_obj)) {
PyObject *scalar = PyArray_ToList(data_obj);
if (scalar != NULL) {
PyObject *s = PyObject_Str(scalar);
if (s == NULL) {
Py_DECREF(scalar);
Py_DECREF(*flex_dtype);
*flex_dtype = NULL;
return;
}
else {
size = PyObject_Length(s);
Py_DECREF(s);
}
Py_DECREF(scalar);
}
}
else if (PyArray_Check(data_obj)) {
/*
* Convert data array to list of objects since
* GetArrayParamsFromObject won't iterate over
* array.
*/
list = PyArray_ToList(data_obj);
result = PyArray_GetArrayParamsFromObject(
list,
flex_dtype,
*flex_dtype,
0, &dtype,
&ndim, dims, &arr, NULL);
if (result == 0 && dtype != NULL) {
Expand All @@ -256,6 +268,18 @@ PyArray_AdaptFlexibleDType(PyObject *data_obj, PyArray_Descr *data_dtype,
}
Py_DECREF(list);
}
else if (PyArray_IsPythonScalar(data_obj)) {
PyObject *s = PyObject_Str(data_obj);
if (s == NULL) {
Py_DECREF(*flex_dtype);
*flex_dtype = NULL;
return;
}
else {
size = PyObject_Length(s);
Py_DECREF(s);
}
}
}
break;
case NPY_STRING:
Expand Down
31 changes: 24 additions & 7 deletions numpy/core/src/multiarray/ctors.c
Expand Up @@ -521,7 +521,7 @@ PyArray_AssignFromSequence(PyArrayObject *self, PyObject *v)
*/

static int
discover_itemsize(PyObject *s, int nd, int *itemsize)
discover_itemsize(PyObject *s, int nd, int *itemsize, int size_as_string)
{
int n, r, i;

Expand All @@ -532,14 +532,26 @@ discover_itemsize(PyObject *s, int nd, int *itemsize)

if ((nd == 0) || PyString_Check(s) ||
#if defined(NPY_PY3K)
PyMemoryView_Check(s) ||
PyMemoryView_Check(s) ||
#else
PyBuffer_Check(s) ||
PyBuffer_Check(s) ||
#endif
PyUnicode_Check(s)) {
PyUnicode_Check(s)) {

/* If an object has no length, leave it be */
n = PyObject_Length(s);
if (size_as_string && s != NULL && !PyString_Check(s)) {
PyObject *s_string = PyObject_Str(s);
if (s_string) {
n = PyObject_Length(s_string);
Py_DECREF(s_string);
}
else {
n = -1;
}
}
else {
n = PyObject_Length(s);
}
if (n == -1) {
PyErr_Clear();
}
Expand All @@ -557,7 +569,7 @@ discover_itemsize(PyObject *s, int nd, int *itemsize)
return -1;
}

r = discover_itemsize(e,nd-1,itemsize);
r = discover_itemsize(e, nd - 1, itemsize, size_as_string);
Py_DECREF(e);
if (r == -1) {
return -1;
Expand Down Expand Up @@ -1528,7 +1540,12 @@ PyArray_GetArrayParamsFromObject(PyObject *op,
if ((*out_dtype)->elsize == 0 &&
PyTypeNum_ISEXTENDED((*out_dtype)->type_num)) {
int itemsize = 0;
if (discover_itemsize(op, *out_ndim, &itemsize) < 0) {
int size_as_string = 0;
if ((*out_dtype)->type_num == NPY_STRING ||
(*out_dtype)->type_num == NPY_UNICODE) {
size_as_string = 1;
}
if (discover_itemsize(op, *out_ndim, &itemsize, size_as_string) < 0) {
Py_DECREF(*out_dtype);
if (PyErr_Occurred() &&
PyErr_GivenExceptionMatches(PyErr_Occurred(),
Expand Down
16 changes: 16 additions & 0 deletions numpy/core/tests/test_api.py
Expand Up @@ -259,6 +259,22 @@ def test_array_astype():
assert_equal(a, b)
assert_equal(b.dtype, np.dtype('U10'))

a = np.array(123456789012345678901234567890, dtype='O').astype('S')
assert_array_equal(a, np.array(b'1234567890' * 3, dtype='S30'))
a = np.array(123456789012345678901234567890, dtype='O').astype('U')
assert_array_equal(a, np.array(sixu('1234567890' * 3), dtype='U30'))

a = np.array([123456789012345678901234567890], dtype='O').astype('S')
assert_array_equal(a, np.array(b'1234567890' * 3, dtype='S30'))
a = np.array([123456789012345678901234567890], dtype='O').astype('U')
assert_array_equal(a, np.array(sixu('1234567890' * 3), dtype='U30'))

a = np.array(123456789012345678901234567890, dtype='S')
assert_array_equal(a, np.array(b'1234567890' * 3, dtype='S30'))
a = np.array(123456789012345678901234567890, dtype='U')
assert_array_equal(a, np.array(sixu('1234567890' * 3), dtype='U30'))


def test_copyto_fromscalar():
a = np.arange(6, dtype='f4').reshape(2,3)

Expand Down

0 comments on commit 884c403

Please sign in to comment.