Skip to content

Commit

Permalink
Add new resolvers oc.dict.keys and oc.dict.values (#644)
Browse files Browse the repository at this point in the history
  • Loading branch information
odelalleau committed Apr 11, 2021
1 parent e1a599f commit 8972c91
Show file tree
Hide file tree
Showing 7 changed files with 556 additions and 55 deletions.
34 changes: 33 additions & 1 deletion docs/source/usage.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
.. testsetup:: *

from omegaconf import OmegaConf, DictConfig, open_dict, read_write
from omegaconf import OmegaConf, DictConfig, ListConfig, open_dict, read_write
import os
import sys
import tempfile
Expand Down Expand Up @@ -466,6 +466,38 @@ This can be useful for instance to parse environment variables:
type: int, value: 3308


Extracting lists of keys / values from a dictionary
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Some config options that are stored as a ``DictConfig`` may sometimes be easier to manipulate as lists,
when we care only about the keys or the associated values.

The resolvers ``oc.dict.keys`` and ``oc.dict.values`` simplify such operations by offering an alternative
view of a dictionary's keys or values as a list.
They take as input a string that is the path to another config node (using the same syntax
as interpolations) and return a ``ListConfig`` with its keys / values.

.. doctest::

>>> cfg = OmegaConf.create(
... {
... "workers": {
... "node3": "10.0.0.2",
... "node7": "10.0.0.9",
... },
... "nodes": "${oc.dict.keys: workers}",
... "ips": "${oc.dict.values: workers}",
... }
... )
>>> # Keys are copied from the DictConfig:
>>> show(cfg.nodes)
type: ListConfig, value: ['node3', 'node7']
>>> # Values are dynamically fetched through interpolations:
>>> show(cfg.ips)
type: ListConfig, value: ['${workers.node3}', '${workers.node7}']
>>> assert cfg.ips == ["10.0.0.2", "10.0.0.9"]


Custom interpolations
^^^^^^^^^^^^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions news/643.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
New resolvers `oc.dict.keys` and `oc.dict.values` provide a list view of the keys or values of a DictConfig node.
1 change: 1 addition & 0 deletions omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ def _resolve_interpolation_from_parse_tree(
node that is created to wrap the interpolated value. It is `None` if and only if
`throw_on_resolution_failure` is `False` and an error occurs during resolution.
"""

try:
resolved = self.resolve_parse_tree(
parse_tree=parse_tree,
Expand Down
140 changes: 140 additions & 0 deletions omegaconf/built_in_resolvers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import os
import warnings
from typing import Any, List, Optional

from ._utils import _DEFAULT_MARKER_, Marker, _get_value, decode_primitive
from .base import Container
from .basecontainer import BaseContainer
from .dictconfig import DictConfig
from .errors import ConfigKeyError, ValidationError
from .grammar_parser import parse
from .listconfig import ListConfig
from .nodes import AnyNode
from .omegaconf import OmegaConf

# Special marker use as default value when calling `OmegaConf.select()`. It must be
# different from `_DEFAULT_MARKER_`, which is used by `OmegaConf.select()`.
_DEFAULT_SELECT_MARKER_: Any = Marker("_DEFAULT_SELECT_MARKER_")


def decode(expr: Optional[str], _parent_: Container) -> Any:
"""
Parse and evaluate `expr` according to the `singleElement` rule of the grammar.
If `expr` is `None`, then return `None`.
"""
if expr is None:
return None

if not isinstance(expr, str):
raise TypeError(
f"`oc.decode` can only take strings or None as input, "
f"but `{expr}` is of type {type(expr).__name__}"
)

parse_tree = parse(expr, parser_rule="singleElement", lexer_mode="VALUE_MODE")
val = _parent_.resolve_parse_tree(parse_tree)
return _get_value(val)


def dict_keys(
key: str,
_parent_: Container,
) -> ListConfig:
assert isinstance(_parent_, BaseContainer)

in_dict = _get_and_validate_dict_input(
key, parent=_parent_, resolver_name="oc.dict.keys"
)

ret = OmegaConf.create(list(in_dict.keys()), parent=_parent_)
assert isinstance(ret, ListConfig)
return ret


def dict_values(key: str, _root_: BaseContainer, _parent_: Container) -> ListConfig:
assert isinstance(_parent_, BaseContainer)
in_dict = _get_and_validate_dict_input(
key, parent=_parent_, resolver_name="oc.dict.values"
)

content = in_dict._content
assert isinstance(content, dict)

ret = ListConfig([])
for k in content:
ref_node = AnyNode(f"${{{key}.{k}}}")
ret.append(ref_node)

# Finalize result by setting proper type and parent.
element_type: Any = in_dict._metadata.element_type
ret._metadata.element_type = element_type
ret._metadata.ref_type = List[element_type]
ret._set_parent(_parent_)

return ret


def env(key: str, default: Any = _DEFAULT_MARKER_) -> Optional[str]:
"""
:param key: Environment variable key
:param default: Optional default value to use in case the key environment variable is not set.
If default is not a string, it is converted with str(default).
None default is returned as is.
:return: The environment variable 'key'. If the environment variable is not set and a default is
provided, the default is used. If used, the default is converted to a string with str(default).
If the default is None, None is returned (without a string conversion).
"""
try:
return os.environ[key]
except KeyError:
if default is not _DEFAULT_MARKER_:
return str(default) if default is not None else None
else:
raise KeyError(f"Environment variable '{key}' not found")


# DEPRECATED: remove in 2.2
def legacy_env(key: str, default: Optional[str] = None) -> Any:
warnings.warn(
"The `env` resolver is deprecated, see https://github.com/omry/omegaconf/issues/573"
)

try:
return decode_primitive(os.environ[key])
except KeyError:
if default is not None:
return decode_primitive(default)
else:
raise ValidationError(f"Environment variable '{key}' not found")


def _get_and_validate_dict_input(
key: str,
parent: BaseContainer,
resolver_name: str,
) -> DictConfig:
if not isinstance(key, str):
raise TypeError(
f"`{resolver_name}` requires a string as input, but obtained `{key}` "
f"of type: {type(key).__name__}"
)

in_dict = OmegaConf.select(
parent,
key,
throw_on_missing=True,
absolute_key=True,
default=_DEFAULT_SELECT_MARKER_,
)

if in_dict is _DEFAULT_SELECT_MARKER_:
raise ConfigKeyError(f"Key not found: '{key}'")

if not isinstance(in_dict, DictConfig):
raise TypeError(
f"`{resolver_name}` cannot be applied to objects of type: "
f"{type(in_dict).__name__}"
)

return in_dict
58 changes: 5 additions & 53 deletions omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
_ensure_container,
_get_value,
_is_none,
decode_primitive,
format_and_raise,
get_dict_key_value_types,
get_list_element_type,
Expand Down Expand Up @@ -60,7 +59,6 @@
UnsupportedInterpolationType,
ValidationError,
)
from .grammar_parser import parse
from .nodes import (
AnyNode,
BooleanNode,
Expand Down Expand Up @@ -95,60 +93,14 @@ def SI(interpolation: str) -> Any:


def register_default_resolvers() -> None:
# DEPRECATED: remove in 2.2
def legacy_env(key: str, default: Optional[str] = None) -> Any:
warnings.warn(
"The `env` resolver is deprecated, see https://github.com/omry/omegaconf/issues/573"
)

try:
return decode_primitive(os.environ[key])
except KeyError:
if default is not None:
return decode_primitive(default)
else:
raise ValidationError(f"Environment variable '{key}' not found")

def env(key: str, default: Any = _DEFAULT_MARKER_) -> Optional[str]:
"""
:param key: Environment variable key
:param default: Optional default value to use in case the key environment variable is not set.
If default is not a string, it is converted with str(default).
None default is returned as is.
:return: The environment variable 'key'. If the environment variable is not set and a default is
provided, the default is used. If used, the default is converted to a string with str(default).
If the default is None, None is returned (without a string conversion).
"""
try:
return os.environ[key]
except KeyError:
if default is not _DEFAULT_MARKER_:
return str(default) if default is not None else None
else:
raise KeyError(f"Environment variable '{key}' not found")

def decode(expr: Optional[str], _parent_: Container) -> Any:
"""
Parse and evaluate `expr` according to the `singleElement` rule of the grammar.
If `expr` is `None`, then return `None`.
"""
if expr is None:
return None

if not isinstance(expr, str):
raise TypeError(
f"`oc.decode` can only take strings or None as input, "
f"but `{expr}` is of type {type(expr).__name__}"
)
from .built_in_resolvers import decode, dict_keys, dict_values, env, legacy_env

parse_tree = parse(expr, parser_rule="singleElement", lexer_mode="VALUE_MODE")
val = _parent_.resolve_parse_tree(parse_tree)
return _get_value(val)
OmegaConf.register_new_resolver("oc.decode", decode)
OmegaConf.register_new_resolver("oc.dict.keys", dict_keys)
OmegaConf.register_new_resolver("oc.dict.values", dict_values)
OmegaConf.register_new_resolver("oc.env", env)

OmegaConf.legacy_register_resolver("env", legacy_env)
OmegaConf.register_new_resolver("oc.env", env, use_cache=False)
OmegaConf.register_new_resolver("oc.decode", decode, use_cache=False)


class OmegaConf:
Expand Down
Loading

0 comments on commit 8972c91

Please sign in to comment.