Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: constrained unions #499

Merged
merged 9 commits into from
Mar 2, 2024
27 changes: 26 additions & 1 deletion polyfactory/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import sys
from collections import abc, defaultdict, deque
from random import Random
from typing import (
Expand All @@ -21,9 +22,10 @@
except ImportError:
UnionType = Union # type: ignore[misc,assignment]

PY_38 = sys.version_info.major == 3 and sys.version_info.minor == 8 # noqa: PLR2004

# Mapping of type annotations into concrete types. This is used to normalize python <= 3.9 annotations.
TYPE_MAPPING = {
INSTANTIABLE_TYPE_MAPPING = {
DefaultDict: defaultdict,
Deque: deque,
Dict: dict,
Expand All @@ -41,6 +43,29 @@
UnionType: Union,
}


if not PY_38:
TYPE_MAPPING = INSTANTIABLE_TYPE_MAPPING
else:
# For 3.8, we have to keep the types from typing since dict[str] syntax is not supported in 3.8.
TYPE_MAPPING = {
DefaultDict: DefaultDict,
Deque: Deque,
Dict: Dict,
FrozenSet: FrozenSet,
Iterable: List,
List: List,
Mapping: Dict,
Sequence: List,
Set: Set,
Tuple: Tuple,
abc.Iterable: List,
abc.Mapping: Dict,
abc.Sequence: List,
abc.Set: Set,
}


DEFAULT_RANDOM = Random()
RANDOMIZE_COLLECTION_LENGTH = False
MIN_COLLECTION_LENGTH = 0
Expand Down
58 changes: 32 additions & 26 deletions polyfactory/factories/pydantic_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,39 +164,45 @@ def from_field_info(

annotation = unwrap_new_type(field_info.annotation)
children: list[FieldMeta,] | None = None
constraints: Constraints = {}
name = field_info.alias if field_info.alias and use_alias else field_name

constraints: PydanticConstraints
# pydantic v2 does not always propagate metadata for Union types
if not field_info.metadata and is_optional(annotation):
field_info = FieldInfo.from_annotation(unwrap_optional(annotation))
elif is_union(annotation):
children = [
cls.from_field_info(
field_info=FieldInfo.from_annotation(arg),
field_name=field_name,
random=random,
use_alias=use_alias,
if is_union(annotation):
constraints = {}
children = []
for arg in get_args(annotation):
if arg is NoneType:
continue
child_field_info = FieldInfo.from_annotation(arg)
merged_field_info = FieldInfo.merge_field_infos(field_info, child_field_info)
children.append(
cls.from_field_info(
field_name="",
field_info=merged_field_info,
use_alias=use_alias,
random=random,
),
)
for arg in get_args(annotation)
]

metadata, is_json = [], False
for m in field_info.metadata:
if not is_json and isinstance(m, Json): # type: ignore[misc]
is_json = True
elif m is not None:
metadata.append(m)
else:
metadata, is_json = [], False
for m in field_info.metadata:
if not is_json and isinstance(m, Json): # type: ignore[misc]
is_json = True
elif m is not None:
metadata.append(m)

constraints = cls.parse_constraints(metadata=metadata) if metadata else {}
constraints = cast(PydanticConstraints, constraints)
constraints = cast(
PydanticConstraints,
cls.parse_constraints(metadata=metadata) if metadata else {},
)

if "url" in constraints:
# pydantic uses a sentinel value for url constraints
annotation = str
if "url" in constraints:
# pydantic uses a sentinel value for url constraints
annotation = str

if is_json:
constraints["json"] = True
if is_json:
constraints["json"] = True

return PydanticFieldMeta.from_type(
annotation=annotation,
Expand Down
11 changes: 5 additions & 6 deletions polyfactory/field_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
from polyfactory.utils.deprecation import check_for_deprecated_parameters
from polyfactory.utils.helpers import (
get_annotation_metadata,
normalize_annotation,
unwrap_annotated,
unwrap_new_type,
)
from polyfactory.utils.predicates import is_annotated, is_any_annotated
from polyfactory.utils.predicates import is_annotated
from polyfactory.utils.types import NoneType

if TYPE_CHECKING:
Expand Down Expand Up @@ -135,14 +134,14 @@ def from_type(
("max_collection_length", max_collection_length),
),
)
field_type = normalize_annotation(annotation, random=random)

if not constraints and is_annotated(annotation):
annotated = is_annotated(annotation)
if not constraints and annotated:
metadata = cls.get_constraints_metadata(annotation)
constraints = cls.parse_constraints(metadata)

if not is_any_annotated(annotation):
annotation = TYPE_MAPPING[field_type] if field_type in TYPE_MAPPING else field_type
if annotated:
annotation = get_args(annotation)[0]
elif (origin := get_origin(annotation)) and origin in TYPE_MAPPING: # pragma: no cover
container = TYPE_MAPPING[origin]
annotation = container[get_args(annotation)] # type: ignore[index]
Expand Down
5 changes: 5 additions & 0 deletions polyfactory/value_generators/complex_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from typing_extensions import is_typeddict

from polyfactory.constants import INSTANTIABLE_TYPE_MAPPING, PY_38
from polyfactory.field_meta import FieldMeta
from polyfactory.utils.model_coverage import CoverageContainer

Expand All @@ -20,6 +21,10 @@ def handle_collection_type(field_meta: FieldMeta, container_type: type, factory:

:returns: A built result.
"""

if PY_38 and container_type in INSTANTIABLE_TYPE_MAPPING:
container_type = INSTANTIABLE_TYPE_MAPPING[container_type] # type: ignore[assignment]

container = container_type()
if not field_meta.children:
return container
Expand Down
41 changes: 32 additions & 9 deletions tests/test_msgspec_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def test_with_new_type() -> None:
class User(Struct):
name: UnixName
groups: List[UnixName]
constrained_name: Annotated[UnixName, Meta(min_length=20)]

class UserFactory(MsgspecFactory[User]):
__model__ = User
Expand Down Expand Up @@ -279,13 +280,14 @@ class FooFactory(MsgspecFactory[Foo]):

def test_union_types() -> None:
class A(Struct):
a: Union[List[str], List[int]]
a: Union[List[str], int]
b: Union[str, List[int]]
c: List[Union[Tuple[int, int], Tuple[str, int]]]
c: List[Union[Tuple[int, int], float]]

AFactory = MsgspecFactory.create_factory(A)

assert AFactory.build()
a = AFactory.build()
assert msgspec.convert(structs.asdict(a), A) == a


def test_collection_unions_with_models() -> None:
Expand All @@ -296,22 +298,28 @@ class B(Struct):
a: str

class C(Struct):
a: Union[List[A], List[B]]
b: List[Union[A, B]]
a: Union[List[A], str]
b: List[Union[A, int]]

CFactory = MsgspecFactory.create_factory(C)

assert CFactory.build()
c = CFactory.build()
assert msgspec.convert(structs.asdict(c), C) == c


def test_constrained_union_types() -> None:
class A(Struct):
a: Union[Annotated[List[str], Meta(min_length=10)], Annotated[int, Meta(ge=1000)]]
b: Union[List[Annotated[str, Meta(min_length=20)]], int]
c: Optional[Annotated[int, Meta(ge=1000)]]
d: Union[Annotated[List[int], Meta(min_length=100)], Annotated[str, Meta(min_length=100)]]
e: Optional[Union[Annotated[List[int], Meta(min_length=100)], Annotated[str, Meta(min_length=100)]]]
f: Optional[Union[Annotated[List[int], Meta(min_length=100)], str]]

AFactory = MsgspecFactory.create_factory(A)
AFactory = MsgspecFactory.create_factory(A, __allow_none_optionals__=False)

assert AFactory.build()
a = AFactory.build()
assert msgspec.convert(structs.asdict(a), A) == a


@pytest.mark.parametrize("allow_none", (True, False))
Expand All @@ -326,4 +334,19 @@ class AFactory(MsgspecFactory[A]):

__allow_none_optionals__ = allow_none

assert AFactory.build()
a = AFactory.build()
assert msgspec.convert(structs.asdict(a), A) == a


def test_annotated_children() -> None:
class A(Struct):
a: Dict[int, Annotated[str, Meta(min_length=20)]]
b: List[Annotated[int, Meta(gt=1000)]]
c: Annotated[List[Annotated[int, Meta(gt=1000)]], Meta(min_length=50)]
d: Dict[int, Annotated[List[Annotated[str, Meta(min_length=1)]], Meta(min_length=1)]]

class AFactory(MsgspecFactory[A]):
__model__ = A

a = AFactory.build()
assert msgspec.convert(structs.asdict(a), A) == a
20 changes: 18 additions & 2 deletions tests/test_pydantic_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from uuid import UUID

import pytest
from annotated_types import Ge, Le, LowerCase, MinLen, UpperCase
from annotated_types import Ge, Gt, Le, LowerCase, MinLen, UpperCase
from typing_extensions import Annotated, TypeAlias

import pydantic
Expand Down Expand Up @@ -432,8 +432,12 @@ def test_constrained_union_types() -> None:
class A(BaseModel):
a: Union[Annotated[List[str], MinLen(100)], Annotated[int, Ge(1000)]]
b: Union[List[Annotated[str, MinLen(100)]], int]
c: Union[Annotated[List[int], MinLen(100)], None]
d: Union[Annotated[List[int], MinLen(100)], Annotated[List[str], MinLen(100)]]
e: Optional[Union[Annotated[List[int], MinLen(10)], Annotated[List[str], MinLen(10)]]]
f: Optional[Union[Annotated[List[int], MinLen(10)], List[str]]]

AFactory = ModelFactory.create_factory(A)
AFactory = ModelFactory.create_factory(A, __allow_none_optionals__=False)

assert AFactory.build()

Expand Down Expand Up @@ -804,3 +808,15 @@ class MyFactory(ModelFactory):
assert isinstance(next(iter(result.conlist_with_complex_type[0].values())), tuple)
assert len(next(iter(result.conlist_with_complex_type[0].values()))) == 3
assert all(isinstance(v, Person) for v in next(iter(result.conlist_with_complex_type[0].values())))


def test_annotated_children() -> None:
class A(BaseModel):
a: Dict[int, Annotated[str, MinLen(min_length=20)]]
b: List[Annotated[int, Gt(gt=1000)]]
c: Annotated[List[Annotated[int, Gt(gt=1000)]], MinLen(min_length=50)]
d: Dict[int, Annotated[List[Annotated[str, MinLen(1)]], MinLen(1)]]

AFactory = ModelFactory.create_factory(A)

assert AFactory.build()
Loading