Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: handle unions properly #491

Merged
merged 13 commits into from
Jan 20, 2024
23 changes: 11 additions & 12 deletions polyfactory/factories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,6 @@
from os.path import realpath
from pathlib import Path
from random import Random

from polyfactory.field_meta import Null

try:
from types import NoneType
except ImportError:
NoneType = type(None) # type: ignore[misc,assignment]

from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -58,6 +50,7 @@
RANDOMIZE_COLLECTION_LENGTH,
)
from polyfactory.exceptions import ConfigurationException, MissingBuildKwargException, ParameterException
from polyfactory.field_meta import Null
from polyfactory.fields import Fixture, Ignore, PostGenerated, Require, Use
from polyfactory.utils.helpers import (
flatten_annotation,
Expand All @@ -68,6 +61,7 @@
)
from polyfactory.utils.model_coverage import CoverageContainer, CoverageContainerCallable, resolve_kwargs_coverage
from polyfactory.utils.predicates import get_type_origin, is_any, is_literal, is_optional, is_safe_subclass, is_union
from polyfactory.utils.types import NoneType
from polyfactory.value_generators.complex_types import handle_collection_type, handle_collection_type_coverage
from polyfactory.value_generators.constrained_collections import (
handle_constrained_collection,
Expand Down Expand Up @@ -693,6 +687,15 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912
if field_meta.constraints:
return cls.get_constrained_field_value(annotation=unwrapped_annotation, field_meta=field_meta)

if is_union(field_meta.annotation) and field_meta.children:
seen_models = build_context["seen_models"]
children = [child for child in field_meta.children if child.annotation not in seen_models]

# `None` is removed from the children when creating FieldMeta so when `children`
# is empty, it must mean that the field meta is an optional type.
if children:
return cls.get_field_value(cls.__random__.choice(children), field_build_parameters, build_context)

if BaseFactory.is_factory_type(annotation=unwrapped_annotation):
if not field_build_parameters and unwrapped_annotation in build_context["seen_models"]:
return None if is_optional(field_meta.annotation) else Null
Expand Down Expand Up @@ -740,10 +743,6 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912

return handle_collection_type(field_meta, origin, cls)

if is_union(unwrapped_annotation) and field_meta.children:
children = [child for child in field_meta.children if child.annotation not in build_context["seen_models"]]
return cls.get_field_value(cls.__random__.choice(children))

if is_any(unwrapped_annotation) or isinstance(unwrapped_annotation, TypeVar):
return create_random_string(cls.__random__, min_length=1, max_length=10)

Expand Down
9 changes: 8 additions & 1 deletion polyfactory/factories/pydantic_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from polyfactory.utils.deprecation import check_for_deprecated_parameters
from polyfactory.utils.helpers import unwrap_new_type, unwrap_optional
from polyfactory.utils.predicates import is_optional, is_safe_subclass, is_union
from polyfactory.utils.types import NoneType
from polyfactory.value_generators.primitives import create_random_bytes

try:
Expand Down Expand Up @@ -251,7 +252,13 @@ def from_model_field( # pragma: no cover
annotation = Literal[default_value] # pyright: ignore # noqa: PGH003

children: list[FieldMeta] = []
if model_field.key_field or model_field.sub_fields:

# Refer #412.
args = get_args(model_field.annotation)
if is_optional(model_field.annotation) and len(args) == 2: # noqa: PLR2004
child_annotation = args[0] if args[0] is not NoneType else args[1]
children.append(PydanticFieldMeta.from_type(child_annotation))
elif model_field.key_field or model_field.sub_fields:
fields_to_iterate = (
([model_field.key_field, *model_field.sub_fields])
if model_field.key_field is not None
Expand Down
8 changes: 3 additions & 5 deletions polyfactory/field_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
get_annotation_metadata,
normalize_annotation,
unwrap_annotated,
unwrap_args,
unwrap_new_type,
)
from polyfactory.utils.predicates import is_annotated, is_any_annotated
from polyfactory.utils.types import NoneType

if TYPE_CHECKING:
import datetime
Expand Down Expand Up @@ -99,10 +99,7 @@ def type_args(self) -> tuple[Any, ...]:

:returns: a tuple of types.
"""
return tuple(
TYPE_MAPPING[arg] if arg in TYPE_MAPPING else arg
for arg in unwrap_args(self.annotation, random=self.random)
)
return tuple(TYPE_MAPPING[arg] if arg in TYPE_MAPPING else arg for arg in get_args(self.annotation))

@classmethod
def from_type(
Expand Down Expand Up @@ -168,6 +165,7 @@ def from_type(
random=random,
)
for arg in extended_type_args
if arg is not NoneType
]
return field

Expand Down
9 changes: 4 additions & 5 deletions polyfactory/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,11 @@
import sys
from typing import TYPE_CHECKING, Any, Mapping

try:
from types import NoneType
except ImportError:
NoneType = type(None) # type: ignore[misc,assignment]

from typing_extensions import TypeAliasType, get_args, get_origin

from polyfactory.constants import TYPE_MAPPING
from polyfactory.utils.predicates import is_annotated, is_new_type, is_optional, is_safe_subclass, is_union
from polyfactory.utils.types import NoneType

if TYPE_CHECKING:
from random import Random
Expand Down Expand Up @@ -67,6 +63,7 @@ def unwrap_annotation(annotation: Any, random: Random) -> Any:
"""
while (
is_optional(annotation)
or is_union(annotation)
or is_new_type(annotation)
or is_annotated(annotation)
or isinstance(annotation, TypeAliasType)
Expand All @@ -79,6 +76,8 @@ def unwrap_annotation(annotation: Any, random: Random) -> Any:
annotation = unwrap_annotated(annotation, random=random)[0]
elif isinstance(annotation, TypeAliasType):
annotation = annotation.__value__
else:
annotation = unwrap_union(annotation, random=random)

return annotation

Expand Down
12 changes: 2 additions & 10 deletions polyfactory/utils/predicates.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
from __future__ import annotations

from inspect import isclass
from typing import Any, Literal, NewType, Optional, TypeVar, Union, get_args
from typing import Any, Literal, NewType, Optional, TypeVar, get_args

from typing_extensions import Annotated, NotRequired, ParamSpec, Required, TypeGuard, _AnnotatedAlias, get_origin

from polyfactory.constants import TYPE_MAPPING

try:
from types import NoneType, UnionType

UNION_TYPES = {UnionType, Union}
except ImportError:
NoneType = type(None) # type: ignore[misc,assignment]
UNION_TYPES = {Union}

from polyfactory.utils.types import UNION_TYPES, NoneType

P = ParamSpec("P")
T = TypeVar("T")
Expand Down
12 changes: 12 additions & 0 deletions polyfactory/utils/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Union

try:
from types import NoneType, UnionType

UNION_TYPES = {UnionType, Union}
except ImportError:
UNION_TYPES = {Union}

NoneType = type(None) # type: ignore[misc,assignment]

__all__ = ("NoneType", "UNION_TYPES")
4 changes: 2 additions & 2 deletions tests/test_dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,5 @@ class MyClassFactory(ModelFactory[MyClass]):
test_obj_1 = MyClassFactory.build()
test_obj_2 = MyClassFactory.build()

assert isinstance(next(iter(test_obj_1.val.values())), int)
assert isinstance(next(iter(test_obj_2.val.values())), str)
assert isinstance(next(iter(test_obj_1.val.values())), str)
assert isinstance(next(iter(test_obj_2.val.values())), int)
60 changes: 60 additions & 0 deletions tests/test_union_handling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import List, Optional, Tuple, Union

import pytest
from annotated_types import Ge, MinLen
from pydantic import BaseModel
from typing_extensions import Annotated

from polyfactory.factories.pydantic_factory import ModelFactory


def test_union_types() -> None:
class A(BaseModel):
a: Union[List[str], List[int]]
b: Union[str, List[int]]
c: List[Union[Tuple[int, int], Tuple[str, int]]]

AFactory = ModelFactory.create_factory(A)

assert AFactory.build()


def test_collection_unions_with_models() -> None:
class A(BaseModel):
a: int

class B(BaseModel):
a: str

class C(BaseModel):
a: Union[List[A], List[B]]
b: List[Union[A, B]]

CFactory = ModelFactory.create_factory(C)

assert CFactory.build()


def test_constrained_union_types() -> None:
class A(BaseModel):
a: Union[Annotated[List[str], MinLen(100)], Annotated[int, Ge(1000)]]
b: Union[List[Annotated[str, MinLen(100)]], int]

AFactory = ModelFactory.create_factory(A)

assert AFactory.build()


@pytest.mark.parametrize("allow_none", (True, False))
def test_optional_type(allow_none: bool) -> None:
class A(BaseModel):
a: Union[str, None]
b: Optional[str]
c: Optional[Union[str, int, List[int]]]

class AFactory(ModelFactory[A]):
__model__ = A

__allow_none_optionals__ = allow_none

assert AFactory.build()
Loading