Skip to content

Commit

Permalink
Merge pull request #9239 from guilhermeleobas/guilhermeleobas/ufunc_at
Browse files Browse the repository at this point in the history
ufunc.at
  • Loading branch information
sklam committed Mar 29, 2024
2 parents 89218bb + a539aa2 commit a060559
Show file tree
Hide file tree
Showing 3 changed files with 627 additions and 2 deletions.
4 changes: 4 additions & 0 deletions docs/upcoming_changes/9239.new_feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Add experimental support for ufunc.at
-------------------------------------

Experimental support for ``ufunc.at`` is added.
233 changes: 233 additions & 0 deletions numba/np/ufunc/dufunc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import operator
import warnings

import numpy as np
Expand All @@ -20,6 +21,115 @@
from numba.core.compiler_lock import global_compiler_lock


class UfuncAtIterator:

def __init__(self, ufunc, a, a_ty, indices, indices_ty, b=None, b_ty=None):
self.ufunc = ufunc
self.a = a
self.a_ty = a_ty
self.indices = indices
self.indices_ty = indices_ty
self.b = b
self.b_ty = b_ty

def run(self, context, builder):
self._prepare(context, builder)
loop_indices, _ = self.indexer.begin_loops()
self._call_ufunc(context, builder, loop_indices)
self.indexer.end_loops()

def need_advanced_indexing(self):
return isinstance(self.indices_ty, types.BaseTuple)

def _prepare(self, context, builder):
from numba.np.arrayobj import normalize_indices, FancyIndexer

a, indices = self.a, self.indices
a_ty, indices_ty = self.a_ty, self.indices_ty

zero = context.get_value_type(types.intp)(0)

if self.b is not None:
self.b_indice = cgutils.alloca_once_value(builder, zero)

if self.need_advanced_indexing():
indices = cgutils.unpack_tuple(builder, indices,
count=len(indices_ty))
index_types = indices_ty.types
index_types, indices = normalize_indices(context, builder,
index_types, indices)
else:
indices = (indices,)
index_types = (indices_ty,)
index_types, indices = normalize_indices(context, builder,
index_types, indices)

self.indexer = FancyIndexer(context, builder, a_ty, a,
index_types, indices)
self.indexer.prepare()
self.cres = self._compile_ufunc(context, builder)

def _load_val(self, context, builder, loop_indices, array, array_ty):
from numba.np.arrayobj import load_item
shapes = cgutils.unpack_tuple(builder, array.shape)
strides = cgutils.unpack_tuple(builder, array.strides)
data = array.data

ptr = cgutils.get_item_pointer2(context, builder, data, shapes, strides,
array_ty.layout, loop_indices)
val = load_item(context, builder, array_ty, ptr)
return ptr, val

def _load_flat(self, context, builder, indices, array, array_ty):
idx = builder.load(indices)
sig = array_ty.dtype(array_ty, types.intp)
impl = context.get_function(operator.getitem, sig)
val = impl(builder, (array, idx))

# increment indices
one = context.get_value_type(types.intp)(1)
idx = builder.add(idx, one)
builder.store(idx, indices)

return None, val

def _store_val(self, context, builder, array, array_ty, ptr, val):
from numba.np.arrayobj import store_item
fromty = self.cres.signature.return_type
toty = array_ty.dtype
val = context.cast(builder, val, fromty, toty)
store_item(context, builder, array_ty, val, ptr)

def _compile_ufunc(self, context, builder):
ufunc = self.ufunc.key[0]

if self.b is None:
sig = (self.a_ty.dtype,)
else:
sig = (self.a_ty.dtype, self.b_ty.dtype)

cres = ufunc.add(sig)
context.add_linking_libs((cres.library,))
return cres

def _call_ufunc(self, context, builder, loop_indices):
cres = self.cres
a, a_ty = self.a, self.a_ty

ptr, val = self._load_val(context, builder, loop_indices, a, a_ty)

if self.b is None:
args = (val,)
else:
b, b_ty, b_idx = self.b, self.b_ty, self.b_indice
_, val_b = self._load_flat(context, builder, b_idx, b, b_ty)
args = (val, val_b)

res = context.call_internal(builder, cres.fndesc, cres.signature,
args)
self._store_val(context, builder, a, a_ty, ptr, res)


def make_dufunc_kernel(_dufunc):
from numba.np import npyimpl

Expand Down Expand Up @@ -293,6 +403,108 @@ def impl(ufunc):

def _install_ufunc_methods(self, template) -> None:
self._install_ufunc_reduce(template)
self._install_ufunc_at(template)

def _install_ufunc_at(self, template) -> None:
at = types.Function(template)

@overload_method(at, 'at')
def ol_at(ufunc, a, indices, b=None):
warnings.warn("ufunc.at feature is experimental",
category=errors.NumbaExperimentalFeatureWarning)

if not isinstance(a, types.Array):
msg = 'The first argument "a" must be array-like'
raise errors.NumbaTypeError(msg)

indices_arr = isinstance(indices, types.Array)
indices_list = isinstance(indices, types.List)
indices_tuple = isinstance(indices, types.Tuple)
indices_slice = isinstance(indices, types.SliceType)
indices_scalar = not (indices_arr or indices_slice or indices_tuple)
indices_empty_tuple = indices_tuple and len(indices) == 0
b_array = isinstance(b, (types.Array, types.Sequence, types.List,
types.Tuple))
b_none = cgutils.is_nonelike(b)
b_scalar = not (b_array or b_none)
need_cast = any([indices_list])

nin = self.ufunc.nin

# missing second argument?
if nin == 2 and cgutils.is_nonelike(b):
raise errors.TypingError('second operand needed for ufunc')

# extra second argument
if nin == 1 and not cgutils.is_nonelike(b):
msg = 'second operand provided when ufunc is unary'
raise errors.TypingError(msg)

if cgutils.is_nonelike(b):
self.add((a.dtype,))
elif b_scalar:
self.add((a.dtype, b))
else:
self.add((a.dtype, b.dtype))

def apply_ufunc_codegen(context, builder, sig, args):
from numba.np.arrayobj import make_array

if len(args) == 4:
_, aty, idxty, bty = sig.args
_, a, indices, b = args
else:
_, aty, idxty, bty = sig.args + (None,)
_, a, indices, b = args + (None,)

a = make_array(aty)(context, builder, a)
at_iter = UfuncAtIterator(ufunc, a, aty, indices, idxty, b, bty)
at_iter.run(context, builder)

@intrinsic
def apply_a_b_ufunc(typingctx, ufunc, a, indices, b):
sig = types.none(ufunc, a, indices, b)
return sig, apply_ufunc_codegen

@intrinsic
def apply_a_ufunc(typingctx, ufunc, a, indices):
sig = types.none(ufunc, a, indices)
return sig, apply_ufunc_codegen

def impl_cast(ufunc, a, indices, b=None):
if b_none:
return ufunc.at(a, np.asarray(indices))
else:
return ufunc.at(a,
np.asarray(indices),
np.asarray(b))

def impl_generic(ufunc, a, indices, b=None):
if b_none:
apply_a_ufunc(ufunc, a, indices,)
else:
b_ = np.asarray(b)
a_ = a[indices]
b_ = np.broadcast_to(b_, a_.shape)
apply_a_b_ufunc(ufunc, a, indices, b_.flat)

def impl_indices_empty_b_scalar(ufunc, a, indices, b=None):
a[()] = ufunc(a[()], b)

def impl_scalar_scalar(ufunc, a, indices, b=None):
if b_none:
a[indices] = ufunc(a[indices])
else:
a[indices] = ufunc(a[indices], b)

if need_cast:
return impl_cast
elif indices_empty_tuple and b_scalar:
return impl_indices_empty_b_scalar
elif indices_scalar and b_scalar:
return impl_scalar_scalar
else:
return impl_generic

def _install_ufunc_reduce(self, template) -> None:
at = types.Function(template)
Expand Down Expand Up @@ -579,6 +791,27 @@ def impl_axis_none(ufunc,
# elif array.ndim == 1:
# return impl_1d

def at(self, a, indices, b=None):
# dynamic compile ufunc.at
args = (a,) if cgutils.is_nonelike(b) else (a, b)
argtys = (typeof(arg) for arg in args)
ewise_types = tuple(arg.dtype if isinstance(arg, types.Array) else arg
for arg in argtys)

if self.find_ewise_function(ewise_types) == (None, None):
# cannot find a matching function and compilation is disabled
if self._frozen:
msg = "compilation disabled for %s.at(...)" % (self,)
raise RuntimeError(msg)

self._compile_for_args(*args)

# all good, just dispatch to the function
if cgutils.is_nonelike(b):
return super().at(a, indices)
else:
return super().at(*(a, indices, b))

def _install_type(self, typingctx=None):
"""Constructs and installs a typing class for a DUFunc object in the
input typing context. If no typing context is given, then
Expand Down

0 comments on commit a060559

Please sign in to comment.