Skip to content

Commit

Permalink
Fix issue with wiring and resource initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
rmk135 committed Oct 30, 2020
1 parent 4c46d34 commit c5f799a
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 19 deletions.
42 changes: 23 additions & 19 deletions src/dependency_injector/wiring.py
Expand Up @@ -5,7 +5,7 @@
import pkgutil
import sys
from types import ModuleType
from typing import Optional, Iterable, Callable, Any, Tuple, List, Dict, Generic, TypeVar, cast
from typing import Optional, Iterable, Callable, Any, Tuple, Dict, Generic, TypeVar, cast

if sys.version_info < (3, 7):
from typing import GenericMeta
Expand Down Expand Up @@ -226,11 +226,11 @@ def _unpatch_fn(
def _resolve_injections(
fn: Callable[..., Any],
providers_map: ProvidersMap,
) -> Tuple[Dict[str, Any], List[Any]]:
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
signature = inspect.signature(fn)

injections = {}
closing = []
closing = {}
for parameter_name, parameter in signature.parameters.items():
if not isinstance(parameter.default, _Marker):
continue
Expand All @@ -246,7 +246,7 @@ def _resolve_injections(
continue

if closing_modifier:
closing.append(provider)
closing[parameter_name] = provider

if isinstance(marker, Provide):
injections[parameter_name] = provider
Expand Down Expand Up @@ -275,40 +275,44 @@ def _patch_with_injections(fn, injections, closing):
if inspect.iscoroutinefunction(fn):
@functools.wraps(fn)
async def _patched(*args, **kwargs):
to_inject = {}
to_inject = kwargs.copy()
for injection, provider in injections.items():
to_inject[injection] = provider()

to_inject.update(kwargs)
if injection not in kwargs:
to_inject[injection] = provider()

result = await fn(*args, **to_inject)

for provider in closing:
if isinstance(provider, providers.Resource):
provider.shutdown()
for injection, provider in closing.items():
if injection in kwargs:
continue
if not isinstance(provider, providers.Resource):
continue
provider.shutdown()

return result
else:
@functools.wraps(fn)
def _patched(*args, **kwargs):
to_inject = {}
to_inject = kwargs.copy()
for injection, provider in injections.items():
to_inject[injection] = provider()

to_inject.update(kwargs)
if injection not in kwargs:
to_inject[injection] = provider()

result = fn(*args, **to_inject)

for provider in closing:
if isinstance(provider, providers.Resource):
provider.shutdown()
for injection, provider in closing.items():
if injection in kwargs:
continue
if not isinstance(provider, providers.Resource):
continue
provider.shutdown()

return result

_patched.__wired__ = True
_patched.__original__ = fn
_patched.__injections__ = injections
_patched.__closing__ = []
_patched.__closing__ = closing

return _patched

Expand Down
5 changes: 5 additions & 0 deletions tests/unit/samples/wiringsamples/resourceclosing.py
Expand Up @@ -6,6 +6,11 @@ class Service:
init_counter: int = 0
shutdown_counter: int = 0

@classmethod
def reset_counter(cls):
cls.init_counter = 0
cls.shutdown_counter = 0

@classmethod
def init(cls):
cls.init_counter += 1
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/wiring/test_wiring_py36.py
Expand Up @@ -178,6 +178,8 @@ def test_wire_multiple_containers(self):
def test_closing_resource(self):
from wiringsamples import resourceclosing

resourceclosing.Service.reset_counter()

container = resourceclosing.Container()
container.wire(modules=[resourceclosing])
self.addCleanup(container.unwire)
Expand All @@ -193,3 +195,23 @@ def test_closing_resource(self):
self.assertEqual(result_2.shutdown_counter, 2)

self.assertIsNot(result_1, result_2)

def test_closing_resource_context(self):
from wiringsamples import resourceclosing

resourceclosing.Service.reset_counter()
service = resourceclosing.Service()

container = resourceclosing.Container()
container.wire(modules=[resourceclosing])
self.addCleanup(container.unwire)

result_1 = resourceclosing.test_function(service=service)
self.assertIs(result_1, service)
self.assertEqual(result_1.init_counter, 0)
self.assertEqual(result_1.shutdown_counter, 0)

result_2 = resourceclosing.test_function(service=service)
self.assertIs(result_2, service)
self.assertEqual(result_2.init_counter, 0)
self.assertEqual(result_2.shutdown_counter, 0)

0 comments on commit c5f799a

Please sign in to comment.