Skip to content

Commit

Permalink
add batch_size() to Cursor PYTHON-161
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike Dirolf committed Sep 8, 2010
1 parent 64f9641 commit 08ffa4f
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 18 deletions.
43 changes: 33 additions & 10 deletions pymongo/cursor.py
Expand Up @@ -79,6 +79,7 @@ def __init__(self, collection, spec=None, fields=None, skip=0, limit=0,
self.__fields = fields
self.__skip = skip
self.__limit = limit
self.__batch_size = 0

# This is ugly. People want to be able to do cursor[5:5] and
# get an empty result set (old behavior was an
Expand Down Expand Up @@ -153,6 +154,7 @@ def clone(self):
copy.__ordering = self.__ordering
copy.__explain = self.__explain
copy.__hint = self.__hint
copy.__batch_size = self.__batch_size
return copy

def __die(self):
Expand Down Expand Up @@ -223,6 +225,30 @@ def limit(self, limit):
self.__limit = limit
return self

def batch_size(self, batch_size):
"""Set the size for batches of results returned by this cursor.
Raises :class:`TypeError` if `batch_size` is not an instance
of :class:`int`. Raises :class:`ValueError` if `batch_size` is
less than ``0``. Raises
:class:`~pymongo.errors.InvalidOperation` if this
:class:`Cursor` has already been used. The last `batch_size`
applied to this cursor takes precedence.
:Parameters:
- `batch_size`: The size of each batch of results requested.
.. versionadded:: 1.8.1+
"""
if not isinstance(batch_size, int):
raise TypeError("batch_size must be an int")
if batch_size < 0:
raise ValueError("batch_size must be >= 0")
self.__check_okay_to_chain()

self.__batch_size = batch_size == 1 and 2 or batch_size
return self

def skip(self, skip):
"""Skips the first `skip` results of this cursor.
Expand Down Expand Up @@ -530,24 +556,21 @@ def _refresh(self):
if len(self.__data) or self.__killed:
return len(self.__data)

if self.__id is None:
# Query
if self.__id is None: # Query
self.__send_message(
message.query(self.__query_options(),
self.__collection.full_name,
self.__skip, self.__limit,
self.__query_spec(), self.__fields))
if not self.__id:
self.__killed = True
elif self.__id:
# Get More
limit = 0
elif self.__id: # Get More
if self.__limit:
if self.__limit > self.__retrieved:
limit = self.__limit - self.__retrieved
else:
self.__killed = True
return 0
limit = self.__limit - self.__retrieved
if self.__batch_size:
limit = min(limit, self.__batch_size)
else:
limit = self.__batch_size

self.__send_message(
message.get_more(self.__collection.full_name,
Expand Down
59 changes: 51 additions & 8 deletions test/test_cursor.py
Expand Up @@ -52,8 +52,7 @@ def test_explain(self):
def test_hint(self):
db = self.db
self.assertRaises(TypeError, db.test.find().hint, 5.5)
db.test.remove({})
db.test.drop_indexes()
db.test.drop()

for i in range(100):
db.test.insert({"num": i, "foo": i})
Expand Down Expand Up @@ -94,7 +93,7 @@ def test_limit(self):
self.assertRaises(TypeError, db.test.find().limit, "hello")
self.assertRaises(TypeError, db.test.find().limit, 5.5)

db.test.remove({})
db.test.drop()
for i in range(100):
db.test.save({"x": i})

Expand Down Expand Up @@ -134,6 +133,50 @@ def test_limit(self):
break
self.assertRaises(InvalidOperation, a.limit, 5)


def test_batch_size(self):
db = self.db
db.test.drop()
for x in range(200):
db.test.save({"x": x})

self.assertRaises(TypeError, db.test.find().batch_size, None)
self.assertRaises(TypeError, db.test.find().batch_size, "hello")
self.assertRaises(TypeError, db.test.find().batch_size, 5.5)
self.assertRaises(ValueError, db.test.find().batch_size, -1)
a = db.test.find()
for _ in a:
break
self.assertRaises(InvalidOperation, a.batch_size, 5)

def cursor_count(cursor, expected_count):
count = 0
for _ in cursor:
count += 1
self.assertEqual(expected_count, count)

cursor_count(db.test.find().batch_size(0), 200)
cursor_count(db.test.find().batch_size(1), 200)
cursor_count(db.test.find().batch_size(2), 200)
cursor_count(db.test.find().batch_size(5), 200)
cursor_count(db.test.find().batch_size(100), 200)
cursor_count(db.test.find().batch_size(500), 200)

cursor_count(db.test.find().batch_size(0).limit(1), 1)
cursor_count(db.test.find().batch_size(1).limit(1), 1)
cursor_count(db.test.find().batch_size(2).limit(1), 1)
cursor_count(db.test.find().batch_size(5).limit(1), 1)
cursor_count(db.test.find().batch_size(100).limit(1), 1)
cursor_count(db.test.find().batch_size(500).limit(1), 1)

cursor_count(db.test.find().batch_size(0).limit(10), 10)
cursor_count(db.test.find().batch_size(1).limit(10), 10)
cursor_count(db.test.find().batch_size(2).limit(10), 10)
cursor_count(db.test.find().batch_size(5).limit(10), 10)
cursor_count(db.test.find().batch_size(100).limit(10), 10)
cursor_count(db.test.find().batch_size(500).limit(10), 10)


def test_skip(self):
db = self.db

Expand Down Expand Up @@ -189,7 +232,7 @@ def test_sort(self):
[("hello", DESCENDING)], DESCENDING)
self.assertRaises(TypeError, db.test.find().sort, "hello", "world")

db.test.remove({})
db.test.drop()

unsort = range(10)
random.shuffle(unsort)
Expand Down Expand Up @@ -218,7 +261,7 @@ def test_sort(self):
shuffled = list(expected)
random.shuffle(shuffled)

db.test.remove({})
db.test.drop()
for (a, b) in shuffled:
db.test.save({"a": a, "b": b})

Expand All @@ -235,7 +278,7 @@ def test_sort(self):

def test_count(self):
db = self.db
db.test.remove({})
db.test.drop()

self.assertEqual(0, db.test.find().count())

Expand All @@ -260,7 +303,7 @@ def test_count(self):

def test_where(self):
db = self.db
db.test.remove({})
db.test.drop()

a = db.test.find()
self.assertRaises(TypeError, a.where, 5)
Expand Down Expand Up @@ -405,7 +448,7 @@ def test_clone(self):
self.assertNotEqual(cursor, cursor.clone())

def test_count_with_fields(self):
self.db.test.remove({})
self.db.test.drop()
self.db.test.save({"x": 1})

if not version.at_least(self.db.connection, (1, 1, 3, -1)):
Expand Down

0 comments on commit 08ffa4f

Please sign in to comment.