diff --git a/numpy/core/code_generators/generate_numpy_api.py b/numpy/core/code_generators/generate_numpy_api.py index 96959b01c69a..917df6d81dd3 100644 --- a/numpy/core/code_generators/generate_numpy_api.py +++ b/numpy/core/code_generators/generate_numpy_api.py @@ -30,6 +30,7 @@ #if defined(NO_IMPORT) || defined(NO_IMPORT_ARRAY) extern void **PyArray_API; #else +HPyContext *numpy_global_ctx = NULL; #if defined(PY_ARRAY_UNIQUE_SYMBOL) void **PyArray_API; #else @@ -51,7 +52,6 @@ return -1; } c_api = PyObject_GetAttrString(numpy, "_ARRAY_API"); - Py_DECREF(numpy); if (c_api == NULL) { PyErr_SetString(PyExc_AttributeError, "_ARRAY_API not found"); return -1; @@ -69,6 +69,11 @@ return -1; } + PyObject *ctx_capsule = PyObject_GetAttrString(numpy, "_HPY_CONTEXT"); + Py_DECREF(numpy); + numpy_global_ctx = (HPyContext *)PyCapsule_GetPointer(ctx_capsule, NULL); + Py_DECREF(ctx_capsule); + /* Perform runtime check of C API version */ if (NPY_VERSION != PyArray_GetNDArrayCVersion()) { PyErr_Format(PyExc_RuntimeError, "module compiled against "\ diff --git a/numpy/core/include/numpy/ndarraytypes.h b/numpy/core/include/numpy/ndarraytypes.h index 1147f5159076..dbec133371e8 100644 --- a/numpy/core/include/numpy/ndarraytypes.h +++ b/numpy/core/include/numpy/ndarraytypes.h @@ -9,7 +9,13 @@ #define NPY_NO_EXPORT NPY_VISIBILITY_HIDDEN #include "hpy.h" -#define npy_get_context _HPyGetContext +extern NPY_NO_EXPORT HPyContext *numpy_global_ctx; +static NPY_INLINE HPyContext * +npy_get_context(void) +{ + assert(numpy_global_ctx != NULL); + return numpy_global_ctx; +} /* Only use thread if configured in config and python supports it */ #if defined WITH_THREAD && !NPY_NO_SMP diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index 608b2535ed4f..285488b8f9de 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -105,6 +105,7 @@ set_legacy_print_mode(PyObject *NPY_UNUSED(self), PyObject *args) } NPY_NO_EXPORT PyTypeObject* _PyArray_Type_p = NULL; +NPY_NO_EXPORT HPyContext *numpy_global_ctx = NULL; /* Only here for API compatibility */ NPY_NO_EXPORT PyTypeObject PyBigArray_Type; @@ -4938,6 +4939,16 @@ static HPy init__multiarray_umath_impl(HPyContext *ctx) { goto err; } #endif + + /* Store the context so legacy functions and extensions can access it */ + numpy_global_ctx = ctx; + s = PyCapsule_New((void *)ctx, NULL, NULL); + if (s == NULL) { + goto err; + } + PyDict_SetItemString(d, "_HPY_CONTEXT", s); + Py_DECREF(s); + return HPy_FromPyObject(ctx, m); err: