Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix extending self #3

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions numba/listobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
_get_incref_decref,
_get_equal,
_container_get_data,
_container_get_meminfo,
)


Expand Down Expand Up @@ -202,6 +203,15 @@ def codegen(context, builder, sig, args):
return sig, codegen


@lower_builtin(operator.is_, types.ListType, types.ListType)
def list_is(context, builder, sig, args):
a_meminfo = _container_get_meminfo(context, builder, sig.args[0], args[0])
b_meminfo = _container_get_meminfo(context, builder, sig.args[1], args[1])
ma = builder.ptrtoint(a_meminfo, cgutils.intp_t)
mb = builder.ptrtoint(b_meminfo, cgutils.intp_t)
return builder.icmp_signed('==', ma, mb)


def _call_list_free(context, builder, ptr):
"""Call numba_list_free(ptr)
"""
Expand Down Expand Up @@ -792,13 +802,21 @@ def impl_extend(l, iterable):
if not isinstance(iterable, types.IterableType):
raise TypingError("extend argument must be iterable")

itemty = l.item_type
if isinstance(iterable, types.ListType):
def impl(l, iterable):
# guard against l.extend(l)
if l is iterable:
iterable = iterable.copy()
for i in iterable:
l.append(i)

def impl(l, iterable):
for i in iterable:
l.append(_cast(i, itemty))
return impl
else:
def impl(l, iterable):
for i in iterable:
l.append(i)

return impl
return impl


@overload_method(types.ListType, 'insert')
Expand Down
46 changes: 46 additions & 0 deletions numba/tests/test_typedlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,52 @@ def test_setitem_slice_value_error(self):
)


class TestExtend(MemoryLeakMixin, TestCase):

def test_extend_other(self):
@njit
def impl(other):
l = List.empty_list(types.int32)
for x in range(10):
l.append(x)
l.extend(other)
return l

other = List.empty_list(types.int32)
for x in range(10):
other.append(x)

expected = impl.py_func(other)
got = impl(other)
self.assertEqual(expected, got)

def test_extend_self(self):
@njit
def impl():
l = List.empty_list(types.int32)
for x in range(10):
l.append(x)
l.extend(l)
return l

expected = impl.py_func()
got = impl()
self.assertEqual(expected, got)

def test_extend_tuple(self):
@njit
def impl():
l = List.empty_list(types.int32)
for x in range(10):
l.append(x)
l.extend((100,200,300))
return l

expected = impl.py_func()
got = impl()
self.assertEqual(expected, got)


class TestListRefctTypes(MemoryLeakMixin, TestCase):

@skip_py2
Expand Down
8 changes: 8 additions & 0 deletions numba/typedobjectutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ def _container_get_data(context, builder, container_ty, c):
return conatainer_struct.data


def _container_get_meminfo(context, builder, container_ty, c):
"""Helper to get the meminfo for a container
"""
ctor = cgutils.create_struct_proxy(container_ty)
conatainer_struct = ctor(context, builder, value=c)
return conatainer_struct.meminfo


def _get_incref_decref(context, module, datamodel, container_type):
assert datamodel.contains_nrt_meminfo()

Expand Down
3 changes: 3 additions & 0 deletions numba/types/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,9 @@ def key(self):
class ListType(IterableType):
"""List type
"""

mutable = True

def __init__(self, itemty):
assert not isinstance(itemty, TypeRef)
itemty = unliteral(itemty)
Expand Down