Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix creation of string arrays from object types #3491

Merged
merged 2 commits into from
Jul 8, 2013
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
50 changes: 37 additions & 13 deletions numpy/core/src/multiarray/convert_datatype.c
Original file line number Diff line number Diff line change
Expand Up @@ -226,24 +226,36 @@ PyArray_AdaptFlexibleDType(PyObject *data_obj, PyArray_Descr *data_dtype,
break;
case NPY_OBJECT:
size = 64;
/*
* If we're adapting a string dtype for an array of string
* objects, call GetArrayParamsFromObject to figure out
* maximum string size, and use that as new dtype size.
*/
if ((flex_type_num == NPY_STRING ||
flex_type_num == NPY_UNICODE) &&
data_obj != NULL) {
/*
* Convert data array to list of objects since
* GetArrayParamsFromObject won't iterate through
* items in an array.
*/
list = PyArray_ToList(data_obj);
if (list != NULL) {
if (PyArray_CheckScalar(data_obj)) {
PyObject *scalar = PyArray_ToList(data_obj);
if (scalar != NULL) {
PyObject *s = PyObject_Str(scalar);
if (s == NULL) {
Py_DECREF(scalar);
Py_DECREF(*flex_dtype);
*flex_dtype = NULL;
return;
}
else {
size = PyObject_Length(s);
Py_DECREF(s);
}
Py_DECREF(scalar);
}
}
else if (PyArray_Check(data_obj)) {
/*
* Convert data array to list of objects since
* GetArrayParamsFromObject won't iterate over
* array.
*/
list = PyArray_ToList(data_obj);
result = PyArray_GetArrayParamsFromObject(
list,
flex_dtype,
*flex_dtype,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flex_dtype type changed or was this a previous error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a previous error that my first set of tests didn't catch.

0, &dtype,
&ndim, dims, &arr, NULL);
if (result == 0 && dtype != NULL) {
Expand All @@ -256,6 +268,18 @@ PyArray_AdaptFlexibleDType(PyObject *data_obj, PyArray_Descr *data_dtype,
}
Py_DECREF(list);
}
else if (PyArray_IsPythonScalar(data_obj)) {
PyObject *s = PyObject_Str(data_obj);
if (s == NULL) {
Py_DECREF(*flex_dtype);
*flex_dtype = NULL;
return;
}
else {
size = PyObject_Length(s);
Py_DECREF(s);
}
}
}
break;
case NPY_STRING:
Expand Down
31 changes: 24 additions & 7 deletions numpy/core/src/multiarray/ctors.c
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ PyArray_AssignFromSequence(PyArrayObject *self, PyObject *v)
*/

static int
discover_itemsize(PyObject *s, int nd, int *itemsize)
discover_itemsize(PyObject *s, int nd, int *itemsize, int size_as_string)
{
int n, r, i;

Expand All @@ -532,14 +532,26 @@ discover_itemsize(PyObject *s, int nd, int *itemsize)

if ((nd == 0) || PyString_Check(s) ||
#if defined(NPY_PY3K)
PyMemoryView_Check(s) ||
PyMemoryView_Check(s) ||
#else
PyBuffer_Check(s) ||
PyBuffer_Check(s) ||
#endif
PyUnicode_Check(s)) {
PyUnicode_Check(s)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like there should be more indentation of the logical expression part.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the preceding logical parts need an extra indent. Basically:

if (log1 ||
        log2 ||
        log3) {
    blah;
}


/* If an object has no length, leave it be */
n = PyObject_Length(s);
if (size_as_string && s != NULL && !PyString_Check(s)) {
PyObject *s_string = PyObject_Str(s);
if (s_string) {
n = PyObject_Length(s_string);
Py_DECREF(s_string);
}
else {
n = -1;
}
}
else {
n = PyObject_Length(s);
}
if (n == -1) {
PyErr_Clear();
}
Expand All @@ -557,7 +569,7 @@ discover_itemsize(PyObject *s, int nd, int *itemsize)
return -1;
}

r = discover_itemsize(e,nd-1,itemsize);
r = discover_itemsize(e, nd - 1, itemsize, size_as_string);
Py_DECREF(e);
if (r == -1) {
return -1;
Expand Down Expand Up @@ -1528,7 +1540,12 @@ PyArray_GetArrayParamsFromObject(PyObject *op,
if ((*out_dtype)->elsize == 0 &&
PyTypeNum_ISEXTENDED((*out_dtype)->type_num)) {
int itemsize = 0;
if (discover_itemsize(op, *out_ndim, &itemsize) < 0) {
int size_as_string = 0;
if ((*out_dtype)->type_num == NPY_STRING ||
(*out_dtype)->type_num == NPY_UNICODE) {
size_as_string = 1;
}
if (discover_itemsize(op, *out_ndim, &itemsize, size_as_string) < 0) {
Py_DECREF(*out_dtype);
if (PyErr_Occurred() &&
PyErr_GivenExceptionMatches(PyErr_Occurred(),
Expand Down
16 changes: 16 additions & 0 deletions numpy/core/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,22 @@ def test_array_astype():
assert_equal(a, b)
assert_equal(b.dtype, np.dtype('U10'))

a = np.array(123456789012345678901234567890, dtype='O').astype('S')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can do this more compactly as

In [2]: '1234567890'*3
Out[2]: '123456789012345678901234567890'

Which also has the virtue that one can tell how many digits there are.

assert_array_equal(a, np.array(b'1234567890' * 3, dtype='S30'))
a = np.array(123456789012345678901234567890, dtype='O').astype('U')
assert_array_equal(a, np.array(sixu('1234567890' * 3), dtype='U30'))

a = np.array([123456789012345678901234567890], dtype='O').astype('S')
assert_array_equal(a, np.array(b'1234567890' * 3, dtype='S30'))
a = np.array([123456789012345678901234567890], dtype='O').astype('U')
assert_array_equal(a, np.array(sixu('1234567890' * 3), dtype='U30'))

a = np.array(123456789012345678901234567890, dtype='S')
assert_array_equal(a, np.array(b'1234567890' * 3, dtype='S30'))
a = np.array(123456789012345678901234567890, dtype='U')
assert_array_equal(a, np.array(sixu('1234567890' * 3), dtype='U30'))


def test_copyto_fromscalar():
a = np.arange(6, dtype='f4').reshape(2,3)

Expand Down