Skip to content

Commit

Permalink
Merge pull request #2 from hpyproject/ndarray-as-hpy-type
Browse files Browse the repository at this point in the history
Use HPy to define the ndarray type
  • Loading branch information
rlamy committed Jan 18, 2021
2 parents 091cb31 + 24c3c4f commit 18019cb
Show file tree
Hide file tree
Showing 12 changed files with 68 additions and 26 deletions.
2 changes: 1 addition & 1 deletion numpy/core/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,7 @@ def generate_umath_c(ext, build_dir):
join(codegen_dir, 'generate_ufunc_api.py'),
]

config.add_extension('_multiarray_umath',
config.add_hpy_extension('_multiarray_umath',
sources=multiarray_src + umath_src +
common_src +
[generate_config_h,
Expand Down
7 changes: 4 additions & 3 deletions numpy/core/src/multiarray/arrayobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ maintainer email: oliphant.travis@ieee.org
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include "structmember.h"
#include "hpy.h"

/*#include <stdio.h>*/
#define NPY_NO_DEPRECATED_API NPY_API_VERSION
Expand Down Expand Up @@ -1821,9 +1822,9 @@ static PyType_Slot PyArray_Type_slots[] = {
{0, NULL},
};

NPY_NO_EXPORT PyType_Spec PyArray_Type_spec = {
NPY_NO_EXPORT HPyType_Spec PyArray_Type_spec = {
.name = "numpy.ndarray",
.basicsize = sizeof(PyArrayObject_fields),
.flags = (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE),
.slots = PyArray_Type_slots,
.flags = (HPy_TPFLAGS_DEFAULT | HPy_TPFLAGS_BASETYPE),
.legacy_slots = PyArray_Type_slots,
};
4 changes: 3 additions & 1 deletion numpy/core/src/multiarray/arrayobject.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#error You should not include this
#endif

#include "hpy.h"

NPY_NO_EXPORT PyObject *
_strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op,
int rstrip);
Expand All @@ -26,6 +28,6 @@ array_might_be_written(PyArrayObject *obj);
*/
static const int NPY_ARRAY_WARN_ON_WRITE = (1 << 31);

extern NPY_NO_EXPORT PyType_Spec PyArray_Type_spec;
extern NPY_NO_EXPORT HPyType_Spec PyArray_Type_spec;

#endif
22 changes: 14 additions & 8 deletions numpy/core/src/multiarray/multiarraymodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#define PY_SSIZE_T_CLEAN
#include "Python.h"
#include "structmember.h"
#include "hpy.h"

#define NPY_NO_DEPRECATED_API NPY_API_VERSION
#define _UMATHMODULE
Expand Down Expand Up @@ -4510,7 +4511,10 @@ intern_strings(void)

static struct PyModuleDef moduledef = {
PyModuleDef_HEAD_INIT,
"_multiarray_umath",
/* XXX: Unclear if a dotted name is legit in .m_name, but universal
* mode requires it.
*/
"numpy.core._multiarray_umath",
NULL,
-1,
array_module_methods,
Expand All @@ -4521,7 +4525,8 @@ static struct PyModuleDef moduledef = {
};

/* Initialization function for the module */
PyMODINIT_FUNC PyInit__multiarray_umath(void) {
HPy_MODINIT(_multiarray_umath)
static HPy init__multiarray_umath_impl(HPyContext ctx) {
PyObject *m, *d, *s;
PyObject *c_api;

Expand Down Expand Up @@ -4588,13 +4593,14 @@ PyMODINIT_FUNC PyInit__multiarray_umath(void) {
goto err;
}

PyObject *array_type = PyType_FromSpec(&PyArray_Type_spec);
if (array_type == NULL) {
HPy h_array_type = HPyType_FromSpec(ctx, &PyArray_Type_spec, NULL);
if (HPy_IsNull(h_array_type)) {
goto err;
}
_PyArray_Type_p = (PyTypeObject*)array_type;
_PyArray_Type_p = (PyTypeObject*)HPy_AsPyObject(ctx, h_array_type);
PyArray_Type.tp_as_buffer = &array_as_buffer;
PyArray_Type.tp_weaklistoffset = offsetof(PyArrayObject_fields, weakreflist);
HPy_Close(ctx, h_array_type);

if (setup_scalartypes(d) < 0) {
goto err;
Expand Down Expand Up @@ -4725,7 +4731,7 @@ PyMODINIT_FUNC PyInit__multiarray_umath(void) {
ADDCONST(MAY_SHARE_EXACT);
#undef ADDCONST

PyDict_SetItemString(d, "ndarray", array_type);
PyDict_SetItemString(d, "ndarray", (PyObject*)&PyArray_Type);
PyDict_SetItemString(d, "flatiter", (PyObject *)&PyArrayIter_Type);
PyDict_SetItemString(d, "nditer", (PyObject *)&NpyIter_Type);
PyDict_SetItemString(d, "broadcast",
Expand Down Expand Up @@ -4767,12 +4773,12 @@ PyMODINIT_FUNC PyInit__multiarray_umath(void) {
if (initumath(m) != 0) {
goto err;
}
return m;
return HPy_FromPyObject(ctx, m);

err:
if (!PyErr_Occurred()) {
PyErr_SetString(PyExc_RuntimeError,
"cannot load multiarray module.");
}
return NULL;
return HPy_NULL;
}
1 change: 1 addition & 0 deletions numpy/ctypeslib.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def load_library(libname, loader_path):
so_ext2 = get_shared_lib_extension(is_python_ext=True)
if not so_ext2 == so_ext:
libname_ext.insert(0, libname + so_ext2)
libname_ext.insert(0, libname + '.hpy.so')
else:
libname_ext = [libname]

Expand Down
6 changes: 5 additions & 1 deletion numpy/distutils/command/build_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from glob import glob

from distutils.dep_util import newer_group
from distutils.command.build_ext import build_ext as old_build_ext
from setuptools.command.build_ext import build_ext as old_build_ext
from distutils.errors import DistutilsFileError, DistutilsSetupError,\
DistutilsError
from distutils.file_util import copy_file
Expand Down Expand Up @@ -543,6 +543,10 @@ def build_extension(self, ext):
build_temp=self.build_temp,
target_lang=ext.language)

if ext._needs_stub:
cmd = self.get_finalized_command('build_py').build_lib
self.write_stub(cmd, ext)

def _add_dummy_mingwex_sym(self, c_sources):
build_src = self.get_finalized_command("build_src").build_src
build_clib = self.get_finalized_command("build_clib").build_clib
Expand Down
39 changes: 28 additions & 11 deletions numpy/distutils/misc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,8 @@ def get_frame(level=0):

class Configuration:

_list_keys = ['packages', 'ext_modules', 'data_files', 'include_dirs',
_list_keys = ['packages', 'ext_modules', 'hpy_ext_modules', 'data_files',
'include_dirs',
'libraries', 'headers', 'scripts', 'py_modules',
'installed_libraries', 'define_macros']
_dict_keys = ['package_dir', 'installed_pkg_config']
Expand Down Expand Up @@ -1456,6 +1457,31 @@ def add_extension(self,name,sources,**kw):
The self.paths(...) method is applied to all lists that may contain
paths.
"""
from numpy.distutils.core import Extension
ext_args = self._process_extension_args(name, sources, **kw)
ext = Extension(**ext_args)
self.ext_modules.append(ext)

dist = self.get_distribution()
if dist is not None:
self.warn('distutils distribution has been initialized,'\
' it may be too late to add an extension '+name)
return ext

def add_hpy_extension(self, name, sources, **kw):
from numpy.distutils.core import Extension
ext_args = self._process_extension_args(name, sources, **kw)
ext = Extension(**ext_args)
self.hpy_ext_modules.append(ext)

dist = self.get_distribution()
if dist is not None:
self.warn('distutils distribution has been initialized,'\
' it may be too late to add an extension '+name)
return ext


def _process_extension_args(self, name, sources, **kw):
ext_args = copy.copy(kw)
ext_args['name'] = dot_join(self.name, name)
ext_args['sources'] = sources
Expand Down Expand Up @@ -1500,16 +1526,7 @@ def add_extension(self,name,sources,**kw):
ext_args['libraries'] = libnames + ext_args['libraries']
ext_args['define_macros'] = \
self.define_macros + ext_args.get('define_macros', [])

from numpy.distutils.core import Extension
ext = Extension(**ext_args)
self.ext_modules.append(ext)

dist = self.get_distribution()
if dist is not None:
self.warn('distutils distribution has been initialized,'\
' it may be too late to add an extension '+name)
return ext
return ext_args

def add_library(self,name,sources,**build_info):
"""
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ requires = [
"setuptools<49.2.0",
"wheel<=0.35.1",
"Cython>=0.29.21,<3.0", # Note: keep in sync with tools/cythonize.py
"hpy.devel @ git+https://github.com/hpyproject/hpy.git@8e20b89116c2993188157c09a6070a64f8efbd82#egg=hpy.devel"
]


Expand Down
6 changes: 6 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def get_docs_url():
# to the associated docs easily.
return "https://numpy.org/doc/{}.{}".format(MAJOR, MINOR)

HPY_ABI = 'cpython' if sys.implementation.name == 'cpython' else 'universal'

def setup_package():
src_path = os.path.dirname(os.path.abspath(__file__))
Expand Down Expand Up @@ -405,6 +406,11 @@ def setup_package():
cmdclass=cmdclass,
python_requires='>=3.7',
zip_safe=False,
setup_requires=['hpy.devel'],
# distuils doesn't load hpy.devel unless hpy_ext_modules is present
# as a keyword
hpy_ext_modules=[],
hpy_abi=HPY_ABI,
entry_points={
'console_scripts': f2py_cmds
},
Expand Down
2 changes: 2 additions & 0 deletions test_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ cffi
# - Mypy doesn't currently work on Python 3.9
mypy==0.790; platform_python_implementation != "PyPy"
typing_extensions
# HPy
git+https://github.com/hpyproject/hpy.git@8e20b89116c2993188157c09a6070a64f8efbd82#egg=hpy.devel
1 change: 1 addition & 0 deletions tools/travis-before-install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ pip install --upgrade pip 'setuptools<49.2.0' wheel
# requirement using `grep cython test_requirements.txt` instead of simply
# writing 'pip install setuptools wheel cython'.
pip install `grep cython test_requirements.txt`
pip install `grep hpy.devel test_requirements.txt`

if [ -n "$DOWNLOAD_OPENBLAS" ]; then
pwd
Expand Down
3 changes: 2 additions & 1 deletion tools/travis-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ if [ -n "$PYTHON_OPTS" ]; then
fi

# make some warnings fatal, mostly to match windows compilers
werrors="-Werror=vla -Werror=nonnull -Werror=pointer-arith"
# werrors="-Werror=vla -Werror=nonnull -Werror=pointer-arith"
werrors="-Werror=nonnull -Werror=pointer-arith"
werrors="$werrors -Werror=implicit-function-declaration"

# build with c99 by default
Expand Down

0 comments on commit 18019cb

Please sign in to comment.