Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT: combine string ufuncs by passing on auxilliary data #25796

Merged
merged 7 commits into from
Feb 15, 2024
166 changes: 32 additions & 134 deletions numpy/_core/src/umath/stringdtype_ufuncs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -961,48 +961,22 @@ string_intp_output_resolve_descriptors(
return NPY_NO_CASTING;
}

template <IMPLEMENTED_UNARY_FUNCTIONS f, typename T>
struct call_buffer_function {
void operator()(const char *buffer, size_t size, char *out) {
Buffer<ENCODING::UTF8> buf((char *)buffer, size);
switch (f) {
case IMPLEMENTED_UNARY_FUNCTIONS::ISALPHA:
*(T *)out = buf.isalpha();
break;
case IMPLEMENTED_UNARY_FUNCTIONS::ISDECIMAL:
*(T *)out = buf.isdecimal();
break;
case IMPLEMENTED_UNARY_FUNCTIONS::ISDIGIT:
*(T *)out = buf.isdigit();
break;
case IMPLEMENTED_UNARY_FUNCTIONS::ISNUMERIC:
*(T *)out = buf.isnumeric();
break;
case IMPLEMENTED_UNARY_FUNCTIONS::ISSPACE:
*(T *)out = buf.isspace();
break;
case IMPLEMENTED_UNARY_FUNCTIONS::STR_LEN:
*(T *)out = buf.num_codepoints();
break;
}
}
};


template <IMPLEMENTED_UNARY_FUNCTIONS f, const char* function_name>
typedef bool (Buffer<ENCODING::UTF8>::*utf8_buffer_method)();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, this is much simpler with the typedef, i'm glad the functor wasn't necessary in the end

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one took my by far the longest - mostly to realize that this method pointers are 16 bytes long. But, yes, once I had it, I was pleased with it!


static int
string_bool_output_unary_strided_loop(
PyArrayMethod_Context *context, char *const data[],
npy_intp const dimensions[],
npy_intp const strides[],
NpyAuxData *NPY_UNUSED(auxdata))
{
const char *ufunc_name = ((PyUFuncObject *)context->caller)->name;
utf8_buffer_method is_it = *(utf8_buffer_method *)(context->method->static_data);
PyArray_StringDTypeObject *descr = (PyArray_StringDTypeObject *)context->descriptors[0];
npy_string_allocator *allocator = NpyString_acquire_allocator(descr);
int has_string_na = descr->has_string_na;
int has_nan_na = descr->has_nan_na;
const npy_static_string *default_string = &descr->default_string;

npy_intp N = dimensions[0];
char *in = data[0];
char *out = data[1];
Expand All @@ -1015,11 +989,12 @@ string_bool_output_unary_strided_loop(
npy_static_string s = {0, NULL};
const char *buffer = NULL;
size_t size = 0;
Buffer<ENCODING::UTF8> buf;

int is_null = NpyString_load(allocator, ps, &s);

if (is_null == -1) {
npy_gil_error(PyExc_MemoryError, "Failed to load string in %s", function_name);
npy_gil_error(PyExc_MemoryError, "Failed to load string in %s", ufunc_name);
goto fail;
}

Expand All @@ -1032,7 +1007,7 @@ string_bool_output_unary_strided_loop(
else if (!has_string_na) {
npy_gil_error(PyExc_ValueError,
"Cannot use the %s function with a null that is "
"not a nan-like value", function_name);
"not a nan-like value", ufunc_name);
goto fail;
}
buffer = default_string->buf;
Expand All @@ -1042,9 +1017,8 @@ string_bool_output_unary_strided_loop(
buffer = s.buf;
size = s.size;
}

call_buffer_function<f, npy_bool> cbf;
cbf(buffer, size, out);
buf = Buffer<ENCODING::UTF8>((char *)buffer, size);
*(npy_bool *)out = (buf.*is_it)();

next_step:
in += in_stride;
Expand All @@ -1060,68 +1034,6 @@ string_bool_output_unary_strided_loop(
return -1;
}

static const char isalpha_name[] = "isalpha";

static int
string_isalpha_strided_loop(PyArrayMethod_Context *context, char *const data[],
npy_intp const dimensions[],
npy_intp const strides[],
NpyAuxData *auxdata)
{
return string_bool_output_unary_strided_loop<IMPLEMENTED_UNARY_FUNCTIONS::ISALPHA, isalpha_name>(
context, data, dimensions, strides, auxdata);
}


static const char isdecimal_name[] = "isdecimal";

static int
string_isdecimal_strided_loop(PyArrayMethod_Context *context, char *const data[],
npy_intp const dimensions[],
npy_intp const strides[],
NpyAuxData *auxdata)
{
return string_bool_output_unary_strided_loop<IMPLEMENTED_UNARY_FUNCTIONS::ISDECIMAL, isdecimal_name>(
context, data, dimensions, strides, auxdata);
}

static const char isdigit_name[] = "isdigit";

static int
string_isdigit_strided_loop(PyArrayMethod_Context *context, char *const data[],
npy_intp const dimensions[],
npy_intp const strides[],
NpyAuxData *auxdata)
{
return string_bool_output_unary_strided_loop<IMPLEMENTED_UNARY_FUNCTIONS::ISDIGIT, isdigit_name>(
context, data, dimensions, strides, auxdata);
}

static const char isnumeric_name[] = "isnumeric";

static int
string_isnumeric_strided_loop(PyArrayMethod_Context *context, char *const data[],
npy_intp const dimensions[],
npy_intp const strides[],
NpyAuxData *auxdata)
{
return string_bool_output_unary_strided_loop<IMPLEMENTED_UNARY_FUNCTIONS::ISNUMERIC, isnumeric_name>(
context, data, dimensions, strides, auxdata);
}

static const char isspace_name[] = "isspace";

static int
string_isspace_strided_loop(PyArrayMethod_Context *context, char *const data[],
npy_intp const dimensions[],
npy_intp const strides[],
NpyAuxData *auxdata)
{
return string_bool_output_unary_strided_loop<IMPLEMENTED_UNARY_FUNCTIONS::ISSPACE, isspace_name>(
context, data, dimensions, strides, auxdata);
}


static int
string_strlen_strided_loop(PyArrayMethod_Context *context, char *const data[],
npy_intp const dimensions[],
Expand All @@ -1145,7 +1057,7 @@ string_strlen_strided_loop(PyArrayMethod_Context *context, char *const data[],
npy_static_string s = {0, NULL};
const char *buffer = NULL;
size_t size = 0;

Buffer<ENCODING::UTF8> buf;
int is_null = NpyString_load(allocator, ps, &s);

if (is_null == -1) {
Expand All @@ -1166,9 +1078,8 @@ string_strlen_strided_loop(PyArrayMethod_Context *context, char *const data[],
buffer = s.buf;
size = s.size;
}

call_buffer_function<IMPLEMENTED_UNARY_FUNCTIONS::STR_LEN, npy_intp> cbf;
cbf(buffer, size, out);
buf = Buffer<ENCODING::UTF8>((char *)buffer, size);
*(npy_intp *)out = buf.num_codepoints();

next_step:
in += in_stride;
Expand Down Expand Up @@ -2245,39 +2156,26 @@ init_stringdtype_ufuncs(PyObject *umath)
return -1;
}

if (init_ufunc(umath, "isalpha", bool_output_dtypes,
&string_bool_output_resolve_descriptors,
&string_isalpha_strided_loop, 1, 1, NPY_NO_CASTING,
(NPY_ARRAYMETHOD_FLAGS) 0, NULL) < 0) {
return -1;
}

if (init_ufunc(umath, "isdecimal", bool_output_dtypes,
&string_bool_output_resolve_descriptors,
&string_isdecimal_strided_loop, 1, 1, NPY_NO_CASTING,
(NPY_ARRAYMETHOD_FLAGS) 0, NULL) < 0) {
return -1;
}

if (init_ufunc(umath, "isdigit", bool_output_dtypes,
&string_bool_output_resolve_descriptors,
&string_isdigit_strided_loop, 1, 1, NPY_NO_CASTING,
(NPY_ARRAYMETHOD_FLAGS) 0, NULL) < 0) {
return -1;
}

if (init_ufunc(umath, "isnumeric", bool_output_dtypes,
&string_bool_output_resolve_descriptors,
&string_isnumeric_strided_loop, 1, 1, NPY_NO_CASTING,
(NPY_ARRAYMETHOD_FLAGS) 0, NULL) < 0) {
return -1;
}

if (init_ufunc(umath, "isspace", bool_output_dtypes,
&string_bool_output_resolve_descriptors,
&string_isspace_strided_loop, 1, 1, NPY_NO_CASTING,
(NPY_ARRAYMETHOD_FLAGS) 0, NULL) < 0) {
return -1;
const char *unary_loop_names[] = {
"isalpha", "isdecimal", "isdigit", "isnumeric", "isspace",
};
// Note: these are member function pointers, not regular function
// function pointers, so we need to pass on their address, not value.
static utf8_buffer_method unary_loop_buffer_methods[] = {
&Buffer<ENCODING::UTF8>::isalpha,
&Buffer<ENCODING::UTF8>::isdecimal,
&Buffer<ENCODING::UTF8>::isdigit,
&Buffer<ENCODING::UTF8>::isnumeric,
&Buffer<ENCODING::UTF8>::isspace,
};
for (int i=0; i<5; i++) {
if (init_ufunc(umath, unary_loop_names[i], bool_output_dtypes,
&string_bool_output_resolve_descriptors,
&string_bool_output_unary_strided_loop, 1, 1, NPY_NO_CASTING,
(NPY_ARRAYMETHOD_FLAGS) 0,
&unary_loop_buffer_methods[i]) < 0) {
return -1;
}
}

PyArray_DTypeMeta *intp_output_dtypes[] = {
Expand Down