From 4733aad44eedc64a3111556397556acc6d89c09e Mon Sep 17 00:00:00 2001 From: Roman Mogylatov Date: Sat, 25 Sep 2021 15:36:48 -0400 Subject: [PATCH] Fix provide issue (#514) --- docs/main/changelog.rst | 1 + docs/wiring.rst | 20 +++++++++++++++----- src/dependency_injector/wiring.py | 10 ++++++++-- tests/unit/samples/wiringsamples/module.py | 8 +++++++- tests/unit/wiring/test_wiring_py36.py | 4 ++++ 5 files changed, 35 insertions(+), 8 deletions(-) diff --git a/docs/main/changelog.rst b/docs/main/changelog.rst index 6950db273..31ed445c7 100644 --- a/docs/main/changelog.rst +++ b/docs/main/changelog.rst @@ -10,6 +10,7 @@ follows `Semantic versioning`_ Develop ------- +- Fix a wiring bug with improper resolving of ``Provide[some_provider.provider]``. - Fix a typo in ``Factory`` provider docs ``service.add_attributes(clent=client)`` `#499 `_. Thanks to `@rajanjha786 `_ for the contribution. diff --git a/docs/wiring.rst b/docs/wiring.rst index ad74ad347..58deeae52 100644 --- a/docs/wiring.rst +++ b/docs/wiring.rst @@ -39,19 +39,29 @@ a function or method argument: Specifying an annotation is optional. -There are two types of markers: +To inject the provider itself use ``Provide[foo.provider]``: -- ``Provide[foo]`` - call the provider ``foo`` and injects the result -- ``Provider[foo]`` - injects the provider ``foo`` itself +.. code-block:: python + + from dependency_injector.providers import Factory + from dependency_injector.wiring import inject, Provide + + + @inject + def foo(bar_provider: Factory[Bar] = Provide[Container.bar.provider]): + bar = bar_provider(argument="baz") + ... +You can also use ``Provider[foo]`` for injecting the provider itself: .. code-block:: python + from dependency_injector.providers import Factory from dependency_injector.wiring import inject, Provider @inject - def foo(bar_provider: Callable[..., Bar] = Provider[Container.bar]): - bar = bar_provider() + def foo(bar_provider: Factory[Bar] = Provider[Container.bar]): + bar = bar_provider(argument="baz") ... You can use configuration, provided instance and sub-container providers as you normally do. diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index 4dd77c098..a98c28c31 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -226,7 +226,10 @@ def _resolve_delegate( self, original: providers.Delegate, ) -> Optional[providers.Provider]: - return self._resolve_provider(original.provides) + provider = self._resolve_provider(original.provides) + if provider: + provider = provider.provider + return provider def _resolve_config_option( self, @@ -539,7 +542,10 @@ def _bind_injections(fn: Callable[..., Any], providers_map: ProvidersMap) -> Non if isinstance(marker, Provide): fn.__injections__[injection] = provider elif isinstance(marker, Provider): - fn.__injections__[injection] = provider.provider + if isinstance(provider, providers.Delegate): + fn.__injections__[injection] = provider + else: + fn.__injections__[injection] = provider.provider if injection in fn.__reference_closing__: fn.__closing__[injection] = provider diff --git a/tests/unit/samples/wiringsamples/module.py b/tests/unit/samples/wiringsamples/module.py index 333de332f..e1ccc1121 100644 --- a/tests/unit/samples/wiringsamples/module.py +++ b/tests/unit/samples/wiringsamples/module.py @@ -84,7 +84,13 @@ def test_config_value_required_undefined( @inject -def test_provide_provider(service_provider: Callable[..., Service] = Provider[Container.service.provider]): +def test_provide_provider(service_provider: Callable[..., Service] = Provide[Container.service.provider]): + service = service_provider() + return service + + +@inject +def test_provider_provider(service_provider: Callable[..., Service] = Provider[Container.service.provider]): service = service_provider() return service diff --git a/tests/unit/wiring/test_wiring_py36.py b/tests/unit/wiring/test_wiring_py36.py index 7c5037828..0d0a4a5f1 100644 --- a/tests/unit/wiring/test_wiring_py36.py +++ b/tests/unit/wiring/test_wiring_py36.py @@ -169,6 +169,10 @@ def test_provide_provider(self): service = module.test_provide_provider() self.assertIsInstance(service, Service) + def test_provider_provider(self): + service = module.test_provider_provider() + self.assertIsInstance(service, Service) + def test_provided_instance(self): class TestService: foo = {