Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: correctly handle collections with constrained items #436

Merged
merged 8 commits into from
Nov 12, 2023
15 changes: 14 additions & 1 deletion polyfactory/factories/pydantic_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

if TYPE_CHECKING:
from random import Random
from typing import Callable
from typing import Callable, Sequence

from typing_extensions import NotRequired, TypeGuard

Expand Down Expand Up @@ -284,6 +284,19 @@ def from_model_field( # pragma: no cover
constraints=cast("PydanticConstraints", {k: v for k, v in constraints.items() if v is not None}) or None,
)

if VERSION.startswith("2"):

@classmethod
def get_constraints_metadata(cls, annotation: Any) -> Sequence[Any]:
metadata = []
for m in super().get_constraints_metadata(annotation):
if isinstance(m, FieldInfo):
metadata.extend(m.metadata)
else:
metadata.append(m)

return metadata


class ModelFactory(Generic[T], BaseFactory[T]):
"""Base factory for pydantic models"""
Expand Down
29 changes: 24 additions & 5 deletions polyfactory/field_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,21 @@
from polyfactory.collection_extender import CollectionExtender
from polyfactory.constants import DEFAULT_RANDOM, TYPE_MAPPING
from polyfactory.utils.deprecation import check_for_deprecated_parameters
from polyfactory.utils.helpers import normalize_annotation, unwrap_annotated, unwrap_args, unwrap_new_type
from polyfactory.utils.helpers import (
get_annotation_metadata,
normalize_annotation,
unwrap_annotated,
unwrap_args,
unwrap_new_type,
)
from polyfactory.utils.predicates import is_annotated, is_any_annotated

if TYPE_CHECKING:
import datetime
from decimal import Decimal
from random import Random
from typing import Sequence

from _pydecimal import Decimal
from typing_extensions import NotRequired, Self


Expand Down Expand Up @@ -134,7 +141,7 @@ def from_type(
field_type = normalize_annotation(annotation, random=random)

if not constraints and is_annotated(annotation):
_, metadata = unwrap_annotated(annotation, random=random)
metadata = cls.get_constraints_metadata(annotation)
constraints = cls.parse_constraints(metadata)

if not is_any_annotated(annotation):
Expand All @@ -156,7 +163,7 @@ def from_type(
number_of_args = 1
extended_type_args = CollectionExtender.extend_type_args(field.annotation, field.type_args, number_of_args)
field.children = [
FieldMeta.from_type(
cls.from_type(
annotation=unwrap_new_type(arg),
random=random,
)
Expand All @@ -165,7 +172,7 @@ def from_type(
return field

@classmethod
def parse_constraints(cls, metadata: list[Any]) -> "Constraints":
def parse_constraints(cls, metadata: Sequence[Any]) -> "Constraints":
constraints = {}

for value in metadata:
Expand Down Expand Up @@ -215,3 +222,15 @@ def parse_constraints(cls, metadata: list[Any]) -> "Constraints":
},
)
return cast("Constraints", constraints)

@classmethod
def get_constraints_metadata(cls, annotation: Any) -> Sequence[Any]:
"""Get the metadatas of the constraints from the given annotation.

:param annotation: A type annotation.
:param random: An instance of random.Random.

:returns: A list of the metadata in the annotation.
"""

return get_annotation_metadata(annotation)
12 changes: 12 additions & 0 deletions polyfactory/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

if TYPE_CHECKING:
from random import Random
from typing import Sequence


def unwrap_new_type(annotation: Any) -> Any:
Expand Down Expand Up @@ -162,6 +163,17 @@ def normalize_annotation(annotation: Any, random: Random) -> Any:
return origin


def get_annotation_metadata(annotation: Any) -> Sequence[Any]:
"""Get the metadata in the annotation.

:param annotation: A type annotation.

:returns: The metadata.
"""

return get_args(annotation)[1:]


def get_collection_type(annotation: Any) -> type[list | tuple | set | frozenset | dict]:
"""Get the collection type from the annotation.

Expand Down
36 changes: 35 additions & 1 deletion tests/test_msgspec_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import msgspec
import pytest
from msgspec import Struct, structs
from msgspec import Meta, Struct, structs
from typing_extensions import Annotated

from polyfactory.exceptions import ParameterException
Expand Down Expand Up @@ -215,3 +215,37 @@ class BarFactory(MsgspecFactory[Bar]):

validated_bar = msgspec.convert(bar_dict, type=Bar)
assert validated_bar == bar


def test_sequence_with_constrained_item_types() -> None:
ConstrainedInt = Annotated[int, Meta(ge=100, le=200)]

class Foo(Struct):
list_field: List[ConstrainedInt]
tuple_field: Tuple[ConstrainedInt]
variable_tuple_field: Tuple[ConstrainedInt, ...]
set_field: Set[ConstrainedInt]

class FooFactory(MsgspecFactory[Foo]):
__model__ = Foo

foo = FooFactory.build()
validated_foo = msgspec.convert(structs.asdict(foo), Foo)

assert validated_foo == foo


def test_mapping_with_constrained_item_types() -> None:
ConstrainedInt = Annotated[int, Meta(ge=100, le=200)]
ConstrainedStr = Annotated[str, Meta(min_length=1, max_length=3)]

class Foo(Struct):
dict_field = Dict[ConstrainedStr, ConstrainedInt]

class FooFactory(MsgspecFactory[Foo]):
__model__ = Foo

foo = FooFactory.build()
validated_foo = msgspec.convert(structs.asdict(foo), Foo)

assert validated_foo == foo
31 changes: 30 additions & 1 deletion tests/test_pydantic_factory.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import sys
from typing import Optional
from typing import Dict, List, Optional, Set, Tuple

import pytest
from pydantic import VERSION, BaseModel, Field, Json
from typing_extensions import Annotated

from polyfactory.factories.pydantic_factory import ModelFactory

Expand Down Expand Up @@ -78,3 +79,31 @@ class BFactory(ModelFactory[B]):
__model__ = B

assert isinstance(BFactory.build(), B)


def test_sequence_with_annotated_item_types() -> None:
ConstrainedInt = Annotated[int, Field(ge=100, le=200)]

class Foo(BaseModel):
list_field: List[ConstrainedInt]
tuple_field: Tuple[ConstrainedInt]
variable_tuple_field: Tuple[ConstrainedInt, ...]
set_field: Set[ConstrainedInt]

class FooFactory(ModelFactory[Foo]):
__model__ = Foo

assert FooFactory.build()


def test_mapping_with_annotated_item_types() -> None:
ConstrainedInt = Annotated[int, Field(ge=100, le=200)]
ConstrainedStr = Annotated[str, Field(min_length=1, max_length=3)]

class Foo(BaseModel):
dict_field: Dict[ConstrainedStr, ConstrainedInt]

class FooFactory(ModelFactory[Foo]):
__model__ = Foo

assert FooFactory.build()
Loading