diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 0e56495e0..e21afe276 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -1,6 +1,6 @@ .. testsetup:: * - from omegaconf import OmegaConf, DictConfig, open_dict, read_write + from omegaconf import OmegaConf, DictConfig, ListConfig, open_dict, read_write import os import sys import tempfile @@ -466,6 +466,38 @@ This can be useful for instance to parse environment variables: type: int, value: 3308 +Extracting lists of keys / values from a dictionary +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Some config options that are stored as a ``DictConfig`` may sometimes be easier to manipulate as lists, +when we care only about the keys or the associated values. + +The resolvers ``oc.dict.keys`` and ``oc.dict.values`` simplify such operations by offering an alternative +view of a dictionary's keys or values as a list. +They take as input a string that is the path to another config node (using the same syntax +as interpolations) and return a ``ListConfig`` with its keys / values. + +.. doctest:: + + >>> cfg = OmegaConf.create( + ... { + ... "workers": { + ... "node3": "10.0.0.2", + ... "node7": "10.0.0.9", + ... }, + ... "nodes": "${oc.dict.keys: workers}", + ... "ips": "${oc.dict.values: workers}", + ... } + ... ) + >>> # Keys are copied from the DictConfig: + >>> show(cfg.nodes) + type: ListConfig, value: ['node3', 'node7'] + >>> # Values are dynamically fetched through interpolations: + >>> show(cfg.ips) + type: ListConfig, value: ['${workers.node3}', '${workers.node7}'] + >>> assert cfg.ips == ["10.0.0.2", "10.0.0.9"] + + Custom interpolations ^^^^^^^^^^^^^^^^^^^^^ diff --git a/news/643.feature b/news/643.feature new file mode 100644 index 000000000..a7fd0f083 --- /dev/null +++ b/news/643.feature @@ -0,0 +1 @@ +New resolvers `oc.dict.keys` and `oc.dict.values` provide a list view of the keys or values of a DictConfig node. diff --git a/omegaconf/base.py b/omegaconf/base.py index 33bc9f996..1a9082525 100644 --- a/omegaconf/base.py +++ b/omegaconf/base.py @@ -451,6 +451,7 @@ def _resolve_interpolation_from_parse_tree( node that is created to wrap the interpolated value. It is `None` if and only if `throw_on_resolution_failure` is `False` and an error occurs during resolution. """ + try: resolved = self.resolve_parse_tree( parse_tree=parse_tree, diff --git a/omegaconf/built_in_resolvers.py b/omegaconf/built_in_resolvers.py index 373c01a4d..885d6e4b3 100644 --- a/omegaconf/built_in_resolvers.py +++ b/omegaconf/built_in_resolvers.py @@ -1,11 +1,20 @@ import os import warnings -from typing import Any, Optional +from typing import Any, List, Optional -from ._utils import _DEFAULT_MARKER_, _get_value, decode_primitive +from ._utils import _DEFAULT_MARKER_, Marker, _get_value, decode_primitive from .base import Container -from .errors import ValidationError +from .basecontainer import BaseContainer +from .dictconfig import DictConfig +from .errors import ConfigKeyError, ValidationError from .grammar_parser import parse +from .listconfig import ListConfig +from .nodes import AnyNode +from .omegaconf import OmegaConf + +# Special marker use as default value when calling `OmegaConf.select()`. It must be +# different from `_DEFAULT_MARKER_`, which is used by `OmegaConf.select()`. +_DEFAULT_SELECT_MARKER_: Any = Marker("_DEFAULT_SELECT_MARKER_") def decode(expr: Optional[str], _parent_: Container) -> Any: @@ -28,6 +37,44 @@ def decode(expr: Optional[str], _parent_: Container) -> Any: return _get_value(val) +def dict_keys( + key: str, + _parent_: Container, +) -> ListConfig: + assert isinstance(_parent_, BaseContainer) + + in_dict = _get_and_validate_dict_input( + key, parent=_parent_, resolver_name="oc.dict.keys" + ) + + ret = OmegaConf.create(list(in_dict.keys()), parent=_parent_) + assert isinstance(ret, ListConfig) + return ret + + +def dict_values(key: str, _root_: BaseContainer, _parent_: Container) -> ListConfig: + assert isinstance(_parent_, BaseContainer) + in_dict = _get_and_validate_dict_input( + key, parent=_parent_, resolver_name="oc.dict.values" + ) + + content = in_dict._content + assert isinstance(content, dict) + + ret = ListConfig([]) + for k in content: + ref_node = AnyNode(f"${{{key}.{k}}}") + ret.append(ref_node) + + # Finalize result by setting proper type and parent. + element_type: Any = in_dict._metadata.element_type + ret._metadata.element_type = element_type + ret._metadata.ref_type = List[element_type] + ret._set_parent(_parent_) + + return ret + + def env(key: str, default: Any = _DEFAULT_MARKER_) -> Optional[str]: """ :param key: Environment variable key @@ -60,3 +107,34 @@ def legacy_env(key: str, default: Optional[str] = None) -> Any: return decode_primitive(default) else: raise ValidationError(f"Environment variable '{key}' not found") + + +def _get_and_validate_dict_input( + key: str, + parent: BaseContainer, + resolver_name: str, +) -> DictConfig: + if not isinstance(key, str): + raise TypeError( + f"`{resolver_name}` requires a string as input, but obtained `{key}` " + f"of type: {type(key).__name__}" + ) + + in_dict = OmegaConf.select( + parent, + key, + throw_on_missing=True, + absolute_key=True, + default=_DEFAULT_SELECT_MARKER_, + ) + + if in_dict is _DEFAULT_SELECT_MARKER_: + raise ConfigKeyError(f"Key not found: '{key}'") + + if not isinstance(in_dict, DictConfig): + raise TypeError( + f"`{resolver_name}` cannot be applied to objects of type: " + f"{type(in_dict).__name__}" + ) + + return in_dict diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index 731da2f68..b2390c193 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -93,11 +93,14 @@ def SI(interpolation: str) -> Any: def register_default_resolvers() -> None: - from .built_in_resolvers import decode, env, legacy_env + from .built_in_resolvers import decode, dict_keys, dict_values, env, legacy_env + + OmegaConf.register_new_resolver("oc.decode", decode) + OmegaConf.register_new_resolver("oc.dict.keys", dict_keys) + OmegaConf.register_new_resolver("oc.dict.values", dict_values) + OmegaConf.register_new_resolver("oc.env", env) OmegaConf.legacy_register_resolver("env", legacy_env) - OmegaConf.register_new_resolver("oc.env", env, use_cache=False) - OmegaConf.register_new_resolver("oc.decode", decode, use_cache=False) class OmegaConf: diff --git a/tests/interpolation/built_in_resolvers/test_dict.py b/tests/interpolation/built_in_resolvers/test_dict.py new file mode 100644 index 000000000..d96cfa92a --- /dev/null +++ b/tests/interpolation/built_in_resolvers/test_dict.py @@ -0,0 +1,353 @@ +import re +from typing import Any, List + +from pytest import mark, param, raises + +from omegaconf import DictConfig, ListConfig, OmegaConf +from omegaconf.built_in_resolvers import _get_and_validate_dict_input +from omegaconf.errors import ( + InterpolationResolutionError, + InterpolationToMissingValueError, +) +from tests import User, Users + + +@mark.parametrize( + ("cfg", "key", "expected"), + [ + param( + {"foo": "${oc.dict.keys:bar}", "bar": {"a": 0, "b": 1}}, + "foo", + ["a", "b"], + id="dictconfig", + ), + param( + {"foo": "${oc.dict.keys:bar}", "bar": "${boz}", "boz": {"a": 0, "b": 1}}, + "foo", + ["a", "b"], + id="dictconfig_chained_interpolation", + ), + ], +) +def test_dict_keys(cfg: Any, key: Any, expected: Any) -> None: + cfg = OmegaConf.create(cfg) + val = cfg[key] + assert val == expected + assert isinstance(val, ListConfig) + assert val._parent is cfg + + +@mark.parametrize( + ("cfg", "key", "expected"), + [ + param( + {"x": "${oc.dict.keys_or_values:y}", "y": "???"}, + "x", + raises( + InterpolationResolutionError, + match=re.escape( + "MissingMandatoryValue raised while resolving interpolation: " + "Missing mandatory value: y" + ), + ), + id="select_missing", + ), + param( + {"foo": "${oc.dict.keys_or_values:bar}"}, + "foo", + raises( + InterpolationResolutionError, + match=re.escape( + "ConfigKeyError raised while resolving interpolation: " + "Key not found: 'bar'" + ), + ), + id="config_key_error", + ), + param( + # This might be allowed in the future. Currently it fails. + {"foo": "${oc.dict.keys_or_values:''}"}, + "foo", + raises( + InterpolationResolutionError, + match=re.escape( + "ConfigKeyError raised while resolving interpolation: " + "Key not found: ''" + ), + ), + id="config_key_error_empty", + ), + param( + {"foo": "${oc.dict.keys_or_values:bar}", "bar": 0}, + "foo", + raises( + InterpolationResolutionError, + match=re.escape( + "TypeError raised while resolving interpolation: " + "`oc.dict.keys_or_values` cannot be applied to objects of type: int" + ), + ), + id="type_error", + ), + param( + {"foo": "${oc.dict.keys_or_values:bar}", "bar": DictConfig(None)}, + "foo", + raises( + InterpolationResolutionError, + match=re.escape( + "TypeError raised while resolving interpolation: " + "`oc.dict.keys_or_values` cannot be applied to objects of type: NoneType" + ), + ), + id="type_error_dictconfig", + ), + ], +) +def test_get_and_validate_dict_input( + restore_resolvers: Any, cfg: Any, key: Any, expected: Any +) -> None: + OmegaConf.register_new_resolver( + "oc.dict.keys_or_values", + lambda in_dict, _parent_: _get_and_validate_dict_input( + in_dict, parent=_parent_, resolver_name="oc.dict.keys_or_values" + ), + ) + cfg = OmegaConf.create(cfg) + with expected: + cfg[key] + + +@mark.parametrize( + ("cfg", "key", "expected_val", "expected_content"), + [ + param( + {"foo": "${oc.dict.values:bar}", "bar": {"a": 0, "b": 1}}, + "foo", + [0, 1], + ["${bar.a}", "${bar.b}"], + id="dictconfig", + ), + param( + { + "foo": "${oc.dict.values:bar}", + "bar": {"a": {"x": 0, "y": 1}, "b": {"x": 0}}, + }, + "foo", + [{"x": 0, "y": 1}, {"x": 0}], + ["${bar.a}", "${bar.b}"], + id="dictconfig_deep", + ), + param( + { + "foo": "${oc.dict.values:bar}", + "bar": {"key": "${val_ref}"}, + "val_ref": "value", + }, + "foo", + ["value"], + ["${bar.key}"], + id="dictconfig_with_interpolated_value", + ), + param( + { + "foo": "${oc.dict.values:bar}", + "bar": "${boz}", + "boz": {"a": 0, "b": 1}, + }, + "foo", + [0, 1], + ["${bar.a}", "${bar.b}"], + id="dictconfig_chained_interpolation", + ), + ], +) +def test_dict_values( + cfg: Any, key: Any, expected_val: Any, expected_content: Any +) -> None: + + cfg = OmegaConf.create(cfg) + val = cfg[key] + assert val == expected_val + assert isinstance(val, ListConfig) + assert val._parent is cfg + content = val._content + assert content == expected_content + + +def test_dict_values_with_missing_value() -> None: + cfg = OmegaConf.create({"foo": "${oc.dict.values:bar}", "bar": {"missing": "???"}}) + foo = cfg.foo + with raises(InterpolationToMissingValueError): + foo[0] + cfg.bar.missing = 1 + assert foo[0] == 1 + + +@mark.parametrize( + ("make_resolver", "key", "expected"), + [ + param( + lambda _parent_: OmegaConf.create({"x": 1}, parent=_parent_), + 0, + 1, + id="basic", + ), + param( + lambda _parent_: OmegaConf.create({"x": "${y}", "y": 1}, parent=_parent_), + 0, + 999, # referring to the `y` node from the global config + id="inter_abs", + ), + param( + lambda _parent_: OmegaConf.create({"x": "${.y}", "y": 1}, parent=_parent_), + 0, + 1, + id="inter_rel", + ), + param(lambda: OmegaConf.create({"x": 1}), 0, 1, id="basic_no_parent"), + param( + lambda: OmegaConf.create({"x": "${y}", "y": 1}), + 0, + 1, # no parent => referring to the local `y` node in the generated config + id="inter_abs_no_parent", + ), + param( + lambda: OmegaConf.create({"x": "${.y}", "y": 1}), + 0, + 1, + id="inter_rel_no_parent", + ), + ], +) +def test_dict_values_dictconfig_resolver_output( + restore_resolvers: Any, make_resolver: Any, key: Any, expected: Any +) -> None: + OmegaConf.register_new_resolver("make", make_resolver) + cfg = OmegaConf.create( + { + "foo": "${oc.dict.values:bar}", + "bar": "${make:}", + "y": 999, + } + ) + + assert cfg.foo[key] == expected + + +@mark.parametrize( + ("make_resolver", "expected_value", "expected_content"), + [ + param( + lambda _parent_: OmegaConf.create({"a": 0, "b": 1}, parent=_parent_), + [0, 1], + ["${y.a}", "${y.b}"], + id="dictconfig_with_parent", + ), + param( + lambda: {"a": 0, "b": 1}, + [0, 1], + ["${y.a}", "${y.b}"], + id="plain_dict", + ), + ], +) +def test_dict_values_transient_interpolation( + restore_resolvers: Any, + make_resolver: Any, + expected_value: Any, + expected_content: Any, +) -> None: + OmegaConf.register_new_resolver("make", make_resolver) + cfg = OmegaConf.create({"x": "${oc.dict.values:y}", "y": "${make:}"}) + assert cfg.x == expected_value + assert cfg.x._content == expected_content + + +def test_dict_values_are_typed() -> None: + cfg = OmegaConf.create( + { + "x": "${oc.dict.values: y.name2user}", + "y": Users( + name2user={ + "john": User(name="john", age=30), + "jane": User(name="jane", age=33), + } + ), + } + ) + x = cfg.x + assert x._metadata.ref_type == List[User] + assert x._metadata.element_type == User + + +@mark.parametrize( + ("cfg", "expected"), + [ + param({"x": "${oc.dict.values:y}", "y": {"a": 1}}, [1], id="values_inter"), + param({"x": "${oc.dict.keys:y}", "y": {"a": 1}}, ["a"], id="keys_inter"), + ], +) +def test_readonly_parent(cfg: Any, expected: Any) -> None: + cfg = OmegaConf.create(cfg) + cfg._set_flag("readonly", True) + assert cfg.x == expected + + +@mark.parametrize( + ("cfg", "expected"), + [ + param( + {"x": "${sum:${oc.dict.values:y}}", "y": {"one": 1, "two": 2}}, + 3, + id="values", + ), + param( + {"x": "${sum:${oc.dict.keys:y}}", "y": {1: "one", 2: "two"}}, + 3, + id="keys", + ), + ], +) +def test_nested_oc_dict(restore_resolvers: Any, cfg: Any, expected: Any) -> None: + OmegaConf.register_new_resolver("sum", sum) + cfg = OmegaConf.create(cfg) + assert cfg.x == expected + + +@mark.parametrize( + "cfg", + [ + param({"x": "${oc.dict.keys:[]}"}, id="list"), + param({"x": "${oc.dict.keys:${bool}}", "bool": True}, id="bool_interpolation"), + param({"x": "${oc.dict.keys:int}", "int": 0}, id="int_select"), + ], +) +def test_dict_keys_invalid_type(cfg: Any) -> None: + cfg = OmegaConf.create(cfg) + with raises(InterpolationResolutionError, match="TypeError"): + cfg.x + + +@mark.parametrize( + "cfg", + [ + param({"x": "${oc.dict.values:[]}"}, id="list"), + param( + {"x": "${oc.dict.values:${bool}}", "bool": True}, id="bool_interpolation" + ), + param({"x": "${oc.dict.values:int}", "int": 0}, id="int_select"), + ], +) +def test_dict_values_invalid_type(cfg: Any) -> None: + cfg = OmegaConf.create(cfg) + with raises(InterpolationResolutionError, match="TypeError"): + cfg.x + + +def test_dict_values_of_root(restore_resolvers: Any) -> None: + cfg = OmegaConf.create({"x": {"a": "${oc.dict.values:z}"}, "y": 0, "z": "${}"}) + a = cfg.x.a + assert a._content == ["${z.x}", "${z.y}", "${z.z}"] + assert a[1] == 0 + # We can recurse indefinitely within the first value if we fancy it. + assert cfg.x.a[0].a[0].a[1] == 0 diff --git a/tests/interpolation/test_custom_resolvers.py b/tests/interpolation/test_custom_resolvers.py index 3e3d9e351..42c7ce0ec 100644 --- a/tests/interpolation/test_custom_resolvers.py +++ b/tests/interpolation/test_custom_resolvers.py @@ -2,7 +2,7 @@ import re from typing import Any, Dict, List -from pytest import mark, raises, warns +from pytest import mark, param, raises, warns from omegaconf import DictConfig, ListConfig, OmegaConf, Resolver from tests.interpolation import dereference_node @@ -466,3 +466,25 @@ def parent_and_default(default: int = 10, *, _parent_: Any) -> Any: assert cfg.no_param == 20 assert cfg.param == 30 + + +@mark.parametrize( + ("cfg2", "expected"), + [ + param({"foo": {"b": 1}}, {"foo": {"a": 0, "b": 1}}, id="extend"), + param({"foo": {"b": "${.a}"}}, {"foo": {"a": 0, "b": 0}}, id="extend_inter"), + param({"foo": {"a": 1}}, {"foo": {"a": 1}}, id="override_int"), + param({"foo": {"a": {"b": 1}}}, {"foo": {"a": {"b": 1}}}, id="override_dict"), + param({"foo": 10}, {"foo": 10}, id="replace_interpolation"), + param({"bar": 10}, {"foo": {"a": 0}, "bar": 10}, id="other_node"), + ], +) +def test_merge_into_resolver_output( + restore_resolvers: Any, cfg2: Any, expected: Any +) -> None: + OmegaConf.register_new_resolver( + "make", lambda _parent_: OmegaConf.create({"a": 0}, parent=_parent_) + ) + + cfg = OmegaConf.create({"foo": "${make:}"}) + assert OmegaConf.merge(cfg, cfg2) == expected