Skip to content

Commit

Permalink
Added boxing and unboxing for Numba Array in dufunc_reduce_direct
Browse files Browse the repository at this point in the history
  • Loading branch information
kc611 committed Nov 23, 2021
1 parent 898713b commit a0e2e21
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 12 deletions.
2 changes: 2 additions & 0 deletions numba/core/runtime/nrt.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,4 +269,6 @@ VISIBILITY_HIDDEN const NRT_api_functions* NRT_get_api(void);
*/
VISIBILITY_HIDDEN NRT_ExternalAllocator* _nrt_get_sample_external_allocator(void);

PyObject * NRT_adapt_ndarray_to_python_acqref(arystruct_t* arystruct, PyTypeObject *retty,
int ndim, int writeable, PyArray_Descr *descr)
#endif /* NUMBA_NRT_H_ */
8 changes: 4 additions & 4 deletions numba/np/ufunc/_internal.c
Original file line number Diff line number Diff line change
Expand Up @@ -344,10 +344,10 @@ dufunc_reduce(PyDUFuncObject * self, PyObject * args, PyObject *kws)
}

static PyObject *
dufunc_reduce_direct(PyDUFuncObject * self, PyObject * args, int axis)
{
PyObject *kwargs = Py_BuildValue("{s:L}", "axis", axis);
return dufunc_reduce(self, args, kwargs);
dufunc_reduce_direct(PyDUFuncObject * self, arystruct_t * args, int axis, PyTypeObject *retty, PyArray_Descr *descr)
{
PyObject * res = NRT_adapt_ndarray_to_python_acqref(args, retty, 2, 1, descr);
return res;
}

static PyObject *
Expand Down
4 changes: 3 additions & 1 deletion numba/np/ufunc/_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include "numpy/ndarrayobject.h"
#include "numpy/ufuncobject.h"
#include "../../_arraystruct.h"
#include "../../../numba/core/runtime/nrt.h"

extern PyObject *ufunc_fromfunc(PyObject *NPY_UNUSED(dummy), PyObject *args);

Expand All @@ -23,7 +25,7 @@ typedef struct {
} PyDUFuncObject;

NUMBA_EXPORT_FUNC(static PyObject *)
dufunc_reduce_direct(PyDUFuncObject * self, PyObject * args, int axis);
dufunc_reduce_direct(PyDUFuncObject * self, arystruct_t * args, int axis, PyTypeObject *retty, PyArray_Descr *descr);

int PyUFunc_GeneralizedFunction(PyUFuncObject *ufunc,
PyObject *args, PyObject *kwds,
Expand Down
38 changes: 31 additions & 7 deletions numba/np/ufunc/dufunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from numba.np.ufunc import ufuncbuilder
from numba.np import numpy_support
from numba.core.extending import intrinsic, overload_method
import llvmlite.llvmpy.core as lc
from llvmlite.llvmpy.core import Type


def make_dufunc_kernel(_dufunc):
Expand Down Expand Up @@ -332,26 +334,48 @@ def _install_cg(self, targetctx=None):
def intr_reduce(typcontext, ft, argst, axist):
# TODO: What should the layout be based on `xt.layout`?
assert argst.ndim > 0
rett = types.Array(argst.dtype, argst.ndim - 1, 'C')
rett = types.Array(argst.dtype, argst.ndim, 'C')
box_type = argst.box_type
arr_dtype = argst.dtype
sig = rett(ft, argst, axist)

def codegen(context, builder, signature, args):
f_ir, args_ir, axis_ir = args

ret_ir_type = context.get_value_type(signature.return_type)

args_ir_type = args_ir.type
axis_ir_type = axis_ir.type

voidptr = Type.pointer(Type.int(8))
aryptr = cgutils.alloca_once_value(builder, args_ir)
pyobj_type = context.get_argument_type(types.pyobject)

pyapi = context.get_python_api(builder)

serial_aryty_pytype = pyapi.unserialize(
pyapi.serialize_object(box_type)
)

fnty = ir.FunctionType(
ret_ir_type,
[f_ir.type, args_ir_type, axis_ir_type]
pyobj_type,
[
f_ir.type,
voidptr,
axis_ir_type,
pyobj_type,
pyobj_type
]
)

fn = cgutils.get_or_insert_function(
builder.module, fnty, "dufunc_reduce_direct"
)
return builder.call(fn, [f_ir, args_ir, axis_ir])
fn.args[1].add_attribute(lc.ATTR_NO_CAPTURE)

res_ptr = builder.call(fn, [
f_ir, builder.bitcast(aryptr, voidptr), axis_ir,
arr_dtype, serial_aryty_pytype
])

return res_ptr

return sig, codegen

Expand Down

0 comments on commit a0e2e21

Please sign in to comment.