Skip to content

Commit

Permalink
fix: properly set annotation in union with nested Annotated (#355)
Browse files Browse the repository at this point in the history
  • Loading branch information
guacs committed Sep 12, 2023
1 parent 7af5469 commit f639c26
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 2 deletions.
4 changes: 2 additions & 2 deletions polyfactory/field_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
TYPE_MAPPING,
)
from polyfactory.utils.helpers import normalize_annotation, unwrap_annotated, unwrap_args, unwrap_new_type
from polyfactory.utils.predicates import is_annotated
from polyfactory.utils.predicates import is_annotated, is_any_annotated

if TYPE_CHECKING:
import datetime
Expand Down Expand Up @@ -134,7 +134,7 @@ def from_type(
_, metadata = unwrap_annotated(annotation, random=random)
constraints = cls.parse_constraints(metadata)

if not any(is_annotated(arg) for arg in get_args(annotation)):
if not is_any_annotated(annotation):
annotation = TYPE_MAPPING[field_type] if field_type in TYPE_MAPPING else field_type
elif (origin := get_origin(annotation)) and origin in TYPE_MAPPING: # pragma: no cover
container = TYPE_MAPPING[origin]
Expand Down
12 changes: 12 additions & 0 deletions polyfactory/utils/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,18 @@ def is_annotated(annotation: Any) -> bool:
)


def is_any_annotated(annotation: Any) -> bool:
"""Determine whether any of the types in the given annotation is
`typing.Annotated`.
:param annotation: A type annotation.
:returns: A boolean
"""

return any(is_annotated(arg) or hasattr(arg, "__args__") and is_any_annotated(arg) for arg in get_args(annotation))


def get_type_origin(annotation: Any) -> Any:
"""Get the type origin of an annotation - safely.
Expand Down
20 changes: 20 additions & 0 deletions tests/test_annotated_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ class LocationFactory(ModelFactory[Location]):
assert LocationFactory.build()


def test_optional_tuple_with_annotated_constraints() -> None:
class Location(BaseModel):
long_lat: Union[Tuple[Annotated[float, Ge(-180), Le(180)], Annotated[float, Ge(-90), Le(90)]], None]

class LocationFactory(ModelFactory[Location]):
__model__ = Location

assert LocationFactory.build()


def test_legacy_tuple_with_annotated_constraints() -> None:
class Location(BaseModel):
long_lat: Tuple[Annotated[float, Ge(-180), Le(180)], Annotated[float, Ge(-90), Le(90)]]
Expand All @@ -67,3 +77,13 @@ class LocationFactory(ModelFactory[Location]):
__model__ = Location

assert LocationFactory.build()


def test_legacy_optional_tuple_with_annotated_constraints() -> None:
class Location(BaseModel):
long_lat: Union[Tuple[Annotated[float, Ge(-180), Le(180)], Annotated[float, Ge(-90), Le(90)]], None]

class LocationFactory(ModelFactory[Location]):
__model__ = Location

assert LocationFactory.build()

0 comments on commit f639c26

Please sign in to comment.