Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix update() ordering to be more consistent with add() ordering #159

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 4 additions & 2 deletions sortedcontainers/sortedlist.py
Expand Up @@ -339,7 +339,8 @@ def update(self, iterable):

if _maxes:
if len(values) * 4 >= self._len:
values.extend(chain.from_iterable(_lists))
_lists.append(values)
values = reduce(iadd, _lists, [])
values.sort()
self._clear()
else:
Expand Down Expand Up @@ -1878,7 +1879,8 @@ def update(self, iterable):

if _maxes:
if len(values) * 4 >= self._len:
values.extend(chain.from_iterable(_lists))
_lists.append(values)
values = reduce(iadd, _lists, [])
values.sort(key=self._key)
self._clear()
else:
Expand Down
40 changes: 40 additions & 0 deletions tests/test_coverage_sortedkeylist_modulo.py
Expand Up @@ -5,6 +5,7 @@
import random
from .context import sortedcontainers
from sortedcontainers import SortedList, SortedKeyList
from itertools import chain, repeat
import pytest

if hexversion < 0x03000000:
Expand Down Expand Up @@ -107,6 +108,45 @@ def test_update():
assert len(slt) == 11000
slt._check()

def test_update_order_consistency():
def modulo_el0(tup):
return tup[0] % 10

slt1 = SortedKeyList(key=modulo_el0)
slt2 = SortedKeyList(key=modulo_el0)

def add_from_iterable(slt, it):
for item in it:
slt.add(item)

def add_from_all_iterables(slt, its):
for it in its:
add_from_iterable(slt, it)

def update_from_all_iterables(slt, its):
for it in its:
slt.update(it)

# the following iterators are set up (from large to small) such that they
# attempt to force the two kinds of internal update logic (extending upon
# the incoming iterable or appending to the existing elements by use of
# `add()`)
it1 = list(zip(repeat(0), range(5)))
it2 = list(zip(repeat(0), range(4)))
it3 = list(zip(repeat(0), range(3)))
it4 = list(zip(repeat(0), range(2)))
it5 = list(zip(repeat(0), range(1)))

it12345 = [it1, it2, it3, it4, it5]

add_from_all_iterables(slt1, it12345)
update_from_all_iterables(slt2, it12345)

slt1._check()
slt2._check()

assert all(tup[0] == tup[1] for tup in zip(slt1, slt2))

def test_contains():
slt = SortedKeyList(key=modulo)
slt._reset(7)
Expand Down
41 changes: 40 additions & 1 deletion tests/test_coverage_sortedkeylist_negate.py
Expand Up @@ -5,7 +5,7 @@
import random
from .context import sortedcontainers
from sortedcontainers import SortedKeyList, SortedListWithKey
from itertools import chain
from itertools import chain, repeat
import pytest

if hexversion < 0x03000000:
Expand Down Expand Up @@ -84,6 +84,45 @@ def test_update():
values = sorted((val for val in chain(range(100), range(1000), range(10000))), key=negate)
assert all(tup[0] == tup[1] for tup in zip(slt, values))

def test_update_order_consistency():
def negate_el0(tup):
return -tup[0]

slt1 = SortedKeyList(key=negate_el0)
slt2 = SortedKeyList(key=negate_el0)

def add_from_iterable(slt, it):
for item in it:
slt.add(item)

def add_from_all_iterables(slt, its):
for it in its:
add_from_iterable(slt, it)

def update_from_all_iterables(slt, its):
for it in its:
slt.update(it)

# the following iterators are set up (from large to small) such that they
# attempt to force the two kinds of internal update logic (extending upon
# the incoming iterable or appending to the existing elements by use of
# `add()`)
it1 = list(zip(repeat(0), range(5)))
it2 = list(zip(repeat(0), range(4)))
it3 = list(zip(repeat(0), range(3)))
it4 = list(zip(repeat(0), range(2)))
it5 = list(zip(repeat(0), range(1)))

it12345 = [it1, it2, it3, it4, it5]

add_from_all_iterables(slt1, it12345)
update_from_all_iterables(slt2, it12345)

slt1._check()
slt2._check()

assert all(tup[0] == tup[1] for tup in zip(slt1, slt2))

def test_contains():
slt = SortedKeyList(key=negate)
assert 0 not in slt
Expand Down