Skip to content

Commit

Permalink
Support frozenset.
Browse files Browse the repository at this point in the history
Fixes Stewori#35.
  • Loading branch information
mitar committed Apr 13, 2018
1 parent 152a221 commit 3130519
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
14 changes: 9 additions & 5 deletions pytypes/type_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
import collections
from inspect import isfunction, ismethod, isclass, ismodule
try:
from backports.typing import Tuple, Dict, List, Set, Union, Any, \
from backports.typing import Tuple, Dict, List, Set, FrozenSet, Union, Any, \
Sequence, Mapping, TypeVar, Container, Generic, Sized, Iterable
except ImportError:
from typing import Tuple, Dict, List, Set, Union, Any, \
from typing import Tuple, Dict, List, Set, FrozenSet, Union, Any, \
Sequence, Mapping, TypeVar, Container, Generic, Sized, Iterable
try:
# Python 3.7
Expand Down Expand Up @@ -465,9 +465,13 @@ def _deep_type(obj, checked, checked_len, depth = None, max_sample = None):
tpl1 = tuple(_deep_type(t, checked, checked_len2, depth-1) for t in ksmpl)
tpl2 = tuple(_deep_type(t, checked, checked_len2, depth-1) for t in vsmpl)
res = Dict[Union[tpl1], Union[tpl2]]
elif res == set:
elif res == set or res == frozenset:
if res == set:
typ = Set
else:
typ = FrozenSet
if len(obj) == 0:
return Empty[Set]
return Empty[typ]
if max_sample == -1 or max_sample >= len(obj)-1 or len(obj) <= 2:
tpl = tuple(_deep_type(t, checked, depth-1) for t in obj)
else:
Expand All @@ -485,7 +489,7 @@ def _deep_type(obj, checked, checked_len, depth = None, max_sample = None):
j -= 1
smpl.append(next(itr))
tpl = tuple(_deep_type(t, checked, depth-1) for t in smpl)
res = Set[Union[tpl]]
res = typ[Union[tpl]]
elif res == types.GeneratorType:
res = get_generator_type(obj)
elif sys.version_info.major == 2 and isinstance(obj, types.InstanceType):
Expand Down
4 changes: 4 additions & 0 deletions tests/test_typechecker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4664,6 +4664,10 @@ class Foo(typing.Generic[T]):
# No exception.
resolve_fw_decl(Foo)

# See: https://github.com/Stewori/pytypes/issues/35
def test_frozenset(self):
self.assertTrue(pytypes.is_of_type(frozenset({1, 2, 'a', None, 'b'}), typing.AbstractSet[typing.Union[str, int, None]]))


if __name__ == '__main__':
unittest.main()

0 comments on commit 3130519

Please sign in to comment.