Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add prefer_literal option to overload API #6058

Merged
merged 10 commits into from
Aug 4, 2020
12 changes: 10 additions & 2 deletions numba/core/extending.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def generic(self):
_overload_default_jit_options = {'no_cpython_wrapper': True}


def overload(func, jit_options={}, strict=True, inline='never'):
def overload(func, jit_options={}, strict=True, inline='never',
prefer_literal=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be documented? OR given the likelihood of it becoming a namedtuple or something for the next release, leave it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, i should document it and leave a note that it might change.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in 89a3840

"""
A decorator marking the decorated function as typing and implementing
*func* in nopython mode.
Expand Down Expand Up @@ -101,6 +102,12 @@ def len_impl(seq):
holds the information from the callee. The function should return Truthy
to determine whether to inline, this essentially permitting custom
inlining rules (typical use might be cost models).

The *prefer_literal* option allows users to control if literal types should
be tried first or last. The default (`False`) is to use non-literal types.
Implementations that can specialize based on literal values should set the
option to `True`. Note, this option maybe expanded in the near future to
allow for more control (e.g. disabling non-literal types).
"""
from numba.core.typing.templates import make_overload_template, infer_global

Expand All @@ -110,7 +117,7 @@ def len_impl(seq):

def decorate(overload_func):
template = make_overload_template(func, overload_func, opts, strict,
inline)
inline, prefer_literal)
infer(template)
if callable(func):
infer_global(func, types.Function(template))
Expand Down Expand Up @@ -207,6 +214,7 @@ def decorate(overload_func):
template = make_overload_method_template(
typ, attr, overload_func,
inline=kwargs.get('inline', 'never'),
prefer_literal=kwargs.get('prefer_literal', False)
)
infer_getattr(template)
overload(overload_func, **kwargs)(overload_func)
Expand Down
64 changes: 38 additions & 26 deletions numba/core/types/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,33 +375,45 @@ def get_call_type(self, context, args, kws):
template = self.template(context)
literal_e = None
nonliteral_e = None
out = None


# Try with Literal
try:
out = template.apply(args, kws)
except Exception as exc:
if isinstance(exc, errors.ForceLiteralArg):
raise exc
literal_e = exc
out = None

# if the unliteral_args and unliteral_kws are the same as the literal
# ones, set up to not bother retrying
unliteral_args = tuple([unliteral(a) for a in args])
unliteral_kws = {k: unliteral(v) for k, v in kws.items()}
skip = unliteral_args == args and kws == unliteral_kws

# If the above template application failed and the non-literal args are
# different to the literal ones, try again with literals rewritten as
# non-literals
if not skip and out is None:
try:
out = template.apply(unliteral_args, unliteral_kws)
except Exception as exc:
if isinstance(exc, errors.ForceLiteralArg):
raise exc
nonliteral_e = exc
choice = [True, False] if template.prefer_literal else [False, True]
for uselit in choice:
if uselit:
# Try with Literal
try:
out = template.apply(args, kws)
except Exception as exc:
if isinstance(exc, errors.ForceLiteralArg):
raise exc
literal_e = exc
out = None
else:
break
else:
# if the unliteral_args and unliteral_kws are the same as the literal
# ones, set up to not bother retrying
unliteral_args = tuple([_unlit_non_poison(a) for a in args])
unliteral_kws = {k: _unlit_non_poison(v)
for k, v in kws.items()}
skip = unliteral_args == args and kws == unliteral_kws

# If the above template application failed and the non-literal args are
# different to the literal ones, try again with literals rewritten as
# non-literals
if not skip and out is None:
try:
out = template.apply(unliteral_args, unliteral_kws)
except Exception as exc:
if isinstance(exc, errors.ForceLiteralArg):
if template.prefer_literal:
# For template that prefers literal types,
# reaching here means that the literal types
# have failed typing as well.
raise exc
nonliteral_e = exc
else:
break

if out is None and (nonliteral_e is not None or literal_e is not None):
header = "- Resolution failure for {} arguments:\n{}\n"
Expand Down
7 changes: 4 additions & 3 deletions numba/core/typing/arraydecl.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,8 +744,9 @@ def generic_index(self, args, kws):
assert not kws
return signature(types.intp, recvr=self.this)

def install_array_method(name, generic):
my_attr = {"key": "array." + name, "generic": generic}
def install_array_method(name, generic, prefer_literal=True):
my_attr = {"key": "array." + name, "generic": generic,
"prefer_literal": prefer_literal}
temp_class = type("Array_" + name, (AbstractTemplate,), my_attr)
def array_attribute_attachment(self, ary):
return types.BoundFunction(temp_class, ary)
Expand All @@ -758,7 +759,7 @@ def array_attribute_attachment(self, ary):

# Functions that return a machine-width type, to avoid overflows
install_array_method("prod", generic_expand)
install_array_method("sum", sum_expand)
install_array_method("sum", sum_expand, prefer_literal=True)

# Functions that return a machine-width type, to avoid overflows
for fname in ["cumsum", "cumprod"]:
Expand Down
15 changes: 10 additions & 5 deletions numba/core/typing/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ def get_template_info(self):


def make_overload_template(func, overload_func, jit_options, strict,
inline):
inline, prefer_literal=False):
"""
Make a template class for function *func* overloaded by *overload_func*.
Compiler options are passed as a dictionary to *jit_options*.
Expand All @@ -793,7 +793,7 @@ def make_overload_template(func, overload_func, jit_options, strict,
dct = dict(key=func, _overload_func=staticmethod(overload_func),
_impl_cache={}, _compiled_overloads={}, _jit_options=jit_options,
_strict=strict, _inline=staticmethod(InlineOptions(inline)),
_inline_overloads={})
_inline_overloads={}, prefer_literal=prefer_literal)
return type(base)(name, (base,), dct)


Expand Down Expand Up @@ -985,6 +985,7 @@ class MethodTemplate(AbstractTemplate):
_inline = self._inline
_overload_func = staticmethod(self._overload_func)
_inline_overloads = self._inline_overloads
prefer_literal = self.prefer_literal

def generic(_, args, kws):
args = (typ,) + tuple(args)
Expand All @@ -1000,6 +1001,7 @@ def generic(_, args, kws):


def make_overload_attribute_template(typ, attr, overload_func, inline,
prefer_literal=False,
base=_OverloadAttributeTemplate):
"""
Make a template class for attribute *attr* of *typ* overloaded by
Expand All @@ -1012,18 +1014,21 @@ def make_overload_attribute_template(typ, attr, overload_func, inline,
_inline=staticmethod(InlineOptions(inline)),
_inline_overloads={},
_overload_func=staticmethod(overload_func),
prefer_literal=prefer_literal,
)
return type(base)(name, (base,), dct)
obj = type(base)(name, (base,), dct)
return obj


def make_overload_method_template(typ, attr, overload_func, inline):
def make_overload_method_template(typ, attr, overload_func, inline,
prefer_literal=False):
"""
Make a template class for method *attr* of *typ* overloaded by
*overload_func*.
"""
return make_overload_attribute_template(
typ, attr, overload_func, inline=inline,
base=_OverloadMethodTemplate,
base=_OverloadMethodTemplate, prefer_literal=prefer_literal,
)


Expand Down
95 changes: 95 additions & 0 deletions numba/tests/test_extending.py
Original file line number Diff line number Diff line change
Expand Up @@ -1770,5 +1770,100 @@ def foo(x):
)


class TestOverloadPreferLiteral(TestCase):
def test_overload(self):
def prefer_lit(x):
pass

def non_lit(x):
pass

def ov(x):
if isinstance(x, types.IntegerLiteral):
# With prefer_literal=False, this branch will not be reached.
if x.literal_value == 1:
def impl(x):
return 0xcafe
return impl
else:
raise errors.TypingError('literal value')
else:
def impl(x):
return x * 100
return impl

overload(prefer_lit, prefer_literal=True)(ov)
overload(non_lit)(ov)

@njit
def check_prefer_lit(x):
return prefer_lit(1), prefer_lit(2), prefer_lit(x)

a, b, c = check_prefer_lit(3)
self.assertEqual(a, 0xcafe)
self.assertEqual(b, 200)
self.assertEqual(c, 300)

@njit
def check_non_lit(x):
return non_lit(1), non_lit(2), non_lit(x)

a, b, c = check_non_lit(3)
self.assertEqual(a, 100)
self.assertEqual(b, 200)
self.assertEqual(c, 300)

def test_overload_method(self):
def ov(self, x):
if isinstance(x, types.IntegerLiteral):
# With prefer_literal=False, this branch will not be reached.
if x.literal_value == 1:
def impl(self, x):
return 0xcafe
return impl
else:
raise errors.TypingError('literal value')
else:
def impl(self, x):
return x * 100
return impl

overload_method(
MyDummyType, "method_prefer_literal",
prefer_literal=True,
)(ov)

overload_method(
MyDummyType, "method_non_literal",
prefer_literal=False,
)(ov)

@njit
def check_prefer_lit(dummy, x):
return (
dummy.method_prefer_literal(1),
dummy.method_prefer_literal(2),
dummy.method_prefer_literal(x),
)

a, b, c = check_prefer_lit(MyDummy(), 3)
self.assertEqual(a, 0xcafe)
self.assertEqual(b, 200)
self.assertEqual(c, 300)

@njit
def check_non_lit(dummy, x):
return (
dummy.method_non_literal(1),
dummy.method_non_literal(2),
dummy.method_non_literal(x),
)

a, b, c = check_non_lit(MyDummy(), 3)
self.assertEqual(a, 100)
self.assertEqual(b, 200)
self.assertEqual(c, 300)


if __name__ == "__main__":
unittest.main()