diff --git a/packages/modern-di/modern_di/providers/factory.py b/packages/modern-di/modern_di/providers/factory.py index b65e6ed..660e8d6 100644 --- a/packages/modern-di/modern_di/providers/factory.py +++ b/packages/modern-di/modern_di/providers/factory.py @@ -3,6 +3,7 @@ from modern_di import Container from modern_di.providers.abstract import AbstractCreatorProvider +from modern_di.providers.injected_factory import InjectedFactory T_co = typing.TypeVar("T_co", covariant=True) @@ -21,6 +22,10 @@ def __init__( ) -> None: super().__init__(scope, creator, *args, **kwargs) + @property + def factory_provider(self) -> InjectedFactory[T_co]: + return InjectedFactory(self) + async def async_resolve(self, container: Container) -> T_co: container = container.find_container(self.scope) if (override := container.fetch_override(self.provider_id)) is not None: diff --git a/packages/modern-di/modern_di/providers/injected_factory.py b/packages/modern-di/modern_di/providers/injected_factory.py new file mode 100644 index 0000000..996cb66 --- /dev/null +++ b/packages/modern-di/modern_di/providers/injected_factory.py @@ -0,0 +1,24 @@ +import functools +import typing + +from modern_di import Container +from modern_di.providers.abstract import AbstractProvider + + +T_co = typing.TypeVar("T_co", covariant=True) +P = typing.ParamSpec("P") + + +class InjectedFactory(typing.Generic[T_co]): + __slots__ = ("_factory_provider",) + + def __init__(self, factory_provider: AbstractProvider[T_co]) -> None: + self._factory_provider = factory_provider + + async def async_resolve(self, container: Container) -> typing.Callable[[], T_co]: + await self._factory_provider.async_resolve(container) + return functools.partial(self._factory_provider.sync_resolve, container) + + def sync_resolve(self, container: Container) -> typing.Callable[[], T_co]: + self._factory_provider.sync_resolve(container) + return functools.partial(self._factory_provider.sync_resolve, container) diff --git a/packages/modern-di/tests_core/providers/test_injected_factory.py b/packages/modern-di/tests_core/providers/test_injected_factory.py new file mode 100644 index 0000000..e51d864 --- /dev/null +++ b/packages/modern-di/tests_core/providers/test_injected_factory.py @@ -0,0 +1,54 @@ +import dataclasses +import datetime + +import pytest +from modern_di import Container, Scope, providers + +from tests_core.creators import create_async_resource, create_sync_resource + + +@dataclasses.dataclass(kw_only=True, slots=True) +class DependentCreator: + dep1: datetime.datetime + + +async_resource = providers.Resource(Scope.APP, create_async_resource) +sync_resource = providers.Resource(Scope.APP, create_sync_resource) +request_sync_factory = providers.Factory(Scope.REQUEST, DependentCreator, dep1=sync_resource.cast) +request_async_factory = providers.Factory(Scope.REQUEST, DependentCreator, dep1=async_resource.cast) + + +async def test_injected_async_factory() -> None: + async with ( + Container(scope=Scope.APP) as app_container, + app_container.build_child_container(scope=Scope.REQUEST) as request_container, + ): + factory = await request_async_factory.factory_provider.async_resolve(request_container) + instance1, instance2 = factory(), factory() + assert instance1 is not instance2 + assert isinstance(instance1, DependentCreator) + assert isinstance(instance2, DependentCreator) + + +async def test_injected_sync_factory() -> None: + with ( + Container(scope=Scope.APP) as app_container, + app_container.build_child_container(scope=Scope.REQUEST) as request_container, + ): + factory = request_sync_factory.factory_provider.sync_resolve(request_container) + instance1, instance2 = factory(), factory() + assert instance1 is not instance2 + assert isinstance(instance1, DependentCreator) + assert isinstance(instance2, DependentCreator) + + +async def test_injected_async_factory_in_sync_mode() -> None: + with ( + Container(scope=Scope.APP) as app_container, + app_container.build_child_container(scope=Scope.REQUEST) as request_container, + ): + with pytest.raises(RuntimeError, match="Resolving async resource in sync container is not allowed"): + await request_async_factory.factory_provider.async_resolve(request_container) + + with pytest.raises(RuntimeError, match="Async resource cannot be resolved synchronously"): + request_async_factory.factory_provider.sync_resolve(request_container)