Skip to content

Commit

Permalink
Merge 9d8ecb3 into 5f1c1b1
Browse files Browse the repository at this point in the history
  • Loading branch information
josenavas committed Jun 25, 2015
2 parents 5f1c1b1 + 9d8ecb3 commit 463497e
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 32 deletions.
52 changes: 30 additions & 22 deletions qiita_db/sql_connection.py
Expand Up @@ -587,7 +587,7 @@ def _checker(func):
"""Decorator to check that methods are executed inside the context"""
@wraps(func)
def wrapper(self, *args, **kwargs):
if not self._is_inside_context:
if self._contexts_entered == 0:
raise RuntimeError(
"Operation not permitted. Transaction methods can only be "
"invoked within the context manager.")
Expand Down Expand Up @@ -624,32 +624,40 @@ def __init__(self, name):
self._results = []
self.index = 0
self._conn_handler = SQLConnectionHandler()
self._is_inside_context = False
self._contexts_entered = 0

def __enter__(self):
self._is_inside_context = True
self._contexts_entered += 1
return self

def _clean_up(self, exc_type):
status = self._conn_handler._connection.get_transaction_status()
if exc_type is not None:
# An exception occurred during the execution of the transaction
# Make sure that we leave the DB w/o any modification
self.rollback()
elif self._queries:
# There are still queries to be executed, execute them
# It is safe to use the execute method here, as internally is
# wrapped in a try/except and rollbacks in case of failure
self.execute()
elif status != TRANSACTION_STATUS_IDLE:
# There are no queries to be executed, however, the transaction
# is still not committed. Commit it so the changes are not lost
self.commit()

def __exit__(self, exc_type, exc_value, traceback):
# We need to wrap the entire function in a try/finally because
# at the end of the function we need to set _is_inside_context to false
try:
status = self._conn_handler._connection.get_transaction_status()
if exc_type is not None:
# An exception occurred during the execution of the transaction
# Make sure that we leave the DB w/o any modification
self.rollback()
elif self._queries:
# There are still queries to be executed, execute them
# It is safe to use the execute method here, as internally is
# wrapped in a try/except and rollbacks in case of failure
self.execute()
elif status != TRANSACTION_STATUS_IDLE:
# There are no queries to be executed, however, the transaction
# is still not committed. Commit it so the changes are not lost
self.commit()
finally:
self._is_inside_context = False
# We only need to perform some action if this is the last context
# that we are entering
if self._contexts_entered == 1:
# We need to wrap the entire function in a try/finally because
# at the end we need to decrement _contexts_entered
try:
self._clean_up(exc_type)
finally:
self._contexts_entered -= 1
else:
self._contexts_entered -= 1

def _raise_execution_error(self, sql, sql_args, error):
"""Rollbacks the current transaction and raises a useful error
Expand Down
50 changes: 40 additions & 10 deletions qiita_db/test/test_sql_connection.py
Expand Up @@ -199,15 +199,14 @@ def test_execute_fetchall_with_sql_args(self):

class TestTransaction(TestBase):
def test_init(self):
with Transaction("test_init") as obs:
obs = Transaction("test_init")
self.assertEqual(obs._name, "test_init")
self.assertEqual(obs._queries, [])
self.assertEqual(obs._results, [])
self.assertEqual(obs.index, 0)
self.assertTrue(
isinstance(obs._conn_handler, SQLConnectionHandler))
self.assertFalse(obs._is_inside_context)
obs = Transaction("test_init")
self.assertEqual(obs._name, "test_init")
self.assertEqual(obs._queries, [])
self.assertEqual(obs._results, [])
self.assertEqual(obs.index, 0)
self.assertTrue(
isinstance(obs._conn_handler, SQLConnectionHandler))
self.assertEqual(obs._contexts_entered, 0)

def test_replace_placeholders(self):
with Transaction("test_replace_placeholders") as trans:
Expand Down Expand Up @@ -535,7 +534,38 @@ def test_context_manager_no_commit(self):
trans._conn_handler._connection.get_transaction_status(),
TRANSACTION_STATUS_IDLE)

def test_context_managet_checker(self):
def test_context_manager_multiple(self):
trans = Transaction("test_context_manager_multiple")
self.assertEqual(trans._contexts_entered, 0)

with trans:
self.assertEqual(trans._contexts_entered, 1)

trans.add("SELECT 42")
with trans:
self.assertEqual(trans._contexts_entered, 2)
sql = """INSERT INTO qiita.test_table (str_column, int_column)
VALUES (%s, %s) RETURNING str_column, int_column"""
args = [['insert1', 1], ['insert2', 2], ['insert3', 3]]
trans.add(sql, args, many=True)

# We exited the second context, anything should have been executed
self.assertEqual(trans._contexts_entered, 1)
self.assertEqual(
trans._conn_handler._connection.get_transaction_status(),
TRANSACTION_STATUS_IDLE)
self._assert_sql_equal([])

# We have exited the first context, everything should have been
# executed and committed
self.assertEqual(trans._contexts_entered, 0)
self._assert_sql_equal([('insert1', True, 1), ('insert2', True, 2),
('insert3', True, 3)])
self.assertEqual(
trans._conn_handler._connection.get_transaction_status(),
TRANSACTION_STATUS_IDLE)

def test_context_manager_checker(self):
t = Transaction("test_context_managet_checker")

with self.assertRaises(RuntimeError):
Expand Down

0 comments on commit 463497e

Please sign in to comment.