diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index 1b7c8f2d..c4dc1298 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -5,7 +5,7 @@ import pkgutil import sys from types import ModuleType -from typing import Optional, Iterable, Callable, Any, Tuple, List, Dict, Generic, TypeVar, cast +from typing import Optional, Iterable, Callable, Any, Tuple, Dict, Generic, TypeVar, cast if sys.version_info < (3, 7): from typing import GenericMeta @@ -226,11 +226,11 @@ def _unpatch_fn( def _resolve_injections( fn: Callable[..., Any], providers_map: ProvidersMap, -) -> Tuple[Dict[str, Any], List[Any]]: +) -> Tuple[Dict[str, Any], Dict[str, Any]]: signature = inspect.signature(fn) injections = {} - closing = [] + closing = {} for parameter_name, parameter in signature.parameters.items(): if not isinstance(parameter.default, _Marker): continue @@ -246,7 +246,7 @@ def _resolve_injections( continue if closing_modifier: - closing.append(provider) + closing[parameter_name] = provider if isinstance(marker, Provide): injections[parameter_name] = provider @@ -275,40 +275,44 @@ def _patch_with_injections(fn, injections, closing): if inspect.iscoroutinefunction(fn): @functools.wraps(fn) async def _patched(*args, **kwargs): - to_inject = {} + to_inject = kwargs.copy() for injection, provider in injections.items(): - to_inject[injection] = provider() - - to_inject.update(kwargs) + if injection not in kwargs: + to_inject[injection] = provider() result = await fn(*args, **to_inject) - for provider in closing: - if isinstance(provider, providers.Resource): - provider.shutdown() + for injection, provider in closing.items(): + if injection in kwargs: + continue + if not isinstance(provider, providers.Resource): + continue + provider.shutdown() return result else: @functools.wraps(fn) def _patched(*args, **kwargs): - to_inject = {} + to_inject = kwargs.copy() for injection, provider in injections.items(): - to_inject[injection] = provider() - - to_inject.update(kwargs) + if injection not in kwargs: + to_inject[injection] = provider() result = fn(*args, **to_inject) - for provider in closing: - if isinstance(provider, providers.Resource): - provider.shutdown() + for injection, provider in closing.items(): + if injection in kwargs: + continue + if not isinstance(provider, providers.Resource): + continue + provider.shutdown() return result _patched.__wired__ = True _patched.__original__ = fn _patched.__injections__ = injections - _patched.__closing__ = [] + _patched.__closing__ = closing return _patched diff --git a/tests/unit/samples/wiringsamples/resourceclosing.py b/tests/unit/samples/wiringsamples/resourceclosing.py index 33a160ca..f7f35bd1 100644 --- a/tests/unit/samples/wiringsamples/resourceclosing.py +++ b/tests/unit/samples/wiringsamples/resourceclosing.py @@ -6,6 +6,11 @@ class Service: init_counter: int = 0 shutdown_counter: int = 0 + @classmethod + def reset_counter(cls): + cls.init_counter = 0 + cls.shutdown_counter = 0 + @classmethod def init(cls): cls.init_counter += 1 diff --git a/tests/unit/wiring/test_wiring_py36.py b/tests/unit/wiring/test_wiring_py36.py index f5b1a509..d8e78c1d 100644 --- a/tests/unit/wiring/test_wiring_py36.py +++ b/tests/unit/wiring/test_wiring_py36.py @@ -178,6 +178,8 @@ def test_wire_multiple_containers(self): def test_closing_resource(self): from wiringsamples import resourceclosing + resourceclosing.Service.reset_counter() + container = resourceclosing.Container() container.wire(modules=[resourceclosing]) self.addCleanup(container.unwire) @@ -193,3 +195,23 @@ def test_closing_resource(self): self.assertEqual(result_2.shutdown_counter, 2) self.assertIsNot(result_1, result_2) + + def test_closing_resource_context(self): + from wiringsamples import resourceclosing + + resourceclosing.Service.reset_counter() + service = resourceclosing.Service() + + container = resourceclosing.Container() + container.wire(modules=[resourceclosing]) + self.addCleanup(container.unwire) + + result_1 = resourceclosing.test_function(service=service) + self.assertIs(result_1, service) + self.assertEqual(result_1.init_counter, 0) + self.assertEqual(result_1.shutdown_counter, 0) + + result_2 = resourceclosing.test_function(service=service) + self.assertIs(result_2, service) + self.assertEqual(result_2.init_counter, 0) + self.assertEqual(result_2.shutdown_counter, 0)