Skip to content

Commit

Permalink
Fix update() ordering to be more consistent with add() ordering (#159)
Browse files Browse the repository at this point in the history
* Fix update() ordering to be more consistent with add() ordering
* Add tests for update order consistency
* Fix update order consistency test to use `modulo` key and to compare internally against using the `add()` method
* Fix update order consistency test to use `negate` key and to compare internally against using the `add()` method
* Improve performance by using `reduce`/`iadd` to construct `values` list
  • Loading branch information
bamartin125 committed Nov 4, 2020
1 parent 13d30bc commit 7dc426c
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 3 deletions.
6 changes: 4 additions & 2 deletions sortedcontainers/sortedlist.py
Expand Up @@ -339,7 +339,8 @@ def update(self, iterable):

if _maxes:
if len(values) * 4 >= self._len:
values.extend(chain.from_iterable(_lists))
_lists.append(values)
values = reduce(iadd, _lists, [])
values.sort()
self._clear()
else:
Expand Down Expand Up @@ -1878,7 +1879,8 @@ def update(self, iterable):

if _maxes:
if len(values) * 4 >= self._len:
values.extend(chain.from_iterable(_lists))
_lists.append(values)
values = reduce(iadd, _lists, [])
values.sort(key=self._key)
self._clear()
else:
Expand Down
40 changes: 40 additions & 0 deletions tests/test_coverage_sortedkeylist_modulo.py
Expand Up @@ -5,6 +5,7 @@
import random
from .context import sortedcontainers
from sortedcontainers import SortedList, SortedKeyList
from itertools import chain, repeat
import pytest

if hexversion < 0x03000000:
Expand Down Expand Up @@ -107,6 +108,45 @@ def test_update():
assert len(slt) == 11000
slt._check()

def test_update_order_consistency():
def modulo_el0(tup):
return tup[0] % 10

slt1 = SortedKeyList(key=modulo_el0)
slt2 = SortedKeyList(key=modulo_el0)

def add_from_iterable(slt, it):
for item in it:
slt.add(item)

def add_from_all_iterables(slt, its):
for it in its:
add_from_iterable(slt, it)

def update_from_all_iterables(slt, its):
for it in its:
slt.update(it)

# the following iterators are set up (from large to small) such that they
# attempt to force the two kinds of internal update logic (extending upon
# the incoming iterable or appending to the existing elements by use of
# `add()`)
it1 = list(zip(repeat(0), range(5)))
it2 = list(zip(repeat(0), range(4)))
it3 = list(zip(repeat(0), range(3)))
it4 = list(zip(repeat(0), range(2)))
it5 = list(zip(repeat(0), range(1)))

it12345 = [it1, it2, it3, it4, it5]

add_from_all_iterables(slt1, it12345)
update_from_all_iterables(slt2, it12345)

slt1._check()
slt2._check()

assert all(tup[0] == tup[1] for tup in zip(slt1, slt2))

def test_contains():
slt = SortedKeyList(key=modulo)
slt._reset(7)
Expand Down
41 changes: 40 additions & 1 deletion tests/test_coverage_sortedkeylist_negate.py
Expand Up @@ -5,7 +5,7 @@
import random
from .context import sortedcontainers
from sortedcontainers import SortedKeyList, SortedListWithKey
from itertools import chain
from itertools import chain, repeat
import pytest

if hexversion < 0x03000000:
Expand Down Expand Up @@ -84,6 +84,45 @@ def test_update():
values = sorted((val for val in chain(range(100), range(1000), range(10000))), key=negate)
assert all(tup[0] == tup[1] for tup in zip(slt, values))

def test_update_order_consistency():
def negate_el0(tup):
return -tup[0]

slt1 = SortedKeyList(key=negate_el0)
slt2 = SortedKeyList(key=negate_el0)

def add_from_iterable(slt, it):
for item in it:
slt.add(item)

def add_from_all_iterables(slt, its):
for it in its:
add_from_iterable(slt, it)

def update_from_all_iterables(slt, its):
for it in its:
slt.update(it)

# the following iterators are set up (from large to small) such that they
# attempt to force the two kinds of internal update logic (extending upon
# the incoming iterable or appending to the existing elements by use of
# `add()`)
it1 = list(zip(repeat(0), range(5)))
it2 = list(zip(repeat(0), range(4)))
it3 = list(zip(repeat(0), range(3)))
it4 = list(zip(repeat(0), range(2)))
it5 = list(zip(repeat(0), range(1)))

it12345 = [it1, it2, it3, it4, it5]

add_from_all_iterables(slt1, it12345)
update_from_all_iterables(slt2, it12345)

slt1._check()
slt2._check()

assert all(tup[0] == tup[1] for tup in zip(slt1, slt2))

def test_contains():
slt = SortedKeyList(key=negate)
assert 0 not in slt
Expand Down

0 comments on commit 7dc426c

Please sign in to comment.