Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions modern_di/graph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import typing

from modern_di import Container
from modern_di.providers import AbstractProvider, BaseCreatorProvider
from modern_di.providers.abstract import AbstractCreatorProvider, AbstractProvider


if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -29,11 +29,11 @@ def get_providers(cls) -> dict[str, AbstractProvider[typing.Any]]:
@classmethod
async def async_resolve_creators(cls, container: Container) -> None:
for provider in cls.get_providers().values():
if isinstance(provider, BaseCreatorProvider) and provider.scope == container.scope:
if isinstance(provider, AbstractCreatorProvider) and provider.scope == container.scope:
await provider.async_resolve(container)

@classmethod
def sync_resolve_creators(cls, container: Container) -> None:
for provider in cls.get_providers().values():
if isinstance(provider, BaseCreatorProvider) and provider.scope == container.scope:
if isinstance(provider, AbstractCreatorProvider) and provider.scope == container.scope:
provider.sync_resolve(container)
11 changes: 6 additions & 5 deletions modern_di/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from modern_di.providers.base import AbstractProvider, BaseCreatorProvider
from modern_di.providers.context_adapter import ContextAdapter
from modern_di.providers.dict import Dict
from modern_di.providers.factory import Factory
from modern_di.providers.list import List
from modern_di.providers.resource import Resource
from modern_di.providers.selector import Selector
from modern_di.providers.singleton import Singleton


__all__ = [
"AbstractProvider",
"BaseCreatorProvider",
"ContextAdapter",
"Factory",
"Dict",
"List",
"Selector",
"Singleton",
"Resource",
]
23 changes: 13 additions & 10 deletions modern_di/providers/base.py → modern_di/providers/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,25 @@ async def async_resolve(self, container: Container) -> T_co:
def sync_resolve(self, container: Container) -> T_co:
"""Resolve dependency synchronously."""

@property
def cast(self) -> T_co:
return typing.cast(T_co, self)

def _check_providers_scope(self, providers: typing.Iterable[typing.Any]) -> None:
if any(x.scope > self.scope for x in providers if isinstance(x, AbstractProvider)):
msg = "Scope of dependency cannot be more than scope of dependent"
raise RuntimeError(msg)


class AbstractOverrideProvider(AbstractProvider[T_co], abc.ABC):
def override(self, override_object: object, container: Container) -> None:
container.override(self.provider_id, override_object)

def reset_override(self, container: Container) -> None:
container.reset_override(self.provider_id)

@property
def cast(self) -> T_co:
return typing.cast(T_co, self)


class BaseCreatorProvider(AbstractProvider[T_co], abc.ABC):
class AbstractCreatorProvider(AbstractOverrideProvider[T_co], abc.ABC):
BASE_SLOTS: typing.ClassVar = [*AbstractProvider.BASE_SLOTS, "_args", "_kwargs", "_creator"]

def __init__(
Expand All @@ -49,11 +56,7 @@ def __init__(
**kwargs: P.kwargs,
) -> None:
super().__init__(scope)

if any(x.scope > self.scope for x in itertools.chain(args, kwargs.values()) if isinstance(x, AbstractProvider)):
msg = "Scope of dependency cannot be more than scope of dependent"
raise RuntimeError(msg)

self._check_providers_scope(itertools.chain(args, kwargs.values()))
self._creator: typing.Final = creator
self._args: typing.Final = args
self._kwargs: typing.Final = kwargs
Expand Down
31 changes: 0 additions & 31 deletions modern_di/providers/context_adapter.py

This file was deleted.

23 changes: 23 additions & 0 deletions modern_di/providers/dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import enum
import typing

from modern_di import Container
from modern_di.providers.abstract import AbstractProvider


T_co = typing.TypeVar("T_co", covariant=True)


class Dict(AbstractProvider[dict[str, T_co]]):
__slots__ = [*AbstractProvider.BASE_SLOTS, "_providers"]

def __init__(self, scope: enum.IntEnum, **providers: AbstractProvider[T_co]) -> None:
super().__init__(scope)
self._check_providers_scope(providers.values())
self._providers: typing.Final = providers

async def async_resolve(self, container: Container) -> dict[str, T_co]:
return {key: await provider.async_resolve(container) for key, provider in self._providers.items()}

def sync_resolve(self, container: Container) -> dict[str, T_co]:
return {key: provider.sync_resolve(container) for key, provider in self._providers.items()}
6 changes: 3 additions & 3 deletions modern_di/providers/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
import typing

from modern_di import Container
from modern_di.providers import BaseCreatorProvider
from modern_di.providers.abstract import AbstractCreatorProvider


T_co = typing.TypeVar("T_co", covariant=True)
P = typing.ParamSpec("P")


class Factory(BaseCreatorProvider[T_co]):
__slots__ = [*BaseCreatorProvider.BASE_SLOTS, "_creator"]
class Factory(AbstractCreatorProvider[T_co]):
__slots__ = [*AbstractCreatorProvider.BASE_SLOTS, "_creator"]

def __init__(
self,
Expand Down
23 changes: 23 additions & 0 deletions modern_di/providers/list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import enum
import typing

from modern_di import Container
from modern_di.providers.abstract import AbstractProvider


T_co = typing.TypeVar("T_co", covariant=True)


class List(AbstractProvider[list[T_co]]):
__slots__ = [*AbstractProvider.BASE_SLOTS, "_providers"]

def __init__(self, scope: enum.IntEnum, *providers: AbstractProvider[T_co]) -> None:
super().__init__(scope)
self._check_providers_scope(providers)
self._providers: typing.Final = providers

async def async_resolve(self, container: Container) -> list[T_co]:
return [await x.async_resolve(container) for x in self._providers]

def sync_resolve(self, container: Container) -> list[T_co]:
return [x.sync_resolve(container) for x in self._providers]
6 changes: 3 additions & 3 deletions modern_di/providers/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
import typing

from modern_di import Container
from modern_di.providers import BaseCreatorProvider
from modern_di.providers.abstract import AbstractCreatorProvider


T_co = typing.TypeVar("T_co", covariant=True)
P = typing.ParamSpec("P")


class Resource(BaseCreatorProvider[T_co]):
__slots__ = [*BaseCreatorProvider.BASE_SLOTS, "_creator", "_args", "_kwargs", "_is_async"]
class Resource(AbstractCreatorProvider[T_co]):
__slots__ = [*AbstractCreatorProvider.BASE_SLOTS, "_creator", "_args", "_kwargs", "_is_async"]

def _is_creator_async(
self,
Expand Down
39 changes: 39 additions & 0 deletions modern_di/providers/selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import enum
import typing

from modern_di import Container
from modern_di.providers.abstract import AbstractProvider


T_co = typing.TypeVar("T_co", covariant=True)
P = typing.ParamSpec("P")


class Selector(AbstractProvider[T_co]):
__slots__ = [*AbstractProvider.BASE_SLOTS, "_function", "_providers"]

def __init__(
self, scope: enum.IntEnum, function: typing.Callable[..., str], **providers: AbstractProvider[T_co]
) -> None:
super().__init__(scope)
self._check_providers_scope(providers.values())
self._function: typing.Final = function
self._providers: typing.Final = providers

async def async_resolve(self, container: Container) -> T_co:
container = container.find_container(self.scope)
selected_key = self._function(**container.context)
if selected_key not in self._providers:
msg = f"No provider matches {selected_key}"
raise RuntimeError(msg)

return await self._providers[selected_key].async_resolve(container)

def sync_resolve(self, container: Container) -> T_co:
container = container.find_container(self.scope)
selected_key = self._function(**container.context)
if selected_key not in self._providers:
msg = f"No provider matches {selected_key}"
raise RuntimeError(msg)

return self._providers[selected_key].sync_resolve(container)
6 changes: 3 additions & 3 deletions modern_di/providers/singleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
import typing

from modern_di import Container
from modern_di.providers import BaseCreatorProvider
from modern_di.providers.abstract import AbstractCreatorProvider


T_co = typing.TypeVar("T_co", covariant=True)
P = typing.ParamSpec("P")


class Singleton(BaseCreatorProvider[T_co]):
__slots__ = [*BaseCreatorProvider.BASE_SLOTS, "_creator"]
class Singleton(AbstractCreatorProvider[T_co]):
__slots__ = [*AbstractCreatorProvider.BASE_SLOTS, "_creator"]

def __init__(
self,
Expand Down
File renamed without changes.
24 changes: 24 additions & 0 deletions tests/providers/test_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pytest

from modern_di import Container, Scope, providers
from tests.creators import create_async_resource, create_sync_resource


async_resource = providers.Resource(Scope.APP, create_async_resource)
sync_resource = providers.Resource(Scope.APP, create_sync_resource)
mapping = providers.Dict(Scope.APP, dep1=async_resource, dep2=sync_resource)


async def test_dict() -> None:
async with Container(scope=Scope.APP, context={"option": "app"}) as app_container:
mapping1 = await mapping.async_resolve(app_container)
mapping2 = mapping.sync_resolve(app_container)
resource1 = await async_resource.async_resolve(app_container)
resource2 = sync_resource.sync_resolve(app_container)
assert mapping1 == mapping2 == {"dep1": resource1, "dep2": resource2}


async def test_dict_wrong_scope() -> None:
request_factory_ = providers.Factory(Scope.REQUEST, lambda: "")
with pytest.raises(RuntimeError, match="Scope of dependency cannot be more than scope of dependent"):
providers.Dict(Scope.APP, dep1=request_factory_)
File renamed without changes.
24 changes: 24 additions & 0 deletions tests/providers/test_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pytest

from modern_di import Container, Scope, providers
from tests.creators import create_async_resource, create_sync_resource


async_resource = providers.Resource(Scope.APP, create_async_resource)
sync_resource = providers.Resource(Scope.APP, create_sync_resource)
sequence = providers.List(Scope.APP, async_resource, sync_resource)


async def test_list() -> None:
async with Container(scope=Scope.APP, context={"option": "app"}) as app_container:
sequence1 = await sequence.async_resolve(app_container)
sequence2 = sequence.sync_resolve(app_container)
resource1 = await async_resource.async_resolve(app_container)
resource2 = sync_resource.sync_resolve(app_container)
assert sequence1 == sequence2 == [resource1, resource2]


async def test_list_wrong_scope() -> None:
request_factory_ = providers.Factory(Scope.REQUEST, lambda: "")
with pytest.raises(RuntimeError, match="Scope of dependency cannot be more than scope of dependent"):
providers.List(Scope.APP, request_factory_)
File renamed without changes.
45 changes: 45 additions & 0 deletions tests/providers/test_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest

from modern_di import Container, Scope, providers


def selector_function(*, option: str, **_: object) -> "str":
return option


app_factory = providers.Factory(Scope.APP, lambda: "app")
request_factory = providers.Factory(Scope.APP, lambda: "request")
app_selector = providers.Selector(Scope.APP, selector_function, app=app_factory, request=request_factory)
request_selector = providers.Selector(Scope.REQUEST, selector_function, app=app_factory, request=request_factory)


async def test_selector() -> None:
async with Container(scope=Scope.APP, context={"option": "app"}) as app_container:
instance1 = await app_selector.async_resolve(app_container)
instance2 = app_selector.sync_resolve(app_container)
assert instance1 == instance2 == "app"


async def test_selector_in_request_scope() -> None:
async with (
Container(scope=Scope.APP) as app_container,
app_container.build_child_container(context={"option": "request"}) as request_container,
):
instance1 = await request_selector.async_resolve(request_container)
instance2 = request_selector.sync_resolve(request_container)
assert instance1 == instance2 == "request"


async def test_selector_no_match() -> None:
async with Container(scope=Scope.APP, context={"option": "wrong"}) as app_container:
with pytest.raises(RuntimeError, match="No provider matches wrong"):
await app_selector.async_resolve(app_container)

with pytest.raises(RuntimeError, match="No provider matches wrong"):
app_selector.sync_resolve(app_container)


async def test_selector_wrong_scope() -> None:
request_factory_ = providers.Factory(Scope.REQUEST, lambda: "")
with pytest.raises(RuntimeError, match="Scope of dependency cannot be more than scope of dependent"):
providers.Selector(Scope.APP, lambda: "", request=request_factory_)
File renamed without changes.
Loading