Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/main/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ that were made in every particular version.
From version 0.7.6 *Dependency Injector* framework strictly
follows `Semantic versioning`_

Development version
-------------------
- Add wiring injections into modules and class attributes.
See issue: `#411 <https://github.com/ets-labs/python-dependency-injector/issues/411>`_.
Many thanks to `@brunopereira27 <https://github.com/brunopereira27>`_ for submitting
the use case.

4.27.0
------
- Introduce wiring inspect filter to filter out ``flask.request`` and other local proxy objects
Expand Down
23 changes: 23 additions & 0 deletions docs/wiring.rst
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,29 @@ To inject a container use special identifier ``<container>``:
def foo(container: Container = Provide['<container>']) -> None:
...


Making injections into modules and class attributes
---------------------------------------------------

You can use wiring to make injections into modules and class attributes.

.. literalinclude:: ../examples/wiring/example_attribute.py
:language: python
:lines: 3-
:emphasize-lines: 16,21

You could also use string identifiers to avoid a dependency on a container:

.. code-block:: python
:emphasize-lines: 1,6

service: Service = Provide['service']


class Main:

service: Service = Provide['service']

Wiring with modules and packages
--------------------------------

Expand Down
31 changes: 31 additions & 0 deletions examples/wiring/example_attribute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Wiring attribute example."""

import sys

from dependency_injector import containers, providers
from dependency_injector.wiring import Provide


class Service:
...


class Container(containers.DeclarativeContainer):

service = providers.Factory(Service)


service: Service = Provide[Container.service]


class Main:

service: Service = Provide[Container.service]


if __name__ == '__main__':
container = Container()
container.wire(modules=[sys.modules[__name__]])

assert isinstance(service, Service)
assert isinstance(Main.service, Service)
123 changes: 97 additions & 26 deletions src/dependency_injector/wiring.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
TypeVar,
Type,
Union,
Set,
cast,
)

Expand Down Expand Up @@ -82,22 +83,53 @@ class GenericMeta(type):
Container = Any


class Registry:
class PatchedRegistry:

def __init__(self):
self._storage = set()
self._callables: Set[Callable[..., Any]] = set()
self._attributes: Set[PatchedAttribute] = set()

def add(self, patched: Callable[..., Any]) -> None:
self._storage.add(patched)
def add_callable(self, patched: Callable[..., Any]) -> None:
self._callables.add(patched)

def get_from_module(self, module: ModuleType) -> Iterator[Callable[..., Any]]:
for patched in self._storage:
def get_callables_from_module(self, module: ModuleType) -> Iterator[Callable[..., Any]]:
for patched in self._callables:
if patched.__module__ != module.__name__:
continue
yield patched

def add_attribute(self, patched: 'PatchedAttribute'):
self._attributes.add(patched)

_patched_registry = Registry()
def get_attributes_from_module(self, module: ModuleType) -> Iterator['PatchedAttribute']:
for attribute in self._attributes:
if not attribute.is_in_module(module):
continue
yield attribute

def clear_module_attributes(self, module: ModuleType):
for attribute in self._attributes.copy():
if not attribute.is_in_module(module):
continue
self._attributes.remove(attribute)


class PatchedAttribute:

def __init__(self, member: Any, name: str, marker: '_Marker'):
self.member = member
self.name = name
self.marker = marker

@property
def module_name(self) -> str:
if isinstance(self.member, ModuleType):
return self.member.__name__
else:
return self.member.__module__

def is_in_module(self, module: ModuleType) -> bool:
return self.module_name == module.__name__


class ProvidersMap:
Expand Down Expand Up @@ -278,9 +310,6 @@ def _is_starlette_request_cls(self, instance: object) -> bool:
and issubclass(instance, starlette.requests.Request)


inspect_filter = InspectFilter()


def wire( # noqa: C901
container: Container,
*,
Expand All @@ -301,20 +330,27 @@ def wire( # noqa: C901
providers_map = ProvidersMap(container)

for module in modules:
for name, member in inspect.getmembers(module):
if inspect_filter.is_excluded(member):
for member_name, member in inspect.getmembers(module):
if _inspect_filter.is_excluded(member):
continue
if inspect.isfunction(member):
_patch_fn(module, name, member, providers_map)
elif inspect.isclass(member):
for method_name, method in inspect.getmembers(member, _is_method):
_patch_method(member, method_name, method, providers_map)

for patched in _patched_registry.get_from_module(module):
if _is_marker(member):
_patch_attribute(module, member_name, member, providers_map)
elif inspect.isfunction(member):
_patch_fn(module, member_name, member, providers_map)
elif inspect.isclass(member):
cls = member
for cls_member_name, cls_member in inspect.getmembers(cls):
if _is_marker(cls_member):
_patch_attribute(cls, cls_member_name, cls_member, providers_map)
elif _is_method(cls_member):
_patch_method(cls, cls_member_name, cls_member, providers_map)

for patched in _patched_registry.get_callables_from_module(module):
_bind_injections(patched, providers_map)


def unwire(
def unwire( # noqa: C901
*,
modules: Optional[Iterable[ModuleType]] = None,
packages: Optional[Iterable[ModuleType]] = None,
Expand All @@ -335,15 +371,19 @@ def unwire(
for method_name, method in inspect.getmembers(member, inspect.isfunction):
_unpatch(member, method_name, method)

for patched in _patched_registry.get_from_module(module):
for patched in _patched_registry.get_callables_from_module(module):
_unbind_injections(patched)

for patched_attribute in _patched_registry.get_attributes_from_module(module):
_unpatch_attribute(patched_attribute)
_patched_registry.clear_module_attributes(module)


def inject(fn: F) -> F:
"""Decorate callable with injecting decorator."""
reference_injections, reference_closing = _fetch_reference_injections(fn)
patched = _get_patched(fn, reference_injections, reference_closing)
_patched_registry.add(patched)
_patched_registry.add_callable(patched)
return cast(F, patched)


Expand All @@ -358,7 +398,7 @@ def _patch_fn(
if not reference_injections:
return
fn = _get_patched(fn, reference_injections, reference_closing)
_patched_registry.add(fn)
_patched_registry.add_callable(fn)

_bind_injections(fn, providers_map)

Expand All @@ -384,7 +424,7 @@ def _patch_method(
if not reference_injections:
return
fn = _get_patched(fn, reference_injections, reference_closing)
_patched_registry.add(fn)
_patched_registry.add_callable(fn)

_bind_injections(fn, providers_map)

Expand All @@ -411,6 +451,31 @@ def _unpatch(
_unbind_injections(fn)


def _patch_attribute(
member: Any,
name: str,
marker: '_Marker',
providers_map: ProvidersMap,
) -> None:
provider = providers_map.resolve_provider(marker.provider, marker.modifier)
if provider is None:
return

_patched_registry.add_attribute(PatchedAttribute(member, name, marker))

if isinstance(marker, Provide):
instance = provider()
setattr(member, name, instance)
elif isinstance(marker, Provider):
setattr(member, name, provider)
else:
raise Exception(f'Unknown type of marker {marker}')


def _unpatch_attribute(patched: PatchedAttribute) -> None:
setattr(patched.member, patched.name, patched.marker)


def _fetch_reference_injections(
fn: Callable[..., Any],
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
Expand Down Expand Up @@ -484,6 +549,10 @@ def _is_method(member):
return inspect.ismethod(member) or inspect.isfunction(member)


def _is_marker(member):
return isinstance(member, _Marker)


def _get_patched(fn, reference_injections, reference_closing):
if inspect.iscoroutinefunction(fn):
patched = _get_async_patched(fn)
Expand Down Expand Up @@ -825,9 +894,6 @@ def uninstall(self):
importlib.invalidate_caches()


_loader = AutoLoader()


def register_loader_containers(*containers: Container) -> None:
"""Register containers in auto-wiring module loader."""
_loader.register_containers(*containers)
Expand All @@ -851,3 +917,8 @@ def uninstall_loader() -> None:
def is_loader_installed() -> bool:
"""Check if auto-wiring module loader hook is installed."""
return _loader.installed


_patched_registry = PatchedRegistry()
_inspect_filter = InspectFilter()
_loader = AutoLoader()
10 changes: 10 additions & 0 deletions tests/unit/samples/wiringsamples/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,24 @@
from decimal import Decimal
from typing import Callable

from dependency_injector import providers
from dependency_injector.wiring import inject, Provide, Provider

from .container import Container, SubContainer
from .service import Service


service: Service = Provide[Container.service]
service_provider: Callable[..., Service] = Provider[Container.service]
undefined: Callable = Provide[providers.Provider()]


class TestClass:

service: Service = Provide[Container.service]
service_provider: Callable[..., Service] = Provider[Container.service]
undefined: Callable = Provide[providers.Provider()]

@inject
def __init__(self, service: Service = Provide[Container.service]):
self.service = service
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Test module for wiring with invalid type of marker for attribute injection."""

from dependency_injector.wiring import Closing

from .container import Container


service = Closing[Container.service]
9 changes: 9 additions & 0 deletions tests/unit/samples/wiringstringidssamples/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,17 @@
from .service import Service


service: Service = Provide['service']
service_provider: Callable[..., Service] = Provider['service']
undefined: Callable = Provide['undefined']


class TestClass:

service: Service = Provide['service']
service_provider: Callable[..., Service] = Provider['service']
undefined: Callable = Provide['undefined']

@inject
def __init__(self, service: Service = Provide['service']):
self.service = service
Expand Down
Loading