diff --git a/omegaconf/dictconfig.py b/omegaconf/dictconfig.py index 9cb310c84..f9f9aff6c 100644 --- a/omegaconf/dictconfig.py +++ b/omegaconf/dictconfig.py @@ -1,7 +1,6 @@ from .config import Config +from .errors import ReadonlyConfigError, MissingMandatoryValue from .nodes import BaseNode, UntypedNode -from .errors import ReadonlyConfigError -import copy class DictConfig(Config): @@ -98,6 +97,22 @@ def pop(self, key, default=__marker): def keys(self): return self.content.keys() + def __contains__(self, key): + """ + A key is contained in a DictConfig if there is an associated value and it is not a mandatory missing value ('???'). + :param key: + :return: + """ + node = self.get_node(key) + if node is None: + return False + else: + try: + self._resolve_with_default(key, node, None) + return True + except (MissingMandatoryValue, KeyError): + return False + def __iter__(self): return iter(self.keys()) diff --git a/tests/test_basic_ops_dict.py b/tests/test_basic_ops_dict.py index f99e03675..c5c154d99 100644 --- a/tests/test_basic_ops_dict.py +++ b/tests/test_basic_ops_dict.py @@ -207,15 +207,19 @@ def test_dict_pop(): c.pop('not_found') -def test_in_dict(): - c = OmegaConf.create(dict( - a=1, - b=2, - c={})) - assert 'a' in c - assert 'b' in c - assert 'c' in c - assert 'd' not in c +@pytest.mark.parametrize("conf,key,expected", [ + ({"a": 1, "b": {}}, "a", True), + ({"a": 1, "b": {}}, "b", True), + ({"a": 1, "b": {}}, "c", False), + ({"a": 1, "b": "${a}"}, "b", True), + ({"a": 1, "b": "???"}, "b", False), + ({"a": 1, "b": "???", "c": "${b}"}, "c", False), + ({"a": 1, "b": "${not_found}"}, "b", False), +]) +def test_in_dict(conf, key, expected): + conf = OmegaConf.create(conf) + ret = key in conf + assert ret == expected def test_get_root():