Skip to content

Commit

Permalink
Revert "Prevent OmegaConf.structured(...) from mutating its input (#909
Browse files Browse the repository at this point in the history
…)" (#915)

This reverts commit c8fc02c.
  • Loading branch information
Jasha10 committed May 6, 2022
1 parent 04ecbc1 commit fe5cd18
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 21 deletions.
12 changes: 6 additions & 6 deletions omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def _resolve_forward(type_: Type[Any], module: str) -> Type[Any]:

def extract_dict_subclass_data(obj: Any, parent: Any) -> Optional[Dict[str, Any]]:
"""Check if obj is an instance of a subclass of Dict. If so, extract the Dict keys/values."""
from omegaconf.omegaconf import _node_wrap
from omegaconf.omegaconf import _maybe_wrap

is_type = isinstance(obj, type)
obj_type = obj if is_type else type(obj)
Expand All @@ -242,7 +242,7 @@ def extract_dict_subclass_data(obj: Any, parent: Any) -> Optional[Dict[str, Any]
is_optional, type_ = _resolve_optional(element_type)
type_ = _resolve_forward(type_, obj.__module__)
try:
dict_subclass_data[name] = _node_wrap(
dict_subclass_data[name] = _maybe_wrap(
ref_type=type_,
is_optional=is_optional,
key=name,
Expand All @@ -269,7 +269,7 @@ def get_attr_class_init_field_names(obj: Any) -> List[str]:


def get_attr_data(obj: Any, allow_objects: Optional[bool] = None) -> Dict[str, Any]:
from omegaconf.omegaconf import OmegaConf, _node_wrap
from omegaconf.omegaconf import OmegaConf, _maybe_wrap

flags = {"allow_objects": allow_objects} if allow_objects is not None else {}

Expand Down Expand Up @@ -297,7 +297,7 @@ def get_attr_data(obj: Any, allow_objects: Optional[bool] = None) -> Dict[str, A
format_and_raise(node=None, key=None, value=value, cause=e, msg=str(e))

try:
d[name] = _node_wrap(
d[name] = _maybe_wrap(
ref_type=type_,
is_optional=is_optional,
key=name,
Expand All @@ -324,7 +324,7 @@ def get_dataclass_data(
) -> Dict[str, Any]:
from typing import get_type_hints

from omegaconf.omegaconf import MISSING, OmegaConf, _node_wrap
from omegaconf.omegaconf import MISSING, OmegaConf, _maybe_wrap

flags = {"allow_objects": allow_objects} if allow_objects is not None else {}
d = {}
Expand Down Expand Up @@ -356,7 +356,7 @@ def get_dataclass_data(
)
format_and_raise(node=None, key=None, value=value, cause=e, msg=str(e))
try:
d[name] = _node_wrap(
d[name] = _maybe_wrap(
ref_type=type_,
is_optional=is_optional,
key=name,
Expand Down
4 changes: 2 additions & 2 deletions omegaconf/basecontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,11 +589,11 @@ def assign(value_key: Any, val: Node) -> None:
self._wrap_value_and_set(key, value, type_hint)

def _wrap_value_and_set(self, key: Any, val: Any, type_hint: Any) -> None:
from omegaconf.omegaconf import _node_wrap
from omegaconf.omegaconf import _maybe_wrap

is_optional, ref_type = _resolve_optional(type_hint)

wrapped = _node_wrap(
wrapped = _maybe_wrap(
ref_type=ref_type,
key=key,
value=val,
Expand Down
13 changes: 12 additions & 1 deletion omegaconf/listconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,8 @@ def _update_keys(self) -> None:
node._metadata.key = i

def insert(self, index: int, item: Any) -> None:
from omegaconf.omegaconf import _maybe_wrap

try:
if self._get_flag("readonly"):
raise ReadonlyConfigError("Cannot insert into a read-only ListConfig")
Expand All @@ -317,7 +319,16 @@ def insert(self, index: int, item: Any) -> None:
assert isinstance(self.__dict__["_content"], list)
# insert place holder
self.__dict__["_content"].insert(index, None)
self._set_at_index(index, item)
is_optional, ref_type = _resolve_optional(self._metadata.element_type)
node = _maybe_wrap(
ref_type=ref_type,
key=index,
value=item,
is_optional=is_optional,
parent=self,
)
self._validate_set(key=index, value=node)
self._set_at_index(index, node)
self._update_keys()
except Exception:
del self.__dict__["_content"][index]
Expand Down
23 changes: 23 additions & 0 deletions omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,6 +1085,29 @@ def _node_wrap(
return node


def _maybe_wrap(
ref_type: Any,
key: Any,
value: Any,
is_optional: bool,
parent: Optional[BaseContainer],
) -> Node:
# if already a node, update key and parent and return as is.
# NOTE: that this mutate the input node!
if isinstance(value, Node):
value._set_key(key)
value._set_parent(parent)
return value
else:
return _node_wrap(
ref_type=ref_type,
parent=parent,
is_optional=is_optional,
value=value,
key=key,
)


def _select_one(
c: Container, key: str, throw_on_missing: bool, throw_on_type_error: bool = True
) -> Tuple[Optional[Node], Union[str, int]]:
Expand Down
10 changes: 0 additions & 10 deletions tests/structured_conf/test_structured_basic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import re
from copy import deepcopy
from importlib import import_module
from typing import Any, Optional

Expand Down Expand Up @@ -323,12 +322,3 @@ def test_allow_objects(self, module: Any) -> None:
with flag_override(cfg, "allow_objects", True):
cfg.plugin = pwo
assert cfg.plugin == pwo

def test_structured_creation_does_not_mutate_input(self, module: Any) -> None:
cfg1 = OmegaConf.structured(module.MissingUserField(module.User("Bond", 7)))
user1 = cfg1.user
prev_user = deepcopy(user1)
cfg2 = OmegaConf.structured(module.MissingUserField(user1))
assert user1._metadata == prev_user._metadata
assert user1._parent == prev_user._parent
assert user1 is not cfg2.user
4 changes: 2 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def test_node_wrap(
],
)
def test_node_wrap2(target_type: Any, value: Any, expected: Any) -> None:
from omegaconf.omegaconf import _node_wrap
from omegaconf.omegaconf import _maybe_wrap

if isinstance(expected, Node):
res = _node_wrap(
Expand All @@ -200,7 +200,7 @@ def test_node_wrap2(target_type: Any, value: Any, expected: Any) -> None:
assert res._key() == "foo"
else:
with raises(expected):
_node_wrap(
_maybe_wrap(
ref_type=target_type,
key=None,
value=value,
Expand Down

0 comments on commit fe5cd18

Please sign in to comment.