Skip to content

Commit

Permalink
support slice assignment in ListConfig (#736)
Browse files Browse the repository at this point in the history
  • Loading branch information
pixelb committed Feb 7, 2022
1 parent c9f4835 commit f3e5b2b
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 2 deletions.
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
## 2.2.0
### Features

- ListConfig now implements slice assignment ([#736](https://github.com/omry/omegaconf/issues/736))

## 2.1.1 (2021-08-17)
### Features

Expand Down
35 changes: 34 additions & 1 deletion omegaconf/listconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,40 @@ def _set_at_index(self, index: Union[int, slice], value: Any) -> None:

def __setitem__(self, index: Union[int, slice], value: Any) -> None:
try:
self._set_at_index(index, value)
if isinstance(index, slice):
_ = iter(value) # check iterable
self_indices = index.indices(len(self))
indexes = range(*self_indices)

# Ensure lengths match for extended slice assignment
if index.step not in (None, 1):
if len(indexes) != len(value):
raise ValueError(
f"attempt to assign sequence of size {len(value)}"
f" to extended slice of size {len(indexes)}"
)

# Initialize insertion offsets for empty slices
if len(indexes) == 0:
curr_index = self_indices[0] - 1
val_i = -1

# Delete and optionally replace non empty slices
only_removed = 0
for val_i, i in enumerate(indexes):
curr_index = i - only_removed
del self[curr_index]
if val_i < len(value):
self.insert(curr_index, value[val_i])
else:
only_removed += 1

# Insert any remaining input items
for val_i in range(val_i + 1, len(value)):
curr_index += 1
self.insert(curr_index, value[val_i])
else:
self._set_at_index(index, value)
except Exception as e:
self._format_and_raise(key=index, value=value, cause=e)

Expand Down
140 changes: 139 additions & 1 deletion tests/test_basic_ops_list.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# -*- coding: utf-8 -*-
import re
from textwrap import dedent
from typing import Any, List, Optional, Union
from typing import Any, Callable, List, MutableSequence, Optional, Union

from _pytest.python_api import RaisesContext
from pytest import mark, param, raises

from omegaconf import MISSING, AnyNode, DictConfig, ListConfig, OmegaConf, flag_override
Expand Down Expand Up @@ -845,6 +846,143 @@ def test_getitem_slice(sli: slice) -> None:
assert olst.__getitem__(sli) == expected


@mark.parametrize(
"constructor",
[OmegaConf.create, list, lambda lst: OmegaConf.create({"foo": lst}).foo],
)
@mark.parametrize(
"lst, idx, value, expected",
[
param(
["a", "b", "c", "d"],
slice(1, 3),
["x", "y"],
["a", "x", "y", "d"],
id="same-number-of-elements",
),
param(
["a", "x", "y", "d"],
slice(1, 3),
["x", "y", "z"],
["a", "x", "y", "z", "d"],
id="extra-elements",
),
param(
["a", "x", "y", "z", "d"],
slice(1, 1),
["b"],
["a", "b", "x", "y", "z", "d"],
id="insert only",
),
param(
["a", "b", "x", "y", "z", "d"],
slice(1, 1),
[],
["a", "b", "x", "y", "z", "d"],
id="nop",
),
param(
["a", "b", "x", "y", "z", "d"],
slice(1, 3),
[],
["a", "y", "z", "d"],
id="less-elements",
),
param(
["a", "y", "z", "d"],
slice(1, 2, 1),
["b"],
["a", "b", "z", "d"],
id="extended-slice",
),
param(
["a", "b", "c", "d"],
slice(1, 3, 1),
["x", "y"],
["a", "x", "y", "d"],
id="extended-slice2",
),
param(
["a", "b", "z", "d"],
slice(0, 3, 2),
["a", "c"],
["a", "b", "c", "d"],
id="extended-slice-disjoint",
),
param(
["a", "b", "c", "d"],
slice(1, 3),
1,
raises(TypeError),
id="non-iterable-input",
),
param(
["a", "b", "c", "d"],
slice(1, 3, 1),
["x", "y", "z"],
["a", "x", "y", "z", "d"],
id="extended-slice-length-mismatch",
),
param(
["a", "b", "c", "d", "e", "f"],
slice(1, 5, 2),
["x", "y", "z"],
raises(ValueError),
id="extended-slice-length-mismatch2",
),
param(
["a", "b", "c", "d", "e", "f"],
slice(-1, -3, -1),
["F", "E"],
["a", "b", "c", "d", "E", "F"],
id="extended-slice-reverse",
),
param(
["a", "b", "c", "d", "e", "g"],
slice(-1, -3, None),
["f"],
["a", "b", "c", "d", "e", "f", "g"],
id="slice-reverse-insert",
),
param(
["a", "b", "c", "r", "r", "e"],
slice(-3, -1, None),
["d"],
["a", "b", "c", "d", "e"],
id="slice-reverse-replace",
),
param(
["c", "d"],
slice(-10, -10, None),
["a", "b"],
["a", "b", "c", "d"],
id="slice-reverse-insert-underflow",
),
param(
["a", "b"],
slice(10, 10, None),
["c", "d"],
["a", "b", "c", "d"],
id="slice-reverse-insert-overflow",
),
],
)
def test_setitem_slice(
lst: List[Any],
idx: slice,
value: Union[List[Any], Any],
expected: Union[List[Any], RaisesContext[Any]],
constructor: Callable[[List[Any]], MutableSequence[Any]],
) -> None:
cfg = constructor(lst)
if isinstance(expected, list):
cfg[idx] = value
assert cfg == expected
else:
with expected:
cfg[idx] = value


@mark.parametrize(
"lst,idx,expected",
[
Expand Down

0 comments on commit f3e5b2b

Please sign in to comment.