Skip to content

Commit

Permalink
Store the context used to import _multiarray so that legacy functions
Browse files Browse the repository at this point in the history
and external extensions can find it.

This removes the dependency on _HPyGetContext() which is specific
to the cpython ABI.
  • Loading branch information
rlamy committed Dec 2, 2021
1 parent d28ca66 commit 1466a7c
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 2 deletions.
7 changes: 6 additions & 1 deletion numpy/core/code_generators/generate_numpy_api.py
Expand Up @@ -30,6 +30,7 @@
#if defined(NO_IMPORT) || defined(NO_IMPORT_ARRAY)
extern void **PyArray_API;
#else
HPyContext *numpy_global_ctx = NULL;
#if defined(PY_ARRAY_UNIQUE_SYMBOL)
void **PyArray_API;
#else
Expand All @@ -51,7 +52,6 @@
return -1;
}
c_api = PyObject_GetAttrString(numpy, "_ARRAY_API");
Py_DECREF(numpy);
if (c_api == NULL) {
PyErr_SetString(PyExc_AttributeError, "_ARRAY_API not found");
return -1;
Expand All @@ -69,6 +69,11 @@
return -1;
}
PyObject *ctx_capsule = PyObject_GetAttrString(numpy, "_HPY_CONTEXT");
Py_DECREF(numpy);
numpy_global_ctx = (HPyContext *)PyCapsule_GetPointer(ctx_capsule, NULL);

This comment has been minimized.

Copy link
@antocuni

antocuni Dec 2, 2021

and here you probably want an assert(numpy_global_ctx != NULL)?

This comment has been minimized.

Copy link
@rlamy

rlamy Dec 15, 2021

Author Member

I think this actually needs proper error handling, like for _ARRAY_APIabove

Py_DECREF(ctx_capsule);
/* Perform runtime check of C API version */
if (NPY_VERSION != PyArray_GetNDArrayCVersion()) {
PyErr_Format(PyExc_RuntimeError, "module compiled against "\
Expand Down
8 changes: 7 additions & 1 deletion numpy/core/include/numpy/ndarraytypes.h
Expand Up @@ -9,7 +9,13 @@
#define NPY_NO_EXPORT NPY_VISIBILITY_HIDDEN

#include "hpy.h"
#define npy_get_context _HPyGetContext
extern NPY_NO_EXPORT HPyContext *numpy_global_ctx;
static NPY_INLINE HPyContext *
npy_get_context(void)
{
assert(numpy_global_ctx != NULL);
return numpy_global_ctx;
}

/* Only use thread if configured in config and python supports it */
#if defined WITH_THREAD && !NPY_NO_SMP
Expand Down
11 changes: 11 additions & 0 deletions numpy/core/src/multiarray/multiarraymodule.c
Expand Up @@ -105,6 +105,7 @@ set_legacy_print_mode(PyObject *NPY_UNUSED(self), PyObject *args)
}

NPY_NO_EXPORT PyTypeObject* _PyArray_Type_p = NULL;
NPY_NO_EXPORT HPyContext *numpy_global_ctx = NULL;

/* Only here for API compatibility */
NPY_NO_EXPORT PyTypeObject PyBigArray_Type;
Expand Down Expand Up @@ -4938,6 +4939,16 @@ static HPy init__multiarray_umath_impl(HPyContext *ctx) {
goto err;
}
#endif

/* Store the context so legacy functions and extensions can access it */
numpy_global_ctx = ctx;

This comment has been minimized.

Copy link
@antocuni

antocuni Dec 2, 2021

I'd add a assert(numpy_global_ctx == NULL), just for extra safety

This comment has been minimized.

Copy link
@rlamy

rlamy Dec 15, 2021

Author Member

well, that would prevent reloading, but that's probably a good thing

s = PyCapsule_New((void *)ctx, NULL, NULL);
if (s == NULL) {
goto err;
}
PyDict_SetItemString(d, "_HPY_CONTEXT", s);
Py_DECREF(s);

return HPy_FromPyObject(ctx, m);

err:
Expand Down

0 comments on commit 1466a7c

Please sign in to comment.