Skip to content

Commit

Permalink
Merge pull request #4978 from sklam/fix/iss4944
Browse files Browse the repository at this point in the history
Fix overload_method problem with stararg
  • Loading branch information
seibert committed Dec 18, 2019
2 parents 890c5c3 + 90956ff commit 63fbcfd
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 5 deletions.
2 changes: 2 additions & 0 deletions numba/datamodel/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ def __init__(self, dmm, fe_type):

@register_default(types.UniTuple)
@register_default(types.NamedUniTuple)
@register_default(types.StarArgUniTuple)
class UniTupleModel(DataModel):
def __init__(self, dmm, fe_type):
super(UniTupleModel, self).__init__(dmm, fe_type)
Expand Down Expand Up @@ -720,6 +721,7 @@ def __init__(self, dmm, fe_type):

@register_default(types.Tuple)
@register_default(types.NamedTuple)
@register_default(types.StarArgTuple)
class TupleModel(StructModel):
def __init__(self, dmm, fe_type):
members = [('f' + str(i), t) for i, t in enumerate(fe_type)]
Expand Down
2 changes: 1 addition & 1 deletion numba/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def normal_handler(index, param, value):
def default_handler(index, param, default):
return types.Omitted(default)
def stararg_handler(index, param, values):
return types.Tuple(values)
return types.StarArgTuple(values)
# For now, we take argument values from the @jit function, even
# in the case of generated jit.
args = fold_arguments(self.pysig, args, kws,
Expand Down
51 changes: 51 additions & 0 deletions numba/tests/test_extending.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,57 @@ def test():

self.assertEqual(test(), 0xdeadbeef)

def test_overload_method_stararg(self):
@overload_method(MyDummyType, "method_stararg")
def _ov_method_stararg(obj, val, val2, *args):
def get(obj, val, val2, *args):
return (val, val2, args)

return get

@njit
def foo(obj, *args):
# Test with expanding stararg
return obj.method_stararg(*args)

obj = MyDummy()
self.assertEqual(foo(obj, 1, 2), (1, 2, ()))
self.assertEqual(foo(obj, 1, 2, 3), (1, 2, (3,)))
self.assertEqual(foo(obj, 1, 2, 3, 4), (1, 2, (3, 4)))

@njit
def bar(obj):
# Test with explicit argument
return (
obj.method_stararg(1, 2),
obj.method_stararg(1, 2, 3),
obj.method_stararg(1, 2, 3, 4),
)

self.assertEqual(
bar(obj),
(
(1, 2, ()),
(1, 2, (3,)),
(1, 2, (3, 4))
),
)

# Check cases that put tuple type into stararg
# NOTE: the expected result has an extra tuple because of stararg.
self.assertEqual(
foo(obj, 1, 2, (3,)),
(1, 2, ((3,),)),
)
self.assertEqual(
foo(obj, 1, 2, (3, 4)),
(1, 2, ((3, 4),)),
)
self.assertEqual(
foo(obj, 1, 2, (3, (4, 5))),
(1, 2, ((3, (4, 5)),)),
)


def _assert_cache_stats(cfunc, expect_hit, expect_misses):
hit = cfunc._cache_hits[cfunc.signatures[0]]
Expand Down
26 changes: 24 additions & 2 deletions numba/types/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,9 @@ class UniTuple(BaseAnonymousTuple, _HomogeneousTuple, Sequence):
def __init__(self, dtype, count):
self.dtype = dtype
self.count = count
name = "UniTuple(%s x %d)" % (dtype, count)
name = "%s(%s x %d)" % (
self.__class__.__name__, dtype, count,
)
super(UniTuple, self).__init__(name)

@property
Expand Down Expand Up @@ -276,7 +278,10 @@ def __init__(self, types):
self.types = tuple(types)
self.count = len(self.types)
self.dtype = UnionType(types)
name = "Tuple(%s)" % ', '.join(str(i) for i in self.types)
name = "%s(%s)" % (
self.__class__.__name__,
', '.join(str(i) for i in self.types),
)
super(Tuple, self).__init__(name)

@property
Expand All @@ -300,6 +305,23 @@ def unify(self, typingctx, other):
return Tuple(unified)


class StarArgTuple(Tuple):
"""To distinguish from Tuple() used as argument to a `*args`.
"""
def __new__(cls, types):
_HeterogeneousTuple.is_types_iterable(types)

if types and all(t == types[0] for t in types[1:]):
return StarArgUniTuple(dtype=types[0], count=len(types))
else:
return object.__new__(StarArgTuple)


class StarArgUniTuple(UniTuple):
"""To distinguish from UniTuple() used as argument to a `*args`.
"""


class BaseNamedTuple(BaseTuple):
pass

Expand Down
15 changes: 13 additions & 2 deletions numba/typing/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,19 @@ def fold_arguments(pysig, args, kws, normal_handler, default_handler,
if param.kind == param.VAR_POSITIONAL:
# stararg may be omitted, in which case its "default" value
# is simply the empty tuple
ba.arguments[name] = stararg_handler(i, param,
ba.arguments.get(name, ()))
if name in ba.arguments:
argval = ba.arguments[name]
# NOTE: avoid wrapping the tuple type for stararg in another
# tuple.
if (len(argval) == 1 and
isinstance(argval[0], (types.StarArgTuple,
types.StarArgUniTuple))):
argval = tuple(argval[0])
else:
argval = ()
out = stararg_handler(i, param, argval)

ba.arguments[name] = out
elif name in ba.arguments:
# Non-stararg, present
ba.arguments[name] = normal_handler(i, param, ba.arguments[name])
Expand Down

0 comments on commit 63fbcfd

Please sign in to comment.