Skip to content

Commit

Permalink
Add some functionality and tests for declarative containers
Browse files Browse the repository at this point in the history
+ Add checks for valid provider type
+ Add some wider functionality for overriding
  • Loading branch information
rmk135 committed May 30, 2016
1 parent 68ae1b8 commit a35db58
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 21 deletions.
51 changes: 39 additions & 12 deletions dependency_injector/containers.py
Expand Up @@ -3,6 +3,7 @@
import six

from dependency_injector import (
providers,
utils,
errors,
)
Expand All @@ -18,7 +19,8 @@ def __new__(mcs, class_name, bases, attributes):
if utils.is_provider(provider))

inherited_providers = tuple((name, provider)
for base in bases if utils.is_catalog(base)
for base in bases if utils.is_container(
base)
for name, provider in six.iteritems(
base.cls_providers))

Expand All @@ -28,14 +30,8 @@ def __new__(mcs, class_name, bases, attributes):

cls = type.__new__(mcs, class_name, bases, attributes)

if cls.provider_type:
for provider in six.itervalues(cls.providers):
try:
assert isinstance(provider, cls.provider_type)
except AssertionError:
raise errors.Error('{0} can contain only {1} '
'instances'.format(cls,
cls.provider_type))
for provider in six.itervalues(cls.providers):
cls._check_provider_type(provider)

return cls

Expand All @@ -46,6 +42,7 @@ def __setattr__(cls, name, value):
dictionary.
"""
if utils.is_provider(value):
cls._check_provider_type(value)
cls.providers[name] = value
cls.cls_providers[name] = value
super(DeclarativeContainerMetaClass, cls).__setattr__(name, value)
Expand All @@ -61,14 +58,19 @@ def __delattr__(cls, name):
del cls.cls_providers[name]
super(DeclarativeContainerMetaClass, cls).__delattr__(name)

def _check_provider_type(cls, provider):
if not isinstance(provider, cls.provider_type):
raise errors.Error('{0} can contain only {1} '
'instances'.format(cls, cls.provider_type))


@six.add_metaclass(DeclarativeContainerMetaClass)
class DeclarativeContainer(object):
"""Declarative inversion of control container."""

__IS_CATALOG__ = True
__IS_CONTAINER__ = True

provider_type = None
provider_type = providers.Provider

providers = dict()
cls_providers = dict()
Expand All @@ -89,7 +91,7 @@ def override(cls, overriding):
:rtype: None
"""
if issubclass(cls, overriding):
raise errors.Error('Catalog {0} could not be overridden '
raise errors.Error('Container {0} could not be overridden '
'with itself or its subclasses'.format(cls))

cls.overridden_by += (overriding,)
Expand All @@ -100,6 +102,31 @@ def override(cls, overriding):
except AttributeError:
pass

@classmethod
def reset_last_overriding(cls):
"""Reset last overriding provider for each container providers.
:rtype: None
"""
if not cls.overridden_by:
raise errors.Error('Container {0} is not overridden'.format(cls))

cls.overridden_by = cls.overridden_by[:-1]

for provider in six.itervalues(cls.providers):
provider.reset_last_overriding()

@classmethod
def reset_override(cls):
"""Reset all overridings for each container providers.
:rtype: None
"""
cls.overridden_by = tuple()

for provider in six.itervalues(cls.providers):
provider.reset_override()


def override(container):
""":py:class:`DeclarativeContainer` overriding decorator.
Expand Down
10 changes: 4 additions & 6 deletions dependency_injector/providers/base.py
Expand Up @@ -64,7 +64,7 @@ class Provider(object):

def __init__(self):
"""Initializer."""
self.overridden_by = None
self.overridden_by = tuple()
super(Provider, self).__init__()
# Enable __call__() / _provide() optimization
if self.__class__.__OPTIMIZED_CALLS__:
Expand Down Expand Up @@ -124,10 +124,7 @@ def override(self, provider):
if not is_provider(provider):
provider = Object(provider)

if not self.is_overridden:
self.overridden_by = (ensure_is_provider(provider),)
else:
self.overridden_by += (ensure_is_provider(provider),)
self.overridden_by += (ensure_is_provider(provider),)

# Disable __call__() / _provide() optimization
if self.__class__.__OPTIMIZED_CALLS__:
Expand All @@ -145,6 +142,7 @@ def reset_last_overriding(self):
"""
if not self.overridden_by:
raise Error('Provider {0} is not overridden'.format(str(self)))

self.overridden_by = self.overridden_by[:-1]

if not self.is_overridden:
Expand All @@ -157,7 +155,7 @@ def reset_override(self):
:rtype: None
"""
self.overridden_by = None
self.overridden_by = tuple()

# Enable __call__() / _provide() optimization
if self.__class__.__OPTIMIZED_CALLS__:
Expand Down
12 changes: 12 additions & 0 deletions dependency_injector/utils.py
Expand Up @@ -59,6 +59,18 @@ def ensure_is_provider(instance):
return instance


def is_container(instance):
"""Check if instance is container instance.
:param instance: Instance to be checked.
:type instance: object
:rtype: bool
"""
return (hasattr(instance, '__IS_CONTAINER__') and
getattr(instance, '__IS_CONTAINER__', False) is True)


def is_catalog(instance):
"""Check if instance is catalog instance.
Expand Down
121 changes: 118 additions & 3 deletions tests/test_containers.py
Expand Up @@ -5,6 +5,7 @@
from dependency_injector import (
containers,
providers,
errors,
)


Expand All @@ -28,7 +29,7 @@ class ContainerB(ContainerA):
class DeclarativeContainerTests(unittest.TestCase):
"""Declarative container tests."""

def test_providers_attribute_with(self):
def test_providers_attribute(self):
"""Test providers attribute."""
self.assertEqual(ContainerA.providers, dict(p11=ContainerA.p11,
p12=ContainerA.p12))
Expand All @@ -37,7 +38,7 @@ def test_providers_attribute_with(self):
p21=ContainerB.p21,
p22=ContainerB.p22))

def test_cls_providers_attribute_with(self):
def test_cls_providers_attribute(self):
"""Test cls_providers attribute."""
self.assertEqual(ContainerA.cls_providers, dict(p11=ContainerA.p11,
p12=ContainerA.p12))
Expand All @@ -51,7 +52,7 @@ def test_inherited_providers_attribute(self):
dict(p11=ContainerA.p11,
p12=ContainerA.p12))

def test_set_get_del_provider_attribute(self):
def test_set_get_del_providers(self):
"""Test set/get/del provider attributes."""
a_p13 = providers.Provider()
b_p23 = providers.Provider()
Expand Down Expand Up @@ -90,6 +91,120 @@ def test_set_get_del_provider_attribute(self):
self.assertEqual(ContainerB.cls_providers, dict(p21=ContainerB.p21,
p22=ContainerB.p22))

def test_declare_with_valid_provider_type(self):
"""Test declaration of container with valid provider type."""
class _Container(containers.DeclarativeContainer):
provider_type = providers.Object
px = providers.Object(object())

self.assertIsInstance(_Container.px, providers.Object)

def test_declare_with_invalid_provider_type(self):
"""Test declaration of container with invalid provider type."""
with self.assertRaises(errors.Error):
class _Container(containers.DeclarativeContainer):
provider_type = providers.Object
px = providers.Provider()

def test_seth_valid_provider_type(self):
"""Test setting of valid provider."""
class _Container(containers.DeclarativeContainer):
provider_type = providers.Object

_Container.px = providers.Object(object())

self.assertIsInstance(_Container.px, providers.Object)

def test_set_invalid_provider_type(self):
"""Test setting of invalid provider."""
class _Container(containers.DeclarativeContainer):
provider_type = providers.Object

with self.assertRaises(errors.Error):
_Container.px = providers.Provider()

def test_override(self):
"""Test override."""
class _Container(containers.DeclarativeContainer):
p11 = providers.Provider()

class _OverridingContainer1(containers.DeclarativeContainer):
p11 = providers.Provider()

class _OverridingContainer2(containers.DeclarativeContainer):
p11 = providers.Provider()
p12 = providers.Provider()

_Container.override(_OverridingContainer1)
_Container.override(_OverridingContainer2)

self.assertEqual(_Container.overridden_by,
(_OverridingContainer1,
_OverridingContainer2))
self.assertEqual(_Container.p11.overridden_by,
(_OverridingContainer1.p11,
_OverridingContainer2.p11))

def test_override_decorator(self):
"""Test override decorator."""
class _Container(containers.DeclarativeContainer):
p11 = providers.Provider()

@containers.override(_Container)
class _OverridingContainer1(containers.DeclarativeContainer):
p11 = providers.Provider()

@containers.override(_Container)
class _OverridingContainer2(containers.DeclarativeContainer):
p11 = providers.Provider()
p12 = providers.Provider()

self.assertEqual(_Container.overridden_by,
(_OverridingContainer1,
_OverridingContainer2))
self.assertEqual(_Container.p11.overridden_by,
(_OverridingContainer1.p11,
_OverridingContainer2.p11))

def test_reset_last_overridding(self):
"""Test reset of last overriding."""
class _Container(containers.DeclarativeContainer):
p11 = providers.Provider()

class _OverridingContainer1(containers.DeclarativeContainer):
p11 = providers.Provider()

class _OverridingContainer2(containers.DeclarativeContainer):
p11 = providers.Provider()
p12 = providers.Provider()

_Container.override(_OverridingContainer1)
_Container.override(_OverridingContainer2)
_Container.reset_last_overriding()

self.assertEqual(_Container.overridden_by,
(_OverridingContainer1,))
self.assertEqual(_Container.p11.overridden_by,
(_OverridingContainer1.p11,))

def test_reset_override(self):
"""Test reset all overridings."""
class _Container(containers.DeclarativeContainer):
p11 = providers.Provider()

class _OverridingContainer1(containers.DeclarativeContainer):
p11 = providers.Provider()

class _OverridingContainer2(containers.DeclarativeContainer):
p11 = providers.Provider()
p12 = providers.Provider()

_Container.override(_OverridingContainer1)
_Container.override(_OverridingContainer2)
_Container.reset_override()

self.assertEqual(_Container.overridden_by, tuple())
self.assertEqual(_Container.p11.overridden_by, tuple())

if __name__ == '__main__':
unittest.main()

0 comments on commit a35db58

Please sign in to comment.