Skip to content

Commit

Permalink
Merge pull request #6111 from stuartarchibald/fix/6094_a
Browse files Browse the repository at this point in the history
Decouple LiteralList and LiteralStrKeyDict from tuple
  • Loading branch information
sklam committed Aug 11, 2020
1 parent 993612a commit 01f0fae
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 11 deletions.
42 changes: 38 additions & 4 deletions numba/core/types/containers.py
Expand Up @@ -468,19 +468,44 @@ def __unliteral__(self):
initial_value=None)


class LiteralList(Literal, _HeterogeneousTuple):
class LiteralList(Literal, ConstSized, Hashable):
"""A heterogeneous immutable list (basically a tuple with list semantics).
"""

mutable = False

def __init__(self, literal_value):
_HeterogeneousTuple.is_types_iterable(literal_value)
self.is_types_iterable(literal_value)
self._literal_init(list(literal_value))
self.types = tuple(literal_value)
self.count = len(self.types)
self.name = "LiteralList({})".format(literal_value)

def __getitem__(self, i):
"""
Return element at position i
"""
return self.types[i]

def __len__(self):
return len(self.types)

def __iter__(self):
return iter(self.types)

@classmethod
def from_types(cls, tys):
return LiteralList(tys)

@staticmethod
def is_types_iterable(types):
if not isinstance(types, Iterable):
raise TypingError("Argument 'types' is not iterable")

@property
def iterator_type(self):
return ListIter(self)

def __unliteral__(self):
return Poison(self)

Expand Down Expand Up @@ -732,7 +757,7 @@ def __unliteral__(self):
return DictType(self.key_type, self.value_type)


class LiteralStrKeyDict(Literal, NamedTuple):
class LiteralStrKeyDict(Literal, ConstSized, Hashable):
"""A Dictionary of string keys to heterogeneous values (basically a
namedtuple with dict semantics).
"""
Expand All @@ -745,7 +770,10 @@ def __init__(self, literal_value, value_index=None):
strkeys = [x.literal_value for x in literal_value.keys()]
self.tuple_ty = namedtuple("_ntclazz", " ".join(strkeys))
tys = [x for x in literal_value.values()]
NamedTuple.__init__(self, tys, self.tuple_ty)
self.types = tuple(tys)
self.count = len(self.types)
self.fields = tuple(self.tuple_ty._fields)
self.instance_class = self.tuple_ty
self.name = "LiteralStrKey[Dict]({})".format(literal_value)

def __unliteral__(self):
Expand All @@ -768,6 +796,12 @@ def unify(self, typingctx, other):
d = {k: v for k, v in zip(self.literal_value.keys(), tys)}
return LiteralStrKeyDict(d)

def __len__(self):
return len(self.types)

def __iter__(self):
return iter(self.types)

@property
def key(self):
# use the namedtuple fields not the namedtuple itself as it's created
Expand Down
36 changes: 32 additions & 4 deletions numba/core/typing/builtins.py
Expand Up @@ -613,15 +613,43 @@ def generic(self, args, kws):
ret = tup.types[idx]
elif isinstance(idx, slice):
ret = types.BaseTuple.from_types(tup.types[idx])
elif isinstance(tup, types.LiteralStrKeyDict):
if isinstance(idx, str):
lookup = tup.fields.index(idx)
ret = tup.types[lookup]
if ret is not None:
sig = signature(ret, *args)
return sig


@infer
class StaticGetItemLiteralList(AbstractTemplate):
key = "static_getitem"

def generic(self, args, kws):
tup, idx = args
ret = None
if not isinstance(tup, types.LiteralList):
return
if isinstance(idx, int):
ret = tup.types[idx]
if ret is not None:
sig = signature(ret, *args)
return sig


@infer
class StaticGetItemLiteralStrKeyDict(AbstractTemplate):
key = "static_getitem"

def generic(self, args, kws):
tup, idx = args
ret = None
if not isinstance(tup, types.LiteralStrKeyDict):
return
if isinstance(idx, str):
lookup = tup.fields.index(idx)
ret = tup.types[lookup]
if ret is not None:
sig = signature(ret, *args)
return sig

# Generic implementation for "not in"

@infer
Expand Down
6 changes: 3 additions & 3 deletions numba/core/untyped_passes.py
Expand Up @@ -617,7 +617,7 @@ class TransformLiteralUnrollConstListToTuple(FunctionPass):
"""
_name = "transform_literal_unroll_const_list_to_tuple"

_accepted_types = (types.BaseTuple,)
_accepted_types = (types.BaseTuple, types.LiteralList)

def __init__(self):
FunctionPass.__init__(self)
Expand Down Expand Up @@ -740,7 +740,7 @@ class MixedContainerUnroller(FunctionPass):

_DEBUG = False

_accepted_types = (types.BaseTuple,)
_accepted_types = (types.BaseTuple, types.LiteralList)

def __init__(self):
FunctionPass.__init__(self)
Expand Down Expand Up @@ -1264,7 +1264,7 @@ class IterLoopCanonicalization(FunctionPass):
_DEBUG = False

# if partial typing info is available it will only look at these types
_accepted_types = (types.BaseTuple,)
_accepted_types = (types.BaseTuple, types.LiteralList)
_accepted_calls = (literal_unroll,)

def __init__(self):
Expand Down
18 changes: 18 additions & 0 deletions numba/cpython/listobj.py
Expand Up @@ -1215,6 +1215,13 @@ def literal_list_getitem(lst, *args):
"statically determined.")
raise errors.TypingError(msg)

@overload(len)
def literal_list_len(lst):
if not isinstance(lst, types.LiteralList):
return
l = lst.count
return lambda lst: l

@overload(operator.contains)
def literal_list_contains(lst, item):
if isinstance(lst, types.LiteralList):
Expand All @@ -1224,3 +1231,14 @@ def impl(lst, item):
return True
return False
return impl

@lower_cast(types.LiteralList, types.LiteralList)
def literallist_to_literallist(context, builder, fromty, toty, val):
if len(fromty) != len(toty):
# Disallowed by typing layer
raise NotImplementedError

olditems = cgutils.unpack_tuple(builder, val, len(fromty))
items = [context.cast(builder, v, f, t)
for v, f, t in zip(olditems, fromty, toty)]
return context.make_tuple(builder, toty, items)
2 changes: 2 additions & 0 deletions numba/cpython/tupleobj.py
Expand Up @@ -346,6 +346,8 @@ def getitem_unituple(context, builder, sig, args):


@lower_builtin('static_getitem', types.LiteralStrKeyDict, types.StringLiteral)
@lower_builtin('static_getitem', types.LiteralList, types.IntegerLiteral)
@lower_builtin('static_getitem', types.LiteralList, types.SliceLiteral)
@lower_builtin('static_getitem', types.BaseTuple, types.IntegerLiteral)
@lower_builtin('static_getitem', types.BaseTuple, types.SliceLiteral)
def static_getitem_tuple(context, builder, sig, args):
Expand Down
19 changes: 19 additions & 0 deletions numba/tests/test_dictobject.py
Expand Up @@ -2149,6 +2149,25 @@ def foo():

np.testing.assert_allclose(foo(), np.ones(3) * 10)

def test_tuple_not_in_mro(self):
# Related to #6094, make sure that LiteralStrKey does not inherit from
# types.BaseTuple as this breaks isinstance checks.
def bar(x):
pass

@overload(bar)
def ol_bar(x):
self.assertFalse(isinstance(x, types.BaseTuple))
self.assertTrue(isinstance(x, types.LiteralStrKeyDict))
return lambda x: ...

@njit
def foo():
d = {'a': 1, 'b': 'c'}
bar(d)

foo()


if __name__ == '__main__':
unittest.main()
32 changes: 32 additions & 0 deletions numba/tests/test_lists.py
Expand Up @@ -1735,6 +1735,19 @@ def foo():

self.assertEqual(foo.py_func(), foo())

def test_staticgetitem_slice(self):
# this is forbidden by typing as there's no way to serialize a list of
# any kind as required by returning a (static) slice of a LiteralList
@njit
def foo():
l = ['a', 'b', 1]
return l[:2]

with self.assertRaises(errors.TypingError) as raises:
foo()
expect = "Cannot __getitem__ on a literal list"
self.assertIn(expect, str(raises.exception))

def test_setitem(self):

@njit
Expand Down Expand Up @@ -1794,6 +1807,25 @@ def foo():

self.assertEqual(foo(), foo.py_func())

def test_tuple_not_in_mro(self):
# Related to #6094, make sure that LiteralList does not inherit from
# types.BaseTuple as this breaks isinstance checks.
def bar(x):
pass

@overload(bar)
def ol_bar(x):
self.assertFalse(isinstance(x, types.BaseTuple))
self.assertTrue(isinstance(x, types.LiteralList))
return lambda x: ...

@njit
def foo():
l = ['a', 1]
bar(l)

foo()


if __name__ == '__main__':
unittest.main()
8 changes: 8 additions & 0 deletions numba/typed/dictobject.py
Expand Up @@ -1247,6 +1247,14 @@ def impl(d, k):
return impl


@overload(len)
def literalstrkeydict_impl_len(d):
if not isinstance(d, types.LiteralStrKeyDict):
return
l = d.count
return lambda d: l


@overload(operator.setitem)
def literalstrkeydict_banned_impl_setitem(d, key, value):
if not isinstance(d, types.LiteralStrKeyDict):
Expand Down

0 comments on commit 01f0fae

Please sign in to comment.