Skip to content

Commit

Permalink
API change: both insert and save return inserted _id(s)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike Dirolf committed Jul 28, 2009
1 parent 819778a commit 0ac83e4
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 15 deletions.
2 changes: 1 addition & 1 deletion gridfs/grid_file.py
Expand Up @@ -107,7 +107,7 @@ class directly - instead see the `gridfs.GridFS.open` method.
file_spec["length"] = 0
file_spec["uploadDate"] = datetime.datetime.utcnow()
file_spec.setdefault("chunkSize", 256000)
self.__id = self.__collection.files.insert(file_spec)["_id"]
self.__id = self.__collection.files.insert(file_spec)

# we use repr(self.__id) here because we need it to be string and
# filename gets tricky with renaming. this is a hack.
Expand Down
24 changes: 12 additions & 12 deletions pymongo/collection.py
Expand Up @@ -129,24 +129,23 @@ def database(self):
return self.__database

def save(self, to_save, manipulate=True, safe=False):
"""Save a SON object in this collection.
"""Save a document in this collection.
Raises TypeError if to_save is not an instance of dict. If `safe`
is True then the save will be checked for errors, raising
OperationFailure if one occurred. Checking for safety requires an extra
round-trip to the database.
round-trip to the database. Returns the _id of the saved document.
:Parameters:
- `to_save`: the SON object to be saved
- `manipulate` (optional): manipulate the son object before saving it
- `manipulate` (optional): manipulate the SON object before saving it
- `safe` (optional): check that the save succeeded?
"""
if not isinstance(to_save, types.DictType):
raise TypeError("cannot save object of type %s" % type(to_save))

if "_id" not in to_save:
result = self.insert(to_save, manipulate, safe)
return result.get("_id", None)
return self.insert(to_save, manipulate, safe)
else:
self.update({"_id": to_save["_id"]}, to_save, True,
manipulate, safe)
Expand All @@ -157,15 +156,15 @@ def insert(self, doc_or_docs,
"""Insert a document(s) into this collection.
If manipulate is set the document(s) are manipulated using any
SONManipulators that have been added to this database. Returns the
inserted object or a list of inserted objects. If `safe` is True then
the insert will be checked for errors, raising OperationFailure if one
occurred. Checking for safety requires an extra round-trip to the
database.
SONManipulators that have been added to this database. Returns the _id
of the inserted document or a list of _ids of the inserted documents.
If `safe` is True then the insert will be checked for errors, raising
OperationFailure if one occurred. Checking for safety requires an extra
round-trip to the database.
:Parameters:
- `doc_or_docs`: a SON object or list of SON objects to be inserted
- `manipulate` (optional): monipulate the objects before inserting?
- `manipulate` (optional): monipulate the documents before inserting?
- `safe` (optional): check that the insert succeeded?
- `check_keys` (optional): check if keys start with '$' or
contain '.', raising `pymongo.errors.InvalidName` in either case
Expand All @@ -188,7 +187,8 @@ def insert(self, doc_or_docs,
if error:
raise OperationFailure("insert failed: " + error["err"])

return len(docs) == 1 and docs[0] or docs
ids = [doc.get("_id", None) for doc in docs]
return len(ids) == 1 and ids[0] or ids

def update(self, spec, document,
upsert=False, manipulate=False, safe=False):
Expand Down
16 changes: 14 additions & 2 deletions test/test_collection.py
Expand Up @@ -285,9 +285,11 @@ def test_insert_find_one(self):
db.test.remove({})
self.assertEqual(db.test.find().count(), 0)
doc = {"hello": u"world"}
db.test.insert(doc)
id = db.test.insert(doc)
self.assertEqual(db.test.find().count(), 1)
self.assertEqual(doc, db.test.find_one())
self.assertEqual(doc["_id"], id)
self.assert_(isinstance(id, ObjectId))

def remove_insert_find_one(dict):
db.test.remove({})
Expand Down Expand Up @@ -396,11 +398,21 @@ def test_insert_multiple(self):
doc1 = {"hello": u"world"}
doc2 = {"hello": u"mike"}
self.assertEqual(db.test.find().count(), 0)
db.test.insert([doc1, doc2])
ids = db.test.insert([doc1, doc2])
self.assertEqual(db.test.find().count(), 2)
self.assertEqual(doc1, db.test.find_one({"hello": u"world"}))
self.assertEqual(doc2, db.test.find_one({"hello": u"mike"}))

self.assertEqual(2, len(ids))
self.assertEqual(doc1["_id"], ids[0])
self.assertEqual(doc2["_id"], ids[1])

def test_save(self):
self.db.drop_collection("test")
id = self.db.test.save({"hello": "world"})
self.assertEqual(self.db.test.find_one()["_id"], id)
self.assert_(isinstance(id, ObjectId))

def test_unique_index(self):
db = self.db

Expand Down

0 comments on commit 0ac83e4

Please sign in to comment.