Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add no_unliteral flag to overloads #6043

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 9 additions & 5 deletions numba/core/extending.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def generic(self):
_overload_default_jit_options = {'no_cpython_wrapper': True}


def overload(func, jit_options={}, strict=True, inline='never'):
def overload(func, jit_options={}, strict=True, inline="never", no_unliteral=False):
"""
A decorator marking the decorated function as typing and implementing
*func* in nopython mode.
Expand Down Expand Up @@ -109,8 +109,9 @@ def len_impl(seq):
opts.update(jit_options) # let user options override

def decorate(overload_func):
template = make_overload_template(func, overload_func, opts, strict,
inline)
template = make_overload_template(
func, overload_func, opts, strict, inline, no_unliteral
)
infer(template)
if callable(func):
infer_global(func, types.Function(template))
Expand Down Expand Up @@ -205,8 +206,11 @@ def take_impl(arr, indices):

def decorate(overload_func):
template = make_overload_method_template(
typ, attr, overload_func,
inline=kwargs.get('inline', 'never'),
typ,
attr,
overload_func,
inline=kwargs.get("inline", "never"),
no_unliteral=kwargs.get("no_unliteral", False),
)
infer_getattr(template)
overload(overload_func, **kwargs)(overload_func)
Expand Down
48 changes: 26 additions & 22 deletions numba/core/types/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,16 +266,15 @@ def get_impl_key(self, sig):
return self._impl_keys[sig.args]

def get_call_type(self, context, args, kws):
failures = _ResolutionFailures(context, self, args, kws,
depth=self._depth)
failures = _ResolutionFailures(context, self, args, kws, depth=self._depth)
self._depth += 1
for temp_cls in self.templates:
temp = temp_cls(context)
for uselit in [True, False]:
try:
if uselit:
sig = temp.apply(args, kws)
else:
elif not getattr(temp, "_no_unliteral", False):
nolitargs = tuple([unliteral(a) for a in args])
nolitkws = {k: unliteral(v) for k, v in kws.items()}
sig = temp.apply(nolitargs, nolitkws)
Expand All @@ -288,19 +287,22 @@ def get_call_type(self, context, args, kws):
self._depth -= 1
return sig
else:
registered_sigs = getattr(temp, 'cases', None)
registered_sigs = getattr(temp, "cases", None)
if registered_sigs is not None:
msg = "No match for registered cases:\n%s"
msg = msg % '\n'.join(" * {}".format(x) for x in
registered_sigs)
msg = msg % "\n".join(
" * {}".format(x) for x in registered_sigs
)
else:
msg = 'No match.'
msg = "No match."
failures.add_error(temp, True, msg, uselit)

if len(failures) == 0:
raise AssertionError("Internal Error. "
"Function resolution ended with no failures "
"or successful signature")
raise AssertionError(
"Internal Error. "
"Function resolution ended with no failures "
"or successful signature"
)
failures.raise_error()

def get_call_signatures(self):
Expand Down Expand Up @@ -361,7 +363,6 @@ def get_call_type(self, context, args, kws):
literal_e = None
nonliteral_e = None


# Try with Literal
try:
out = template.apply(args, kws)
Expand All @@ -380,7 +381,7 @@ def get_call_type(self, context, args, kws):
# If the above template application failed and the non-literal args are
# different to the literal ones, try again with literals rewritten as
# non-literals
if not skip and out is None:
if not skip and out is None and not getattr(template, "_no_unliteral", False):
try:
out = template.apply(unliteral_args, unliteral_kws)
except Exception as exc:
Expand All @@ -391,30 +392,33 @@ def get_call_type(self, context, args, kws):
if out is None and (nonliteral_e is not None or literal_e is not None):
header = "- Resolution failure for {} arguments:\n{}\n"
tmplt = _termcolor.highlight(header)
if config.DEVELOPER_MODE:
indent = ' ' * 4
if numba.core.config.DEVELOPER_MODE:
indent = " " * 4

def add_bt(error):
if isinstance(error, BaseException):
# if the error is an actual exception instance, trace it
bt = traceback.format_exception(type(error), error,
error.__traceback__)
bt = traceback.format_exception(
type(error), error, error.__traceback__
)
else:
bt = [""]
nd2indent = '\n{}'.format(2 * indent)
errstr = _termcolor.reset(nd2indent +
nd2indent.join(_bt_as_lines(bt)))
nd2indent = "\n{}".format(2 * indent)
errstr += _termcolor.reset(nd2indent + nd2indent.join(bt_as_lines))
return _termcolor.reset(errstr)

else:
add_bt = lambda X: ''
add_bt = lambda X: ""

def nested_msg(literalness, e):
estr = str(e)
estr = estr if estr else (str(repr(e)) + add_bt(e))
new_e = errors.TypingError(textwrap.dedent(estr))
return tmplt.format(literalness, str(new_e))

raise errors.TypingError(nested_msg('literal', literal_e) +
nested_msg('non-literal', nonliteral_e))
raise errors.TypingError(
nested_msg("literal", literal_e) + nested_msg("non-literal", nonliteral_e)
)
return out


Expand Down
63 changes: 45 additions & 18 deletions numba/core/typing/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,19 +763,27 @@ def get_template_info(self):
return info


def make_overload_template(func, overload_func, jit_options, strict,
inline):
def make_overload_template(
func, overload_func, jit_options, strict, inline, no_unliteral
):
"""
Make a template class for function *func* overloaded by *overload_func*.
Compiler options are passed as a dictionary to *jit_options*.
"""
func_name = getattr(func, '__name__', str(func))
func_name = getattr(func, "__name__", str(func))
name = "OverloadTemplate_%s" % (func_name,)
base = _OverloadFunctionTemplate
dct = dict(key=func, _overload_func=staticmethod(overload_func),
_impl_cache={}, _compiled_overloads={}, _jit_options=jit_options,
_strict=strict, _inline=staticmethod(InlineOptions(inline)),
_inline_overloads={})
base = numba.core.typing.templates._OverloadFunctionTemplate
dct = dict(
key=func,
_overload_func=staticmethod(overload_func),
_impl_cache={},
_compiled_overloads={},
_jit_options=jit_options,
_strict=strict,
_inline=staticmethod(InlineOptions(inline)),
_inline_overloads={},
_no_unliteral=no_unliteral,
)
return type(base)(name, (base,), dct)


Expand Down Expand Up @@ -965,6 +973,7 @@ def _resolve(self, typ, attr):
class MethodTemplate(AbstractTemplate):
key = (self.key, attr)
_inline = self._inline
_no_unliteral = getattr(self, "_no_unliteral", False)
_overload_func = staticmethod(self._overload_func)
_inline_overloads = self._inline_overloads

Expand All @@ -981,35 +990,49 @@ def generic(_, args, kws):
return types.BoundFunction(MethodTemplate, typ)


def make_overload_attribute_template(typ, attr, overload_func, inline,
base=_OverloadAttributeTemplate):
def make_overload_attribute_template(
typ,
attr,
overload_func,
inline,
no_unliteral=False,
base=_OverloadAttributeTemplate,
):
"""
Make a template class for attribute *attr* of *typ* overloaded by
*overload_func*.
"""
assert isinstance(typ, types.Type) or issubclass(typ, types.Type)
name = "OverloadAttributeTemplate_%s_%s" % (typ, attr)
# Note the implementation cache is subclass-specific
dct = dict(key=typ, _attr=attr, _impl_cache={},
_inline=staticmethod(InlineOptions(inline)),
_inline_overloads={},
_overload_func=staticmethod(overload_func),
)
dct = dict(
key=typ,
_attr=attr,
_impl_cache={},
_inline=staticmethod(InlineOptions(inline)),
_inline_overloads={},
_no_unliteral=no_unliteral,
_overload_func=staticmethod(overload_func),
)
return type(base)(name, (base,), dct)


def make_overload_method_template(typ, attr, overload_func, inline):
def make_overload_method_template(typ, attr, overload_func, inline, no_unliteral):
"""
Make a template class for method *attr* of *typ* overloaded by
*overload_func*.
"""
return make_overload_attribute_template(
typ, attr, overload_func, inline=inline,
typ,
attr,
overload_func,
inline=inline,
no_unliteral=no_unliteral,
base=_OverloadMethodTemplate,
)


def bound_function(template_key):
def bound_function(template_key, no_unliteral=False):
"""
Wrap an AttributeTemplate resolve_* method to allow it to
resolve an instance method's signature rather than a instance attribute.
Expand All @@ -1026,6 +1049,7 @@ def resolve_conjugate(self, ty, args, kwds):
*template_key* (e.g. "complex.conjugate" above) will be used by the
target to look up the method's implementation, as a regular function.
"""

def wrapper(method_resolver):
@functools.wraps(method_resolver)
def attribute_resolver(self, ty):
Expand All @@ -1038,8 +1062,11 @@ def generic(_, args, kws):
sig = sig.replace(recvr=ty)
return sig

MethodTemplate._no_unliteral = no_unliteral
return types.BoundFunction(MethodTemplate, ty)

return attribute_resolver

return wrapper


Expand Down