-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
sqliteq.py
264 lines (210 loc) · 8.34 KB
/
sqliteq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
import logging
import weakref
from threading import Event
from threading import Thread
try:
from Queue import Queue
except ImportError:
from queue import Queue
try:
import gevent
from gevent import Greenlet as GThread
from gevent.event import Event as GEvent
from gevent.queue import Queue as GQueue
except ImportError:
GThread = GQueue = GEvent = None
from playhouse.sqlite_ext import SqliteExtDatabase
logger = logging.getLogger('peewee.sqliteq')
class ResultTimeout(Exception):
pass
class AsyncCursor(object):
__slots__ = ('sql', 'params', 'commit', 'timeout',
'_event', '_cursor', '_exc', '_idx', '_rows')
def __init__(self, event, sql, params, commit, timeout):
self._event = event
self.sql = sql
self.params = params
self.commit = commit
self.timeout = timeout
self._cursor = self._exc = self._idx = self._rows = None
def set_result(self, cursor, exc=None):
self._cursor = cursor
self._exc = exc
self._idx = 0
self._rows = cursor.fetchall() if exc is None else []
self._event.set()
return self
def _wait(self, timeout=None):
timeout = timeout if timeout is not None else self.timeout
if not self._event.wait(timeout=timeout) and timeout:
raise ResultTimeout('results not ready, timed out.')
if self._exc is not None:
raise self._exc
def __iter__(self):
self._wait()
if self._exc is not None:
raise self._exec
return self
def next(self):
try:
obj = self._rows[self._idx]
except IndexError:
raise StopIteration
else:
self._idx += 1
return obj
__next__ = next
@property
def lastrowid(self):
self._wait()
return self._cursor.lastrowid
@property
def rowcount(self):
self._wait()
return self._cursor.rowcount
@property
def description(self):
return self._cursor.description
def close(self):
self._cursor.close()
def fetchall(self):
return list(self) # Iterating implies waiting until populated.
def fetchone(self):
self._wait()
try:
return next(self)
except StopIteration:
return None
THREADLOCAL_ERROR_MESSAGE = ('threadlocals cannot be set to True when using '
'the Sqlite thread / queue database. All queries '
'are serialized through a single connection, so '
'allowing multiple threads to connect defeats '
'the purpose of this database.')
WAL_MODE_ERROR_MESSAGE = ('SQLite must be configured to use the WAL journal '
'mode when using this feature. WAL mode allows '
'one or more readers to continue reading while '
'another connection writes to the database.')
class SqliteQueueDatabase(SqliteExtDatabase):
def __init__(self, database, use_gevent=False, autostart=False, readers=1,
queue_max_size=None, results_timeout=None, *args, **kwargs):
if kwargs.get('threadlocals'):
raise ValueError(THREADLOCAL_ERROR_MESSAGE)
kwargs['threadlocals'] = False
kwargs['check_same_thread'] = False
# Ensure that journal_mode is WAL. This value is passed to the parent
# class constructor below.
pragmas = self._validate_journal_mode(
kwargs.pop('journal_mode', None),
kwargs.pop('pragmas', None))
# Reference to execute_sql on the parent class. Since we've overridden
# execute_sql(), this is just a handy way to reference the real
# implementation.
Parent = super(SqliteQueueDatabase, self)
self.__execute_sql = Parent.execute_sql
# Call the parent class constructor with our modified pragmas.
Parent.__init__(database, pragmas=pragmas, *args, **kwargs)
self._autostart = autostart
self._results_timeout = results_timeout
self._num_readers = readers
self._is_stopped = True
self._thread_helper = self.get_thread_impl(use_gevent)(queue_max_size)
self._create_queues_and_workers()
if self._autostart:
self.start()
def get_thread_impl(self, use_gevent):
return GreenletHelper if use_gevent else ThreadHelper
def _validate_journal_mode(self, journal_mode=None, pragmas=None):
if journal_mode and journal_mode.lower() != 'wal':
raise ValueError(WAL_MODE_ERROR_MESSAGE)
if pragmas:
pdict = dict((k.lower(), v) for (k, v) in pragmas)
if pdict.get('journal_mode', 'wal').lower() != 'wal':
raise ValueError(WAL_MODE_ERROR_MESSAGE)
return [(k, v) for (k, v) in pragmas
if k != 'journal_mode'] + [('journal_mode', 'wal')]
else:
return [('journal_mode', 'wal')]
def _create_queues_and_workers(self):
self._write_queue = self._thread_helper.queue()
self._read_queue = self._thread_helper.queue()
target = self._run_worker_loop
self._writer = self._thread_helper.thread(target, self._write_queue)
self._readers = [self._thread_helper.thread(target, self._read_queue)
for _ in range(self._num_readers)]
def _run_worker_loop(self, queue):
while True:
async_cursor = queue.get()
if async_cursor is StopIteration:
logger.info('worker shutting down.')
return
logger.debug('received query %s', async_cursor.sql)
self._process_execution(async_cursor)
def _process_execution(self, async_cursor):
try:
cursor = self.__execute_sql(async_cursor.sql, async_cursor.params,
async_cursor.commit)
except Exception as exc:
cursor = None
else:
exc = None
return async_cursor.set_result(cursor, exc)
def queue_size(self):
return (self._write_queue.qsize(), self._read_queue.qsize())
def execute_sql(self, sql, params=None, require_commit=True, timeout=None):
cursor = AsyncCursor(
event=self._thread_helper.event(),
sql=sql,
params=params,
commit=require_commit,
timeout=self._results_timeout if timeout is None else timeout)
queue = self._write_queue if require_commit else self._read_queue
queue.put(cursor)
return cursor
def start(self):
with self._conn_lock:
if not self._is_stopped:
return False
self._writer.start()
for reader in self._readers:
reader.start()
logger.info('workers started.')
self._is_stopped = False
return True
def stop(self):
logger.debug('environment stop requested.')
with self._conn_lock:
if self._is_stopped:
return False
self._write_queue.put(StopIteration)
for _ in self._readers:
self._read_queue.put(StopIteration)
self._writer.join()
for reader in self._readers:
reader.join()
return True
def is_stopped(self):
with self._conn_lock:
return self._is_stopped
class ThreadHelper(object):
__slots__ = ('queue_max_size',)
def __init__(self, queue_max_size=None):
self.queue_max_size = queue_max_size
def event(self): return Event()
def queue(self, max_size=None):
max_size = max_size if max_size is not None else self.queue_max_size
return Queue(maxsize=max_size or 0)
def thread(self, fn, *args, **kwargs):
thread = Thread(target=fn, args=args, kwargs=kwargs)
thread.daemon = True
return thread
class GreenletHelper(ThreadHelper):
__slots__ = ('queue_max_size',)
def event(self): return GEvent()
def queue(self, max_size=None):
max_size = max_size if max_size is not None else self.queue_max_size
return GQueue(maxsize=max_size or 0)
def thread(self, fn, *args, **kwargs):
def wrap(*a, **k):
gevent.sleep()
return fn(*a, **k)
return GThread(wrap, *args, **kwargs)