Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add field overrides #519

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions polyfactory/factories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
RANDOMIZE_COLLECTION_LENGTH,
)
from polyfactory.exceptions import ConfigurationException, MissingBuildKwargException, ParameterException
from polyfactory.field_meta import Null
from polyfactory.field_meta import FieldMeta, Null
from polyfactory.fields import Fixture, Ignore, PostGenerated, Require, Use
from polyfactory.utils.helpers import (
flatten_annotation,
Expand Down Expand Up @@ -82,7 +82,7 @@
if TYPE_CHECKING:
from typing_extensions import TypeGuard

from polyfactory.field_meta import Constraints, FieldMeta
from polyfactory.field_meta import Constraints
from polyfactory.persistence import AsyncPersistenceProtocol, SyncPersistenceProtocol


Expand Down Expand Up @@ -914,7 +914,7 @@ def _check_declared_fields_exist_in_model(cls) -> None:
raise ConfigurationException(error_message)

@classmethod
def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]:
def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]: # noqa: C901
"""Process the given kwargs and generate values for the factory's model.

:param kwargs: Any build kwargs.
Expand All @@ -931,7 +931,12 @@ def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]:
for field_meta in cls.get_model_fields():
field_build_parameters = cls.extract_field_build_parameters(field_meta=field_meta, build_args=kwargs)
if cls.should_set_field_value(field_meta, **kwargs) and not cls.should_use_default_value(field_meta):
if hasattr(cls, field_meta.name) and not hasattr(BaseFactory, field_meta.name):
field_value = getattr(cls, field_meta.name, Null)
if (
field_value is not Null
and not hasattr(BaseFactory, field_meta.name)
and not isinstance(field_value, FieldMeta)
):
field_value = getattr(cls, field_meta.name)
if isinstance(field_value, Ignore):
continue
Expand All @@ -951,6 +956,16 @@ def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]:
)
continue

if isinstance(field_value, FieldMeta):
if field_value.annotation is not Null:
field_meta.annotation = field_value.annotation
field_meta.children = field_value.children

if field_value.constraints:
if field_meta.constraints is None:
field_meta.constraints = {}
field_meta.constraints.update(field_value.constraints)

Comment on lines +959 to +968
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't this just be cls.get_field_value(field_value, ...) here instead? The modifying of the field_meta directly may have unforeseen consequences if they're cached or something.

field_result = cls.get_field_value(
field_meta,
field_build_parameters=field_build_parameters,
Expand Down
12 changes: 10 additions & 2 deletions polyfactory/fields.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from __future__ import annotations

from typing import Any, Callable, Generic, TypedDict, TypeVar, cast
from typing import TYPE_CHECKING, Any, Callable, Generic, TypedDict, TypeVar, cast

from typing_extensions import ParamSpec
from typing_extensions import ParamSpec, Unpack

from polyfactory.exceptions import ParameterException
from polyfactory.field_meta import FieldMeta, Null

if TYPE_CHECKING:
from polyfactory.field_meta import Constraints

T = TypeVar("T")
P = ParamSpec("P")
Expand Down Expand Up @@ -114,3 +118,7 @@ def to_value(self) -> Any:

msg = "fixture has not been registered using the register_factory decorator"
raise ParameterException(msg)


def Field(annotation: Any = Null, **constraints: Unpack[Constraints]) -> FieldMeta: # noqa: N802
return FieldMeta.from_type(annotation, constraints=constraints)
25 changes: 23 additions & 2 deletions tests/test_factory_fields.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import random
from datetime import datetime, timedelta
from typing import Any, Optional, Union
from typing import Any, Literal, Optional, Union

import pytest

Expand All @@ -9,7 +9,7 @@
from polyfactory.decorators import post_generated
from polyfactory.exceptions import ConfigurationException, MissingBuildKwargException
from polyfactory.factories.pydantic_factory import ModelFactory
from polyfactory.fields import Ignore, PostGenerated, Require, Use
from polyfactory.fields import Field, Ignore, PostGenerated, Require, Use


def test_use() -> None:
Expand Down Expand Up @@ -202,3 +202,24 @@ class NoFieldModel(BaseModel):
match="unknown_field is declared on the factory NoFieldModelFactory but it is not part of the model NoFieldModel",
):
ModelFactory.create_factory(NoFieldModel, bases=None, __check_model__=True, unknown_field=factory_field)


def test_field() -> None:
class Model(BaseModel):
a: int
b: Union[str, None]
complex_type: list[tuple[int, int]]

class Factory(ModelFactory[Model]):
a = Field(ge=1, le=2)
b = Field(annotation=None)
complex_type = Field(
annotation=list[tuple[Literal[1], Literal[2]]],
min_length=2,
max_length=2,
)

result = Factory.build()
assert 1 <= result.a <= 2
assert result.b is None
assert result.complex_type == [(1, 2), (1, 2)]
Loading