Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix overload_method problem with stararg #4978

Merged
merged 6 commits into from
Dec 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions numba/datamodel/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ def __init__(self, dmm, fe_type):

@register_default(types.UniTuple)
@register_default(types.NamedUniTuple)
@register_default(types.StarArgUniTuple)
class UniTupleModel(DataModel):
def __init__(self, dmm, fe_type):
super(UniTupleModel, self).__init__(dmm, fe_type)
Expand Down Expand Up @@ -720,6 +721,7 @@ def __init__(self, dmm, fe_type):

@register_default(types.Tuple)
@register_default(types.NamedTuple)
@register_default(types.StarArgTuple)
class TupleModel(StructModel):
def __init__(self, dmm, fe_type):
members = [('f' + str(i), t) for i, t in enumerate(fe_type)]
Expand Down
2 changes: 1 addition & 1 deletion numba/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def normal_handler(index, param, value):
def default_handler(index, param, default):
return types.Omitted(default)
def stararg_handler(index, param, values):
return types.Tuple(values)
return types.StarArgTuple(values)
# For now, we take argument values from the @jit function, even
# in the case of generated jit.
args = fold_arguments(self.pysig, args, kws,
Expand Down
51 changes: 51 additions & 0 deletions numba/tests/test_extending.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,57 @@ def test():

self.assertEqual(test(), 0xdeadbeef)

def test_overload_method_stararg(self):
@overload_method(MyDummyType, "method_stararg")
def _ov_method_stararg(obj, val, val2, *args):
def get(obj, val, val2, *args):
return (val, val2, args)

return get

@njit
def foo(obj, *args):
# Test with expanding stararg
return obj.method_stararg(*args)

obj = MyDummy()
self.assertEqual(foo(obj, 1, 2), (1, 2, ()))
self.assertEqual(foo(obj, 1, 2, 3), (1, 2, (3,)))
self.assertEqual(foo(obj, 1, 2, 3, 4), (1, 2, (3, 4)))

@njit
def bar(obj):
# Test with explicit argument
return (
obj.method_stararg(1, 2),
obj.method_stararg(1, 2, 3),
obj.method_stararg(1, 2, 3, 4),
)

self.assertEqual(
bar(obj),
(
(1, 2, ()),
(1, 2, (3,)),
(1, 2, (3, 4))
),
)

# Check cases that put tuple type into stararg
# NOTE: the expected result has an extra tuple because of stararg.
self.assertEqual(
foo(obj, 1, 2, (3,)),
(1, 2, ((3,),)),
)
self.assertEqual(
foo(obj, 1, 2, (3, 4)),
(1, 2, ((3, 4),)),
)
self.assertEqual(
foo(obj, 1, 2, (3, (4, 5))),
(1, 2, ((3, (4, 5)),)),
)


def _assert_cache_stats(cfunc, expect_hit, expect_misses):
hit = cfunc._cache_hits[cfunc.signatures[0]]
Expand Down
26 changes: 24 additions & 2 deletions numba/types/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,9 @@ class UniTuple(BaseAnonymousTuple, _HomogeneousTuple, Sequence):
def __init__(self, dtype, count):
self.dtype = dtype
self.count = count
name = "UniTuple(%s x %d)" % (dtype, count)
name = "%s(%s x %d)" % (
self.__class__.__name__, dtype, count,
)
super(UniTuple, self).__init__(name)

@property
Expand Down Expand Up @@ -276,7 +278,10 @@ def __init__(self, types):
self.types = tuple(types)
self.count = len(self.types)
self.dtype = UnionType(types)
name = "Tuple(%s)" % ', '.join(str(i) for i in self.types)
name = "%s(%s)" % (
self.__class__.__name__,
', '.join(str(i) for i in self.types),
)
super(Tuple, self).__init__(name)

@property
Expand All @@ -300,6 +305,23 @@ def unify(self, typingctx, other):
return Tuple(unified)


class StarArgTuple(Tuple):
"""To distinguish from Tuple() used as argument to a `*args`.
"""
def __new__(cls, types):
_HeterogeneousTuple.is_types_iterable(types)

if types and all(t == types[0] for t in types[1:]):
return StarArgUniTuple(dtype=types[0], count=len(types))
else:
return object.__new__(StarArgTuple)


class StarArgUniTuple(UniTuple):
"""To distinguish from UniTuple() used as argument to a `*args`.
"""
Comment on lines +308 to +322
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these get a custom name written to self.name so as to distinguish the type in traceback/debug?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch. yes it needs new name

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in 90956ff



class BaseNamedTuple(BaseTuple):
pass

Expand Down
15 changes: 13 additions & 2 deletions numba/typing/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,19 @@ def fold_arguments(pysig, args, kws, normal_handler, default_handler,
if param.kind == param.VAR_POSITIONAL:
# stararg may be omitted, in which case its "default" value
# is simply the empty tuple
ba.arguments[name] = stararg_handler(i, param,
ba.arguments.get(name, ()))
if name in ba.arguments:
argval = ba.arguments[name]
# NOTE: avoid wrapping the tuple type for stararg in another
# tuple.
if (len(argval) == 1 and
isinstance(argval[0], (types.StarArgTuple,
types.StarArgUniTuple))):
argval = tuple(argval[0])
else:
argval = ()
out = stararg_handler(i, param, argval)

ba.arguments[name] = out
elif name in ba.arguments:
# Non-stararg, present
ba.arguments[name] = normal_handler(i, param, ba.arguments[name])
Expand Down