Skip to content

Commit

Permalink
Merge pull request #6058 from sklam/fix/overspecialize_option
Browse files Browse the repository at this point in the history
Add prefer_literal option to overload API
  • Loading branch information
sklam committed Aug 4, 2020
2 parents 8b40f73 + 50d3df6 commit bed4d04
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 36 deletions.
12 changes: 10 additions & 2 deletions numba/core/extending.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def generic(self):
_overload_default_jit_options = {'no_cpython_wrapper': True}


def overload(func, jit_options={}, strict=True, inline='never'):
def overload(func, jit_options={}, strict=True, inline='never',
prefer_literal=False):
"""
A decorator marking the decorated function as typing and implementing
*func* in nopython mode.
Expand Down Expand Up @@ -101,6 +102,12 @@ def len_impl(seq):
holds the information from the callee. The function should return Truthy
to determine whether to inline, this essentially permitting custom
inlining rules (typical use might be cost models).
The *prefer_literal* option allows users to control if literal types should
be tried first or last. The default (`False`) is to use non-literal types.
Implementations that can specialize based on literal values should set the
option to `True`. Note, this option maybe expanded in the near future to
allow for more control (e.g. disabling non-literal types).
"""
from numba.core.typing.templates import make_overload_template, infer_global

Expand All @@ -110,7 +117,7 @@ def len_impl(seq):

def decorate(overload_func):
template = make_overload_template(func, overload_func, opts, strict,
inline)
inline, prefer_literal)
infer(template)
if callable(func):
infer_global(func, types.Function(template))
Expand Down Expand Up @@ -207,6 +214,7 @@ def decorate(overload_func):
template = make_overload_method_template(
typ, attr, overload_func,
inline=kwargs.get('inline', 'never'),
prefer_literal=kwargs.get('prefer_literal', False)
)
infer_getattr(template)
overload(overload_func, **kwargs)(overload_func)
Expand Down
64 changes: 38 additions & 26 deletions numba/core/types/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,33 +375,45 @@ def get_call_type(self, context, args, kws):
template = self.template(context)
literal_e = None
nonliteral_e = None
out = None


# Try with Literal
try:
out = template.apply(args, kws)
except Exception as exc:
if isinstance(exc, errors.ForceLiteralArg):
raise exc
literal_e = exc
out = None

# if the unliteral_args and unliteral_kws are the same as the literal
# ones, set up to not bother retrying
unliteral_args = tuple([unliteral(a) for a in args])
unliteral_kws = {k: unliteral(v) for k, v in kws.items()}
skip = unliteral_args == args and kws == unliteral_kws

# If the above template application failed and the non-literal args are
# different to the literal ones, try again with literals rewritten as
# non-literals
if not skip and out is None:
try:
out = template.apply(unliteral_args, unliteral_kws)
except Exception as exc:
if isinstance(exc, errors.ForceLiteralArg):
raise exc
nonliteral_e = exc
choice = [True, False] if template.prefer_literal else [False, True]
for uselit in choice:
if uselit:
# Try with Literal
try:
out = template.apply(args, kws)
except Exception as exc:
if isinstance(exc, errors.ForceLiteralArg):
raise exc
literal_e = exc
out = None
else:
break
else:
# if the unliteral_args and unliteral_kws are the same as the literal
# ones, set up to not bother retrying
unliteral_args = tuple([_unlit_non_poison(a) for a in args])
unliteral_kws = {k: _unlit_non_poison(v)
for k, v in kws.items()}
skip = unliteral_args == args and kws == unliteral_kws

# If the above template application failed and the non-literal args are
# different to the literal ones, try again with literals rewritten as
# non-literals
if not skip and out is None:
try:
out = template.apply(unliteral_args, unliteral_kws)
except Exception as exc:
if isinstance(exc, errors.ForceLiteralArg):
if template.prefer_literal:
# For template that prefers literal types,
# reaching here means that the literal types
# have failed typing as well.
raise exc
nonliteral_e = exc
else:
break

if out is None and (nonliteral_e is not None or literal_e is not None):
header = "- Resolution failure for {} arguments:\n{}\n"
Expand Down
7 changes: 4 additions & 3 deletions numba/core/typing/arraydecl.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,8 +744,9 @@ def generic_index(self, args, kws):
assert not kws
return signature(types.intp, recvr=self.this)

def install_array_method(name, generic):
my_attr = {"key": "array." + name, "generic": generic}
def install_array_method(name, generic, prefer_literal=True):
my_attr = {"key": "array." + name, "generic": generic,
"prefer_literal": prefer_literal}
temp_class = type("Array_" + name, (AbstractTemplate,), my_attr)
def array_attribute_attachment(self, ary):
return types.BoundFunction(temp_class, ary)
Expand All @@ -758,7 +759,7 @@ def array_attribute_attachment(self, ary):

# Functions that return a machine-width type, to avoid overflows
install_array_method("prod", generic_expand)
install_array_method("sum", sum_expand)
install_array_method("sum", sum_expand, prefer_literal=True)

# Functions that return a machine-width type, to avoid overflows
for fname in ["cumsum", "cumprod"]:
Expand Down
15 changes: 10 additions & 5 deletions numba/core/typing/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ def get_template_info(self):


def make_overload_template(func, overload_func, jit_options, strict,
inline):
inline, prefer_literal=False):
"""
Make a template class for function *func* overloaded by *overload_func*.
Compiler options are passed as a dictionary to *jit_options*.
Expand All @@ -793,7 +793,7 @@ def make_overload_template(func, overload_func, jit_options, strict,
dct = dict(key=func, _overload_func=staticmethod(overload_func),
_impl_cache={}, _compiled_overloads={}, _jit_options=jit_options,
_strict=strict, _inline=staticmethod(InlineOptions(inline)),
_inline_overloads={})
_inline_overloads={}, prefer_literal=prefer_literal)
return type(base)(name, (base,), dct)


Expand Down Expand Up @@ -985,6 +985,7 @@ class MethodTemplate(AbstractTemplate):
_inline = self._inline
_overload_func = staticmethod(self._overload_func)
_inline_overloads = self._inline_overloads
prefer_literal = self.prefer_literal

def generic(_, args, kws):
args = (typ,) + tuple(args)
Expand All @@ -1000,6 +1001,7 @@ def generic(_, args, kws):


def make_overload_attribute_template(typ, attr, overload_func, inline,
prefer_literal=False,
base=_OverloadAttributeTemplate):
"""
Make a template class for attribute *attr* of *typ* overloaded by
Expand All @@ -1012,18 +1014,21 @@ def make_overload_attribute_template(typ, attr, overload_func, inline,
_inline=staticmethod(InlineOptions(inline)),
_inline_overloads={},
_overload_func=staticmethod(overload_func),
prefer_literal=prefer_literal,
)
return type(base)(name, (base,), dct)
obj = type(base)(name, (base,), dct)
return obj


def make_overload_method_template(typ, attr, overload_func, inline):
def make_overload_method_template(typ, attr, overload_func, inline,
prefer_literal=False):
"""
Make a template class for method *attr* of *typ* overloaded by
*overload_func*.
"""
return make_overload_attribute_template(
typ, attr, overload_func, inline=inline,
base=_OverloadMethodTemplate,
base=_OverloadMethodTemplate, prefer_literal=prefer_literal,
)


Expand Down
95 changes: 95 additions & 0 deletions numba/tests/test_extending.py
Original file line number Diff line number Diff line change
Expand Up @@ -1770,5 +1770,100 @@ def foo(x):
)


class TestOverloadPreferLiteral(TestCase):
def test_overload(self):
def prefer_lit(x):
pass

def non_lit(x):
pass

def ov(x):
if isinstance(x, types.IntegerLiteral):
# With prefer_literal=False, this branch will not be reached.
if x.literal_value == 1:
def impl(x):
return 0xcafe
return impl
else:
raise errors.TypingError('literal value')
else:
def impl(x):
return x * 100
return impl

overload(prefer_lit, prefer_literal=True)(ov)
overload(non_lit)(ov)

@njit
def check_prefer_lit(x):
return prefer_lit(1), prefer_lit(2), prefer_lit(x)

a, b, c = check_prefer_lit(3)
self.assertEqual(a, 0xcafe)
self.assertEqual(b, 200)
self.assertEqual(c, 300)

@njit
def check_non_lit(x):
return non_lit(1), non_lit(2), non_lit(x)

a, b, c = check_non_lit(3)
self.assertEqual(a, 100)
self.assertEqual(b, 200)
self.assertEqual(c, 300)

def test_overload_method(self):
def ov(self, x):
if isinstance(x, types.IntegerLiteral):
# With prefer_literal=False, this branch will not be reached.
if x.literal_value == 1:
def impl(self, x):
return 0xcafe
return impl
else:
raise errors.TypingError('literal value')
else:
def impl(self, x):
return x * 100
return impl

overload_method(
MyDummyType, "method_prefer_literal",
prefer_literal=True,
)(ov)

overload_method(
MyDummyType, "method_non_literal",
prefer_literal=False,
)(ov)

@njit
def check_prefer_lit(dummy, x):
return (
dummy.method_prefer_literal(1),
dummy.method_prefer_literal(2),
dummy.method_prefer_literal(x),
)

a, b, c = check_prefer_lit(MyDummy(), 3)
self.assertEqual(a, 0xcafe)
self.assertEqual(b, 200)
self.assertEqual(c, 300)

@njit
def check_non_lit(dummy, x):
return (
dummy.method_non_literal(1),
dummy.method_non_literal(2),
dummy.method_non_literal(x),
)

a, b, c = check_non_lit(MyDummy(), 3)
self.assertEqual(a, 100)
self.assertEqual(b, 200)
self.assertEqual(c, 300)


if __name__ == "__main__":
unittest.main()

0 comments on commit bed4d04

Please sign in to comment.