Skip to content

Commit

Permalink
Merge 9f6d236 into c787ac2
Browse files Browse the repository at this point in the history
  • Loading branch information
rmk135 committed Mar 1, 2021
2 parents c787ac2 + 9f6d236 commit faf8e5c
Show file tree
Hide file tree
Showing 9 changed files with 242 additions and 27 deletions.
7 changes: 7 additions & 0 deletions docs/main/changelog.rst
Expand Up @@ -7,6 +7,13 @@ that were made in every particular version.
From version 0.7.6 *Dependency Injector* framework strictly
follows `Semantic versioning`_

Development version
-------------------
- Add wiring injections into modules and class attributes.
See issue: `#411 <https://github.com/ets-labs/python-dependency-injector/issues/411>`_.
Many thanks to `@brunopereira27 <https://github.com/brunopereira27>`_ for submitting
the use case.

4.27.0
------
- Introduce wiring inspect filter to filter out ``flask.request`` and other local proxy objects
Expand Down
23 changes: 23 additions & 0 deletions docs/wiring.rst
Expand Up @@ -164,6 +164,29 @@ To inject a container use special identifier ``<container>``:
def foo(container: Container = Provide['<container>']) -> None:
...
Making injections into modules and class attributes
---------------------------------------------------

You can use wiring to make injections into modules and class attributes.

.. literalinclude:: ../examples/wiring/example_attribute.py
:language: python
:lines: 3-
:emphasize-lines: 16,21

You could also use string identifiers to avoid a dependency on a container:

.. code-block:: python
:emphasize-lines: 1,6
service: Service = Provide['service']
class Main:
service: Service = Provide['service']
Wiring with modules and packages
--------------------------------

Expand Down
31 changes: 31 additions & 0 deletions examples/wiring/example_attribute.py
@@ -0,0 +1,31 @@
"""Wiring attribute example."""

import sys

from dependency_injector import containers, providers
from dependency_injector.wiring import Provide


class Service:
...


class Container(containers.DeclarativeContainer):

service = providers.Factory(Service)


service: Service = Provide[Container.service]


class Main:

service: Service = Provide[Container.service]


if __name__ == '__main__':
container = Container()
container.wire(modules=[sys.modules[__name__]])

assert isinstance(service, Service)
assert isinstance(Main.service, Service)
123 changes: 97 additions & 26 deletions src/dependency_injector/wiring.py
Expand Up @@ -20,6 +20,7 @@
TypeVar,
Type,
Union,
Set,
cast,
)

Expand Down Expand Up @@ -82,22 +83,53 @@ class GenericMeta(type):
Container = Any


class Registry:
class PatchedRegistry:

def __init__(self):
self._storage = set()
self._callables: Set[Callable[..., Any]] = set()
self._attributes: Set[PatchedAttribute] = set()

def add(self, patched: Callable[..., Any]) -> None:
self._storage.add(patched)
def add_callable(self, patched: Callable[..., Any]) -> None:
self._callables.add(patched)

def get_from_module(self, module: ModuleType) -> Iterator[Callable[..., Any]]:
for patched in self._storage:
def get_callables_from_module(self, module: ModuleType) -> Iterator[Callable[..., Any]]:
for patched in self._callables:
if patched.__module__ != module.__name__:
continue
yield patched

def add_attribute(self, patched: 'PatchedAttribute'):
self._attributes.add(patched)

_patched_registry = Registry()
def get_attributes_from_module(self, module: ModuleType) -> Iterator['PatchedAttribute']:
for attribute in self._attributes:
if not attribute.is_in_module(module):
continue
yield attribute

def clear_module_attributes(self, module: ModuleType):
for attribute in self._attributes.copy():
if not attribute.is_in_module(module):
continue
self._attributes.remove(attribute)


class PatchedAttribute:

def __init__(self, member: Any, name: str, marker: '_Marker'):
self.member = member
self.name = name
self.marker = marker

@property
def module_name(self) -> str:
if isinstance(self.member, ModuleType):
return self.member.__name__
else:
return self.member.__module__

def is_in_module(self, module: ModuleType) -> bool:
return self.module_name == module.__name__


class ProvidersMap:
Expand Down Expand Up @@ -278,9 +310,6 @@ def _is_starlette_request_cls(self, instance: object) -> bool:
and issubclass(instance, starlette.requests.Request)


inspect_filter = InspectFilter()


def wire( # noqa: C901
container: Container,
*,
Expand All @@ -301,20 +330,27 @@ def wire( # noqa: C901
providers_map = ProvidersMap(container)

for module in modules:
for name, member in inspect.getmembers(module):
if inspect_filter.is_excluded(member):
for member_name, member in inspect.getmembers(module):
if _inspect_filter.is_excluded(member):
continue
if inspect.isfunction(member):
_patch_fn(module, name, member, providers_map)
elif inspect.isclass(member):
for method_name, method in inspect.getmembers(member, _is_method):
_patch_method(member, method_name, method, providers_map)

for patched in _patched_registry.get_from_module(module):
if _is_marker(member):
_patch_attribute(module, member_name, member, providers_map)
elif inspect.isfunction(member):
_patch_fn(module, member_name, member, providers_map)
elif inspect.isclass(member):
cls = member
for cls_member_name, cls_member in inspect.getmembers(cls):
if _is_marker(cls_member):
_patch_attribute(cls, cls_member_name, cls_member, providers_map)
elif _is_method(cls_member):
_patch_method(cls, cls_member_name, cls_member, providers_map)

for patched in _patched_registry.get_callables_from_module(module):
_bind_injections(patched, providers_map)


def unwire(
def unwire( # noqa: C901
*,
modules: Optional[Iterable[ModuleType]] = None,
packages: Optional[Iterable[ModuleType]] = None,
Expand All @@ -335,15 +371,19 @@ def unwire(
for method_name, method in inspect.getmembers(member, inspect.isfunction):
_unpatch(member, method_name, method)

for patched in _patched_registry.get_from_module(module):
for patched in _patched_registry.get_callables_from_module(module):
_unbind_injections(patched)

for patched_attribute in _patched_registry.get_attributes_from_module(module):
_unpatch_attribute(patched_attribute)
_patched_registry.clear_module_attributes(module)


def inject(fn: F) -> F:
"""Decorate callable with injecting decorator."""
reference_injections, reference_closing = _fetch_reference_injections(fn)
patched = _get_patched(fn, reference_injections, reference_closing)
_patched_registry.add(patched)
_patched_registry.add_callable(patched)
return cast(F, patched)


Expand All @@ -358,7 +398,7 @@ def _patch_fn(
if not reference_injections:
return
fn = _get_patched(fn, reference_injections, reference_closing)
_patched_registry.add(fn)
_patched_registry.add_callable(fn)

_bind_injections(fn, providers_map)

Expand All @@ -384,7 +424,7 @@ def _patch_method(
if not reference_injections:
return
fn = _get_patched(fn, reference_injections, reference_closing)
_patched_registry.add(fn)
_patched_registry.add_callable(fn)

_bind_injections(fn, providers_map)

Expand All @@ -411,6 +451,31 @@ def _unpatch(
_unbind_injections(fn)


def _patch_attribute(
member: Any,
name: str,
marker: '_Marker',
providers_map: ProvidersMap,
) -> None:
provider = providers_map.resolve_provider(marker.provider, marker.modifier)
if provider is None:
return

_patched_registry.add_attribute(PatchedAttribute(member, name, marker))

if isinstance(marker, Provide):
instance = provider()
setattr(member, name, instance)
elif isinstance(marker, Provider):
setattr(member, name, provider)
else:
raise Exception(f'Unknown type of marker {marker}')


def _unpatch_attribute(patched: PatchedAttribute) -> None:
setattr(patched.member, patched.name, patched.marker)


def _fetch_reference_injections(
fn: Callable[..., Any],
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
Expand Down Expand Up @@ -484,6 +549,10 @@ def _is_method(member):
return inspect.ismethod(member) or inspect.isfunction(member)


def _is_marker(member):
return isinstance(member, _Marker)


def _get_patched(fn, reference_injections, reference_closing):
if inspect.iscoroutinefunction(fn):
patched = _get_async_patched(fn)
Expand Down Expand Up @@ -825,9 +894,6 @@ def uninstall(self):
importlib.invalidate_caches()


_loader = AutoLoader()


def register_loader_containers(*containers: Container) -> None:
"""Register containers in auto-wiring module loader."""
_loader.register_containers(*containers)
Expand All @@ -851,3 +917,8 @@ def uninstall_loader() -> None:
def is_loader_installed() -> bool:
"""Check if auto-wiring module loader hook is installed."""
return _loader.installed


_patched_registry = PatchedRegistry()
_inspect_filter = InspectFilter()
_loader = AutoLoader()
10 changes: 10 additions & 0 deletions tests/unit/samples/wiringsamples/module.py
Expand Up @@ -3,14 +3,24 @@
from decimal import Decimal
from typing import Callable

from dependency_injector import providers
from dependency_injector.wiring import inject, Provide, Provider

from .container import Container, SubContainer
from .service import Service


service: Service = Provide[Container.service]
service_provider: Callable[..., Service] = Provider[Container.service]
undefined: Callable = Provide[providers.Provider()]


class TestClass:

service: Service = Provide[Container.service]
service_provider: Callable[..., Service] = Provider[Container.service]
undefined: Callable = Provide[providers.Provider()]

@inject
def __init__(self, service: Service = Provide[Container.service]):
self.service = service
Expand Down
@@ -0,0 +1,8 @@
"""Test module for wiring with invalid type of marker for attribute injection."""

from dependency_injector.wiring import Closing

from .container import Container


service = Closing[Container.service]
9 changes: 9 additions & 0 deletions tests/unit/samples/wiringstringidssamples/module.py
Expand Up @@ -19,8 +19,17 @@
from .service import Service


service: Service = Provide['service']
service_provider: Callable[..., Service] = Provider['service']
undefined: Callable = Provide['undefined']


class TestClass:

service: Service = Provide['service']
service_provider: Callable[..., Service] = Provider['service']
undefined: Callable = Provide['undefined']

@inject
def __init__(self, service: Service = Provide['service']):
self.service = service
Expand Down

0 comments on commit faf8e5c

Please sign in to comment.