Skip to content

Commit

Permalink
docs: Add note in docs
Browse files Browse the repository at this point in the history
  • Loading branch information
adhtruong committed Jan 17, 2024
1 parent 7fe5528 commit 1bbeecc
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
4 changes: 4 additions & 0 deletions docs/usage/handling_custom_types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,7 @@ for handling dataclasses:
.. literalinclude:: /examples/handling_custom_types/test_example_2.py
:caption: Creating a custom dataclass factory with extended provider map
:language: python

.. note::
If extra configs values are defined for custom base classes, then ``__config_keys__`` should be extended so
that these values are correctly passed onto to concrete factories.
8 changes: 5 additions & 3 deletions polyfactory/factories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ class BaseFactory(ABC, Generic[T]):
"""
Flag indicating whether to use the default value on a specific field, if provided.
"""
__extra_providers__: dict[Any, Callable[[], Any]] | None = None

__config_keys__: tuple[str, ...] = (
"__check_model__",
Expand All @@ -196,6 +195,9 @@ class BaseFactory(ABC, Generic[T]):
_factory_type_mapping: ClassVar[dict[Any, type[BaseFactory[Any]]]]
_base_factories: ClassVar[list[type[BaseFactory[Any]]]]

# Non-public attributes
_extra_providers: dict[Any, Callable[[], Any]] | None = None

def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: # noqa: C901
super().__init_subclass__(*args, **kwargs)

Expand Down Expand Up @@ -363,7 +365,7 @@ def _handle_factory_field_coverage(
def _get_config(cls) -> dict[str, Any]:
return {
**{key: getattr(cls, key) for key in cls.__config_keys__},
"__extra_providers__": cls.get_provider_map(),
"_extra_providers": cls.get_provider_map(),
}

@classmethod
Expand Down Expand Up @@ -526,7 +528,7 @@ def _create_generic_fn() -> Callable:
Callable: _create_generic_fn,
abc.Callable: _create_generic_fn,
Counter: lambda: Counter(cls.__faker__.pystr()),
**(cls.__extra_providers__ or {}),
**(cls._extra_providers or {}),
}

@classmethod
Expand Down

0 comments on commit 1bbeecc

Please sign in to comment.