Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
293 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
Copyright 2012 Scott Sadler | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
#!/usr/bin/env python | ||
|
||
from setuptools import setup, find_packages | ||
|
||
setup( | ||
name='squirrel', | ||
description='Psycopg2 wrapper for tornadp', | ||
version='0.1', | ||
author='scott sadler', | ||
py_modules=['squirrel'], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
import logging as _logging | ||
import weakref | ||
import psycopg2 | ||
import psycopg2.extensions | ||
from collections import deque | ||
from functools import partial | ||
from tornado import ioloop, stack_context, gen | ||
|
||
|
||
logger = _logging.getLogger('squirrel.pool') | ||
|
||
|
||
class ConnectionPool(object): | ||
""" | ||
Manages connections and provisions cursors. | ||
""" | ||
def __init__(self, io_loop, max_connections=10, **connect_kwargs): | ||
self.io_loop = io_loop | ||
self.max_connections = max_connections | ||
connect_kwargs['async'] = 1 | ||
self.connect_kwargs = connect_kwargs | ||
self.connections = deque() | ||
self.queue = deque() | ||
self.owed = 0 | ||
self._refs = set() | ||
self._closed = False | ||
|
||
def cursor(self, callback): | ||
""" Get a cursor """ | ||
dispatch = partial(self._dispatch, callback) | ||
dispatch = stack_context.wrap(dispatch) | ||
self.queue.append(dispatch) | ||
self._process_queue() | ||
|
||
def execute(self, sql, args, callback): | ||
""" Shortcut to execute a query """ | ||
self.cursor(lambda cursor: cursor.execute(sql, args, callback)) | ||
|
||
def close(self): | ||
""" | ||
Closes the pool. The effect of this is not that the pool can no | ||
longer be used, but that connections will be closed if not in use. | ||
""" | ||
self._closed = True | ||
while self.connections: | ||
self.connections.pop().close() | ||
|
||
def _process_queue(self): | ||
with stack_context.NullContext(): | ||
while self.queue and self.owed < self.max_connections: | ||
self.owed += 1 | ||
self.queue.popleft()() | ||
|
||
def _checkin(self, connection, err=None): | ||
""" Called automatically on dereference of CursorFairy """ | ||
self.owed -= 1 | ||
if err: | ||
logger.error("Error: %s, closing connection" % err) | ||
connection.close() | ||
elif self._closed: | ||
connection.close() | ||
elif not connection.closed: | ||
self.connections.append(connection) | ||
|
||
@gen.engine | ||
def _dispatch(self, callback): | ||
try: | ||
connection = self.connections.popleft() | ||
except IndexError: | ||
connection = psycopg2.connect(**self.connect_kwargs) | ||
try: | ||
yield gen.Task(poll, connection, self.io_loop) | ||
except Exception as e: | ||
self._checkin(connection, e) | ||
raise | ||
else: | ||
fairy = self._make_fairy(connection.cursor()) | ||
callback(fairy) | ||
|
||
def _make_fairy(self, cursor): | ||
# We must be careful here not to make a reference to the fairy, | ||
# or it will never be dereferenced. But, we must make a reference to | ||
# the cursor, or it will be dereferenced with the fairy. | ||
on_deref = partial(self._on_fairy_deref, cursor) | ||
fairy = CursorFairy(self.io_loop, cursor) | ||
ref = weakref.ref(fairy, on_deref) | ||
self._refs.add(ref) | ||
return fairy | ||
|
||
def _on_fairy_deref(self, cursor, ref): | ||
with stack_context.NullContext(): | ||
self._refs.remove(ref) | ||
self._checkin(cursor.connection) | ||
self.io_loop.add_callback(self._process_queue) | ||
|
||
|
||
class CursorFairy(object): | ||
CONNECTION_WARN = False | ||
|
||
def __init__(self, io_loop, cursor): | ||
self._io_loop = io_loop | ||
self._cursor = cursor | ||
|
||
def __getattr__(self, name): | ||
""" Proxy missing attribute lookups to the cursor """ | ||
return getattr(self._cursor, name) | ||
|
||
@property | ||
def connection(self): | ||
if not self.CONNECTION_WARN: | ||
self.CONNECTION_WARN = True | ||
logger.warning("Using the connection directly may cause " | ||
"inconsistent state of the poller!") | ||
return self._cursor.connection | ||
|
||
def execute(self, sql, args, callback): | ||
self._cursor.execute(sql, args) | ||
self.poll(callback) | ||
|
||
def poll(self, callback): | ||
# bind self as first argument of callback. | ||
# this makes the cursor available to the | ||
# callee and ensures we aren't dereferenced until | ||
# the query has finished executing. | ||
callback = partial(callback, self) | ||
poll(self._cursor.connection, self._io_loop, callback) | ||
|
||
|
||
class poll(object): | ||
""" | ||
A poller that polls the PostgreSQL connection and calls the callback | ||
when the connection state is `POLL_OK`, or an error occurs. | ||
""" | ||
def __init__(self, connection, io_loop, callback): | ||
self.connection = connection | ||
self.io_loop = io_loop | ||
self.callback = callback | ||
self.tick(connection.fileno(), 0) | ||
|
||
def tick(self, fd, events): | ||
mask = -1 | ||
try: | ||
mask = STATE_MAP.get(self.connection.poll()) | ||
if mask > 0: | ||
if events == 0: | ||
self.io_loop.add_handler(fd, self.tick, mask) | ||
elif events > 0: | ||
self.io_loop.update_handler(fd, mask) | ||
elif mask < 0: | ||
raise psycopg2.OperationalError("Connection has unknown error state") | ||
except: | ||
self.callback = None | ||
raise | ||
finally: | ||
if mask <= 0: | ||
if events: | ||
self.io_loop.remove_handler(fd) | ||
if mask == 0: | ||
self.callback() | ||
|
||
STATE_MAP = { | ||
psycopg2.extensions.POLL_OK: 0, | ||
psycopg2.extensions.POLL_READ: ioloop.IOLoop.ERROR | ioloop.IOLoop.READ, | ||
psycopg2.extensions.POLL_WRITE: ioloop.IOLoop.ERROR | ioloop.IOLoop.WRITE, | ||
psycopg2.extensions.POLL_ERROR: -1, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import psycopg2 | ||
import itertools | ||
from mock import patch | ||
from squirrel import ConnectionPool | ||
from tornado.testing import AsyncTestCase | ||
from tornado import gen | ||
|
||
|
||
class ConnectionPoolTestCase(AsyncTestCase): | ||
def async(self, func, *args, **kwargs): | ||
args = args + (self.stop,) | ||
func(*args, **kwargs) | ||
return self.wait() | ||
|
||
def setUp(self): | ||
super(ConnectionPoolTestCase, self).setUp() | ||
dsn = "host=127.0.0.1 dbname=test port=5432" | ||
self.provider = ConnectionPool(dsn=dsn, | ||
io_loop=self.io_loop) | ||
|
||
def test_connect_error_propagates_exception(self): | ||
provider = ConnectionPool(dsn="host=127.0.0.1 dbname=test port=6432", | ||
io_loop=self.io_loop) | ||
provider.cursor(self.stop) | ||
self.assertRaises(psycopg2.OperationalError, self.wait) | ||
|
||
def test_query(self): | ||
cursor = self.async(self.provider.cursor) | ||
self.async(cursor.execute, 'select 1', ()) | ||
self.assertEqual((1,), cursor.fetchone()) | ||
|
||
def test_query_error_propagates_error(self): | ||
cursor = self.async(self.provider.cursor) | ||
cursor.execute("I AM BAD", (), self.stop) | ||
self.assertRaises(psycopg2.ProgrammingError, self.wait) | ||
|
||
def test_query_error_dereferences(self): | ||
return | ||
try: | ||
self.async(self.provider.execute, "I AM BAD", ()) | ||
except: | ||
import sys | ||
sys.exc_clear() | ||
self.assertEqual(1, len(self.provider.connections)) | ||
|
||
def test_query_error_propagates_error_2(self): | ||
self.provider.execute("I AM BAD", (), self.stop) | ||
self.assertRaises(psycopg2.ProgrammingError, self.wait) | ||
|
||
@patch.object(ConnectionPool, '_checkin') | ||
def test_connection_checkin_on_deref(self, checkin): | ||
cursor = self.async(self.provider.cursor) | ||
connection = cursor.connection | ||
cursor = None | ||
checkin.assert_called_once_with(connection) | ||
|
||
@patch.object(ConnectionPool, '_checkin') | ||
def test_shorthand_execute_doesnt_deref_fairy(self, checkin): | ||
self.provider.cursor(self.stop) | ||
self.wait().execute("select 1", (), self.stop) | ||
self.assertEqual(0, checkin.call_count) | ||
self.wait() # cursor returned here but we dont reference it | ||
self.assertEqual(1, checkin.call_count) | ||
|
||
def test_100_queries(self): | ||
n = 100 | ||
c = itertools.count().next | ||
|
||
@gen.engine | ||
def query(i): | ||
if i % 2: | ||
cursor = yield gen.Task(self.provider.execute, 'select %s', (i,)) | ||
self.assertEqual((i,), cursor.fetchone()) | ||
else: | ||
try: | ||
cursor = yield gen.Task(self.provider.cursor) | ||
yield gen.Task(cursor.execute, 'select _%s' % i, ()) | ||
except psycopg2.ProgrammingError as e: | ||
self.assertIn('select _%s' % i, e.pgerror) | ||
|
||
if c() == n - 1: | ||
self.stop() | ||
|
||
for i in range(n): | ||
query(i) | ||
|
||
self.wait() | ||
|
||
def test_close_pool_eventually_closes_everything(self): | ||
cursor1 = self.async(self.provider.cursor) | ||
cursor2 = self.async(self.provider.cursor) | ||
conn2 = cursor2.connection | ||
cursor2 = None | ||
self.assertEqual(1, len(self.provider.connections)) | ||
self.assertEqual(1, self.provider.owed) | ||
self.provider.close() | ||
self.assertTrue(conn2.closed) | ||
self.assertEqual(0, len(self.provider.connections)) | ||
conn1 = cursor1.connection | ||
cursor1 = None | ||
self.assertEqual(0, self.provider.owed) | ||
self.assertTrue(conn1.closed) | ||
|