Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added overload methods for Numpy's ufunc.reduce #7524

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions numba/core/runtime/nrt.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,4 +269,6 @@ VISIBILITY_HIDDEN const NRT_api_functions* NRT_get_api(void);
*/
VISIBILITY_HIDDEN NRT_ExternalAllocator* _nrt_get_sample_external_allocator(void);

PyObject * NRT_adapt_ndarray_to_python_acqref(arystruct_t* arystruct, PyTypeObject *retty,
int ndim, int writeable, PyArray_Descr *descr)
#endif /* NUMBA_NRT_H_ */
23 changes: 14 additions & 9 deletions numba/np/ufunc/_internal.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "_internal.h"
#include "Python.h"
#include "methodobject.h"


/* A small object that handles deallocation of some of a PyUFunc's fields */
typedef struct {
Expand Down Expand Up @@ -98,15 +100,6 @@ PyTypeObject PyUFuncCleaner_Type = {
/* ______________________________________________________________________
* DUFunc: A call-time (hence dynamic) specializable ufunc.
*/

typedef struct {
PyObject_HEAD
PyObject * dispatcher;
PyUFuncObject * ufunc;
PyObject * keepalive;
int frozen;
} PyDUFuncObject;

static void
dufunc_dealloc(PyDUFuncObject *self)
{
Expand Down Expand Up @@ -350,6 +343,15 @@ dufunc_reduce(PyDUFuncObject * self, PyObject * args, PyObject *kws)
return ufunc_dispatch.ufunc_reduce((PyObject*)self->ufunc, args, kws);
}

static PyObject *
dufunc_reduce_direct(PyDUFuncObject * self, arystruct_t * args, int axis, PyTypeObject *retty, PyArray_Descr *descr)
{
PyObject * res = NRT_adapt_ndarray_to_python_acqref(args, retty, 2, 1, descr);
PyObject *kwargs = Py_BuildValue("{s:L}", "axis", axis);
// The code segfaults on the following call.
return dufunc_reduce(self, res, kwargs);
}

static PyObject *
dufunc_accumulate(PyDUFuncObject * self, PyObject * args, PyObject *kws)
{
Expand Down Expand Up @@ -677,6 +679,9 @@ MOD_INIT(_internal)
if (m == NULL)
return MOD_ERROR_VAL;

PyObject_SetAttrString(m, "dufunc_reduce_direct",
PyLong_FromVoidPtr((void*)&dufunc_reduce_direct));

if (PyType_Ready(&PyUFuncCleaner_Type) < 0)
return MOD_ERROR_VAL;

Expand Down
17 changes: 16 additions & 1 deletion numba/np/ufunc/_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,28 @@

#include "../../_pymodule.h"
#include <structmember.h>
#include "../../cext/cext.h"
#include "Python.h"

#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include "numpy/ndarrayobject.h"
#include "numpy/ufuncobject.h"
#include "../../_arraystruct.h"
#include "../../../numba/core/runtime/nrt.h"

extern PyObject *ufunc_fromfunc(PyObject *NPY_UNUSED(dummy), PyObject *args);

typedef struct {
PyObject_HEAD
PyObject * dispatcher;
PyUFuncObject * ufunc;
PyObject * keepalive;
int frozen;
} PyDUFuncObject;

NUMBA_EXPORT_FUNC(static PyObject *)
dufunc_reduce_direct(PyDUFuncObject * self, arystruct_t * args, int axis, PyTypeObject *retty, PyArray_Descr *descr);

int PyUFunc_GeneralizedFunction(PyUFuncObject *ufunc,
PyObject *args, PyObject *kwds,
PyArrayObject **op);
Expand All @@ -24,4 +39,4 @@ int PyUFunc_GeneralizedFunction(PyUFuncObject *ufunc,
return NULL; \
}

#endif /* NUMBA_UFUNC_INTERNAL_H_ */
#endif /* NUMBA_UFUNC_INTERNAL_H_ */
65 changes: 65 additions & 0 deletions numba/np/ufunc/dufunc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from llvmlite import ir, binding

from numba import jit, typeof
from numba.core import cgutils, types, serialize, sigutils
from numba.core.typing import npydecl
Expand All @@ -7,6 +9,9 @@
from numba.parfors import array_analysis
from numba.np.ufunc import ufuncbuilder
from numba.np import numpy_support
from numba.core.extending import intrinsic, overload_method
import llvmlite.llvmpy.core as lc
from llvmlite.llvmpy.core import Type


def make_dufunc_kernel(_dufunc):
Expand Down Expand Up @@ -321,3 +326,63 @@ def _install_cg(self, targetctx=None):


array_analysis.MAP_TYPES.append(DUFunc)

binding.add_symbol("dufunc_reduce_direct", _internal.dufunc_reduce_direct)


@intrinsic
def intr_reduce(typcontext, ft, argst, axist):
# TODO: What should the layout be based on `xt.layout`?
assert argst.ndim > 0
rett = types.Array(argst.dtype, argst.ndim - 1, 'C')
box_type = argst.box_type
arr_dtype = argst.dtype
sig = rett(ft, argst, axist)

def codegen(context, builder, signature, args):
f_ir, args_ir, axis_ir = args

axis_ir_type = axis_ir.type

voidptr = Type.pointer(Type.int(8))
aryptr = cgutils.alloca_once_value(builder, args_ir)
pyobj_type = context.get_argument_type(types.pyobject)

pyapi = context.get_python_api(builder)

serial_aryty_pytype = pyapi.unserialize(
pyapi.serialize_object(box_type)
)

fnty = ir.FunctionType(
pyobj_type,
[
f_ir.type,
voidptr,
axis_ir_type,
pyobj_type,
pyobj_type
]
)

fn = cgutils.get_or_insert_function(
builder.module, fnty, "dufunc_reduce_direct"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current issue I was facing is that this call cannot detect the implemented dufunc_reduce_direct method in C code. Is it a type issue ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to add the symbol - try using llvmlite.binding.add_symbol('dufunc_reduce_direct', <addr of dufunc_reduce_direct>)

)
fn.args[1].add_attribute(lc.ATTR_NO_CAPTURE)

res_ptr = builder.call(fn, [
f_ir, builder.bitcast(aryptr, voidptr), axis_ir,
arr_dtype, serial_aryty_pytype
])

return res_ptr

return sig, codegen


@overload_method(types.Function, "reduce")
def dufunc_reduce(fn, args, axis):
if isinstance(fn.typing_key, DUFunc):
def _reduce_impl(fn, args, axis):
return intr_reduce(fn, args, axis)
return _reduce_impl
15 changes: 14 additions & 1 deletion numba/tests/npyufunc/test_ufunc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from numba import float32, jit, njit
from numba import float32, jit, njit, vectorize
from numba.np.ufunc import Vectorize
from numba.core.errors import TypingError
from numba.tests.support import TestCase
Expand Down Expand Up @@ -140,6 +140,19 @@ def inner(x, y):
msg = "expected array(float64, 1d, C), got None"
self.assertIn(msg, str(raises.exception))

def test_reduce(self):

@vectorize
def add_twice(x, y):
return 2 * x + 2 * y

@njit
def test_fn(x, axis):
return add_twice.reduce(x, axis)

a = np.arange(2 * 3).reshape((2, 3))
self.assertPreciseEqual(test_fn(a, 1), np.sum(a, axis=1))


if __name__ == '__main__':
unittest.main()