Skip to content

Commit

Permalink
Fix cursor limit&skip when passed as slice
Browse files Browse the repository at this point in the history
  • Loading branch information
touilleMan committed May 13, 2017
1 parent 7e8b68a commit d7aeac1
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 4 deletions.
29 changes: 25 additions & 4 deletions mongomock/collection.py
Expand Up @@ -1707,8 +1707,9 @@ def __iter__(self):
return self

def clone(self):
return Cursor(self.collection,
self._spec, self._sort, self._projection, self._skip, self._limit)
cursor = Cursor(self.collection,
self._spec, self._sort, self._projection, self._skip, self._limit)
return cursor

def __next__(self):
if self._skip and not self._skipped:
Expand Down Expand Up @@ -1792,8 +1793,28 @@ def distinct(self, key):

def __getitem__(self, index):
if isinstance(index, slice):
# Limit the cursor to the given slice
self._dataset = (x for x in list(self._dataset)[index])
if index.step is not None:
raise IndexError("Cursor instances do not support slice steps")

skip = 0
if index.start is not None:
if index.start < 0:
raise IndexError("Cursor instances do not support"
"negative indices")
skip = index.start

if index.stop is not None:
limit = index.stop - skip
if limit < 0:
raise IndexError("stop index must be greater than start"
"index for slice %r" % index)
if limit == 0:
self.__empty = True
else:
limit = 0

self._skip = skip
self._limit = limit
return self
elif not isinstance(index, int):
raise TypeError("index '%s' cannot be applied to Cursor instances" % index)
Expand Down
16 changes: 16 additions & 0 deletions tests/test__collection_api.py
Expand Up @@ -159,6 +159,20 @@ def test__cursor_clone(self):
with self.assertRaises(StopIteration):
next(iterator2)

def test__cursor_clone_keep_limit_skip(self):
self.db.collection.insert([{"a": "b"}, {"b": "c"}, {"c": "d"}])
cursor1 = self.db.collection.find()[1:2]
cursor2 = cursor1.clone()
result1 = list(cursor1)
result2 = list(cursor2)
self.assertEqual(result1, result2)

cursor3 = self.db.collection.find(skip=1, limit=1)
cursor4 = cursor3.clone()
result3 = list(cursor3)
result4 = list(cursor4)
self.assertEqual(result3, result4)

def test_cursor_returns_document_copies(self):
obj = {'a': 1, 'b': 2}
self.db.collection.insert(obj)
Expand Down Expand Up @@ -654,6 +668,8 @@ def test__cursor_getitem_slice(self):
ret = cursor[1:4]
self.assertIs(ret, cursor)
count = cursor.count()
self.assertEqual(count, 3)
count = cursor.count(with_limit_and_skip=True)
self.assertEqual(count, 2)

def test__cursor_getitem_negative_index(self):
Expand Down

0 comments on commit d7aeac1

Please sign in to comment.