Skip to content
Permalink
 
 
Cannot retrieve contributors at this time
213 lines (175 sloc) 6.49 KB
from collections import deque
from numba.core import types, cgutils
class DataPacker(object):
"""
A helper to pack a number of typed arguments into a data structure.
Omitted arguments (i.e. values with the type `Omitted`) are automatically
skipped.
"""
# XXX should DataPacker be a model for a dedicated type?
def __init__(self, dmm, fe_types):
self._dmm = dmm
self._fe_types = fe_types
self._models = [dmm.lookup(ty) for ty in fe_types]
self._pack_map = []
self._be_types = []
for i, ty in enumerate(fe_types):
if not isinstance(ty, types.Omitted):
self._pack_map.append(i)
self._be_types.append(self._models[i].get_data_type())
def as_data(self, builder, values):
"""
Return the given values packed as a data structure.
"""
elems = [self._models[i].as_data(builder, values[i])
for i in self._pack_map]
return cgutils.make_anonymous_struct(builder, elems)
def _do_load(self, builder, ptr, formal_list=None):
res = []
for i, i_formal in enumerate(self._pack_map):
elem_ptr = cgutils.gep_inbounds(builder, ptr, 0, i)
val = self._models[i_formal].load_from_data_pointer(builder, elem_ptr)
if formal_list is None:
res.append((self._fe_types[i_formal], val))
else:
formal_list[i_formal] = val
return res
def load(self, builder, ptr):
"""
Load the packed values and return a (type, value) tuples.
"""
return self._do_load(builder, ptr)
def load_into(self, builder, ptr, formal_list):
"""
Load the packed values into a sequence indexed by formal
argument number (skipping any Omitted position).
"""
self._do_load(builder, ptr, formal_list)
class ArgPacker(object):
"""
Compute the position for each high-level typed argument.
It flattens every composite argument into primitive types.
It maintains a position map for unflattening the arguments.
Since struct (esp. nested struct) have specific ABI requirements (e.g.
alignemnt, pointer address-space, ...) in different architecture (e.g.
OpenCL, CUDA), flattening composite argument types simplifes the call
setup from the Python side. Functions are receiving simple primitive
types and there are only a handful of these.
"""
def __init__(self, dmm, fe_args):
self._dmm = dmm
self._fe_args = fe_args
self._nargs = len(fe_args)
self._dm_args = []
argtys = []
for ty in fe_args:
dm = self._dmm.lookup(ty)
self._dm_args.append(dm)
argtys.append(dm.get_argument_type())
self._unflattener = _Unflattener(argtys)
self._be_args = list(_flatten(argtys))
def as_arguments(self, builder, values):
"""Flatten all argument values
"""
if len(values) != self._nargs:
raise TypeError("invalid number of args: expected %d, got %d"
% (self._nargs, len(values)))
if not values:
return ()
args = [dm.as_argument(builder, val)
for dm, val in zip(self._dm_args, values)
]
args = tuple(_flatten(args))
return args
def from_arguments(self, builder, args):
"""Unflatten all argument values
"""
valtree = self._unflattener.unflatten(args)
values = [dm.from_argument(builder, val)
for dm, val in zip(self._dm_args, valtree)
]
return values
def assign_names(self, args, names):
"""Assign names for each flattened argument values.
"""
valtree = self._unflattener.unflatten(args)
for aval, aname in zip(valtree, names):
self._assign_names(aval, aname)
def _assign_names(self, val_or_nested, name, depth=()):
if isinstance(val_or_nested, (tuple, list)):
for pos, aval in enumerate(val_or_nested):
self._assign_names(aval, name, depth=depth + (pos,))
else:
postfix = '.'.join(map(str, depth))
parts = [name, postfix]
val_or_nested.name = '.'.join(filter(bool, parts))
@property
def argument_types(self):
"""Return a list of LLVM types that are results of flattening
composite types.
"""
return tuple(ty for ty in self._be_args if ty != ())
def _flatten(iterable):
"""
Flatten nested iterable of (tuple, list).
"""
def rec(iterable):
for i in iterable:
if isinstance(i, (tuple, list)):
for j in rec(i):
yield j
else:
yield i
return rec(iterable)
_PUSH_LIST = 1
_APPEND_NEXT_VALUE = 2
_APPEND_EMPTY_TUPLE = 3
_POP = 4
class _Unflattener(object):
"""
An object used to unflatten nested sequences after a given pattern
(an arbitrarily nested sequence).
The pattern shows the nested sequence shape desired when unflattening;
the values it contains are irrelevant.
"""
def __init__(self, pattern):
self._code = self._build_unflatten_code(pattern)
def _build_unflatten_code(self, iterable):
"""Build the unflatten opcode sequence for the given *iterable* structure
(an iterable of nested sequences).
"""
code = []
def rec(iterable):
for i in iterable:
if isinstance(i, (tuple, list)):
if len(i) > 0:
code.append(_PUSH_LIST)
rec(i)
code.append(_POP)
else:
code.append(_APPEND_EMPTY_TUPLE)
else:
code.append(_APPEND_NEXT_VALUE)
rec(iterable)
return code
def unflatten(self, flatiter):
"""Rebuild a nested tuple structure.
"""
vals = deque(flatiter)
res = []
cur = res
stack = []
for op in self._code:
if op is _PUSH_LIST:
stack.append(cur)
cur.append([])
cur = cur[-1]
elif op is _APPEND_NEXT_VALUE:
cur.append(vals.popleft())
elif op is _APPEND_EMPTY_TUPLE:
cur.append(())
elif op is _POP:
cur = stack.pop()
assert not stack, stack
assert not vals, vals
return res
You can’t perform that action at this time.