diff --git a/news/532.feature b/news/532.feature new file mode 100644 index 000000000..4919c73be --- /dev/null +++ b/news/532.feature @@ -0,0 +1 @@ +Optimized ListConfig iterator (550x to 150x slower than native list iterator) diff --git a/omegaconf/listconfig.py b/omegaconf/listconfig.py index 5c1aaee90..1f60b7f2f 100644 --- a/omegaconf/listconfig.py +++ b/omegaconf/listconfig.py @@ -497,6 +497,23 @@ def __hash__(self) -> int: def __iter__(self) -> Iterator[Any]: return self._iter_ex(resolve=True) + class ListIterator(Iterator[Any]): + def __init__(self, lst: Any, resolve: bool) -> None: + self.iter = iter(lst.__dict__["_content"]) + self.resolve = resolve + self.index = 0 + + def __next__(self) -> Any: + v = next(self.iter) + + if self.resolve: + v = v._dereference_node() + + if v._is_missing(): + raise MissingMandatoryValue(f"Missing value at index {self.index}") + self.index = self.index + 1 + return _get_value(v) + def _iter_ex(self, resolve: bool) -> Iterator[Any]: try: if self._is_none(): @@ -504,25 +521,7 @@ def _iter_ex(self, resolve: bool) -> Iterator[Any]: if self._is_missing(): raise MissingMandatoryValue("Cannot iterate a missing ListConfig") - class MyItems(Iterator[Any]): - def __init__(self, lst: ListConfig) -> None: - self.lst = lst - self.index = 0 - - def __next__(self) -> Any: - if self.index == len(self.lst): - raise StopIteration() - if resolve: - v = self.lst[self.index] - else: - v = self.lst.__dict__["_content"][self.index] - if v is not None: - v = _get_value(v) - self.index = self.index + 1 - return v - - assert isinstance(self.__dict__["_content"], list) - return MyItems(self) + return ListConfig.ListIterator(self, resolve) except (ReadonlyConfigError, TypeError, MissingMandatoryValue) as e: self._format_and_raise(key=None, value=None, cause=e) assert False @@ -571,15 +570,16 @@ def _set_value_impl( if flags is None: flags = {} + vk = get_value_kind(value) if OmegaConf.is_none(value): if not self._is_optional(): raise ValidationError( "Non optional ListConfig cannot be constructed from None" ) self.__dict__["_content"] = None - elif get_value_kind(value) == ValueKind.MANDATORY_MISSING: + elif vk is ValueKind.MANDATORY_MISSING: self.__dict__["_content"] = "???" - elif get_value_kind(value) in ( + elif vk in ( ValueKind.INTERPOLATION, ValueKind.STR_INTERPOLATION, ):