Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: properly set annotation in union with Annotated constraints in subtypes #355

Merged
merged 1 commit into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions polyfactory/field_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
TYPE_MAPPING,
)
from polyfactory.utils.helpers import normalize_annotation, unwrap_annotated, unwrap_args, unwrap_new_type
from polyfactory.utils.predicates import is_annotated
from polyfactory.utils.predicates import is_annotated, is_any_annotated

if TYPE_CHECKING:
import datetime
Expand Down Expand Up @@ -134,7 +134,7 @@ def from_type(
_, metadata = unwrap_annotated(annotation, random=random)
constraints = cls.parse_constraints(metadata)

if not any(is_annotated(arg) for arg in get_args(annotation)):
if not is_any_annotated(annotation):
annotation = TYPE_MAPPING[field_type] if field_type in TYPE_MAPPING else field_type
elif (origin := get_origin(annotation)) and origin in TYPE_MAPPING: # pragma: no cover
container = TYPE_MAPPING[origin]
Expand Down
12 changes: 12 additions & 0 deletions polyfactory/utils/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,18 @@ def is_annotated(annotation: Any) -> bool:
)


def is_any_annotated(annotation: Any) -> bool:
"""Determine whether any of the types in the given annotation is
`typing.Annotated`.

:param annotation: A type annotation.

:returns: A boolean
"""

return any(is_annotated(arg) or hasattr(arg, "__args__") and is_any_annotated(arg) for arg in get_args(annotation))


def get_type_origin(annotation: Any) -> Any:
"""Get the type origin of an annotation - safely.

Expand Down
20 changes: 20 additions & 0 deletions tests/test_annotated_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ class LocationFactory(ModelFactory[Location]):
assert LocationFactory.build()


def test_optional_tuple_with_annotated_constraints() -> None:
class Location(BaseModel):
long_lat: Union[Tuple[Annotated[float, Ge(-180), Le(180)], Annotated[float, Ge(-90), Le(90)]], None]

class LocationFactory(ModelFactory[Location]):
__model__ = Location

assert LocationFactory.build()


def test_legacy_tuple_with_annotated_constraints() -> None:
class Location(BaseModel):
long_lat: Tuple[Annotated[float, Ge(-180), Le(180)], Annotated[float, Ge(-90), Le(90)]]
Expand All @@ -67,3 +77,13 @@ class LocationFactory(ModelFactory[Location]):
__model__ = Location

assert LocationFactory.build()


def test_legacy_optional_tuple_with_annotated_constraints() -> None:
class Location(BaseModel):
long_lat: Union[Tuple[Annotated[float, Ge(-180), Le(180)], Annotated[float, Ge(-90), Le(90)]], None]

class LocationFactory(ModelFactory[Location]):
__model__ = Location

assert LocationFactory.build()
Loading