Skip to content

Commit

Permalink
Assignment of a dict to an existing node in a parent in struct mode n…
Browse files Browse the repository at this point in the history
…o longer raises ValidationError
  • Loading branch information
omry committed Mar 9, 2021
1 parent 88586d8 commit 4d26d4c
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 11 deletions.
2 changes: 2 additions & 0 deletions news/586.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Assignment of a dict to an existing node in a parent in struct mode no longer raises ValidationError

27 changes: 16 additions & 11 deletions omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,10 @@ def _node_wrap(
ref_type: Any = Any,
) -> Node:
node: Node
allow_objects = parent is not None and parent._get_flag("allow_objects") is True
flags = {"allow_objects": allow_objects} if allow_objects is not None else {}
dummy = OmegaConf.create(flags=flags)

is_dict = is_primitive_dict(value) or is_dict_annotation(type_)
is_list = (
type(value) in (list, tuple)
Expand All @@ -909,7 +913,7 @@ def _node_wrap(
node = DictConfig(
content=value,
key=key,
parent=parent,
parent=dummy,
ref_type=type_,
is_optional=is_optional,
key_type=key_type,
Expand All @@ -920,7 +924,7 @@ def _node_wrap(
node = ListConfig(
content=value,
key=key,
parent=parent,
parent=dummy,
is_optional=is_optional,
element_type=element_type,
ref_type=ref_type,
Expand All @@ -932,33 +936,34 @@ def _node_wrap(
is_optional=is_optional,
content=value,
key=key,
parent=parent,
parent=dummy,
key_type=key_type,
element_type=element_type,
)
elif type_ == Any or type_ is None:
node = AnyNode(value=value, key=key, parent=parent, is_optional=is_optional)
node = AnyNode(value=value, key=key, parent=dummy, is_optional=is_optional)
elif issubclass(type_, Enum):
node = EnumNode(
enum_type=type_,
value=value,
key=key,
parent=parent,
parent=dummy,
is_optional=is_optional,
)
elif type_ == int:
node = IntegerNode(value=value, key=key, parent=parent, is_optional=is_optional)
node = IntegerNode(value=value, key=key, parent=dummy, is_optional=is_optional)
elif type_ == float:
node = FloatNode(value=value, key=key, parent=parent, is_optional=is_optional)
node = FloatNode(value=value, key=key, parent=dummy, is_optional=is_optional)
elif type_ == bool:
node = BooleanNode(value=value, key=key, parent=parent, is_optional=is_optional)
node = BooleanNode(value=value, key=key, parent=dummy, is_optional=is_optional)
elif type_ == str:
node = StringNode(value=value, key=key, parent=parent, is_optional=is_optional)
node = StringNode(value=value, key=key, parent=dummy, is_optional=is_optional)
else:
if parent is not None and parent._get_flag("allow_objects") is True:
node = AnyNode(value=value, key=key, parent=parent, is_optional=is_optional)
if allow_objects:
node = AnyNode(value=value, key=key, parent=dummy, is_optional=is_optional)
else:
raise ValidationError(f"Unexpected object type : {type_str(type_)}")
node._set_parent(parent)
return node


Expand Down
7 changes: 7 additions & 0 deletions tests/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,10 @@ def test_struct_contain_missing() -> None:
@mark.parametrize("cfg", [{}, OmegaConf.create({}, flags={"struct": True})])
def test_struct_dict_get(cfg: Any) -> None:
assert cfg.get("z") is None


def test_struct_dict_assign() -> None:
cfg = OmegaConf.create({"a": {}})
OmegaConf.set_struct(cfg, True)
cfg.a = {"b": 10}
assert cfg.a == {"b": 10}

0 comments on commit 4d26d4c

Please sign in to comment.