-
Notifications
You must be signed in to change notification settings - Fork 0
/
ufunc_override.py
65 lines (52 loc) · 2.01 KB
/
ufunc_override.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""
Usage:
import ufunc_override as np
define a class with __array_priority__ and __ufunc_override__.
__ufunc_override__ should be a dict keyed with ufunc names, and valued
with the callable functions you want to override them.
You can test this with a verison of scipy.sparse here:
https://github.com/cowlicks/scipy/tree/ufunc-override-scipy
Build this then try it out with:
import scipy.sparse as sp
np.multiply(dense_matrix, sparse_matrix)
"""
import numpy as np
class make_overridable(object):
def __init__(self, name):
self.name = name
self.__name__ = self.name.__name__
def __call__(self, *args, **kwargs):
# Get a list of the args that want to override.
override_args = []
for arg in args:
if (hasattr(arg, '__ufunc_override__') and
hasattr(arg, '__array_priority__')):
if arg.__ufunc_override__.get(self.name.__name__):
override_args.append(arg)
# Sort by __array_priority__
override_args = sorted(override_args,
key=lambda arg: arg.__array_priority__)
if override_args:
dominant_arg = override_args[-1]
remaining_args = [ arg for arg in args if arg is not dominant_arg]
new_func = dominant_arg.__ufunc_override__.get(self.name.__name__)
return new_func(dominant_arg, *remaining_args, **kwargs)
else:
return self.name(*args, **kwargs)
def override_all():
for name, call in np.__dict__.items():
if isinstance(getattr(np, name), np.ufunc):
setattr(np, name, make_overridable(call))
def override_set_numeric_ops():
def_ops = np.set_numeric_ops()
new_ops ={}
for name, call in def_ops.items():
new_ops[name] = make_overridable(call)
np.set_numeric_ops(**new_ops)
def np_to_global():
for name in np.__all__:
if name not in globals():
globals()[name] = getattr(np, name)
override_all()
override_set_numeric_ops()
np_to_global()