Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Struct assign #587

Merged
merged 9 commits into from
Mar 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions news/586.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Assignment of a dict/list to an existing node in a parent in struct mode no longer raises ValidationError
27 changes: 13 additions & 14 deletions omegaconf/dictconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,26 +629,25 @@ def _set_value_impl(
self.__dict__["_content"] = {}
if is_structured_config(value):
self._metadata.object_type = None
data = get_structured_config_data(
value,
allow_objects=self._get_flag("allow_objects"),
)
for k, v in data.items():
self.__setitem__(k, v)
ao = self._get_flag("allow_objects")
data = get_structured_config_data(value, allow_objects=ao)
with flag_override(self, ["struct", "readonly"], False):
for k, v in data.items():
self.__setitem__(k, v)
self._metadata.object_type = get_type_of(value)

elif isinstance(value, DictConfig):
self.__dict__["_metadata"] = copy.deepcopy(value._metadata)
self._metadata.flags = copy.deepcopy(flags)
# disable struct and readonly for the construction phase
# retaining other flags like allow_objects. The real flags are restored at the end of this function
with flag_override(self, "struct", False):
with flag_override(self, "readonly", False):
for k, v in value.__dict__["_content"].items():
self.__setitem__(k, v)
with flag_override(self, ["struct", "readonly"], False):
for k, v in value.__dict__["_content"].items():
self.__setitem__(k, v)

elif isinstance(value, dict):
for k, v in value.items():
self.__setitem__(k, v)
with flag_override(self, ["struct", "readonly"], False):
for k, v in value.items():
self.__setitem__(k, v)

else: # pragma: no cover
msg = f"Unsupported value type : {value}"
raise ValidationError(msg)
Expand Down
12 changes: 6 additions & 6 deletions omegaconf/listconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,13 +593,13 @@ def _set_value_impl(
self._metadata.flags = copy.deepcopy(flags)
# disable struct and readonly for the construction phase
# retaining other flags like allow_objects. The real flags are restored at the end of this function
with flag_override(self, "struct", False):
with flag_override(self, "readonly", False):
for item in value._iter_ex(resolve=False):
self.append(item)
with flag_override(self, ["struct", "readonly"], False):
for item in value._iter_ex(resolve=False):
self.append(item)
elif is_primitive_list(value):
for item in value:
self.append(item)
with flag_override(self, ["struct", "readonly"], False):
for item in value:
self.append(item)

@staticmethod
def _list_eq(l1: Optional["ListConfig"], l2: Optional["ListConfig"]) -> bool:
Expand Down
9 changes: 9 additions & 0 deletions tests/test_basic_ops_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,3 +1023,12 @@ def test_dict_getitem_not_found() -> None:
def test_dict_getitem_none_output() -> None:
cfg = OmegaConf.create({"a": None})
assert cfg["a"] is None


@pytest.mark.parametrize("data", [{"b": 0}, User])
@pytest.mark.parametrize("flag", ["struct", "readonly"])
def test_dictconfig_creation_with_parent_flag(flag: str, data: Any) -> None:
parent = OmegaConf.create({"a": 10})
parent._set_flag(flag, True)
cfg = DictConfig(data, parent=parent)
assert cfg == data
9 changes: 9 additions & 0 deletions tests/test_basic_ops_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,3 +712,12 @@ def test_shallow_copy_none() -> None:
c._set_value([1])
assert c[0] == 1
assert cfg._is_none()


@pytest.mark.parametrize("flag", ["struct", "readonly"])
def test_listconfig_creation_with_parent_flag(flag: str) -> None:
parent = OmegaConf.create([])
parent._set_flag(flag, True)
d = [1, 2, 3]
cfg = ListConfig(d, parent=parent)
assert cfg == d
6 changes: 6 additions & 0 deletions tests/test_readonly.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
raises(ReadonlyConfigError, match="a"),
id="dict_setitem",
),
pytest.param(
{"a": None},
lambda c: c.__setitem__("a", {"b": 10}),
raises(ReadonlyConfigError, match="a"),
id="dict_setitem",
),
pytest.param(
{"a": {"b": {"c": 1}}},
lambda c: c.__getattr__("a").__getattr__("b").__setitem__("c", 1),
Expand Down
23 changes: 15 additions & 8 deletions tests/test_struct.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from typing import Any, Dict

import pytest
from pytest import mark, raises

from omegaconf import OmegaConf
from omegaconf.errors import ConfigKeyError
Expand All @@ -16,40 +16,40 @@ def test_struct_set_on_dict() -> None:
c = OmegaConf.create({"a": {}})
OmegaConf.set_struct(c, True)
# Throwing when it hits foo, so exception key is a.foo and not a.foo.bar
with pytest.raises(AttributeError, match=re.escape("a.foo")):
with raises(AttributeError, match=re.escape("a.foo")):
# noinspection PyStatementEffect
c.a.foo.bar


def test_struct_set_on_nested_dict() -> None:
c = OmegaConf.create({"a": {"b": 10}})
OmegaConf.set_struct(c, True)
with pytest.raises(AttributeError):
with raises(AttributeError):
# noinspection PyStatementEffect
c.foo

assert "a" in c
assert c.a.b == 10
with pytest.raises(AttributeError, match=re.escape("a.foo")):
with raises(AttributeError, match=re.escape("a.foo")):
# noinspection PyStatementEffect
c.a.foo


def test_merge_dotlist_into_struct() -> None:
c = OmegaConf.create({"a": {"b": 10}})
OmegaConf.set_struct(c, True)
with pytest.raises(AttributeError, match=re.escape("foo")):
with raises(AttributeError, match=re.escape("foo")):
c.merge_with_dotlist(["foo=1"])


@pytest.mark.parametrize("in_base, in_merged", [(dict(), dict(a=10))])
@mark.parametrize("in_base, in_merged", [({}, {"a": 10})])
def test_merge_config_with_struct(
in_base: Dict[str, Any], in_merged: Dict[str, Any]
) -> None:
base = OmegaConf.create(in_base)
merged = OmegaConf.create(in_merged)
OmegaConf.set_struct(base, True)
with pytest.raises(ConfigKeyError):
with raises(ConfigKeyError):
OmegaConf.merge(base, merged)


Expand All @@ -59,6 +59,13 @@ def test_struct_contain_missing() -> None:
assert "foo" not in c


@pytest.mark.parametrize("cfg", [{}, OmegaConf.create({}, flags={"struct": True})])
@mark.parametrize("cfg", [{}, OmegaConf.create({}, flags={"struct": True})])
def test_struct_dict_get(cfg: Any) -> None:
assert cfg.get("z") is None


def test_struct_dict_assign() -> None:
cfg = OmegaConf.create({"a": {}})
OmegaConf.set_struct(cfg, True)
cfg.a = {"b": 10}
assert cfg.a == {"b": 10}