Skip to content

Commit

Permalink
fix: proper __setitem__ and insert for RepeatedComposite (#178)
Browse files Browse the repository at this point in the history
Fixes #177; please refer to the issue for context.

The proposed solution is not elegant at all, but might be the best in the short term until the underlying protocol buffer object supports the __setitem__ method.

Nota bene: the underlying RepeatedCompositeFieldContainer from google.protobuf.internal.containers already supports the insert method, and we only need to perform type conversion for this particular case.
  • Loading branch information
0x2b3bfa0 committed Feb 10, 2021
1 parent 30265d6 commit 1157a76
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 17 deletions.
68 changes: 53 additions & 15 deletions proto/marshal/collections/repeated.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,22 +119,60 @@ def __getitem__(self, key):
return self._marshal.to_python(self._pb_type, self.pb[key])

def __setitem__(self, key, value):
pb_value = self._marshal.to_proto(self._pb_type, value, strict=True)

# Protocol buffers does not define a useful __setitem__, so we
# have to pop everything after this point off the list and reload it.
after = [pb_value]
while self.pb[key:]:
after.append(self.pb.pop(key))
self.pb.extend(after)
# The underlying protocol buffer does not define __setitem__, so we
# have to implement all the operations on our own.

# If ``key`` is an integer, as in list[index] = value:
if isinstance(key, int):
if -len(self) <= key < len(self):
self.pop(key) # Delete the old item.
self.insert(key, value) # Insert the new item in its place.
else:
raise IndexError("list assignment index out of range")

# If ``key`` is a slice object, as in list[start:stop:step] = [values]:
elif isinstance(key, slice):
start, stop, step = key.indices(len(self))

if not isinstance(value, collections.abc.Iterable):
raise TypeError("can only assign an iterable")

if step == 1: # Is not an extended slice.
# Assign all the new values to the sliced part, replacing the
# old values, if any, and unconditionally inserting those
# values whose indices already exceed the slice length.
for index, item in enumerate(value):
if start + index < stop:
self.pop(start + index)
self.insert(start + index, item)

# If there are less values than the length of the slice, remove
# the remaining elements so that the slice adapts to the
# newly provided values.
for _ in range(stop - start - len(value)):
self.pop(start + len(value))

else: # Is an extended slice.
indices = range(start, stop, step)

if len(value) != len(indices): # XXX: Use PEP 572 on 3.8+
raise ValueError(
f"attempt to assign sequence of size "
f"{len(value)} to extended slice of size "
f"{len(indices)}"
)

# Assign each value to its index, calling this function again
# with individual integer indexes that get processed above.
for index, item in zip(indices, value):
self[index] = item

else:
raise TypeError(
f"list indices must be integers or slices, not {type(key).__name__}"
)

def insert(self, index: int, value):
"""Insert ``value`` in the sequence before ``index``."""
pb_value = self._marshal.to_proto(self._pb_type, value, strict=True)

# Protocol buffers does not define a useful insert, so we have
# to pop everything after this point off the list and reload it.
after = [pb_value]
while self.pb[index:]:
after.append(self.pb.pop(index))
self.pb.extend(after)
self.pb.insert(index, pb_value)
96 changes: 94 additions & 2 deletions tests/test_fields_repeated_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class Baz(proto.Message):
assert baz.foos[1].bar == 48


def test_repeated_composite_set():
def test_repeated_composite_set_index():
class Foo(proto.Message):
bar = proto.Field(proto.INT32, number=1)

Expand All @@ -176,9 +176,101 @@ class Baz(proto.Message):
baz.foos[1] = Foo(bar=55)
assert baz.foos[0].bar == 96
assert baz.foos[1].bar == 55
assert len(baz.foos) == 2


def test_repeated_composite_set_index_error():
class Foo(proto.Message):
bar = proto.Field(proto.INT32, number=1)

class Baz(proto.Message):
foos = proto.RepeatedField(proto.MESSAGE, message=Foo, number=1)

baz = Baz(foos=[])
with pytest.raises(IndexError):
baz.foos[0] = Foo(bar=55)


def test_repeated_composite_set_slice_less():
class Foo(proto.Message):
bar = proto.Field(proto.INT32, number=1)

class Baz(proto.Message):
foos = proto.RepeatedField(proto.MESSAGE, message=Foo, number=1)

baz = Baz(foos=[{"bar": 96}, {"bar": 48}, {"bar": 24}])
baz.foos[:2] = [{"bar": 12}]
assert baz.foos[0].bar == 12
assert baz.foos[1].bar == 24
assert len(baz.foos) == 2


def test_repeated_composite_set_slice_more():
class Foo(proto.Message):
bar = proto.Field(proto.INT32, number=1)

class Baz(proto.Message):
foos = proto.RepeatedField(proto.MESSAGE, message=Foo, number=1)

baz = Baz(foos=[{"bar": 12}])
baz.foos[:2] = [{"bar": 96}, {"bar": 48}, {"bar": 24}]
assert baz.foos[0].bar == 96
assert baz.foos[1].bar == 48
assert baz.foos[2].bar == 24
assert len(baz.foos) == 3


def test_repeated_composite_set_slice_not_iterable():
class Foo(proto.Message):
bar = proto.Field(proto.INT32, number=1)

class Baz(proto.Message):
foos = proto.RepeatedField(proto.MESSAGE, message=Foo, number=1)

baz = Baz(foos=[])
with pytest.raises(TypeError):
baz.foos[:1] = None


def test_repeated_composite_set_extended_slice():
class Foo(proto.Message):
bar = proto.Field(proto.INT32, number=1)

class Baz(proto.Message):
foos = proto.RepeatedField(proto.MESSAGE, message=Foo, number=1)

baz = Baz(foos=[{"bar": 96}, {"bar": 48}])
baz.foos[::-1] = [{"bar": 96}, {"bar": 48}]
assert baz.foos[0].bar == 48
assert baz.foos[1].bar == 96
assert len(baz.foos) == 2


def test_repeated_composite_set_extended_slice_wrong_length():
class Foo(proto.Message):
bar = proto.Field(proto.INT32, number=1)

class Baz(proto.Message):
foos = proto.RepeatedField(proto.MESSAGE, message=Foo, number=1)

baz = Baz(foos=[{"bar": 96}])
with pytest.raises(ValueError):
baz.foos[::-1] = []


def test_repeated_composite_set_wrong_key_type():
class Foo(proto.Message):
bar = proto.Field(proto.INT32, number=1)

class Baz(proto.Message):
foos = proto.RepeatedField(proto.MESSAGE, message=Foo, number=1)

baz = Baz(foos=[])
with pytest.raises(TypeError):
baz.foos[None] = Foo(bar=55)


def test_repeated_composite_set_wrong_type():
def test_repeated_composite_set_wrong_value_type():
class Foo(proto.Message):
bar = proto.Field(proto.INT32, number=1)

Expand Down

0 comments on commit 1157a76

Please sign in to comment.