Skip to content

Commit

Permalink
Implement wiring by string id
Browse files Browse the repository at this point in the history
  • Loading branch information
rmk135 committed Feb 20, 2021
1 parent c8a8603 commit f153f46
Show file tree
Hide file tree
Showing 12 changed files with 784 additions and 13 deletions.
151 changes: 138 additions & 13 deletions src/dependency_injector/wiring.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ class GenericMeta(type):
'wire',
'unwire',
'inject',
'as_int',
'as_float',
'as_',
'required',
'invariant',
'provided',
'Provide',
'Provider',
'Closing',
Expand Down Expand Up @@ -85,16 +91,19 @@ def get_from_module(self, module: ModuleType) -> Iterator[Callable[..., Any]]:

class ProvidersMap:

CONTAINER_STRING_ID = '<container>'

def __init__(self, container):
self._container = container
self._map = self._create_providers_map(
current_container=container,
original_container=container.declarative_parent,
original_container=container.declarative_parent if container.declarative_parent else container,
)

def resolve_provider(
self,
provider: providers.Provider,
provider: Union[providers.Provider, str],
modifier: Optional['Modifier'] = None,
) -> Optional[providers.Provider]:
if isinstance(provider, providers.Delegate):
return self._resolve_delegate(provider)
Expand All @@ -110,18 +119,40 @@ def resolve_provider(
elif isinstance(provider, providers.TypedConfigurationOption):
return self._resolve_config_option(provider.option, as_=provider.provides)
elif isinstance(provider, str):
current_provider = self._container
for segment in provider.split('.'):
current_provider = getattr(current_provider, segment)
return current_provider
return self._resolve_string_id(provider, modifier)
else:
return self._resolve_provider(provider)

def _resolve_delegate(
self,
original: providers.Delegate,
) -> Optional[providers.Provider]:
return self._resolve_provider(original.provides)
def _resolve_string_id(self, id: str, modifier: Optional['Modifier'] = None) -> Optional[providers.Provider]:
if id == self.CONTAINER_STRING_ID:
return self._container.__self__

provider = self._container
for segment in id.split('.'):
try:
provider = getattr(provider, segment)
except AttributeError:
return

if isinstance(modifier, TypeModifier):
provider = provider.as_(modifier.type_)
elif isinstance(modifier, RequiredModifier):
provider = provider.required()
if modifier.type_modifier:
provider = provider.as_(modifier.type_modifier.type_)
elif isinstance(modifier, InvariantModifier):
invariant_segment = self._resolve_string_id(modifier.id)
provider = provider[invariant_segment]
elif isinstance(modifier, ProvidedInstance):
provider = provider.provided
for type_, value in modifier.segments:
if type_ == ProvidedInstance.TYPE_ATTRIBUTE:
provider = getattr(provider, value)
elif type_ == ProvidedInstance.TYPE_ITEM:
provider = provider[value]
elif type_ == ProvidedInstance.TYPE_CALL:
provider = provider.call()
return provider

def _resolve_provided_instance(
self,
Expand Down Expand Up @@ -156,6 +187,12 @@ def _resolve_provided_instance(

return new

def _resolve_delegate(
self,
original: providers.Delegate,
) -> Optional[providers.Provider]:
return self._resolve_provider(original.provides)

def _resolve_config_option(
self,
original: providers.ConfigurationOption,
Expand Down Expand Up @@ -386,7 +423,7 @@ def _fetch_reference_injections(

def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> None:
for injection, marker in fn.__reference_injections__.items():
provider = providers_map.resolve_provider(marker.provider)
provider = providers_map.resolve_provider(marker.provider, marker.modifier)

if provider is None:
continue
Expand Down Expand Up @@ -521,20 +558,108 @@ def _is_declarative_container(instance: Any) -> bool:
and getattr(instance, 'declarative_parent', None) is None)


class Modifier:
...


class TypeModifier(Modifier):
def __init__(self, type_: Type):
self.type_ = type_


def as_int() -> TypeModifier:
return TypeModifier(int)


def as_float() -> TypeModifier:
return TypeModifier(float)


def as_(type_: Type) -> TypeModifier:
return TypeModifier(type_)


class RequiredModifier(Modifier):
def __init__(self):
self.type_modifier = None

def as_int(self) -> 'RequiredModifier':
self.type_modifier = TypeModifier(int)
return self


def as_float(self) -> 'RequiredModifier':
self.type_modifier = TypeModifier(float)
return self


def as_(self, type_: Type) -> 'RequiredModifier':
self.type_modifier = TypeModifier(type_)
return self


def required() -> RequiredModifier:
return RequiredModifier()


class InvariantModifier(Modifier):
def __init__(self, id: str) -> None:
self.id = id


def invariant(id: str) -> InvariantModifier:
return InvariantModifier(id)


class ProvidedInstance(Modifier):

TYPE_ATTRIBUTE = 'attr'
TYPE_ITEM = 'item'
TYPE_CALL = 'call'

def __init__(self):
self.segments = []

def __getattr__(self, item):
self.segments.append((self.TYPE_ATTRIBUTE, item))
return self

def __getitem__(self, item):
self.segments.append((self.TYPE_ITEM, item))
return self

def call(self):
self.segments.append((self.TYPE_CALL, None))
return self


def provided() -> ProvidedInstance:
return ProvidedInstance()


class ClassGetItemMeta(GenericMeta):
def __getitem__(cls, item):
# Spike for Python 3.6
if isinstance(item, tuple):
return cls(*item)
return cls(item)


class _Marker(Generic[T], metaclass=ClassGetItemMeta):

def __init__(self, provider: Union[providers.Provider, Container, str]) -> None:
def __init__(
self,
provider: Union[providers.Provider, Container, str],
modifier: Optional[Modifier] = None,
) -> None:
if _is_declarative_container(provider):
provider = provider.__self__
self.provider = provider
self.modifier = modifier

def __class_getitem__(cls, item) -> T:
if isinstance(item, tuple):
return cls(*item)
return cls(item)

def __call__(self) -> T:
Expand Down
Empty file.
50 changes: 50 additions & 0 deletions tests/unit/samples/wiringstringidssamples/asyncinjections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import asyncio

from dependency_injector import containers, providers
from dependency_injector.wiring import inject, Provide, Closing


class TestResource:
def __init__(self):
self.init_counter = 0
self.shutdown_counter = 0

def reset_counters(self):
self.init_counter = 0
self.shutdown_counter = 0


resource1 = TestResource()
resource2 = TestResource()


async def async_resource(resource):
await asyncio.sleep(0.001)
resource.init_counter += 1

yield resource

await asyncio.sleep(0.001)
resource.shutdown_counter += 1


class Container(containers.DeclarativeContainer):

resource1 = providers.Resource(async_resource, providers.Object(resource1))
resource2 = providers.Resource(async_resource, providers.Object(resource2))


@inject
async def async_injection(
resource1: object = Provide['resource1'],
resource2: object = Provide['resource2'],
):
return resource1, resource2


@inject
async def async_injection_with_closing(
resource1: object = Closing[Provide['resource1']],
resource2: object = Closing[Provide['resource2']],
):
return resource1, resource2
17 changes: 17 additions & 0 deletions tests/unit/samples/wiringstringidssamples/container.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from dependency_injector import containers, providers

from .service import Service


class SubContainer(containers.DeclarativeContainer):

int_object = providers.Object(1)


class Container(containers.DeclarativeContainer):

config = providers.Configuration()

service = providers.Factory(Service)

sub = providers.Container(SubContainer)

0 comments on commit f153f46

Please sign in to comment.