From f5d920411cb73a65d907994dc548cb8329600a16 Mon Sep 17 00:00:00 2001 From: Nathan Goldbaum Date: Mon, 12 Dec 2022 14:33:24 -0700 Subject: [PATCH 01/13] add missing return for error case --- asciidtype/asciidtype/src/dtype.c | 1 + 1 file changed, 1 insertion(+) diff --git a/asciidtype/asciidtype/src/dtype.c b/asciidtype/asciidtype/src/dtype.c index 67387e99..2c37a8a9 100644 --- a/asciidtype/asciidtype/src/dtype.c +++ b/asciidtype/asciidtype/src/dtype.c @@ -16,6 +16,7 @@ get_value(PyObject *scalar) PyErr_SetString( PyExc_TypeError, "Can only store ASCII text in a ASCIIDType array."); + return NULL; } } else if (scalar_type != ASCIIScalar_Type) { From 09a84ef9ff80ddfb09df1c1b06c720e4f900921a Mon Sep 17 00:00:00 2001 From: Nathan Goldbaum Date: Mon, 12 Dec 2022 14:34:20 -0700 Subject: [PATCH 02/13] add missing error checking in get_value --- asciidtype/asciidtype/src/dtype.c | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/asciidtype/asciidtype/src/dtype.c b/asciidtype/asciidtype/src/dtype.c index 2c37a8a9..60dca09c 100644 --- a/asciidtype/asciidtype/src/dtype.c +++ b/asciidtype/asciidtype/src/dtype.c @@ -30,6 +30,12 @@ get_value(PyObject *scalar) return NULL; } ret_bytes = PyUnicode_AsASCIIString(value); + if (ret_bytes == NULL) { + PyErr_SetString( + PyExc_TypeError, + "Can only store ASCII text in a ASCIIDType array."); + return NULL; + } Py_DECREF(value); } return ret_bytes; From 38630cca171baeb7be7d4f91d92a990abd714036 Mon Sep 17 00:00:00 2001 From: Nathan Goldbaum Date: Mon, 12 Dec 2022 14:35:14 -0700 Subject: [PATCH 03/13] remove incorrect decref for borrowed reference --- asciidtype/asciidtype/src/dtype.c | 1 - 1 file changed, 1 deletion(-) diff --git a/asciidtype/asciidtype/src/dtype.c b/asciidtype/asciidtype/src/dtype.c index 60dca09c..980a6994 100644 --- a/asciidtype/asciidtype/src/dtype.c +++ b/asciidtype/asciidtype/src/dtype.c @@ -200,7 +200,6 @@ asciidtype_new(PyTypeObject *NPY_UNUSED(cls), PyObject *args, PyObject *kwds) } PyObject *ret = (PyObject *)new_asciidtype_instance(size); - Py_DECREF(size); return ret; } From f7bb0ba29ff135a8acb4eb3ff1d65c61b67301d6 Mon Sep 17 00:00:00 2001 From: Nathan Goldbaum Date: Mon, 12 Dec 2022 14:37:11 -0700 Subject: [PATCH 04/13] add NPY_UNUSED for unused parameters to ascii_to_ascii_get_loop --- asciidtype/asciidtype/src/casts.c | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/asciidtype/asciidtype/src/casts.c b/asciidtype/asciidtype/src/casts.c index bc4f9eb7..6f6c5ef7 100644 --- a/asciidtype/asciidtype/src/casts.c +++ b/asciidtype/asciidtype/src/casts.c @@ -73,9 +73,10 @@ ascii_to_ascii(PyArrayMethod_Context *context, char *const data[], } static int -ascii_to_ascii_get_loop(PyArrayMethod_Context *context, int aligned, +ascii_to_ascii_get_loop(PyArrayMethod_Context *NPY_UNUSED(context), + int NPY_UNUSED(aligned), int NPY_UNUSED(move_references), - const npy_intp *strides, + const npy_intp *NPY_UNUSED(strides), PyArrayMethod_StridedLoop **out_loop, NpyAuxData **NPY_UNUSED(out_transferdata), NPY_ARRAYMETHOD_FLAGS *flags) From 9e86760497e37a5810ef6ca026b6974c8cbc8dbd Mon Sep 17 00:00:00 2001 From: Nathan Goldbaum Date: Mon, 12 Dec 2022 14:51:32 -0700 Subject: [PATCH 05/13] increase maximum allowed line length --- asciidtype/.flake8 | 1 + 1 file changed, 1 insertion(+) diff --git a/asciidtype/.flake8 b/asciidtype/.flake8 index 80676bc7..fff82cbb 100644 --- a/asciidtype/.flake8 +++ b/asciidtype/.flake8 @@ -1,2 +1,3 @@ [flake8] per-file-ignores = __init__.py:F401 +max-line-length = 160 From bc38f899c6f566e70722dcea161efc611f14c542 Mon Sep 17 00:00:00 2001 From: Nathan Goldbaum Date: Mon, 12 Dec 2022 14:53:37 -0700 Subject: [PATCH 06/13] add ascii to unicode and unicode to ascii casts --- asciidtype/asciidtype/src/casts.c | 240 ++++++++++++++++++++++++++++ asciidtype/asciidtype/src/casts.h | 3 +- asciidtype/asciidtype/src/dtype.c | 8 +- asciidtype/tests/test_asciidtype.py | 80 +++++++--- 4 files changed, 308 insertions(+), 23 deletions(-) diff --git a/asciidtype/asciidtype/src/casts.c b/asciidtype/asciidtype/src/casts.c index 6f6c5ef7..5d10f2ba 100644 --- a/asciidtype/asciidtype/src/casts.c +++ b/asciidtype/asciidtype/src/casts.c @@ -87,6 +87,190 @@ ascii_to_ascii_get_loop(PyArrayMethod_Context *NPY_UNUSED(context), return 0; } +static NPY_CASTING +unicode_to_ascii_resolve_descriptors(PyObject *NPY_UNUSED(self), + PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]), + PyArray_Descr *given_descrs[2], + PyArray_Descr *loop_descrs[2], + npy_intp *NPY_UNUSED(view_offset)) +{ + Py_INCREF(given_descrs[0]); + loop_descrs[0] = given_descrs[0]; + if (given_descrs[1] == NULL) { + Py_INCREF(given_descrs[0]); + loop_descrs[1] = given_descrs[0]; + } + else { + Py_INCREF(given_descrs[1]); + loop_descrs[1] = given_descrs[1]; + } + + return NPY_SAME_KIND_CASTING; +} + +static int +ucs4_character_is_ascii(char *buffer) +{ + int first_char = buffer[0]; + + if (first_char < 0 || first_char > 127) { + return -1; + } + + for (int i = 1; i < 4; i++) { + if (buffer[i] != 0) { + return -1; + } + } + + return 0; +} + +static int +unicode_to_ascii(PyArrayMethod_Context *context, char *const data[], + npy_intp const dimensions[], npy_intp const strides[], + NpyAuxData *NPY_UNUSED(auxdata)) +{ + PyArray_Descr **descrs = context->descriptors; + long in_size = (descrs[0]->elsize) / 4; + long out_size = ((ASCIIDTypeObject *)descrs[1])->size; + long copy_size; + + if (out_size > in_size) { + copy_size = in_size; + } + else { + copy_size = out_size; + } + + npy_intp N = dimensions[0]; + char *in = data[0]; + char *out = data[1]; + npy_intp in_stride = strides[0]; + npy_intp out_stride = strides[1]; + + while (N--) { + // copy input characters, checking that input UCS4 + // characters are all ascii, raising an error otherwise + for (int i = 0; i < copy_size; i++) { + if (ucs4_character_is_ascii(in) == -1) { + PyGILState_STATE gstate; + gstate = PyGILState_Ensure(); + PyErr_SetString( + PyExc_TypeError, + "Can only store ASCII text in a ASCIIDType array."); + PyGILState_Release(gstate); + return -1; + } + // UCS4 character is ascii, so copy first byte of character + // into output, ignoring the rest + *(out + i) = *(in + i * 4); + } + // write zeros to remaining ASCII characters (if any) + for (int i = copy_size; i < out_size; i++) { + *(out + i) = '\0'; + } + in += in_stride; + out += out_stride; + } + + return 0; +} + +static int +unicode_to_ascii_get_loop(PyArrayMethod_Context *NPY_UNUSED(context), + int NPY_UNUSED(aligned), + int NPY_UNUSED(move_references), + const npy_intp *NPY_UNUSED(strides), + PyArrayMethod_StridedLoop **out_loop, + NpyAuxData **NPY_UNUSED(out_transferdata), + NPY_ARRAYMETHOD_FLAGS *flags) +{ + *out_loop = (PyArrayMethod_StridedLoop *)&unicode_to_ascii; + + *flags = 0; + return 0; +} + +static int +ascii_to_unicode(PyArrayMethod_Context *context, char *const data[], + npy_intp const dimensions[], npy_intp const strides[], + NpyAuxData *NPY_UNUSED(auxdata)) +{ + PyArray_Descr **descrs = context->descriptors; + long in_size = ((ASCIIDTypeObject *)descrs[0])->size; + long out_size = (descrs[1]->elsize) / 4; + long copy_size; + + if (out_size > in_size) { + copy_size = in_size; + } + else { + copy_size = out_size; + } + + npy_intp N = dimensions[0]; + char *in = data[0]; + char *out = data[1]; + npy_intp in_stride = strides[0]; + npy_intp out_stride = strides[1]; + + while (N--) { + // copy ASCII input to first byte, fill rest with zeros + for (int i = 0; i < copy_size; i++) { + *(out + i * 4) = *(in + i); + for (int j = 1; j < 4; j++) { + *(out + i * 4 + j) = '\0'; + } + } + // fill all remaining UCS4 characters with zeros + for (int i = copy_size; i < out_size; i++) { + for (int j = 0; j < 4; j++) { + *(out + i * 4 + j) = '\0'; + } + } + in += in_stride; + out += out_stride; + } + return 0; +} + +static NPY_CASTING +ascii_to_unicode_resolve_descriptors(PyObject *NPY_UNUSED(self), + PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]), + PyArray_Descr *given_descrs[2], + PyArray_Descr *loop_descrs[2], + npy_intp *NPY_UNUSED(view_offset)) +{ + Py_INCREF(given_descrs[0]); + loop_descrs[0] = given_descrs[0]; + if (given_descrs[1] == NULL) { + Py_INCREF(given_descrs[0]); + loop_descrs[1] = given_descrs[0]; + } + else { + Py_INCREF(given_descrs[1]); + loop_descrs[1] = given_descrs[1]; + } + + return NPY_SAME_KIND_CASTING; +} + +static int +ascii_to_unicode_get_loop(PyArrayMethod_Context *NPY_UNUSED(context), + int NPY_UNUSED(aligned), + int NPY_UNUSED(move_references), + const npy_intp *NPY_UNUSED(strides), + PyArrayMethod_StridedLoop **out_loop, + NpyAuxData **NPY_UNUSED(out_transferdata), + NPY_ARRAYMETHOD_FLAGS *flags) +{ + *out_loop = (PyArrayMethod_StridedLoop *)&ascii_to_unicode; + + *flags = 0; + return 0; +} + static PyArray_DTypeMeta *a2a_dtypes[2] = {NULL, NULL}; static PyType_Slot a2a_slots[] = { @@ -103,3 +287,59 @@ PyArrayMethod_Spec ASCIIToASCIICastSpec = { .dtypes = a2a_dtypes, .slots = a2a_slots, }; + +static PyType_Slot u2a_slots[] = { + {NPY_METH_resolve_descriptors, &unicode_to_ascii_resolve_descriptors}, + {_NPY_METH_get_loop, &unicode_to_ascii_get_loop}, + {0, NULL}}; + +static char *u2a_name = "cast_Unicode_to_ASCIIDType"; + +static PyType_Slot a2u_slots[] = { + {NPY_METH_resolve_descriptors, &ascii_to_unicode_resolve_descriptors}, + {_NPY_METH_get_loop, &ascii_to_unicode_get_loop}, + {0, NULL}}; + +static char *a2u_name = "cast_ASCIIDType_to_Unicode"; + +PyArrayMethod_Spec ** +get_casts(void) +{ + PyArray_DTypeMeta **u2a_dtypes = malloc(2 * sizeof(PyArray_DTypeMeta *)); + u2a_dtypes[0] = &PyArray_UnicodeDType; + u2a_dtypes[1] = NULL; + + PyArrayMethod_Spec *UnicodeToASCIICastSpec = + malloc(sizeof(PyArrayMethod_Spec)); + + UnicodeToASCIICastSpec->name = u2a_name; + UnicodeToASCIICastSpec->nin = 1; + UnicodeToASCIICastSpec->nout = 1; + UnicodeToASCIICastSpec->flags = NPY_METH_SUPPORTS_UNALIGNED; + UnicodeToASCIICastSpec->casting = NPY_SAME_KIND_CASTING; + UnicodeToASCIICastSpec->dtypes = u2a_dtypes; + UnicodeToASCIICastSpec->slots = u2a_slots; + + PyArray_DTypeMeta **a2u_dtypes = malloc(2 * sizeof(PyArray_DTypeMeta *)); + a2u_dtypes[0] = NULL; + a2u_dtypes[1] = &PyArray_UnicodeDType; + + PyArrayMethod_Spec *ASCIIToUnicodeCastSpec = + malloc(sizeof(PyArrayMethod_Spec)); + + ASCIIToUnicodeCastSpec->name = a2u_name; + ASCIIToUnicodeCastSpec->nin = 1; + ASCIIToUnicodeCastSpec->nout = 1; + ASCIIToUnicodeCastSpec->flags = NPY_METH_SUPPORTS_UNALIGNED; + ASCIIToUnicodeCastSpec->casting = NPY_SAME_KIND_CASTING; + ASCIIToUnicodeCastSpec->dtypes = a2u_dtypes; + ASCIIToUnicodeCastSpec->slots = a2u_slots; + + PyArrayMethod_Spec **casts = malloc(4 * sizeof(PyArrayMethod_Spec *)); + casts[0] = &ASCIIToASCIICastSpec; + casts[1] = UnicodeToASCIICastSpec; + casts[2] = ASCIIToUnicodeCastSpec; + casts[3] = NULL; + + return casts; +} diff --git a/asciidtype/asciidtype/src/casts.h b/asciidtype/asciidtype/src/casts.h index f403fe8c..8dc8605e 100644 --- a/asciidtype/asciidtype/src/casts.h +++ b/asciidtype/asciidtype/src/casts.h @@ -10,6 +10,7 @@ #include "numpy/experimental_dtype_api.h" #include "numpy/ndarraytypes.h" -extern PyArrayMethod_Spec ASCIIToASCIICastSpec; +PyArrayMethod_Spec ** +get_casts(void); #endif /* _NPY_CASTS_H */ diff --git a/asciidtype/asciidtype/src/dtype.c b/asciidtype/asciidtype/src/dtype.c index 980a6994..cb8df318 100644 --- a/asciidtype/asciidtype/src/dtype.c +++ b/asciidtype/asciidtype/src/dtype.c @@ -245,7 +245,7 @@ PyArray_DTypeMeta ASCIIDType = { int init_ascii_dtype(void) { - static PyArrayMethod_Spec *casts[] = {&ASCIIToASCIICastSpec, NULL}; + PyArrayMethod_Spec **casts = get_casts(); PyArrayDTypeMeta_Spec ASCIIDType_DTypeSpec = { .flags = NPY_DT_PARAMETRIC, @@ -273,5 +273,11 @@ init_ascii_dtype(void) ASCIIDType.singleton = singleton; + free(ASCIIDType_DTypeSpec.casts[1]->dtypes); + free(ASCIIDType_DTypeSpec.casts[1]); + free(ASCIIDType_DTypeSpec.casts[2]->dtypes); + free(ASCIIDType_DTypeSpec.casts[2]); + free(ASCIIDType_DTypeSpec.casts); + return 0; } diff --git a/asciidtype/tests/test_asciidtype.py b/asciidtype/tests/test_asciidtype.py index 9c11e541..fdc900ed 100644 --- a/asciidtype/tests/test_asciidtype.py +++ b/asciidtype/tests/test_asciidtype.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from asciidtype import ASCIIDType, ASCIIScalar @@ -50,24 +51,61 @@ def test_creation_truncation(): def test_casting_to_asciidtype(): - arr = np.array(["hello", "this", "is", "an", "array"], dtype=ASCIIDType(5)) - - assert repr(arr.astype(ASCIIDType(7))) == ( - "array(['hello', 'this', 'is', 'an', 'array'], dtype=ASCIIDType(7))" - ) - - assert repr(arr.astype(ASCIIDType(5))) == ( - "array(['hello', 'this', 'is', 'an', 'array'], dtype=ASCIIDType(5))" - ) - - assert repr(arr.astype(ASCIIDType(4))) == ( - "array(['hell', 'this', 'is', 'an', 'arra'], dtype=ASCIIDType(4))" - ) - - assert repr(arr.astype(ASCIIDType(1))) == ( - "array(['h', 't', 'i', 'a', 'a'], dtype=ASCIIDType(1))" - ) - - # assert repr(arr.astype(ASCIIDType())) == ( - # "array(['', '', '', '', ''], dtype=ASCIIDType(0))" - # ) + for dtype in (None, ASCIIDType(5)): + arr = np.array(["this", "is", "an", "array"], dtype=dtype) + + assert repr(arr.astype(ASCIIDType(7))) == ( + "array(['this', 'is', 'an', 'array'], dtype=ASCIIDType(7))" + ) + + assert repr(arr.astype(ASCIIDType(5))) == ( + "array(['this', 'is', 'an', 'array'], dtype=ASCIIDType(5))" + ) + + assert repr(arr.astype(ASCIIDType(4))) == ( + "array(['this', 'is', 'an', 'arra'], dtype=ASCIIDType(4))" + ) + + assert repr(arr.astype(ASCIIDType(1))) == ( + "array(['t', 'i', 'a', 'a'], dtype=ASCIIDType(1))" + ) + + # assert repr(arr.astype(ASCIIDType())) == ( + # "array(['', '', '', '', ''], dtype=ASCIIDType(0))" + # ) + + +def test_unicode_to_ascii_to_unicode(): + arr = np.array(["hello", "this", "is", "an", "array"]) + ascii_arr = arr.astype(ASCIIDType(5)) + round_trip_arr = ascii_arr.astype("U5") + np.testing.assert_array_equal(arr, round_trip_arr) + + +def test_creation_fails_with_non_ascii_characters(): + inps = [ + ["πŸ˜€", "Β‘", "Β©", "ΓΏ"], + ["πŸ˜€", "hello", "some", "ascii"], + ["hello", "some", "ascii", "πŸ˜€"], + ] + for inp in inps: + with pytest.raises( + TypeError, + match="Can only store ASCII text in a ASCIIDType array.", + ): + np.array(inp, dtype=ASCIIDType(5)) + + +def test_casting_fails_with_non_ascii_characters(): + inps = [ + ["πŸ˜€", "Β‘", "Β©", "ΓΏ"], + ["πŸ˜€", "hello", "some", "ascii"], + ["hello", "some", "ascii", "πŸ˜€"], + ] + for inp in inps: + arr = np.array(inp) + with pytest.raises( + TypeError, + match="Can only store ASCII text in a ASCIIDType array.", + ): + arr.astype(ASCIIDType(5)) From d73cec2ec1d4c1824577cca571da90ba936c98fb Mon Sep 17 00:00:00 2001 From: Nathan Goldbaum Date: Tue, 13 Dec 2022 12:21:03 -0700 Subject: [PATCH 07/13] make new_asciidtype_instance take a long instead of PyObject* --- asciidtype/asciidtype/src/dtype.c | 19 ++++++------------- asciidtype/asciidtype/src/dtype.h | 2 +- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/asciidtype/asciidtype/src/dtype.c b/asciidtype/asciidtype/src/dtype.c index cb8df318..cc56f619 100644 --- a/asciidtype/asciidtype/src/dtype.c +++ b/asciidtype/asciidtype/src/dtype.c @@ -45,20 +45,16 @@ get_value(PyObject *scalar) * Internal helper to create new instances */ ASCIIDTypeObject * -new_asciidtype_instance(PyObject *size) +new_asciidtype_instance(long size) { ASCIIDTypeObject *new = (ASCIIDTypeObject *)PyArrayDescr_Type.tp_new( (PyTypeObject *)&ASCIIDType, NULL, NULL); if (new == NULL) { return NULL; } - long size_l = PyLong_AsLong(size); - if (size_l == -1 && PyErr_Occurred()) { - return NULL; - } - new->size = size_l; - new->base.elsize = size_l * sizeof(char); - new->base.alignment = size_l *_Alignof(char); + new->size = size; + new->base.elsize = size * sizeof(char); + new->base.alignment = size *_Alignof(char); return new; } @@ -189,15 +185,12 @@ asciidtype_new(PyTypeObject *NPY_UNUSED(cls), PyObject *args, PyObject *kwds) { static char *kwargs_strs[] = {"size", NULL}; - PyObject *size = NULL; + long size = 0; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "|O:ASCIIDType", kwargs_strs, + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|l:ASCIIDType", kwargs_strs, &size)) { return NULL; } - if (size == NULL) { - size = PyLong_FromLong(0); - } PyObject *ret = (PyObject *)new_asciidtype_instance(size); return ret; diff --git a/asciidtype/asciidtype/src/dtype.h b/asciidtype/asciidtype/src/dtype.h index 232c6e63..47fcf8f1 100644 --- a/asciidtype/asciidtype/src/dtype.h +++ b/asciidtype/asciidtype/src/dtype.h @@ -22,7 +22,7 @@ extern PyArray_DTypeMeta ASCIIDType; extern PyTypeObject *ASCIIScalar_Type; ASCIIDTypeObject * -new_asciidtype_instance(PyObject *size); +new_asciidtype_instance(long size); int init_ascii_dtype(void); From 23ad336c6331190df778570da824c95b6f46eaf5 Mon Sep 17 00:00:00 2001 From: Nathan Goldbaum Date: Tue, 13 Dec 2022 12:25:19 -0700 Subject: [PATCH 08/13] ascii <-> unicode resolve_descriptors return correct descriptors for when the output descriptor is abstract --- asciidtype/asciidtype/src/casts.c | 15 +++++++++++---- asciidtype/tests/test_asciidtype.py | 5 +++-- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/asciidtype/asciidtype/src/casts.c b/asciidtype/asciidtype/src/casts.c index 5d10f2ba..ae617c53 100644 --- a/asciidtype/asciidtype/src/casts.c +++ b/asciidtype/asciidtype/src/casts.c @@ -97,8 +97,11 @@ unicode_to_ascii_resolve_descriptors(PyObject *NPY_UNUSED(self), Py_INCREF(given_descrs[0]); loop_descrs[0] = given_descrs[0]; if (given_descrs[1] == NULL) { - Py_INCREF(given_descrs[0]); - loop_descrs[1] = given_descrs[0]; + // numpy stores unicode as UCS4 (4 bytes wide), so bitshift + // by 2 to get the number of ASCII bytes needed + long size = (loop_descrs[0]->elsize) >> 2; + ASCIIDTypeObject *ascii_descr = new_asciidtype_instance(size); + loop_descrs[1] = (PyArray_Descr *)ascii_descr; } else { Py_INCREF(given_descrs[1]); @@ -245,8 +248,12 @@ ascii_to_unicode_resolve_descriptors(PyObject *NPY_UNUSED(self), Py_INCREF(given_descrs[0]); loop_descrs[0] = given_descrs[0]; if (given_descrs[1] == NULL) { - Py_INCREF(given_descrs[0]); - loop_descrs[1] = given_descrs[0]; + PyArray_Descr *unicode_descr = PyArray_DescrNewFromType(NPY_UNICODE); + long num_ascii_bytes = ((ASCIIDTypeObject *)given_descrs[0])->size; + // numpy stores unicode as UCS4 (4 bytes wide), so bitshift + // by 2 to get the number of bytes needed to store the UCS4 charaters + unicode_descr->elsize = num_ascii_bytes << 2; + loop_descrs[1] = unicode_descr; } else { Py_INCREF(given_descrs[1]); diff --git a/asciidtype/tests/test_asciidtype.py b/asciidtype/tests/test_asciidtype.py index fdc900ed..06d63c98 100644 --- a/asciidtype/tests/test_asciidtype.py +++ b/asciidtype/tests/test_asciidtype.py @@ -78,8 +78,9 @@ def test_casting_to_asciidtype(): def test_unicode_to_ascii_to_unicode(): arr = np.array(["hello", "this", "is", "an", "array"]) ascii_arr = arr.astype(ASCIIDType(5)) - round_trip_arr = ascii_arr.astype("U5") - np.testing.assert_array_equal(arr, round_trip_arr) + for dtype in ["U5", np.unicode_, np.str_]: + round_trip_arr = ascii_arr.astype(dtype) + np.testing.assert_array_equal(arr, round_trip_arr) def test_creation_fails_with_non_ascii_characters(): From e4a019a6854ade5fdb851d1932d2d16d40aa421f Mon Sep 17 00:00:00 2001 From: Nathan Goldbaum Date: Tue, 13 Dec 2022 13:20:29 -0700 Subject: [PATCH 09/13] remove get_loop and fix casting safety --- asciidtype/asciidtype/src/casts.c | 105 ++++++++++------------------ asciidtype/tests/test_asciidtype.py | 67 ++++++++++++++++++ 2 files changed, 105 insertions(+), 67 deletions(-) diff --git a/asciidtype/asciidtype/src/casts.c b/asciidtype/asciidtype/src/casts.c index ae617c53..75a9ba35 100644 --- a/asciidtype/asciidtype/src/casts.c +++ b/asciidtype/asciidtype/src/casts.c @@ -28,13 +28,17 @@ ascii_to_ascii_resolve_descriptors(PyObject *NPY_UNUSED(self), loop_descrs[1] = given_descrs[1]; } - if (((ASCIIDTypeObject *)loop_descrs[0])->size == - ((ASCIIDTypeObject *)loop_descrs[1])->size) { + long in_size = ((ASCIIDTypeObject *)loop_descrs[0])->size; + long out_size = ((ASCIIDTypeObject *)loop_descrs[1])->size; + + if (in_size == out_size) { *view_offset = 0; return NPY_NO_CASTING; } - - return NPY_SAME_KIND_CASTING; + else if (in_size > out_size) { + return NPY_UNSAFE_CASTING; + } + return NPY_SAFE_CASTING; } static int @@ -72,21 +76,6 @@ ascii_to_ascii(PyArrayMethod_Context *context, char *const data[], return 0; } -static int -ascii_to_ascii_get_loop(PyArrayMethod_Context *NPY_UNUSED(context), - int NPY_UNUSED(aligned), - int NPY_UNUSED(move_references), - const npy_intp *NPY_UNUSED(strides), - PyArrayMethod_StridedLoop **out_loop, - NpyAuxData **NPY_UNUSED(out_transferdata), - NPY_ARRAYMETHOD_FLAGS *flags) -{ - *out_loop = (PyArrayMethod_StridedLoop *)&ascii_to_ascii; - - *flags = 0; - return 0; -} - static NPY_CASTING unicode_to_ascii_resolve_descriptors(PyObject *NPY_UNUSED(self), PyArray_DTypeMeta *NPY_UNUSED(dtypes[2]), @@ -96,11 +85,11 @@ unicode_to_ascii_resolve_descriptors(PyObject *NPY_UNUSED(self), { Py_INCREF(given_descrs[0]); loop_descrs[0] = given_descrs[0]; + // numpy stores unicode as UCS4 (4 bytes wide), so bitshift + // by 2 to get the number of ASCII bytes needed + long in_size = (loop_descrs[0]->elsize) >> 2; if (given_descrs[1] == NULL) { - // numpy stores unicode as UCS4 (4 bytes wide), so bitshift - // by 2 to get the number of ASCII bytes needed - long size = (loop_descrs[0]->elsize) >> 2; - ASCIIDTypeObject *ascii_descr = new_asciidtype_instance(size); + ASCIIDTypeObject *ascii_descr = new_asciidtype_instance(in_size); loop_descrs[1] = (PyArray_Descr *)ascii_descr; } else { @@ -108,7 +97,13 @@ unicode_to_ascii_resolve_descriptors(PyObject *NPY_UNUSED(self), loop_descrs[1] = given_descrs[1]; } - return NPY_SAME_KIND_CASTING; + long out_size = ((ASCIIDTypeObject *)loop_descrs[1])->size; + + if (out_size >= in_size) { + return NPY_SAFE_CASTING; + } + + return NPY_UNSAFE_CASTING; } static int @@ -157,12 +152,9 @@ unicode_to_ascii(PyArrayMethod_Context *context, char *const data[], // characters are all ascii, raising an error otherwise for (int i = 0; i < copy_size; i++) { if (ucs4_character_is_ascii(in) == -1) { - PyGILState_STATE gstate; - gstate = PyGILState_Ensure(); PyErr_SetString( PyExc_TypeError, "Can only store ASCII text in a ASCIIDType array."); - PyGILState_Release(gstate); return -1; } // UCS4 character is ascii, so copy first byte of character @@ -180,21 +172,6 @@ unicode_to_ascii(PyArrayMethod_Context *context, char *const data[], return 0; } -static int -unicode_to_ascii_get_loop(PyArrayMethod_Context *NPY_UNUSED(context), - int NPY_UNUSED(aligned), - int NPY_UNUSED(move_references), - const npy_intp *NPY_UNUSED(strides), - PyArrayMethod_StridedLoop **out_loop, - NpyAuxData **NPY_UNUSED(out_transferdata), - NPY_ARRAYMETHOD_FLAGS *flags) -{ - *out_loop = (PyArrayMethod_StridedLoop *)&unicode_to_ascii; - - *flags = 0; - return 0; -} - static int ascii_to_unicode(PyArrayMethod_Context *context, char *const data[], npy_intp const dimensions[], npy_intp const strides[], @@ -247,12 +224,12 @@ ascii_to_unicode_resolve_descriptors(PyObject *NPY_UNUSED(self), { Py_INCREF(given_descrs[0]); loop_descrs[0] = given_descrs[0]; + long in_size = ((ASCIIDTypeObject *)given_descrs[0])->size; if (given_descrs[1] == NULL) { PyArray_Descr *unicode_descr = PyArray_DescrNewFromType(NPY_UNICODE); - long num_ascii_bytes = ((ASCIIDTypeObject *)given_descrs[0])->size; // numpy stores unicode as UCS4 (4 bytes wide), so bitshift // by 2 to get the number of bytes needed to store the UCS4 charaters - unicode_descr->elsize = num_ascii_bytes << 2; + unicode_descr->elsize = in_size << 2; loop_descrs[1] = unicode_descr; } else { @@ -260,51 +237,44 @@ ascii_to_unicode_resolve_descriptors(PyObject *NPY_UNUSED(self), loop_descrs[1] = given_descrs[1]; } - return NPY_SAME_KIND_CASTING; -} + long out_size = (loop_descrs[1]->elsize) >> 2; -static int -ascii_to_unicode_get_loop(PyArrayMethod_Context *NPY_UNUSED(context), - int NPY_UNUSED(aligned), - int NPY_UNUSED(move_references), - const npy_intp *NPY_UNUSED(strides), - PyArrayMethod_StridedLoop **out_loop, - NpyAuxData **NPY_UNUSED(out_transferdata), - NPY_ARRAYMETHOD_FLAGS *flags) -{ - *out_loop = (PyArrayMethod_StridedLoop *)&ascii_to_unicode; + if (out_size >= in_size) { + return NPY_SAFE_CASTING; + } - *flags = 0; - return 0; + return NPY_UNSAFE_CASTING; } static PyArray_DTypeMeta *a2a_dtypes[2] = {NULL, NULL}; static PyType_Slot a2a_slots[] = { {NPY_METH_resolve_descriptors, &ascii_to_ascii_resolve_descriptors}, - {_NPY_METH_get_loop, &ascii_to_ascii_get_loop}, + {NPY_METH_strided_loop, &ascii_to_ascii}, + {NPY_METH_unaligned_strided_loop, &ascii_to_ascii}, {0, NULL}}; PyArrayMethod_Spec ASCIIToASCIICastSpec = { .name = "cast_ASCIIDType_to_ASCIIDType", .nin = 1, .nout = 1, - .flags = NPY_METH_SUPPORTS_UNALIGNED, - .casting = NPY_SAME_KIND_CASTING, + .casting = NPY_UNSAFE_CASTING, + .flags = (NPY_METH_NO_FLOATINGPOINT_ERRORS | + NPY_METH_SUPPORTS_UNALIGNED), .dtypes = a2a_dtypes, .slots = a2a_slots, }; static PyType_Slot u2a_slots[] = { {NPY_METH_resolve_descriptors, &unicode_to_ascii_resolve_descriptors}, - {_NPY_METH_get_loop, &unicode_to_ascii_get_loop}, + {NPY_METH_strided_loop, &unicode_to_ascii}, {0, NULL}}; static char *u2a_name = "cast_Unicode_to_ASCIIDType"; static PyType_Slot a2u_slots[] = { {NPY_METH_resolve_descriptors, &ascii_to_unicode_resolve_descriptors}, - {_NPY_METH_get_loop, &ascii_to_unicode_get_loop}, + {NPY_METH_strided_loop, &ascii_to_unicode}, {0, NULL}}; static char *a2u_name = "cast_ASCIIDType_to_Unicode"; @@ -322,8 +292,9 @@ get_casts(void) UnicodeToASCIICastSpec->name = u2a_name; UnicodeToASCIICastSpec->nin = 1; UnicodeToASCIICastSpec->nout = 1; - UnicodeToASCIICastSpec->flags = NPY_METH_SUPPORTS_UNALIGNED; - UnicodeToASCIICastSpec->casting = NPY_SAME_KIND_CASTING; + UnicodeToASCIICastSpec->casting = NPY_UNSAFE_CASTING, + UnicodeToASCIICastSpec->flags = + (NPY_METH_NO_FLOATINGPOINT_ERRORS | NPY_METH_REQUIRES_PYAPI); UnicodeToASCIICastSpec->dtypes = u2a_dtypes; UnicodeToASCIICastSpec->slots = u2a_slots; @@ -337,8 +308,8 @@ get_casts(void) ASCIIToUnicodeCastSpec->name = a2u_name; ASCIIToUnicodeCastSpec->nin = 1; ASCIIToUnicodeCastSpec->nout = 1; - ASCIIToUnicodeCastSpec->flags = NPY_METH_SUPPORTS_UNALIGNED; - ASCIIToUnicodeCastSpec->casting = NPY_SAME_KIND_CASTING; + ASCIIToUnicodeCastSpec->casting = NPY_UNSAFE_CASTING, + ASCIIToUnicodeCastSpec->flags = NPY_METH_NO_FLOATINGPOINT_ERRORS; ASCIIToUnicodeCastSpec->dtypes = a2u_dtypes; ASCIIToUnicodeCastSpec->slots = a2u_slots; diff --git a/asciidtype/tests/test_asciidtype.py b/asciidtype/tests/test_asciidtype.py index 06d63c98..3fde5ac1 100644 --- a/asciidtype/tests/test_asciidtype.py +++ b/asciidtype/tests/test_asciidtype.py @@ -1,3 +1,5 @@ +import re + import numpy as np import pytest @@ -75,6 +77,71 @@ def test_casting_to_asciidtype(): # ) +def test_casting_safety(): + arr = np.array(["this", "is", "an", "array"]) + assert repr(arr.astype(ASCIIDType(6), casting="safe")) == ( + "array(['this', 'is', 'an', 'array'], dtype=ASCIIDType(6))" + ) + assert repr(arr.astype(ASCIIDType(5), casting="safe")) == ( + "array(['this', 'is', 'an', 'array'], dtype=ASCIIDType(5))" + ) + with pytest.raises( + TypeError, + match=re.escape( + "Cannot cast array data from dtype(' Date: Tue, 13 Dec 2022 13:28:39 -0700 Subject: [PATCH 10/13] use unsigned char in ucs4_character_is_ascii --- asciidtype/asciidtype/src/casts.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/asciidtype/asciidtype/src/casts.c b/asciidtype/asciidtype/src/casts.c index 75a9ba35..42a1472d 100644 --- a/asciidtype/asciidtype/src/casts.c +++ b/asciidtype/asciidtype/src/casts.c @@ -109,9 +109,9 @@ unicode_to_ascii_resolve_descriptors(PyObject *NPY_UNUSED(self), static int ucs4_character_is_ascii(char *buffer) { - int first_char = buffer[0]; + unsigned char first_char = buffer[0]; - if (first_char < 0 || first_char > 127) { + if (first_char > 127) { return -1; } From 3f21ba6255ef14eaa81f6c43cc6d687a63b58e66 Mon Sep 17 00:00:00 2001 From: Nathan Goldbaum Date: Tue, 13 Dec 2022 13:34:52 -0700 Subject: [PATCH 11/13] simplify ascii to unicode casting use PY_UCS types --- asciidtype/asciidtype/src/casts.c | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/asciidtype/asciidtype/src/casts.c b/asciidtype/asciidtype/src/casts.c index 42a1472d..6175d261 100644 --- a/asciidtype/asciidtype/src/casts.c +++ b/asciidtype/asciidtype/src/casts.c @@ -198,16 +198,11 @@ ascii_to_unicode(PyArrayMethod_Context *context, char *const data[], while (N--) { // copy ASCII input to first byte, fill rest with zeros for (int i = 0; i < copy_size; i++) { - *(out + i * 4) = *(in + i); - for (int j = 1; j < 4; j++) { - *(out + i * 4 + j) = '\0'; - } + ((Py_UCS4 *)out)[i] = ((Py_UCS1 *)in)[i]; } // fill all remaining UCS4 characters with zeros for (int i = copy_size; i < out_size; i++) { - for (int j = 0; j < 4; j++) { - *(out + i * 4 + j) = '\0'; - } + ((Py_UCS4 *)out)[i] = (Py_UCS1)0; } in += in_stride; out += out_stride; From 7b883218f9ba3a1f37fb30fa06760ea342f9ea8a Mon Sep 17 00:00:00 2001 From: Nathan Goldbaum Date: Wed, 14 Dec 2022 11:15:41 -0700 Subject: [PATCH 12/13] simplify unicode to ascii cast --- asciidtype/asciidtype/src/casts.c | 26 ++++---------------------- 1 file changed, 4 insertions(+), 22 deletions(-) diff --git a/asciidtype/asciidtype/src/casts.c b/asciidtype/asciidtype/src/casts.c index 6175d261..5ba907b8 100644 --- a/asciidtype/asciidtype/src/casts.c +++ b/asciidtype/asciidtype/src/casts.c @@ -106,24 +106,6 @@ unicode_to_ascii_resolve_descriptors(PyObject *NPY_UNUSED(self), return NPY_UNSAFE_CASTING; } -static int -ucs4_character_is_ascii(char *buffer) -{ - unsigned char first_char = buffer[0]; - - if (first_char > 127) { - return -1; - } - - for (int i = 1; i < 4; i++) { - if (buffer[i] != 0) { - return -1; - } - } - - return 0; -} - static int unicode_to_ascii(PyArrayMethod_Context *context, char *const data[], npy_intp const dimensions[], npy_intp const strides[], @@ -151,15 +133,15 @@ unicode_to_ascii(PyArrayMethod_Context *context, char *const data[], // copy input characters, checking that input UCS4 // characters are all ascii, raising an error otherwise for (int i = 0; i < copy_size; i++) { - if (ucs4_character_is_ascii(in) == -1) { + Py_UCS4 c = ((Py_UCS4 *)in)[i]; + if (c > 127) { PyErr_SetString( PyExc_TypeError, "Can only store ASCII text in a ASCIIDType array."); return -1; } - // UCS4 character is ascii, so copy first byte of character - // into output, ignoring the rest - *(out + i) = *(in + i * 4); + // UCS4 character is ascii, so casting to Py_UCS1 does not truncate + out[i] = (Py_UCS1)c; } // write zeros to remaining ASCII characters (if any) for (int i = copy_size; i < out_size; i++) { From dae600b5fae5922de361e16441fc95f0dc9d5207 Mon Sep 17 00:00:00 2001 From: Nathan Goldbaum Date: Wed, 14 Dec 2022 11:16:14 -0700 Subject: [PATCH 13/13] don't use NPY_METH_REQUIRES_PYAPI --- asciidtype/asciidtype/src/casts.c | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/asciidtype/asciidtype/src/casts.c b/asciidtype/asciidtype/src/casts.c index 5ba907b8..dc504b9a 100644 --- a/asciidtype/asciidtype/src/casts.c +++ b/asciidtype/asciidtype/src/casts.c @@ -269,9 +269,8 @@ get_casts(void) UnicodeToASCIICastSpec->name = u2a_name; UnicodeToASCIICastSpec->nin = 1; UnicodeToASCIICastSpec->nout = 1; - UnicodeToASCIICastSpec->casting = NPY_UNSAFE_CASTING, - UnicodeToASCIICastSpec->flags = - (NPY_METH_NO_FLOATINGPOINT_ERRORS | NPY_METH_REQUIRES_PYAPI); + UnicodeToASCIICastSpec->casting = NPY_UNSAFE_CASTING; + UnicodeToASCIICastSpec->flags = NPY_METH_NO_FLOATINGPOINT_ERRORS; UnicodeToASCIICastSpec->dtypes = u2a_dtypes; UnicodeToASCIICastSpec->slots = u2a_slots; @@ -285,7 +284,7 @@ get_casts(void) ASCIIToUnicodeCastSpec->name = a2u_name; ASCIIToUnicodeCastSpec->nin = 1; ASCIIToUnicodeCastSpec->nout = 1; - ASCIIToUnicodeCastSpec->casting = NPY_UNSAFE_CASTING, + ASCIIToUnicodeCastSpec->casting = NPY_UNSAFE_CASTING; ASCIIToUnicodeCastSpec->flags = NPY_METH_NO_FLOATINGPOINT_ERRORS; ASCIIToUnicodeCastSpec->dtypes = a2u_dtypes; ASCIIToUnicodeCastSpec->slots = a2u_slots;