/
utils.py
105 lines (84 loc) · 3.83 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from collections import OrderedDict
import ctypes
from devito.cgen_utils import INT
from devito.ir.iet import Expression, ForeignExpression, FindNodes, Transformer
from devito.symbolics import FunctionFromPointer, ListInitializer, retrieve_indexed
from devito.tools import ctypes_pointer
__all__ = ['make_grid_accesses', 'make_sharedptr_funcall', 'rawpointer',
'split_increment']
def make_sharedptr_funcall(call, params, sharedptr):
return FunctionFromPointer(call, FunctionFromPointer('get', sharedptr), params)
def make_grid_accesses(node):
"""
Construct a new Iteration/Expression based on ``node``, in which all
:class:`types.Indexed` accesses have been converted into YASK grid
accesses.
"""
def make_grid_gets(expr):
mapper = {}
indexeds = retrieve_indexed(expr)
data_carriers = [i for i in indexeds if i.base.function.from_YASK]
for i in data_carriers:
name = namespace['code-grid-name'](i.base.function.name)
args = [ListInitializer([INT(make_grid_gets(j)) for j in i.indices])]
mapper[i] = make_sharedptr_funcall(namespace['code-grid-get'], args, name)
return expr.xreplace(mapper)
mapper = {}
for i, e in enumerate(FindNodes(Expression).visit(node)):
lhs, rhs = e.expr.args
# RHS translation
rhs = make_grid_gets(rhs)
# LHS translation
if e.write.from_YASK:
name = namespace['code-grid-name'](e.write.name)
args = [rhs]
args += [ListInitializer([INT(make_grid_gets(i)) for i in lhs.indices])]
handle = make_sharedptr_funcall(namespace['code-grid-put'], args, name)
processed = ForeignExpression(handle, e.dtype, is_Increment=e.is_increment)
else:
# Writing to a scalar temporary
processed = Expression(e.expr.func(lhs, rhs))
mapper.update({e: processed})
return Transformer(mapper).visit(node)
def rawpointer(obj):
"""Return a :class:`ctypes.c_void_p` pointing to ``obj``."""
return ctypes.cast(int(obj), ctypes.c_void_p)
def split_increment(expr):
"""
Split an increment of type: ::
u->set_element(v + u->get_element(indices), indices)
into its three main components, namely the target grid ``u``, the increment
value ``v``, and the :class:`ListInitializer` ``indices``.
:raises ValueError: If ``expr`` is not an increment or does not appear in
the normal form above.
"""
if not isinstance(expr, FunctionFromPointer) or len(expr.params) != 2:
raise ValueError
target = expr.pointer
expr, indices = expr.params
if not isinstance(indices, ListInitializer):
raise ValueError
if not expr.is_Add or len(expr.args) != 2:
raise ValueError
values = [i for i in expr.args if not isinstance(i, FunctionFromPointer)]
if not len(values) == 1:
raise ValueError
return target, values[0], indices
# YASK conventions
namespace = OrderedDict()
namespace['jit-yc-hook'] = lambda i, j: 'devito_%s_yc_hook%d' % (i, j)
namespace['jit-yk-hook'] = lambda i, j: 'devito_%s_yk_hook%d' % (i, j)
namespace['jit-yc-soln'] = lambda i, j: 'devito_%s_yc_soln%d' % (i, j)
namespace['jit-yk-soln'] = lambda i, j: 'devito_%s_yk_soln%d' % (i, j)
namespace['kernel-filename'] = 'yask_stencil_code.hpp'
namespace['code-soln-type'] = 'yask::yk_solution'
namespace['code-soln-name'] = 'soln'
namespace['code-soln-run'] = 'run_solution'
namespace['code-grid-type'] = 'yask::yk_grid'
namespace['code-grid-name'] = lambda i: "grid_%s" % str(i)
namespace['code-grid-get'] = 'get_element'
namespace['code-grid-put'] = 'set_element'
namespace['code-grid-add'] = 'add_to_element'
namespace['type-solution'] = ctypes_pointer('yask::yk_solution_ptr')
namespace['type-grid'] = ctypes_pointer('yask::yk_grid_ptr')
namespace['numa-put-local'] = -1