Skip to content

Commit

Permalink
fix: handle constrained unions properly (#499)
Browse files Browse the repository at this point in the history
  • Loading branch information
guacs committed Mar 2, 2024
1 parent c4e3d91 commit 0f8f9e8
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 44 deletions.
27 changes: 26 additions & 1 deletion polyfactory/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import sys
from collections import abc, defaultdict, deque
from random import Random
from typing import (
Expand All @@ -21,9 +22,10 @@
except ImportError:
UnionType = Union # type: ignore[misc,assignment]

PY_38 = sys.version_info.major == 3 and sys.version_info.minor == 8 # noqa: PLR2004

# Mapping of type annotations into concrete types. This is used to normalize python <= 3.9 annotations.
TYPE_MAPPING = {
INSTANTIABLE_TYPE_MAPPING = {
DefaultDict: defaultdict,
Deque: deque,
Dict: dict,
Expand All @@ -41,6 +43,29 @@
UnionType: Union,
}


if not PY_38:
TYPE_MAPPING = INSTANTIABLE_TYPE_MAPPING
else:
# For 3.8, we have to keep the types from typing since dict[str] syntax is not supported in 3.8.
TYPE_MAPPING = {
DefaultDict: DefaultDict,
Deque: Deque,
Dict: Dict,
FrozenSet: FrozenSet,
Iterable: List,
List: List,
Mapping: Dict,
Sequence: List,
Set: Set,
Tuple: Tuple,
abc.Iterable: List,
abc.Mapping: Dict,
abc.Sequence: List,
abc.Set: Set,
}


DEFAULT_RANDOM = Random()
RANDOMIZE_COLLECTION_LENGTH = False
MIN_COLLECTION_LENGTH = 0
Expand Down
58 changes: 32 additions & 26 deletions polyfactory/factories/pydantic_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,39 +164,45 @@ def from_field_info(

annotation = unwrap_new_type(field_info.annotation)
children: list[FieldMeta,] | None = None
constraints: Constraints = {}
name = field_info.alias if field_info.alias and use_alias else field_name

constraints: PydanticConstraints
# pydantic v2 does not always propagate metadata for Union types
if not field_info.metadata and is_optional(annotation):
field_info = FieldInfo.from_annotation(unwrap_optional(annotation))
elif is_union(annotation):
children = [
cls.from_field_info(
field_info=FieldInfo.from_annotation(arg),
field_name=field_name,
random=random,
use_alias=use_alias,
if is_union(annotation):
constraints = {}
children = []
for arg in get_args(annotation):
if arg is NoneType:
continue
child_field_info = FieldInfo.from_annotation(arg)
merged_field_info = FieldInfo.merge_field_infos(field_info, child_field_info)
children.append(
cls.from_field_info(
field_name="",
field_info=merged_field_info,
use_alias=use_alias,
random=random,
),
)
for arg in get_args(annotation)
]

metadata, is_json = [], False
for m in field_info.metadata:
if not is_json and isinstance(m, Json): # type: ignore[misc]
is_json = True
elif m is not None:
metadata.append(m)
else:
metadata, is_json = [], False
for m in field_info.metadata:
if not is_json and isinstance(m, Json): # type: ignore[misc]
is_json = True
elif m is not None:
metadata.append(m)

constraints = cls.parse_constraints(metadata=metadata) if metadata else {}
constraints = cast(PydanticConstraints, constraints)
constraints = cast(
PydanticConstraints,
cls.parse_constraints(metadata=metadata) if metadata else {},
)

if "url" in constraints:
# pydantic uses a sentinel value for url constraints
annotation = str
if "url" in constraints:
# pydantic uses a sentinel value for url constraints
annotation = str

if is_json:
constraints["json"] = True
if is_json:
constraints["json"] = True

return PydanticFieldMeta.from_type(
annotation=annotation,
Expand Down
11 changes: 5 additions & 6 deletions polyfactory/field_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
from polyfactory.utils.deprecation import check_for_deprecated_parameters
from polyfactory.utils.helpers import (
get_annotation_metadata,
normalize_annotation,
unwrap_annotated,
unwrap_new_type,
)
from polyfactory.utils.predicates import is_annotated, is_any_annotated
from polyfactory.utils.predicates import is_annotated
from polyfactory.utils.types import NoneType

if TYPE_CHECKING:
Expand Down Expand Up @@ -135,14 +134,14 @@ def from_type(
("max_collection_length", max_collection_length),
),
)
field_type = normalize_annotation(annotation, random=random)

if not constraints and is_annotated(annotation):
annotated = is_annotated(annotation)
if not constraints and annotated:
metadata = cls.get_constraints_metadata(annotation)
constraints = cls.parse_constraints(metadata)

if not is_any_annotated(annotation):
annotation = TYPE_MAPPING[field_type] if field_type in TYPE_MAPPING else field_type
if annotated:
annotation = get_args(annotation)[0]
elif (origin := get_origin(annotation)) and origin in TYPE_MAPPING: # pragma: no cover
container = TYPE_MAPPING[origin]
annotation = container[get_args(annotation)] # type: ignore[index]
Expand Down
5 changes: 5 additions & 0 deletions polyfactory/value_generators/complex_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from typing_extensions import is_typeddict

from polyfactory.constants import INSTANTIABLE_TYPE_MAPPING, PY_38
from polyfactory.field_meta import FieldMeta
from polyfactory.utils.model_coverage import CoverageContainer

Expand All @@ -20,6 +21,10 @@ def handle_collection_type(field_meta: FieldMeta, container_type: type, factory:
:returns: A built result.
"""

if PY_38 and container_type in INSTANTIABLE_TYPE_MAPPING:
container_type = INSTANTIABLE_TYPE_MAPPING[container_type] # type: ignore[assignment]

container = container_type()
if not field_meta.children:
return container
Expand Down
41 changes: 32 additions & 9 deletions tests/test_msgspec_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def test_with_new_type() -> None:
class User(Struct):
name: UnixName
groups: List[UnixName]
constrained_name: Annotated[UnixName, Meta(min_length=20)]

class UserFactory(MsgspecFactory[User]):
__model__ = User
Expand Down Expand Up @@ -279,13 +280,14 @@ class FooFactory(MsgspecFactory[Foo]):

def test_union_types() -> None:
class A(Struct):
a: Union[List[str], List[int]]
a: Union[List[str], int]
b: Union[str, List[int]]
c: List[Union[Tuple[int, int], Tuple[str, int]]]
c: List[Union[Tuple[int, int], float]]

AFactory = MsgspecFactory.create_factory(A)

assert AFactory.build()
a = AFactory.build()
assert msgspec.convert(structs.asdict(a), A) == a


def test_collection_unions_with_models() -> None:
Expand All @@ -296,22 +298,28 @@ class B(Struct):
a: str

class C(Struct):
a: Union[List[A], List[B]]
b: List[Union[A, B]]
a: Union[List[A], str]
b: List[Union[A, int]]

CFactory = MsgspecFactory.create_factory(C)

assert CFactory.build()
c = CFactory.build()
assert msgspec.convert(structs.asdict(c), C) == c


def test_constrained_union_types() -> None:
class A(Struct):
a: Union[Annotated[List[str], Meta(min_length=10)], Annotated[int, Meta(ge=1000)]]
b: Union[List[Annotated[str, Meta(min_length=20)]], int]
c: Optional[Annotated[int, Meta(ge=1000)]]
d: Union[Annotated[List[int], Meta(min_length=100)], Annotated[str, Meta(min_length=100)]]
e: Optional[Union[Annotated[List[int], Meta(min_length=100)], Annotated[str, Meta(min_length=100)]]]
f: Optional[Union[Annotated[List[int], Meta(min_length=100)], str]]

AFactory = MsgspecFactory.create_factory(A)
AFactory = MsgspecFactory.create_factory(A, __allow_none_optionals__=False)

assert AFactory.build()
a = AFactory.build()
assert msgspec.convert(structs.asdict(a), A) == a


@pytest.mark.parametrize("allow_none", (True, False))
Expand All @@ -326,4 +334,19 @@ class AFactory(MsgspecFactory[A]):

__allow_none_optionals__ = allow_none

assert AFactory.build()
a = AFactory.build()
assert msgspec.convert(structs.asdict(a), A) == a


def test_annotated_children() -> None:
class A(Struct):
a: Dict[int, Annotated[str, Meta(min_length=20)]]
b: List[Annotated[int, Meta(gt=1000)]]
c: Annotated[List[Annotated[int, Meta(gt=1000)]], Meta(min_length=50)]
d: Dict[int, Annotated[List[Annotated[str, Meta(min_length=1)]], Meta(min_length=1)]]

class AFactory(MsgspecFactory[A]):
__model__ = A

a = AFactory.build()
assert msgspec.convert(structs.asdict(a), A) == a
20 changes: 18 additions & 2 deletions tests/test_pydantic_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from uuid import UUID

import pytest
from annotated_types import Ge, Le, LowerCase, MinLen, UpperCase
from annotated_types import Ge, Gt, Le, LowerCase, MinLen, UpperCase
from typing_extensions import Annotated, TypeAlias

import pydantic
Expand Down Expand Up @@ -432,8 +432,12 @@ def test_constrained_union_types() -> None:
class A(BaseModel):
a: Union[Annotated[List[str], MinLen(100)], Annotated[int, Ge(1000)]]
b: Union[List[Annotated[str, MinLen(100)]], int]
c: Union[Annotated[List[int], MinLen(100)], None]
d: Union[Annotated[List[int], MinLen(100)], Annotated[List[str], MinLen(100)]]
e: Optional[Union[Annotated[List[int], MinLen(10)], Annotated[List[str], MinLen(10)]]]
f: Optional[Union[Annotated[List[int], MinLen(10)], List[str]]]

AFactory = ModelFactory.create_factory(A)
AFactory = ModelFactory.create_factory(A, __allow_none_optionals__=False)

assert AFactory.build()

Expand Down Expand Up @@ -804,3 +808,15 @@ class MyFactory(ModelFactory):
assert isinstance(next(iter(result.conlist_with_complex_type[0].values())), tuple)
assert len(next(iter(result.conlist_with_complex_type[0].values()))) == 3
assert all(isinstance(v, Person) for v in next(iter(result.conlist_with_complex_type[0].values())))


def test_annotated_children() -> None:
class A(BaseModel):
a: Dict[int, Annotated[str, MinLen(min_length=20)]]
b: List[Annotated[int, Gt(gt=1000)]]
c: Annotated[List[Annotated[int, Gt(gt=1000)]], MinLen(min_length=50)]
d: Dict[int, Annotated[List[Annotated[str, MinLen(1)]], MinLen(1)]]

AFactory = ModelFactory.create_factory(A)

assert AFactory.build()

0 comments on commit 0f8f9e8

Please sign in to comment.