/
transform.py
78 lines (61 loc) · 2.11 KB
/
transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import ast
import inspect
def rewrite_ndreduce(func):
"""Transforms aggregation functions into something numba can handle.
To be more precise, it converts functions with source that looks like
@ndreduce
def my_func(x)
...
return foo
into
def __sub__gufunc(x, __out):
...
__out[0] = foo
which is the form numba needs for writing a gufunc that returns a scalar
value.
"""
return _apply_ast_rewrite(func, _NDReduceTransformer())
_OUT_NAME = "__numbagg_out"
_TRANSFORMED_FUNC_NAME = "__numbagg_transformed_func"
def _apply_ast_rewrite(func, node_transformer):
"""A hack to make the syntax for writing aggregators more Pythonic.
This should go away once numba is more fully featured.
"""
orig_source = inspect.getsource(func)
tree = ast.parse(orig_source)
tree = node_transformer.visit(tree)
ast.fix_missing_locations(tree)
source = compile(tree, filename="<ast>", mode="exec")
scope: dict = {}
exec(source, func.__globals__, scope)
try:
return scope[_TRANSFORMED_FUNC_NAME]
except KeyError:
raise TypeError("failed to rewrite function definition:\n%s" % orig_source)
class _NDReduceTransformer(ast.NodeTransformer):
def visit_FunctionDef(self, node):
args = node.args.args + [ast.arg(arg=_OUT_NAME, annotation=None)]
arguments = ast.arguments(
args=args,
vararg=None,
kwonlyargs=[],
kw_defaults=[],
kwarg=None,
defaults=[],
posonlyargs=[],
)
function_def = ast.FunctionDef(
name=_TRANSFORMED_FUNC_NAME,
args=arguments,
body=node.body,
decorator_list=[],
)
return self.generic_visit(function_def)
def visit_Return(self, node):
subscript = ast.Subscript(
value=ast.Name(id=_OUT_NAME, ctx=ast.Load()),
slice=ast.Index(value=ast.Num(n=0)),
ctx=ast.Store(),
)
assign = ast.Assign(targets=[subscript], value=node.value)
return assign