Skip to content
Browse files

Functioning nesting transactions

  • Loading branch information...
1 parent b7978b5 commit 55cc288c8e7b5de5f12c873c23b51a472c1853ed @jeffjenkins committed Jul 29, 2012
Showing with 125 additions and 27 deletions.
  1. +14 −6 mongoalchemy/ops.py
  2. +55 −20 mongoalchemy/session.py
  3. +56 −1 test/test_session.py
View
20 mongoalchemy/ops.py
@@ -19,16 +19,20 @@ def ensure_indexes(self):
index.ensure(c)
class ClearCollectionOp(Operation):
- def __init__(self, session, kind):
+ def __init__(self, trans_id, session, kind):
+ self.trans_id = trans_id
self.session = session
self.type = kind
def execute(self):
+ print 'CLEAROP', self.collection
self.collection.remove()
+ print 'CLEAROP', self.collection.count()
class UpdateDocumentOp(Operation):
- def __init__(self, session, document, safe, id_expression=None, upsert=False, update_ops={}, **kwargs):
+ def __init__(self, trans_id, session, document, safe, id_expression=None, upsert=False, update_ops={}, **kwargs):
from mongoalchemy.query import Query
self.session = session
+ self.trans_id = trans_id
self.type = type(document)
self.safe = safe
self.upsert = upsert
@@ -54,8 +58,9 @@ def execute(self):
return self.collection.update(self.db_key, self.dirty_ops, upsert=self.upsert, safe=self.safe)
class UpdateOp(Operation):
- def __init__(self, session, kind, safe, update_obj):
+ def __init__(self, trans_id, session, kind, safe, update_obj):
self.session = session
+ self.trans_id = trans_id
self.type = kind
self.safe = safe
self.query = update_obj.query.query
@@ -69,8 +74,9 @@ def execute(self):
class SaveOp(Operation):
- def __init__(self, session, document, safe):
+ def __init__(self, trans_id, session, document, safe):
self.session = session
+ self.trans_id = trans_id
self.data = document.wrap()
self.type = type(document)
self.safe = safe
@@ -84,8 +90,9 @@ def execute(self):
return self.collection.save(self.data, safe=self.safe)
class RemoveOp(Operation):
- def __init__(self, session, kind, safe, query):
+ def __init__(self, trans_id, session, kind, safe, query):
self.session = session
+ self.trans_id = trans_id
self.query = query.query
self.safe = safe
self.type = kind
@@ -96,7 +103,8 @@ def execute(self):
class RemoveDocumentOp(Operation):
- def __init__(self, session, obj, safe):
+ def __init__(self, trans_id, session, obj, safe):
+ self.trans_id = trans_id
self.session = session
self.type = type(obj)
self.safe = safe
View
75 mongoalchemy/session.py
@@ -41,7 +41,7 @@
anything intelligent for ordering.
'''
-
+from uuid import uuid4
from pymongo.connection import Connection
from bson import DBRef, ObjectId
from mongoalchemy.query import Query, QueryResult, RemoveQuery
@@ -72,7 +72,6 @@ def __init__(self, database, timezone=None, safe=False, cache_size=0):
self.timezone = timezone
self.cache_size = cache_size
self.cache = {}
-
self.transactions = []
@property
def autoflush(self):
@@ -131,7 +130,8 @@ def end(self):
''' End the session. Flush all pending operations and ending the
*pymongo* request'''
self.cache = {}
- self.flush()
+ if not self.transactions:
+ self.flush()
self.db.connection.end_request()
def insert(self, item, safe=None):
@@ -143,7 +143,7 @@ def add(self, item, safe=None):
item._set_session(self)
if safe is None:
safe = self.safe
- self.queue.append(SaveOp(self, item, safe))
+ self.queue.append(SaveOp(self.transaction_id, self, item, safe))
# after the save op is recorded, the document has an _id and can be
# cached
self.cache_write(item)
@@ -178,7 +178,7 @@ def update(self, item, id_expression=None, upsert=False, update_ops={}, safe=Non
'''
if safe is None:
safe = self.safe
- self.queue.append(UpdateDocumentOp(self, item, safe, id_expression=id_expression,
+ self.queue.append(UpdateDocumentOp(self.transaction_id, self, item, safe, id_expression=id_expression,
upsert=upsert, update_ops=update_ops, **kwargs))
if self.autoflush:
return self.flush()
@@ -243,7 +243,7 @@ def remove(self, obj, safe=None):
'''
if safe is None:
safe = self.safe
- remove = RemoveDocumentOp(self, obj, safe)
+ remove = RemoveDocumentOp(self.transaction_id, self, obj, safe)
self.queue.append(remove)
if self.autoflush:
return self.flush()
@@ -256,7 +256,7 @@ def execute_remove(self, remove):
if remove.safe is not None:
safe = remove.safe
- self.queue.append(RemoveOp(self, remove.type, safe, remove))
+ self.queue.append(RemoveOp(self.transaction_id, self, remove.type, safe, remove))
if self.autoflush:
return self.flush()
@@ -269,7 +269,7 @@ def execute_update(self, update, safe=False):
# safe = remove.safe
assert len(update.update_data) > 0
- self.queue.append(UpdateOp(self, update.query.type, safe, update))
+ self.queue.append(UpdateOp(self.transaction_id, self, update.query.type, safe, update))
if self.autoflush:
return self.flush()
@@ -320,23 +320,42 @@ def execute_find_and_modify(self, fm_exp):
self.cache_write(obj)
return obj
-
+ @property
+ def transaction_id(self):
+ if not self.transactions:
+ return None
+ return self.transactions[-1]
+
def get_indexes(self, cls):
''' Get the index information for the collection associated with
`cls`. Index information is returned in the same format as *pymongo*.
'''
return self.db[cls.get_collection_name()].index_information()
- def clear(self):
+ def clear_queue(self, trans_id=None):
''' Clear the queue of database operations without executing any of
the pending operations'''
- self.queue = []
+ if not self.queue:
+ return
+ if trans_id is None:
+ self.queue = []
+ return
+
+ for index, op in enumerate(self.queue):
+ if op.trans_id == trans_id:
+ break
+ print 'GOT INDEX', index
+ self.queue = self.queue[:index]
+ print '\t', self.queue
+
+ def clear_cache(self):
+ self.cache = {}
def clear_collection(self, *classes):
''' Clear all objects from the collections associated with the
objects in `*cls`. **use with caution!**'''
for c in classes:
- self.queue.append(ClearCollectionOp(self, c))
+ self.queue.append(ClearCollectionOp(self.transaction_id, self, c))
if self.autoflush:
self.flush()
@@ -347,10 +366,10 @@ def flush(self, safe=None):
try:
result = op.execute()
except:
- self.cache = {}
- self.clear()
+ self.clear_queue()
+ self.clear_cache()
raise
- self.clear()
+ self.clear_queue()
return result
def dereference(self, ref):
@@ -386,13 +405,29 @@ def clone(self, document):
if '_id' in wrapped:
del wrapped['_id']
return type(document).unwrap(wrapped, session=self)
-
+
def __enter__(self):
- self.transactions.append(None)
+ self.transactions.append(uuid4())
return self
def __exit__(self, exc_type, exc_val, exc_tb):
- self.transactions.pop()
- self.flush()
- self.end()
+ # Pop this level of transaction from the stack
+ id = self.transactions.pop()
+
+ # If exception, set us as being in an error state
+ if exc_type:
+ self.clear_queue(trans_id=id)
+
+ # If we aren't at the top level, return
+ if self.transactions:
+ return False
+
+ if not exc_type:
+ self.flush()
+ self.end()
+ else:
+ self.clear_queue()
+ self.clear_cache()
return False
+
+
View
57 test/test_session.py
@@ -60,7 +60,7 @@ def test_session():
s = Session.connect('unit-testing')
s.clear_collection(T)
s.insert(T(i=1))
- s.clear()
+ s.clear_queue()
s.end()
def test_context_manager():
@@ -137,6 +137,61 @@ def test_cache_miss():
t2 = s.query(TExtra).filter_by(mongo_id=t.mongo_id).one()
# assert id(t) == id(t2)
+def test_transactions():
+ class Doc(Document):
+ i = IntField()
+ s = Session.connect('unit-testing')
+ s.clear_collection(Doc)
+ assert s.query(Doc).count() == 0
+ with s:
+ assert s.query(Doc).count() == 0
+ s.add(Doc(i=4))
+ assert s.query(Doc).count() == 0
+ with s:
+ assert s.query(Doc).count() == 0
+ s.add(Doc(i=2))
+ assert s.query(Doc).count() == 0
+ assert s.query(Doc).count() == 0, s.query(Doc).count()
+ assert s.query(Doc).count() == 2
+
+def test_transactions2():
+ class Doc(Document):
+ i = IntField()
+ s = Session.connect('unit-testing')
+ s.clear_collection(Doc)
+ assert s.query(Doc).count() == 0
+ try:
+ with s:
+ assert s.query(Doc).count() == 0
+ s.add(Doc(i=4))
+ assert s.query(Doc).count() == 0
+ with s:
+ assert s.query(Doc).count() == 0
+ s.add(Doc(i=2))
+ assert s.query(Doc).count() == 0
+ raise Exception()
+ assert s.query(Doc).count() == 0, s.query(Doc).count()
+ except:
+ assert s.query(Doc).count() == 0, s.query(Doc).count()
+
+def test_transactions3():
+ class Doc(Document):
+ i = IntField()
+ s = Session.connect('unit-testing')
+ s.clear_collection(Doc)
+ assert s.query(Doc).count() == 0
+ with s:
+ s.add(Doc(i=4))
+ try:
+
+ with s:
+ s.add(Doc(i=2))
+ print 'RAISE'
+ raise Exception()
+ except:
+ print 'CAUGHT'
+ assert s.query(Doc).count() == 0, s.query(Doc).count()
+ assert s.query(Doc).count() == 1, s.query(Doc).count()
def test_cache_max():

0 comments on commit 55cc288

Please sign in to comment.
Something went wrong with that request. Please try again.