Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
66 lines (52 sloc) 2.01 KB
import ufunc_override as np
define a class with __array_priority__ and __ufunc_override__.
__ufunc_override__ should be a dict keyed with ufunc names, and valued
with the callable functions you want to override them.
You can test this with a verison of scipy.sparse here:
Build this then try it out with:
import scipy.sparse as sp
np.multiply(dense_matrix, sparse_matrix)
import numpy as np
class make_overridable(object):
def __init__(self, name): = name
self.__name__ =
def __call__(self, *args, **kwargs):
# Get a list of the args that want to override.
override_args = []
for arg in args:
if (hasattr(arg, '__ufunc_override__') and
hasattr(arg, '__array_priority__')):
if arg.__ufunc_override__.get(
# Sort by __array_priority__
override_args = sorted(override_args,
key=lambda arg: arg.__array_priority__)
if override_args:
dominant_arg = override_args[-1]
remaining_args = [ arg for arg in args if arg is not dominant_arg]
new_func = dominant_arg.__ufunc_override__.get(
return new_func(dominant_arg, *remaining_args, **kwargs)
return*args, **kwargs)
def override_all():
for name, call in np.__dict__.items():
if isinstance(getattr(np, name), np.ufunc):
setattr(np, name, make_overridable(call))
def override_set_numeric_ops():
def_ops = np.set_numeric_ops()
new_ops ={}
for name, call in def_ops.items():
new_ops[name] = make_overridable(call)
def np_to_global():
for name in np.__all__:
if name not in globals():
globals()[name] = getattr(np, name)