Permalink
Browse files

Have the execution functions deal with entering the mediator.

  • Loading branch information...
1 parent eeecb7e commit a88044710ccd4d1f0741de730042b2418149e903 @kgaughan committed Oct 14, 2012
Showing with 84 additions and 76 deletions.
  1. +67 −58 dbkit.py
  2. +5 −6 tests/fakedb.py
  3. +11 −11 tests/test_dbkit.py
  4. +1 −1 tests/test_pool.py
View
125 dbkit.py
@@ -94,7 +94,7 @@ class Context(object):
"""A database connection context."""
__slots__ = (
- '_mdr', '_depth', 'logger', 'default_factory',
+ 'mdr', '_depth', 'logger', 'default_factory',
'last_row_count', 'last_row_id') + _EXCEPTIONS
stack = _ContextStack()
@@ -103,7 +103,7 @@ def __init__(self, module, mdr):
Initialise a context with a given driver module and connection.
"""
super(Context, self).__init__()
- self._mdr = mdr
+ self.mdr = mdr
self._depth = 0
self.logger = null_logger
self.default_factory = tuple_set
@@ -140,42 +140,41 @@ def transaction(self):
# The idea here is to fake the nesting of transactions. Only when
# we've gotten back to the topmost transaction context do we actually
# commit or rollback.
- with self._mdr:
+ with self.mdr:
try:
self._depth += 1
yield self
self._depth -= 1
- except self._mdr.OperationalError:
+ except self.mdr.OperationalError:
# We've lost the connection, so there's no sense in
# attempting to roll back back the transaction.
self._depth -= 1
raise
except:
self._depth -= 1
if self._depth == 0:
- self._mdr.rollback()
+ self.mdr.rollback()
raise
if self._depth == 0:
- self._mdr.commit()
+ self.mdr.commit()
@contextlib.contextmanager
def cursor(self):
"""Get a cursor for the current connection. For internal use only."""
- with self._mdr:
- logger.debug("Creating cursor")
- cursor = self._mdr.cursor()
- try:
- logger.debug("Yielding cursor")
- yield cursor
- if cursor.rowcount != -1:
- self.last_row_count = cursor.rowcount
- self.last_row_id = getattr(cursor, 'lastrowid', None)
- except:
- self.last_row_count = None
- self.last_row_id = None
- logger.debug("Closing cursor")
- _safe_close(cursor)
- raise
+ logger.debug("Creating cursor")
+ cursor = self.mdr.cursor()
+ try:
+ logger.debug("Yielding cursor")
+ yield cursor
+ if cursor.rowcount != -1:
+ self.last_row_count = cursor.rowcount
+ self.last_row_id = getattr(cursor, 'lastrowid', None)
+ except:
+ self.last_row_count = None
+ self.last_row_id = None
+ logger.debug("Closing cursor")
+ _safe_close(cursor)
+ raise
def execute(self, stmt, args):
"""Execute a statement, returning a cursor. For internal use only."""
@@ -205,9 +204,9 @@ def close(self):
for exc in _EXCEPTIONS:
setattr(self, exc, None)
try:
- self._mdr.close()
+ self.mdr.close()
finally:
- self._mdr = None
+ self.mdr = None
# pylint: disable-msg=R0903
@@ -331,6 +330,7 @@ def __exit__(self, exc_type, _exc_value, _traceback):
self.conn = None
def cursor(self):
+ cursor = None
try:
cursor = self.conn.cursor()
cursor.execute('SELECT 1')
@@ -620,17 +620,20 @@ def last_row_id():
def execute(stmt, args=()):
"""Execute an SQL statement. Returns the number of affected rows."""
- cursor = Context.current().execute(stmt, args)
- row_count = cursor.rowcount
- _safe_close(cursor)
+ ctx = Context.current()
+ with ctx.mdr:
+ cursor = ctx.execute(stmt, args)
+ row_count = cursor.rowcount
+ _safe_close(cursor)
return row_count
def query(stmt, args=(), factory=None):
"""Execute a query. This returns an iterator of the result set."""
ctx = Context.current()
factory = ctx.default_factory if factory is None else factory
- return factory(ctx.execute(stmt, args))
+ with ctx.mdr:
+ return factory(ctx.execute(stmt, args), ctx.mdr)
def query_row(stmt, args=(), factory=None):
@@ -658,9 +661,11 @@ def query_column(stmt, args=()):
def execute_proc(procname, args=()):
"""Execute a stored procedure. Returns the number of affected rows."""
- cursor = Context.current().execute_proc(procname, args)
- row_count = cursor.rowcount
- _safe_close(cursor)
+ ctx = Context.current()
+ with ctx.mdr:
+ cursor = ctx.execute_proc(procname, args)
+ row_count = cursor.rowcount
+ _safe_close(cursor)
return row_count
@@ -670,7 +675,8 @@ def query_proc(procname, args=(), factory=None):
"""
ctx = Context.current()
factory = ctx.default_factory if factory is None else factory
- return factory(ctx.execute_proc(procname, args))
+ with ctx.mdr:
+ return factory(ctx.execute_proc(procname, args), ctx.mdr)
def query_proc_row(procname, args=(), factory=None):
@@ -701,41 +707,44 @@ def query_proc_column(procname, args=()):
return query_proc(procname, args, column_set)
-def dict_set(cursor):
+def dict_set(cursor, mdr):
"""Iterator over a statement's results as a dict."""
columns = [col[0] for col in cursor.description]
- try:
- while True:
- row = cursor.fetchone()
- if row is None:
- break
- yield AttrDict(zip(columns, row))
- finally:
- _safe_close(cursor)
+ with mdr:
+ try:
+ while True:
+ row = cursor.fetchone()
+ if row is None:
+ break
+ yield AttrDict(zip(columns, row))
+ finally:
+ _safe_close(cursor)
-def tuple_set(cursor):
+def tuple_set(cursor, mdr):
"""Iterator over a statement's results where each row is a tuple."""
- try:
- while True:
- row = cursor.fetchone()
- if row is None:
- break
- yield row
- finally:
- _safe_close(cursor)
+ with mdr:
+ try:
+ while True:
+ row = cursor.fetchone()
+ if row is None:
+ break
+ yield row
+ finally:
+ _safe_close(cursor)
-def column_set(cursor):
+def column_set(cursor, mdr):
"""Iterator over the first column of a statement's results."""
- try:
- while True:
- row = cursor.fetchone()
- if row is None:
- break
- yield row[0]
- finally:
- _safe_close(cursor)
+ with mdr:
+ try:
+ while True:
+ row = cursor.fetchone()
+ if row is None:
+ break
+ yield row[0]
+ finally:
+ _safe_close(cursor)
class AttrDict(dict):
View
@@ -5,6 +5,7 @@
# DB names used to trigger certain behaviours.
INVALID_DB = 'invalid-db'
INVALID_CURSOR = 'invalid-cursor'
+HAPPY_OUT = 'happy-out'
apilevel = '2.0'
threadsafety = 2
@@ -35,7 +36,7 @@ def close(self):
if not self.valid:
raise ProgrammingError("Cannot close a closed connection.")
self.valid = False
- for cursor in cursors:
+ for cursor in self.cursors:
cursor.close()
self.session.append('close')
if self.database == INVALID_DB:
@@ -59,22 +60,20 @@ class Cursor(object):
A fake cursor.
"""
- __slots__ = ['connection', 'valid', 'result', 'rowcount']
-
def __init__(self, connection):
self.connection = connection
self.result = None
if connection.database == INVALID_CURSOR:
self.valid = False
- raise OperationalError()
+ raise OperationalError("You've tripped INVALID_CURSOR!")
connection.cursors.add(self)
self.valid = True
self.rowcount = -1
def close(self):
self.connection.session.append('cursor-close')
if not self.valid:
- raise InterfaceError()
+ raise InterfaceError("Cursor is closed")
self.connection.cursors.remove(self)
self.valid = False
@@ -103,7 +102,7 @@ def callproc(self, procname, args=()):
def fetchone(self):
if not self.valid:
- raise InterfaceError()
+ raise InterfaceError("Cursor is closed")
result = self.result
self.result = None
return result
View
@@ -55,7 +55,7 @@ def test_bad_connect():
try:
with dbkit.connect(sqlite3, '/nonexistent.db') as ctx:
# Wouldn't do this in real code as the mediator is private.
- with ctx._mdr:
+ with ctx.mdr:
pass
assert False, "Should not have been able to open database."
except sqlite3.OperationalError:
@@ -74,15 +74,15 @@ def test_context():
assert len(ctx.stack) == 1
assert dbkit.Context.current(with_exception=False) is ctx
- assert ctx._mdr is not None
+ assert ctx.mdr is not None
assert ctx.logger is not None
ctx.close()
try:
dbkit.context()
assert False, "Should not have been able to access context."
except:
pass
- assert ctx._mdr is None
+ assert ctx.mdr is None
assert ctx.logger is None
assert len(ctx.stack) == 0
@@ -177,26 +177,26 @@ def test_unpooled_disconnect():
with ctx:
try:
with dbkit.transaction():
- assert ctx._mdr.depth == 1
- assert ctx._mdr.conn is not None
+ assert ctx.mdr.depth == 1
+ assert ctx.mdr.conn is not None
assert dbkit.query_value(GET_COUNTER, ('foo',)) == 42
raise ctx.OperationalError("Simulating disconnect")
except:
- assert ctx._mdr.depth == 0
- assert ctx._mdr.conn is None
+ assert ctx.mdr.depth == 0
+ assert ctx.mdr.conn is None
raise
assert False, "Should've raised OperationalError"
except ctx.OperationalError, exc:
- assert ctx._mdr.depth == 0
- assert ctx._mdr.conn is None
+ assert ctx.mdr.depth == 0
+ assert ctx.mdr.conn is None
assert exc.message == "Simulating disconnect"
# Test reconnect. As we're running this all against an in-memory DB,
# everything in it will have been throttled, thus the only query we can
# do is query the list of tables, which will be empty.
with ctx:
assert len(list(dbkit.query_column(LIST_TABLES))) == 0
- assert ctx._mdr.conn is not None
+ assert ctx.mdr.conn is not None
ctx.close()
@@ -236,7 +236,7 @@ def test_procs():
dbkit.query_proc_row('query_proc_row')
dbkit.query_proc_value('query_proc_value')
list(dbkit.query_proc_column('query_proc_column'))
- conn = ctx._mdr.conn
+ conn = ctx.mdr.conn
assert conn.executed == 4
assert conn.session == [
'cursor', 'proc:execute_proc', 'cursor-close',
View
@@ -19,7 +19,7 @@ def test_check_pool():
def test_lazy_connect():
assert len(POOL._pool) == 0
with POOL.connect() as ctx:
- assert isinstance(ctx._mdr, dbkit.PooledConnectionMediator)
+ assert isinstance(ctx.mdr, dbkit.PooledConnectionMediator)
assert POOL._allocated == 0
assert len(POOL._pool) == 0
assert POOL._allocated == 0

0 comments on commit a880447

Please sign in to comment.