Permalink
Browse files

Placeholder generation (useful with the likes of 'IN (...)')

  • Loading branch information...
1 parent 9f25d46 commit 6f47bac8aa3f8ab4c1a3fd1ac01804ded07afe0e @kgaughan committed Apr 4, 2013
Showing with 106 additions and 1 deletion.
  1. +29 −1 dbkit.py
  2. +65 −0 tests/test_dbkit.py
  3. +12 −0 tests/utils.py
View
@@ -13,6 +13,7 @@
import contextlib
import datetime
import functools
+import itertools
import pprint
import sys
import textwrap
@@ -90,7 +91,7 @@ class Context(object):
"""A database connection context."""
__slots__ = (
- 'mdr', '_depth', 'logger', 'default_factory',
+ 'mdr', '_depth', 'logger', 'default_factory', 'param_style',
'last_row_count', 'last_row_id') + _EXCEPTIONS
stack = _ContextStack()
@@ -105,6 +106,7 @@ def __init__(self, module, mdr):
self.default_factory = TupleFactory
self.last_row_count = None
self.last_row_id = None
+ self.param_style = module.paramstyle
# Copy driver module's exception references.
for exc in _EXCEPTIONS:
setattr(self, exc, getattr(module, exc))
@@ -853,6 +855,32 @@ def to_dict(key, resultset):
return dict((row[key], row) for row in resultset)
+def make_placeholders(seq, start=1):
+ """
+ Generate placeholders for the given sequence.
+ """
+ if len(seq) == 0:
+ raise ValueError('Sequence must have at least one element.')
+ param_style = Context.current().param_style
+ placeholders = None
+ if isinstance(seq, dict):
+ if param_style in ('named', 'pyformat'):
+ template = ':%s' if param_style == 'named' else '%%(%s)s'
+ placeholders = (template % key for key in seq.iterkeys())
+ elif isinstance(seq, (list, tuple)):
+ if param_style == 'numeric':
+ placeholders = (':%d' % i for i in xrange(start, start + len(seq)))
+ elif param_style in ('qmark', 'format', 'pyformat'):
+ placeholders = itertools.repeat(
+ '?' if param_style == 'qmark' else '%s',
+ times=len(seq))
+ if placeholders is None:
+ raise NotSupported(
+ "Param style '%s' does not support sequence type '%s'" % (
+ param_style, seq.__class__.__name__))
+ return ', '.join(placeholders)
+
+
def null_logger(_stmt, _args):
"""A logger that discards everything sent to it."""
pass
View
@@ -1,4 +1,5 @@
from __future__ import with_statement
+
import sqlite3
import StringIO
import types
@@ -265,4 +266,68 @@ def test_to_dict_sequence():
assert 'barney' in result
assert result['barney'] is row
+
+def test_make_placeholders():
+ with dbkit.connect(fakedb, 'db') as ctx:
+ try:
+ dbkit.make_placeholders([])
+ assert False, "Expected ValueError"
+ except ValueError:
+ pass
+
+ with utils.set_temporarily(fakedb, 'paramstyle', 'qmark'):
+ with dbkit.connect(fakedb, 'db') as ctx:
+ assert dbkit.make_placeholders([0]) == '?'
+ assert dbkit.make_placeholders([0, 1]) == '?, ?'
+ assert dbkit.make_placeholders([0, 1, 4]) == '?, ?, ?'
+
+ for style in ('format', 'pyformat'):
+ with utils.set_temporarily(fakedb, 'paramstyle', style):
+ with dbkit.connect(fakedb, 'db') as ctx:
+ assert dbkit.make_placeholders([0]) == '%s'
+ assert dbkit.make_placeholders([0, 2]) == '%s, %s'
+ assert dbkit.make_placeholders([0, 2, 7]) == '%s, %s, %s'
+
+ with utils.set_temporarily(fakedb, 'paramstyle', 'numeric'):
+ with dbkit.connect(fakedb, 'db') as ctx:
+ assert dbkit.make_placeholders([0], 7) == ':7'
+ assert dbkit.make_placeholders([0, 1], 7) == ':7, :8'
+ assert dbkit.make_placeholders([0, 1, 4], 7) == ':7, :8, :9'
+
+ def sort_fields(fields):
+ """Helper to ensure named fields are sorted for the test."""
+ return ', '.join(sorted(field.lstrip() for field in fields.split(',')))
+
+ def make_sorted(seq):
+ """Wrap repetitive code for the next few checks."""
+ return sort_fields(dbkit.make_placeholders(seq))
+
+ with utils.set_temporarily(fakedb, 'paramstyle', 'pyformat'):
+ with dbkit.connect(fakedb, 'db') as ctx:
+ assert make_sorted({'foo': None}) == '%(foo)s'
+ assert make_sorted({'foo': None, 'bar': None}) == '%(bar)s, %(foo)s'
+ assert make_sorted({'foo': None, 'bar': None, 'baz': None}) == '%(bar)s, %(baz)s, %(foo)s'
+
+ with utils.set_temporarily(fakedb, 'paramstyle', 'named'):
+ with dbkit.connect(fakedb, 'db') as ctx:
+ assert make_sorted({'foo': None}) == ':foo'
+ assert make_sorted({'foo': None, 'bar': None}) == ':bar, :foo'
+ assert make_sorted({'foo': None, 'bar': None, 'baz': None}) == ':bar, :baz, :foo'
+
+ with utils.set_temporarily(fakedb, 'paramstyle', 'qmark'):
+ with dbkit.connect(fakedb, 'db') as ctx:
+ try:
+ print dbkit.make_placeholders({'foo': None})
+ assert False, "Should've got 'NotSupported' exception."
+ except dbkit.NotSupported, exc:
+ assert exc.message == "Param style 'qmark' does not support sequence type 'dict'"
+
+ with utils.set_temporarily(fakedb, 'paramstyle', 'named'):
+ with dbkit.connect(fakedb, 'db') as ctx:
+ try:
+ print dbkit.make_placeholders(['foo'])
+ assert False, "Should've got 'NotSupported' exception."
+ except dbkit.NotSupported, exc:
+ assert exc.message == "Param style 'named' does not support sequence type 'list'"
+
# vim:set et ai:
View
@@ -2,6 +2,7 @@
Utility functions used by the tests.
"""
+import contextlib
import threading
@@ -20,3 +21,14 @@ def spawn(targets):
threads.append(thread)
for thread in threads:
thread.join()
+
+
+@contextlib.contextmanager
+def set_temporarily(obj, attr, value):
+ """Temporarily change the value of an object's attribute."""
+ try:
+ original = getattr(obj, attr)
+ setattr(obj, attr, value)
+ yield
+ finally:
+ setattr(obj, attr, original)

0 comments on commit 6f47bac

Please sign in to comment.