-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
overloads.py
238 lines (185 loc) · 7.09 KB
/
overloads.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
"""
Overloads for ClassInstanceType for built-in functions that call dunder methods
on an object.
"""
from functools import wraps
import inspect
import operator
from numba.core.extending import overload
from numba.core.types import ClassInstanceType
def _get_args(n_args):
assert n_args in (1, 2)
return list("xy")[:n_args]
def class_instance_overload(target):
"""
Decorator to add an overload for target that applies when the first argument
is a ClassInstanceType.
"""
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
if not isinstance(args[0], ClassInstanceType):
return
return func(*args, **kwargs)
if target is not complex:
# complex ctor needs special treatment as it uses kwargs
params = list(inspect.signature(wrapped).parameters)
assert params == _get_args(len(params))
return overload(target)(wrapped)
return decorator
def extract_template(template, name):
"""
Extract a code-generated function from a string template.
"""
namespace = {}
exec(template, namespace)
return namespace[name]
def register_simple_overload(func, *attrs, n_args=1,):
"""
Register an overload for func that checks for methods __attr__ for each
attr in attrs.
"""
# Use a template to set the signature correctly.
arg_names = _get_args(n_args)
template = f"""
def func({','.join(arg_names)}):
pass
"""
@wraps(extract_template(template, "func"))
def overload_func(*args, **kwargs):
options = [
try_call_method(args[0], f"__{attr}__", n_args)
for attr in attrs
]
return take_first(*options)
return class_instance_overload(func)(overload_func)
def try_call_method(cls_type, method, n_args=1):
"""
If method is defined for cls_type, return a callable that calls this method.
If not, return None.
"""
if method in cls_type.jit_methods:
arg_names = _get_args(n_args)
template = f"""
def func({','.join(arg_names)}):
return {arg_names[0]}.{method}({','.join(arg_names[1:])})
"""
return extract_template(template, "func")
def try_call_complex_method(cls_type, method):
""" __complex__ needs special treatment as the argument names are kwargs
and therefore specific in name and default value.
"""
if method in cls_type.jit_methods:
template = f"""
def func(real=0, imag=0):
return real.{method}()
"""
return extract_template(template, "func")
def take_first(*options):
"""
Take the first non-None option.
"""
assert all(o is None or inspect.isfunction(o) for o in options), options
for o in options:
if o is not None:
return o
@class_instance_overload(bool)
def class_bool(x):
using_bool_impl = try_call_method(x, "__bool__")
if '__len__' in x.jit_methods:
def using_len_impl(x):
return bool(len(x))
else:
using_len_impl = None
always_true_impl = lambda x: True
return take_first(using_bool_impl, using_len_impl, always_true_impl)
@class_instance_overload(complex)
def class_complex(real=0, imag=0):
return take_first(
try_call_complex_method(real, "__complex__"),
lambda real=0, imag=0: complex(float(real))
)
@class_instance_overload(operator.contains)
def class_contains(x, y):
# https://docs.python.org/3/reference/expressions.html#membership-test-operations
return try_call_method(x, "__contains__", 2)
# TODO: use __iter__ if defined.
@class_instance_overload(float)
def class_float(x):
options = [try_call_method(x, "__float__")]
if (
"__index__" in x.jit_methods
):
options.append(lambda x: float(x.__index__()))
return take_first(*options)
@class_instance_overload(int)
def class_int(x):
options = [try_call_method(x, "__int__")]
options.append(try_call_method(x, "__index__"))
return take_first(*options)
@class_instance_overload(str)
def class_str(x):
return take_first(
try_call_method(x, "__str__"),
lambda x: repr(x),
)
@class_instance_overload(operator.ne)
def class_ne(x, y):
# This doesn't use register_reflected_overload like the other operators
# because it falls back to inverting __eq__ rather than reflecting its
# arguments (as per the definition of the Python data model).
return take_first(
try_call_method(x, "__ne__", 2),
lambda x, y: not (x == y),
)
def register_reflected_overload(func, meth_forward, meth_reflected):
def class_lt(x, y):
normal_impl = try_call_method(x, f"__{meth_forward}__", 2)
if f"__{meth_reflected}__" in y.jit_methods:
def reflected_impl(x, y):
return y > x
else:
reflected_impl = None
return take_first(normal_impl, reflected_impl)
class_instance_overload(func)(class_lt)
register_simple_overload(abs, "abs")
register_simple_overload(len, "len")
register_simple_overload(hash, "hash")
# Comparison operators.
register_reflected_overload(operator.ge, "ge", "le")
register_reflected_overload(operator.gt, "gt", "lt")
register_reflected_overload(operator.le, "le", "ge")
register_reflected_overload(operator.lt, "lt", "gt")
# Note that eq is missing support for fallback to `x is y`, but `is` and
# `operator.is` are presently unsupported in general.
register_reflected_overload(operator.eq, "eq", "eq")
# Arithmetic operators.
register_simple_overload(operator.add, "add", n_args=2)
register_simple_overload(operator.floordiv, "floordiv", n_args=2)
register_simple_overload(operator.lshift, "lshift", n_args=2)
register_simple_overload(operator.mul, "mul", n_args=2)
register_simple_overload(operator.mod, "mod", n_args=2)
register_simple_overload(operator.neg, "neg")
register_simple_overload(operator.pos, "pos")
register_simple_overload(operator.pow, "pow", n_args=2)
register_simple_overload(operator.rshift, "rshift", n_args=2)
register_simple_overload(operator.sub, "sub", n_args=2)
register_simple_overload(operator.truediv, "truediv", n_args=2)
# Inplace arithmetic operators.
register_simple_overload(operator.iadd, "iadd", "add", n_args=2)
register_simple_overload(operator.ifloordiv, "ifloordiv", "floordiv", n_args=2)
register_simple_overload(operator.ilshift, "ilshift", "lshift", n_args=2)
register_simple_overload(operator.imul, "imul", "mul", n_args=2)
register_simple_overload(operator.imod, "imod", "mod", n_args=2)
register_simple_overload(operator.ipow, "ipow", "pow", n_args=2)
register_simple_overload(operator.irshift, "irshift", "rshift", n_args=2)
register_simple_overload(operator.isub, "isub", "sub", n_args=2)
register_simple_overload(operator.itruediv, "itruediv", "truediv", n_args=2)
# Logical operators.
register_simple_overload(operator.and_, "and", n_args=2)
register_simple_overload(operator.or_, "or", n_args=2)
register_simple_overload(operator.xor, "xor", n_args=2)
# Inplace logical operators.
register_simple_overload(operator.iand, "iand", "and", n_args=2)
register_simple_overload(operator.ior, "ior", "or", n_args=2)
register_simple_overload(operator.ixor, "ixor", "xor", n_args=2)