Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dict keys values #8

Draft
wants to merge 17 commits into
base: tmp
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion docs/source/usage.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
.. testsetup:: *

from omegaconf import OmegaConf, DictConfig, open_dict, read_write
from omegaconf import OmegaConf, DictConfig, ListConfig, open_dict, read_write
import os
import sys
import tempfile
Expand Down Expand Up @@ -466,6 +466,38 @@ This can be useful for instance to parse environment variables:
type: int, value: 3308


Extracting lists of keys / values from a dictionary
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Some config options that are stored as a ``DictConfig`` may sometimes be easier to manipulate as lists,
when we care only about the keys or the associated values.

The resolvers ``oc.dict.keys`` and ``oc.dict.values`` simplify such operations by offering an alternative
view of a dictionary's keys or values as a list.
They take as input a string that is the path to another config node (using the same syntax
as interpolations) and return a ``ListConfig`` with its keys / values.

.. doctest::

>>> cfg = OmegaConf.create(
... {
... "workers": {
... "node3": "10.0.0.2",
... "node7": "10.0.0.9",
... },
... "nodes": "${oc.dict.keys: workers}",
... "ips": "${oc.dict.values: workers}",
... }
... )
>>> # Keys are copied from the DictConfig:
>>> show(cfg.nodes)
type: ListConfig, value: ['node3', 'node7']
>>> # Values are dynamically fetched through interpolations:
>>> show(cfg.ips)
type: ListConfig, value: ['${workers.node3}', '${workers.node7}']
>>> assert cfg.ips == ["10.0.0.2", "10.0.0.9"]


Custom interpolations
^^^^^^^^^^^^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions news/643.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
New resolvers `oc.dict.keys` and `oc.dict.values` provide a list view of the keys or values of a DictConfig node.
1 change: 1 addition & 0 deletions omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ def _resolve_interpolation_from_parse_tree(
node that is created to wrap the interpolated value. It is `None` if and only if
`throw_on_resolution_failure` is `False` and an error occurs during resolution.
"""

try:
resolved = self.resolve_parse_tree(
parse_tree=parse_tree,
Expand Down
84 changes: 81 additions & 3 deletions omegaconf/built_in_resolvers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
import os
import warnings
from typing import Any, Optional
from typing import Any, List, Optional

from ._utils import _DEFAULT_MARKER_, _get_value, decode_primitive
from ._utils import _DEFAULT_MARKER_, Marker, _get_value, decode_primitive
from .base import Container
from .errors import ValidationError
from .basecontainer import BaseContainer
from .dictconfig import DictConfig
from .errors import ConfigKeyError, ValidationError
from .grammar_parser import parse
from .listconfig import ListConfig
from .nodes import AnyNode
from .omegaconf import OmegaConf

# Special marker use as default value when calling `OmegaConf.select()`. It must be
# different from `_DEFAULT_MARKER_`, which is used by `OmegaConf.select()`.
_DEFAULT_SELECT_MARKER_: Any = Marker("_DEFAULT_SELECT_MARKER_")


def decode(expr: Optional[str], _parent_: Container) -> Any:
Expand All @@ -28,6 +37,44 @@ def decode(expr: Optional[str], _parent_: Container) -> Any:
return _get_value(val)


def dict_keys(
key: str,
_parent_: Container,
) -> ListConfig:
assert isinstance(_parent_, BaseContainer)

in_dict = _get_and_validate_dict_input(
key, parent=_parent_, resolver_name="oc.dict.keys"
)

ret = OmegaConf.create(list(in_dict.keys()), parent=_parent_)
assert isinstance(ret, ListConfig)
return ret


def dict_values(key: str, _root_: BaseContainer, _parent_: Container) -> ListConfig:
assert isinstance(_parent_, BaseContainer)
in_dict = _get_and_validate_dict_input(
key, parent=_parent_, resolver_name="oc.dict.values"
)

content = in_dict._content
assert isinstance(content, dict)

ret = ListConfig([])
for k in content:
ref_node = AnyNode(f"${{{key}.{k}}}")
ret.append(ref_node)

# Finalize result by setting proper type and parent.
element_type: Any = in_dict._metadata.element_type
ret._metadata.element_type = element_type
ret._metadata.ref_type = List[element_type]
ret._set_parent(_parent_)

return ret


def env(key: str, default: Any = _DEFAULT_MARKER_) -> Optional[str]:
"""
:param key: Environment variable key
Expand Down Expand Up @@ -60,3 +107,34 @@ def legacy_env(key: str, default: Optional[str] = None) -> Any:
return decode_primitive(default)
else:
raise ValidationError(f"Environment variable '{key}' not found")


def _get_and_validate_dict_input(
key: str,
parent: BaseContainer,
resolver_name: str,
) -> DictConfig:
if not isinstance(key, str):
raise TypeError(
f"`{resolver_name}` requires a string as input, but obtained `{key}` "
f"of type: {type(key).__name__}"
)

in_dict = OmegaConf.select(
parent,
key,
throw_on_missing=True,
absolute_key=True,
default=_DEFAULT_SELECT_MARKER_,
)

if in_dict is _DEFAULT_SELECT_MARKER_:
raise ConfigKeyError(f"Key not found: '{key}'")

if not isinstance(in_dict, DictConfig):
raise TypeError(
f"`{resolver_name}` cannot be applied to objects of type: "
f"{type(in_dict).__name__}"
)

return in_dict
9 changes: 6 additions & 3 deletions omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,14 @@ def SI(interpolation: str) -> Any:


def register_default_resolvers() -> None:
from .built_in_resolvers import decode, env, legacy_env
from .built_in_resolvers import decode, dict_keys, dict_values, env, legacy_env

OmegaConf.register_new_resolver("oc.decode", decode)
OmegaConf.register_new_resolver("oc.dict.keys", dict_keys)
OmegaConf.register_new_resolver("oc.dict.values", dict_values)
OmegaConf.register_new_resolver("oc.env", env)

OmegaConf.legacy_register_resolver("env", legacy_env)
OmegaConf.register_new_resolver("oc.env", env, use_cache=False)
OmegaConf.register_new_resolver("oc.decode", decode, use_cache=False)


class OmegaConf:
Expand Down
Loading