Skip to content

Commit

Permalink
added OmegaConf.unsafe_merge
Browse files Browse the repository at this point in the history
  • Loading branch information
omry committed Jan 22, 2021
1 parent d575834 commit eac62f4
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 57 deletions.
18 changes: 17 additions & 1 deletion benchmark/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Dict

from pytest import lazy_fixture # type: ignore
from pytest import fixture, mark
from pytest import fixture, mark, param

from omegaconf import OmegaConf

Expand Down Expand Up @@ -51,6 +51,11 @@ def large_dict_config(large_dict: Any) -> Any:
return OmegaConf.create(large_dict)


@fixture(scope="module")
def merge_data(small_dict: Any) -> Any:
return [OmegaConf.create(small_dict) for _ in range(5)]


@mark.parametrize(
"data",
[
Expand All @@ -63,3 +68,14 @@ def large_dict_config(large_dict: Any) -> Any:
)
def test_omegaconf_create(data: Any, benchmark: Any) -> None:
benchmark(OmegaConf.create, data)


@mark.parametrize(
"merge_function",
[
param(OmegaConf.merge, id="merge"),
param(OmegaConf.unsafe_merge, id="unsafe_merge"),
],
)
def test_omegaconf_merge(merge_function: Any, merge_data: Any, benchmark: Any) -> None:
benchmark(merge_function, merge_data)
37 changes: 27 additions & 10 deletions omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import defaultdict
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, Iterator, Optional, Tuple, Type, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union

from ._utils import ValueKind, _get_value, format_and_raise, get_value_kind
from .errors import ConfigKeyError, MissingMandatoryValue, UnsupportedInterpolationType
Expand Down Expand Up @@ -73,15 +73,32 @@ def _get_parent(self) -> Optional["Container"]:
assert parent is None or isinstance(parent, Container)
return parent

def _set_flag(self, flag: str, value: Optional[bool]) -> "Node":
assert value is None or isinstance(value, bool)
if value is None:
assert self._metadata.flags is not None
if flag in self._metadata.flags:
del self._metadata.flags[flag]
else:
assert self._metadata.flags is not None
self._metadata.flags[flag] = value
def _set_flag(
self,
flags: Union[List[str], str],
values: Union[List[Optional[bool]], Optional[bool]],
) -> "Node":
if isinstance(flags, str):
flags = [flags]

if values is None or isinstance(values, bool):
values = [values]

if len(values) == 1:
values = len(flags) * values

if len(flags) != len(values):
raise ValueError("Inconsistent lengths of input flag names and values")

for idx, flag in enumerate(flags):
value = values[idx]
if value is None:
assert self._metadata.flags is not None
if flag in self._metadata.flags:
del self._metadata.flags[flag]
else:
assert self._metadata.flags is not None
self._metadata.flags[flag] = value
self._invalidate_flags_cache()
return self

Expand Down
9 changes: 8 additions & 1 deletion omegaconf/basecontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,14 @@ def _set_item_impl(self, key: Any, value: Any) -> None:
from .nodes import AnyNode, ValueNode

if isinstance(value, Node):
value = copy.deepcopy(value)
do_deepcopy = not self._get_flag("no_deepcopy_set_nodes")
if not do_deepcopy and isinstance(value, Container):
# if value is from the same config, perform a deepcopy no matter what.
if self._get_root() is value._get_root():
do_deepcopy = True

if do_deepcopy:
value = copy.deepcopy(value)
value._set_parent(None)

try:
Expand Down
67 changes: 57 additions & 10 deletions omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,16 +343,61 @@ def from_dotlist(dotlist: List[str]) -> DictConfig:

@staticmethod
def merge(
*others: Union[BaseContainer, Dict[str, Any], List[Any], Tuple[Any, ...], Any]
*configs: Union[
DictConfig,
ListConfig,
Dict[DictKeyType, Any],
List[Any],
Tuple[Any, ...],
Any,
],
) -> Union[ListConfig, DictConfig]:
"""Merge a list of previously created configs into a single one"""
assert len(others) > 0
target = copy.deepcopy(others[0])
"""
Merge a list of previously created configs into a single one
:param configs: Input configs
:return: the merged config object.
"""
assert len(configs) > 0
target = copy.deepcopy(configs[0])
target = _ensure_container(target)
assert isinstance(target, (DictConfig, ListConfig))

with flag_override(target, "readonly", False):
target.merge_with(*others[1:])
target.merge_with(*configs[1:])
turned_readonly = target._get_flag("readonly") is True

if turned_readonly:
OmegaConf.set_readonly(target, True)

return target

@staticmethod
def unsafe_merge(
*configs: Union[
DictConfig,
ListConfig,
Dict[DictKeyType, Any],
List[Any],
Tuple[Any, ...],
Any,
],
) -> Union[ListConfig, DictConfig]:
"""
Merge a list of previously created configs into a single one
This is much faster than OmegaConf.merge() as the input configs are not copied.
However, the input configs must not be used after this operation as will become inconsistent.
:param configs: Input configs
:return: the merged config object.
"""
assert len(configs) > 0
target = configs[0]
target = _ensure_container(target)
assert isinstance(target, (DictConfig, ListConfig))

with flag_override(
target, ["readonly", "no_deepcopy_set_nodes"], [False, True]
):
target.merge_with(*configs[1:])
turned_readonly = target._get_flag("readonly") is True

if turned_readonly:
Expand Down Expand Up @@ -712,21 +757,23 @@ def to_yaml(cfg: Any, *, resolve: bool = False, sort_keys: bool = False) -> str:

@contextmanager
def flag_override(
config: Node, names: Union[List[str], str], value: Optional[bool]
config: Node,
names: Union[List[str], str],
values: Union[List[Optional[bool]], Optional[bool]],
) -> Generator[Node, None, None]:

if isinstance(names, str):
names = [names]
if values is None or isinstance(values, bool):
values = [values]

prev_states = [config._get_flag(name) for name in names]

try:
for idx, name in enumerate(names):
config._set_flag(name, value)
config._set_flag(names, values)
yield config
finally:
for idx, name in enumerate(names):
config._set_flag(name, prev_states[idx])
config._set_flag(names, prev_states)


@contextmanager
Expand Down
18 changes: 17 additions & 1 deletion tests/test_base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def test_flag_dict(flag: str) -> None:

@pytest.mark.parametrize("flag", ["readonly", "struct"])
def test_freeze_nested_dict(flag: str) -> None:
c = OmegaConf.create(dict(a=dict(b=2)))
c = OmegaConf.create({"a": {"b": 2}})
assert not c._get_flag(flag)
assert not c.a._get_flag(flag)
c._set_flag(flag, True)
Expand All @@ -308,6 +308,22 @@ def test_freeze_nested_dict(flag: str) -> None:
assert c.a._get_flag(flag)


def test_set_flags() -> None:
c = OmegaConf.create({"a": {"b": 2}})
assert not c._get_flag("readonly")
assert not c._get_flag("struct")
c._set_flag(["readonly", "struct"], True)
assert c._get_flag("readonly")
assert c._get_flag("struct")

c._set_flag(["readonly", "struct"], [False, True])
assert not c._get_flag("readonly")
assert c._get_flag("struct")

with pytest.raises(ValueError):
c._set_flag(["readonly", "struct"], [True, False, False])


@pytest.mark.parametrize(
"src", [[], [1, 2, 3], dict(), dict(a=10), StructuredWithMissing]
)
Expand Down
Loading

0 comments on commit eac62f4

Please sign in to comment.