Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimized list iterator some more #618

Merged
merged 1 commit into from
Mar 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion news/532.misc
Original file line number Diff line number Diff line change
@@ -1 +1 @@
Optimized ListConfig iteration by 7.7x in a benchmark
Optimized ListConfig iteration by 8.8x in a benchmark
12 changes: 7 additions & 5 deletions omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,12 +532,14 @@ def _get_value(value: Any) -> Any:
from .base import Container
from .nodes import ValueNode

if isinstance(value, Container) and (
value._is_none() or value._is_missing() or value._is_interpolation()
):
return value._value()
if isinstance(value, ValueNode):
value = value._value()
return value._value()
elif isinstance(value, Container):
boxed = value._value()
if boxed is None or _is_missing_literal(boxed) or _is_interpolation(boxed):
return boxed

# return primitives and regular OmegaConf Containers as is
return value


Expand Down
5 changes: 4 additions & 1 deletion omegaconf/basecontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,10 @@ def __delitem__(self, key: Any) -> None:
...

def __len__(self) -> int:
return self.__dict__["_content"].__len__() # type: ignore
if self._is_none() or self._is_missing() or self._is_interpolation():
return 0
content = self.__dict__["_content"]
return len(content)
odelalleau marked this conversation as resolved.
Show resolved Hide resolved

def merge_with_cli(self) -> None:
args_list = sys.argv[1:]
Expand Down
35 changes: 19 additions & 16 deletions omegaconf/listconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from ._utils import (
ValueKind,
_get_value,
_is_none,
format_and_raise,
get_value_kind,
Expand Down Expand Up @@ -153,14 +152,6 @@ def __dir__(self) -> Iterable[str]:
return []
return [str(x) for x in range(0, len(self))]

def __len__(self) -> int:
if self._is_none():
return 0
if self._is_missing():
return 0
assert isinstance(self.__dict__["_content"], list)
return len(self.__dict__["_content"])

def __setattr__(self, key: str, value: Any) -> None:
self._format_and_raise(
key=key,
Expand Down Expand Up @@ -500,20 +491,32 @@ def __iter__(self) -> Iterator[Any]:

class ListIterator(Iterator[Any]):
def __init__(self, lst: Any, resolve: bool) -> None:
self.iter = iter(lst.__dict__["_content"])
self.resolve = resolve
self.iterator = iter(lst.__dict__["_content"])
self.index = 0
from .nodes import ValueNode

self.ValueNode = ValueNode

def __next__(self) -> Any:
v = next(self.iter)

x = next(self.iterator)
if self.resolve:
v = v._dereference_node()
x = x._dereference_node()
if x._is_missing():
raise MissingMandatoryValue(f"Missing value at index {self.index}")

if v._is_missing():
raise MissingMandatoryValue(f"Missing value at index {self.index}")
self.index = self.index + 1
return _get_value(v)
if isinstance(x, self.ValueNode):
return x._value()
odelalleau marked this conversation as resolved.
Show resolved Hide resolved
else:
# Must be omegaconf.Container. not checking for perf reasons.
if x._is_none():
return None
return x

def __repr__(self) -> str: # pragma: no cover
return f"ListConfig.ListIterator(resolve={self.resolve})"

def _iter_ex(self, resolve: bool) -> Iterator[Any]:
try:
Expand All @@ -523,7 +526,7 @@ def _iter_ex(self, resolve: bool) -> Iterator[Any]:
raise MissingMandatoryValue("Cannot iterate a missing ListConfig")

return ListConfig.ListIterator(self, resolve)
except (ReadonlyConfigError, TypeError, MissingMandatoryValue) as e:
except (TypeError, MissingMandatoryValue) as e:
self._format_and_raise(key=None, value=None, cause=e)
assert False

Expand Down
21 changes: 17 additions & 4 deletions tests/test_basic_ops_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,10 +598,23 @@ def test_dict_nested_structured_delitem() -> None:
assert "name" not in c.user


@pytest.mark.parametrize("d, expected", [({}, 0), ({"a": 10, "b": 11}, 2)])
def test_dict_len(d: Any, expected: Any) -> None:
c = OmegaConf.create(d)
assert len(c) == expected
@pytest.mark.parametrize(
"d, expected",
[
pytest.param(DictConfig({}), 0, id="empty"),
pytest.param(DictConfig({"a": 10}), 1, id="full"),
pytest.param(DictConfig(None), 0, id="none"),
pytest.param(DictConfig("???"), 0, id="missing"),
pytest.param(
DictConfig("${foo}", parent=OmegaConf.create({"foo": {"a": 10}})),
0,
id="interpolation",
),
pytest.param(DictConfig("${foo}"), 0, id="broken_interpolation"),
],
)
def test_dict_len(d: DictConfig, expected: Any) -> None:
assert d.__len__() == expected


def test_dict_assign_illegal_value() -> None:
Expand Down
62 changes: 55 additions & 7 deletions tests/test_basic_ops_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,31 @@ def test_list_get_do_not_return_default(


@pytest.mark.parametrize(
"input_, expected, list_key",
"input_, expected, expected_no_resolve, list_key",
[
pytest.param([1, 2], [1, 2], None, id="simple"),
pytest.param(["${1}", 2], [2, 2], None, id="interpolation"),
pytest.param([1, 2], [1, 2], [1, 2], None, id="simple"),
pytest.param(["${1}", 2], [2, 2], ["${1}", 2], None, id="interpolation"),
pytest.param(
[ListConfig(None), ListConfig("${.2}"), [1, 2]],
[None, ListConfig([1, 2]), ListConfig([1, 2])],
[None, ListConfig("${.2}"), ListConfig([1, 2])],
None,
id="iter_over_lists",
),
pytest.param(
[DictConfig(None), DictConfig("${.2}"), {"a": 10}],
[None, DictConfig({"a": 10}), DictConfig({"a": 10})],
[None, DictConfig("${.2}"), DictConfig({"a": 10})],
None,
id="iter_over_dicts",
),
pytest.param(
["???", ListConfig("???"), DictConfig("???")],
pytest.raises(MissingMandatoryValue),
["???", ListConfig("???"), DictConfig("???")],
None,
id="iter_over_missing",
),
pytest.param(
{
"defaults": [
Expand All @@ -77,20 +98,45 @@ def test_list_get_do_not_return_default(
{"foo": "${defaults.0.optimizer}_${defaults.1.dataset}"},
]
},
[{"optimizer": "adam"}, {"dataset": "imagenet"}, {"foo": "adam_imagenet"}],
[
OmegaConf.create({"optimizer": "adam"}),
OmegaConf.create({"dataset": "imagenet"}),
OmegaConf.create({"foo": "adam_imagenet"}),
],
[
OmegaConf.create({"optimizer": "adam"}),
OmegaConf.create({"dataset": "imagenet"}),
OmegaConf.create(
{"foo": "${defaults.0.optimizer}_${defaults.1.dataset}"}
),
],
"defaults",
id="str_interpolation",
),
],
)
def test_iterate_list(input_: Any, expected: Any, list_key: str) -> None:
def test_iterate_list(
input_: Any, expected: Any, expected_no_resolve: Any, list_key: str
) -> None:
c = OmegaConf.create(input_)
if list_key is not None:
lst = c.get(list_key)
else:
lst = c
items = [x for x in lst]
assert items == expected

def test_iter(iterator: Any, expected_output: Any) -> None:
if isinstance(expected_output, list):
items = [x for x in iterator]
assert items == expected_output
for idx in range(len(items)):
assert type(items[idx]) is type(expected_output[idx]) # noqa
else:
with expected_output:
for _ in iterator:
pass

test_iter(iter(lst), expected)
test_iter(lst._iter_ex(resolve=False), expected_no_resolve)


def test_iterate_list_with_missing_interpolation() -> None:
Expand Down Expand Up @@ -242,6 +288,8 @@ def test_list_delitem() -> None:
(OmegaConf.create([1, 2]), 2),
(ListConfig(content=None), 0),
(ListConfig(content="???"), 0),
(ListConfig(content="${foo}"), 0),
(ListConfig(content="${foo}", parent=DictConfig({"foo": [1, 2]})), 0),
],
)
def test_list_len(lst: Any, expected: Any) -> None:
Expand Down