diff --git a/staticconf/testing.py b/staticconf/testing.py index 4a3bcc6..e85a594 100644 --- a/staticconf/testing.py +++ b/staticconf/testing.py @@ -1,11 +1,13 @@ """ Facilitate testing of code which uses staticconf. """ +import copy + from staticconf import config, loader class MockConfiguration(object): - """A a context manager which patches the configuration namespace + """A context manager which replaces the configuration namespace while inside the context. When the context exits the old configuration values will be restored to that namespace. @@ -56,3 +58,35 @@ def __enter__(self): def __exit__(self, *args): self.teardown() + + +class PatchConfiguration(MockConfiguration): + """A context manager which updates the configuration namespace while inside + the context. When the context exits the old configuration values will be + restored to that namespace. + + Unlike MockConfiguration which completely replaces the configuration with + the new one, this class instead only updates the keys in the configuration + which are passed to it. It preserves all previous values that weren't + updated. + + .. code-block:: python + + import staticconf.testing + + config = { + ... + } + with staticconf.testing.PatchConfiguration(config, namespace='special'): + # Run your tests. + ... + + The arguments are identical to MockConfiguration. + """ + + def setup(self): + self.old_values = copy.deepcopy(dict(self.namespace.get_config_values())) + new_configuration = copy.deepcopy(self.old_values) + new_configuration.update(self.config_data) + self.reset_namespace(new_configuration) + config.reload(name=self.namespace.name) diff --git a/tests/testing_test.py b/tests/testing_test.py index be9fd41..8775c43 100644 --- a/tests/testing_test.py +++ b/tests/testing_test.py @@ -22,3 +22,20 @@ def test_init_nested(self): with testing.MockConfiguration(conf): assert_equal(staticconf.get('a.b'), 'two') assert_equal(staticconf.get('c'), 'three') + + +class TestPatchConfiguration(object): + + def test_nested(self): + with testing.MockConfiguration(a='one', b='two'): + with testing.PatchConfiguration(a='three'): + assert_equal(staticconf.get('a'), 'three') + assert_equal(staticconf.get('b'), 'two') + + assert_equal(staticconf.get('a'), 'one') + assert_equal(staticconf.get('b'), 'two') + + def test_not_nested(self): + with testing.PatchConfiguration(a='one', b='two'): + assert_equal(staticconf.get('a'), 'one') + assert_equal(staticconf.get('b'), 'two')