diff --git a/numpy/core/_asarray.py b/numpy/core/_asarray.py index cbaab8c3f960..a9abc5a88ca3 100644 --- a/numpy/core/_asarray.py +++ b/numpy/core/_asarray.py @@ -24,10 +24,6 @@ } -def _require_dispatcher(a, dtype=None, requirements=None, *, like=None): - return (like,) - - @set_array_function_like_doc @set_module('numpy') def require(a, dtype=None, requirements=None, *, like=None): @@ -100,10 +96,10 @@ def require(a, dtype=None, requirements=None, *, like=None): """ if like is not None: return _require_with_like( + like, a, dtype=dtype, requirements=requirements, - like=like, ) if not requirements: @@ -135,6 +131,4 @@ def require(a, dtype=None, requirements=None, *, like=None): return arr -_require_with_like = array_function_dispatch( - _require_dispatcher -)(require) +_require_with_like = array_function_dispatch()(require) diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 577f8e7cd53c..91e35c684f65 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -130,10 +130,6 @@ def zeros_like(a, dtype=None, order='K', subok=True, shape=None): return res -def _ones_dispatcher(shape, dtype=None, order=None, *, like=None): - return(like,) - - @set_array_function_like_doc @set_module('numpy') def ones(shape, dtype=None, order='C', *, like=None): @@ -187,16 +183,13 @@ def ones(shape, dtype=None, order='C', *, like=None): """ if like is not None: - return _ones_with_like(shape, dtype=dtype, order=order, like=like) + return _ones_with_like(like, shape, dtype=dtype, order=order) a = empty(shape, dtype, order) multiarray.copyto(a, 1, casting='unsafe') return a - -_ones_with_like = array_function_dispatch( - _ones_dispatcher -)(ones) +_ones_with_like = array_function_dispatch()(ones) def _ones_like_dispatcher(a, dtype=None, order=None, subok=None, shape=None): @@ -323,7 +316,7 @@ def full(shape, fill_value, dtype=None, order='C', *, like=None): """ if like is not None: - return _full_with_like(shape, fill_value, dtype=dtype, order=order, like=like) + return _full_with_like(like, shape, fill_value, dtype=dtype, order=order) if dtype is None: fill_value = asarray(fill_value) @@ -333,9 +326,7 @@ def full(shape, fill_value, dtype=None, order='C', *, like=None): return a -_full_with_like = array_function_dispatch( - _full_dispatcher -)(full) +_full_with_like = array_function_dispatch()(full) def _full_like_dispatcher(a, fill_value, dtype=None, order=None, subok=None, shape=None): @@ -1778,10 +1769,6 @@ def indices(dimensions, dtype=int, sparse=False): return res -def _fromfunction_dispatcher(function, shape, *, dtype=None, like=None, **kwargs): - return (like,) - - @set_array_function_like_doc @set_module('numpy') def fromfunction(function, shape, *, dtype=float, like=None, **kwargs): @@ -1847,15 +1834,13 @@ def fromfunction(function, shape, *, dtype=float, like=None, **kwargs): """ if like is not None: - return _fromfunction_with_like(function, shape, dtype=dtype, like=like, **kwargs) + return _fromfunction_with_like(like, function, shape, dtype=dtype, **kwargs) args = indices(shape, dtype=dtype) return function(*args, **kwargs) -_fromfunction_with_like = array_function_dispatch( - _fromfunction_dispatcher -)(fromfunction) +_fromfunction_with_like = array_function_dispatch()(fromfunction) def _frombuffer(buf, dtype, shape, order): @@ -2130,10 +2115,6 @@ def _maketup(descr, val): return tuple(res) -def _identity_dispatcher(n, dtype=None, *, like=None): - return (like,) - - @set_array_function_like_doc @set_module('numpy') def identity(n, dtype=None, *, like=None): @@ -2168,15 +2149,13 @@ def identity(n, dtype=None, *, like=None): """ if like is not None: - return _identity_with_like(n, dtype=dtype, like=like) + return _identity_with_like(like, n, dtype=dtype) from numpy import eye return eye(n, dtype=dtype, like=like) -_identity_with_like = array_function_dispatch( - _identity_dispatcher -)(identity) +_identity_with_like = array_function_dispatch()(identity) def _allclose_dispatcher(a, b, rtol=None, atol=None, equal_nan=None): diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py index 46e1fbe2c0fa..25892d5de458 100644 --- a/numpy/core/overrides.py +++ b/numpy/core/overrides.py @@ -6,7 +6,7 @@ from .._utils import set_module from .._utils._inspect import getargspec from numpy.core._multiarray_umath import ( - add_docstring, implement_array_function, _get_implementing_args) + add_docstring, _get_implementing_args, _ArrayFunctionDispatcher) ARRAY_FUNCTIONS = set() @@ -33,40 +33,33 @@ def set_array_function_like_doc(public_api): add_docstring( - implement_array_function, + _ArrayFunctionDispatcher, """ - Implement a function with checks for __array_function__ overrides. + Class to wrap functions with checks for __array_function__ overrides. All arguments are required, and can only be passed by position. Parameters ---------- + dispatcher : function or None + The dispatcher function that returns a single sequence-like object + of all arguments relevant. It must have the same signature (except + the default values) as the actual implementation. + If ``None``, this is a ``like=`` dispatcher and the + ``_ArrayFunctionDispatcher`` must be called with ``like`` as the + first (additional and positional) argument. implementation : function Function that implements the operation on NumPy array without - overrides when called like ``implementation(*args, **kwargs)``. - public_api : function - Function exposed by NumPy's public API originally called like - ``public_api(*args, **kwargs)`` on which arguments are now being - checked. - relevant_args : iterable - Iterable of arguments to check for __array_function__ methods. - args : tuple - Arbitrary positional arguments originally passed into ``public_api``. - kwargs : dict - Arbitrary keyword arguments originally passed into ``public_api``. + overrides when called like. - Returns - ------- - Result from calling ``implementation()`` or an ``__array_function__`` - method, as appropriate. - - Raises - ------ - TypeError : if no implementation is found. + Attributes + ---------- + _implementation : function + The original implementation passed in. """) -# exposed for testing purposes; used internally by implement_array_function +# exposed for testing purposes; used internally by _ArrayFunctionDispatcher add_docstring( _get_implementing_args, """ @@ -110,7 +103,7 @@ def verify_matching_signatures(implementation, dispatcher): 'default argument values') -def array_function_dispatch(dispatcher, module=None, verify=True, +def array_function_dispatch(dispatcher=None, module=None, verify=True, docs_from_dispatcher=False): """Decorator for adding dispatch with the __array_function__ protocol. @@ -118,10 +111,14 @@ def array_function_dispatch(dispatcher, module=None, verify=True, Parameters ---------- - dispatcher : callable + dispatcher : callable or None Function that when called like ``dispatcher(*args, **kwargs)`` with arguments from the NumPy function call returns an iterable of array-like arguments to check for ``__array_function__``. + + If `None`, the first argument is used as the single `like=` argument + and not passed on. A function implementing `like=` must call its + dispatcher with `like` as the first non-keyword argument. module : str, optional __module__ attribute to set on new function, e.g., ``module='numpy'``. By default, module is copied from the decorated function. @@ -154,45 +151,28 @@ def decorator(implementation): def decorator(implementation): if verify: - verify_matching_signatures(implementation, dispatcher) + if dispatcher is not None: + verify_matching_signatures(implementation, dispatcher) + else: + # Using __code__ directly similar to verify_matching_signature + co = implementation.__code__ + last_arg = co.co_argcount + co.co_kwonlyargcount - 1 + last_arg = co.co_varnames[last_arg] + if last_arg != "like" or co.co_kwonlyargcount == 0: + raise RuntimeError( + "__array_function__ expects `like=` to be the last " + "argument and a keyword-only argument. " + f"{implementation} does not seem to comply.") if docs_from_dispatcher: add_docstring(implementation, dispatcher.__doc__) - @functools.wraps(implementation) - def public_api(*args, **kwargs): - try: - relevant_args = dispatcher(*args, **kwargs) - except TypeError as exc: - # Try to clean up a signature related TypeError. Such an - # error will be something like: - # dispatcher.__name__() got an unexpected keyword argument - # - # So replace the dispatcher name in this case. In principle - # TypeErrors may be raised from _within_ the dispatcher, so - # we check that the traceback contains a string that starts - # with the name. (In principle we could also check the - # traceback length, as it would be deeper.) - msg = exc.args[0] - disp_name = dispatcher.__name__ - if not isinstance(msg, str) or not msg.startswith(disp_name): - raise - - # Replace with the correct name and re-raise: - new_msg = msg.replace(disp_name, public_api.__name__) - raise TypeError(new_msg) from None - - return implement_array_function( - implementation, public_api, relevant_args, args, kwargs) - - public_api.__code__ = public_api.__code__.replace( - co_name=implementation.__name__, - co_filename='<__array_function__ internals>') + public_api = _ArrayFunctionDispatcher(dispatcher, implementation) + public_api = functools.wraps(implementation)(public_api) + if module is not None: public_api.__module__ = module - public_api._implementation = implementation - ARRAY_FUNCTIONS.add(public_api) return public_api diff --git a/numpy/core/src/multiarray/arrayfunction_override.c b/numpy/core/src/multiarray/arrayfunction_override.c index 2bb3fbe2848b..e27bb516ecda 100644 --- a/numpy/core/src/multiarray/arrayfunction_override.c +++ b/numpy/core/src/multiarray/arrayfunction_override.c @@ -1,11 +1,15 @@ #define NPY_NO_DEPRECATED_API NPY_API_VERSION #define _MULTIARRAYMODULE +#include +#include "structmember.h" + #include "npy_pycompat.h" #include "get_attr_string.h" #include "npy_import.h" #include "multiarraymodule.h" +#include "arrayfunction_override.h" /* Return the ndarray.__array_function__ method. */ static PyObject * @@ -200,183 +204,67 @@ call_array_function(PyObject* argument, PyObject* method, } -/** - * Internal handler for the array-function dispatching. The helper returns - * either the result, or NotImplemented (as a borrowed reference). - * - * @param public_api The public API symbol used for dispatching - * @param relevant_args Arguments which may implement __array_function__ - * @param args Original arguments - * @param kwargs Original keyword arguments - * - * @returns The result of the dispatched version, or a borrowed reference - * to NotImplemented to indicate the default implementation should - * be used. + +/* + * Helper to convert from vectorcall convention, since the protocol requires + * args and kwargs to be passed as tuple and dict explicitly. + * We always pass a dict, so always returns it. */ -static PyObject * -array_implement_array_function_internal( - PyObject *public_api, PyObject *relevant_args, - PyObject *args, PyObject *kwargs) +static int +get_args_and_kwargs( + PyObject *const *fast_args, Py_ssize_t len_args, PyObject *kwnames, + PyObject **out_args, PyObject **out_kwargs) { - PyObject *implementing_args[NPY_MAXARGS]; - PyObject *array_function_methods[NPY_MAXARGS]; - PyObject *types = NULL; - - PyObject *result = NULL; - - static PyObject *errmsg_formatter = NULL; + len_args = PyVectorcall_NARGS(len_args); + PyObject *args = PyTuple_New(len_args); + PyObject *kwargs = NULL; - relevant_args = PySequence_Fast( - relevant_args, - "dispatcher for __array_function__ did not return an iterable"); - if (relevant_args == NULL) { - return NULL; + if (args == NULL) { + return -1; } - - /* Collect __array_function__ implementations */ - int num_implementing_args = get_implementing_args_and_methods( - relevant_args, implementing_args, array_function_methods); - if (num_implementing_args == -1) { - goto cleanup; + for (Py_ssize_t i = 0; i < len_args; i++) { + Py_INCREF(fast_args[i]); + PyTuple_SET_ITEM(args, i, fast_args[i]); } - - /* - * Handle the typical case of no overrides. This is merely an optimization - * if some arguments are ndarray objects, but is also necessary if no - * arguments implement __array_function__ at all (e.g., if they are all - * built-in types). - */ - int any_overrides = 0; - for (int j = 0; j < num_implementing_args; j++) { - if (!is_default_array_function(array_function_methods[j])) { - any_overrides = 1; - break; - } - } - if (!any_overrides) { - /* - * When the default implementation should be called, return - * `Py_NotImplemented` to indicate this. - */ - result = Py_NotImplemented; - goto cleanup; - } - - /* - * Create a Python object for types. - * We use a tuple, because it's the fastest Python collection to create - * and has the bonus of being immutable. - */ - types = PyTuple_New(num_implementing_args); - if (types == NULL) { - goto cleanup; - } - for (int j = 0; j < num_implementing_args; j++) { - PyObject *arg_type = (PyObject *)Py_TYPE(implementing_args[j]); - Py_INCREF(arg_type); - PyTuple_SET_ITEM(types, j, arg_type); - } - - /* Call __array_function__ methods */ - for (int j = 0; j < num_implementing_args; j++) { - PyObject *argument = implementing_args[j]; - PyObject *method = array_function_methods[j]; - - /* - * We use `public_api` instead of `implementation` here so - * __array_function__ implementations can do equality/identity - * comparisons. - */ - result = call_array_function( - argument, method, public_api, types, args, kwargs); - - if (result == Py_NotImplemented) { - /* Try the next one */ - Py_DECREF(result); - result = NULL; + kwargs = PyDict_New(); + if (kwnames != NULL) { + if (kwargs == NULL) { + Py_DECREF(args); + return -1; } - else { - /* Either a good result, or an exception was raised. */ - goto cleanup; + Py_ssize_t nkwargs = PyTuple_GET_SIZE(kwnames); + for (Py_ssize_t i = 0; i < nkwargs; i++) { + PyObject *key = PyTuple_GET_ITEM(kwnames, i); + PyObject *value = fast_args[i+len_args]; + if (PyDict_SetItem(kwargs, key, value) < 0) { + Py_DECREF(args); + Py_DECREF(kwargs); + return -1; + } } } + *out_args = args; + *out_kwargs = kwargs; + return 0; +} + +static void +set_no_matching_types_error(PyObject *public_api, PyObject *types) +{ + static PyObject *errmsg_formatter = NULL; /* No acceptable override found, raise TypeError. */ npy_cache_import("numpy.core._internal", "array_function_errmsg_formatter", &errmsg_formatter); if (errmsg_formatter != NULL) { PyObject *errmsg = PyObject_CallFunctionObjArgs( - errmsg_formatter, public_api, types, NULL); + errmsg_formatter, public_api, types, NULL); if (errmsg != NULL) { PyErr_SetObject(PyExc_TypeError, errmsg); Py_DECREF(errmsg); } } - -cleanup: - for (int j = 0; j < num_implementing_args; j++) { - Py_DECREF(implementing_args[j]); - Py_DECREF(array_function_methods[j]); - } - Py_XDECREF(types); - Py_DECREF(relevant_args); - return result; -} - - -/* - * Implements the __array_function__ protocol for a Python function, as described in - * in NEP-18. See numpy.core.overrides for a full docstring. - */ -NPY_NO_EXPORT PyObject * -array_implement_array_function( - PyObject *NPY_UNUSED(dummy), PyObject *positional_args) -{ - PyObject *res, *implementation, *public_api, *relevant_args, *args, *kwargs; - - if (!PyArg_UnpackTuple( - positional_args, "implement_array_function", 5, 5, - &implementation, &public_api, &relevant_args, &args, &kwargs)) { - return NULL; - } - - /* - * Remove `like=` kwarg, which is NumPy-exclusive and thus not present - * in downstream libraries. If `like=` is specified but doesn't - * implement `__array_function__`, raise a `TypeError`. - */ - if (kwargs != NULL && PyDict_Contains(kwargs, npy_ma_str_like)) { - PyObject *like_arg = PyDict_GetItem(kwargs, npy_ma_str_like); - if (like_arg != NULL) { - PyObject *tmp_has_override = get_array_function(like_arg); - if (tmp_has_override == NULL) { - return PyErr_Format(PyExc_TypeError, - "The `like` argument must be an array-like that " - "implements the `__array_function__` protocol."); - } - Py_DECREF(tmp_has_override); - PyDict_DelItem(kwargs, npy_ma_str_like); - - /* - * If `like=` kwarg was removed, `implementation` points to the NumPy - * public API, as `public_api` is in that case the wrapper dispatcher - * function. For example, in the `np.full` case, `implementation` is - * `np.full`, whereas `public_api` is `_full_with_like`. This is done - * to ensure `__array_function__` implementations can do - * equality/identity comparisons when `like=` is present. - */ - public_api = implementation; - } - } - - res = array_implement_array_function_internal( - public_api, relevant_args, args, kwargs); - - if (res == Py_NotImplemented) { - return PyObject_Call(implementation, args, kwargs); - } - return res; } /* @@ -392,64 +280,48 @@ array_implement_c_array_function_creation( PyObject *args, PyObject *kwargs, PyObject *const *fast_args, Py_ssize_t len_args, PyObject *kwnames) { - PyObject *relevant_args = NULL; + PyObject *dispatch_types = NULL; PyObject *numpy_module = NULL; PyObject *public_api = NULL; PyObject *result = NULL; /* If `like` doesn't implement `__array_function__`, raise a `TypeError` */ - PyObject *tmp_has_override = get_array_function(like); - if (tmp_has_override == NULL) { + PyObject *method = get_array_function(like); + if (method == NULL) { return PyErr_Format(PyExc_TypeError, "The `like` argument must be an array-like that " "implements the `__array_function__` protocol."); } - Py_DECREF(tmp_has_override); - - if (fast_args != NULL) { + if (is_default_array_function(method)) { /* - * Convert from vectorcall convention, since the protocol requires - * the normal convention. We have to do this late to ensure the - * normal path where NotImplemented is returned is fast. + * Return a borrowed reference of Py_NotImplemented to defer back to + * the original function. */ + Py_DECREF(method); + return Py_NotImplemented; + } + + dispatch_types = PyTuple_Pack(1, Py_TYPE(like)); + if (dispatch_types == NULL) { + goto finish; + } + + /* We have to call __array_function__ properly, which needs some prep */ + if (fast_args != NULL) { assert(args == NULL); assert(kwargs == NULL); - args = PyTuple_New(len_args); - if (args == NULL) { - return NULL; - } - for (Py_ssize_t i = 0; i < len_args; i++) { - Py_INCREF(fast_args[i]); - PyTuple_SET_ITEM(args, i, fast_args[i]); - } - if (kwnames != NULL) { - kwargs = PyDict_New(); - if (kwargs == NULL) { - Py_DECREF(args); - return NULL; - } - Py_ssize_t nkwargs = PyTuple_GET_SIZE(kwnames); - for (Py_ssize_t i = 0; i < nkwargs; i++) { - PyObject *key = PyTuple_GET_ITEM(kwnames, i); - PyObject *value = fast_args[i+len_args]; - if (PyDict_SetItem(kwargs, key, value) < 0) { - Py_DECREF(args); - Py_DECREF(kwargs); - return NULL; - } - } + if (get_args_and_kwargs( + fast_args, len_args, kwnames, &args, &kwargs) < 0) { + goto finish; } } - relevant_args = PyTuple_Pack(1, like); - if (relevant_args == NULL) { - goto finish; - } /* The like argument must be present in the keyword arguments, remove it */ if (PyDict_DelItem(kwargs, npy_ma_str_like) < 0) { goto finish; } + /* Fetch the actual symbol (the long way right now) */ numpy_module = PyImport_Import(npy_ma_str_numpy); if (numpy_module == NULL) { goto finish; @@ -466,16 +338,20 @@ array_implement_c_array_function_creation( goto finish; } - result = array_implement_array_function_internal( - public_api, relevant_args, args, kwargs); + result = call_array_function(like, method, + public_api, dispatch_types, args, kwargs); - finish: - if (kwnames != NULL) { - /* args and kwargs were converted from vectorcall convention */ - Py_XDECREF(args); - Py_XDECREF(kwargs); + if (result == Py_NotImplemented) { + Py_DECREF(result); + result = NULL; + set_no_matching_types_error(public_api, dispatch_types); } - Py_XDECREF(relevant_args); + + finish: + Py_DECREF(method); + Py_XDECREF(args); + Py_XDECREF(kwargs); + Py_XDECREF(dispatch_types); Py_XDECREF(public_api); return result; } @@ -530,3 +406,275 @@ array__get_implementing_args( Py_DECREF(relevant_args); return result; } + + +typedef struct { + PyObject_HEAD + vectorcallfunc vectorcall; + PyObject *dict; + PyObject *relevant_arg_func; + PyObject *default_impl; +} PyArray_ArrayFunctionDispatcherObject; + + +static void +dispatcher_dealloc(PyArray_ArrayFunctionDispatcherObject *self) +{ + Py_CLEAR(self->relevant_arg_func); + Py_CLEAR(self->default_impl); + Py_CLEAR(self->dict); + PyObject_FREE(self); +} + + +static PyObject * +dispatcher_vectorcall(PyArray_ArrayFunctionDispatcherObject *self, + PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames) +{ + PyObject *result = NULL; + PyObject *types = NULL; + PyObject *relevant_args = NULL; + + PyObject *public_api; + + /* __array_function__ passes args, kwargs. These may be filled: */ + PyObject *packed_args = NULL; + PyObject *packed_kwargs = NULL; + + PyObject *implementing_args[NPY_MAXARGS]; + PyObject *array_function_methods[NPY_MAXARGS]; + + int num_implementing_args; + + if (self->relevant_arg_func != NULL) { + public_api = (PyObject *)self; + + /* Typical path, need to call the relevant_arg_func and unpack them */ + relevant_args = PyObject_Vectorcall( + self->relevant_arg_func, args, len_args, kwnames); + if (relevant_args == NULL) { + return NULL; + } + Py_SETREF(relevant_args, PySequence_Fast(relevant_args, + "dispatcher for __array_function__ did not return an iterable")); + if (relevant_args == NULL) { + return NULL; + } + + num_implementing_args = get_implementing_args_and_methods( + relevant_args, implementing_args, array_function_methods); + if (num_implementing_args < 0) { + Py_DECREF(relevant_args); + return NULL; + } + } + else { + /* For like= dispatching from Python, the public_symbol is the impl */ + public_api = self->default_impl; + + /* + * We are dealing with `like=` from Python. For simplicity, the + * Python code passes it on as the first argument. + */ + if (PyVectorcall_NARGS(len_args) == 0) { + PyErr_Format(PyExc_TypeError, + "`like` argument dispatching, but first argument is not " + "positional in call to %S.", self->default_impl); + return NULL; + } + + array_function_methods[0] = get_array_function(args[0]); + if (array_function_methods[0] == NULL) { + return PyErr_Format(PyExc_TypeError, + "The `like` argument must be an array-like that " + "implements the `__array_function__` protocol."); + } + num_implementing_args = 1; + implementing_args[0] = args[0]; + Py_INCREF(implementing_args[0]); + + /* do not pass the like argument */ + len_args = PyVectorcall_NARGS(len_args) - 1; + len_args |= PY_VECTORCALL_ARGUMENTS_OFFSET; + args++; + } + + /* + * Handle the typical case of no overrides. This is merely an optimization + * if some arguments are ndarray objects, but is also necessary if no + * arguments implement __array_function__ at all (e.g., if they are all + * built-in types). + */ + int any_overrides = 0; + for (int j = 0; j < num_implementing_args; j++) { + if (!is_default_array_function(array_function_methods[j])) { + any_overrides = 1; + break; + } + } + if (!any_overrides) { + /* Directly call the actual implementation. */ + result = PyObject_Vectorcall(self->default_impl, args, len_args, kwnames); + goto cleanup; + } + + /* Find args and kwargs as tuple and dict, as we pass them out: */ + if (get_args_and_kwargs( + args, len_args, kwnames, &packed_args, &packed_kwargs) < 0) { + goto cleanup; + } + + /* + * Create a Python object for types. + * We use a tuple, because it's the fastest Python collection to create + * and has the bonus of being immutable. + */ + types = PyTuple_New(num_implementing_args); + if (types == NULL) { + goto cleanup; + } + for (int j = 0; j < num_implementing_args; j++) { + PyObject *arg_type = (PyObject *)Py_TYPE(implementing_args[j]); + Py_INCREF(arg_type); + PyTuple_SET_ITEM(types, j, arg_type); + } + + /* Call __array_function__ methods */ + for (int j = 0; j < num_implementing_args; j++) { + PyObject *argument = implementing_args[j]; + PyObject *method = array_function_methods[j]; + + result = call_array_function( + argument, method, public_api, types, + packed_args, packed_kwargs); + + if (result == Py_NotImplemented) { + /* Try the next one */ + Py_DECREF(result); + result = NULL; + } + else { + /* Either a good result, or an exception was raised. */ + goto cleanup; + } + } + + set_no_matching_types_error(public_api, types); + +cleanup: + for (int j = 0; j < num_implementing_args; j++) { + Py_DECREF(implementing_args[j]); + Py_DECREF(array_function_methods[j]); + } + Py_XDECREF(packed_args); + Py_XDECREF(packed_kwargs); + Py_XDECREF(types); + Py_XDECREF(relevant_args); + return result; +} + + +static PyObject * +dispatcher_new(PyTypeObject *NPY_UNUSED(cls), PyObject *args, PyObject *kwargs) +{ + PyArray_ArrayFunctionDispatcherObject *self; + + self = PyObject_New( + PyArray_ArrayFunctionDispatcherObject, + &PyArrayFunctionDispatcher_Type); + if (self == NULL) { + return PyErr_NoMemory(); + } + + char *kwlist[] = {"", "", NULL}; + if (!PyArg_ParseTupleAndKeywords( + args, kwargs, "OO:_ArrayFunctionDispatcher", kwlist, + &self->relevant_arg_func, &self->default_impl)) { + Py_DECREF(self); + return NULL; + } + + self->vectorcall = (vectorcallfunc)dispatcher_vectorcall; + if (self->relevant_arg_func == Py_None) { + /* NULL in the relevant arg function means we use `like=` */ + Py_CLEAR(self->relevant_arg_func); + } + else { + Py_INCREF(self->relevant_arg_func); + } + Py_INCREF(self->default_impl); + + /* Need to be like a Python function that has arbitrary attributes */ + self->dict = PyDict_New(); + if (self->dict == NULL) { + Py_DECREF(self); + return NULL; + } + return (PyObject *)self; +} + + +static PyObject * +dispatcher_str(PyArray_ArrayFunctionDispatcherObject *self) +{ + return PyObject_Str(self->default_impl); +} + + +static PyObject * +dispatcher_repr(PyObject *self) +{ + PyObject *name = PyObject_GetAttrString(self, "__name__"); + if (name == NULL) { + return NULL; + } + /* Print like a normal function */ + return PyUnicode_FromFormat("", name, self); +} + +static PyObject * +dispatcher_get_implementation( + PyArray_ArrayFunctionDispatcherObject *self, void *NPY_UNUSED(closure)) +{ + Py_INCREF(self->default_impl); + return self->default_impl; +} + + +static PyObject * +dispatcher_reduce(PyObject *self, PyObject *NPY_UNUSED(args)) +{ + return PyObject_GetAttrString(self, "__qualname__"); +} + + +static struct PyMethodDef func_dispatcher_methods[] = { + {"__reduce__", + (PyCFunction)dispatcher_reduce, METH_NOARGS, NULL}, + {NULL, NULL, 0, NULL} +}; + + +static struct PyGetSetDef func_dispatcher_getset[] = { + {"__dict__", &PyObject_GenericGetDict, 0, NULL, 0}, + {"_implementation", (getter)&dispatcher_get_implementation, 0, NULL, 0}, + {0, 0, 0, 0, 0} +}; + + +NPY_NO_EXPORT PyTypeObject PyArrayFunctionDispatcher_Type = { + PyVarObject_HEAD_INIT(NULL, 0) + .tp_name = "numpy._ArrayFunctionDispatcher", + .tp_basicsize = sizeof(PyArray_ArrayFunctionDispatcherObject), + /* We have a dict, so in theory could traverse, but in practice... */ + .tp_dictoffset = offsetof(PyArray_ArrayFunctionDispatcherObject, dict), + .tp_dealloc = (destructor)dispatcher_dealloc, + .tp_new = (newfunc)dispatcher_new, + .tp_str = (reprfunc)dispatcher_str, + .tp_repr = (reprfunc)dispatcher_repr, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_VECTORCALL, + .tp_methods = func_dispatcher_methods, + .tp_getset = func_dispatcher_getset, + .tp_call = &PyVectorcall_Call, + .tp_vectorcall_offset = offsetof(PyArray_ArrayFunctionDispatcherObject, vectorcall), +}; diff --git a/numpy/core/src/multiarray/arrayfunction_override.h b/numpy/core/src/multiarray/arrayfunction_override.h index 09f7ee5480ed..3b8b88bacf3f 100644 --- a/numpy/core/src/multiarray/arrayfunction_override.h +++ b/numpy/core/src/multiarray/arrayfunction_override.h @@ -1,9 +1,7 @@ #ifndef NUMPY_CORE_SRC_MULTIARRAY_ARRAYFUNCTION_OVERRIDE_H_ #define NUMPY_CORE_SRC_MULTIARRAY_ARRAYFUNCTION_OVERRIDE_H_ -NPY_NO_EXPORT PyObject * -array_implement_array_function( - PyObject *NPY_UNUSED(dummy), PyObject *positional_args); +extern NPY_NO_EXPORT PyTypeObject PyArrayFunctionDispatcher_Type; NPY_NO_EXPORT PyObject * array__get_implementing_args( diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index 94fa2a9092ac..6b1862b185d9 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -4539,9 +4539,6 @@ static struct PyMethodDef array_module_methods[] = { METH_VARARGS | METH_KEYWORDS, NULL}, {"_monotonicity", (PyCFunction)arr__monotonicity, METH_VARARGS | METH_KEYWORDS, NULL}, - {"implement_array_function", - (PyCFunction)array_implement_array_function, - METH_VARARGS, NULL}, {"interp", (PyCFunction)arr_interp, METH_VARARGS | METH_KEYWORDS, NULL}, {"interp_complex", (PyCFunction)arr_interp_complex, @@ -5112,6 +5109,12 @@ PyMODINIT_FUNC PyInit__multiarray_umath(void) { if (set_typeinfo(d) != 0) { goto err; } + if (PyType_Ready(&PyArrayFunctionDispatcher_Type) < 0) { + goto err; + } + PyDict_SetItemString( + d, "_ArrayFunctionDispatcher", + (PyObject *)&PyArrayFunctionDispatcher_Type); if (PyType_Ready(&PyArrayMethod_Type) < 0) { goto err; } diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py index 71d600c30039..0c1740df1766 100644 --- a/numpy/lib/npyio.py +++ b/numpy/lib/npyio.py @@ -760,13 +760,6 @@ def _ensure_ndmin_ndarray(a, *, ndmin: int): _loadtxt_chunksize = 50000 -def _loadtxt_dispatcher( - fname, dtype=None, comments=None, delimiter=None, - converters=None, skiprows=None, usecols=None, unpack=None, - ndmin=None, encoding=None, max_rows=None, *, like=None): - return (like,) - - def _check_nonneg_int(value, name="argument"): try: operator.index(value) @@ -1331,10 +1324,10 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None, if like is not None: return _loadtxt_with_like( - fname, dtype=dtype, comments=comments, delimiter=delimiter, + like, fname, dtype=dtype, comments=comments, delimiter=delimiter, converters=converters, skiprows=skiprows, usecols=usecols, unpack=unpack, ndmin=ndmin, encoding=encoding, - max_rows=max_rows, like=like + max_rows=max_rows ) if isinstance(delimiter, bytes): @@ -1361,9 +1354,7 @@ def loadtxt(fname, dtype=float, comments='#', delimiter=None, return arr -_loadtxt_with_like = array_function_dispatch( - _loadtxt_dispatcher -)(loadtxt) +_loadtxt_with_like = array_function_dispatch()(loadtxt) def _savetxt_dispatcher(fname, X, fmt=None, delimiter=None, newline=None, @@ -1724,17 +1715,6 @@ def fromregex(file, regexp, dtype, encoding=None): #####-------------------------------------------------------------------------- -def _genfromtxt_dispatcher(fname, dtype=None, comments=None, delimiter=None, - skip_header=None, skip_footer=None, converters=None, - missing_values=None, filling_values=None, usecols=None, - names=None, excludelist=None, deletechars=None, - replace_space=None, autostrip=None, case_sensitive=None, - defaultfmt=None, unpack=None, usemask=None, loose=None, - invalid_raise=None, max_rows=None, encoding=None, - *, ndmin=None, like=None): - return (like,) - - @set_array_function_like_doc @set_module('numpy') def genfromtxt(fname, dtype=float, comments='#', delimiter=None, @@ -1932,7 +1912,7 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, if like is not None: return _genfromtxt_with_like( - fname, dtype=dtype, comments=comments, delimiter=delimiter, + like, fname, dtype=dtype, comments=comments, delimiter=delimiter, skip_header=skip_header, skip_footer=skip_footer, converters=converters, missing_values=missing_values, filling_values=filling_values, usecols=usecols, names=names, @@ -1942,7 +1922,6 @@ def genfromtxt(fname, dtype=float, comments='#', delimiter=None, unpack=unpack, usemask=usemask, loose=loose, invalid_raise=invalid_raise, max_rows=max_rows, encoding=encoding, ndmin=ndmin, - like=like ) _ensure_ndmin_ndarray_check_param(ndmin) @@ -2471,9 +2450,7 @@ def encode_unicode_cols(row_tup): return output -_genfromtxt_with_like = array_function_dispatch( - _genfromtxt_dispatcher -)(genfromtxt) +_genfromtxt_with_like = array_function_dispatch()(genfromtxt) def recfromtxt(fname, **kwargs): diff --git a/numpy/lib/twodim_base.py b/numpy/lib/twodim_base.py index dcb4ed46ce14..ed4f9870420b 100644 --- a/numpy/lib/twodim_base.py +++ b/numpy/lib/twodim_base.py @@ -155,10 +155,6 @@ def flipud(m): return m[::-1, ...] -def _eye_dispatcher(N, M=None, k=None, dtype=None, order=None, *, like=None): - return (like,) - - @set_array_function_like_doc @set_module('numpy') def eye(N, M=None, k=0, dtype=float, order='C', *, like=None): @@ -209,7 +205,7 @@ def eye(N, M=None, k=0, dtype=float, order='C', *, like=None): """ if like is not None: - return _eye_with_like(N, M=M, k=k, dtype=dtype, order=order, like=like) + return _eye_with_like(like, N, M=M, k=k, dtype=dtype, order=order) if M is None: M = N m = zeros((N, M), dtype=dtype, order=order) @@ -228,9 +224,7 @@ def eye(N, M=None, k=0, dtype=float, order='C', *, like=None): return m -_eye_with_like = array_function_dispatch( - _eye_dispatcher -)(eye) +_eye_with_like = array_function_dispatch()(eye) def _diag_dispatcher(v, k=None): @@ -369,10 +363,6 @@ def diagflat(v, k=0): return wrap(res) -def _tri_dispatcher(N, M=None, k=None, dtype=None, *, like=None): - return (like,) - - @set_array_function_like_doc @set_module('numpy') def tri(N, M=None, k=0, dtype=float, *, like=None): @@ -416,7 +406,7 @@ def tri(N, M=None, k=0, dtype=float, *, like=None): """ if like is not None: - return _tri_with_like(N, M=M, k=k, dtype=dtype, like=like) + return _tri_with_like(like, N, M=M, k=k, dtype=dtype) if M is None: M = N @@ -430,9 +420,7 @@ def tri(N, M=None, k=0, dtype=float, *, like=None): return m -_tri_with_like = array_function_dispatch( - _tri_dispatcher -)(tri) +_tri_with_like = array_function_dispatch()(tri) def _trilu_dispatcher(m, k=None):