Skip to content
Permalink
 
 
Cannot retrieve contributors at this time
"""
Provides wrapper functions for "glueing" together Numba implementations that are
written in the "old" style of a separate typing and lowering implementation.
"""
import types as pytypes
import textwrap
from threading import RLock
from collections import defaultdict
from numba.core import errors
class _OverloadWrapper(object):
"""This class does all the work of assembling and registering wrapped split
implementations.
"""
def __init__(self, function, typing_key=None):
assert function is not None
self._function = function
self._typing_key = typing_key
self._BIND_TYPES = dict()
self._selector = None
self._TYPER = None
# run to register overload, the intrinsic sorts out the binding to the
# registered impls at the point the overload is evaluated, i.e. this
# is all lazy.
self._build()
def _stub_generator(self, nargs, body_func, kwargs=None):
"""This generates a function that takes "nargs" count of arguments
and the presented kwargs, the "body_func" is the function that'll
type the overloaded function and then work out which lowering to
return"""
def stub(tyctx):
# body is supplied when the function is magic'd into life via glbls
return body(tyctx) # noqa: F821
if kwargs is None:
kwargs = {}
# create new code parts
stub_code = stub.__code__
co_args = [stub_code.co_argcount + nargs + len(kwargs)]
new_varnames = [*stub_code.co_varnames]
new_varnames.extend([f'tmp{x}' for x in range(nargs)])
new_varnames.extend([x for x, _ in kwargs.items()])
from numba.core import utils
if utils.PYVERSION >= (3, 8):
co_args.append(stub_code.co_posonlyargcount)
co_args.append(stub_code.co_kwonlyargcount)
co_args.extend([stub_code.co_nlocals + nargs + len(kwargs),
stub_code.co_stacksize,
stub_code.co_flags,
stub_code.co_code,
stub_code.co_consts,
stub_code.co_names,
tuple(new_varnames),
stub_code.co_filename,
stub_code.co_name,
stub_code.co_firstlineno,
stub_code.co_lnotab,
stub_code.co_freevars,
stub_code.co_cellvars
])
new_code = pytypes.CodeType(*co_args)
# get function
new_func = pytypes.FunctionType(new_code, {'body': body_func})
return new_func
def wrap_typing(self):
"""
Use this to replace @infer_global, it records the decorated function
as a typer for the argument `concrete_function`.
"""
if self._typing_key is None:
key = self._function
else:
key = self._typing_key
def inner(typing_class):
# Note that two templates could be used for the same function, to
# avoid @infer_global etc the typing template is copied. This is to
# ensure there's a 1:1 relationship between the typing templates and
# their keys.
clazz_dict = dict(typing_class.__dict__)
clazz_dict['key'] = key
cloned = type(f"cloned_template_for_{key}", typing_class.__bases__,
clazz_dict)
self._TYPER = cloned
_overload_glue.add_no_defer(key)
self._build()
return typing_class
return inner
def wrap_impl(self, *args):
"""
Use this to replace @lower*, it records the decorated function as the
lowering implementation
"""
assert self._TYPER is not None
def inner(lowerer):
self._BIND_TYPES[args] = lowerer
return lowerer
return inner
def _assemble(self):
"""Assembles the OverloadSelector definitions from the registered
typing to lowering map.
"""
from numba.core.base import OverloadSelector
if self._typing_key is None:
key = self._function
else:
key = self._typing_key
_overload_glue.flush_deferred_lowering(key)
self._selector = OverloadSelector()
msg = f"No entries in the typing->lowering map for {self._function}"
assert self._BIND_TYPES, msg
for sig, impl in self._BIND_TYPES.items():
self._selector.append(impl, sig)
def _build(self):
from numba.core.extending import overload, intrinsic
@overload(self._function, strict=False,
jit_options={'forceinline': True})
def ol_generated(*ol_args, **ol_kwargs):
def body(tyctx):
msg = f"No typer registered for {self._function}"
if self._TYPER is None:
raise errors.InternalError(msg)
typing = self._TYPER(tyctx)
sig = typing.apply(ol_args, ol_kwargs)
if sig is None:
# this follows convention of something not typeable
# returning None
return None
if self._selector is None:
self._assemble()
lowering = self._selector.find(sig.args)
msg = (f"Could not find implementation to lower {sig} for ",
f"{self._function}")
if lowering is None:
raise errors.InternalError(msg)
return sig, lowering
stub = self._stub_generator(len(ol_args), body, ol_kwargs)
intrin = intrinsic(stub)
# This is horrible, need to generate a jit wrapper function that
# walks the ol_kwargs into the intrin with a signature that
# matches the lowering sig. The actual kwarg var names matter,
# they have to match exactly.
arg_str = ','.join([f'tmp{x}' for x in range(len(ol_args))])
kws_str = ','.join(ol_kwargs.keys())
call_str = ','.join([x for x in (arg_str, kws_str) if x])
# NOTE: The jit_wrapper functions cannot take `*args`
# albeit this an obvious choice for accepting an unknown number
# of arguments. If this is done, `*args` ends up as a cascade of
# Tuple assembling in the IR which ends up with literal
# information being lost. As a result the _exact_ argument list
# is generated to match the number of arguments and kwargs.
name = str(self._function)
# This is to name the function with something vaguely identifiable
name = ''.join([x if x not in {'>','<',' ','-','.'} else '_'
for x in name])
gen = textwrap.dedent(("""
def jit_wrapper_{}({}):
return intrin({})
""")).format(name, call_str, call_str)
l = {}
g = {'intrin': intrin}
exec(gen, g, l)
return l['jit_wrapper_{}'.format(name)]
class _Gluer:
"""This is a helper class to make sure that each concrete overload has only
one wrapper as the code relies on the wrapper being a singleton."""
def __init__(self):
self._registered = dict()
self._lock = RLock()
# `_no_defer` stores keys that should not defer lowering because typing
# is already provided.
self._no_defer = set()
# `_deferred` stores lowering that must be deferred because the typing
# has not been provided.
self._deferred = defaultdict(list)
def __call__(self, func, typing_key=None):
with self._lock:
if typing_key is None:
key = func
else:
key = typing_key
if key in self._registered:
return self._registered[key]
else:
wrapper = _OverloadWrapper(func, typing_key=typing_key)
self._registered[key] = wrapper
return wrapper
def defer_lowering(self, key, lower_fn):
"""Defer lowering of the given key and lowering function.
"""
with self._lock:
if key in self._no_defer:
# Key is marked as no defer, register lowering now
lower_fn()
else:
# Defer
self._deferred[key].append(lower_fn)
def add_no_defer(self, key):
"""Stop lowering to be deferred for the given key.
"""
with self._lock:
self._no_defer.add(key)
def flush_deferred_lowering(self, key):
"""Flush the deferred lowering for the given key.
"""
with self._lock:
deferred = self._deferred.pop(key, [])
for cb in deferred:
cb()
_overload_glue = _Gluer()
del _Gluer
def glue_typing(concrete_function, typing_key=None):
"""This is a decorator for wrapping the typing part for a concrete function
'concrete_function', it's a text-only replacement for '@infer_global'"""
return _overload_glue(concrete_function,
typing_key=typing_key).wrap_typing()
def glue_lowering(*args):
"""This is a decorator for wrapping the implementation (lowering) part for
a concrete function. 'args[0]' is the concrete_function, 'args[1:]' are the
types the lowering will accept. This acts as a text-only replacement for
'@lower/@lower_builtin'"""
def wrap(fn):
key = args[0]
def real_call():
glue = _overload_glue(args[0], typing_key=key)
return glue.wrap_impl(*args[1:])(fn)
_overload_glue.defer_lowering(key, real_call)
return fn
return wrap