Permalink
Browse files

BUG: Fix CastToType to handle string->string casts (ticket #1748)

  • Loading branch information...
1 parent 13212a5 commit cfff7508bc29e6bc0c44b2d42d7bb23e143d5bc3 @mwiebe mwiebe committed Mar 4, 2011
Showing with 38 additions and 26 deletions.
  1. +28 −26 numpy/core/src/multiarray/convert_datatype.c
  2. +10 −0 numpy/core/tests/test_regression.py
@@ -23,54 +23,56 @@
* Cast an array using typecode structure.
* steals reference to at --- cannot be NULL
*
- * This function always makes a copy of mp, even if the dtype
+ * This function always makes a copy of arr, even if the dtype
* doesn't change.
*/
NPY_NO_EXPORT PyObject *
-PyArray_CastToType(PyArrayObject *mp, PyArray_Descr *at, int fortran)
+PyArray_CastToType(PyArrayObject *arr, PyArray_Descr *dtype, int fortran)
{
PyObject *out;
- int ret;
- PyArray_Descr *mpd;
+ PyArray_Descr *arr_dtype;
- mpd = mp->descr;
+ arr_dtype = PyArray_DESCR(arr);
- if (at->elsize == 0) {
- PyArray_DESCR_REPLACE(at);
- if (at == NULL) {
+ if (dtype->elsize == 0) {
+ PyArray_DESCR_REPLACE(dtype);
+ if (dtype == NULL) {
return NULL;
}
- if (mpd->type_num == PyArray_STRING &&
- at->type_num == PyArray_UNICODE) {
- at->elsize = mpd->elsize << 2;
+
+ if (arr_dtype->type_num == dtype->type_num) {
+ dtype->elsize = arr_dtype->elsize;
+ }
+ else if (arr_dtype->type_num == NPY_STRING &&
+ dtype->type_num == NPY_UNICODE) {
+ dtype->elsize = arr_dtype->elsize * 4;
}
- if (mpd->type_num == PyArray_UNICODE &&
- at->type_num == PyArray_STRING) {
- at->elsize = mpd->elsize >> 2;
+ else if (arr_dtype->type_num == NPY_UNICODE &&
+ dtype->type_num == NPY_STRING) {
+ dtype->elsize = arr_dtype->elsize / 4;
}
- if (at->type_num == PyArray_VOID) {
- at->elsize = mpd->elsize;
+ else if (dtype->type_num == NPY_VOID) {
+ dtype->elsize = arr_dtype->elsize;
}
}
- out = PyArray_NewFromDescr(Py_TYPE(mp), at,
- mp->nd,
- mp->dimensions,
+ out = PyArray_NewFromDescr(Py_TYPE(arr), dtype,
+ arr->nd,
+ arr->dimensions,
NULL, NULL,
fortran,
- (PyObject *)mp);
+ (PyObject *)arr);
if (out == NULL) {
return NULL;
}
- ret = PyArray_CopyInto((PyArrayObject *)out, mp);
- if (ret != -1) {
- return out;
- }
- Py_DECREF(out);
- return NULL;
+ if (PyArray_CopyInto((PyArrayObject *)out, arr) < 0) {
+ Py_DECREF(out);
+ return NULL;
+ }
+ return out;
}
/*NUMPY_API
@@ -1535,5 +1535,15 @@ def test_setting_rank0_string(self):
a[()] = np.array(4)
assert_equal(a, np.array(4))
+ def test_string_astype(self):
+ "Ticket #1748"
+ s1 = asbytes('black')
+ s2 = asbytes('white')
+ s3 = asbytes('other')
+ a = np.array([[s1],[s2],[s3]])
+ assert_equal(a.dtype, np.dtype('S5'))
+ b = a.astype('str')
+ assert_equal(b.dtype, np.dtype('S5'))
+
if __name__ == "__main__":
run_module_suite()

0 comments on commit cfff750

Please sign in to comment.