Skip to content

Commit

Permalink
feat: pass on factory config
Browse files Browse the repository at this point in the history
  • Loading branch information
adhtruong committed Jan 16, 2024
1 parent 80bd012 commit fa34a6f
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 3 deletions.
28 changes: 26 additions & 2 deletions polyfactory/factories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,20 @@ class BaseFactory(ABC, Generic[T]):
"""
Flag indicating whether to use the default value on a specific field, if provided.
"""
__extra_providers__: dict[Any, Callable[[], Any]] | None = None

__config_keys__: tuple[str, ...] = (
"__check_model__",
"__allow_none_optionals__",
"__set_as_default_factory_for_type__",
"__faker__",
"__random__",
"__randomize_collection_length__",
"__min_collection_length__",
"__max_collection_length__",
"__use_defaults__",
)
"""Keys to be considered as config values to pass on to dynamically created factories."""

# cached attributes
_fields_metadata: list[FieldMeta]
Expand Down Expand Up @@ -345,6 +359,13 @@ def _handle_factory_field_coverage(

return CoverageContainerCallable(field_value) if callable(field_value) else field_value

@classmethod
def _get_config(cls) -> dict[str, Any]:
return {
**{key: getattr(cls, key) for key in cls.__config_keys__},
"__extra_providers__": cls.get_provider_map(),
}

@classmethod
def _get_or_create_factory(cls, model: type) -> type[BaseFactory[Any]]:
"""Get a factory from registered factories or generate a factory dynamically.
Expand All @@ -356,14 +377,16 @@ def _get_or_create_factory(cls, model: type) -> type[BaseFactory[Any]]:
if factory := BaseFactory._factory_type_mapping.get(model):
return factory

config = cls._get_config()

if cls.__base_factory_overrides__:
for model_ancestor in model.mro():
if factory := cls.__base_factory_overrides__.get(model_ancestor):
return factory.create_factory(model)
return factory.create_factory(model, **config)

for factory in reversed(BaseFactory._base_factories):
if factory.is_supported_type(model):
return factory.create_factory(model)
return factory.create_factory(model, **config)

msg = f"unsupported model type {model.__name__}"
raise ParameterException(msg) # pragma: no cover
Expand Down Expand Up @@ -503,6 +526,7 @@ def _create_generic_fn() -> Callable:
Callable: _create_generic_fn,
abc.Callable: _create_generic_fn,
Counter: lambda: Counter(cls.__faker__.pystr()),
**(cls.__extra_providers__ or {}),
}

@classmethod
Expand Down
7 changes: 7 additions & 0 deletions polyfactory/factories/sqlalchemy_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ class SQLAlchemyFactory(Generic[T], BaseFactory[T]):
__session__: ClassVar[Session | Callable[[], Session] | None] = None
__async_session__: ClassVar[AsyncSession | Callable[[], AsyncSession] | None] = None

__config_keys__ = (
*BaseFactory.__config_keys__,
"__set_primary_key__",
"__set_foreign_keys__",
"__set_relationships__",
)

@classmethod
def get_sqlalchemy_types(cls) -> dict[Any, Callable[[], Any]]:
"""Get mapping of types where column type."""
Expand Down
37 changes: 36 additions & 1 deletion tests/test_factory_configuration.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Any, Type
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Type

from typing_extensions import TypeGuard

from polyfactory.factories.base import BaseFactory, T
from polyfactory.factories.dataclass_factory import DataclassFactory


def test_setting_set_as_default_factory_for_type_on_base_factory() -> None:
Expand All @@ -18,3 +20,36 @@ def is_supported_type(cls, value: Any) -> TypeGuard[Type[T]]:
# list of base factories, but this obviously shouldn't be ran
# for any of the types.
return False


def test_inheriting_config() -> None:
class CustomType:
def __init__(self, a: int) -> None:
self.a = a

@dataclass
class Child:
a: List[int]
custom_type: CustomType

@dataclass
class Parent:
children: List[Child]

class ParentFactory(DataclassFactory[Parent]):
__randomize_collection_length__ = True
__min_collection_length__ = 5
__max_collection_length__ = 5

@classmethod
def get_provider_map(cls) -> Dict[Any, Callable[[], Any]]:
return {
**super().get_provider_map(),
int: lambda: 42,
CustomType: lambda: CustomType(a=5),
}

result = ParentFactory.build()
assert len(result.children) == 5
assert result.children[0].a == [42] * 5
assert result.children[0].custom_type.a == 5

0 comments on commit fa34a6f

Please sign in to comment.