Skip to content

[BUG]: Avoiding promoter overrides leading to global state corruption#78

Open
SwayamInSync wants to merge 4 commits intonumpy:mainfrom
SwayamInSync:promoter-fix
Open

[BUG]: Avoiding promoter overrides leading to global state corruption#78
SwayamInSync wants to merge 4 commits intonumpy:mainfrom
SwayamInSync:promoter-fix

Conversation

@SwayamInSync
Copy link
Copy Markdown
Member

closes #76

Copy link
Copy Markdown
Member

@seberg seberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, LGTM, but the ld promoter doesn't really look right. I.e. it doesn't do anything interesting?

To be fair, that seems like an existing bug...


// Preserve the integer type for the exponent (slot 1)
Py_INCREF(op_dtypes[1]);
new_op_dtypes[1] = op_dtypes[1];
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you want to promote to PyArray_PyLongDType here. (assuming signature[1] == NULL).
(I guess you could add more logic in case of non-ints, but maybe doesn't matter much in practice... cast safety may even reject it normally anyway.)

Although long seems weird compared to just using intp that always works, but OK. Otherwise this looks like it doesn't do anything really.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I updated the loop registration + promoter to use intp

Copy link
Copy Markdown
Member

@ngoldbaum ngoldbaum left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! I asked Claude Code to review your code and it spotted some pre-existing reference leaks.

Maybe you could reduce boilerplate and the chance for bugs by using the following helper function:

diff --git a/src/include/umath/promoters.hpp b/src/include/umath/promoters.hpp
index 3b3c1ef..2a3d26c 100644
--- a/src/include/umath/promoters.hpp
+++ b/src/include/umath/promoters.hpp
@@ -12,9 +12,10 @@
 #include "../dtype.h"

 inline int
-quad_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
-                    PyArray_DTypeMeta *signature[], PyArray_DTypeMeta *new_op_dtypes[])
+quad_ufunc_promoter(PyObject *ob_ufunc, PyArray_DTypeMeta *const op_dtypes[],
+                    PyArray_DTypeMeta *const signature[], PyArray_DTypeMeta *new_op_dtypes[])
 {
+    PyUFuncObject *ufunc = (PyUFuncObject *)ufunc;
     int nin = ufunc->nin;
     int nargs = ufunc->nargs;
     PyArray_DTypeMeta *common = NULL;
@@ -56,7 +57,8 @@ quad_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
     }
     // If no common output dtype, use standard promotion for inputs
     if (common == NULL) {
-        common = PyArray_PromoteDTypeSequence(nin, op_dtypes);
+        common = PyArray_PromoteDTypeSequence(
+                nin, (PyArray_DTypeMeta **)op_dtypes);
         if (common == NULL) {
             if (PyErr_ExceptionMatches(PyExc_TypeError)) {
                 PyErr_Clear();  // Do not propagate normal promotion errors
@@ -86,5 +88,42 @@ quad_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
     return 0;
 }

+static inline int
+add_promoter(PyObject *ufunc, PyObject *dtypes[], size_t n_dtypes,
+             PyArrayMethod_PromoterFunction *promoter_impl)
+{
+    PyObject *DType_tuple = NULL;
+    PyObject *promoter_capsule = NULL;
+    int ret = -1;
+
+    DType_tuple = PyTuple_New(n_dtypes);
+
+    if (DType_tuple == NULL) {
+        goto cleanup;
+    }
+
+    for (size_t i=0; i<n_dtypes; i++) {
+        Py_INCREF((PyObject *)dtypes[i]);
+        PyTuple_SET_ITEM(DType_tuple, i, (PyObject *)dtypes[i]);
+    }
+
+    promoter_capsule = PyCapsule_New(
+            (void *)&promoter_impl, "numpy._ufunc_promoter", NULL);
+
+    if (promoter_capsule == NULL) {
+        goto cleanup;
+    }
+
+    if (PyUFunc_AddPromoter(ufunc, DType_tuple, promoter_capsule) < 0) {
+        goto cleanup;
+    }
+
+    ret = 0;
+  cleanup:
+    Py_XDECREF(promoter_capsule);
+    Py_XDECREF(DType_tuple);
+
+    return ret;
+}

-#endif
\ No newline at end of file
+#endif

And then you call it like so:

        PyObject *left_DTypes[] = {
            (PyObject *)&QuadPrecDType,
            (PyObject *)&PyArrayDescr_Type,
            (PyObject *)&QuadPrecDType,
        };

        if (add_promoter(ufunc, left_DTypes, 4, quad_ufunc_promoter) != 0) {
            Py_DECREF(ufunc);
            return -1;
        }



inline int
quad_ldexp_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type of the first argument should be PyObject *. There's also some const missing from the second and third parameters:

https://numpy.org/devdocs/reference/c-api/array.html#c.PyArrayMethod_PromoterFunction

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same error happens in quad_ufunc_promoter.

int
create_quad_ldexp_ufunc(PyObject *numpy, const char *ufunc_name)
{
PyObject *ufunc = PyObject_GetAttrString(numpy, ufunc_name);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a pre-existing issue but since you're touching this code: ufunc object is never DECREF'd below, including the error paths.

int
create_quad_binary_2out_ufunc(PyObject *numpy, const char *ufunc_name)
{
PyObject *ufunc = PyObject_GetAttrString(numpy, ufunc_name);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly here, ufunc is never cleaned up below. Since you're touching the error paths you might as well fix this too.

int
create_quad_binary_ufunc(PyObject *numpy, const char *ufunc_name)
{
PyObject *ufunc = PyObject_GetAttrString(numpy, ufunc_name);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, ufunc is never DECREF'd.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Promoters are matching too broadly

3 participants