Skip to content

Commit

Permalink
fix: correctly handle collections with constrained items (#436)
Browse files Browse the repository at this point in the history
  • Loading branch information
guacs authored Nov 12, 2023
1 parent b1e8b5e commit 6b7512d
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 8 deletions.
15 changes: 14 additions & 1 deletion polyfactory/factories/pydantic_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

if TYPE_CHECKING:
from random import Random
from typing import Callable
from typing import Callable, Sequence

from typing_extensions import NotRequired, TypeGuard

Expand Down Expand Up @@ -284,6 +284,19 @@ def from_model_field( # pragma: no cover
constraints=cast("PydanticConstraints", {k: v for k, v in constraints.items() if v is not None}) or None,
)

if VERSION.startswith("2"):

@classmethod
def get_constraints_metadata(cls, annotation: Any) -> Sequence[Any]:
metadata = []
for m in super().get_constraints_metadata(annotation):
if isinstance(m, FieldInfo):
metadata.extend(m.metadata)
else:
metadata.append(m)

return metadata


class ModelFactory(Generic[T], BaseFactory[T]):
"""Base factory for pydantic models"""
Expand Down
29 changes: 24 additions & 5 deletions polyfactory/field_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,21 @@
from polyfactory.collection_extender import CollectionExtender
from polyfactory.constants import DEFAULT_RANDOM, TYPE_MAPPING
from polyfactory.utils.deprecation import check_for_deprecated_parameters
from polyfactory.utils.helpers import normalize_annotation, unwrap_annotated, unwrap_args, unwrap_new_type
from polyfactory.utils.helpers import (
get_annotation_metadata,
normalize_annotation,
unwrap_annotated,
unwrap_args,
unwrap_new_type,
)
from polyfactory.utils.predicates import is_annotated, is_any_annotated

if TYPE_CHECKING:
import datetime
from decimal import Decimal
from random import Random
from typing import Sequence

from _pydecimal import Decimal
from typing_extensions import NotRequired, Self


Expand Down Expand Up @@ -134,7 +141,7 @@ def from_type(
field_type = normalize_annotation(annotation, random=random)

if not constraints and is_annotated(annotation):
_, metadata = unwrap_annotated(annotation, random=random)
metadata = cls.get_constraints_metadata(annotation)
constraints = cls.parse_constraints(metadata)

if not is_any_annotated(annotation):
Expand All @@ -156,7 +163,7 @@ def from_type(
number_of_args = 1
extended_type_args = CollectionExtender.extend_type_args(field.annotation, field.type_args, number_of_args)
field.children = [
FieldMeta.from_type(
cls.from_type(
annotation=unwrap_new_type(arg),
random=random,
)
Expand All @@ -165,7 +172,7 @@ def from_type(
return field

@classmethod
def parse_constraints(cls, metadata: list[Any]) -> "Constraints":
def parse_constraints(cls, metadata: Sequence[Any]) -> "Constraints":
constraints = {}

for value in metadata:
Expand Down Expand Up @@ -215,3 +222,15 @@ def parse_constraints(cls, metadata: list[Any]) -> "Constraints":
},
)
return cast("Constraints", constraints)

@classmethod
def get_constraints_metadata(cls, annotation: Any) -> Sequence[Any]:
"""Get the metadatas of the constraints from the given annotation.
:param annotation: A type annotation.
:param random: An instance of random.Random.
:returns: A list of the metadata in the annotation.
"""

return get_annotation_metadata(annotation)
12 changes: 12 additions & 0 deletions polyfactory/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

if TYPE_CHECKING:
from random import Random
from typing import Sequence


def unwrap_new_type(annotation: Any) -> Any:
Expand Down Expand Up @@ -162,6 +163,17 @@ def normalize_annotation(annotation: Any, random: Random) -> Any:
return origin


def get_annotation_metadata(annotation: Any) -> Sequence[Any]:
"""Get the metadata in the annotation.
:param annotation: A type annotation.
:returns: The metadata.
"""

return get_args(annotation)[1:]


def get_collection_type(annotation: Any) -> type[list | tuple | set | frozenset | dict]:
"""Get the collection type from the annotation.
Expand Down
36 changes: 35 additions & 1 deletion tests/test_msgspec_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import msgspec
import pytest
from msgspec import Struct, structs
from msgspec import Meta, Struct, structs
from typing_extensions import Annotated

from polyfactory.exceptions import ParameterException
Expand Down Expand Up @@ -215,3 +215,37 @@ class BarFactory(MsgspecFactory[Bar]):

validated_bar = msgspec.convert(bar_dict, type=Bar)
assert validated_bar == bar


def test_sequence_with_constrained_item_types() -> None:
ConstrainedInt = Annotated[int, Meta(ge=100, le=200)]

class Foo(Struct):
list_field: List[ConstrainedInt]
tuple_field: Tuple[ConstrainedInt]
variable_tuple_field: Tuple[ConstrainedInt, ...]
set_field: Set[ConstrainedInt]

class FooFactory(MsgspecFactory[Foo]):
__model__ = Foo

foo = FooFactory.build()
validated_foo = msgspec.convert(structs.asdict(foo), Foo)

assert validated_foo == foo


def test_mapping_with_constrained_item_types() -> None:
ConstrainedInt = Annotated[int, Meta(ge=100, le=200)]
ConstrainedStr = Annotated[str, Meta(min_length=1, max_length=3)]

class Foo(Struct):
dict_field = Dict[ConstrainedStr, ConstrainedInt]

class FooFactory(MsgspecFactory[Foo]):
__model__ = Foo

foo = FooFactory.build()
validated_foo = msgspec.convert(structs.asdict(foo), Foo)

assert validated_foo == foo
31 changes: 30 additions & 1 deletion tests/test_pydantic_factory.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import sys
from typing import Optional
from typing import Dict, List, Optional, Set, Tuple

import pytest
from pydantic import VERSION, BaseModel, Field, Json
from typing_extensions import Annotated

from polyfactory.factories.pydantic_factory import ModelFactory

Expand Down Expand Up @@ -78,3 +79,31 @@ class BFactory(ModelFactory[B]):
__model__ = B

assert isinstance(BFactory.build(), B)


def test_sequence_with_annotated_item_types() -> None:
ConstrainedInt = Annotated[int, Field(ge=100, le=200)]

class Foo(BaseModel):
list_field: List[ConstrainedInt]
tuple_field: Tuple[ConstrainedInt]
variable_tuple_field: Tuple[ConstrainedInt, ...]
set_field: Set[ConstrainedInt]

class FooFactory(ModelFactory[Foo]):
__model__ = Foo

assert FooFactory.build()


def test_mapping_with_annotated_item_types() -> None:
ConstrainedInt = Annotated[int, Field(ge=100, le=200)]
ConstrainedStr = Annotated[str, Field(min_length=1, max_length=3)]

class Foo(BaseModel):
dict_field: Dict[ConstrainedStr, ConstrainedInt]

class FooFactory(ModelFactory[Foo]):
__model__ = Foo

assert FooFactory.build()

0 comments on commit 6b7512d

Please sign in to comment.