Skip to content

Commit a35db58

Browse files
committed
Add some functionality and tests for declarative containers
+ Add checks for valid provider type + Add some wider functionality for overriding
1 parent 68ae1b8 commit a35db58

File tree

4 files changed

+173
-21
lines changed

4 files changed

+173
-21
lines changed

dependency_injector/containers.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import six
44

55
from dependency_injector import (
6+
providers,
67
utils,
78
errors,
89
)
@@ -18,7 +19,8 @@ def __new__(mcs, class_name, bases, attributes):
1819
if utils.is_provider(provider))
1920

2021
inherited_providers = tuple((name, provider)
21-
for base in bases if utils.is_catalog(base)
22+
for base in bases if utils.is_container(
23+
base)
2224
for name, provider in six.iteritems(
2325
base.cls_providers))
2426

@@ -28,14 +30,8 @@ def __new__(mcs, class_name, bases, attributes):
2830

2931
cls = type.__new__(mcs, class_name, bases, attributes)
3032

31-
if cls.provider_type:
32-
for provider in six.itervalues(cls.providers):
33-
try:
34-
assert isinstance(provider, cls.provider_type)
35-
except AssertionError:
36-
raise errors.Error('{0} can contain only {1} '
37-
'instances'.format(cls,
38-
cls.provider_type))
33+
for provider in six.itervalues(cls.providers):
34+
cls._check_provider_type(provider)
3935

4036
return cls
4137

@@ -46,6 +42,7 @@ def __setattr__(cls, name, value):
4642
dictionary.
4743
"""
4844
if utils.is_provider(value):
45+
cls._check_provider_type(value)
4946
cls.providers[name] = value
5047
cls.cls_providers[name] = value
5148
super(DeclarativeContainerMetaClass, cls).__setattr__(name, value)
@@ -61,14 +58,19 @@ def __delattr__(cls, name):
6158
del cls.cls_providers[name]
6259
super(DeclarativeContainerMetaClass, cls).__delattr__(name)
6360

61+
def _check_provider_type(cls, provider):
62+
if not isinstance(provider, cls.provider_type):
63+
raise errors.Error('{0} can contain only {1} '
64+
'instances'.format(cls, cls.provider_type))
65+
6466

6567
@six.add_metaclass(DeclarativeContainerMetaClass)
6668
class DeclarativeContainer(object):
6769
"""Declarative inversion of control container."""
6870

69-
__IS_CATALOG__ = True
71+
__IS_CONTAINER__ = True
7072

71-
provider_type = None
73+
provider_type = providers.Provider
7274

7375
providers = dict()
7476
cls_providers = dict()
@@ -89,7 +91,7 @@ def override(cls, overriding):
8991
:rtype: None
9092
"""
9193
if issubclass(cls, overriding):
92-
raise errors.Error('Catalog {0} could not be overridden '
94+
raise errors.Error('Container {0} could not be overridden '
9395
'with itself or its subclasses'.format(cls))
9496

9597
cls.overridden_by += (overriding,)
@@ -100,6 +102,31 @@ def override(cls, overriding):
100102
except AttributeError:
101103
pass
102104

105+
@classmethod
106+
def reset_last_overriding(cls):
107+
"""Reset last overriding provider for each container providers.
108+
109+
:rtype: None
110+
"""
111+
if not cls.overridden_by:
112+
raise errors.Error('Container {0} is not overridden'.format(cls))
113+
114+
cls.overridden_by = cls.overridden_by[:-1]
115+
116+
for provider in six.itervalues(cls.providers):
117+
provider.reset_last_overriding()
118+
119+
@classmethod
120+
def reset_override(cls):
121+
"""Reset all overridings for each container providers.
122+
123+
:rtype: None
124+
"""
125+
cls.overridden_by = tuple()
126+
127+
for provider in six.itervalues(cls.providers):
128+
provider.reset_override()
129+
103130

104131
def override(container):
105132
""":py:class:`DeclarativeContainer` overriding decorator.

dependency_injector/providers/base.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class Provider(object):
6464

6565
def __init__(self):
6666
"""Initializer."""
67-
self.overridden_by = None
67+
self.overridden_by = tuple()
6868
super(Provider, self).__init__()
6969
# Enable __call__() / _provide() optimization
7070
if self.__class__.__OPTIMIZED_CALLS__:
@@ -124,10 +124,7 @@ def override(self, provider):
124124
if not is_provider(provider):
125125
provider = Object(provider)
126126

127-
if not self.is_overridden:
128-
self.overridden_by = (ensure_is_provider(provider),)
129-
else:
130-
self.overridden_by += (ensure_is_provider(provider),)
127+
self.overridden_by += (ensure_is_provider(provider),)
131128

132129
# Disable __call__() / _provide() optimization
133130
if self.__class__.__OPTIMIZED_CALLS__:
@@ -145,6 +142,7 @@ def reset_last_overriding(self):
145142
"""
146143
if not self.overridden_by:
147144
raise Error('Provider {0} is not overridden'.format(str(self)))
145+
148146
self.overridden_by = self.overridden_by[:-1]
149147

150148
if not self.is_overridden:
@@ -157,7 +155,7 @@ def reset_override(self):
157155
158156
:rtype: None
159157
"""
160-
self.overridden_by = None
158+
self.overridden_by = tuple()
161159

162160
# Enable __call__() / _provide() optimization
163161
if self.__class__.__OPTIMIZED_CALLS__:

dependency_injector/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,18 @@ def ensure_is_provider(instance):
5959
return instance
6060

6161

62+
def is_container(instance):
63+
"""Check if instance is container instance.
64+
65+
:param instance: Instance to be checked.
66+
:type instance: object
67+
68+
:rtype: bool
69+
"""
70+
return (hasattr(instance, '__IS_CONTAINER__') and
71+
getattr(instance, '__IS_CONTAINER__', False) is True)
72+
73+
6274
def is_catalog(instance):
6375
"""Check if instance is catalog instance.
6476

tests/test_containers.py

Lines changed: 118 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from dependency_injector import (
66
containers,
77
providers,
8+
errors,
89
)
910

1011

@@ -28,7 +29,7 @@ class ContainerB(ContainerA):
2829
class DeclarativeContainerTests(unittest.TestCase):
2930
"""Declarative container tests."""
3031

31-
def test_providers_attribute_with(self):
32+
def test_providers_attribute(self):
3233
"""Test providers attribute."""
3334
self.assertEqual(ContainerA.providers, dict(p11=ContainerA.p11,
3435
p12=ContainerA.p12))
@@ -37,7 +38,7 @@ def test_providers_attribute_with(self):
3738
p21=ContainerB.p21,
3839
p22=ContainerB.p22))
3940

40-
def test_cls_providers_attribute_with(self):
41+
def test_cls_providers_attribute(self):
4142
"""Test cls_providers attribute."""
4243
self.assertEqual(ContainerA.cls_providers, dict(p11=ContainerA.p11,
4344
p12=ContainerA.p12))
@@ -51,7 +52,7 @@ def test_inherited_providers_attribute(self):
5152
dict(p11=ContainerA.p11,
5253
p12=ContainerA.p12))
5354

54-
def test_set_get_del_provider_attribute(self):
55+
def test_set_get_del_providers(self):
5556
"""Test set/get/del provider attributes."""
5657
a_p13 = providers.Provider()
5758
b_p23 = providers.Provider()
@@ -90,6 +91,120 @@ def test_set_get_del_provider_attribute(self):
9091
self.assertEqual(ContainerB.cls_providers, dict(p21=ContainerB.p21,
9192
p22=ContainerB.p22))
9293

94+
def test_declare_with_valid_provider_type(self):
95+
"""Test declaration of container with valid provider type."""
96+
class _Container(containers.DeclarativeContainer):
97+
provider_type = providers.Object
98+
px = providers.Object(object())
99+
100+
self.assertIsInstance(_Container.px, providers.Object)
101+
102+
def test_declare_with_invalid_provider_type(self):
103+
"""Test declaration of container with invalid provider type."""
104+
with self.assertRaises(errors.Error):
105+
class _Container(containers.DeclarativeContainer):
106+
provider_type = providers.Object
107+
px = providers.Provider()
108+
109+
def test_seth_valid_provider_type(self):
110+
"""Test setting of valid provider."""
111+
class _Container(containers.DeclarativeContainer):
112+
provider_type = providers.Object
113+
114+
_Container.px = providers.Object(object())
115+
116+
self.assertIsInstance(_Container.px, providers.Object)
117+
118+
def test_set_invalid_provider_type(self):
119+
"""Test setting of invalid provider."""
120+
class _Container(containers.DeclarativeContainer):
121+
provider_type = providers.Object
122+
123+
with self.assertRaises(errors.Error):
124+
_Container.px = providers.Provider()
125+
126+
def test_override(self):
127+
"""Test override."""
128+
class _Container(containers.DeclarativeContainer):
129+
p11 = providers.Provider()
130+
131+
class _OverridingContainer1(containers.DeclarativeContainer):
132+
p11 = providers.Provider()
133+
134+
class _OverridingContainer2(containers.DeclarativeContainer):
135+
p11 = providers.Provider()
136+
p12 = providers.Provider()
137+
138+
_Container.override(_OverridingContainer1)
139+
_Container.override(_OverridingContainer2)
140+
141+
self.assertEqual(_Container.overridden_by,
142+
(_OverridingContainer1,
143+
_OverridingContainer2))
144+
self.assertEqual(_Container.p11.overridden_by,
145+
(_OverridingContainer1.p11,
146+
_OverridingContainer2.p11))
147+
148+
def test_override_decorator(self):
149+
"""Test override decorator."""
150+
class _Container(containers.DeclarativeContainer):
151+
p11 = providers.Provider()
152+
153+
@containers.override(_Container)
154+
class _OverridingContainer1(containers.DeclarativeContainer):
155+
p11 = providers.Provider()
156+
157+
@containers.override(_Container)
158+
class _OverridingContainer2(containers.DeclarativeContainer):
159+
p11 = providers.Provider()
160+
p12 = providers.Provider()
161+
162+
self.assertEqual(_Container.overridden_by,
163+
(_OverridingContainer1,
164+
_OverridingContainer2))
165+
self.assertEqual(_Container.p11.overridden_by,
166+
(_OverridingContainer1.p11,
167+
_OverridingContainer2.p11))
168+
169+
def test_reset_last_overridding(self):
170+
"""Test reset of last overriding."""
171+
class _Container(containers.DeclarativeContainer):
172+
p11 = providers.Provider()
173+
174+
class _OverridingContainer1(containers.DeclarativeContainer):
175+
p11 = providers.Provider()
176+
177+
class _OverridingContainer2(containers.DeclarativeContainer):
178+
p11 = providers.Provider()
179+
p12 = providers.Provider()
180+
181+
_Container.override(_OverridingContainer1)
182+
_Container.override(_OverridingContainer2)
183+
_Container.reset_last_overriding()
184+
185+
self.assertEqual(_Container.overridden_by,
186+
(_OverridingContainer1,))
187+
self.assertEqual(_Container.p11.overridden_by,
188+
(_OverridingContainer1.p11,))
189+
190+
def test_reset_override(self):
191+
"""Test reset all overridings."""
192+
class _Container(containers.DeclarativeContainer):
193+
p11 = providers.Provider()
194+
195+
class _OverridingContainer1(containers.DeclarativeContainer):
196+
p11 = providers.Provider()
197+
198+
class _OverridingContainer2(containers.DeclarativeContainer):
199+
p11 = providers.Provider()
200+
p12 = providers.Provider()
201+
202+
_Container.override(_OverridingContainer1)
203+
_Container.override(_OverridingContainer2)
204+
_Container.reset_override()
205+
206+
self.assertEqual(_Container.overridden_by, tuple())
207+
self.assertEqual(_Container.p11.overridden_by, tuple())
93208

94209
if __name__ == '__main__':
95210
unittest.main()

0 commit comments

Comments
 (0)