Permalink
6480 lines (5291 sloc) 209 KB
from bisect import bisect_left
from bisect import bisect_right
from contextlib import contextmanager
from copy import deepcopy
from functools import wraps
from inspect import isclass
import calendar
import collections
import datetime
import decimal
import hashlib
import itertools
import logging
import operator
import re
import socket
import struct
import sys
import threading
import time
import uuid
import warnings
try:
from pysqlite3 import dbapi2 as pysq3
except ImportError:
try:
from pysqlite2 import dbapi2 as pysq3
except ImportError:
pysq3 = None
try:
import sqlite3
except ImportError:
sqlite3 = pysq3
else:
if pysq3 and pysq3.sqlite_version_info >= sqlite3.sqlite_version_info:
sqlite3 = pysq3
try:
from psycopg2cffi import compat
compat.register()
except ImportError:
pass
try:
import psycopg2
from psycopg2 import extensions as pg_extensions
except ImportError:
psycopg2 = None
mysql_passwd = False
try:
import pymysql as mysql
except ImportError:
try:
import MySQLdb as mysql # prefer the C module.
mysql_passwd = True
except ImportError:
mysql = None
__version__ = '3.5.0'
__all__ = [
'AsIs',
'AutoField',
'BareField',
'BigAutoField',
'BigBitField',
'BigIntegerField',
'BinaryUUIDField',
'BitField',
'BlobField',
'BooleanField',
'Case',
'Cast',
'CharField',
'Check',
'Column',
'CompositeKey',
'Context',
'Database',
'DatabaseError',
'DataError',
'DateField',
'DateTimeField',
'DecimalField',
'DeferredForeignKey',
'DeferredThroughModel',
'DJANGO_MAP',
'DoesNotExist',
'DoubleField',
'DQ',
'Field',
'FixedCharField',
'FloatField',
'fn',
'ForeignKeyField',
'ImproperlyConfigured',
'Index',
'IntegerField',
'IntegrityError',
'InterfaceError',
'InternalError',
'IPField',
'JOIN',
'ManyToManyField',
'Model',
'ModelIndex',
'MySQLDatabase',
'NotSupportedError',
'OP',
'OperationalError',
'PostgresqlDatabase',
'PrimaryKeyField', # XXX: Deprecated, change to AutoField.
'prefetch',
'ProgrammingError',
'Proxy',
'QualifiedNames',
'SchemaManager',
'SmallIntegerField',
'Select',
'SQL',
'SqliteDatabase',
'Table',
'TextField',
'TimeField',
'TimestampField',
'Tuple',
'UUIDField',
'Value',
'ValuesList',
'Window',
]
try: # Python 2.7+
from logging import NullHandler
except ImportError:
class NullHandler(logging.Handler):
def emit(self, record):
pass
logger = logging.getLogger('peewee')
logger.addHandler(NullHandler())
# Import any speedups or provide alternate implementations.
try:
from playhouse._speedups import quote
except ImportError:
def quote(path, quote_chars):
if len(path) == 1:
return path[0].join(quote_chars)
return '.'.join([part.join(quote_chars) for part in path])
if sys.version_info[0] == 2:
text_type = unicode
bytes_type = str
buffer_type = buffer
izip_longest = itertools.izip_longest
exec('def reraise(tp, value, tb=None): raise tp, value, tb')
def print_(s):
sys.stdout.write(s)
sys.stdout.write('\n')
else:
import builtins
from collections import Callable
from functools import reduce
callable = lambda c: isinstance(c, Callable)
text_type = str
bytes_type = bytes
buffer_type = memoryview
basestring = str
long = int
print_ = getattr(builtins, 'print')
izip_longest = itertools.zip_longest
def reraise(tp, value, tb=None):
if value.__traceback__ is not tb:
raise value.with_traceback(tb)
raise value
if sqlite3:
sqlite3.register_adapter(decimal.Decimal, str)
sqlite3.register_adapter(datetime.date, str)
sqlite3.register_adapter(datetime.time, str)
__sqlite_version__ = sqlite3.sqlite_version_info
else:
__sqlite_version__ = (0, 0, 0)
__date_parts__ = set(('year', 'month', 'day', 'hour', 'minute', 'second'))
# Sqlite does not support the `date_part` SQL function, so we will define an
# implementation in python.
__sqlite_datetime_formats__ = (
'%Y-%m-%d %H:%M:%S',
'%Y-%m-%d %H:%M:%S.%f',
'%Y-%m-%d',
'%H:%M:%S',
'%H:%M:%S.%f',
'%H:%M')
__sqlite_date_trunc__ = {
'year': '%Y',
'month': '%Y-%m',
'day': '%Y-%m-%d',
'hour': '%Y-%m-%d %H',
'minute': '%Y-%m-%d %H:%M',
'second': '%Y-%m-%d %H:%M:%S'}
__mysql_date_trunc__ = __sqlite_date_trunc__.copy()
__mysql_date_trunc__['minute'] = '%Y-%m-%d %H:%i'
__mysql_date_trunc__['second'] = '%Y-%m-%d %H:%i:%S'
def _sqlite_date_part(lookup_type, datetime_string):
assert lookup_type in __date_parts__
if not datetime_string:
return
dt = format_date_time(datetime_string, __sqlite_datetime_formats__)
return getattr(dt, lookup_type)
def _sqlite_date_trunc(lookup_type, datetime_string):
assert lookup_type in __sqlite_date_trunc__
if not datetime_string:
return
dt = format_date_time(datetime_string, __sqlite_datetime_formats__)
return dt.strftime(__sqlite_date_trunc__[lookup_type])
def __deprecated__(s):
warnings.warn(s, DeprecationWarning)
class attrdict(dict):
def __getattr__(self, attr):
try:
return self[attr]
except KeyError:
raise AttributeError(attr)
def __setattr__(self, attr, value): self[attr] = value
def __iadd__(self, rhs): self.update(rhs); return self
def __add__(self, rhs): d = attrdict(self); d.update(rhs); return d
SENTINEL = object()
#: Operations for use in SQL expressions.
OP = attrdict(
AND='AND',
OR='OR',
ADD='+',
SUB='-',
MUL='*',
DIV='/',
BIN_AND='&',
BIN_OR='|',
XOR='#',
MOD='%',
EQ='=',
LT='<',
LTE='<=',
GT='>',
GTE='>=',
NE='!=',
IN='IN',
NOT_IN='NOT IN',
IS='IS',
IS_NOT='IS NOT',
LIKE='LIKE',
ILIKE='ILIKE',
BETWEEN='BETWEEN',
REGEXP='REGEXP',
IREGEXP='IREGEXP',
CONCAT='||',
BITWISE_NEGATION='~')
# To support "django-style" double-underscore filters, create a mapping between
# operation name and operation code, e.g. "__eq" == OP.EQ.
DJANGO_MAP = attrdict({
'eq': OP.EQ,
'lt': OP.LT,
'lte': OP.LTE,
'gt': OP.GT,
'gte': OP.GTE,
'ne': OP.NE,
'in': OP.IN,
'is': OP.IS,
'like': OP.LIKE,
'ilike': OP.ILIKE,
'regexp': OP.REGEXP})
#: Mapping of field type to the data-type supported by the database. Databases
#: may override or add to this list.
FIELD = attrdict(
AUTO='INTEGER',
BIGAUTO='BIGINT',
BIGINT='BIGINT',
BLOB='BLOB',
BOOL='SMALLINT',
CHAR='CHAR',
DATE='DATE',
DATETIME='DATETIME',
DECIMAL='DECIMAL',
DEFAULT='',
DOUBLE='REAL',
FLOAT='REAL',
INT='INTEGER',
SMALLINT='SMALLINT',
TEXT='TEXT',
TIME='TIME',
UUID='TEXT',
UUIDB='BLOB',
VARCHAR='VARCHAR')
#: Join helpers (for convenience) -- all join types are supported, this object
#: is just to help avoid introducing errors by using strings everywhere.
JOIN = attrdict(
INNER='INNER',
LEFT_OUTER='LEFT OUTER',
RIGHT_OUTER='RIGHT OUTER',
FULL='FULL',
FULL_OUTER='FULL OUTER',
CROSS='CROSS',
NATURAL='NATURAL')
# Row representations.
ROW = attrdict(
TUPLE=1,
DICT=2,
NAMED_TUPLE=3,
CONSTRUCTOR=4,
MODEL=5)
SCOPE_NORMAL = 1
SCOPE_SOURCE = 2
SCOPE_VALUES = 4
SCOPE_CTE = 8
SCOPE_COLUMN = 16
# Helper functions that are used in various parts of the codebase.
MODEL_BASE = '_metaclass_helper_'
def with_metaclass(meta, base=object):
return meta(MODEL_BASE, (base,), {})
def merge_dict(source, overrides):
merged = source.copy()
if overrides:
merged.update(overrides)
return merged
is_model = lambda o: isclass(o) and issubclass(o, Model)
def ensure_tuple(value):
if value is not None:
return value if isinstance(value, (list, tuple)) else (value,)
def ensure_entity(value):
if value is not None:
return value if isinstance(value, Node) else Entity(value)
def chunked(it, n):
marker = object()
for group in (list(g) for g in izip_longest(*[iter(it)] * n,
fillvalue=marker)):
if group[-1] is marker:
del group[group.index(marker):]
yield group
class _callable_context_manager(object):
def __call__(self, fn):
@wraps(fn)
def inner(*args, **kwargs):
with self:
return fn(*args, **kwargs)
return inner
class Proxy(object):
"""
Create a proxy or placeholder for another object.
"""
__slots__ = ('obj', '_callbacks')
def __init__(self):
self._callbacks = []
self.initialize(None)
def initialize(self, obj):
self.obj = obj
for callback in self._callbacks:
callback(obj)
def attach_callback(self, callback):
self._callbacks.append(callback)
return callback
def __getattr__(self, attr):
if self.obj is None:
raise AttributeError('Cannot use uninitialized Proxy.')
return getattr(self.obj, attr)
def __setattr__(self, attr, value):
if attr not in self.__slots__:
raise AttributeError('Cannot set attribute on proxy.')
return super(Proxy, self).__setattr__(attr, value)
# SQL Generation.
class AliasManager(object):
def __init__(self):
# A list of dictionaries containing mappings at various depths.
self._counter = 0
self._current_index = 0
self._mapping = []
self.push()
@property
def mapping(self):
return self._mapping[self._current_index - 1]
def add(self, source):
if source not in self.mapping:
self._counter += 1
self[source] = 't%d' % self._counter
return self.mapping[source]
def get(self, source, any_depth=False):
if any_depth:
for idx in reversed(range(self._current_index)):
if source in self._mapping[idx]:
return self._mapping[idx][source]
return self.add(source)
def __getitem__(self, source):
return self.get(source)
def __setitem__(self, source, alias):
self.mapping[source] = alias
def push(self):
self._current_index += 1
if self._current_index > len(self._mapping):
self._mapping.append({})
def pop(self):
if self._current_index == 1:
raise ValueError('Cannot pop() from empty alias manager.')
self._current_index -= 1
class State(collections.namedtuple('_State', ('scope', 'parentheses',
'subquery', 'settings'))):
def __new__(cls, scope=SCOPE_NORMAL, parentheses=False, subquery=False,
**kwargs):
return super(State, cls).__new__(cls, scope, parentheses, subquery,
kwargs)
def __call__(self, scope=None, parentheses=None, subquery=None, **kwargs):
# All state is "inherited" except parentheses.
scope = self.scope if scope is None else scope
subquery = self.subquery if subquery is None else subquery
# Try to avoid unnecessary dict copying.
if kwargs and self.settings:
settings = self.settings.copy() # Copy original settings dict.
settings.update(kwargs) # Update copy with overrides.
elif kwargs:
settings = kwargs
else:
settings = self.settings
return State(scope, parentheses, subquery, **settings)
def __getattr__(self, attr_name):
return self.settings.get(attr_name)
def __scope_context__(scope):
@contextmanager
def inner(self, **kwargs):
with self(scope=scope, **kwargs):
yield self
return inner
class Context(object):
def __init__(self, **settings):
self.stack = []
self._sql = []
self._values = []
self.alias_manager = AliasManager()
self.state = State(**settings)
def column_sort_key(self, item):
return item[0].get_sort_key(self)
@property
def scope(self):
return self.state.scope
@property
def parentheses(self):
return self.state.parentheses
@property
def subquery(self):
return self.state.subquery
def __call__(self, **overrides):
if overrides and overrides.get('scope') == self.scope:
del overrides['scope']
self.stack.append(self.state)
self.state = self.state(**overrides)
return self
scope_normal = __scope_context__(SCOPE_NORMAL)
scope_source = __scope_context__(SCOPE_SOURCE)
scope_values = __scope_context__(SCOPE_VALUES)
scope_cte = __scope_context__(SCOPE_CTE)
scope_column = __scope_context__(SCOPE_COLUMN)
def __enter__(self):
if self.parentheses:
self.literal('(')
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.parentheses:
self.literal(')')
self.state = self.stack.pop()
@contextmanager
def push_alias(self):
self.alias_manager.push()
yield
self.alias_manager.pop()
def sql(self, obj):
if isinstance(obj, (Node, Context)):
return obj.__sql__(self)
elif is_model(obj):
return obj._meta.table.__sql__(self)
else:
return self.sql(Value(obj))
def literal(self, keyword):
self._sql.append(keyword)
return self
def value(self, value, converter=None, add_param=True):
if converter:
value = converter(value)
if isinstance(value, Node):
return self.sql(value)
elif converter is None and self.state.converter:
# Explicitly check for None so that "False" can be used to signify
# that no conversion should be applied.
value = self.state.converter(value)
if isinstance(value, Node):
with self(converter=None):
return self.sql(value)
self._values.append(value)
return self.literal(self.state.param or '?') if add_param else self
def __sql__(self, ctx):
ctx._sql.extend(self._sql)
ctx._values.extend(self._values)
return ctx
def parse(self, node):
return self.sql(node).query()
def query(self):
return ''.join(self._sql), self._values
# AST.
class Node(object):
_coerce = True
def clone(self):
obj = self.__class__.__new__(self.__class__)
obj.__dict__ = self.__dict__.copy()
return obj
def __sql__(self, ctx):
raise NotImplementedError
@staticmethod
def copy(method):
def inner(self, *args, **kwargs):
clone = self.clone()
method(clone, *args, **kwargs)
return clone
return inner
def coerce(self, _coerce=True):
if _coerce != self._coerce:
clone = self.clone()
clone._coerce = _coerce
return clone
return self
def is_alias(self):
return False
def unwrap(self):
return self
class ColumnFactory(object):
__slots__ = ('node',)
def __init__(self, node):
self.node = node
def __getattr__(self, attr):
return Column(self.node, attr)
class _DynamicColumn(object):
__slots__ = ()
def __get__(self, instance, instance_type=None):
if instance is not None:
return ColumnFactory(instance) # Implements __getattr__().
return self
class _ExplicitColumn(object):
__slots__ = ()
def __get__(self, instance, instance_type=None):
if instance is not None:
raise AttributeError(
'%s specifies columns explicitly, and does not support '
'dynamic column lookups.' % instance)
return self
class Source(Node):
c = _DynamicColumn()
def __init__(self, alias=None):
super(Source, self).__init__()
self._alias = alias
@Node.copy
def alias(self, name):
self._alias = name
def select(self, *columns):
return Select((self,), columns)
def join(self, dest, join_type='INNER', on=None):
return Join(self, dest, join_type, on)
def left_outer_join(self, dest, on=None):
return Join(self, dest, JOIN.LEFT_OUTER, on)
def get_sort_key(self, ctx):
if self._alias:
return (self._alias,)
return (ctx.alias_manager[self],)
def apply_alias(self, ctx):
# If we are defining the source, include the "AS alias" declaration. An
# alias is created for the source if one is not already defined.
if ctx.scope == SCOPE_SOURCE:
if self._alias:
ctx.alias_manager[self] = self._alias
ctx.literal(' AS ').sql(Entity(ctx.alias_manager[self]))
return ctx
def apply_column(self, ctx):
if self._alias:
ctx.alias_manager[self] = self._alias
return ctx.sql(Entity(ctx.alias_manager[self]))
class _HashableSource(object):
def __init__(self, *args, **kwargs):
super(_HashableSource, self).__init__(*args, **kwargs)
self._update_hash()
@Node.copy
def alias(self, name):
self._alias = name
self._update_hash()
def _update_hash(self):
self._hash = self._get_hash()
def _get_hash(self):
return hash((self.__class__, self._path, self._alias))
def __hash__(self):
return self._hash
def __eq__(self, other):
return self._hash == other._hash
def __ne__(self, other):
return not (self == other)
def __bind_database__(meth):
@wraps(meth)
def inner(self, *args, **kwargs):
result = meth(self, *args, **kwargs)
if self._database:
return result.bind(self._database)
return result
return inner
def __join__(join_type='INNER', inverted=False):
def method(self, other):
if inverted:
self, other = other, self
return Join(self, other, join_type=join_type)
return method
class BaseTable(Source):
__and__ = __join__(JOIN.INNER)
__add__ = __join__(JOIN.LEFT_OUTER)
__sub__ = __join__(JOIN.RIGHT_OUTER)
__or__ = __join__(JOIN.FULL_OUTER)
__mul__ = __join__(JOIN.CROSS)
__rand__ = __join__(JOIN.INNER, inverted=True)
__radd__ = __join__(JOIN.LEFT_OUTER, inverted=True)
__rsub__ = __join__(JOIN.RIGHT_OUTER, inverted=True)
__ror__ = __join__(JOIN.FULL_OUTER, inverted=True)
__rmul__ = __join__(JOIN.CROSS, inverted=True)
class _BoundTableContext(_callable_context_manager):
def __init__(self, table, database):
self.table = table
self.database = database
def __enter__(self):
self._orig_database = self.table._database
self.table.bind(self.database)
if self.table._model is not None:
self.table._model.bind(self.database)
return self.table
def __exit__(self, exc_type, exc_val, exc_tb):
self.table.bind(self._orig_database)
if self.table._model is not None:
self.table._model.bind(self._orig_database)
class Table(_HashableSource, BaseTable):
def __init__(self, name, columns=None, primary_key=None, schema=None,
alias=None, _model=None, _database=None):
self.__name__ = name
self._columns = columns
self._primary_key = primary_key
self._schema = schema
self._path = (schema, name) if schema else (name,)
self._model = _model
self._database = _database
super(Table, self).__init__(alias=alias)
# Allow tables to restrict what columns are available.
if columns is not None:
self.c = _ExplicitColumn()
for column in columns:
setattr(self, column, Column(self, column))
if primary_key:
col_src = self if self._columns else self.c
self.primary_key = getattr(col_src, primary_key)
else:
self.primary_key = None
def clone(self):
# Ensure a deep copy of the column instances.
return Table(
self.__name__,
columns=self._columns,
primary_key=self._primary_key,
schema=self._schema,
alias=self._alias,
_model=self._model,
_database=self._database)
def bind(self, database=None):
self._database = database
return self
def bind_ctx(self, database=None):
return _BoundTableContext(self, database)
def _get_hash(self):
return hash((self.__class__, self._path, self._alias, self._model))
@__bind_database__
def select(self, *columns):
if not columns and self._columns:
columns = [Column(self, column) for column in self._columns]
return Select((self,), columns)
@__bind_database__
def insert(self, insert=None, columns=None, **kwargs):
if kwargs:
insert = {} if insert is None else insert
src = self if self._columns else self.c
for key, value in kwargs.items():
insert[getattr(src, key)] = value
return Insert(self, insert=insert, columns=columns)
@__bind_database__
def replace(self, insert=None, columns=None, **kwargs):
return (self
.insert(insert=insert, columns=columns)
.on_conflict('REPLACE'))
@__bind_database__
def update(self, update=None, **kwargs):
if kwargs:
update = {} if update is None else update
for key, value in kwargs.items():
src = self if self._columns else self.c
update[getattr(src, key)] = value
return Update(self, update=update)
@__bind_database__
def delete(self):
return Delete(self)
def __sql__(self, ctx):
if ctx.scope == SCOPE_VALUES:
# Return the quoted table name.
return ctx.sql(Entity(*self._path))
if self._alias:
ctx.alias_manager[self] = self._alias
if ctx.scope == SCOPE_SOURCE:
# Define the table and its alias.
return self.apply_alias(ctx.sql(Entity(*self._path)))
else:
# Refer to the table using the alias.
return self.apply_column(ctx)
class Join(BaseTable):
def __init__(self, lhs, rhs, join_type=JOIN.INNER, on=None, alias=None):
super(Join, self).__init__(alias=alias)
self.lhs = lhs
self.rhs = rhs
self.join_type = join_type
self._on = on
def on(self, predicate):
self._on = predicate
return self
def __sql__(self, ctx):
(ctx
.sql(self.lhs)
.literal(' %s JOIN ' % self.join_type)
.sql(self.rhs))
if self._on is not None:
ctx.literal(' ON ').sql(self._on)
return ctx
class ValuesList(BaseTable):
def __init__(self, values, columns=None, alias=None):
super(ValuesList, self).__init__(alias=alias)
self._values = values
self._columns = columns
def _get_hash(self):
return hash((self.__class__, id(self._values), self._alias))
@Node.copy
def columns(self, *names):
self._columns = names
def __sql__(self, ctx):
if self._alias:
ctx.alias_manager[self] = self._alias
if ctx.scope == SCOPE_SOURCE:
ctx = (ctx
.literal('(VALUES ')
.sql(CommaNodeList([
EnclosedNodeList(row) for row in self._values]))
.literal(') AS ')
.sql(Entity(ctx.alias_manager[self])))
if self._columns:
ctx.sql(EnclosedNodeList([Entity(c) for c in self._columns]))
else:
ctx.sql(Entity(ctx.alias_manager[self]))
return ctx
class CTE(_HashableSource, Source):
def __init__(self, name, query, recursive=False, columns=None):
self._alias = name
self._query = query
self._recursive = recursive
if columns is not None:
columns = [Entity(c) if isinstance(c, basestring) else c
for c in columns]
self._columns = columns
query._cte_list = ()
super(CTE, self).__init__(alias=name)
def select_from(self, *columns):
query = (Select((self,), columns)
.with_cte(self)
.bind(self._query._database))
try:
query = query.objects(self._query.model)
except AttributeError:
pass
return query
def _get_hash(self):
return hash((self.__class__, self._alias, id(self._query)))
def union_all(self, rhs):
clone = self._query.clone()
return CTE(self._alias, clone + rhs, self._recursive, self._columns)
__add__ = union_all
def __sql__(self, ctx):
if ctx.scope != SCOPE_CTE:
return ctx.sql(Entity(self._alias))
with ctx.push_alias():
ctx.alias_manager[self] = self._alias
ctx.sql(Entity(self._alias))
if self._columns:
ctx.literal(' ').sql(EnclosedNodeList(self._columns))
ctx.literal(' AS (')
with ctx.scope_normal():
ctx.sql(self._query)
ctx.literal(')')
return ctx
class ColumnBase(Node):
def alias(self, alias):
if alias:
return Alias(self, alias)
return self
def unalias(self):
return self
def cast(self, as_type):
return Cast(self, as_type)
def asc(self, collation=None, nulls=None):
return Asc(self, collation=collation, nulls=nulls)
__pos__ = asc
def desc(self, collation=None, nulls=None):
return Desc(self, collation=collation, nulls=nulls)
__neg__ = desc
def __invert__(self):
return Negated(self)
def _e(op, inv=False):
"""
Lightweight factory which returns a method that builds an Expression
consisting of the left-hand and right-hand operands, using `op`.
"""
def inner(self, rhs):
if inv:
return Expression(rhs, op, self)
return Expression(self, op, rhs)
return inner
__and__ = _e(OP.AND)
__or__ = _e(OP.OR)
__add__ = _e(OP.ADD)
__sub__ = _e(OP.SUB)
__mul__ = _e(OP.MUL)
__div__ = __truediv__ = _e(OP.DIV)
__xor__ = _e(OP.XOR)
__radd__ = _e(OP.ADD, inv=True)
__rsub__ = _e(OP.SUB, inv=True)
__rmul__ = _e(OP.MUL, inv=True)
__rdiv__ = __rtruediv__ = _e(OP.DIV, inv=True)
__rand__ = _e(OP.AND, inv=True)
__ror__ = _e(OP.OR, inv=True)
__rxor__ = _e(OP.XOR, inv=True)
def __eq__(self, rhs):
op = OP.IS if rhs is None else OP.EQ
return Expression(self, op, rhs)
def __ne__(self, rhs):
op = OP.IS_NOT if rhs is None else OP.NE
return Expression(self, op, rhs)
__lt__ = _e(OP.LT)
__le__ = _e(OP.LTE)
__gt__ = _e(OP.GT)
__ge__ = _e(OP.GTE)
__lshift__ = _e(OP.IN)
__rshift__ = _e(OP.IS)
__mod__ = _e(OP.LIKE)
__pow__ = _e(OP.ILIKE)
bin_and = _e(OP.BIN_AND)
bin_or = _e(OP.BIN_OR)
in_ = _e(OP.IN)
not_in = _e(OP.NOT_IN)
regexp = _e(OP.REGEXP)
# Special expressions.
def is_null(self, is_null=True):
op = OP.IS if is_null else OP.IS_NOT
return Expression(self, op, None)
def contains(self, rhs):
return Expression(self, OP.ILIKE, '%%%s%%' % rhs)
def startswith(self, rhs):
return Expression(self, OP.ILIKE, '%s%%' % rhs)
def endswith(self, rhs):
return Expression(self, OP.ILIKE, '%%%s' % rhs)
def between(self, lo, hi):
return Expression(self, OP.BETWEEN, NodeList((lo, SQL('AND'), hi)))
def concat(self, rhs):
return StringExpression(self, OP.CONCAT, rhs)
def regexp(self, rhs):
return Expression(self, OP.REGEXP, rhs)
def iregexp(self, rhs):
return Expression(self, OP.IREGEXP, rhs)
def __getitem__(self, item):
if isinstance(item, slice):
if item.start is None or item.stop is None:
raise ValueError('BETWEEN range must have both a start- and '
'end-point.')
return self.between(item.start, item.stop)
return self == item
def distinct(self):
return NodeList((SQL('DISTINCT'), self))
def collate(self, collation):
return NodeList((self, SQL('COLLATE %s' % collation)))
def get_sort_key(self, ctx):
return ()
class Column(ColumnBase):
def __init__(self, source, name):
self.source = source
self.name = name
def get_sort_key(self, ctx):
if ctx.scope == SCOPE_VALUES:
return (self.name,)
else:
return self.source.get_sort_key(ctx) + (self.name,)
def __hash__(self):
return hash((self.source, self.name))
def __sql__(self, ctx):
if ctx.scope == SCOPE_VALUES:
return ctx.sql(Entity(self.name))
else:
with ctx.scope_column():
return ctx.sql(self.source).literal('.').sql(Entity(self.name))
class WrappedNode(ColumnBase):
def __init__(self, node):
self.node = node
self._coerce = getattr(node, '_coerce', True)
def is_alias(self):
return self.node.is_alias()
def unwrap(self):
return self.node.unwrap()
class EntityFactory(object):
__slots__ = ('node',)
def __init__(self, node):
self.node = node
def __getattr__(self, attr):
return Entity(self.node, attr)
class _DynamicEntity(object):
__slots__ = ()
def __get__(self, instance, instance_type=None):
if instance is not None:
return EntityFactory(instance._alias) # Implements __getattr__().
return self
class Alias(WrappedNode):
c = _DynamicEntity()
def __init__(self, node, alias):
super(Alias, self).__init__(node)
self._alias = alias
def alias(self, alias=None):
if alias is None:
return self.node
else:
return Alias(self.node, alias)
def unalias(self):
return self.node
def is_alias(self):
return True
def __sql__(self, ctx):
if ctx.scope == SCOPE_SOURCE:
return (ctx
.sql(self.node)
.literal(' AS ')
.sql(Entity(self._alias)))
else:
return ctx.sql(Entity(self._alias))
class Negated(WrappedNode):
def __invert__(self):
return self.node
def __sql__(self, ctx):
return ctx.literal('NOT ').sql(self.node)
class BitwiseMixin(object):
def __and__(self, other):
return self.bin_and(other)
def __or__(self, other):
return self.bin_or(other)
def __sub__(self, other):
return self.bin_and(other.bin_negated())
def __invert__(self):
return BitwiseNegated(self)
class BitwiseNegated(BitwiseMixin, WrappedNode):
def __invert__(self):
return self.node
def __sql__(self, ctx):
if ctx.state.operations:
op_sql = ctx.state.operations.get(self.op, self.op)
else:
op_sql = self.op
return ctx.literal(op_sql).sql(self.node)
class Value(ColumnBase):
def __init__(self, value, converter=None, unpack=True):
self.value = value
self.converter = converter
self.multi = isinstance(self.value, (list, set, tuple)) and unpack
if self.multi:
self.values = []
for item in self.value:
if isinstance(item, Node):
self.values.append(item)
else:
self.values.append(Value(item, self.converter))
def __sql__(self, ctx):
if self.multi:
# For multi-part values (e.g. lists of IDs).
return ctx.sql(EnclosedNodeList(self.values))
return ctx.value(self.value, self.converter)
def AsIs(value):
return Value(value, unpack=False)
class Cast(WrappedNode):
def __init__(self, node, cast):
super(Cast, self).__init__(node)
self.cast = cast
self._coerce = False
def __sql__(self, ctx):
return (ctx
.literal('CAST(')
.sql(self.node)
.literal(' AS %s)' % self.cast))
class Ordering(WrappedNode):
def __init__(self, node, direction, collation=None, nulls=None):
super(Ordering, self).__init__(node)
self.direction = direction
self.collation = collation
self.nulls = nulls
def collate(self, collation=None):
return Ordering(self.node, self.direction, collation)
def __sql__(self, ctx):
ctx.sql(self.node).literal(' %s' % self.direction)
if self.collation:
ctx.literal(' COLLATE %s' % self.collation)
if self.nulls:
ctx.literal(' NULLS %s' % self.nulls)
return ctx
def Asc(node, collation=None, nulls=None):
return Ordering(node, 'ASC', collation, nulls)
def Desc(node, collation=None, nulls=None):
return Ordering(node, 'DESC', collation, nulls)
class Expression(ColumnBase):
def __init__(self, lhs, op, rhs, flat=False):
self.lhs = lhs
self.op = op
self.rhs = rhs
self.flat = flat
def __sql__(self, ctx):
overrides = {'parentheses': not self.flat}
if isinstance(self.lhs, Field):
overrides['converter'] = self.lhs.db_value
else:
overrides['converter'] = None
if ctx.state.operations:
op_sql = ctx.state.operations.get(self.op, self.op)
else:
op_sql = self.op
with ctx(**overrides):
# Postgresql reports an error for IN/NOT IN (), so convert to
# the equivalent boolean expression.
if (self.op == OP.IN or self.op == OP.NOT_IN) and \
Context().parse(self.rhs)[0] == '()':
return ctx.literal('0 = 1' if self.op == OP.IN else '1 = 1')
return (ctx
.sql(self.lhs)
.literal(' %s ' % op_sql)
.sql(self.rhs))
class StringExpression(Expression):
def __add__(self, rhs):
return self.concat(rhs)
def __radd__(self, lhs):
return StringExpression(lhs, OP.CONCAT, self)
class Entity(ColumnBase):
def __init__(self, *path):
self._path = [part.replace('"', '""') for part in path if part]
def __getattr__(self, attr):
return Entity(*self._path + [attr])
def get_sort_key(self, ctx):
return tuple(self._path)
def __hash__(self):
return hash((self.__class__.__name__, tuple(self._path)))
def __sql__(self, ctx):
return ctx.literal(quote(self._path, ctx.state.quote or '""'))
class SQL(ColumnBase):
def __init__(self, sql, params=None):
self.sql = sql
self.params = params
def __sql__(self, ctx):
ctx.literal(self.sql)
if self.params:
for param in self.params:
ctx.value(param, False, add_param=False)
return ctx
def Check(constraint):
return SQL('CHECK (%s)' % constraint)
class Function(ColumnBase):
def __init__(self, name, arguments, coerce=True):
self.name = name
self.arguments = arguments
if name and name.lower() in ('sum', 'count', 'cast'):
self._coerce = False
else:
self._coerce = coerce
def __getattr__(self, attr):
def decorator(*args, **kwargs):
return Function(attr, args, **kwargs)
return decorator
def over(self, partition_by=None, order_by=None, start=None, end=None,
window=None):
if isinstance(partition_by, Window) and window is None:
window = partition_by
if start is not None and not isinstance(start, SQL):
start = SQL(*start)
if end is not None and not isinstance(end, SQL):
end = SQL(*end)
if window is None:
node = Window(partition_by=partition_by, order_by=order_by,
start=start, end=end)
else:
node = SQL(window._alias)
return NodeList((self, SQL('OVER'), node))
def __sql__(self, ctx):
ctx.literal(self.name)
if not len(self.arguments):
return ctx.literal('()')
with ctx(in_function=True):
return ctx.sql(EnclosedNodeList([
(argument if isinstance(argument, Node)
else Value(argument))
for argument in self.arguments]))
fn = Function(None, None)
class Window(Node):
CURRENT_ROW = 'CURRENT ROW'
def __init__(self, partition_by=None, order_by=None, start=None, end=None,
alias=None):
super(Window, self).__init__()
self.partition_by = partition_by
self.order_by = order_by
self.start = start
self.end = end
if self.start is None and self.end is not None:
raise ValueError('Cannot specify WINDOW end without start.')
self._alias = alias or 'w'
def alias(self, alias=None):
self._alias = alias or 'w'
return self
@staticmethod
def following(value=None):
if value is None:
return SQL('UNBOUNDED FOLLOWING')
return SQL('%d FOLLOWING' % value)
@staticmethod
def preceding(value=None):
if value is None:
return SQL('UNBOUNDED PRECEDING')
return SQL('%d PRECEDING' % value)
def __sql__(self, ctx):
if ctx.scope != SCOPE_SOURCE:
ctx.literal(self._alias)
ctx.literal(' AS ')
with ctx(parentheses=True):
parts = []
if self.partition_by:
parts.extend((
SQL('PARTITION BY'),
CommaNodeList(self.partition_by)))
if self.order_by:
parts.extend((
SQL('ORDER BY'),
CommaNodeList(self.order_by)))
if self.start is not None and self.end is not None:
parts.extend((
SQL('ROWS BETWEEN'),
self.start,
SQL('AND'),
self.end))
elif self.start is not None:
parts.extend((SQL('ROWS'), self.start))
ctx.sql(NodeList(parts))
return ctx
def clone_base(self):
return Window(self.partition_by, self.order_by)
def Case(predicate, expression_tuples, default=None):
clauses = [SQL('CASE')]
if predicate is not None:
clauses.append(predicate)
for expr, value in expression_tuples:
clauses.extend((SQL('WHEN'), expr, SQL('THEN'), value))
if default is not None:
clauses.extend((SQL('ELSE'), default))
clauses.append(SQL('END'))
return NodeList(clauses)
class NodeList(ColumnBase):
def __init__(self, nodes, glue=' ', parens=False):
self.nodes = nodes
self.glue = glue
self.parens = parens
if parens and len(self.nodes) == 1:
if isinstance(self.nodes[0], Expression):
# Hack to avoid double-parentheses.
self.nodes[0].flat = True
def __sql__(self, ctx):
n_nodes = len(self.nodes)
if n_nodes == 0:
return ctx.literal('()') if self.parens else ctx
with ctx(parentheses=self.parens):
for i in range(n_nodes - 1):
ctx.sql(self.nodes[i])
ctx.literal(self.glue)
ctx.sql(self.nodes[n_nodes - 1])
return ctx
def CommaNodeList(nodes):
return NodeList(nodes, ', ')
def EnclosedNodeList(nodes):
return NodeList(nodes, ', ', True)
class DQ(ColumnBase):
def __init__(self, **query):
super(DQ, self).__init__()
self.query = query
self._negated = False
@Node.copy
def __invert__(self):
self._negated = not self._negated
def clone(self):
node = DQ(**self.query)
node._negated = self._negated
return node
#: Represent a row tuple.
Tuple = lambda *a: EnclosedNodeList(a)
class QualifiedNames(WrappedNode):
def __sql__(self, ctx):
with ctx.scope_column():
return ctx.sql(self.node)
class OnConflict(Node):
def __init__(self, action=None, update=None, preserve=None, where=None,
conflict_target=None):
self._action = action
self._update = update
self._preserve = ensure_tuple(preserve)
self._where = where
self._conflict_target = ensure_tuple(conflict_target)
def get_conflict_statement(self, ctx):
return ctx.state.conflict_statement(self)
def get_conflict_update(self, ctx):
return ctx.state.conflict_update(self)
@Node.copy
def preserve(self, *columns):
self._preserve = columns
@Node.copy
def update(self, _data=None, **kwargs):
if _data and kwargs and not isinstance(_data, dict):
raise ValueError('Cannot mix data with keyword arguments in the '
'OnConflict update method.')
_data = _data or {}
if kwargs:
_data.update(kwargs)
self._update = _data
@Node.copy
def where(self, *expressions):
if self._where is not None:
expressions = (self._where,) + expressions
self._where = reduce(operator.and_, expressions)
@Node.copy
def conflict_target(self, *constraints):
self._conflict_target = constraints
def database_required(method):
@wraps(method)
def inner(self, database=None, *args, **kwargs):
database = self._database if database is None else database
if not database:
raise Exception('Query must be bound to a database in order '
'to call "%s".' % method.__name__)
return method(self, database, *args, **kwargs)
return inner
# BASE QUERY INTERFACE.
class BaseQuery(Node):
default_row_type = ROW.DICT
def __init__(self, _database=None, **kwargs):
self._database = _database
self._cursor_wrapper = None
self._row_type = None
self._constructor = None
super(BaseQuery, self).__init__(**kwargs)
def bind(self, database=None):
self._database = database
return self
def clone(self):
query = super(BaseQuery, self).clone()
query._cursor_wrapper = None
return query
@Node.copy
def dicts(self, as_dict=True):
self._row_type = ROW.DICT if as_dict else None
return self
@Node.copy
def tuples(self, as_tuple=True):
self._row_type = ROW.TUPLE if as_tuple else None
return self
@Node.copy
def namedtuples(self, as_namedtuple=True):
self._row_type = ROW.NAMED_TUPLE if as_namedtuple else None
return self
@Node.copy
def objects(self, constructor=None):
self._row_type = ROW.CONSTRUCTOR if constructor else None
self._constructor = constructor
return self
def _get_cursor_wrapper(self, cursor):
row_type = self._row_type or self.default_row_type
if row_type == ROW.DICT:
return DictCursorWrapper(cursor)
elif row_type == ROW.TUPLE:
return CursorWrapper(cursor)
elif row_type == ROW.NAMED_TUPLE:
return NamedTupleCursorWrapper(cursor)
elif row_type == ROW.CONSTRUCTOR:
return ObjectCursorWrapper(cursor, self._constructor)
else:
raise ValueError('Unrecognized row type: "%s".' % row_type)
def __sql__(self, ctx):
raise NotImplementedError
def sql(self):
if self._database:
context = self._database.get_sql_context()
else:
context = Context()
return context.parse(self)
@database_required
def execute(self, database):
return self._execute(database)
def _execute(self, database):
raise NotImplementedError
def iterator(self, database=None):
return iter(self.execute(database).iterator())
def _ensure_execution(self):
if not self._cursor_wrapper:
if not self._database:
raise ValueError('Query has not been executed.')
self.execute()
def __iter__(self):
self._ensure_execution()
return iter(self._cursor_wrapper)
def __getitem__(self, value):
self._ensure_execution()
if isinstance(value, slice):
index = value.stop
else:
index = value
if index is not None:
index = index + 1 if index >= 0 else 0
self._cursor_wrapper.fill_cache(index)
return self._cursor_wrapper.row_cache[value]
def __len__(self):
self._ensure_execution()
return len(self._cursor_wrapper)
class RawQuery(BaseQuery):
def __init__(self, sql=None, params=None, **kwargs):
super(RawQuery, self).__init__(**kwargs)
self._sql = sql
self._params = params
def __sql__(self, ctx):
ctx.literal(self._sql)
if self._params:
for param in self._params:
ctx.value(param, add_param=False)
return ctx
def _execute(self, database):
if self._cursor_wrapper is None:
cursor = database.execute(self)
self._cursor_wrapper = self._get_cursor_wrapper(cursor)
return self._cursor_wrapper
class Query(BaseQuery):
def __init__(self, where=None, order_by=None, limit=None, offset=None,
**kwargs):
super(Query, self).__init__(**kwargs)
self._where = where
self._order_by = order_by
self._limit = limit
self._offset = offset
self._cte_list = None
@Node.copy
def with_cte(self, *cte_list):
self._cte_list = cte_list
@Node.copy
def where(self, *expressions):
if self._where is not None:
expressions = (self._where,) + expressions
self._where = reduce(operator.and_, expressions)
@Node.copy
def order_by(self, *values):
self._order_by = values
@Node.copy
def order_by_extend(self, *values):
self._order_by = ((self._order_by or ()) + values) or None
@Node.copy
def limit(self, value=None):
self._limit = value
@Node.copy
def offset(self, value=None):
self._offset = value
@Node.copy
def paginate(self, page, paginate_by=20):
if page > 0:
page -= 1
self._limit = paginate_by
self._offset = page * paginate_by
def _apply_ordering(self, ctx):
if self._order_by:
(ctx
.literal(' ORDER BY ')
.sql(CommaNodeList(self._order_by)))
if self._limit is not None or (self._offset is not None and
ctx.state.limit_max):
ctx.literal(' LIMIT ').sql(self._limit or ctx.state.limit_max)
if self._offset is not None:
ctx.literal(' OFFSET ').sql(self._offset)
return ctx
def __sql__(self, ctx):
if self._cte_list:
# The CTE scope is only used at the very beginning of the query,
# when we are describing the various CTEs we will be using.
recursive = any(cte._recursive for cte in self._cte_list)
with ctx.scope_cte():
(ctx
.literal('WITH RECURSIVE ' if recursive else 'WITH ')
.sql(CommaNodeList(self._cte_list))
.literal(' '))
return ctx
def __compound_select__(operation, inverted=False):
def method(self, other):
if inverted:
self, other = other, self
return CompoundSelectQuery(self, operation, other)
return method
class SelectQuery(Query):
union_all = __add__ = __compound_select__('UNION ALL')
union = __or__ = __compound_select__('UNION')
intersect = __and__ = __compound_select__('INTERSECT')
except_ = __sub__ = __compound_select__('EXCEPT')
__radd__ = __compound_select__('UNION ALL', inverted=True)
__ror__ = __compound_select__('UNION', inverted=True)
__rand__ = __compound_select__('INTERSECT', inverted=True)
__rsub__ = __compound_select__('EXCEPT', inverted=True)
def cte(self, name, recursive=False, columns=None):
return CTE(name, self, recursive=recursive, columns=columns)
class SelectBase(_HashableSource, Source, SelectQuery):
def _get_hash(self):
return hash((self.__class__, self._alias or id(self)))
def _execute(self, database):
if self._cursor_wrapper is None:
cursor = database.execute(self)
self._cursor_wrapper = self._get_cursor_wrapper(cursor)
return self._cursor_wrapper
@database_required
def peek(self, database, n=1):
rows = self.execute(database)[:n]
if rows:
return rows[0] if n == 1 else rows
@database_required
def first(self, database, n=1):
if self._limit != n:
self._limit = n
self._cursor_wrapper = None
return self.peek(database, n=n)
@database_required
def scalar(self, database, as_tuple=False):
row = self.tuples().peek(database)
return row[0] if row and not as_tuple else row
@database_required
def count(self, database, clear_limit=False):
clone = self.order_by().alias('_wrapped')
if clear_limit:
clone._limit = clone._offset = None
try:
if clone._having is None and clone._windows is None and \
clone._distinct is None and clone._simple_distinct is not True:
clone = clone.select(SQL('1'))
except AttributeError:
pass
return Select([clone], [fn.COUNT(SQL('1'))]).scalar(database)
@database_required
def exists(self, database):
clone = self.columns(SQL('1'))
clone._limit = 1
clone._offset = None
return bool(clone.scalar())
@database_required
def get(self, database):
self._cursor_wrapper = None
try:
return self.execute(database)[0]
except IndexError:
pass
# QUERY IMPLEMENTATIONS.
class CompoundSelectQuery(SelectBase):
def __init__(self, lhs, op, rhs):
super(CompoundSelectQuery, self).__init__()
self.lhs = lhs
self.op = op
self.rhs = rhs
@property
def _returning(self):
return self.lhs._returning
def _get_query_key(self):
return (self.lhs.get_query_key(), self.rhs.get_query_key())
def __sql__(self, ctx):
if ctx.scope == SCOPE_COLUMN:
return self.apply_column(ctx)
parens_around_query = ctx.state.compound_select_parentheses
outer_parens = ctx.subquery or (ctx.scope == SCOPE_SOURCE)
with ctx(parentheses=outer_parens):
with ctx.scope_normal(parentheses=parens_around_query,
subquery=False):
ctx.sql(self.lhs)
ctx.literal(' %s ' % self.op)
with ctx.push_alias():
with ctx.scope_normal(parentheses=parens_around_query,
subquery=False):
ctx.sql(self.rhs)
# Apply ORDER BY, LIMIT, OFFSET.
self._apply_ordering(ctx)
return self.apply_alias(ctx)
class Select(SelectBase):
def __init__(self, from_list=None, columns=None, group_by=None,
having=None, distinct=None, windows=None, for_update=None,
**kwargs):
super(Select, self).__init__(**kwargs)
self._from_list = (list(from_list) if isinstance(from_list, tuple)
else from_list) or []
self._returning = columns
self._group_by = group_by
self._having = having
self._windows = None
self._for_update = 'FOR UPDATE' if for_update is True else for_update
self._distinct = self._simple_distinct = None
if distinct:
if isinstance(distinct, bool):
self._simple_distinct = distinct
else:
self._distinct = distinct
self._cursor_wrapper = None
def clone(self):
clone = super(Select, self).clone()
if clone._from_list:
clone._from_list = list(clone._from_list)
return clone
@Node.copy
def columns(self, *columns, **kwargs):
self._returning = columns
select = columns
@Node.copy
def select_extend(self, *columns):
self._returning = tuple(self._returning) + columns
@Node.copy
def from_(self, *sources):
self._from_list = list(sources)
@Node.copy
def join(self, dest, join_type='INNER', on=None):
if not self._from_list:
raise ValueError('No sources to join on.')
item = self._from_list.pop()
self._from_list.append(Join(item, dest, join_type, on))
@Node.copy
def group_by(self, *columns):
grouping = []
for column in columns:
if isinstance(column, Table):
if not column._columns:
raise ValueError('Cannot pass a table to group_by() that '
'does not have columns explicitly '
'declared.')
grouping.extend([getattr(column, col_name)
for col_name in column._columns])
else:
grouping.append(column)
self._group_by = grouping
@Node.copy
def group_by_extend(self, *values):
group_by = tuple(self._group_by or ()) + values
return self.group_by(*group_by)
@Node.copy
def having(self, *expressions):
if self._having is not None:
expressions = (self._having,) + expressions
self._having = reduce(operator.and_, expressions)
@Node.copy
def distinct(self, *columns):
if len(columns) == 1 and (columns[0] is True or columns[0] is False):
self._simple_distinct = columns[0]
else:
self._simple_distinct = False
self._distinct = columns
@Node.copy
def window(self, *windows):
self._windows = windows if windows else None
@Node.copy
def for_update(self, for_update=True):
self._for_update = 'FOR UPDATE' if for_update is True else for_update
def _get_query_key(self):
return self._alias
def __sql_selection__(self, ctx, is_subquery=False):
return ctx.sql(CommaNodeList(self._returning))
def __sql__(self, ctx):
super(Select, self).__sql__(ctx)
if ctx.scope == SCOPE_COLUMN:
return self.apply_column(ctx)
is_subquery = ctx.subquery
state = {
'converter': None,
'in_function': False,
'parentheses': is_subquery or (ctx.scope == SCOPE_SOURCE),
'subquery': True,
}
if ctx.state.in_function:
state['parentheses'] = False
with ctx.scope_normal(**state):
ctx.literal('SELECT ')
if self._simple_distinct or self._distinct is not None:
ctx.literal('DISTINCT ')
if self._distinct:
(ctx
.literal('ON ')
.sql(EnclosedNodeList(self._distinct))
.literal(' '))
with ctx.scope_source():
ctx = self.__sql_selection__(ctx, is_subquery)
if self._from_list:
with ctx.scope_source(parentheses=False):
ctx.literal(' FROM ').sql(CommaNodeList(self._from_list))
if self._where is not None:
ctx.literal(' WHERE ').sql(self._where)
if self._group_by:
ctx.literal(' GROUP BY ').sql(CommaNodeList(self._group_by))
if self._having is not None:
ctx.literal(' HAVING ').sql(self._having)
if self._windows is not None:
ctx.literal(' WINDOW ')
ctx.sql(CommaNodeList(self._windows))
# Apply ORDER BY, LIMIT, OFFSET.
self._apply_ordering(ctx)
if self._for_update:
if not ctx.state.for_update:
raise ValueError('FOR UPDATE specified but not supported '
'by database.')
ctx.literal(' ')
ctx.sql(SQL(self._for_update))
if not ctx.state.in_function:
ctx = self.apply_alias(ctx)
return ctx
class _WriteQuery(Query):
def __init__(self, table, returning=None, **kwargs):
self.table = table
self._returning = returning
self._return_cursor = True if returning else False
super(_WriteQuery, self).__init__(**kwargs)
@Node.copy
def returning(self, *returning):
self._returning = returning
self._return_cursor = True if returning else False
def apply_returning(self, ctx):
if self._returning:
ctx.literal(' RETURNING ').sql(CommaNodeList(self._returning))
return ctx
def _execute(self, database):
if self._returning:
cursor = self.execute_returning(database)
else:
cursor = database.execute(self)
return self.handle_result(database, cursor)
def execute_returning(self, database):
if self._cursor_wrapper is None:
cursor = database.execute(self)
self._cursor_wrapper = self._get_cursor_wrapper(cursor)
return self._cursor_wrapper
def handle_result(self, database, cursor):
if self._return_cursor:
return cursor
return database.rows_affected(cursor)
def _set_table_alias(self, ctx):
ctx.alias_manager[self.table] = self.table.__name__
def __sql__(self, ctx):
super(_WriteQuery, self).__sql__(ctx)
# We explicitly set the table alias to the table's name, which ensures
# that if a sub-select references a column on the outer table, we won't
# assign it a new alias (e.g. t2) but will refer to it as table.column.
self._set_table_alias(ctx)
return ctx
class Update(_WriteQuery):
def __init__(self, table, update=None, **kwargs):
super(Update, self).__init__(table, **kwargs)
self._update = update
self._from = None
@Node.copy
def from_(self, *sources):
self._from = sources
def __sql__(self, ctx):
super(Update, self).__sql__(ctx)
with ctx.scope_values(subquery=True):
ctx.literal('UPDATE ')
expressions = []
for k, v in sorted(self._update.items(), key=ctx.column_sort_key):
if not isinstance(v, Node):
converter = k.db_value if isinstance(k, Field) else None
v = Value(v, converter=converter, unpack=False)
expressions.append(NodeList((k, SQL('='), v)))
(ctx
.sql(self.table)
.literal(' SET ')
.sql(CommaNodeList(expressions)))
if self._from:
with ctx.scope_source(parentheses=False):
ctx.literal(' FROM ').sql(CommaNodeList(self._from))
if self._where:
ctx.literal(' WHERE ').sql(self._where)
self._apply_ordering(ctx)
return self.apply_returning(ctx)
class Insert(_WriteQuery):
SIMPLE = 0
QUERY = 1
MULTI = 2
class DefaultValuesException(Exception): pass
def __init__(self, table, insert=None, columns=None, on_conflict=None,
**kwargs):
super(Insert, self).__init__(table, **kwargs)
self._insert = insert
self._columns = columns
self._on_conflict = on_conflict
self._query_type = None
def where(self, *expressions):
raise NotImplementedError('INSERT queries cannot have a WHERE clause.')
@Node.copy
def on_conflict_ignore(self, ignore=True):
self._on_conflict = OnConflict('IGNORE') if ignore else None
@Node.copy
def on_conflict_replace(self, replace=True):
self._on_conflict = OnConflict('REPLACE') if replace else None
@Node.copy
def on_conflict(self, *args, **kwargs):
self._on_conflict = (OnConflict(*args, **kwargs) if (args or kwargs)
else None)
def _simple_insert(self, ctx):
if not self._insert:
raise self.DefaultValuesException('Error: no data to insert.')
return self._generate_insert((self._insert,), ctx)
def get_default_data(self):
return {}
def _generate_insert(self, insert, ctx):
rows_iter = iter(insert)
columns = self._columns
# Load and organize column defaults (if provided).
defaults = self.get_default_data()
if not columns:
uses_strings = False
try:
row = next(rows_iter)
except StopIteration:
raise self.DefaultValuesException('Error: no rows to insert.')
else:
accum = []
value_lookups = {}
for key in row:
if isinstance(key, basestring):
column = getattr(self.table, key)
uses_strings = True
else:
column = key
accum.append(column)
value_lookups[column] = key
column_set = set(accum)
for column in (set(defaults) - column_set):
accum.append(column)
value_lookups[column] = column.name if uses_strings else column
columns = sorted(accum, key=lambda obj: obj.get_sort_key(ctx))
rows_iter = itertools.chain(iter((row,)), rows_iter)
else:
columns = list(columns)
value_lookups = dict((column, column) for column in columns)
for col in sorted(defaults, key=lambda obj: obj.get_sort_key(ctx)):
if col not in value_lookups:
columns.append(col)
value_lookups[col] = col
ctx.sql(EnclosedNodeList(columns)).literal(' VALUES ')
columns_converters = [
(column, column.db_value if isinstance(column, Field) else None)
for column in columns]
all_values = []
for row in rows_iter:
values = []
is_dict = isinstance(row, collections.Mapping)
for i, (column, converter) in enumerate(columns_converters):
try:
if is_dict:
val = row[value_lookups[column]]
else:
val = row[i]
except (KeyError, IndexError):
if column in defaults:
val = defaults[column]
if callable(val):
val = val()
else:
raise ValueError('Missing value for "%s".' % column)
if not isinstance(val, Node):
val = Value(val, converter=converter, unpack=False)
values.append(val)
all_values.append(EnclosedNodeList(values))
with ctx.scope_values(subquery=True):
return ctx.sql(CommaNodeList(all_values))
def _query_insert(self, ctx):
return (ctx
.sql(EnclosedNodeList(self._columns))
.literal(' ')
.sql(self._insert))
def _default_values(self, ctx):
if not self._database:
return ctx.literal('DEFAULT VALUES')
return self._database.default_values_insert(ctx)
def __sql__(self, ctx):
super(Insert, self).__sql__(ctx)
with ctx.scope_values():
statement = None
if self._on_conflict is not None:
statement = self._on_conflict.get_conflict_statement(ctx)
(ctx
.sql(statement or SQL('INSERT'))
.literal(' INTO ')
.sql(self.table)
.literal(' '))
if isinstance(self._insert, dict) and not self._columns:
try:
self._simple_insert(ctx)
except self.DefaultValuesException:
self._default_values(ctx)
self._query_type = Insert.SIMPLE
elif isinstance(self._insert, SelectQuery):
self._query_insert(ctx)
self._query_type = Insert.QUERY
else:
try:
self._generate_insert(self._insert, ctx)
except self.DefaultValuesException:
return
self._query_type = Insert.MULTI
if self._on_conflict is not None:
update = self._on_conflict.get_conflict_update(ctx)
if update is not None:
ctx.literal(' ').sql(update)
return self.apply_returning(ctx)
def _execute(self, database):
if self._returning is None and database.returning_clause \
and self.table._primary_key:
self._returning = (self.table._primary_key,)
return super(Insert, self)._execute(database)
def handle_result(self, database, cursor):
if self._return_cursor:
return cursor
return database.last_insert_id(cursor, self._query_type)
class Delete(_WriteQuery):
def __sql__(self, ctx):
super(Delete, self).__sql__(ctx)
with ctx.scope_values(subquery=True):
ctx.literal('DELETE FROM ').sql(self.table)
if self._where is not None:
ctx.literal(' WHERE ').sql(self._where)
self._apply_ordering(ctx)
return self.apply_returning(ctx)
class Index(Node):
def __init__(self, name, table, expressions, unique=False, safe=False,
where=None, using=None):
self._name = name
self._table = Entity(table) if not isinstance(table, Table) else table
self._expressions = expressions
self._where = where
self._unique = unique
self._safe = safe
self._using = using
@Node.copy
def safe(self, _safe=True):
self._safe = _safe
@Node.copy
def where(self, *expressions):
if self._where is not None:
expressions = (self._where,) + expressions
self._where = reduce(operator.and_, expressions)
@Node.copy
def using(self, _using=None):
self._using = _using
def __sql__(self, ctx):
statement = 'CREATE UNIQUE INDEX ' if self._unique else 'CREATE INDEX '
with ctx.scope_values(subquery=True):
ctx.literal(statement)
if self._safe:
ctx.literal('IF NOT EXISTS ')
(ctx
.sql(Entity(self._name))
.literal(' ON ')
.sql(self._table)
.literal(' '))
if self._using is not None:
ctx.literal('USING %s ' % self._using)
ctx.sql(EnclosedNodeList([
SQL(expr) if isinstance(expr, basestring) else expr
for expr in self._expressions]))
if self._where is not None:
ctx.literal(' WHERE ').sql(self._where)
return ctx
class ModelIndex(Index):
def __init__(self, model, fields, unique=False, safe=True, where=None,
using=None, name=None):
self._model = model
if name is None:
name = self._generate_name_from_fields(model, fields)
if using is None:
for field in fields:
if getattr(field, 'index_type', None):
using = field.index_type
super(ModelIndex, self).__init__(
name=name,
table=model._meta.table,
expressions=fields,
unique=unique,
safe=safe,
where=where,
using=using)
def _generate_name_from_fields(self, model, fields):
accum = []
for field in fields:
if isinstance(field, basestring):
accum.append(field.split()[0])
else:
if isinstance(field, Node) and not isinstance(field, Field):
field = field.unwrap()
if isinstance(field, Field):
accum.append(field.column_name)
if not accum:
raise ValueError('Unable to generate a name for the index, please '
'explicitly specify a name.')
index_name = re.sub('[^\w]+', '',
'%s_%s' % (model._meta.name, '_'.join(accum)))
if len(index_name) > 64:
index_hash = hashlib.md5(index_name.encode('utf-8')).hexdigest()
index_name = '%s_%s' % (index_name[:56], index_hash[:7])
return index_name
# DB-API 2.0 EXCEPTIONS.
class PeeweeException(Exception): pass
class ImproperlyConfigured(PeeweeException): pass
class DatabaseError(PeeweeException): pass
class DataError(DatabaseError): pass
class IntegrityError(DatabaseError): pass
class InterfaceError(PeeweeException): pass
class InternalError(DatabaseError): pass
class NotSupportedError(DatabaseError): pass
class OperationalError(DatabaseError): pass
class ProgrammingError(DatabaseError): pass
class ExceptionWrapper(object):
__slots__ = ('exceptions',)
def __init__(self, exceptions):
self.exceptions = exceptions
def __enter__(self): pass
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None:
return
if exc_type.__name__ in self.exceptions:
new_type = self.exceptions[exc_type.__name__]
exc_args = exc_value.args
reraise(new_type, new_type(*exc_args), traceback)
EXCEPTIONS = {
'ConstraintError': IntegrityError,
'DatabaseError': DatabaseError,
'DataError': DataError,
'IntegrityError': IntegrityError,
'InterfaceError': InterfaceError,
'InternalError': InternalError,
'NotSupportedError': NotSupportedError,
'OperationalError': OperationalError,
'ProgrammingError': ProgrammingError}
__exception_wrapper__ = ExceptionWrapper(EXCEPTIONS)
# DATABASE INTERFACE AND CONNECTION MANAGEMENT.
IndexMetadata = collections.namedtuple(
'IndexMetadata',
('name', 'sql', 'columns', 'unique', 'table'))
ColumnMetadata = collections.namedtuple(
'ColumnMetadata',
('name', 'data_type', 'null', 'primary_key', 'table', 'default'))
ForeignKeyMetadata = collections.namedtuple(
'ForeignKeyMetadata',
('column', 'dest_table', 'dest_column', 'table'))
class _ConnectionState(object):
def __init__(self, **kwargs):
super(_ConnectionState, self).__init__(**kwargs)
self.reset()
def reset(self):
self.closed = True
self.conn = None
self.transactions = []
def set_connection(self, conn):
self.conn = conn
self.closed = False
class _ConnectionLocal(_ConnectionState, threading.local): pass
class _NoopLock(object):
__slots__ = ()
def __enter__(self): return self
def __exit__(self, exc_type, exc_val, exc_tb): pass
class ConnectionContext(_callable_context_manager):
__slots__ = ('db',)
def __init__(self, db): self.db = db
def __enter__(self):
if self.db.is_closed():
self.db.connect()
def __exit__(self, exc_type, exc_val, exc_tb): self.db.close()
class Database(_callable_context_manager):
context_class = Context
field_types = {}
operations = {}
param = '?'
quote = '""'
# Feature toggles.
commit_select = False
compound_select_parentheses = False
for_update = False
limit_max = None
returning_clause = False
safe_create_index = True
safe_drop_index = True
sequences = False
def __init__(self, database, thread_safe=True, autorollback=False,
field_types=None, operations=None, autocommit=None, **kwargs):
self._field_types = merge_dict(FIELD, self.field_types)
self._operations = merge_dict(OP, self.operations)
if field_types:
self._field_types.update(field_types)
if operations:
self._operations.update(operations)
self.autorollback = autorollback
self.thread_safe = thread_safe
if thread_safe:
self._state = _ConnectionLocal()
self._lock = threading.Lock()
else:
self._state = _ConnectionState()
self._lock = _NoopLock()
if autocommit is not None:
__deprecated__('Peewee no longer uses the "autocommit" option, as '
'the semantics now require it to always be True. '
'Because some database-drivers also use the '
'"autocommit" parameter, you are receiving a '
'warning so you may update your code and remove '
'the parameter, as in the future, specifying '
'autocommit could impact the behavior of the '
'database driver you are using.')
self.connect_params = {}
self.init(database, **kwargs)
def init(self, database, **kwargs):
if not self.is_closed():
self.close()
self.database = database
self.connect_params.update(kwargs)
self.deferred = not bool(database)
def __enter__(self):
if self.is_closed():
self.connect()
self.transaction().__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
top = self._state.transactions[-1]
try:
top.__exit__(exc_type, exc_val, exc_tb)
finally:
self.close()
def connection_context(self):
return ConnectionContext(self)
def _connect(self):
raise NotImplementedError
def connect(self, reuse_if_open=False):
with self._lock:
if self.deferred:
raise Exception('Error, database must be initialized before '
'opening a connection.')
if not self._state.closed:
if reuse_if_open:
return False
raise OperationalError('Connection already opened.')
self._state.reset()
with __exception_wrapper__:
self._state.set_connection(self._connect())
self._initialize_connection(self._state.conn)
return True
def _initialize_connection(self, conn):
pass
def close(self):
with self._lock:
if self.deferred:
raise Exception('Error, database must be initialized before '
'opening a connection.')
if self.in_transaction():
raise OperationalError('Attempting to close database while '
'transaction is open.')
is_open = not self._state.closed
try:
if is_open:
with __exception_wrapper__:
self._close(self._state.conn)
finally:
self._state.reset()
return is_open
def _close(self, conn):
conn.close()
def is_closed(self):
return self._state.closed
def connection(self):
if self.is_closed():
self.connect()
return self._state.conn
def cursor(self, commit=None):
if self.is_closed():
self.connect()
return self._state.conn.cursor()
def execute_sql(self, sql, params=None, commit=SENTINEL):
logger.debug((sql, params))
if commit is SENTINEL:
if self.in_transaction():
commit = False
elif self.commit_select:
commit = True
else:
commit = not sql[:6].lower().startswith('select')
with __exception_wrapper__:
cursor = self.cursor(commit)
try:
cursor.execute(sql, params or ())
except Exception:
if self.autorollback and not self.in_transaction():
self.rollback()
raise
else:
if commit and not self.in_transaction():
self.commit()
return cursor
def execute(self, query, commit=SENTINEL, **context_options):
ctx = self.get_sql_context(**context_options)
sql, params = ctx.sql(query).query()
return self.execute_sql(sql, params, commit=commit)
def get_context_options(self):
return {
'field_types': self._field_types,
'operations': self._operations,
'param': self.param,
'quote': self.quote,
'compound_select_parentheses': self.compound_select_parentheses,
'conflict_statement': self.conflict_statement,
'conflict_update': self.conflict_update,
'for_update': self.for_update,
'limit_max': self.limit_max,
}
def get_sql_context(self, **context_options):
context = self.get_context_options()
if context_options:
context.update(context_options)
return self.context_class(**context)
def conflict_statement(self, on_conflict):
raise NotImplementedError
def conflict_update(self, on_conflict):
raise NotImplementedError
def _build_on_conflict_update(self, on_conflict):
target = EnclosedNodeList([
Entity(col) if isinstance(col, basestring) else col
for col in on_conflict._conflict_target])
updates = []
if on_conflict._preserve:
for column in on_conflict._preserve:
excluded = NodeList((SQL('EXCLUDED'), ensure_entity(column)),
glue='.')
expression = NodeList((ensure_entity(column), SQL('='),
excluded))
updates.append(expression)
if on_conflict._update:
for k, v in on_conflict._update.items():
if not isinstance(v, Node):
converter = k.db_value if isinstance(k, Field) else None
v = Value(v, converter=converter, unpack=False)
else:
v = QualifiedNames(v)
updates.append(NodeList((ensure_entity(k), SQL('='), v)))
parts = [SQL('ON CONFLICT'),
target,
SQL('DO UPDATE SET'),
CommaNodeList(updates)]
if on_conflict._where:
parts.extend((SQL('WHERE'), QualifiedNames(on_conflict._where)))
return NodeList(parts)
def last_insert_id(self, cursor, query_type=None):
return cursor.lastrowid
def rows_affected(self, cursor):
return cursor.rowcount
def default_values_insert(self, ctx):
return ctx.literal('DEFAULT VALUES')
def in_transaction(self):
return bool(self._state.transactions)
def push_transaction(self, transaction):
self._state.transactions.append(transaction)
def pop_transaction(self):
return self._state.transactions.pop()
def transaction_depth(self):
return len(self._state.transactions)
def top_transaction(self):
if self._state.transactions:
return self._state.transactions[-1]
def atomic(self):
return _atomic(self)
def manual_commit(self):
return _manual(self)
def transaction(self):
return _transaction(self)
def savepoint(self):
return _savepoint(self)
def begin(self):
if self.is_closed():
self.connect()
def commit(self):
return self._state.conn.commit()
def rollback(self):
return self._state.conn.rollback()
def batch_commit(self, it, n):
for group in chunked(it, n):
with self.atomic():
for obj in group:
yield obj
def table_exists(self, table_name, schema=None):
return table_name in self.get_tables(schema=schema)
def get_tables(self, schema=None):
raise NotImplementedError
def get_indexes(self, table, schema=None):
raise NotImplementedError
def get_columns(self, table, schema=None):
raise NotImplementedError
def get_primary_keys(self, table, schema=None):
raise NotImplementedError
def get_foreign_keys(self, table, schema=None):
raise NotImplementedError
def sequence_exists(self, seq):
raise NotImplementedError
def create_tables(self, models, **options):
for model in sort_models(models):
model.create_table(**options)
def drop_tables(self, models, **kwargs):
for model in reversed(sort_models(models)):
model.drop_table(**kwargs)
def extract_date(self, date_part, date_field):
raise NotImplementedError
def truncate_date(self, date_part, date_field):
raise NotImplementedError
def bind(self, models, bind_refs=True, bind_backrefs=True):
for model in models:
model.bind(self, bind_refs=bind_refs, bind_backrefs=bind_backrefs)
def bind_ctx(self, models, bind_refs=True, bind_backrefs=True):
return _BoundModelsContext(models, self, bind_refs, bind_backrefs)
def get_noop_select(self, ctx):
return ctx.sql(Select().columns(SQL('0')).where(SQL('0')))
def __pragma__(name):
def __get__(self):
return self.pragma(name)
def __set__(self, value):
return self.pragma(name, value)
return property(__get__, __set__)
class SqliteDatabase(Database):
field_types = {
'BIGAUTO': FIELD.AUTO,
'BIGINT': FIELD.INT,
'BOOL': FIELD.INT,
'DOUBLE': FIELD.FLOAT,
'SMALLINT': FIELD.INT,
'UUID': FIELD.TEXT}
operations = {
'LIKE': 'GLOB',
'ILIKE': 'LIKE'}
limit_max = -1
def __init__(self, database, *args, **kwargs):
self._pragmas = kwargs.pop('pragmas', ())
super(SqliteDatabase, self).__init__(database, *args, **kwargs)
self._aggregates = {}
self._collations = {}
self._functions = {}
self._table_functions = []
self._extensions = set()
self._attached = {}
self.register_function(_sqlite_date_part, 'date_part', 2)
self.register_function(_sqlite_date_trunc, 'date_trunc', 2)
def init(self, database, pragmas=None, timeout=5, **kwargs):
if pragmas is not None:
self._pragmas = pragmas
if isinstance(self._pragmas, dict):
self._pragmas = list(self._pragmas.items())
self._timeout = timeout
super(SqliteDatabase, self).init(database, **kwargs)
def _connect(self):
if sqlite3 is None:
raise ImproperlyConfigured('SQLite driver not installed!')
conn = sqlite3.connect(self.database, timeout=self._timeout,
**self.connect_params)
conn.isolation_level = None
try:
self._add_conn_hooks(conn)
except:
conn.close()
raise
return conn
def _add_conn_hooks(self, conn):
for db_name, filename in self._attached.items():
conn.execute('ATTACH DATABASE "%s" AS "%s"' % (filename, db_name))
self._set_pragmas(conn)
self._load_aggregates(conn)
self._load_collations(conn)
self._load_functions(conn)
if self._table_functions:
for table_function in self._table_functions:
table_function.register(conn)
if self._extensions:
self._load_extensions(conn)
def _set_pragmas(self, conn):
if self._pragmas:
cursor = conn.cursor()
for pragma, value in self._pragmas:
cursor.execute('PRAGMA %s = %s;' % (pragma, value))
cursor.close()
def pragma(self, key, value=SENTINEL, permanent=False, schema=None):
if schema is not None:
key = '"%s".%s' % (schema, key)
sql = 'PRAGMA %s' % key
if value is not SENTINEL:
sql += ' = %s' % (value or 0)
if permanent:
pragmas = dict(self._pragmas or ())
pragmas[key] = value
self._pragmas = list(pragmas.items())
elif permanent:
raise ValueError('Cannot specify a permanent pragma without value')
row = self.execute_sql(sql).fetchone()
if row:
return row[0]
cache_size = __pragma__('cache_size')
foreign_keys = __pragma__('foreign_keys')
journal_mode = __pragma__('journal_mode')
journal_size_limit = __pragma__('journal_size_limit')
mmap_size = __pragma__('mmap_size')
page_size = __pragma__('page_size')
read_uncommitted = __pragma__('read_uncommitted')
synchronous = __pragma__('synchronous')
wal_autocheckpoint = __pragma__('wal_autocheckpoint')
@property
def timeout(self):
return self._timeout
@timeout.setter
def timeout(self, seconds):
if self._timeout == seconds:
return
self._timeout = seconds
if not self.is_closed():
# PySQLite multiplies user timeout by 1000, but the unit of the
# timeout PRAGMA is actually milliseconds.
self.execute_sql('PRAGMA busy_timeout=%d;' % (seconds * 1000))
def _load_aggregates(self, conn):
for name, (klass, num_params) in self._aggregates.items():
conn.create_aggregate(name, num_params, klass)
def _load_collations(self, conn):
for name, fn in self._collations.items():
conn.create_collation(name, fn)
def _load_functions(self, conn):
for name, (fn, num_params) in self._functions.items():
conn.create_function(name, num_params, fn)
def register_aggregate(self, klass, name=None, num_params=-1):
self._aggregates[name or klass.__name__.lower()] = (klass, num_params)
if not self.is_closed():
self._load_aggregates(self.connection())
def aggregate(self, name=None, num_params=-1):
def decorator(klass):
self.register_aggregate(klass, name, num_params)
return klass
return decorator
def register_collation(self, fn, name=None):
name = name or fn.__name__
def _collation(*args):
expressions = args + (SQL('collate %s' % name),)
return NodeList(expressions)
fn.collation = _collation
self._collations[name] = fn
if not self.is_closed():
self._load_collations(self.connection())
def collation(self, name=None):
def decorator(fn):
self.register_collation(fn, name)
return fn
return decorator
def register_function(self, fn, name=None, num_params=-1):
self._functions[name or fn.__name__] = (fn, num_params)
if not self.is_closed():
self._load_functions(self.connection())
def func(self, name=None, num_params=-1):
def decorator(fn):
self.register_function(fn, name, num_params)
return fn
return decorator
def register_table_function(self, klass, name=None):
if name is not None:
klass.name = name
self._table_functions.append(klass)
if not self.is_closed():
klass.register(self.connection())
def table_function(self, name=None):
def decorator(klass):
self.register_table_function(klass, name)
return klass
return decorator
def unregister_aggregate(self, name):
del(self._aggregates[name])
def unregister_collation(self, name):
del(self._collations[name])
def unregister_function(self, name):
del(self._functions[name])
def unregister_table_function(self, name):
for idx, klass in enumerate(self._table_functions):
if klass.name == name:
break
else:
return False
self._table_functions.pop(idx)
return True
def _load_extensions(self, conn):
conn.enable_load_extension(True)
for extension in self._extensions:
conn.load_extension(extension)
def load_extension(self, extension):
self._extensions.add(extension)
if not self.is_closed():
conn = self.connection()
conn.enable_load_extension(True)
conn.load_extension(extension)
def unload_extension(self, extension):
self._extensions.remove(extension)
def attach(self, filename, name):
if name in self._attached:
if self._attached[name] == filename:
return False
raise OperationalError('schema "%s" already attached.' % name)
self._attached[name] = filename
if not self.is_closed():
self.execute_sql('ATTACH DATABASE "%s" AS "%s"' % (filename, name))
return True
def detach(self, name):
if name not in self._attached:
return False
del self._attached[name]
if not self.is_closed():
self.execute_sql('DETACH DATABASE "%s"' % name)
return True
def atomic(self, lock_type=None):
return _atomic(self, lock_type=lock_type)
def transaction(self, lock_type=None):
return _transaction(self, lock_type=lock_type)
def begin(self, lock_type=None):
statement = 'BEGIN %s' % lock_type if lock_type else 'BEGIN'
self.execute_sql(statement, commit=False)
def get_tables(self, schema=None):
schema = schema or 'main'
cursor = self.execute_sql('SELECT name FROM "%s".sqlite_master WHERE '
'type=? ORDER BY name' % schema, ('table',))
return [row for row, in cursor.fetchall()]
def get_indexes(self, table, schema=None):
schema = schema or 'main'
query = ('SELECT name, sql FROM "%s".sqlite_master '
'WHERE tbl_name = ? AND type = ? ORDER BY name') % schema
cursor = self.execute_sql(query, (table, 'index'))
index_to_sql = dict(cursor.fetchall())
# Determine which indexes have a unique constraint.
unique_indexes = set()
cursor = self.execute_sql('PRAGMA "%s".index_list("%s")' %
(schema, table))
for row in cursor.fetchall():
name = row[1]
is_unique = int(row[2]) == 1
if is_unique:
unique_indexes.add(name)
# Retrieve the indexed columns.
index_columns = {}
for index_name in sorted(index_to_sql):
cursor = self.execute_sql('PRAGMA "%s".index_info("%s")' %
(schema, index_name))
index_columns[index_name] = [row[2] for row in cursor.fetchall()]
return [
IndexMetadata(
name,
index_to_sql[name],
index_columns[name],
name in unique_indexes,
table)
for name in sorted(index_to_sql)]
def get_columns(self, table, schema=None):
cursor = self.execute_sql('PRAGMA "%s".table_info("%s")' %
(schema or 'main', table))
return [ColumnMetadata(r[1], r[2], not r[3], bool(r[5]), table, r[4])
for r in cursor.fetchall()]
def get_primary_keys(self, table, schema=None):
cursor = self.execute_sql('PRAGMA "%s".table_info("%s")' %
(schema or 'main', table))
return [row[1] for row in filter(lambda r: r[-1], cursor.fetchall())]
def get_foreign_keys(self, table, schema=None):
cursor = self.execute_sql('PRAGMA "%s".foreign_key_list("%s")' %
(schema or 'main', table))
return [ForeignKeyMetadata(row[3], row[2], row[4], table)
for row in cursor.fetchall()]
def get_binary_type(self):
return sqlite3.Binary
def conflict_statement(self, on_conflict):
action = on_conflict._action.lower() if on_conflict._action else ''
if action and action not in ('nothing', 'update'):
return SQL('INSERT OR %s' % on_conflict._action.upper())
def conflict_update(self, on_conflict):
# Sqlite prior to 3.24.0 does not support Postgres-style upsert.
if __sqlite_version__ < (3, 24, 0) and \
any((on_conflict._preserve, on_conflict._update, on_conflict._where,
on_conflict._conflict_target)):
raise ValueError('SQLite does not support specifying which values '
'to preserve or update.')
action = on_conflict._action.lower() if on_conflict._action else ''
if action and action not in ('nothing', 'update', ''):
return
if action == 'nothing':
return SQL('ON CONFLICT DO NOTHING')
elif not on_conflict._update and not on_conflict._preserve:
raise ValueError('If you are not performing any updates (or '
'preserving any INSERTed values), then the '
'conflict resolution action should be set to '
'"NOTHING".')
elif not on_conflict._conflict_target:
raise ValueError('SQLite requires that a conflict target be '
'specified when doing an upsert.')
return self._build_on_conflict_update(on_conflict)
def extract_date(self, date_part, date_field):
return fn.date_part(date_part, date_field)
def truncate_date(self, date_part, date_field):
return fn.date_trunc(date_part, date_field)
class PostgresqlDatabase(Database):
field_types = {
'AUTO': 'SERIAL',
'BIGAUTO': 'BIGSERIAL',
'BLOB': 'BYTEA',
'BOOL': 'BOOLEAN',
'DATETIME': 'TIMESTAMP',
'DECIMAL': 'NUMERIC',
'DOUBLE': 'DOUBLE PRECISION',
'UUID': 'UUID',
'UUIDB': 'BYTEA'}
operations = {'REGEXP': '~', 'IREGEXP': '~*'}
param = '%s'
commit_select = True
compound_select_parentheses = True
for_update = True
returning_clause = True
safe_create_index = False
sequences = True
def init(self, database, register_unicode=True, encoding=None, **kwargs):
self._register_unicode = register_unicode
self._encoding = encoding
self._need_server_version = True
super(PostgresqlDatabase, self).init(database, **kwargs)
def _connect(self):
if psycopg2 is None:
raise ImproperlyConfigured('Postgres driver not installed!')
conn = psycopg2.connect(database=self.database, **self.connect_params)
if self._register_unicode:
pg_extensions.register_type(pg_extensions.UNICODE, conn)
pg_extensions.register_type(pg_extensions.UNICODEARRAY, conn)
if self._encoding:
conn.set_client_encoding(self._encoding)
if self._need_server_version:
self.set_server_version(conn.server_version)
self._need_server_version = False
return conn
def set_server_version(self, version):
if version >= 90600:
self.safe_create_index = True
def last_insert_id(self, cursor, query_type=None):
try:
return cursor if query_type else cursor[0][0]
except (IndexError, KeyError, TypeError):
pass
def get_tables(self, schema=None):
query = ('SELECT tablename FROM pg_catalog.pg_tables '
'WHERE schemaname = %s ORDER BY tablename')
cursor = self.execute_sql(query, (schema or 'public',))
return [table for table, in cursor.fetchall()]
def get_indexes(self, table, schema=None):
query = """
SELECT
i.relname, idxs.indexdef, idx.indisunique,
array_to_string(array_agg(cols.attname), ',')
FROM pg_catalog.pg_class AS t
INNER JOIN pg_catalog.pg_index AS idx ON t.oid = idx.indrelid
INNER JOIN pg_catalog.pg_class AS i ON idx.indexrelid = i.oid
INNER JOIN pg_catalog.pg_indexes AS idxs ON
(idxs.tablename = t.relname AND idxs.indexname = i.relname)
LEFT OUTER JOIN pg_catalog.pg_attribute AS cols ON
(cols.attrelid = t.oid AND cols.attnum = ANY(idx.indkey))
WHERE t.relname = %s AND t.relkind = %s AND idxs.schemaname = %s
GROUP BY i.relname, idxs.indexdef, idx.indisunique
ORDER BY idx.indisunique DESC, i.relname;"""
cursor = self.execute_sql(query, (table, 'r', schema or 'public'))
return [IndexMetadata(row[0], row[1], row[3].split(','), row[2], table)
for row in cursor.fetchall()]
def get_columns(self, table, schema=None):
query = """
SELECT column_name, is_nullable, data_type, column_default
FROM information_schema.columns
WHERE table_name = %s AND table_schema = %s
ORDER BY ordinal_position"""
cursor = self.execute_sql(query, (table, schema or 'public'))
pks = set(self.get_primary_keys(table, schema))
return [ColumnMetadata(name, dt, null == 'YES', name in pks, table, df)
for name, null, dt, df in cursor.fetchall()]
def get_primary_keys(self, table