Skip to content

Commit

Permalink
Merge pull request #5958 from luk-f-a/typed-list-as-generic
Browse files Browse the repository at this point in the history
Making typed.List a typing Generic
  • Loading branch information
sklam committed Feb 1, 2021
2 parents c743ac8 + 1dfa46f commit 5f98f6b
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 14 deletions.
Empty file added numba/typed/py.typed
Empty file.
64 changes: 50 additions & 14 deletions numba/typed/typedlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,23 @@
from numba.typed import listobject
from numba.core.errors import TypingError, LoweringError
from numba.core.typing.templates import Signature
import typing as pt
import sys


Int_or_Slice = pt.Union["pt.SupportsIndex", slice]


if sys.version_info >= (3, 8):
T_co = pt.TypeVar('T_co', covariant=True)

class _Sequence(pt.Protocol[T_co]):
def __getitem__(self, i: int) -> T_co:
...

def __len__(self) -> int:
...


DEFAULT_ALLOCATED = listobject.DEFAULT_ALLOCATED

Expand Down Expand Up @@ -170,7 +187,11 @@ def _from_meminfo_ptr(ptr, listtype):
return List(meminfo=ptr, lsttype=listtype)


class List(MutableSequence):
T = pt.TypeVar('T')
T_or_ListT = pt.Union[T, 'List[T]']


class List(MutableSequence, pt.Generic[T]):
"""A typed-list usable in Numba compiled functions.
Implements the MutableSequence interface.
Expand Down Expand Up @@ -278,7 +299,7 @@ def _initialise_list(self, item):
lsttype = types.ListType(typeof(item))
self._list_type, self._opaque = self._parse_arg(lsttype)

def __len__(self):
def __len__(self) -> int:
if not self._typed:
return 0
else:
Expand Down Expand Up @@ -317,44 +338,58 @@ def __gt__(self, other):
def __ge__(self, other):
return _ge(self, other)

def append(self, item):
def append(self, item: T) -> None:
if not self._typed:
self._initialise_list(item)
_append(self, item)

def __setitem__(self, i, item):
# noqa F811 comments required due to github.com/PyCQA/pyflakes/issues/592
# noqa E704 required to follow overload style of using ... in the same line
@pt.overload # type: ignore[override]
def __setitem__(self, i: int, o: T) -> None: ... # noqa: F811, E704
@pt.overload
def __setitem__(self, s: slice, o: 'List[T]') -> None: ... # noqa: F811, E704, E501

def __setitem__(self, i: Int_or_Slice, item: T_or_ListT) -> None: # noqa: F811, E501
if not self._typed:
self._initialise_list(item)
_setitem(self, i, item)

def __getitem__(self, i):
# noqa F811 comments required due to github.com/PyCQA/pyflakes/issues/592
# noqa E704 required to follow overload style of using ... in the same line
@pt.overload
def __getitem__(self, i: int) -> T: ... # noqa: F811, E704
@pt.overload
def __getitem__(self, i: slice) -> 'List[T]': ... # noqa: F811, E704

def __getitem__(self, i: Int_or_Slice) -> T_or_ListT: # noqa: F811
if not self._typed:
raise IndexError
else:
return _getitem(self, i)

def __iter__(self):
def __iter__(self) -> pt.Iterator[T]:
for i in range(len(self)):
yield self[i]

def __contains__(self, item):
def __contains__(self, item: T) -> bool: # type: ignore[override]
return _contains(self, item)

def __delitem__(self, i):
def __delitem__(self, i: Int_or_Slice) -> None:
_delitem(self, i)

def insert(self, i, item):
def insert(self, i: int, item: T) -> None:
if not self._typed:
self._initialise_list(item)
_insert(self, i, item)

def count(self, item):
def count(self, item: T) -> int:
return _count(self, item)

def pop(self, i=-1):
def pop(self, i: "pt.SupportsIndex" = -1) -> T:
return _pop(self, i)

def extend(self, iterable):
def extend(self, iterable: "_Sequence[T]") -> None: #type: ignore[override]
# Empty iterable, do nothing
if len(iterable) == 0:
return None
Expand All @@ -365,7 +400,7 @@ def extend(self, iterable):
self._initialise_list(iterable[0])
return _extend(self, iterable)

def remove(self, item):
def remove(self, item: T) -> None:
return _remove(self, item)

def clear(self):
Expand All @@ -377,7 +412,8 @@ def reverse(self):
def copy(self):
return _copy(self)

def index(self, item, start=None, stop=None):
def index(self, item: T, start: pt.Optional[int] = None,
stop: pt.Optional[int] = None) -> int:
return _index(self, item, start, stop)

def sort(self, key=None, reverse=False):
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def check_file_at_path(path2file):
"numba.cext": ["*.c", "*.h"],
# numba gdb hook init command language file
"numba.misc": ["cmdlang.gdb"],
"numba.typed": ["py.typed"],
},
scripts=["numba/pycc/pycc", "bin/numba"],
author="Anaconda, Inc.",
Expand Down

0 comments on commit 5f98f6b

Please sign in to comment.