5
5
from dependency_injector import (
6
6
containers ,
7
7
providers ,
8
+ errors ,
8
9
)
9
10
10
11
@@ -28,7 +29,7 @@ class ContainerB(ContainerA):
28
29
class DeclarativeContainerTests (unittest .TestCase ):
29
30
"""Declarative container tests."""
30
31
31
- def test_providers_attribute_with (self ):
32
+ def test_providers_attribute (self ):
32
33
"""Test providers attribute."""
33
34
self .assertEqual (ContainerA .providers , dict (p11 = ContainerA .p11 ,
34
35
p12 = ContainerA .p12 ))
@@ -37,7 +38,7 @@ def test_providers_attribute_with(self):
37
38
p21 = ContainerB .p21 ,
38
39
p22 = ContainerB .p22 ))
39
40
40
- def test_cls_providers_attribute_with (self ):
41
+ def test_cls_providers_attribute (self ):
41
42
"""Test cls_providers attribute."""
42
43
self .assertEqual (ContainerA .cls_providers , dict (p11 = ContainerA .p11 ,
43
44
p12 = ContainerA .p12 ))
@@ -51,7 +52,7 @@ def test_inherited_providers_attribute(self):
51
52
dict (p11 = ContainerA .p11 ,
52
53
p12 = ContainerA .p12 ))
53
54
54
- def test_set_get_del_provider_attribute (self ):
55
+ def test_set_get_del_providers (self ):
55
56
"""Test set/get/del provider attributes."""
56
57
a_p13 = providers .Provider ()
57
58
b_p23 = providers .Provider ()
@@ -90,6 +91,120 @@ def test_set_get_del_provider_attribute(self):
90
91
self .assertEqual (ContainerB .cls_providers , dict (p21 = ContainerB .p21 ,
91
92
p22 = ContainerB .p22 ))
92
93
94
+ def test_declare_with_valid_provider_type (self ):
95
+ """Test declaration of container with valid provider type."""
96
+ class _Container (containers .DeclarativeContainer ):
97
+ provider_type = providers .Object
98
+ px = providers .Object (object ())
99
+
100
+ self .assertIsInstance (_Container .px , providers .Object )
101
+
102
+ def test_declare_with_invalid_provider_type (self ):
103
+ """Test declaration of container with invalid provider type."""
104
+ with self .assertRaises (errors .Error ):
105
+ class _Container (containers .DeclarativeContainer ):
106
+ provider_type = providers .Object
107
+ px = providers .Provider ()
108
+
109
+ def test_seth_valid_provider_type (self ):
110
+ """Test setting of valid provider."""
111
+ class _Container (containers .DeclarativeContainer ):
112
+ provider_type = providers .Object
113
+
114
+ _Container .px = providers .Object (object ())
115
+
116
+ self .assertIsInstance (_Container .px , providers .Object )
117
+
118
+ def test_set_invalid_provider_type (self ):
119
+ """Test setting of invalid provider."""
120
+ class _Container (containers .DeclarativeContainer ):
121
+ provider_type = providers .Object
122
+
123
+ with self .assertRaises (errors .Error ):
124
+ _Container .px = providers .Provider ()
125
+
126
+ def test_override (self ):
127
+ """Test override."""
128
+ class _Container (containers .DeclarativeContainer ):
129
+ p11 = providers .Provider ()
130
+
131
+ class _OverridingContainer1 (containers .DeclarativeContainer ):
132
+ p11 = providers .Provider ()
133
+
134
+ class _OverridingContainer2 (containers .DeclarativeContainer ):
135
+ p11 = providers .Provider ()
136
+ p12 = providers .Provider ()
137
+
138
+ _Container .override (_OverridingContainer1 )
139
+ _Container .override (_OverridingContainer2 )
140
+
141
+ self .assertEqual (_Container .overridden_by ,
142
+ (_OverridingContainer1 ,
143
+ _OverridingContainer2 ))
144
+ self .assertEqual (_Container .p11 .overridden_by ,
145
+ (_OverridingContainer1 .p11 ,
146
+ _OverridingContainer2 .p11 ))
147
+
148
+ def test_override_decorator (self ):
149
+ """Test override decorator."""
150
+ class _Container (containers .DeclarativeContainer ):
151
+ p11 = providers .Provider ()
152
+
153
+ @containers .override (_Container )
154
+ class _OverridingContainer1 (containers .DeclarativeContainer ):
155
+ p11 = providers .Provider ()
156
+
157
+ @containers .override (_Container )
158
+ class _OverridingContainer2 (containers .DeclarativeContainer ):
159
+ p11 = providers .Provider ()
160
+ p12 = providers .Provider ()
161
+
162
+ self .assertEqual (_Container .overridden_by ,
163
+ (_OverridingContainer1 ,
164
+ _OverridingContainer2 ))
165
+ self .assertEqual (_Container .p11 .overridden_by ,
166
+ (_OverridingContainer1 .p11 ,
167
+ _OverridingContainer2 .p11 ))
168
+
169
+ def test_reset_last_overridding (self ):
170
+ """Test reset of last overriding."""
171
+ class _Container (containers .DeclarativeContainer ):
172
+ p11 = providers .Provider ()
173
+
174
+ class _OverridingContainer1 (containers .DeclarativeContainer ):
175
+ p11 = providers .Provider ()
176
+
177
+ class _OverridingContainer2 (containers .DeclarativeContainer ):
178
+ p11 = providers .Provider ()
179
+ p12 = providers .Provider ()
180
+
181
+ _Container .override (_OverridingContainer1 )
182
+ _Container .override (_OverridingContainer2 )
183
+ _Container .reset_last_overriding ()
184
+
185
+ self .assertEqual (_Container .overridden_by ,
186
+ (_OverridingContainer1 ,))
187
+ self .assertEqual (_Container .p11 .overridden_by ,
188
+ (_OverridingContainer1 .p11 ,))
189
+
190
+ def test_reset_override (self ):
191
+ """Test reset all overridings."""
192
+ class _Container (containers .DeclarativeContainer ):
193
+ p11 = providers .Provider ()
194
+
195
+ class _OverridingContainer1 (containers .DeclarativeContainer ):
196
+ p11 = providers .Provider ()
197
+
198
+ class _OverridingContainer2 (containers .DeclarativeContainer ):
199
+ p11 = providers .Provider ()
200
+ p12 = providers .Provider ()
201
+
202
+ _Container .override (_OverridingContainer1 )
203
+ _Container .override (_OverridingContainer2 )
204
+ _Container .reset_override ()
205
+
206
+ self .assertEqual (_Container .overridden_by , tuple ())
207
+ self .assertEqual (_Container .p11 .overridden_by , tuple ())
93
208
94
209
if __name__ == '__main__' :
95
210
unittest .main ()
0 commit comments