From 8972c91577b6c5c205c32778842b169bbe37013d Mon Sep 17 00:00:00 2001 From: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> Date: Sun, 11 Apr 2021 02:49:04 -0400 Subject: [PATCH] Add new resolvers `oc.dict.keys` and `oc.dict.values` (#644) --- docs/source/usage.rst | 34 +- news/643.feature | 1 + omegaconf/base.py | 1 + omegaconf/built_in_resolvers.py | 140 +++++++ omegaconf/omegaconf.py | 58 +-- .../built_in_resolvers/test_dict.py | 353 ++++++++++++++++++ tests/interpolation/test_custom_resolvers.py | 24 +- 7 files changed, 556 insertions(+), 55 deletions(-) create mode 100644 news/643.feature create mode 100644 omegaconf/built_in_resolvers.py create mode 100644 tests/interpolation/built_in_resolvers/test_dict.py diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 0e56495e0..e21afe276 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -1,6 +1,6 @@ .. testsetup:: * - from omegaconf import OmegaConf, DictConfig, open_dict, read_write + from omegaconf import OmegaConf, DictConfig, ListConfig, open_dict, read_write import os import sys import tempfile @@ -466,6 +466,38 @@ This can be useful for instance to parse environment variables: type: int, value: 3308 +Extracting lists of keys / values from a dictionary +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Some config options that are stored as a ``DictConfig`` may sometimes be easier to manipulate as lists, +when we care only about the keys or the associated values. + +The resolvers ``oc.dict.keys`` and ``oc.dict.values`` simplify such operations by offering an alternative +view of a dictionary's keys or values as a list. +They take as input a string that is the path to another config node (using the same syntax +as interpolations) and return a ``ListConfig`` with its keys / values. + +.. doctest:: + + >>> cfg = OmegaConf.create( + ... { + ... "workers": { + ... "node3": "10.0.0.2", + ... "node7": "10.0.0.9", + ... }, + ... "nodes": "${oc.dict.keys: workers}", + ... "ips": "${oc.dict.values: workers}", + ... } + ... ) + >>> # Keys are copied from the DictConfig: + >>> show(cfg.nodes) + type: ListConfig, value: ['node3', 'node7'] + >>> # Values are dynamically fetched through interpolations: + >>> show(cfg.ips) + type: ListConfig, value: ['${workers.node3}', '${workers.node7}'] + >>> assert cfg.ips == ["10.0.0.2", "10.0.0.9"] + + Custom interpolations ^^^^^^^^^^^^^^^^^^^^^ diff --git a/news/643.feature b/news/643.feature new file mode 100644 index 000000000..a7fd0f083 --- /dev/null +++ b/news/643.feature @@ -0,0 +1 @@ +New resolvers `oc.dict.keys` and `oc.dict.values` provide a list view of the keys or values of a DictConfig node. diff --git a/omegaconf/base.py b/omegaconf/base.py index 33bc9f996..1a9082525 100644 --- a/omegaconf/base.py +++ b/omegaconf/base.py @@ -451,6 +451,7 @@ def _resolve_interpolation_from_parse_tree( node that is created to wrap the interpolated value. It is `None` if and only if `throw_on_resolution_failure` is `False` and an error occurs during resolution. """ + try: resolved = self.resolve_parse_tree( parse_tree=parse_tree, diff --git a/omegaconf/built_in_resolvers.py b/omegaconf/built_in_resolvers.py new file mode 100644 index 000000000..885d6e4b3 --- /dev/null +++ b/omegaconf/built_in_resolvers.py @@ -0,0 +1,140 @@ +import os +import warnings +from typing import Any, List, Optional + +from ._utils import _DEFAULT_MARKER_, Marker, _get_value, decode_primitive +from .base import Container +from .basecontainer import BaseContainer +from .dictconfig import DictConfig +from .errors import ConfigKeyError, ValidationError +from .grammar_parser import parse +from .listconfig import ListConfig +from .nodes import AnyNode +from .omegaconf import OmegaConf + +# Special marker use as default value when calling `OmegaConf.select()`. It must be +# different from `_DEFAULT_MARKER_`, which is used by `OmegaConf.select()`. +_DEFAULT_SELECT_MARKER_: Any = Marker("_DEFAULT_SELECT_MARKER_") + + +def decode(expr: Optional[str], _parent_: Container) -> Any: + """ + Parse and evaluate `expr` according to the `singleElement` rule of the grammar. + + If `expr` is `None`, then return `None`. + """ + if expr is None: + return None + + if not isinstance(expr, str): + raise TypeError( + f"`oc.decode` can only take strings or None as input, " + f"but `{expr}` is of type {type(expr).__name__}" + ) + + parse_tree = parse(expr, parser_rule="singleElement", lexer_mode="VALUE_MODE") + val = _parent_.resolve_parse_tree(parse_tree) + return _get_value(val) + + +def dict_keys( + key: str, + _parent_: Container, +) -> ListConfig: + assert isinstance(_parent_, BaseContainer) + + in_dict = _get_and_validate_dict_input( + key, parent=_parent_, resolver_name="oc.dict.keys" + ) + + ret = OmegaConf.create(list(in_dict.keys()), parent=_parent_) + assert isinstance(ret, ListConfig) + return ret + + +def dict_values(key: str, _root_: BaseContainer, _parent_: Container) -> ListConfig: + assert isinstance(_parent_, BaseContainer) + in_dict = _get_and_validate_dict_input( + key, parent=_parent_, resolver_name="oc.dict.values" + ) + + content = in_dict._content + assert isinstance(content, dict) + + ret = ListConfig([]) + for k in content: + ref_node = AnyNode(f"${{{key}.{k}}}") + ret.append(ref_node) + + # Finalize result by setting proper type and parent. + element_type: Any = in_dict._metadata.element_type + ret._metadata.element_type = element_type + ret._metadata.ref_type = List[element_type] + ret._set_parent(_parent_) + + return ret + + +def env(key: str, default: Any = _DEFAULT_MARKER_) -> Optional[str]: + """ + :param key: Environment variable key + :param default: Optional default value to use in case the key environment variable is not set. + If default is not a string, it is converted with str(default). + None default is returned as is. + :return: The environment variable 'key'. If the environment variable is not set and a default is + provided, the default is used. If used, the default is converted to a string with str(default). + If the default is None, None is returned (without a string conversion). + """ + try: + return os.environ[key] + except KeyError: + if default is not _DEFAULT_MARKER_: + return str(default) if default is not None else None + else: + raise KeyError(f"Environment variable '{key}' not found") + + +# DEPRECATED: remove in 2.2 +def legacy_env(key: str, default: Optional[str] = None) -> Any: + warnings.warn( + "The `env` resolver is deprecated, see https://github.com/omry/omegaconf/issues/573" + ) + + try: + return decode_primitive(os.environ[key]) + except KeyError: + if default is not None: + return decode_primitive(default) + else: + raise ValidationError(f"Environment variable '{key}' not found") + + +def _get_and_validate_dict_input( + key: str, + parent: BaseContainer, + resolver_name: str, +) -> DictConfig: + if not isinstance(key, str): + raise TypeError( + f"`{resolver_name}` requires a string as input, but obtained `{key}` " + f"of type: {type(key).__name__}" + ) + + in_dict = OmegaConf.select( + parent, + key, + throw_on_missing=True, + absolute_key=True, + default=_DEFAULT_SELECT_MARKER_, + ) + + if in_dict is _DEFAULT_SELECT_MARKER_: + raise ConfigKeyError(f"Key not found: '{key}'") + + if not isinstance(in_dict, DictConfig): + raise TypeError( + f"`{resolver_name}` cannot be applied to objects of type: " + f"{type(in_dict).__name__}" + ) + + return in_dict diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index e32059cda..b2390c193 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -32,7 +32,6 @@ _ensure_container, _get_value, _is_none, - decode_primitive, format_and_raise, get_dict_key_value_types, get_list_element_type, @@ -60,7 +59,6 @@ UnsupportedInterpolationType, ValidationError, ) -from .grammar_parser import parse from .nodes import ( AnyNode, BooleanNode, @@ -95,60 +93,14 @@ def SI(interpolation: str) -> Any: def register_default_resolvers() -> None: - # DEPRECATED: remove in 2.2 - def legacy_env(key: str, default: Optional[str] = None) -> Any: - warnings.warn( - "The `env` resolver is deprecated, see https://github.com/omry/omegaconf/issues/573" - ) - - try: - return decode_primitive(os.environ[key]) - except KeyError: - if default is not None: - return decode_primitive(default) - else: - raise ValidationError(f"Environment variable '{key}' not found") - - def env(key: str, default: Any = _DEFAULT_MARKER_) -> Optional[str]: - """ - :param key: Environment variable key - :param default: Optional default value to use in case the key environment variable is not set. - If default is not a string, it is converted with str(default). - None default is returned as is. - :return: The environment variable 'key'. If the environment variable is not set and a default is - provided, the default is used. If used, the default is converted to a string with str(default). - If the default is None, None is returned (without a string conversion). - """ - try: - return os.environ[key] - except KeyError: - if default is not _DEFAULT_MARKER_: - return str(default) if default is not None else None - else: - raise KeyError(f"Environment variable '{key}' not found") - - def decode(expr: Optional[str], _parent_: Container) -> Any: - """ - Parse and evaluate `expr` according to the `singleElement` rule of the grammar. - - If `expr` is `None`, then return `None`. - """ - if expr is None: - return None - - if not isinstance(expr, str): - raise TypeError( - f"`oc.decode` can only take strings or None as input, " - f"but `{expr}` is of type {type(expr).__name__}" - ) + from .built_in_resolvers import decode, dict_keys, dict_values, env, legacy_env - parse_tree = parse(expr, parser_rule="singleElement", lexer_mode="VALUE_MODE") - val = _parent_.resolve_parse_tree(parse_tree) - return _get_value(val) + OmegaConf.register_new_resolver("oc.decode", decode) + OmegaConf.register_new_resolver("oc.dict.keys", dict_keys) + OmegaConf.register_new_resolver("oc.dict.values", dict_values) + OmegaConf.register_new_resolver("oc.env", env) OmegaConf.legacy_register_resolver("env", legacy_env) - OmegaConf.register_new_resolver("oc.env", env, use_cache=False) - OmegaConf.register_new_resolver("oc.decode", decode, use_cache=False) class OmegaConf: diff --git a/tests/interpolation/built_in_resolvers/test_dict.py b/tests/interpolation/built_in_resolvers/test_dict.py new file mode 100644 index 000000000..d96cfa92a --- /dev/null +++ b/tests/interpolation/built_in_resolvers/test_dict.py @@ -0,0 +1,353 @@ +import re +from typing import Any, List + +from pytest import mark, param, raises + +from omegaconf import DictConfig, ListConfig, OmegaConf +from omegaconf.built_in_resolvers import _get_and_validate_dict_input +from omegaconf.errors import ( + InterpolationResolutionError, + InterpolationToMissingValueError, +) +from tests import User, Users + + +@mark.parametrize( + ("cfg", "key", "expected"), + [ + param( + {"foo": "${oc.dict.keys:bar}", "bar": {"a": 0, "b": 1}}, + "foo", + ["a", "b"], + id="dictconfig", + ), + param( + {"foo": "${oc.dict.keys:bar}", "bar": "${boz}", "boz": {"a": 0, "b": 1}}, + "foo", + ["a", "b"], + id="dictconfig_chained_interpolation", + ), + ], +) +def test_dict_keys(cfg: Any, key: Any, expected: Any) -> None: + cfg = OmegaConf.create(cfg) + val = cfg[key] + assert val == expected + assert isinstance(val, ListConfig) + assert val._parent is cfg + + +@mark.parametrize( + ("cfg", "key", "expected"), + [ + param( + {"x": "${oc.dict.keys_or_values:y}", "y": "???"}, + "x", + raises( + InterpolationResolutionError, + match=re.escape( + "MissingMandatoryValue raised while resolving interpolation: " + "Missing mandatory value: y" + ), + ), + id="select_missing", + ), + param( + {"foo": "${oc.dict.keys_or_values:bar}"}, + "foo", + raises( + InterpolationResolutionError, + match=re.escape( + "ConfigKeyError raised while resolving interpolation: " + "Key not found: 'bar'" + ), + ), + id="config_key_error", + ), + param( + # This might be allowed in the future. Currently it fails. + {"foo": "${oc.dict.keys_or_values:''}"}, + "foo", + raises( + InterpolationResolutionError, + match=re.escape( + "ConfigKeyError raised while resolving interpolation: " + "Key not found: ''" + ), + ), + id="config_key_error_empty", + ), + param( + {"foo": "${oc.dict.keys_or_values:bar}", "bar": 0}, + "foo", + raises( + InterpolationResolutionError, + match=re.escape( + "TypeError raised while resolving interpolation: " + "`oc.dict.keys_or_values` cannot be applied to objects of type: int" + ), + ), + id="type_error", + ), + param( + {"foo": "${oc.dict.keys_or_values:bar}", "bar": DictConfig(None)}, + "foo", + raises( + InterpolationResolutionError, + match=re.escape( + "TypeError raised while resolving interpolation: " + "`oc.dict.keys_or_values` cannot be applied to objects of type: NoneType" + ), + ), + id="type_error_dictconfig", + ), + ], +) +def test_get_and_validate_dict_input( + restore_resolvers: Any, cfg: Any, key: Any, expected: Any +) -> None: + OmegaConf.register_new_resolver( + "oc.dict.keys_or_values", + lambda in_dict, _parent_: _get_and_validate_dict_input( + in_dict, parent=_parent_, resolver_name="oc.dict.keys_or_values" + ), + ) + cfg = OmegaConf.create(cfg) + with expected: + cfg[key] + + +@mark.parametrize( + ("cfg", "key", "expected_val", "expected_content"), + [ + param( + {"foo": "${oc.dict.values:bar}", "bar": {"a": 0, "b": 1}}, + "foo", + [0, 1], + ["${bar.a}", "${bar.b}"], + id="dictconfig", + ), + param( + { + "foo": "${oc.dict.values:bar}", + "bar": {"a": {"x": 0, "y": 1}, "b": {"x": 0}}, + }, + "foo", + [{"x": 0, "y": 1}, {"x": 0}], + ["${bar.a}", "${bar.b}"], + id="dictconfig_deep", + ), + param( + { + "foo": "${oc.dict.values:bar}", + "bar": {"key": "${val_ref}"}, + "val_ref": "value", + }, + "foo", + ["value"], + ["${bar.key}"], + id="dictconfig_with_interpolated_value", + ), + param( + { + "foo": "${oc.dict.values:bar}", + "bar": "${boz}", + "boz": {"a": 0, "b": 1}, + }, + "foo", + [0, 1], + ["${bar.a}", "${bar.b}"], + id="dictconfig_chained_interpolation", + ), + ], +) +def test_dict_values( + cfg: Any, key: Any, expected_val: Any, expected_content: Any +) -> None: + + cfg = OmegaConf.create(cfg) + val = cfg[key] + assert val == expected_val + assert isinstance(val, ListConfig) + assert val._parent is cfg + content = val._content + assert content == expected_content + + +def test_dict_values_with_missing_value() -> None: + cfg = OmegaConf.create({"foo": "${oc.dict.values:bar}", "bar": {"missing": "???"}}) + foo = cfg.foo + with raises(InterpolationToMissingValueError): + foo[0] + cfg.bar.missing = 1 + assert foo[0] == 1 + + +@mark.parametrize( + ("make_resolver", "key", "expected"), + [ + param( + lambda _parent_: OmegaConf.create({"x": 1}, parent=_parent_), + 0, + 1, + id="basic", + ), + param( + lambda _parent_: OmegaConf.create({"x": "${y}", "y": 1}, parent=_parent_), + 0, + 999, # referring to the `y` node from the global config + id="inter_abs", + ), + param( + lambda _parent_: OmegaConf.create({"x": "${.y}", "y": 1}, parent=_parent_), + 0, + 1, + id="inter_rel", + ), + param(lambda: OmegaConf.create({"x": 1}), 0, 1, id="basic_no_parent"), + param( + lambda: OmegaConf.create({"x": "${y}", "y": 1}), + 0, + 1, # no parent => referring to the local `y` node in the generated config + id="inter_abs_no_parent", + ), + param( + lambda: OmegaConf.create({"x": "${.y}", "y": 1}), + 0, + 1, + id="inter_rel_no_parent", + ), + ], +) +def test_dict_values_dictconfig_resolver_output( + restore_resolvers: Any, make_resolver: Any, key: Any, expected: Any +) -> None: + OmegaConf.register_new_resolver("make", make_resolver) + cfg = OmegaConf.create( + { + "foo": "${oc.dict.values:bar}", + "bar": "${make:}", + "y": 999, + } + ) + + assert cfg.foo[key] == expected + + +@mark.parametrize( + ("make_resolver", "expected_value", "expected_content"), + [ + param( + lambda _parent_: OmegaConf.create({"a": 0, "b": 1}, parent=_parent_), + [0, 1], + ["${y.a}", "${y.b}"], + id="dictconfig_with_parent", + ), + param( + lambda: {"a": 0, "b": 1}, + [0, 1], + ["${y.a}", "${y.b}"], + id="plain_dict", + ), + ], +) +def test_dict_values_transient_interpolation( + restore_resolvers: Any, + make_resolver: Any, + expected_value: Any, + expected_content: Any, +) -> None: + OmegaConf.register_new_resolver("make", make_resolver) + cfg = OmegaConf.create({"x": "${oc.dict.values:y}", "y": "${make:}"}) + assert cfg.x == expected_value + assert cfg.x._content == expected_content + + +def test_dict_values_are_typed() -> None: + cfg = OmegaConf.create( + { + "x": "${oc.dict.values: y.name2user}", + "y": Users( + name2user={ + "john": User(name="john", age=30), + "jane": User(name="jane", age=33), + } + ), + } + ) + x = cfg.x + assert x._metadata.ref_type == List[User] + assert x._metadata.element_type == User + + +@mark.parametrize( + ("cfg", "expected"), + [ + param({"x": "${oc.dict.values:y}", "y": {"a": 1}}, [1], id="values_inter"), + param({"x": "${oc.dict.keys:y}", "y": {"a": 1}}, ["a"], id="keys_inter"), + ], +) +def test_readonly_parent(cfg: Any, expected: Any) -> None: + cfg = OmegaConf.create(cfg) + cfg._set_flag("readonly", True) + assert cfg.x == expected + + +@mark.parametrize( + ("cfg", "expected"), + [ + param( + {"x": "${sum:${oc.dict.values:y}}", "y": {"one": 1, "two": 2}}, + 3, + id="values", + ), + param( + {"x": "${sum:${oc.dict.keys:y}}", "y": {1: "one", 2: "two"}}, + 3, + id="keys", + ), + ], +) +def test_nested_oc_dict(restore_resolvers: Any, cfg: Any, expected: Any) -> None: + OmegaConf.register_new_resolver("sum", sum) + cfg = OmegaConf.create(cfg) + assert cfg.x == expected + + +@mark.parametrize( + "cfg", + [ + param({"x": "${oc.dict.keys:[]}"}, id="list"), + param({"x": "${oc.dict.keys:${bool}}", "bool": True}, id="bool_interpolation"), + param({"x": "${oc.dict.keys:int}", "int": 0}, id="int_select"), + ], +) +def test_dict_keys_invalid_type(cfg: Any) -> None: + cfg = OmegaConf.create(cfg) + with raises(InterpolationResolutionError, match="TypeError"): + cfg.x + + +@mark.parametrize( + "cfg", + [ + param({"x": "${oc.dict.values:[]}"}, id="list"), + param( + {"x": "${oc.dict.values:${bool}}", "bool": True}, id="bool_interpolation" + ), + param({"x": "${oc.dict.values:int}", "int": 0}, id="int_select"), + ], +) +def test_dict_values_invalid_type(cfg: Any) -> None: + cfg = OmegaConf.create(cfg) + with raises(InterpolationResolutionError, match="TypeError"): + cfg.x + + +def test_dict_values_of_root(restore_resolvers: Any) -> None: + cfg = OmegaConf.create({"x": {"a": "${oc.dict.values:z}"}, "y": 0, "z": "${}"}) + a = cfg.x.a + assert a._content == ["${z.x}", "${z.y}", "${z.z}"] + assert a[1] == 0 + # We can recurse indefinitely within the first value if we fancy it. + assert cfg.x.a[0].a[0].a[1] == 0 diff --git a/tests/interpolation/test_custom_resolvers.py b/tests/interpolation/test_custom_resolvers.py index 3e3d9e351..42c7ce0ec 100644 --- a/tests/interpolation/test_custom_resolvers.py +++ b/tests/interpolation/test_custom_resolvers.py @@ -2,7 +2,7 @@ import re from typing import Any, Dict, List -from pytest import mark, raises, warns +from pytest import mark, param, raises, warns from omegaconf import DictConfig, ListConfig, OmegaConf, Resolver from tests.interpolation import dereference_node @@ -466,3 +466,25 @@ def parent_and_default(default: int = 10, *, _parent_: Any) -> Any: assert cfg.no_param == 20 assert cfg.param == 30 + + +@mark.parametrize( + ("cfg2", "expected"), + [ + param({"foo": {"b": 1}}, {"foo": {"a": 0, "b": 1}}, id="extend"), + param({"foo": {"b": "${.a}"}}, {"foo": {"a": 0, "b": 0}}, id="extend_inter"), + param({"foo": {"a": 1}}, {"foo": {"a": 1}}, id="override_int"), + param({"foo": {"a": {"b": 1}}}, {"foo": {"a": {"b": 1}}}, id="override_dict"), + param({"foo": 10}, {"foo": 10}, id="replace_interpolation"), + param({"bar": 10}, {"foo": {"a": 0}, "bar": 10}, id="other_node"), + ], +) +def test_merge_into_resolver_output( + restore_resolvers: Any, cfg2: Any, expected: Any +) -> None: + OmegaConf.register_new_resolver( + "make", lambda _parent_: OmegaConf.create({"a": 0}, parent=_parent_) + ) + + cfg = OmegaConf.create({"foo": "${make:}"}) + assert OmegaConf.merge(cfg, cfg2) == expected