From 56536b10c5677b8b64c0f2f6da3885332672d7a1 Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Wed, 23 Feb 2022 14:25:50 -0800 Subject: [PATCH] BUG: Fix numba DUFuncs added loops getting picked up It was always my intention to do this: If no loop is found and we go into the legacy ufunc path and legacy resolving works, we need to double check that the ufunc was not mutated. Normal operation never mutates ufuncs (its really not meant to be) but numbas DUFuncs need to do it (they compile loops dynamically). The future is much brighter for them in this regard, but right now they have to keep working. Closes gh-20735 --- numpy/core/src/umath/dispatching.c | 34 ++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/numpy/core/src/umath/dispatching.c b/numpy/core/src/umath/dispatching.c index 81d47a0e1520..61a95a857e73 100644 --- a/numpy/core/src/umath/dispatching.c +++ b/numpy/core/src/umath/dispatching.c @@ -746,6 +746,40 @@ promote_and_get_info_and_ufuncimpl(PyUFuncObject *ufunc, } info = promote_and_get_info_and_ufuncimpl(ufunc, ops, signature, new_op_dtypes, NPY_FALSE); + if (info == NULL) { + /* + * NOTE: This block exists solely to support numba's DUFuncs which add + * new loops dynamically, so our list may get outdated. Thus, we + * have to make sure that the loop exists. + * + * Before adding a new loop, ensure that it actually exists. There + * is a tiny chance that this would not work, but it would require an + * extension additionally have a custom loop getter. + * This check should ensure a the right error message, but in principle + * we could try to call the loop getter here. + */ + char *types = ufunc->types; + npy_bool loop_exists = NPY_FALSE; + for (int i = 0; i < ufunc->ntypes; ++i) { + loop_exists = NPY_TRUE; /* assume it exists, break if not */ + for (int j = 0; j < ufunc->nargs; ++j) { + if (types[j] != new_op_dtypes[j]->type_num) { + loop_exists = NPY_FALSE; + break; + } + } + if (loop_exists) { + break; + } + types += ufunc->nargs; + } + + if (loop_exists) { + info = add_and_return_legacy_wrapping_ufunc_loop( + ufunc, new_op_dtypes, 0); + } + } + for (int i = 0; i < ufunc->nargs; i++) { Py_XDECREF(new_op_dtypes[i]); }