From 5aebda88291ff85ca2e8b9a8d5e3343759709d84 Mon Sep 17 00:00:00 2001 From: Omry Yadan Date: Wed, 17 Mar 2021 17:22:02 -0700 Subject: [PATCH] optimized ListConfig iteration and improved testing --- news/532.misc | 2 +- omegaconf/_utils.py | 12 ++++--- omegaconf/basecontainer.py | 5 ++- omegaconf/listconfig.py | 35 ++++++++++---------- tests/test_basic_ops_dict.py | 21 +++++++++--- tests/test_basic_ops_list.py | 62 ++++++++++++++++++++++++++++++++---- 6 files changed, 103 insertions(+), 34 deletions(-) diff --git a/news/532.misc b/news/532.misc index 406b209da..982c2e2c2 100644 --- a/news/532.misc +++ b/news/532.misc @@ -1 +1 @@ -Optimized ListConfig iteration by 7.7x in a benchmark +Optimized ListConfig iteration by 8.8x in a benchmark diff --git a/omegaconf/_utils.py b/omegaconf/_utils.py index 08e49e9be..fa94a1d8c 100644 --- a/omegaconf/_utils.py +++ b/omegaconf/_utils.py @@ -532,12 +532,14 @@ def _get_value(value: Any) -> Any: from .base import Container from .nodes import ValueNode - if isinstance(value, Container) and ( - value._is_none() or value._is_missing() or value._is_interpolation() - ): - return value._value() if isinstance(value, ValueNode): - value = value._value() + return value._value() + elif isinstance(value, Container): + boxed = value._value() + if boxed is None or _is_missing_literal(boxed) or _is_interpolation(boxed): + return boxed + + # return primitives and regular OmegaConf Containers as is return value diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 74ce491ef..91a3f2209 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -124,7 +124,10 @@ def __delitem__(self, key: Any) -> None: ... def __len__(self) -> int: - return self.__dict__["_content"].__len__() # type: ignore + if self._is_none() or self._is_missing() or self._is_interpolation(): + return 0 + content = self.__dict__["_content"] + return len(content) def merge_with_cli(self) -> None: args_list = sys.argv[1:] diff --git a/omegaconf/listconfig.py b/omegaconf/listconfig.py index 40564d483..37e6b2b5b 100644 --- a/omegaconf/listconfig.py +++ b/omegaconf/listconfig.py @@ -16,7 +16,6 @@ from ._utils import ( ValueKind, - _get_value, _is_none, format_and_raise, get_value_kind, @@ -153,14 +152,6 @@ def __dir__(self) -> Iterable[str]: return [] return [str(x) for x in range(0, len(self))] - def __len__(self) -> int: - if self._is_none(): - return 0 - if self._is_missing(): - return 0 - assert isinstance(self.__dict__["_content"], list) - return len(self.__dict__["_content"]) - def __setattr__(self, key: str, value: Any) -> None: self._format_and_raise( key=key, @@ -500,20 +491,32 @@ def __iter__(self) -> Iterator[Any]: class ListIterator(Iterator[Any]): def __init__(self, lst: Any, resolve: bool) -> None: - self.iter = iter(lst.__dict__["_content"]) self.resolve = resolve + self.iterator = iter(lst.__dict__["_content"]) self.index = 0 + from .nodes import ValueNode + + self.ValueNode = ValueNode def __next__(self) -> Any: - v = next(self.iter) + x = next(self.iterator) if self.resolve: - v = v._dereference_node() + x = x._dereference_node() + if x._is_missing(): + raise MissingMandatoryValue(f"Missing value at index {self.index}") - if v._is_missing(): - raise MissingMandatoryValue(f"Missing value at index {self.index}") self.index = self.index + 1 - return _get_value(v) + if isinstance(x, self.ValueNode): + return x._value() + else: + # Must be omegaconf.Container. not checking for perf reasons. + if x._is_none(): + return None + return x + + def __repr__(self) -> str: # pragma: no cover + return f"ListConfig.ListIterator(resolve={self.resolve})" def _iter_ex(self, resolve: bool) -> Iterator[Any]: try: @@ -523,7 +526,7 @@ def _iter_ex(self, resolve: bool) -> Iterator[Any]: raise MissingMandatoryValue("Cannot iterate a missing ListConfig") return ListConfig.ListIterator(self, resolve) - except (ReadonlyConfigError, TypeError, MissingMandatoryValue) as e: + except (TypeError, MissingMandatoryValue) as e: self._format_and_raise(key=None, value=None, cause=e) assert False diff --git a/tests/test_basic_ops_dict.py b/tests/test_basic_ops_dict.py index beab06e7f..ddf8b8fd7 100644 --- a/tests/test_basic_ops_dict.py +++ b/tests/test_basic_ops_dict.py @@ -598,10 +598,23 @@ def test_dict_nested_structured_delitem() -> None: assert "name" not in c.user -@pytest.mark.parametrize("d, expected", [({}, 0), ({"a": 10, "b": 11}, 2)]) -def test_dict_len(d: Any, expected: Any) -> None: - c = OmegaConf.create(d) - assert len(c) == expected +@pytest.mark.parametrize( + "d, expected", + [ + pytest.param(DictConfig({}), 0, id="empty"), + pytest.param(DictConfig({"a": 10}), 1, id="full"), + pytest.param(DictConfig(None), 0, id="none"), + pytest.param(DictConfig("???"), 0, id="missing"), + pytest.param( + DictConfig("${foo}", parent=OmegaConf.create({"foo": {"a": 10}})), + 0, + id="interpolation", + ), + pytest.param(DictConfig("${foo}"), 0, id="broken_interpolation"), + ], +) +def test_dict_len(d: DictConfig, expected: Any) -> None: + assert d.__len__() == expected def test_dict_assign_illegal_value() -> None: diff --git a/tests/test_basic_ops_list.py b/tests/test_basic_ops_list.py index 0c292b686..eb526b7ed 100644 --- a/tests/test_basic_ops_list.py +++ b/tests/test_basic_ops_list.py @@ -65,10 +65,31 @@ def test_list_get_do_not_return_default( @pytest.mark.parametrize( - "input_, expected, list_key", + "input_, expected, expected_no_resolve, list_key", [ - pytest.param([1, 2], [1, 2], None, id="simple"), - pytest.param(["${1}", 2], [2, 2], None, id="interpolation"), + pytest.param([1, 2], [1, 2], [1, 2], None, id="simple"), + pytest.param(["${1}", 2], [2, 2], ["${1}", 2], None, id="interpolation"), + pytest.param( + [ListConfig(None), ListConfig("${.2}"), [1, 2]], + [None, ListConfig([1, 2]), ListConfig([1, 2])], + [None, ListConfig("${.2}"), ListConfig([1, 2])], + None, + id="iter_over_lists", + ), + pytest.param( + [DictConfig(None), DictConfig("${.2}"), {"a": 10}], + [None, DictConfig({"a": 10}), DictConfig({"a": 10})], + [None, DictConfig("${.2}"), DictConfig({"a": 10})], + None, + id="iter_over_dicts", + ), + pytest.param( + ["???", ListConfig("???"), DictConfig("???")], + pytest.raises(MissingMandatoryValue), + ["???", ListConfig("???"), DictConfig("???")], + None, + id="iter_over_missing", + ), pytest.param( { "defaults": [ @@ -77,20 +98,45 @@ def test_list_get_do_not_return_default( {"foo": "${defaults.0.optimizer}_${defaults.1.dataset}"}, ] }, - [{"optimizer": "adam"}, {"dataset": "imagenet"}, {"foo": "adam_imagenet"}], + [ + OmegaConf.create({"optimizer": "adam"}), + OmegaConf.create({"dataset": "imagenet"}), + OmegaConf.create({"foo": "adam_imagenet"}), + ], + [ + OmegaConf.create({"optimizer": "adam"}), + OmegaConf.create({"dataset": "imagenet"}), + OmegaConf.create( + {"foo": "${defaults.0.optimizer}_${defaults.1.dataset}"} + ), + ], "defaults", id="str_interpolation", ), ], ) -def test_iterate_list(input_: Any, expected: Any, list_key: str) -> None: +def test_iterate_list( + input_: Any, expected: Any, expected_no_resolve: Any, list_key: str +) -> None: c = OmegaConf.create(input_) if list_key is not None: lst = c.get(list_key) else: lst = c - items = [x for x in lst] - assert items == expected + + def test_iter(iterator: Any, expected_output: Any) -> None: + if isinstance(expected_output, list): + items = [x for x in iterator] + assert items == expected_output + for idx in range(len(items)): + assert type(items[idx]) is type(expected_output[idx]) # noqa + else: + with expected_output: + for _ in iterator: + pass + + test_iter(iter(lst), expected) + test_iter(lst._iter_ex(resolve=False), expected_no_resolve) def test_iterate_list_with_missing_interpolation() -> None: @@ -242,6 +288,8 @@ def test_list_delitem() -> None: (OmegaConf.create([1, 2]), 2), (ListConfig(content=None), 0), (ListConfig(content="???"), 0), + (ListConfig(content="${foo}"), 0), + (ListConfig(content="${foo}", parent=DictConfig({"foo": [1, 2]})), 0), ], ) def test_list_len(lst: Any, expected: Any) -> None: