Skip to content

Commit

Permalink
Inspector now distinguishes betwen UnionType and typing.Union
Browse files Browse the repository at this point in the history
  • Loading branch information
sg495 committed Feb 22, 2024
1 parent d2ff39d commit 74007de
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 14 deletions.
18 changes: 18 additions & 0 deletions test/test_00_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
else:
from typing_extensions import TypedDict

if sys.version_info[1] >= 10:
from types import UnionType
else:
UnionType = None


_basic_types = [
bool, int, float, complex, str, bytes, bytearray,
list, tuple, set, frozenset, dict, type(None)
Expand Down Expand Up @@ -402,3 +408,15 @@ def test_subtype() -> None:
validate(10, typing.Type[int])
with pytest.raises(TypeError):
validate(10, typing.Type[typing.Union[str, float]])

@pytest.mark.parametrize("val, ts", _union_cases)
def test_union_type_cases(val: typing.Any, ts: typing.List[typing.Any]) -> None:
if UnionType is not None:
for t in ts:
members = t.__args__
if not members:
continue
u = members[0]
for t in members[1:]:
u |= t
validate(val, u)
21 changes: 21 additions & 0 deletions test/test_01_can_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
import typing
import pytest

if sys.version_info[1] >= 10:
from types import UnionType
else:
UnionType = None

from typing_validation import can_validate, validation_aliases
from typing_validation.inspector import _typing_equiv
from typing_validation.validation import _pseudotypes_dict
Expand Down Expand Up @@ -119,3 +124,19 @@ def test_subtype() -> None:
assert can_validate(typing.Type[typing.Union[int,str]])
assert can_validate(typing.Type[typing.Any])
assert can_validate(typing.Type[typing.Union[typing.Any, str, int]])

_union_cases_ts = sorted({
t for _, ts in _union_cases for t in ts
}, key=repr)


@pytest.mark.parametrize("t", _union_cases_ts)
def test_union_type_cases(t: typing.Any) -> None:
if UnionType is not None:
members = t.__args__
if not members:
return
u = members[0]
for t in members[1:]:
u |= t
assert can_validate(u)
45 changes: 34 additions & 11 deletions typing_validation/inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
typing.Tuple[Literal["mapping"], None],
typing.Tuple[Literal["typed-dict"], type],
typing.Tuple[Literal["typevar"], TypeVar],
typing.Tuple[Literal["union"], int],
typing.Tuple[Literal["union"], tuple[int, bool]],
typing.Tuple[Literal["tuple"], Optional[int]],
typing.Tuple[Literal["user-class"], Optional[int]],
typing.Tuple[Literal["alias"], str],
Expand All @@ -44,6 +44,11 @@
else:
TypeConstructorArgs = typing.Tuple[str, Any]

if sys.version_info[1] >= 10:
from types import UnionType
else:
UnionType = None

if sys.version_info[1] >= 11:
from typing import Self
else:
Expand Down Expand Up @@ -186,12 +191,19 @@ def _recorded_type(self, idx: int) -> typing.Tuple[Any, int]:
idx,
) # pylint: disable = unnecessary-dunder-call
if tag == "union":
assert isinstance(param, int)
assert isinstance(param, tuple)
num_members, use_UnionType = param
assert isinstance(num_members, int)
member_ts: typing.List[Any] = []
for _ in range(param):
for _ in range(num_members):
member_t, idx = self._recorded_type(idx + 1)
member_ts.append(member_t)
return typing.Union.__getitem__(tuple(member_ts)), idx
if not use_UnionType:
return typing.Union.__getitem__(tuple(member_ts)), idx
union_type = member_ts[0]
for t in member_ts[1:]:
union_type |= t
return union_type, idx
if tag == "typed-dict":
for _ in get_type_hints(param):
_, idx = self._recorded_type(idx + 1)
Expand Down Expand Up @@ -302,8 +314,11 @@ def _record_collection(self, item_t: Any) -> None:
def _record_mapping(self, key_t: Any, value_t: Any) -> None:
self._append_constructor_args(("mapping", None))

def _record_union(self, *member_ts: Any) -> None:
self._append_constructor_args(("union", len(member_ts)))
def _record_union(self, *member_ts: Any, use_UnionType: bool = False) -> None:
if use_UnionType:
assert member_ts, "Cannot use UnionType with empty members."
assert UnionType is not None, "Cannot use UnionType, version <= 3.9"
self._append_constructor_args(("union", (len(member_ts), use_UnionType)))

def _record_variadic_tuple(self, item_t: Any) -> None:
self._append_constructor_args(("tuple", None))
Expand Down Expand Up @@ -385,14 +400,22 @@ def _repr(
]
return lines, idx
if tag == "union":
assert isinstance(param, int)
lines = [indent + "Union["]
for _ in range(param):
assert isinstance(param, tuple)
num_members, use_UnionType = param
assert isinstance(num_members, int)
lines = []
if not use_UnionType:
lines.append(indent + "Union[")
for _ in range(num_members):
member_lines, idx = self._repr(idx + 1, level + 1)
member_lines[-1] += ","
if use_UnionType:
member_lines[-1] += "|"
else:
member_lines[-1] += ","
lines.extend(member_lines)
assert len(lines) > 1, "Cannot take a union of no types."
lines.append(indent + "]")
if not use_UnionType:
lines.append(indent + "]")
return lines, idx
if tag == "typed-dict":
t = param
Expand Down
6 changes: 3 additions & 3 deletions typing_validation/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def _validate_tuple(val: Any, t: Any) -> None:
) from None


def _validate_union(val: Any, t: Any, *, union_type: bool = False) -> None:
def _validate_union(val: Any, t: Any, *, use_UnionType: bool = False) -> None:
"""
Union type validation. Each type ``u`` listed in the union type ``t`` is checked:
Expand All @@ -444,7 +444,7 @@ def _validate_union(val: Any, t: Any, *, union_type: bool = False) -> None:
t.__args__, tuple
), f"For type {repr(t)}, expected '__args__' to be a tuple."
if isinstance(val, TypeInspector):
val._record_union(*t.__args__)
val._record_union(*t.__args__, use_UnionType=use_UnionType)
for member_t in t.__args__:
validate(val, member_t)
return
Expand Down Expand Up @@ -815,7 +815,7 @@ def validate(val: Any, t: Any) -> Literal[True]:
_validate_typevar(val, t)
return True
if UnionType is not None and isinstance(t, UnionType):
_validate_union(val, t, union_type=True)
_validate_union(val, t, use_UnionType=True)
return True
if hasattr(t, "__origin__"): # parametric types
if t.__origin__ is Union:
Expand Down

0 comments on commit 74007de

Please sign in to comment.