Skip to content

Commit

Permalink
optimized ListConfig iteration and improved testing
Browse files Browse the repository at this point in the history
  • Loading branch information
omry committed Mar 18, 2021
1 parent f24fe6a commit 5aebda8
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 34 deletions.
2 changes: 1 addition & 1 deletion news/532.misc
Original file line number Diff line number Diff line change
@@ -1 +1 @@
Optimized ListConfig iteration by 7.7x in a benchmark
Optimized ListConfig iteration by 8.8x in a benchmark
12 changes: 7 additions & 5 deletions omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,12 +532,14 @@ def _get_value(value: Any) -> Any:
from .base import Container
from .nodes import ValueNode

if isinstance(value, Container) and (
value._is_none() or value._is_missing() or value._is_interpolation()
):
return value._value()
if isinstance(value, ValueNode):
value = value._value()
return value._value()
elif isinstance(value, Container):
boxed = value._value()
if boxed is None or _is_missing_literal(boxed) or _is_interpolation(boxed):
return boxed

# return primitives and regular OmegaConf Containers as is
return value


Expand Down
5 changes: 4 additions & 1 deletion omegaconf/basecontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,10 @@ def __delitem__(self, key: Any) -> None:
...

def __len__(self) -> int:
return self.__dict__["_content"].__len__() # type: ignore
if self._is_none() or self._is_missing() or self._is_interpolation():
return 0
content = self.__dict__["_content"]
return len(content)

def merge_with_cli(self) -> None:
args_list = sys.argv[1:]
Expand Down
35 changes: 19 additions & 16 deletions omegaconf/listconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from ._utils import (
ValueKind,
_get_value,
_is_none,
format_and_raise,
get_value_kind,
Expand Down Expand Up @@ -153,14 +152,6 @@ def __dir__(self) -> Iterable[str]:
return []
return [str(x) for x in range(0, len(self))]

def __len__(self) -> int:
if self._is_none():
return 0
if self._is_missing():
return 0
assert isinstance(self.__dict__["_content"], list)
return len(self.__dict__["_content"])

def __setattr__(self, key: str, value: Any) -> None:
self._format_and_raise(
key=key,
Expand Down Expand Up @@ -500,20 +491,32 @@ def __iter__(self) -> Iterator[Any]:

class ListIterator(Iterator[Any]):
def __init__(self, lst: Any, resolve: bool) -> None:
self.iter = iter(lst.__dict__["_content"])
self.resolve = resolve
self.iterator = iter(lst.__dict__["_content"])
self.index = 0
from .nodes import ValueNode

self.ValueNode = ValueNode

def __next__(self) -> Any:
v = next(self.iter)

x = next(self.iterator)
if self.resolve:
v = v._dereference_node()
x = x._dereference_node()
if x._is_missing():
raise MissingMandatoryValue(f"Missing value at index {self.index}")

if v._is_missing():
raise MissingMandatoryValue(f"Missing value at index {self.index}")
self.index = self.index + 1
return _get_value(v)
if isinstance(x, self.ValueNode):
return x._value()
else:
# Must be omegaconf.Container. not checking for perf reasons.
if x._is_none():
return None
return x

def __repr__(self) -> str: # pragma: no cover
return f"ListConfig.ListIterator(resolve={self.resolve})"

def _iter_ex(self, resolve: bool) -> Iterator[Any]:
try:
Expand All @@ -523,7 +526,7 @@ def _iter_ex(self, resolve: bool) -> Iterator[Any]:
raise MissingMandatoryValue("Cannot iterate a missing ListConfig")

return ListConfig.ListIterator(self, resolve)
except (ReadonlyConfigError, TypeError, MissingMandatoryValue) as e:
except (TypeError, MissingMandatoryValue) as e:
self._format_and_raise(key=None, value=None, cause=e)
assert False

Expand Down
21 changes: 17 additions & 4 deletions tests/test_basic_ops_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,10 +598,23 @@ def test_dict_nested_structured_delitem() -> None:
assert "name" not in c.user


@pytest.mark.parametrize("d, expected", [({}, 0), ({"a": 10, "b": 11}, 2)])
def test_dict_len(d: Any, expected: Any) -> None:
c = OmegaConf.create(d)
assert len(c) == expected
@pytest.mark.parametrize(
"d, expected",
[
pytest.param(DictConfig({}), 0, id="empty"),
pytest.param(DictConfig({"a": 10}), 1, id="full"),
pytest.param(DictConfig(None), 0, id="none"),
pytest.param(DictConfig("???"), 0, id="missing"),
pytest.param(
DictConfig("${foo}", parent=OmegaConf.create({"foo": {"a": 10}})),
0,
id="interpolation",
),
pytest.param(DictConfig("${foo}"), 0, id="broken_interpolation"),
],
)
def test_dict_len(d: DictConfig, expected: Any) -> None:
assert d.__len__() == expected


def test_dict_assign_illegal_value() -> None:
Expand Down
62 changes: 55 additions & 7 deletions tests/test_basic_ops_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,31 @@ def test_list_get_do_not_return_default(


@pytest.mark.parametrize(
"input_, expected, list_key",
"input_, expected, expected_no_resolve, list_key",
[
pytest.param([1, 2], [1, 2], None, id="simple"),
pytest.param(["${1}", 2], [2, 2], None, id="interpolation"),
pytest.param([1, 2], [1, 2], [1, 2], None, id="simple"),
pytest.param(["${1}", 2], [2, 2], ["${1}", 2], None, id="interpolation"),
pytest.param(
[ListConfig(None), ListConfig("${.2}"), [1, 2]],
[None, ListConfig([1, 2]), ListConfig([1, 2])],
[None, ListConfig("${.2}"), ListConfig([1, 2])],
None,
id="iter_over_lists",
),
pytest.param(
[DictConfig(None), DictConfig("${.2}"), {"a": 10}],
[None, DictConfig({"a": 10}), DictConfig({"a": 10})],
[None, DictConfig("${.2}"), DictConfig({"a": 10})],
None,
id="iter_over_dicts",
),
pytest.param(
["???", ListConfig("???"), DictConfig("???")],
pytest.raises(MissingMandatoryValue),
["???", ListConfig("???"), DictConfig("???")],
None,
id="iter_over_missing",
),
pytest.param(
{
"defaults": [
Expand All @@ -77,20 +98,45 @@ def test_list_get_do_not_return_default(
{"foo": "${defaults.0.optimizer}_${defaults.1.dataset}"},
]
},
[{"optimizer": "adam"}, {"dataset": "imagenet"}, {"foo": "adam_imagenet"}],
[
OmegaConf.create({"optimizer": "adam"}),
OmegaConf.create({"dataset": "imagenet"}),
OmegaConf.create({"foo": "adam_imagenet"}),
],
[
OmegaConf.create({"optimizer": "adam"}),
OmegaConf.create({"dataset": "imagenet"}),
OmegaConf.create(
{"foo": "${defaults.0.optimizer}_${defaults.1.dataset}"}
),
],
"defaults",
id="str_interpolation",
),
],
)
def test_iterate_list(input_: Any, expected: Any, list_key: str) -> None:
def test_iterate_list(
input_: Any, expected: Any, expected_no_resolve: Any, list_key: str
) -> None:
c = OmegaConf.create(input_)
if list_key is not None:
lst = c.get(list_key)
else:
lst = c
items = [x for x in lst]
assert items == expected

def test_iter(iterator: Any, expected_output: Any) -> None:
if isinstance(expected_output, list):
items = [x for x in iterator]
assert items == expected_output
for idx in range(len(items)):
assert type(items[idx]) is type(expected_output[idx]) # noqa
else:
with expected_output:
for _ in iterator:
pass

test_iter(iter(lst), expected)
test_iter(lst._iter_ex(resolve=False), expected_no_resolve)


def test_iterate_list_with_missing_interpolation() -> None:
Expand Down Expand Up @@ -242,6 +288,8 @@ def test_list_delitem() -> None:
(OmegaConf.create([1, 2]), 2),
(ListConfig(content=None), 0),
(ListConfig(content="???"), 0),
(ListConfig(content="${foo}"), 0),
(ListConfig(content="${foo}", parent=DictConfig({"foo": [1, 2]})), 0),
],
)
def test_list_len(lst: Any, expected: Any) -> None:
Expand Down

0 comments on commit 5aebda8

Please sign in to comment.