Skip to content

Commit

Permalink
Implement wiring autoloader
Browse files Browse the repository at this point in the history
  • Loading branch information
rmk135 committed Jan 29, 2021
1 parent 9225f9d commit 41e18d2
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 1 deletion.
101 changes: 101 additions & 0 deletions src/dependency_injector/wiring.py
Expand Up @@ -4,6 +4,7 @@
import functools
import inspect
import importlib
import importlib.machinery
import pkgutil
import sys
from types import ModuleType
Expand Down Expand Up @@ -52,6 +53,11 @@ class GenericMeta(type):
'Provide',
'Provider',
'Closing',
'register_loader_containers',
'unregister_loader_containers',
'install_loader',
'uninstall_loader',
'is_loader_installed',
)

T = TypeVar('T')
Expand Down Expand Up @@ -535,3 +541,98 @@ class Provider(_Marker):

class Closing(_Marker):
...


class AutoLoader:
"""Auto-wiring module loader.
Automatically wire containers when modules are imported.
"""

def __init__(self):
self.containers = []
self._path_hook = None

def register_containers(self, *containers):
self.containers.extend(containers)

if not self.installed:
self.install()

def unregister_containers(self, *containers):
for container in containers:
self.containers.remove(container)

if not self.containers:
self.uninstall()

def wire_module(self, module):
for container in self.containers:
container.wire(modules=[module])

@property
def installed(self):
return self._path_hook is not None

def install(self):
if self.installed:
return

loader = self

class SourcelessFileLoader(importlib.machinery.SourcelessFileLoader):
def exec_module(self, module):
super().exec_module(module)
loader.wire_module(module)

class SourceFileLoader(importlib.machinery.SourceFileLoader):
def exec_module(self, module):
super().exec_module(module)
loader.wire_module(module)

loader_details = [
(SourcelessFileLoader, importlib.machinery.BYTECODE_SUFFIXES),
(SourceFileLoader, importlib.machinery.SOURCE_SUFFIXES),
]

self._path_hook = importlib.machinery.FileFinder.path_hook(*loader_details)

sys.path_hooks.insert(0, self._path_hook)
sys.path_importer_cache.clear()
importlib.invalidate_caches()

def uninstall(self):
if not self.installed:
return

sys.path_hooks.remove(self._path_hook)
sys.path_importer_cache.clear()
importlib.invalidate_caches()


_loader = AutoLoader()


def register_loader_containers(*containers: Container) -> None:
"""Register containers in auto-wiring module loader."""
_loader.register_containers(*containers)


def unregister_loader_containers(*containers: Container) -> None:
"""Unregister containers from auto-wiring module loader."""
_loader.unregister_containers(*containers)


def install_loader() -> None:
"""Install auto-wiring module loader hook."""
_loader.install()


def uninstall_loader() -> None:
"""Uninstall auto-wiring module loader hook."""
_loader.uninstall()


def is_loader_installed() -> bool:
"""Check if auto-wiring module loader hook is installed."""
return _loader.installed
36 changes: 35 additions & 1 deletion tests/unit/wiring/test_wiring_py36.py
@@ -1,7 +1,15 @@
import contextlib
from decimal import Decimal
import importlib
import unittest

from dependency_injector.wiring import wire, Provide, Closing
from dependency_injector.wiring import (
wire,
Provide,
Closing,
register_loader_containers,
unregister_loader_containers,
)
from dependency_injector import errors

# Runtime import to avoid syntax errors in samples on Python < 3.5
Expand Down Expand Up @@ -367,3 +375,29 @@ def test_async_injections_with_closing(self):
self.assertIs(resource2, asyncinjections.resource2)
self.assertEqual(asyncinjections.resource2.init_counter, 2)
self.assertEqual(asyncinjections.resource2.shutdown_counter, 2)


class AutoLoaderTest(unittest.TestCase):

container: Container

def setUp(self) -> None:
self.container = Container(config={'a': {'b': {'c': 10}}})
importlib.reload(module)

def tearDown(self) -> None:
with contextlib.suppress(ValueError):
unregister_loader_containers(self.container)

self.container.unwire()

@classmethod
def tearDownClass(cls) -> None:
importlib.reload(module)

def test_register_container(self):
register_loader_containers(self.container)
importlib.reload(module)

service = module.test_function()
self.assertIsInstance(service, Service)

0 comments on commit 41e18d2

Please sign in to comment.