Skip to content

Commit

Permalink
ENH: allow NEP 42 dtypes to work with np.char (#22863)
Browse files Browse the repository at this point in the history
This makes it possible for new-style NEP 42 string dtypes like ASCIIDType to work with the functions in np.char, this has leads to some mild modification (stricter behavior in bad paths).

It will only work with dtypes with a scalar that subclasses str or bytes. I also assume that you can create instances of the user dtype from python like dtype_instance = CustomDType(size_in_bytes). This is a pretty big assumption about the API of the dtype, I'm not sure offhand how I can do this more portably or more safely.

I also added a new macro, NPY_DT_is_user_defined, which checks dtype->type_num == -1, which is currently true for all custom dtypes using the experimental dtype API. This new macro is needed because NPY_DT_is_legacy will return false for np.void.

This is only tested via the user dtypes currently.
  • Loading branch information
ngoldbaum committed Jan 10, 2023
1 parent df1ee34 commit 737b064
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 43 deletions.
11 changes: 4 additions & 7 deletions doc/release/upcoming_changes/22863.new_feature.rst
@@ -1,7 +1,4 @@
String dtype instances can be created from the string abstract dtype classes
----------------------------------------------------------------------------
It is now possible to create a string dtype instance with a size without
using the string name of the dtype. For example, ``type(np.dtype('U'))(8)``
will create a dtype that is equivalent to ``np.dtype('U8')``. This feature
is most useful when writing generic code dealing with string dtype
classes.
String functions in np.char are compatible with NEP 42 custom dtypes
--------------------------------------------------------------------
Custom dtypes that represent unicode strings or byte strings can now be
passed to the string functions in np.char.
7 changes: 7 additions & 0 deletions doc/release/upcoming_changes/22963.new_feature.rst
@@ -0,0 +1,7 @@
String dtype instances can be created from the string abstract dtype classes
----------------------------------------------------------------------------
It is now possible to create a string dtype instance with a size without
using the string name of the dtype. For example, ``type(np.dtype('U'))(8)``
will create a dtype that is equivalent to ``np.dtype('U8')``. This feature
is most useful when writing generic code dealing with string dtype
classes.
77 changes: 45 additions & 32 deletions numpy/core/defchararray.py
Expand Up @@ -46,26 +46,29 @@
overrides.array_function_dispatch, module='numpy.char')


def _use_unicode(*args):
"""
Helper function for determining the output type of some string
operations.
def _is_unicode(arr):
"""Returns True if arr is a string or a string array with a dtype that
represents a unicode string, otherwise returns False.
For an operation on two ndarrays, if at least one is unicode, the
result should be unicode.
"""
for x in args:
if (isinstance(x, str) or
issubclass(numpy.asarray(x).dtype.type, unicode_)):
return unicode_
return string_
if (isinstance(arr, str) or
issubclass(numpy.asarray(arr).dtype.type, str)):
return True
return False


def _to_string_or_unicode_array(result):
def _to_string_or_unicode_array(result, output_dtype_like=None):
"""
Helper function to cast a result back into a string or unicode array
if an object array must be used as an intermediary.
Helper function to cast a result back into an array
with the appropriate dtype if an object array must be used
as an intermediary.
"""
return numpy.asarray(result.tolist())
ret = numpy.asarray(result.tolist())
dtype = getattr(output_dtype_like, 'dtype', None)
if dtype is not None:
return ret.astype(type(dtype)(_get_num_chars(ret)), copy=False)
return ret


def _clean_args(*args):
"""
Expand Down Expand Up @@ -319,9 +322,19 @@ def add(x1, x2):
arr1 = numpy.asarray(x1)
arr2 = numpy.asarray(x2)
out_size = _get_num_chars(arr1) + _get_num_chars(arr2)
dtype = _use_unicode(arr1, arr2)
return _vec_string(arr1, (dtype, out_size), '__add__', (arr2,))

if type(arr1.dtype) != type(arr2.dtype):
# Enforce this for now. The solution to it will be implement add
# as a ufunc. It never worked right on Python 3: bytes + unicode gave
# nonsense unicode + bytes errored, and unicode + object used the
# object dtype itemsize as num chars (worked on short strings).
# bytes + void worked but promoting void->bytes is dubious also.
raise TypeError(
"np.char.add() requires both arrays of the same dtype kind, but "
f"got dtypes: '{arr1.dtype}' and '{arr2.dtype}' (the few cases "
"where this used to work often lead to incorrect results).")

return _vec_string(arr1, type(arr1.dtype)(out_size), '__add__', (arr2,))

def _multiply_dispatcher(a, i):
return (a,)
Expand Down Expand Up @@ -371,7 +384,7 @@ def multiply(a, i):
raise ValueError("Can only multiply by integers")
out_size = _get_num_chars(a_arr) * max(int(i_arr.max()), 0)
return _vec_string(
a_arr, (a_arr.dtype.type, out_size), '__mul__', (i_arr,))
a_arr, type(a_arr.dtype)(out_size), '__mul__', (i_arr,))


def _mod_dispatcher(a, values):
Expand Down Expand Up @@ -403,7 +416,7 @@ def mod(a, values):
"""
return _to_string_or_unicode_array(
_vec_string(a, object_, '__mod__', (values,)))
_vec_string(a, object_, '__mod__', (values,)), a)


@array_function_dispatch(_unary_op_dispatcher)
Expand Down Expand Up @@ -499,7 +512,7 @@ def center(a, width, fillchar=' '):
if numpy.issubdtype(a_arr.dtype, numpy.string_):
fillchar = asbytes(fillchar)
return _vec_string(
a_arr, (a_arr.dtype.type, size), 'center', (width_arr, fillchar))
a_arr, type(a_arr.dtype)(size), 'center', (width_arr, fillchar))


def _count_dispatcher(a, sub, start=None, end=None):
Expand Down Expand Up @@ -723,7 +736,7 @@ def expandtabs(a, tabsize=8):
"""
return _to_string_or_unicode_array(
_vec_string(a, object_, 'expandtabs', (tabsize,)))
_vec_string(a, object_, 'expandtabs', (tabsize,)), a)


@array_function_dispatch(_count_dispatcher)
Expand Down Expand Up @@ -1043,7 +1056,7 @@ def join(sep, seq):
"""
return _to_string_or_unicode_array(
_vec_string(sep, object_, 'join', (seq,)))
_vec_string(sep, object_, 'join', (seq,)), seq)



Expand Down Expand Up @@ -1084,7 +1097,7 @@ def ljust(a, width, fillchar=' '):
if numpy.issubdtype(a_arr.dtype, numpy.string_):
fillchar = asbytes(fillchar)
return _vec_string(
a_arr, (a_arr.dtype.type, size), 'ljust', (width_arr, fillchar))
a_arr, type(a_arr.dtype)(size), 'ljust', (width_arr, fillchar))


@array_function_dispatch(_unary_op_dispatcher)
Expand Down Expand Up @@ -1218,7 +1231,7 @@ def partition(a, sep):
"""
return _to_string_or_unicode_array(
_vec_string(a, object_, 'partition', (sep,)))
_vec_string(a, object_, 'partition', (sep,)), a)


def _replace_dispatcher(a, old, new, count=None):
Expand Down Expand Up @@ -1263,8 +1276,7 @@ def replace(a, old, new, count=None):
array(['The dwash was fresh', 'Thwas was it'], dtype='<U19')
"""
return _to_string_or_unicode_array(
_vec_string(
a, object_, 'replace', [old, new] + _clean_args(count)))
_vec_string(a, object_, 'replace', [old, new] + _clean_args(count)), a)


@array_function_dispatch(_count_dispatcher)
Expand Down Expand Up @@ -1363,7 +1375,7 @@ def rjust(a, width, fillchar=' '):
if numpy.issubdtype(a_arr.dtype, numpy.string_):
fillchar = asbytes(fillchar)
return _vec_string(
a_arr, (a_arr.dtype.type, size), 'rjust', (width_arr, fillchar))
a_arr, type(a_arr.dtype)(size), 'rjust', (width_arr, fillchar))


@array_function_dispatch(_partition_dispatcher)
Expand Down Expand Up @@ -1399,7 +1411,7 @@ def rpartition(a, sep):
"""
return _to_string_or_unicode_array(
_vec_string(a, object_, 'rpartition', (sep,)))
_vec_string(a, object_, 'rpartition', (sep,)), a)


def _split_dispatcher(a, sep=None, maxsplit=None):
Expand Down Expand Up @@ -1829,7 +1841,7 @@ def zfill(a, width):
width_arr = numpy.asarray(width)
size = int(numpy.max(width_arr.flat))
return _vec_string(
a_arr, (a_arr.dtype.type, size), 'zfill', (width_arr,))
a_arr, type(a_arr.dtype)(size), 'zfill', (width_arr,))


@array_function_dispatch(_unary_op_dispatcher)
Expand Down Expand Up @@ -1864,7 +1876,7 @@ def isnumeric(a):
array([ True, False, False, False, False])
"""
if _use_unicode(a) != unicode_:
if not _is_unicode(a):
raise TypeError("isnumeric is only available for Unicode strings and arrays")
return _vec_string(a, bool_, 'isnumeric')

Expand Down Expand Up @@ -1901,8 +1913,9 @@ def isdecimal(a):
array([ True, False, False, False])
"""
if _use_unicode(a) != unicode_:
raise TypeError("isnumeric is only available for Unicode strings and arrays")
if not _is_unicode(a):
raise TypeError(
"isdecimal is only available for Unicode strings and arrays")
return _vec_string(a, bool_, 'isdecimal')


Expand Down
1 change: 1 addition & 0 deletions numpy/core/src/multiarray/dtypemeta.h
Expand Up @@ -81,6 +81,7 @@ typedef struct {
#define NPY_DT_is_legacy(dtype) (((dtype)->flags & NPY_DT_LEGACY) != 0)
#define NPY_DT_is_abstract(dtype) (((dtype)->flags & NPY_DT_ABSTRACT) != 0)
#define NPY_DT_is_parametric(dtype) (((dtype)->flags & NPY_DT_PARAMETRIC) != 0)
#define NPY_DT_is_user_defined(dtype) (((dtype)->type_num == -1))

/*
* Macros for convenient classmethod calls, since these require
Expand Down
44 changes: 40 additions & 4 deletions numpy/core/src/multiarray/multiarraymodule.c
Expand Up @@ -3785,6 +3785,34 @@ format_longfloat(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject *kwds)
TrimMode_LeaveOneZero, -1, -1);
}

/*
* returns 1 if array is a user-defined string dtype, sets an error and
* returns 0 otherwise
*/
static int _is_user_defined_string_array(PyArrayObject* array)
{
if (NPY_DT_is_user_defined(PyArray_DESCR(array))) {
PyTypeObject* scalar_type = NPY_DTYPE(PyArray_DESCR(array))->scalar_type;
if (PyType_IsSubtype(scalar_type, &PyBytes_Type) ||
PyType_IsSubtype(scalar_type, &PyUnicode_Type)) {
return 1;
}
else {
PyErr_SetString(
PyExc_TypeError,
"string comparisons are only allowed for dtypes with a "
"scalar type that is a subtype of str or bytes.");
return 0;
}
}
else {
PyErr_SetString(
PyExc_TypeError,
"string operation on non-string array");
return 0;
}
}


/*
* The only purpose of this function is that it allows the "rstrip".
Expand Down Expand Up @@ -3861,6 +3889,9 @@ compare_chararrays(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject *kwds)
else {
PyErr_SetString(PyExc_TypeError,
"comparison of non-string arrays");
Py_DECREF(newarr);
Py_DECREF(newoth);
return NULL;
}
Py_DECREF(newarr);
Py_DECREF(newoth);
Expand Down Expand Up @@ -4061,10 +4092,15 @@ _vec_string(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject *NPY_UNUSED(kw
method = PyObject_GetAttr((PyObject *)&PyUnicode_Type, method_name);
}
else {
PyErr_SetString(PyExc_TypeError,
"string operation on non-string array");
Py_DECREF(type);
goto err;
if (_is_user_defined_string_array(char_array)) {
PyTypeObject* scalar_type =
NPY_DTYPE(PyArray_DESCR(char_array))->scalar_type;
method = PyObject_GetAttr((PyObject*)scalar_type, method_name);
}
else {
Py_DECREF(type);
goto err;
}
}
if (method == NULL) {
Py_DECREF(type);
Expand Down
14 changes: 14 additions & 0 deletions numpy/core/tests/test_defchararray.py
@@ -1,3 +1,5 @@
import pytest

import numpy as np
from numpy.core.multiarray import _vec_string
from numpy.testing import (
Expand Down Expand Up @@ -670,3 +672,15 @@ def test_empty_indexing():
# empty chararray instead of a chararray with a single empty string in it.
s = np.chararray((4,))
assert_(s[[]].size == 0)


@pytest.mark.parametrize(["dt1", "dt2"],
[("S", "U"), ("U", "S"), ("S", "O"), ("U", "O"),
("S", "d"), ("S", "V")])
def test_add_types(dt1, dt2):
arr1 = np.array([1234234], dtype=dt1)
# If the following fails, e.g. use a number and test "V" explicitly
arr2 = np.array([b"423"], dtype=dt2)
with pytest.raises(TypeError,
match=f".*same dtype kind.*{arr1.dtype}.*{arr2.dtype}"):
np.char.add(arr1, arr2)

0 comments on commit 737b064

Please sign in to comment.