Skip to content

Commit

Permalink
fixed a few bugs creating with allow_objects
Browse files Browse the repository at this point in the history
  • Loading branch information
omry committed Sep 24, 2020
1 parent bd64232 commit 8ba28be
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 28 deletions.
2 changes: 1 addition & 1 deletion news/382.feature
Original file line number Diff line number Diff line change
@@ -1 +1 @@
Experimental support for objects in config
Experimental support for enabling objects in config via "allow_objects" flag
4 changes: 1 addition & 3 deletions omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,9 +615,7 @@ def format_and_raise(
from omegaconf import OmegaConf
from omegaconf.base import Node

# Uncomment to make debugging easier.
# Note that this will cause some tests to fail
#
# Uncomment to make debugging easier. Note that this will cause some tests to fail
# raise cause

if isinstance(cause, AssertionError):
Expand Down
2 changes: 1 addition & 1 deletion omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def _value(self) -> Any:
...

@abstractmethod
def _set_value(self, value: Any) -> None:
def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None:
...

@abstractmethod
Expand Down
24 changes: 17 additions & 7 deletions omegaconf/dictconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def __init__(
flags: Optional[Dict[str, bool]] = None,
) -> None:
try:
if isinstance(content, DictConfig):
if flags is None:
flags = content._metadata.flags
super().__init__(
parent=parent,
metadata=ContainerMetadata(
Expand All @@ -84,21 +87,21 @@ def __init__(
raise KeyValidationError(f"Unsupported key type {key_type}")

if is_structured_config(content) or is_structured_config(ref_type):
self._set_value(content)
self._set_value(content, flags=flags)
if is_structured_config_frozen(content) or is_structured_config_frozen(
ref_type
):
self._set_flag("readonly", True)

else:
self._set_value(content)
if isinstance(content, DictConfig):
metadata = copy.deepcopy(content._metadata)
metadata.key = key
metadata.optional = is_optional
metadata.element_type = element_type
metadata.key_type = key_type
self.__dict__["_metadata"] = metadata
self._set_value(content, flags=flags)
except Exception as ex:
format_and_raise(node=None, key=None, value=None, cause=ex, msg=str(ex))

Expand Down Expand Up @@ -529,8 +532,11 @@ def _promote(self, type_or_prototype: Optional[Type[Any]]) -> None:
# restore the type.
self._metadata.object_type = object_type

def _set_value(self, value: Any) -> None:
from omegaconf import OmegaConf
def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None:
from omegaconf import OmegaConf, flag_override

if flags is None:
flags = {}

assert not isinstance(value, ValueNode)
self._validate_set(key=None, value=value)
Expand All @@ -556,10 +562,14 @@ def _set_value(self, value: Any) -> None:
self.__setitem__(k, v)
self._metadata.object_type = get_type_of(value)
elif isinstance(value, DictConfig):
self._metadata.object_type = dict
for k, v in value.__dict__["_content"].items():
self.__setitem__(k, v)
self.__dict__["_metadata"] = copy.deepcopy(value._metadata)
self._metadata.flags = copy.deepcopy(flags)
# disable struct and readonly for the construction phase
# retaining other flags like allow_objects. The real flags are restored at the end of this function
with flag_override(self, "struct", False):
with flag_override(self, "readonly", False):
for k, v in value.__dict__["_content"].items():
self.__setitem__(k, v)

elif isinstance(value, dict):
for k, v in value.items():
Expand Down
32 changes: 18 additions & 14 deletions omegaconf/listconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def __init__(
flags: Optional[Dict[str, bool]] = None,
) -> None:
try:
if isinstance(content, ListConfig):
if flags is None:
flags = content._metadata.flags
super().__init__(
parent=parent,
metadata=ContainerMetadata(
Expand All @@ -71,7 +74,7 @@ def __init__(
)

self.__dict__["_content"] = None
self._set_value(value=content)
self._set_value(value=content, flags=flags)
except Exception as ex:
format_and_raise(node=None, key=None, value=None, cause=ex, msg=str(ex))

Expand Down Expand Up @@ -500,8 +503,11 @@ def __contains__(self, item: Any) -> bool:
return True
return False

def _set_value(self, value: Any) -> None:
from omegaconf import OmegaConf
def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None:
from omegaconf import OmegaConf, flag_override

if flags is None:
flags = {}

if OmegaConf.is_none(value):
if not self._is_optional():
Expand All @@ -519,25 +525,23 @@ def _set_value(self, value: Any) -> None:
else:
if not (is_primitive_list(value) or isinstance(value, ListConfig)):
type_ = type(value)
msg = (
f"Invalid value assigned : {type_.__name__} is not a "
f"subclass of ListConfig or list."
)
msg = f"Invalid value assigned : {type_.__name__} is not a ListConfig, list or tuple."
raise ValidationError(msg)

self.__dict__["_content"] = []
if isinstance(value, ListConfig):
self.__dict__["_metadata"] = copy.deepcopy(value._metadata)
self.__dict__["_metadata"].flags = {}
for item in value._iter_ex(resolve=False):
self.append(item)
self.__dict__["_metadata"].flags = copy.deepcopy(value._metadata.flags)
self._metadata.flags = copy.deepcopy(flags)
# disable struct and readonly for the construction phase
# retaining other flags like allow_objects. The real flags are restored at the end of this function
with flag_override(self, "struct", False):
with flag_override(self, "readonly", False):
for item in value._iter_ex(resolve=False):
self.append(item)
elif is_primitive_list(value):
for item in value:
self.append(item)

if isinstance(value, ListConfig):
self.__dict__["_metadata"].flags = value._metadata.flags

@staticmethod
def _list_eq(l1: Optional["ListConfig"], l2: Optional["ListConfig"]) -> bool:

Expand Down
2 changes: 1 addition & 1 deletion omegaconf/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, parent: Optional[Container], value: Any, metadata: Metadata):
def _value(self) -> Any:
return self._val

def _set_value(self, value: Any) -> None:
def _set_value(self, value: Any, flags: Optional[Dict[str, bool]] = None) -> None:
from ._utils import ValueKind, get_value_kind

if self._get_flag("readonly"):
Expand Down
50 changes: 50 additions & 0 deletions tests/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,56 @@ def test_create_value(input_: Any, expected: Any) -> None:
assert OmegaConf.create(input_) == expected


@pytest.mark.parametrize( # type: ignore
"input_",
[
# top level dict
{"x": IllegalType()},
{"x": {"y": IllegalType()}},
{"x": [IllegalType()]},
# top level list
[IllegalType()],
[[IllegalType()]],
[{"x": IllegalType()}],
[{"x": [IllegalType()]}],
],
)
def test_create_allow_objects(input_: Any) -> None:
# test creating from a primitive container
cfg = OmegaConf.create(input_, flags={"allow_objects": True})
assert cfg == input_

# test creating from an OmegaConf object, inheriting the allow_objects flag
cfg = OmegaConf.create(cfg)
assert cfg == input_

# test creating from an OmegaConf object
cfg = OmegaConf.create(cfg, flags={"allow_objects": True})
assert cfg == input_


@pytest.mark.parametrize( # type: ignore
"input_",
[
pytest.param({"foo": "bar"}, id="dict"),
pytest.param([1, 2, 3], id="list"),
],
)
def test_create_flags_overriding(input_: Any) -> Any:
cfg = OmegaConf.create(input_)
OmegaConf.set_struct(cfg, True)

# by default flags are inherited
cfg2 = OmegaConf.create(cfg)
assert OmegaConf.is_struct(cfg2)
assert not OmegaConf.is_readonly(cfg2)

# but specified flags are replacing all of the flags (even those that are not specified)
cfg2 = OmegaConf.create(cfg, flags={"readonly": True})
assert not OmegaConf.is_struct(cfg2)
assert OmegaConf.is_readonly(cfg2)


def test_create_from_cli() -> None:
sys.argv = ["program.py", "a=1", "b.c=2"]
c = OmegaConf.from_cli()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,7 @@ def finalize(self, cfg: Any) -> None:
op=lambda cfg: cfg._set_value(True),
exception_type=ValidationError,
object_type=None,
msg="Invalid value assigned : bool is not a subclass of ListConfig or list",
msg="Invalid value assigned : bool is not a ListConfig, list or tuple.",
ref_type=List[int],
low_level=True,
),
Expand Down

0 comments on commit 8ba28be

Please sign in to comment.