From 534c2e5431983bf83764138e81a5536ad9bae7ee Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Thu, 20 Apr 2023 14:27:37 +0200 Subject: [PATCH] MAINT: Add a proper implementation for structured zerofill This reorganizes the generic traversal functions for structured dtypes a bit (before it wasn't quite generic). Then uses that for zerofilling. We could get the two a bit closer by also supporting `func==NULL` explicitly for clearing. (I have not includes this here.) The old `FillObjectArray` is still in use now, this is not ideal and the new approach should be duplicated to add an "emptyfill" (same semantics, normally just memset to 0, but for objects we place an explicit `None` object). This is a necessary follow-up to gh-23591. --- numpy/core/src/multiarray/dtype_traversal.c | 267 ++++++++++++++------ numpy/core/src/multiarray/dtype_traversal.h | 9 +- numpy/core/src/multiarray/dtypemeta.c | 3 +- numpy/core/src/multiarray/refcount.c | 7 +- numpy/core/src/multiarray/refcount.h | 3 - numpy/core/src/multiarray/usertypes.c | 5 +- 6 files changed, 202 insertions(+), 92 deletions(-) diff --git a/numpy/core/src/multiarray/dtype_traversal.c b/numpy/core/src/multiarray/dtype_traversal.c index 769c2e015d55..96a00c189a73 100644 --- a/numpy/core/src/multiarray/dtype_traversal.c +++ b/numpy/core/src/multiarray/dtype_traversal.c @@ -24,7 +24,6 @@ #include "alloc.h" #include "array_method.h" #include "dtypemeta.h" -#include "refcount.h" #include "dtype_traversal.h" @@ -32,6 +31,11 @@ #define NPY_LOWLEVEL_BUFFER_BLOCKSIZE 128 +typedef int traverse_func_get( + void *traverse_context, PyArray_Descr *dtype, int aligned, + npy_intp stride, NPY_traverse_info *clear_info, + NPY_ARRAYMETHOD_FLAGS *flags); + /* * Generic Clear function helpers: */ @@ -89,6 +93,45 @@ PyArray_GetClearFunction( } +/* + * Generic zerofill/fill function helper: + */ + +static int +get_zerofill_function( + void *traverse_context, PyArray_Descr *dtype, int aligned, + npy_intp stride, NPY_traverse_info *zerofill_info, + NPY_ARRAYMETHOD_FLAGS *flags) +{ + NPY_traverse_info_init(zerofill_info); + /* not that filling code bothers to check e.g. for floating point flags */ + *flags = PyArrayMethod_MINIMAL_FLAGS; + + get_traverse_loop_function *get_zerofill = NPY_DT_SLOTS(NPY_DTYPE(dtype))->get_fill_zero_loop; + if (get_zerofill == NULL) { + /* Allowed to be NULL (and accept it here) */ + return 0; + } + + if (get_zerofill(traverse_context, dtype, aligned, stride, + &zerofill_info->func, &zerofill_info->auxdata, flags) < 0) { + /* callee should clean up, but make sure outside debug mode */ + assert(zerofill_info->func == NULL); + zerofill_info->func = NULL; + return -1; + } + if (zerofill_info->func == NULL) { + /* Zerofill also may return func=NULL without an error. */ + return 0; + } + + Py_INCREF(dtype); + zerofill_info->descr = dtype; + + return 0; +} + + /****************** Python Object clear ***********************/ static int @@ -157,7 +200,7 @@ npy_object_get_fill_zero_loop(void *NPY_UNUSED(traverse_context), return 0; } -/**************** Structured DType clear funcationality ***************/ +/**************** Structured DType generic funcationality ***************/ /* * Note that legacy user dtypes also make use of this. Someone managed to @@ -172,20 +215,20 @@ npy_object_get_fill_zero_loop(void *NPY_UNUSED(traverse_context), typedef struct { npy_intp src_offset; NPY_traverse_info info; -} single_field_clear_data; +} single_field_traverse_data; typedef struct { NpyAuxData base; npy_intp field_count; - single_field_clear_data fields[]; -} fields_clear_data; + single_field_traverse_data fields[]; +} fields_traverse_data; /* traverse data free function */ static void -fields_clear_data_free(NpyAuxData *data) +fields_traverse_data_free(NpyAuxData *data) { - fields_clear_data *d = (fields_clear_data *)data; + fields_traverse_data *d = (fields_traverse_data *)data; for (npy_intp i = 0; i < d->field_count; ++i) { NPY_traverse_info_xfree(&d->fields[i].info); @@ -196,16 +239,16 @@ fields_clear_data_free(NpyAuxData *data) /* traverse data copy function (untested due to no direct use currently) */ static NpyAuxData * -fields_clear_data_clone(NpyAuxData *data) +fields_traverse_data_clone(NpyAuxData *data) { - fields_clear_data *d = (fields_clear_data *)data; + fields_traverse_data *d = (fields_traverse_data *)data; npy_intp field_count = d->field_count; - npy_intp structsize = sizeof(fields_clear_data) + - field_count * sizeof(single_field_clear_data); + npy_intp structsize = sizeof(fields_traverse_data) + + field_count * sizeof(single_field_traverse_data); /* Allocate the data and populate it */ - fields_clear_data *newdata = PyMem_Malloc(structsize); + fields_traverse_data *newdata = PyMem_Malloc(structsize); if (newdata == NULL) { return NULL; } @@ -213,15 +256,15 @@ fields_clear_data_clone(NpyAuxData *data) newdata->field_count = 0; /* Copy all the fields transfer data */ - single_field_clear_data *in_field = d->fields; - single_field_clear_data *new_field = newdata->fields; + single_field_traverse_data *in_field = d->fields; + single_field_traverse_data *new_field = newdata->fields; for (; newdata->field_count < field_count; newdata->field_count++, in_field++, new_field++) { new_field->src_offset = in_field->src_offset; if (NPY_traverse_info_copy(&new_field->info, &in_field->info) < 0) { - fields_clear_data_free((NpyAuxData *)newdata); + fields_traverse_data_free((NpyAuxData *)newdata); return NULL; } } @@ -236,7 +279,7 @@ traverse_fields_function( char *data, npy_intp N, npy_intp stride, NpyAuxData *auxdata) { - fields_clear_data *d = (fields_clear_data *)auxdata; + fields_traverse_data *d = (fields_traverse_data *)auxdata; npy_intp i, field_count = d->field_count; /* Do the traversing a block at a time for better memory caching */ @@ -245,7 +288,7 @@ traverse_fields_function( for (;;) { if (N > blocksize) { for (i = 0; i < field_count; ++i) { - single_field_clear_data field = d->fields[i]; + single_field_traverse_data field = d->fields[i]; if (field.info.func(traverse_context, field.info.descr, data + field.src_offset, blocksize, stride, field.info.auxdata) < 0) { @@ -257,7 +300,7 @@ traverse_fields_function( } else { for (i = 0; i < field_count; ++i) { - single_field_clear_data field = d->fields[i]; + single_field_traverse_data field = d->fields[i]; if (field.info.func(traverse_context, field.info.descr, data + field.src_offset, N, stride, field.info.auxdata) < 0) { @@ -271,10 +314,11 @@ traverse_fields_function( static int -get_clear_fields_transfer_function( +get_fields_traverse_function( void *traverse_context, PyArray_Descr *dtype, int NPY_UNUSED(aligned), npy_intp stride, traverse_loop_function **out_func, - NpyAuxData **out_auxdata, NPY_ARRAYMETHOD_FLAGS *flags) + NpyAuxData **out_auxdata, NPY_ARRAYMETHOD_FLAGS *flags, + traverse_func_get *traverse_func_get) { PyObject *names, *key, *tup, *title; PyArray_Descr *fld_dtype; @@ -285,19 +329,19 @@ get_clear_fields_transfer_function( field_count = PyTuple_GET_SIZE(dtype->names); /* Over-allocating here: less fields may be used */ - structsize = (sizeof(fields_clear_data) + - field_count * sizeof(single_field_clear_data)); + structsize = (sizeof(fields_traverse_data) + + field_count * sizeof(single_field_traverse_data)); /* Allocate the data and populate it */ - fields_clear_data *data = PyMem_Malloc(structsize); + fields_traverse_data *data = PyMem_Malloc(structsize); if (data == NULL) { PyErr_NoMemory(); return -1; } - data->base.free = &fields_clear_data_free; - data->base.clone = &fields_clear_data_clone; + data->base.free = &fields_traverse_data_free; + data->base.clone = &fields_traverse_data_clone; data->field_count = 0; - single_field_clear_data *field = data->fields; + single_field_traverse_data *field = data->fields; for (i = 0; i < field_count; ++i) { int offset; @@ -307,19 +351,26 @@ get_clear_fields_transfer_function( NPY_AUXDATA_FREE((NpyAuxData *)data); return -1; } - if (PyDataType_REFCHK(fld_dtype)) { - NPY_ARRAYMETHOD_FLAGS clear_flags; - if (get_clear_function( - traverse_context, fld_dtype, 0, - stride, &field->info, &clear_flags) < 0) { - NPY_AUXDATA_FREE((NpyAuxData *)data); - return -1; - } - *flags = PyArrayMethod_COMBINED_FLAGS(*flags, clear_flags); - field->src_offset = offset; - data->field_count++; - field++; + if (traverse_func_get == &get_clear_function + && !PyDataType_REFCHK(fld_dtype)) { + /* No need to do clearing (could change to use NULL return) */ + continue; + } + NPY_ARRAYMETHOD_FLAGS clear_flags; + if (traverse_func_get( + traverse_context, fld_dtype, 0, + stride, &field->info, &clear_flags) < 0) { + NPY_AUXDATA_FREE((NpyAuxData *)data); + return -1; + } + if (field->info.func == NULL) { + /* zerofill allows NULL func as "default" memset to zero */ + continue; } + *flags = PyArrayMethod_COMBINED_FLAGS(*flags, clear_flags); + field->src_offset = offset; + data->field_count++; + field++; } *out_func = &traverse_fields_function; @@ -333,14 +384,14 @@ typedef struct { NpyAuxData base; npy_intp count; NPY_traverse_info info; -} subarray_clear_data; +} subarray_traverse_data; /* traverse data free function */ static void -subarray_clear_data_free(NpyAuxData *data) +subarray_traverse_data_free(NpyAuxData *data) { - subarray_clear_data *d = (subarray_clear_data *)data; + subarray_traverse_data *d = (subarray_traverse_data *)data; NPY_traverse_info_xfree(&d->info); PyMem_Free(d); @@ -351,17 +402,17 @@ subarray_clear_data_free(NpyAuxData *data) * We seem to be neither using nor exposing this right now, so leave it NULL. * (The implementation below should be functional.) */ -#define subarray_clear_data_clone NULL +#define subarray_traverse_data_clone NULL -#ifndef subarray_clear_data_clone +#ifndef subarray_traverse_data_clone /* traverse data copy function */ static NpyAuxData * -subarray_clear_data_clone(NpyAuxData *data) +subarray_traverse_data_clone(NpyAuxData *data) { - subarray_clear_data *d = (subarray_clear_data *)data; + subarray_traverse_data *d = (subarray_traverse_data *)data; /* Allocate the data and populate it */ - subarray_clear_data *newdata = PyMem_Malloc(sizeof(subarray_clear_data)); + subarray_traverse_data *newdata = PyMem_Malloc(sizeof(subarray_traverse_data)); if (newdata == NULL) { return NULL; } @@ -384,7 +435,7 @@ traverse_subarray_func( char *data, npy_intp N, npy_intp stride, NpyAuxData *auxdata) { - subarray_clear_data *subarr_data = (subarray_clear_data *)auxdata; + subarray_traverse_data *subarr_data = (subarray_traverse_data *)auxdata; traverse_loop_function *func = subarr_data->info.func; PyArray_Descr *sub_descr = subarr_data->info.descr; @@ -404,27 +455,35 @@ traverse_subarray_func( static int -get_subarray_clear_func( +get_subarray_traverse_func( void *traverse_context, PyArray_Descr *dtype, int aligned, npy_intp size, npy_intp stride, traverse_loop_function **out_func, - NpyAuxData **out_auxdata, NPY_ARRAYMETHOD_FLAGS *flags) + NpyAuxData **out_auxdata, NPY_ARRAYMETHOD_FLAGS *flags, + traverse_func_get *traverse_func_get) { - subarray_clear_data *auxdata = PyMem_Malloc(sizeof(subarray_clear_data)); + subarray_traverse_data *auxdata = PyMem_Malloc(sizeof(subarray_traverse_data)); if (auxdata == NULL) { PyErr_NoMemory(); return -1; } auxdata->count = size; - auxdata->base.free = &subarray_clear_data_free; - auxdata->base.clone = subarray_clear_data_clone; + auxdata->base.free = &subarray_traverse_data_free; + auxdata->base.clone = subarray_traverse_data_clone; - if (get_clear_function( + if (traverse_func_get( traverse_context, dtype, aligned, dtype->elsize, &auxdata->info, flags) < 0) { PyMem_Free(auxdata); return -1; } + if (auxdata->info.func == NULL) { + /* zerofill allows func to be NULL, in which we need not do anything */ + PyMem_Free(auxdata); + *out_func = NULL; + *out_auxdata = NULL; + return 0; + } *out_func = &traverse_subarray_func; *out_auxdata = (NpyAuxData *)auxdata; @@ -469,9 +528,9 @@ npy_get_clear_void_and_legacy_user_dtype_loop( size = PyArray_MultiplyList(shape.ptr, shape.len); npy_free_cache_dim_obj(shape); - if (get_subarray_clear_func( + if (get_subarray_traverse_func( traverse_context, dtype->subarray->base, aligned, size, stride, - out_func, out_auxdata, flags) < 0) { + out_func, out_auxdata, flags, &get_clear_function) < 0) { return -1; } @@ -479,9 +538,9 @@ npy_get_clear_void_and_legacy_user_dtype_loop( } /* If there are fields, need to do each field */ else if (PyDataType_HASFIELDS(dtype)) { - if (get_clear_fields_transfer_function( + if (get_fields_traverse_function( traverse_context, dtype, aligned, stride, - out_func, out_auxdata, flags) < 0) { + out_func, out_auxdata, flags, &get_clear_function) < 0) { return -1; } return 0; @@ -507,38 +566,86 @@ npy_get_clear_void_and_legacy_user_dtype_loop( /**************** Structured DType zero fill ***************/ + static int -fill_zero_void_with_objects_strided_loop( - void *NPY_UNUSED(traverse_context), PyArray_Descr *descr, - char *data, npy_intp size, npy_intp stride, - NpyAuxData *NPY_UNUSED(auxdata)) +zerofill_fields_function( + void *traverse_context, PyArray_Descr *descr, + char *data, npy_intp N, npy_intp stride, + NpyAuxData *auxdata) { - PyObject *zero = PyLong_FromLong(0); - while (size--) { - _fillobject(data, zero, descr); - data += stride; + npy_intp itemsize = descr->elsize; + + /* + * TODO: We could optimize this by chunking, but since we currently memset + * each element always, just loop manually. + */ + while (N--) { + memset(data, 0, itemsize); + if (traverse_fields_function( + traverse_context, descr, data, 1, stride, auxdata) < 0) { + return -1; + } + data +=stride; } - Py_DECREF(zero); return 0; } - +/* + * Similar to other (e.g. clear) traversal loop getter, but unlike it, we + * do need to take care of zeroing out everything (in principle not gaps). + * So we add a memset before calling the actual traverse function for the + * structured path. + */ NPY_NO_EXPORT int -npy_void_get_fill_zero_loop(void *NPY_UNUSED(traverse_context), - PyArray_Descr *descr, - int NPY_UNUSED(aligned), - npy_intp NPY_UNUSED(fixed_stride), - traverse_loop_function **out_loop, - NpyAuxData **NPY_UNUSED(out_auxdata), - NPY_ARRAYMETHOD_FLAGS *flags) +npy_get_zerofill_void_and_legacy_user_dtype_loop( + void *traverse_context, PyArray_Descr *dtype, int aligned, + npy_intp stride, traverse_loop_function **out_func, + NpyAuxData **out_auxdata, NPY_ARRAYMETHOD_FLAGS *flags) { - *flags = NPY_METH_NO_FLOATINGPOINT_ERRORS; - if (PyDataType_REFCHK(descr)) { - *flags |= NPY_METH_REQUIRES_PYAPI; - *out_loop = &fill_zero_void_with_objects_strided_loop; + if (PyDataType_HASSUBARRAY(dtype)) { + PyArray_Dims shape = {NULL, -1}; + npy_intp size; + + if (!(PyArray_IntpConverter(dtype->subarray->shape, &shape))) { + PyErr_SetString(PyExc_ValueError, + "invalid subarray shape"); + return -1; + } + size = PyArray_MultiplyList(shape.ptr, shape.len); + npy_free_cache_dim_obj(shape); + + if (get_subarray_traverse_func( + traverse_context, dtype->subarray->base, aligned, size, stride, + out_func, out_auxdata, flags, &get_zerofill_function) < 0) { + return -1; + } + + return 0; } - else { - *out_loop = NULL; + /* If there are fields, need to do each field */ + else if (PyDataType_HASFIELDS(dtype)) { + if (get_fields_traverse_function( + traverse_context, dtype, aligned, stride, + out_func, out_auxdata, flags, &get_zerofill_function) < 0) { + return -1; + } + if (((fields_traverse_data *)*out_auxdata)->field_count == 0) { + /* If there are no fields, just return NULL for zerofill */ + NPY_AUXDATA_FREE(*out_auxdata); + *out_auxdata = NULL; + *out_func = NULL; + return 0; + } + /* + * Traversal skips fields that have no custom zeroing, so we need + * to take care of it. + */ + *out_func = &zerofill_fields_function; + return 0; } + + /* Otherwise, assume there is nothing to do (user dtypes reach here) */ + *out_auxdata = NULL; + *out_func = NULL; return 0; } diff --git a/numpy/core/src/multiarray/dtype_traversal.h b/numpy/core/src/multiarray/dtype_traversal.h index a9c185382561..fc12c0f7ba20 100644 --- a/numpy/core/src/multiarray/dtype_traversal.h +++ b/numpy/core/src/multiarray/dtype_traversal.h @@ -29,11 +29,10 @@ npy_object_get_fill_zero_loop( NPY_ARRAYMETHOD_FLAGS *flags); NPY_NO_EXPORT int -npy_void_get_fill_zero_loop( - void *NPY_UNUSED(traverse_context), PyArray_Descr *descr, - int NPY_UNUSED(aligned), npy_intp NPY_UNUSED(fixed_stride), - traverse_loop_function **out_loop, NpyAuxData **NPY_UNUSED(out_auxdata), - NPY_ARRAYMETHOD_FLAGS *flags); +npy_get_zerofill_void_and_legacy_user_dtype_loop( + void *traverse_context, PyArray_Descr *dtype, int aligned, + npy_intp stride, traverse_loop_function **out_func, + NpyAuxData **out_auxdata, NPY_ARRAYMETHOD_FLAGS *flags); /* Helper to deal with calling or nesting simple strided loops */ diff --git a/numpy/core/src/multiarray/dtypemeta.c b/numpy/core/src/multiarray/dtypemeta.c index f8c1b661700a..c833f3f6a4ee 100644 --- a/numpy/core/src/multiarray/dtypemeta.c +++ b/numpy/core/src/multiarray/dtypemeta.c @@ -893,7 +893,8 @@ dtypemeta_wrap_legacy_descriptor(PyArray_Descr *descr) void_discover_descr_from_pyobject); dt_slots->common_instance = void_common_instance; dt_slots->ensure_canonical = void_ensure_canonical; - dt_slots->get_fill_zero_loop = npy_void_get_fill_zero_loop; + dt_slots->get_fill_zero_loop = + npy_get_zerofill_void_and_legacy_user_dtype_loop; dt_slots->get_clear_loop = npy_get_clear_void_and_legacy_user_dtype_loop; } diff --git a/numpy/core/src/multiarray/refcount.c b/numpy/core/src/multiarray/refcount.c index d200957c3630..876bb53e1655 100644 --- a/numpy/core/src/multiarray/refcount.c +++ b/numpy/core/src/multiarray/refcount.c @@ -350,6 +350,11 @@ PyArray_XDECREF(PyArrayObject *mp) return 0; } + +static void +_fillobject(char *optr, PyObject *obj, PyArray_Descr *dtype); + + /*NUMPY_API * Assumes contiguous */ @@ -392,7 +397,7 @@ PyArray_FillObjectArray(PyArrayObject *arr, PyObject *obj) } } -NPY_NO_EXPORT void +static void _fillobject(char *optr, PyObject *obj, PyArray_Descr *dtype) { if (!PyDataType_FLAGCHK(dtype, NPY_ITEM_REFCOUNT)) { diff --git a/numpy/core/src/multiarray/refcount.h b/numpy/core/src/multiarray/refcount.h index 7f39b9ca4c2f..16d34e292fe1 100644 --- a/numpy/core/src/multiarray/refcount.h +++ b/numpy/core/src/multiarray/refcount.h @@ -24,7 +24,4 @@ PyArray_XDECREF(PyArrayObject *mp); NPY_NO_EXPORT void PyArray_FillObjectArray(PyArrayObject *arr, PyObject *obj); -NPY_NO_EXPORT void -_fillobject(char *optr, PyObject *obj, PyArray_Descr *dtype); - #endif /* NUMPY_CORE_SRC_MULTIARRAY_REFCOUNT_H_ */ diff --git a/numpy/core/src/multiarray/usertypes.c b/numpy/core/src/multiarray/usertypes.c index a172343f1dc7..4e413c9b2069 100644 --- a/numpy/core/src/multiarray/usertypes.c +++ b/numpy/core/src/multiarray/usertypes.c @@ -271,10 +271,11 @@ PyArray_RegisterDataType(PyArray_Descr *descr) } if (use_void_clearimpl) { /* See comment where use_void_clearimpl is set... */ - PyArray_DTypeMeta *Void = PyArray_DTypeFromTypeNum(NPY_VOID); NPY_DT_SLOTS(NPY_DTYPE(descr))->get_clear_loop = ( &npy_get_clear_void_and_legacy_user_dtype_loop); - Py_DECREF(Void); + /* Also use the void zerofill since there may be objects */ + NPY_DT_SLOTS(NPY_DTYPE(descr))->get_clear_loop = ( + &npy_get_zerofill_void_and_legacy_user_dtype_loop); } return typenum;