Skip to content

Commit

Permalink
New implementation of oc.dict.values based on interpolations
Browse files Browse the repository at this point in the history
Of particular note:

* When the result of an interpolation is a node whose parent is the
  current node's parent, but has no key, then we set its key to the
  current node's key. This makes it possible to use its full key as an
  identifier.

* _get_and_validate_dict_input() now properly raises an exception if the
  desired key does not exist
  • Loading branch information
odelalleau committed Mar 29, 2021
1 parent f2555ae commit 2e604f9
Show file tree
Hide file tree
Showing 4 changed files with 347 additions and 26 deletions.
16 changes: 11 additions & 5 deletions docs/source/usage.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
.. testsetup:: *

from omegaconf import OmegaConf, DictConfig, open_dict, read_write
from omegaconf import OmegaConf, DictConfig, ListConfig, open_dict, read_write
import os
import sys
import tempfile
Expand Down Expand Up @@ -483,10 +483,16 @@ If a string is given as input, ``OmegaConf.select()`` is used to access the corr
... "ips": "${oc.dict.values:machines}",
... }
... )
>>> show(cfg.nodes)
type: ListConfig, value: ['node007', 'node012', 'node075']
>>> show(cfg.ips)
type: ListConfig, value: ['10.0.0.7', '10.0.0.3', '10.0.1.8']
>>> nodes = cfg.nodes
>>> ips = cfg.ips
>>> # The corresponding lists of keys / values are ListConfig nodes.
>>> assert isinstance(nodes, ListConfig)
>>> assert isinstance(ips, ListConfig)
>>> assert nodes == ['node007', 'node012', 'node075']
>>> assert ips == ['10.0.0.7', '10.0.0.3', '10.0.1.8']
>>> # Values are dynamically updated with the underlying dict.
>>> cfg.machines.node012 = "10.0.0.5"
>>> assert ips[1] == "10.0.0.5"


Custom interpolations
Expand Down
11 changes: 11 additions & 0 deletions omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ def _resolve_interpolation_from_parse_tree(
node that is created to wrap the interpolated value. It is `None` if and only if
`throw_on_resolution_failure` is `False` and an error occurs during resolution.
"""

try:
resolved = self.resolve_parse_tree(
parse_tree=parse_tree,
Expand Down Expand Up @@ -515,6 +516,16 @@ def _validate_and_convert_interpolation_result(
)
else:
assert isinstance(resolved, Node)
if (
parent is not None
and resolved._parent is parent
and resolved._key() is None
and value._key() is not None
):
# The interpolation is returning a transient (key-less) node attached
# to the current parent. By setting its key to this node's key, we make
# it possible to refer to it through an interpolation path.
resolved._set_key(value._key())
return resolved

def _wrap_interpolation_result(
Expand Down
63 changes: 59 additions & 4 deletions omegaconf/built_in_resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,22 @@
import warnings

# from collections.abc import Mapping, MutableMapping
from typing import Any, Mapping, Optional, Union
from typing import Any, List, Mapping, Optional, Union

from ._utils import _DEFAULT_MARKER_, _get_value, decode_primitive
from ._utils import _DEFAULT_MARKER_, Marker, _get_value, decode_primitive
from .base import Container
from .basecontainer import BaseContainer
from .errors import ValidationError
from .dictconfig import DictConfig
from .errors import ConfigKeyError, ValidationError
from .grammar_parser import parse
from .listconfig import ListConfig
from .nodes import AnyNode
from .omegaconf import OmegaConf

# Special marker use as default value when calling `OmegaConf.select()`. It must be
# different from `_DEFAULT_MARKER_`, which is used by `OmegaConf.select()`.
_DEFAULT_SELECT_MARKER_: Any = Marker("_DEFAULT_SELECT_MARKER_")


def decode(expr: Optional[str], _parent_: Container) -> Any:
"""
Expand Down Expand Up @@ -42,6 +48,7 @@ def dict_keys(
in_dict, root=_root_, resolver_name="oc.dict.keys"
)
assert isinstance(_parent_, BaseContainer)

ret = OmegaConf.create(list(in_dict.keys()), parent=_parent_)
assert isinstance(ret, ListConfig)
return ret
Expand All @@ -53,6 +60,49 @@ def dict_values(
in_dict = _get_and_validate_dict_input(
in_dict, root=_root_, resolver_name="oc.dict.values"
)

if isinstance(in_dict, DictConfig):
# DictConfig objects are handled in a special way: the goal is to make the
# returned ListConfig point to the DictConfig nodes through interpolations.

dict_key: Optional[str] = None
if in_dict._get_root() is _root_:
# Try to obtain the full key through which we can access `in_dict`.
if in_dict is _root_:
dict_key = ""
else:
dict_key = in_dict._get_full_key(None)
if dict_key:
dict_key += "." # append dot for future concatenation
else:
# This can happen e.g. if `in_dict` is a transient node.
dict_key = None

if dict_key is None:
# No path to `in_dict` in the existing config.
raise NotImplementedError(
"`oc.dict.values` currently only supports input config nodes that "
"are accessible through the root config. See "
"https://github.com/omry/omegaconf/issues/650 for details."
)

ret = ListConfig([])
content = in_dict._content
assert isinstance(content, dict)

for key, node in content.items():
ref_node = AnyNode(f"${{{dict_key}{key}}}")
ret.append(ref_node)

# Finalize result by setting proper type and parent.
element_type: Any = in_dict._metadata.element_type
ret._metadata.element_type = element_type
ret._metadata.ref_type = List[element_type]
ret._set_parent(_parent_)

return ret

# Other dict-like object: simply create a ListConfig from its values.
assert isinstance(_parent_, BaseContainer)
ret = OmegaConf.create(list(in_dict.values()), parent=_parent_)
assert isinstance(ret, ListConfig)
Expand Down Expand Up @@ -101,7 +151,12 @@ def _get_and_validate_dict_input(
) -> Mapping[Any, Any]:
if isinstance(in_dict, str):
# Path to an existing key in the config: use `select()`.
in_dict = OmegaConf.select(root, in_dict, throw_on_missing=True)
key = in_dict
in_dict = OmegaConf.select(
root, key, throw_on_missing=True, default=_DEFAULT_SELECT_MARKER_
)
if in_dict is _DEFAULT_SELECT_MARKER_:
raise ConfigKeyError(f"Key not found: '{key}'")

if not isinstance(in_dict, Mapping):
raise TypeError(
Expand Down
Loading

0 comments on commit 2e604f9

Please sign in to comment.