diff --git a/asciidtype/.flake8 b/asciidtype/.flake8 index 80676bc7..fff82cbb 100644 --- a/asciidtype/.flake8 +++ b/asciidtype/.flake8 @@ -1,2 +1,3 @@ [flake8] per-file-ignores = __init__.py:F401 +max-line-length = 160 diff --git a/asciidtype/asciidtype/src/casts.c b/asciidtype/asciidtype/src/casts.c index bc4f9eb7..dc504b9a 100644 --- a/asciidtype/asciidtype/src/casts.c +++ b/asciidtype/asciidtype/src/casts.c @@ -28,13 +28,17 @@ ascii_to_ascii_resolve_descriptors(PyObject *NPY_UNUSED(self), loop_descrs[1] = given_descrs[1]; } - if (((ASCIIDTypeObject *)loop_descrs[0])->size == - ((ASCIIDTypeObject *)loop_descrs[1])->size) { + long in_size = ((ASCIIDTypeObject *)loop_descrs[0])->size; + long out_size = ((ASCIIDTypeObject *)loop_descrs[1])->size; + + if (in_size == out_size) { *view_offset = 0; return NPY_NO_CASTING; } - - return NPY_SAME_KIND_CASTING; + else if (in_size > out_size) { + return NPY_UNSAFE_CASTING; + } + return NPY_SAFE_CASTING; } static int @@ -72,33 +76,224 @@ ascii_to_ascii(PyArrayMethod_Context *context, char *const data[], return 0; } +static NPY_CASTING +unicode_to_ascii_resolve_descriptors(PyObject *NPY_UNUSED(self), + PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]), + PyArray_Descr *given_descrs[2], + PyArray_Descr *loop_descrs[2], + npy_intp *NPY_UNUSED(view_offset)) +{ + Py_INCREF(given_descrs[0]); + loop_descrs[0] = given_descrs[0]; + // numpy stores unicode as UCS4 (4 bytes wide), so bitshift + // by 2 to get the number of ASCII bytes needed + long in_size = (loop_descrs[0]->elsize) >> 2; + if (given_descrs[1] == NULL) { + ASCIIDTypeObject *ascii_descr = new_asciidtype_instance(in_size); + loop_descrs[1] = (PyArray_Descr *)ascii_descr; + } + else { + Py_INCREF(given_descrs[1]); + loop_descrs[1] = given_descrs[1]; + } + + long out_size = ((ASCIIDTypeObject *)loop_descrs[1])->size; + + if (out_size >= in_size) { + return NPY_SAFE_CASTING; + } + + return NPY_UNSAFE_CASTING; +} + static int -ascii_to_ascii_get_loop(PyArrayMethod_Context *context, int aligned, - int NPY_UNUSED(move_references), - const npy_intp *strides, - PyArrayMethod_StridedLoop **out_loop, - NpyAuxData **NPY_UNUSED(out_transferdata), - NPY_ARRAYMETHOD_FLAGS *flags) +unicode_to_ascii(PyArrayMethod_Context *context, char *const data[], + npy_intp const dimensions[], npy_intp const strides[], + NpyAuxData *NPY_UNUSED(auxdata)) { - *out_loop = (PyArrayMethod_StridedLoop *)&ascii_to_ascii; + PyArray_Descr **descrs = context->descriptors; + long in_size = (descrs[0]->elsize) / 4; + long out_size = ((ASCIIDTypeObject *)descrs[1])->size; + long copy_size; + + if (out_size > in_size) { + copy_size = in_size; + } + else { + copy_size = out_size; + } + + npy_intp N = dimensions[0]; + char *in = data[0]; + char *out = data[1]; + npy_intp in_stride = strides[0]; + npy_intp out_stride = strides[1]; + + while (N--) { + // copy input characters, checking that input UCS4 + // characters are all ascii, raising an error otherwise + for (int i = 0; i < copy_size; i++) { + Py_UCS4 c = ((Py_UCS4 *)in)[i]; + if (c > 127) { + PyErr_SetString( + PyExc_TypeError, + "Can only store ASCII text in a ASCIIDType array."); + return -1; + } + // UCS4 character is ascii, so casting to Py_UCS1 does not truncate + out[i] = (Py_UCS1)c; + } + // write zeros to remaining ASCII characters (if any) + for (int i = copy_size; i < out_size; i++) { + *(out + i) = '\0'; + } + in += in_stride; + out += out_stride; + } - *flags = 0; return 0; } +static int +ascii_to_unicode(PyArrayMethod_Context *context, char *const data[], + npy_intp const dimensions[], npy_intp const strides[], + NpyAuxData *NPY_UNUSED(auxdata)) +{ + PyArray_Descr **descrs = context->descriptors; + long in_size = ((ASCIIDTypeObject *)descrs[0])->size; + long out_size = (descrs[1]->elsize) / 4; + long copy_size; + + if (out_size > in_size) { + copy_size = in_size; + } + else { + copy_size = out_size; + } + + npy_intp N = dimensions[0]; + char *in = data[0]; + char *out = data[1]; + npy_intp in_stride = strides[0]; + npy_intp out_stride = strides[1]; + + while (N--) { + // copy ASCII input to first byte, fill rest with zeros + for (int i = 0; i < copy_size; i++) { + ((Py_UCS4 *)out)[i] = ((Py_UCS1 *)in)[i]; + } + // fill all remaining UCS4 characters with zeros + for (int i = copy_size; i < out_size; i++) { + ((Py_UCS4 *)out)[i] = (Py_UCS1)0; + } + in += in_stride; + out += out_stride; + } + return 0; +} + +static NPY_CASTING +ascii_to_unicode_resolve_descriptors(PyObject *NPY_UNUSED(self), + PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]), + PyArray_Descr *given_descrs[2], + PyArray_Descr *loop_descrs[2], + npy_intp *NPY_UNUSED(view_offset)) +{ + Py_INCREF(given_descrs[0]); + loop_descrs[0] = given_descrs[0]; + long in_size = ((ASCIIDTypeObject *)given_descrs[0])->size; + if (given_descrs[1] == NULL) { + PyArray_Descr *unicode_descr = PyArray_DescrNewFromType(NPY_UNICODE); + // numpy stores unicode as UCS4 (4 bytes wide), so bitshift + // by 2 to get the number of bytes needed to store the UCS4 charaters + unicode_descr->elsize = in_size << 2; + loop_descrs[1] = unicode_descr; + } + else { + Py_INCREF(given_descrs[1]); + loop_descrs[1] = given_descrs[1]; + } + + long out_size = (loop_descrs[1]->elsize) >> 2; + + if (out_size >= in_size) { + return NPY_SAFE_CASTING; + } + + return NPY_UNSAFE_CASTING; +} + static PyArray_DTypeMeta *a2a_dtypes[2] = {NULL, NULL}; static PyType_Slot a2a_slots[] = { {NPY_METH_resolve_descriptors, &ascii_to_ascii_resolve_descriptors}, - {_NPY_METH_get_loop, &ascii_to_ascii_get_loop}, + {NPY_METH_strided_loop, &ascii_to_ascii}, + {NPY_METH_unaligned_strided_loop, &ascii_to_ascii}, {0, NULL}}; PyArrayMethod_Spec ASCIIToASCIICastSpec = { .name = "cast_ASCIIDType_to_ASCIIDType", .nin = 1, .nout = 1, - .flags = NPY_METH_SUPPORTS_UNALIGNED, - .casting = NPY_SAME_KIND_CASTING, + .casting = NPY_UNSAFE_CASTING, + .flags = (NPY_METH_NO_FLOATINGPOINT_ERRORS | + NPY_METH_SUPPORTS_UNALIGNED), .dtypes = a2a_dtypes, .slots = a2a_slots, }; + +static PyType_Slot u2a_slots[] = { + {NPY_METH_resolve_descriptors, &unicode_to_ascii_resolve_descriptors}, + {NPY_METH_strided_loop, &unicode_to_ascii}, + {0, NULL}}; + +static char *u2a_name = "cast_Unicode_to_ASCIIDType"; + +static PyType_Slot a2u_slots[] = { + {NPY_METH_resolve_descriptors, &ascii_to_unicode_resolve_descriptors}, + {NPY_METH_strided_loop, &ascii_to_unicode}, + {0, NULL}}; + +static char *a2u_name = "cast_ASCIIDType_to_Unicode"; + +PyArrayMethod_Spec ** +get_casts(void) +{ + PyArray_DTypeMeta **u2a_dtypes = malloc(2 * sizeof(PyArray_DTypeMeta *)); + u2a_dtypes[0] = &PyArray_UnicodeDType; + u2a_dtypes[1] = NULL; + + PyArrayMethod_Spec *UnicodeToASCIICastSpec = + malloc(sizeof(PyArrayMethod_Spec)); + + UnicodeToASCIICastSpec->name = u2a_name; + UnicodeToASCIICastSpec->nin = 1; + UnicodeToASCIICastSpec->nout = 1; + UnicodeToASCIICastSpec->casting = NPY_UNSAFE_CASTING; + UnicodeToASCIICastSpec->flags = NPY_METH_NO_FLOATINGPOINT_ERRORS; + UnicodeToASCIICastSpec->dtypes = u2a_dtypes; + UnicodeToASCIICastSpec->slots = u2a_slots; + + PyArray_DTypeMeta **a2u_dtypes = malloc(2 * sizeof(PyArray_DTypeMeta *)); + a2u_dtypes[0] = NULL; + a2u_dtypes[1] = &PyArray_UnicodeDType; + + PyArrayMethod_Spec *ASCIIToUnicodeCastSpec = + malloc(sizeof(PyArrayMethod_Spec)); + + ASCIIToUnicodeCastSpec->name = a2u_name; + ASCIIToUnicodeCastSpec->nin = 1; + ASCIIToUnicodeCastSpec->nout = 1; + ASCIIToUnicodeCastSpec->casting = NPY_UNSAFE_CASTING; + ASCIIToUnicodeCastSpec->flags = NPY_METH_NO_FLOATINGPOINT_ERRORS; + ASCIIToUnicodeCastSpec->dtypes = a2u_dtypes; + ASCIIToUnicodeCastSpec->slots = a2u_slots; + + PyArrayMethod_Spec **casts = malloc(4 * sizeof(PyArrayMethod_Spec *)); + casts[0] = &ASCIIToASCIICastSpec; + casts[1] = UnicodeToASCIICastSpec; + casts[2] = ASCIIToUnicodeCastSpec; + casts[3] = NULL; + + return casts; +} diff --git a/asciidtype/asciidtype/src/casts.h b/asciidtype/asciidtype/src/casts.h index f403fe8c..8dc8605e 100644 --- a/asciidtype/asciidtype/src/casts.h +++ b/asciidtype/asciidtype/src/casts.h @@ -10,6 +10,7 @@ #include "numpy/experimental_dtype_api.h" #include "numpy/ndarraytypes.h" -extern PyArrayMethod_Spec ASCIIToASCIICastSpec; +PyArrayMethod_Spec ** +get_casts(void); #endif /* _NPY_CASTS_H */ diff --git a/asciidtype/asciidtype/src/dtype.c b/asciidtype/asciidtype/src/dtype.c index 67387e99..cc56f619 100644 --- a/asciidtype/asciidtype/src/dtype.c +++ b/asciidtype/asciidtype/src/dtype.c @@ -16,6 +16,7 @@ get_value(PyObject *scalar) PyErr_SetString( PyExc_TypeError, "Can only store ASCII text in a ASCIIDType array."); + return NULL; } } else if (scalar_type != ASCIIScalar_Type) { @@ -29,6 +30,12 @@ get_value(PyObject *scalar) return NULL; } ret_bytes = PyUnicode_AsASCIIString(value); + if (ret_bytes == NULL) { + PyErr_SetString( + PyExc_TypeError, + "Can only store ASCII text in a ASCIIDType array."); + return NULL; + } Py_DECREF(value); } return ret_bytes; @@ -38,20 +45,16 @@ get_value(PyObject *scalar) * Internal helper to create new instances */ ASCIIDTypeObject * -new_asciidtype_instance(PyObject *size) +new_asciidtype_instance(long size) { ASCIIDTypeObject *new = (ASCIIDTypeObject *)PyArrayDescr_Type.tp_new( (PyTypeObject *)&ASCIIDType, NULL, NULL); if (new == NULL) { return NULL; } - long size_l = PyLong_AsLong(size); - if (size_l == -1 && PyErr_Occurred()) { - return NULL; - } - new->size = size_l; - new->base.elsize = size_l * sizeof(char); - new->base.alignment = size_l *_Alignof(char); + new->size = size; + new->base.elsize = size * sizeof(char); + new->base.alignment = size *_Alignof(char); return new; } @@ -182,18 +185,14 @@ asciidtype_new(PyTypeObject *NPY_UNUSED(cls), PyObject *args, PyObject *kwds) { static char *kwargs_strs[] = {"size", NULL}; - PyObject *size = NULL; + long size = 0; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "|O:ASCIIDType", kwargs_strs, + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|l:ASCIIDType", kwargs_strs, &size)) { return NULL; } - if (size == NULL) { - size = PyLong_FromLong(0); - } PyObject *ret = (PyObject *)new_asciidtype_instance(size); - Py_DECREF(size); return ret; } @@ -239,7 +238,7 @@ PyArray_DTypeMeta ASCIIDType = { int init_ascii_dtype(void) { - static PyArrayMethod_Spec *casts[] = {&ASCIIToASCIICastSpec, NULL}; + PyArrayMethod_Spec **casts = get_casts(); PyArrayDTypeMeta_Spec ASCIIDType_DTypeSpec = { .flags = NPY_DT_PARAMETRIC, @@ -267,5 +266,11 @@ init_ascii_dtype(void) ASCIIDType.singleton = singleton; + free(ASCIIDType_DTypeSpec.casts[1]->dtypes); + free(ASCIIDType_DTypeSpec.casts[1]); + free(ASCIIDType_DTypeSpec.casts[2]->dtypes); + free(ASCIIDType_DTypeSpec.casts[2]); + free(ASCIIDType_DTypeSpec.casts); + return 0; } diff --git a/asciidtype/asciidtype/src/dtype.h b/asciidtype/asciidtype/src/dtype.h index 232c6e63..47fcf8f1 100644 --- a/asciidtype/asciidtype/src/dtype.h +++ b/asciidtype/asciidtype/src/dtype.h @@ -22,7 +22,7 @@ extern PyArray_DTypeMeta ASCIIDType; extern PyTypeObject *ASCIIScalar_Type; ASCIIDTypeObject * -new_asciidtype_instance(PyObject *size); +new_asciidtype_instance(long size); int init_ascii_dtype(void); diff --git a/asciidtype/tests/test_asciidtype.py b/asciidtype/tests/test_asciidtype.py index 9c11e541..3fde5ac1 100644 --- a/asciidtype/tests/test_asciidtype.py +++ b/asciidtype/tests/test_asciidtype.py @@ -1,4 +1,7 @@ +import re + import numpy as np +import pytest from asciidtype import ASCIIDType, ASCIIScalar @@ -50,24 +53,127 @@ def test_creation_truncation(): def test_casting_to_asciidtype(): - arr = np.array(["hello", "this", "is", "an", "array"], dtype=ASCIIDType(5)) + for dtype in (None, ASCIIDType(5)): + arr = np.array(["this", "is", "an", "array"], dtype=dtype) - assert repr(arr.astype(ASCIIDType(7))) == ( - "array(['hello', 'this', 'is', 'an', 'array'], dtype=ASCIIDType(7))" - ) + assert repr(arr.astype(ASCIIDType(7))) == ( + "array(['this', 'is', 'an', 'array'], dtype=ASCIIDType(7))" + ) - assert repr(arr.astype(ASCIIDType(5))) == ( - "array(['hello', 'this', 'is', 'an', 'array'], dtype=ASCIIDType(5))" + assert repr(arr.astype(ASCIIDType(5))) == ( + "array(['this', 'is', 'an', 'array'], dtype=ASCIIDType(5))" + ) + + assert repr(arr.astype(ASCIIDType(4))) == ( + "array(['this', 'is', 'an', 'arra'], dtype=ASCIIDType(4))" + ) + + assert repr(arr.astype(ASCIIDType(1))) == ( + "array(['t', 'i', 'a', 'a'], dtype=ASCIIDType(1))" + ) + + # assert repr(arr.astype(ASCIIDType())) == ( + # "array(['', '', '', '', ''], dtype=ASCIIDType(0))" + # ) + + +def test_casting_safety(): + arr = np.array(["this", "is", "an", "array"]) + assert repr(arr.astype(ASCIIDType(6), casting="safe")) == ( + "array(['this', 'is', 'an', 'array'], dtype=ASCIIDType(6))" + ) + assert repr(arr.astype(ASCIIDType(5), casting="safe")) == ( + "array(['this', 'is', 'an', 'array'], dtype=ASCIIDType(5))" + ) + with pytest.raises( + TypeError, + match=re.escape( + "Cannot cast array data from dtype('