Skip to content

Commit

Permalink
woo code
Browse files Browse the repository at this point in the history
  • Loading branch information
ssadler committed Oct 28, 2012
1 parent 27b86fb commit eb31129
Show file tree
Hide file tree
Showing 4 changed files with 293 additions and 0 deletions.
13 changes: 13 additions & 0 deletions LICENSE
@@ -0,0 +1,13 @@
Copyright 2012 Scott Sadler

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
11 changes: 11 additions & 0 deletions setup.py
@@ -0,0 +1,11 @@
#!/usr/bin/env python

from setuptools import setup, find_packages

setup(
name='squirrel',
description='Psycopg2 wrapper for tornadp',
version='0.1',
author='scott sadler',
py_modules=['squirrel'],
)
166 changes: 166 additions & 0 deletions squirrel.py
@@ -0,0 +1,166 @@
import logging as _logging
import weakref
import psycopg2
import psycopg2.extensions
from collections import deque
from functools import partial
from tornado import ioloop, stack_context, gen


logger = _logging.getLogger('squirrel.pool')


class ConnectionPool(object):
"""
Manages connections and provisions cursors.
"""
def __init__(self, io_loop, max_connections=10, **connect_kwargs):
self.io_loop = io_loop
self.max_connections = max_connections
connect_kwargs['async'] = 1
self.connect_kwargs = connect_kwargs
self.connections = deque()
self.queue = deque()
self.owed = 0
self._refs = set()
self._closed = False

def cursor(self, callback):
""" Get a cursor """
dispatch = partial(self._dispatch, callback)
dispatch = stack_context.wrap(dispatch)
self.queue.append(dispatch)
self._process_queue()

def execute(self, sql, args, callback):
""" Shortcut to execute a query """
self.cursor(lambda cursor: cursor.execute(sql, args, callback))

def close(self):
"""
Closes the pool. The effect of this is not that the pool can no
longer be used, but that connections will be closed if not in use.
"""
self._closed = True
while self.connections:
self.connections.pop().close()

def _process_queue(self):
with stack_context.NullContext():
while self.queue and self.owed < self.max_connections:
self.owed += 1
self.queue.popleft()()

def _checkin(self, connection, err=None):
""" Called automatically on dereference of CursorFairy """
self.owed -= 1
if err:
logger.error("Error: %s, closing connection" % err)
connection.close()
elif self._closed:
connection.close()
elif not connection.closed:
self.connections.append(connection)

@gen.engine
def _dispatch(self, callback):
try:
connection = self.connections.popleft()
except IndexError:
connection = psycopg2.connect(**self.connect_kwargs)
try:
yield gen.Task(poll, connection, self.io_loop)
except Exception as e:
self._checkin(connection, e)
raise
else:
fairy = self._make_fairy(connection.cursor())
callback(fairy)

def _make_fairy(self, cursor):
# We must be careful here not to make a reference to the fairy,
# or it will never be dereferenced. But, we must make a reference to
# the cursor, or it will be dereferenced with the fairy.
on_deref = partial(self._on_fairy_deref, cursor)
fairy = CursorFairy(self.io_loop, cursor)
ref = weakref.ref(fairy, on_deref)
self._refs.add(ref)
return fairy

def _on_fairy_deref(self, cursor, ref):
with stack_context.NullContext():
self._refs.remove(ref)
self._checkin(cursor.connection)
self.io_loop.add_callback(self._process_queue)


class CursorFairy(object):
CONNECTION_WARN = False

def __init__(self, io_loop, cursor):
self._io_loop = io_loop
self._cursor = cursor

def __getattr__(self, name):
""" Proxy missing attribute lookups to the cursor """
return getattr(self._cursor, name)

@property
def connection(self):
if not self.CONNECTION_WARN:
self.CONNECTION_WARN = True
logger.warning("Using the connection directly may cause "
"inconsistent state of the poller!")
return self._cursor.connection

def execute(self, sql, args, callback):
self._cursor.execute(sql, args)
self.poll(callback)

def poll(self, callback):
# bind self as first argument of callback.
# this makes the cursor available to the
# callee and ensures we aren't dereferenced until
# the query has finished executing.
callback = partial(callback, self)
poll(self._cursor.connection, self._io_loop, callback)


class poll(object):
"""
A poller that polls the PostgreSQL connection and calls the callback
when the connection state is `POLL_OK`, or an error occurs.
"""
def __init__(self, connection, io_loop, callback):
self.connection = connection
self.io_loop = io_loop
self.callback = callback
self.tick(connection.fileno(), 0)

def tick(self, fd, events):
mask = -1
try:
mask = STATE_MAP.get(self.connection.poll())
if mask > 0:
if events == 0:
self.io_loop.add_handler(fd, self.tick, mask)
elif events > 0:
self.io_loop.update_handler(fd, mask)
elif mask < 0:
raise psycopg2.OperationalError("Connection has unknown error state")
except:
self.callback = None
raise
finally:
if mask <= 0:
if events:
self.io_loop.remove_handler(fd)
if mask == 0:
self.callback()

STATE_MAP = {
psycopg2.extensions.POLL_OK: 0,
psycopg2.extensions.POLL_READ: ioloop.IOLoop.ERROR | ioloop.IOLoop.READ,
psycopg2.extensions.POLL_WRITE: ioloop.IOLoop.ERROR | ioloop.IOLoop.WRITE,
psycopg2.extensions.POLL_ERROR: -1,
}
103 changes: 103 additions & 0 deletions test_squirrel.py
@@ -0,0 +1,103 @@
import psycopg2
import itertools
from mock import patch
from squirrel import ConnectionPool
from tornado.testing import AsyncTestCase
from tornado import gen


class ConnectionPoolTestCase(AsyncTestCase):
def async(self, func, *args, **kwargs):
args = args + (self.stop,)
func(*args, **kwargs)
return self.wait()

def setUp(self):
super(ConnectionPoolTestCase, self).setUp()
dsn = "host=127.0.0.1 dbname=test port=5432"
self.provider = ConnectionPool(dsn=dsn,
io_loop=self.io_loop)

def test_connect_error_propagates_exception(self):
provider = ConnectionPool(dsn="host=127.0.0.1 dbname=test port=6432",
io_loop=self.io_loop)
provider.cursor(self.stop)
self.assertRaises(psycopg2.OperationalError, self.wait)

def test_query(self):
cursor = self.async(self.provider.cursor)
self.async(cursor.execute, 'select 1', ())
self.assertEqual((1,), cursor.fetchone())

def test_query_error_propagates_error(self):
cursor = self.async(self.provider.cursor)
cursor.execute("I AM BAD", (), self.stop)
self.assertRaises(psycopg2.ProgrammingError, self.wait)

def test_query_error_dereferences(self):
return
try:
self.async(self.provider.execute, "I AM BAD", ())
except:
import sys
sys.exc_clear()
self.assertEqual(1, len(self.provider.connections))

def test_query_error_propagates_error_2(self):
self.provider.execute("I AM BAD", (), self.stop)
self.assertRaises(psycopg2.ProgrammingError, self.wait)

@patch.object(ConnectionPool, '_checkin')
def test_connection_checkin_on_deref(self, checkin):
cursor = self.async(self.provider.cursor)
connection = cursor.connection
cursor = None
checkin.assert_called_once_with(connection)

@patch.object(ConnectionPool, '_checkin')
def test_shorthand_execute_doesnt_deref_fairy(self, checkin):
self.provider.cursor(self.stop)
self.wait().execute("select 1", (), self.stop)
self.assertEqual(0, checkin.call_count)
self.wait() # cursor returned here but we dont reference it
self.assertEqual(1, checkin.call_count)

def test_100_queries(self):
n = 100
c = itertools.count().next

@gen.engine
def query(i):
if i % 2:
cursor = yield gen.Task(self.provider.execute, 'select %s', (i,))
self.assertEqual((i,), cursor.fetchone())
else:
try:
cursor = yield gen.Task(self.provider.cursor)
yield gen.Task(cursor.execute, 'select _%s' % i, ())
except psycopg2.ProgrammingError as e:
self.assertIn('select _%s' % i, e.pgerror)

if c() == n - 1:
self.stop()

for i in range(n):
query(i)

self.wait()

def test_close_pool_eventually_closes_everything(self):
cursor1 = self.async(self.provider.cursor)
cursor2 = self.async(self.provider.cursor)
conn2 = cursor2.connection
cursor2 = None
self.assertEqual(1, len(self.provider.connections))
self.assertEqual(1, self.provider.owed)
self.provider.close()
self.assertTrue(conn2.closed)
self.assertEqual(0, len(self.provider.connections))
conn1 = cursor1.connection
cursor1 = None
self.assertEqual(0, self.provider.owed)
self.assertTrue(conn1.closed)

0 comments on commit eb31129

Please sign in to comment.