Permalink
Fetching contributors…
Cannot retrieve contributors at this time
337 lines (271 sloc) 10.2 KB
import logging
import weakref
from threading import local as thread_local
from threading import Event
from threading import Thread
try:
from Queue import Queue
except ImportError:
from queue import Queue
try:
import gevent
from gevent import Greenlet as GThread
from gevent.event import Event as GEvent
from gevent.local import local as greenlet_local
from gevent.queue import Queue as GQueue
except ImportError:
GThread = GQueue = GEvent = None
from playhouse.sqlite_ext import SqliteExtDatabase
logger = logging.getLogger('peewee.sqliteq')
class ResultTimeout(Exception):
pass
class WriterPaused(Exception):
pass
class ShutdownException(Exception):
pass
class AsyncCursor(object):
__slots__ = ('sql', 'params', 'commit', 'timeout',
'_event', '_cursor', '_exc', '_idx', '_rows', '_ready')
def __init__(self, event, sql, params, commit, timeout):
self._event = event
self.sql = sql
self.params = params
self.commit = commit
self.timeout = timeout
self._cursor = self._exc = self._idx = self._rows = None
self._ready = False
def set_result(self, cursor, exc=None):
self._cursor = cursor
self._exc = exc
self._idx = 0
self._rows = cursor.fetchall() if exc is None else []
self._event.set()
return self
def _wait(self, timeout=None):
timeout = timeout if timeout is not None else self.timeout
if not self._event.wait(timeout=timeout) and timeout:
raise ResultTimeout('results not ready, timed out.')
if self._exc is not None:
raise self._exc
self._ready = True
def __iter__(self):
if not self._ready:
self._wait()
if self._exc is not None:
raise self._exec
return self
def next(self):
if not self._ready:
self._wait()
try:
obj = self._rows[self._idx]
except IndexError:
raise StopIteration
else:
self._idx += 1
return obj
__next__ = next
@property
def lastrowid(self):
if not self._ready:
self._wait()
return self._cursor.lastrowid
@property
def rowcount(self):
if not self._ready:
self._wait()
return self._cursor.rowcount
@property
def description(self):
return self._cursor.description
def close(self):
self._cursor.close()
def fetchall(self):
return list(self) # Iterating implies waiting until populated.
def fetchone(self):
if not self._ready:
self._wait()
try:
return next(self)
except StopIteration:
return None
SHUTDOWN = StopIteration
PAUSE = object()
UNPAUSE = object()
class Writer(object):
__slots__ = ('database', 'queue')
def __init__(self, database, queue):
self.database = database
self.queue = queue
def run(self):
conn = self.database.get_conn()
try:
while True:
try:
if conn is None: # Paused.
if self.wait_unpause():
conn = self.database.get_conn()
else:
conn = self.loop(conn)
except ShutdownException:
logger.info('writer received shutdown request, exiting.')
return
finally:
if conn is not None:
self.database._close(conn)
self.database._local.closed = True
def wait_unpause(self):
obj = self.queue.get()
if obj is UNPAUSE:
logger.info('writer unpaused - reconnecting to database.')
return True
elif obj is SHUTDOWN:
raise ShutdownException()
elif obj is PAUSE:
logger.error('writer received pause, but is already paused.')
else:
obj.set_result(None, WriterPaused())
logger.warning('writer paused, not handling %s', obj)
def loop(self, conn):
obj = self.queue.get()
if isinstance(obj, AsyncCursor):
self.execute(obj)
elif obj is PAUSE:
logger.info('writer paused - closing database connection.')
self.database._close(conn)
self.database._local.closed = True
return
elif obj is UNPAUSE:
logger.error('writer received unpause, but is already running.')
elif obj is SHUTDOWN:
raise ShutdownException()
else:
logger.error('writer received unsupported object: %s', obj)
return conn
def execute(self, obj):
logger.debug('received query %s', obj.sql)
try:
cursor = self.database._execute(obj.sql, obj.params, obj.commit)
except Exception as execute_err:
cursor = None
exc = execute_err # python3 is so fucking lame.
else:
exc = None
return obj.set_result(cursor, exc)
class SqliteQueueDatabase(SqliteExtDatabase):
WAL_MODE_ERROR_MESSAGE = ('SQLite must be configured to use the WAL '
'journal mode when using this feature. WAL mode '
'allows one or more readers to continue reading '
'while another connection writes to the '
'database.')
def __init__(self, database, use_gevent=False, autostart=True,
queue_max_size=None, results_timeout=None, *args, **kwargs):
if 'threadlocals' in kwargs and not kwargs['threadlocals']:
raise ValueError('"threadlocals" must be true to use the '
'SqliteQueueDatabase.')
kwargs['threadlocals'] = True
kwargs['check_same_thread'] = False
# Ensure that journal_mode is WAL. This value is passed to the parent
# class constructor below.
pragmas = self._validate_journal_mode(
kwargs.pop('journal_mode', None),
kwargs.pop('pragmas', None))
# Reference to execute_sql on the parent class. Since we've overridden
# execute_sql(), this is just a handy way to reference the real
# implementation.
Parent = super(SqliteQueueDatabase, self)
self._execute = Parent.execute_sql
# Call the parent class constructor with our modified pragmas.
Parent.__init__(database, pragmas=pragmas, *args, **kwargs)
self._autostart = autostart
self._results_timeout = results_timeout
self._is_stopped = True
# Get different objects depending on the threading implementation.
self._thread_helper = self.get_thread_impl(use_gevent)(queue_max_size)
# Create the writer thread, optionally starting it.
self._create_write_queue()
if self._autostart:
self.start()
def get_thread_impl(self, use_gevent):
return GreenletHelper if use_gevent else ThreadHelper
def _validate_journal_mode(self, journal_mode=None, pragmas=None):
if journal_mode and journal_mode.lower() != 'wal':
raise ValueError(self.WAL_MODE_ERROR_MESSAGE)
if pragmas:
pdict = dict((k.lower(), v) for (k, v) in pragmas)
if pdict.get('journal_mode', 'wal').lower() != 'wal':
raise ValueError(self.WAL_MODE_ERROR_MESSAGE)
return [(k, v) for (k, v) in pragmas
if k != 'journal_mode'] + [('journal_mode', 'wal')]
else:
return [('journal_mode', 'wal')]
def _create_write_queue(self):
self._write_queue = self._thread_helper.queue()
def queue_size(self):
return self._write_queue.qsize()
def execute_sql(self, sql, params=None, require_commit=True, timeout=None):
if not require_commit:
return self._execute(sql, params, require_commit=require_commit)
cursor = AsyncCursor(
event=self._thread_helper.event(),
sql=sql,
params=params,
commit=require_commit,
timeout=self._results_timeout if timeout is None else timeout)
self._write_queue.put(cursor)
return cursor
def start(self):
with self._conn_lock:
if not self._is_stopped:
return False
def run():
writer = Writer(self, self._write_queue)
writer.run()
self._writer = self._thread_helper.thread(run)
self._writer.start()
self._is_stopped = False
return True
def stop(self):
logger.debug('environment stop requested.')
with self._conn_lock:
if self._is_stopped:
return False
self._write_queue.put(SHUTDOWN)
self._writer.join()
self._is_stopped = True
return True
def is_stopped(self):
with self._conn_lock:
return self._is_stopped
def pause(self):
with self._conn_lock:
self._write_queue.put(PAUSE)
def unpause(self):
with self._conn_lock:
self._write_queue.put(UNPAUSE)
def __unsupported__(self, *args, **kwargs):
raise ValueError('This method is not supported by %r.' % type(self))
atomic = transaction = savepoint = __unsupported__
class ThreadHelper(object):
__slots__ = ('queue_max_size',)
def __init__(self, queue_max_size=None):
self.queue_max_size = queue_max_size
def event(self): return Event()
def queue(self, max_size=None):
max_size = max_size if max_size is not None else self.queue_max_size
return Queue(maxsize=max_size or 0)
def thread(self, fn, *args, **kwargs):
thread = Thread(target=fn, args=args, kwargs=kwargs)
thread.daemon = True
return thread
class GreenletHelper(ThreadHelper):
__slots__ = ()
def event(self): return GEvent()
def queue(self, max_size=None):
max_size = max_size if max_size is not None else self.queue_max_size
return GQueue(maxsize=max_size or 0)
def thread(self, fn, *args, **kwargs):
def wrap(*a, **k):
gevent.sleep()
return fn(*a, **k)
return GThread(wrap, *args, **kwargs)