From fa34a6faafaf8e0e95dd0b3e35bcfad871968761 Mon Sep 17 00:00:00 2001 From: Andrew Truong Date: Tue, 16 Jan 2024 12:04:54 +0000 Subject: [PATCH] feat: pass on factory config --- polyfactory/factories/base.py | 28 ++++++++++++++-- polyfactory/factories/sqlalchemy_factory.py | 7 ++++ tests/test_factory_configuration.py | 37 ++++++++++++++++++++- 3 files changed, 69 insertions(+), 3 deletions(-) diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index 9ab29c0e..1ce347bc 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -175,6 +175,20 @@ class BaseFactory(ABC, Generic[T]): """ Flag indicating whether to use the default value on a specific field, if provided. """ + __extra_providers__: dict[Any, Callable[[], Any]] | None = None + + __config_keys__: tuple[str, ...] = ( + "__check_model__", + "__allow_none_optionals__", + "__set_as_default_factory_for_type__", + "__faker__", + "__random__", + "__randomize_collection_length__", + "__min_collection_length__", + "__max_collection_length__", + "__use_defaults__", + ) + """Keys to be considered as config values to pass on to dynamically created factories.""" # cached attributes _fields_metadata: list[FieldMeta] @@ -345,6 +359,13 @@ def _handle_factory_field_coverage( return CoverageContainerCallable(field_value) if callable(field_value) else field_value + @classmethod + def _get_config(cls) -> dict[str, Any]: + return { + **{key: getattr(cls, key) for key in cls.__config_keys__}, + "__extra_providers__": cls.get_provider_map(), + } + @classmethod def _get_or_create_factory(cls, model: type) -> type[BaseFactory[Any]]: """Get a factory from registered factories or generate a factory dynamically. @@ -356,14 +377,16 @@ def _get_or_create_factory(cls, model: type) -> type[BaseFactory[Any]]: if factory := BaseFactory._factory_type_mapping.get(model): return factory + config = cls._get_config() + if cls.__base_factory_overrides__: for model_ancestor in model.mro(): if factory := cls.__base_factory_overrides__.get(model_ancestor): - return factory.create_factory(model) + return factory.create_factory(model, **config) for factory in reversed(BaseFactory._base_factories): if factory.is_supported_type(model): - return factory.create_factory(model) + return factory.create_factory(model, **config) msg = f"unsupported model type {model.__name__}" raise ParameterException(msg) # pragma: no cover @@ -503,6 +526,7 @@ def _create_generic_fn() -> Callable: Callable: _create_generic_fn, abc.Callable: _create_generic_fn, Counter: lambda: Counter(cls.__faker__.pystr()), + **(cls.__extra_providers__ or {}), } @classmethod diff --git a/polyfactory/factories/sqlalchemy_factory.py b/polyfactory/factories/sqlalchemy_factory.py index c423d0b9..f34d3af6 100644 --- a/polyfactory/factories/sqlalchemy_factory.py +++ b/polyfactory/factories/sqlalchemy_factory.py @@ -73,6 +73,13 @@ class SQLAlchemyFactory(Generic[T], BaseFactory[T]): __session__: ClassVar[Session | Callable[[], Session] | None] = None __async_session__: ClassVar[AsyncSession | Callable[[], AsyncSession] | None] = None + __config_keys__ = ( + *BaseFactory.__config_keys__, + "__set_primary_key__", + "__set_foreign_keys__", + "__set_relationships__", + ) + @classmethod def get_sqlalchemy_types(cls) -> dict[Any, Callable[[], Any]]: """Get mapping of types where column type.""" diff --git a/tests/test_factory_configuration.py b/tests/test_factory_configuration.py index aa928693..38e902a2 100644 --- a/tests/test_factory_configuration.py +++ b/tests/test_factory_configuration.py @@ -1,8 +1,10 @@ -from typing import Any, Type +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Type from typing_extensions import TypeGuard from polyfactory.factories.base import BaseFactory, T +from polyfactory.factories.dataclass_factory import DataclassFactory def test_setting_set_as_default_factory_for_type_on_base_factory() -> None: @@ -18,3 +20,36 @@ def is_supported_type(cls, value: Any) -> TypeGuard[Type[T]]: # list of base factories, but this obviously shouldn't be ran # for any of the types. return False + + +def test_inheriting_config() -> None: + class CustomType: + def __init__(self, a: int) -> None: + self.a = a + + @dataclass + class Child: + a: List[int] + custom_type: CustomType + + @dataclass + class Parent: + children: List[Child] + + class ParentFactory(DataclassFactory[Parent]): + __randomize_collection_length__ = True + __min_collection_length__ = 5 + __max_collection_length__ = 5 + + @classmethod + def get_provider_map(cls) -> Dict[Any, Callable[[], Any]]: + return { + **super().get_provider_map(), + int: lambda: 42, + CustomType: lambda: CustomType(a=5), + } + + result = ParentFactory.build() + assert len(result.children) == 5 + assert result.children[0].a == [42] * 5 + assert result.children[0].custom_type.a == 5