Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions packages/modern-di/modern_di/providers/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from modern_di import Container
from modern_di.providers.abstract import AbstractCreatorProvider
from modern_di.providers.injected_factory import InjectedFactory


T_co = typing.TypeVar("T_co", covariant=True)
Expand All @@ -21,6 +22,10 @@ def __init__(
) -> None:
super().__init__(scope, creator, *args, **kwargs)

@property
def factory_provider(self) -> InjectedFactory[T_co]:
return InjectedFactory(self)

async def async_resolve(self, container: Container) -> T_co:
container = container.find_container(self.scope)
if (override := container.fetch_override(self.provider_id)) is not None:
Expand Down
24 changes: 24 additions & 0 deletions packages/modern-di/modern_di/providers/injected_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import functools
import typing

from modern_di import Container
from modern_di.providers.abstract import AbstractProvider


T_co = typing.TypeVar("T_co", covariant=True)
P = typing.ParamSpec("P")


class InjectedFactory(typing.Generic[T_co]):
__slots__ = ("_factory_provider",)

def __init__(self, factory_provider: AbstractProvider[T_co]) -> None:
self._factory_provider = factory_provider

async def async_resolve(self, container: Container) -> typing.Callable[[], T_co]:
await self._factory_provider.async_resolve(container)
return functools.partial(self._factory_provider.sync_resolve, container)

def sync_resolve(self, container: Container) -> typing.Callable[[], T_co]:
self._factory_provider.sync_resolve(container)
return functools.partial(self._factory_provider.sync_resolve, container)
54 changes: 54 additions & 0 deletions packages/modern-di/tests_core/providers/test_injected_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import dataclasses
import datetime

import pytest
from modern_di import Container, Scope, providers

from tests_core.creators import create_async_resource, create_sync_resource


@dataclasses.dataclass(kw_only=True, slots=True)
class DependentCreator:
dep1: datetime.datetime


async_resource = providers.Resource(Scope.APP, create_async_resource)
sync_resource = providers.Resource(Scope.APP, create_sync_resource)
request_sync_factory = providers.Factory(Scope.REQUEST, DependentCreator, dep1=sync_resource.cast)
request_async_factory = providers.Factory(Scope.REQUEST, DependentCreator, dep1=async_resource.cast)


async def test_injected_async_factory() -> None:
async with (
Container(scope=Scope.APP) as app_container,
app_container.build_child_container(scope=Scope.REQUEST) as request_container,
):
factory = await request_async_factory.factory_provider.async_resolve(request_container)
instance1, instance2 = factory(), factory()
assert instance1 is not instance2
assert isinstance(instance1, DependentCreator)
assert isinstance(instance2, DependentCreator)


async def test_injected_sync_factory() -> None:
with (
Container(scope=Scope.APP) as app_container,
app_container.build_child_container(scope=Scope.REQUEST) as request_container,
):
factory = request_sync_factory.factory_provider.sync_resolve(request_container)
instance1, instance2 = factory(), factory()
assert instance1 is not instance2
assert isinstance(instance1, DependentCreator)
assert isinstance(instance2, DependentCreator)


async def test_injected_async_factory_in_sync_mode() -> None:
with (
Container(scope=Scope.APP) as app_container,
app_container.build_child_container(scope=Scope.REQUEST) as request_container,
):
with pytest.raises(RuntimeError, match="Resolving async resource in sync container is not allowed"):
await request_async_factory.factory_provider.async_resolve(request_container)

with pytest.raises(RuntimeError, match="Async resource cannot be resolved synchronously"):
request_async_factory.factory_provider.sync_resolve(request_container)