Skip to content

Commit

Permalink
added list iter tests and dict and list len tests
Browse files Browse the repository at this point in the history
While adding tests for the list iteration, I noticed a bug with len where assertion error
would be triggered when taking length of a missing/interpolation/none ListConfig.
DictConfig was also not being consistent here (and actually returned the wrong result in those cases).
I think this is mostly not user visible because users do not normally gain access to such objects, so
I am not bothering with a news fragment.

I also noticed that list iteration and dict iteration are not handling missing/interpolatio values consistently.
  • Loading branch information
omry committed Mar 18, 2021
1 parent ad7c9af commit 8ce5eee
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 23 deletions.
3 changes: 1 addition & 2 deletions omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,8 +533,7 @@ def _get_value(value: Any) -> Any:
from .nodes import ValueNode

if isinstance(value, ValueNode):
boxed = value._value()
return boxed
return value._value()
elif isinstance(value, Container):
boxed = value._value()
if boxed is None or _is_missing_literal(boxed) or _is_interpolation(boxed):
Expand Down
3 changes: 2 additions & 1 deletion omegaconf/basecontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,9 @@ def __setstate__(self, d: Dict[str, Any]) -> None:
def __delitem__(self, key: Any) -> None:
...

@abstractmethod
def __len__(self) -> int:
return self.__dict__["_content"].__len__() # type: ignore
...

def merge_with_cli(self) -> None:
args_list = sys.argv[1:]
Expand Down
7 changes: 7 additions & 0 deletions omegaconf/dictconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,13 @@ def __dir__(self) -> Iterable[str]:
return []
return self.__dict__["_content"].keys() # type: ignore

def __len__(self) -> int:
if self._is_none() or self._is_missing() or self._is_interpolation():
return 0
content = self.__dict__["_content"]
assert isinstance(content, dict)
return len(content)

def __setattr__(self, key: str, value: Any) -> None:
"""
Allow assigning attributes to DictConfig
Expand Down
17 changes: 9 additions & 8 deletions omegaconf/listconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,11 @@ def __dir__(self) -> Iterable[str]:
return [str(x) for x in range(0, len(self))]

def __len__(self) -> int:
if self._is_none():
return 0
if self._is_missing():
if self._is_none() or self._is_missing() or self._is_interpolation():
return 0
assert isinstance(self.__dict__["_content"], list)
return len(self.__dict__["_content"])
content = self.__dict__["_content"]
assert isinstance(content, list)
return len(content)

def __setattr__(self, key: str, value: Any) -> None:
self._format_and_raise(
Expand Down Expand Up @@ -511,15 +510,17 @@ def __next__(self) -> Any:
x = next(self.iterator)
if self.resolve:
x = x._dereference_node()

if x._is_missing():
raise MissingMandatoryValue(f"Missing value at index {self.index}")
if x._is_missing():
raise MissingMandatoryValue(f"Missing value at index {self.index}")

self.index = self.index + 1
if isinstance(x, self.ValueNode):
return x._value()
return x

def __repr__(self) -> str:
return f"ListConfig.ListIterator(resolve={self.resolve})"

def _iter_ex(self, resolve: bool) -> Iterator[Any]:
try:
if self._is_none():
Expand Down
21 changes: 17 additions & 4 deletions tests/test_basic_ops_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,10 +598,23 @@ def test_dict_nested_structured_delitem() -> None:
assert "name" not in c.user


@pytest.mark.parametrize("d, expected", [({}, 0), ({"a": 10, "b": 11}, 2)])
def test_dict_len(d: Any, expected: Any) -> None:
c = OmegaConf.create(d)
assert len(c) == expected
@pytest.mark.parametrize(
"d, expected",
[
pytest.param(DictConfig({}), 0, id="empty"),
pytest.param(DictConfig({"a": 10}), 1, id="full"),
pytest.param(DictConfig(None), 0, id="none"),
pytest.param(DictConfig("???"), 0, id="missing"),
pytest.param(
DictConfig("${foo}", parent=OmegaConf.create({"foo": {"a": 10}})),
0,
id="interpolation",
),
pytest.param(DictConfig("${foo}"), 0, id="broken_interpolation"),
],
)
def test_dict_len(d: DictConfig, expected: Any) -> None:
assert d.__len__() == expected


def test_dict_assign_illegal_value() -> None:
Expand Down
58 changes: 50 additions & 8 deletions tests/test_basic_ops_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,31 @@ def test_list_get_do_not_return_default(


@pytest.mark.parametrize(
"input_, expected, list_key",
"input_, expected, expected_no_resolve, list_key",
[
pytest.param([1, 2], [1, 2], None, id="simple"),
pytest.param(["${1}", 2], [2, 2], None, id="interpolation"),
pytest.param([1, 2], [1, 2], [1, 2], None, id="simple"),
pytest.param(["${1}", 2], [2, 2], ["${1}", 2], None, id="interpolation"),
pytest.param(
[ListConfig(None), ListConfig("${.2}"), [1, 2]],
[ListConfig(None), ListConfig([1, 2]), ListConfig([1, 2])],
[ListConfig(None), ListConfig("${.2}"), ListConfig([1, 2])],
None,
id="iter_over_lists",
),
pytest.param(
[DictConfig(None), DictConfig("${.2}"), {"a": 10}],
[DictConfig(None), DictConfig({"a": 10}), DictConfig({"a": 10})],
[DictConfig(None), DictConfig("${.2}"), DictConfig({"a": 10})],
None,
id="iter_over_dicts",
),
pytest.param(
["???", ListConfig("???"), DictConfig("???")],
pytest.raises(MissingMandatoryValue),
["???", ListConfig("???"), DictConfig("???")],
None,
id="iter_over_missing",
),
pytest.param(
{
"defaults": [
Expand All @@ -82,21 +103,40 @@ def test_list_get_do_not_return_default(
OmegaConf.create({"dataset": "imagenet"}),
OmegaConf.create({"foo": "adam_imagenet"}),
],
[
OmegaConf.create({"optimizer": "adam"}),
OmegaConf.create({"dataset": "imagenet"}),
OmegaConf.create(
{"foo": "${defaults.0.optimizer}_${defaults.1.dataset}"}
),
],
"defaults",
id="str_interpolation",
),
],
)
def test_iterate_list(input_: Any, expected: Any, list_key: str) -> None:
def test_iterate_list(
input_: Any, expected: Any, expected_no_resolve: Any, list_key: str
) -> None:
c = OmegaConf.create(input_)
if list_key is not None:
lst = c.get(list_key)
else:
lst = c
items = [x for x in lst]
assert items == expected
for idx in range(len(items)):
assert type(items[idx]) is type(expected[idx]) # noqa

def test_iter(iterator: Any, expected_output: Any) -> None:
if isinstance(expected_output, list):
items = [x for x in iterator]
assert items == expected_output
for idx in range(len(items)):
assert type(items[idx]) is type(expected_output[idx]) # noqa
else:
with expected_output:
for x in iterator:
pass

test_iter(iter(lst), expected)
test_iter(lst._iter_ex(resolve=False), expected_no_resolve)


def test_iterate_list_with_missing_interpolation() -> None:
Expand Down Expand Up @@ -248,6 +288,8 @@ def test_list_delitem() -> None:
(OmegaConf.create([1, 2]), 2),
(ListConfig(content=None), 0),
(ListConfig(content="???"), 0),
(ListConfig(content="${foo}"), 0),
(ListConfig(content="${foo}", parent=DictConfig({"foo": [1, 2]})), 0),
],
)
def test_list_len(lst: Any, expected: Any) -> None:
Expand Down

0 comments on commit 8ce5eee

Please sign in to comment.