Skip to content

Commit

Permalink
Add: ability to bulk insert using a generator
Browse files Browse the repository at this point in the history
  • Loading branch information
yegle authored and ajdavis committed Oct 17, 2013
1 parent c3d13b8 commit ec6b7bc
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
11 changes: 5 additions & 6 deletions pymongo/collection.py
Expand Up @@ -334,15 +334,14 @@ def insert(self, doc_or_docs, manipulate=True,
docs = [docs]

if manipulate:
docs = [self.__database._fix_incoming(doc, self) for doc in docs]
docs = (self.__database._fix_incoming(doc, self) for doc in docs)

safe, options = self._get_write_mode(safe, **kwargs)
message._do_batched_insert(self.__full_name, docs,
check_keys, safe, options,
continue_on_error, self.uuid_subtype,
self.database.connection)
ids = message._do_batched_insert(self.__full_name, docs,
check_keys, safe, options,
continue_on_error, self.uuid_subtype,
self.database.connection)

ids = [doc.get("_id", None) for doc in docs]
if return_one:
return ids[0]
else:
Expand Down
5 changes: 4 additions & 1 deletion pymongo/message.py
Expand Up @@ -210,7 +210,9 @@ def _insert_message(insert_message, send_safe):
begin += bson._make_c_string(collection_name)
message_length = len(begin)
data = [begin]
ids = []
for doc in docs:
ids.append(doc.get("_id", None))
encoded = bson.BSON.encode(doc, check_keys, uuid_subtype)
encoded_length = len(encoded)
if encoded_length > client.max_bson_size:
Expand Down Expand Up @@ -238,7 +240,7 @@ def _insert_message(insert_message, send_safe):
last_error = exc
# With unacknowledged writes just return at the first error.
elif not safe:
return
return ids
# With acknowledged writes raise immediately.
else:
raise
Expand All @@ -250,5 +252,6 @@ def _insert_message(insert_message, send_safe):
# Re-raise any exception stored due to continue_on_error
if last_error is not None:
raise last_error
return ids
if _use_c:
_do_batched_insert = _cmessage._do_batched_insert
4 changes: 3 additions & 1 deletion test/test_collection.py
Expand Up @@ -761,7 +761,9 @@ def test_insert_multiple(self):
self.assertTrue(isinstance(id, list))
self.assertEqual(1, len(id))

self.assertRaises(InvalidOperation, db.test.insert, [])
# InvalidOperation: empty list
# OperationFailure: exhausted generator
self.assertRaises((InvalidOperation, OperationFailure), db.test.insert, [])

def test_insert_multiple_with_duplicate(self):
db = self.db
Expand Down

0 comments on commit ec6b7bc

Please sign in to comment.