Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Optional __model__ type #452

Merged
merged 15 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions polyfactory/factories/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import typing
from abc import ABC, abstractmethod
from collections import Counter, abc, deque
from contextlib import suppress
Expand Down Expand Up @@ -36,7 +35,9 @@
ClassVar,
Collection,
Generic,
Iterable,
Mapping,
Optional,
Sequence,
Type,
TypeVar,
Expand All @@ -45,7 +46,7 @@
from uuid import UUID

from faker import Faker
from typing_extensions import get_args
from typing_extensions import get_args, get_origin

from polyfactory.constants import (
DEFAULT_RANDOM,
Expand Down Expand Up @@ -190,12 +191,13 @@ def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: # noqa: C901
)

if "__is_base_factory__" not in cls.__dict__ or not cls.__is_base_factory__:
model = getattr(cls, "__model__", None)
model: Optional[type[T]] = getattr(cls, "__model__", None) or cls._infer_model_type()
JacobCoffee marked this conversation as resolved.
Show resolved Hide resolved
if not model:
msg = f"required configuration attribute '__model__' is not set on {cls.__name__}"
raise ConfigurationException(
msg,
)
cls.__model__ = model
if not cls.is_supported_type(model):
for factory in BaseFactory._base_factories:
if factory.is_supported_type(model):
Expand All @@ -219,6 +221,27 @@ def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: # noqa: C901
if cls.__set_as_default_factory_for_type__:
BaseFactory._factory_type_mapping[cls.__model__] = cls

@classmethod
def _infer_model_type(cls: type[F]) -> Optional[type[T]]:
"""Return model type inferred from class declaration.
class Foo(ModelFactory[MyModel]): # <<< MyModel
...

If more than one base class and/or generic arguments specified return None.

:returns: Inferred model type or None
"""
factory_bases: Iterable[type[BaseFactory[T]]] = (
b for b in getattr(cls, "__orig_bases__", []) if issubclass(get_origin(b), BaseFactory)
JacobCoffee marked this conversation as resolved.
Show resolved Hide resolved
)
generic_args: Sequence[type[T]] = [
arg for factory_base in factory_bases for arg in get_args(factory_base) if not isinstance(arg, TypeVar)
]
if len(generic_args) != 1:
return None

return generic_args[0]
JacobCoffee marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def _get_sync_persistence(cls) -> SyncPersistenceProtocol[T]:
"""Return a SyncPersistenceHandler if defined for the factory, otherwise raises a ConfigurationException.
Expand Down Expand Up @@ -678,7 +701,7 @@ def get_field_value_coverage( # noqa: C901
cls,
field_meta: FieldMeta,
field_build_parameters: Any | None = None,
) -> typing.Iterable[Any]:
) -> Iterable[Any]:
"""Return a field value on the subclass if existing, otherwise returns a mock value.

:param field_meta: FieldMeta instance.
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ line-length = 120
src = ["polyfactory", "tests", "docs/examples"]
target-version = "py38"

[tool.ruff.lint.pyupgrade]
# Preserve types, even if a file imports `from __future__ import annotations`.
keep-runtime-typing = true
JacobCoffee marked this conversation as resolved.
Show resolved Hide resolved

[tool.ruff.pydocstyle]
convention = "google"

Expand Down
Empty file.
48 changes: 48 additions & 0 deletions tests/optional_model_field/test_attrs_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import datetime as dt
from decimal import Decimal
from enum import Enum
from typing import Any, Dict, List, Tuple
from uuid import UUID

import pytest
from attrs import asdict, define

from polyfactory.factories.attrs_factory import AttrsFactory

pytestmark = [pytest.mark.attrs]


def test_with_basic_types_annotated() -> None:
class SampleEnum(Enum):
FOO = "foo"
BAR = "bar"

@define
class Foo:
bool_field: bool
int_field: int
float_field: float
str_field: str
bytse_field: bytes
bytearray_field: bytearray
tuple_field: Tuple[int, str]
tuple_with_variadic_args: Tuple[int, ...]
list_field: List[int]
dict_field: Dict[str, int]
datetime_field: dt.datetime
date_field: dt.date
time_field: dt.time
uuid_field: UUID
decimal_field: Decimal
enum_type: SampleEnum
any_type: Any

class FooFactory(AttrsFactory[Foo]):
...

assert getattr(FooFactory, "__model__") is Foo

foo: Foo = FooFactory.build()
foo_dict = asdict(foo)

assert foo == Foo(**foo_dict)
48 changes: 48 additions & 0 deletions tests/optional_model_field/test_beanie_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Callable, List

import pymongo
import pytest

try:
from beanie import Document, init_beanie
from beanie.odm.fields import Indexed, PydanticObjectId
from mongomock_motor import AsyncMongoMockClient

from polyfactory.factories.beanie_odm_factory import BeanieDocumentFactory
except ImportError:
pytest.importorskip("beanie")

BeanieDocumentFactory = None # type: ignore
Document = None # type: ignore
init_beanie = None # type: ignore
Indexed = None # type: ignore
PydanticObjectId = None # type: ignore


@pytest.fixture()
def mongo_connection() -> AsyncMongoMockClient:
return AsyncMongoMockClient()


class MyDocument(Document):
id: PydanticObjectId
name: str
index: Indexed(str, pymongo.DESCENDING) # type: ignore
siblings: List[PydanticObjectId]


class MyFactory(BeanieDocumentFactory[MyDocument]):
...


@pytest.fixture()
async def beanie_init(mongo_connection: AsyncMongoMockClient) -> None:
await init_beanie(database=mongo_connection.db_name, document_models=[MyDocument]) # type: ignore


async def test_handling_of_beanie_types(beanie_init: Callable) -> None:
assert getattr(MyFactory, "__model__") is MyDocument
result: MyDocument = MyFactory.build()
assert result.name
assert result.index
assert isinstance(result.index, str)
20 changes: 20 additions & 0 deletions tests/optional_model_field/test_generic_class_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Generic, TypeVar

from pydantic.generics import GenericModel

from polyfactory.factories.pydantic_factory import ModelFactory


def test_generic_model_is_not_an_error() -> None:
T = TypeVar("T")
P = TypeVar("P")

class Foo(GenericModel, Generic[T, P]): # type: ignore[misc]
val1: T
val2: P

class FooFactory(ModelFactory[Foo[str, int]]):
...

assert isinstance(FooFactory.build().val1, str)
assert isinstance(FooFactory.build().val2, int)
59 changes: 59 additions & 0 deletions tests/optional_model_field/test_model_inference_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Type, TypedDict

import pytest
from pydantic import BaseModel

from polyfactory import ConfigurationException
from polyfactory.factories import TypedDictFactory
from polyfactory.factories.attrs_factory import AttrsFactory
from polyfactory.factories.base import BaseFactory
from polyfactory.factories.msgspec_factory import MsgspecFactory
from polyfactory.factories.pydantic_factory import ModelFactory
from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory


@pytest.mark.parametrize(
"base_factory",
[
AttrsFactory,
ModelFactory,
MsgspecFactory,
SQLAlchemyFactory,
TypedDictFactory,
],
)
def test_model_without_generic_type_inference_error(base_factory: Type[BaseFactory]) -> None:
with pytest.raises(ConfigurationException):

class Foo(base_factory): # type: ignore
...


@pytest.mark.parametrize(
"base_factory",
[
AttrsFactory,
ModelFactory,
MsgspecFactory,
SQLAlchemyFactory,
TypedDictFactory,
],
)
def test_model_type_error(base_factory: Type[BaseFactory]) -> None:
with pytest.raises(ConfigurationException):

class Foo(base_factory[int]): # type: ignore
...


def test_model_multiple_inheritance_cannot_infer_error() -> None:
class PFoo(BaseModel):
val: int

class TDFoo(TypedDict):
val: str

with pytest.raises(ConfigurationException):

class Foo(ModelFactory[PFoo], TypedDictFactory[TDFoo]): # type: ignore
...
25 changes: 25 additions & 0 deletions tests/optional_model_field/test_msgspec_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import msgspec
from msgspec import Struct, structs

from polyfactory.factories.msgspec_factory import MsgspecFactory


def test_with_nested_struct() -> None:
class Foo(Struct):
int_field: int

class Bar(Struct):
int_field: int
foo_field: Foo

class BarFactory(MsgspecFactory[Bar]):
...

assert getattr(BarFactory, "__model__") is Bar

bar: Bar = BarFactory.build()
bar_dict = structs.asdict(bar)
bar_dict["foo_field"] = structs.asdict(bar_dict["foo_field"])

validated_bar = msgspec.convert(bar_dict, type=Bar)
assert validated_bar == bar
90 changes: 90 additions & 0 deletions tests/optional_model_field/test_odmantic_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from datetime import datetime
from typing import Any, List
from uuid import UUID

import bson
import pytest

try:
from odmantic import AIOEngine, EmbeddedModel, Model

from polyfactory.factories.odmantic_odm_factory import OdmanticModelFactory
except ImportError:
AIOEngine, EmbeddedModel, Model, OdmanticModelFactory = None, None, None, None # type: ignore
pytest.importorskip("odmantic")


class OtherEmbeddedDocument(EmbeddedModel): # type: ignore
name: str
serial: UUID
created_on: datetime
bson_id: bson.ObjectId
bson_int64: bson.Int64
bson_dec128: bson.Decimal128
bson_binary: bson.Binary


class MyEmbeddedDocument(EmbeddedModel): # type: ignore
name: str
serial: UUID
other_embedded_document: OtherEmbeddedDocument
created_on: datetime
bson_id: bson.ObjectId
bson_int64: bson.Int64
bson_dec128: bson.Decimal128
bson_binary: bson.Binary


class MyModel(Model): # type: ignore
created_on: datetime
bson_id: bson.ObjectId
bson_int64: bson.Int64
bson_dec128: bson.Decimal128
bson_binary: bson.Binary
name: str
embedded: MyEmbeddedDocument
embedded_list: List[MyEmbeddedDocument]


@pytest.fixture()
async def odmantic_engine(mongo_connection: Any) -> AIOEngine:
return AIOEngine(client=mongo_connection, database=mongo_connection.db_name)


def test_handles_odmantic_models() -> None:
class MyFactory(OdmanticModelFactory[MyModel]):
...

assert getattr(MyFactory, "__model__") is MyModel

result: MyModel = MyFactory.build()

assert isinstance(result, MyModel)
assert isinstance(result.id, bson.ObjectId)
assert isinstance(result.created_on, datetime)
assert isinstance(result.bson_id, bson.ObjectId)
assert isinstance(result.bson_int64, bson.Int64)
assert isinstance(result.bson_dec128, bson.Decimal128)
assert isinstance(result.bson_binary, bson.Binary)
assert isinstance(result.name, str)
assert isinstance(result.embedded, MyEmbeddedDocument)
assert isinstance(result.embedded_list, list)
for item in result.embedded_list:
assert isinstance(item, MyEmbeddedDocument)
assert isinstance(item.name, str)
assert isinstance(item.serial, UUID)
assert isinstance(item.created_on, datetime)
assert isinstance(item.bson_id, bson.ObjectId)
assert isinstance(item.bson_int64, bson.Int64)
assert isinstance(item.bson_dec128, bson.Decimal128)
assert isinstance(item.bson_binary, bson.Binary)

other = item.other_embedded_document
assert isinstance(other, OtherEmbeddedDocument)
assert isinstance(other.name, str)
assert isinstance(other.serial, UUID)
assert isinstance(other.created_on, datetime)
assert isinstance(other.bson_id, bson.ObjectId)
assert isinstance(other.bson_int64, bson.Int64)
assert isinstance(other.bson_dec128, bson.Decimal128)
assert isinstance(other.bson_binary, bson.Binary)
Loading
Loading