Skip to content


Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
4101 lines (3448 sloc) 137.646 kb
# May you do good and not evil
# May you find forgiveness for yourself and forgive others
# May you share freely, never taking more than you give. -- SQLite source code
# As we enjoy great advantages from the inventions of others, we should be glad
# of an opportunity to serve others by an invention of ours, and this we should
# do freely and generously. -- Ben Franklin
# (\
# ( \ /(o)\ caw!
# ( \/ ()/ /)
# ( `;.))'".)
# `(/////.-'
# =====))=))===()
# ///'
# //
# '
import datetime
import decimal
import hashlib
import logging
import operator
import re
import sys
import threading
import uuid
from collections import deque
from collections import namedtuple
from collections import OrderedDict
except ImportError:
OrderedDict = dict
from copy import deepcopy
from functools import wraps
from inspect import isclass
__version__ = '2.4.7'
__all__ = [
# Set default logging handler to avoid "No handlers could be found for logger
# "peewee"" warnings.
try: # Python 2.7+
from logging import NullHandler
except ImportError:
class NullHandler(logging.Handler):
def emit(self, record):
# All peewee-generated logs are logged to this namespace.
logger = logging.getLogger('peewee')
# Python 2/3 compatibility helpers. These helpers are used internally and are
# not exported.
def with_metaclass(meta, base=object):
return meta("NewBase", (base,), {})
PY2 = sys.version_info[0] == 2
PY3 = sys.version_info[0] == 3
if PY3:
import builtins
from collections import Callable
from functools import reduce
callable = lambda c: isinstance(c, Callable)
unicode_type = str
string_type = bytes
basestring = str
print_ = getattr(builtins, 'print')
binary_construct = lambda s: bytes(s.encode('raw_unicode_escape'))
def reraise(tp, value, tb=None):
if value.__traceback__ is not tb:
raise value.with_traceback(tb)
raise value
elif PY2:
unicode_type = unicode
string_type = basestring
binary_construct = buffer
def print_(s):
exec('def reraise(tp, value, tb=None): raise tp, value, tb')
raise RuntimeError('Unsupported python version.')
# By default, peewee supports Sqlite, MySQL and Postgresql.
import sqlite3
except ImportError:
from pysqlite2 import dbapi2 as sqlite3
except ImportError:
sqlite3 = None
from psycopg2cffi import compat
except ImportError:
import psycopg2
from psycopg2 import extensions as pg_extensions
except ImportError:
psycopg2 = None
import MySQLdb as mysql # prefer the C module.
except ImportError:
import pymysql as mysql
except ImportError:
mysql = None
if sqlite3:
sqlite3.register_adapter(decimal.Decimal, str)
sqlite3.register_adapter(, str)
sqlite3.register_adapter(datetime.time, str)
DATETIME_PARTS = ['year', 'month', 'day', 'hour', 'minute', 'second']
# Sqlite does not support the `date_part` SQL function, so we will define an
# implementation in python.
'%Y-%m-%d %H:%M:%S',
'%Y-%m-%d %H:%M:%S.%f',
def _sqlite_date_part(lookup_type, datetime_string):
assert lookup_type in DATETIME_LOOKUPS
if not datetime_string:
dt = format_date_time(datetime_string, SQLITE_DATETIME_FORMATS)
return getattr(dt, lookup_type)
'year': '%Y',
'month': '%Y-%m',
'day': '%Y-%m-%d',
'hour': '%Y-%m-%d %H',
'minute': '%Y-%m-%d %H:%M',
'second': '%Y-%m-%d %H:%M:%S'}
MYSQL_DATE_TRUNC_MAPPING['minute'] = '%Y-%m-%d %H:%i'
MYSQL_DATE_TRUNC_MAPPING['second'] = '%Y-%m-%d %H:%i:%S'
def _sqlite_date_trunc(lookup_type, datetime_string):
assert lookup_type in SQLITE_DATE_TRUNC_MAPPING
if not datetime_string:
dt = format_date_time(datetime_string, SQLITE_DATETIME_FORMATS)
return dt.strftime(SQLITE_DATE_TRUNC_MAPPING[lookup_type])
def _sqlite_regexp(regex, value):
return, value, re.I) is not None
# Operators used in binary expressions.
OP_AND = 'and'
OP_OR = 'or'
OP_ADD = '+'
OP_SUB = '-'
OP_MUL = '*'
OP_DIV = '/'
OP_BIN_AND = '&'
OP_BIN_OR = '|'
OP_XOR = '^'
OP_MOD = '%'
OP_EQ = '='
OP_LT = '<'
OP_LTE = '<='
OP_GT = '>'
OP_GTE = '>='
OP_NE = '!='
OP_IN = 'in'
OP_NOT_IN = 'not in'
OP_IS = 'is'
OP_IS_NOT = 'is not'
OP_LIKE = 'like'
OP_ILIKE = 'ilike'
OP_BETWEEN = 'between'
OP_REGEXP = 'regexp'
OP_CONCAT = '||'
# To support "django-style" double-underscore filters, create a mapping between
# operation name and operation code, e.g. "__eq" == OP_EQ.
'eq': OP_EQ,
'lt': OP_LT,
'lte': OP_LTE,
'gt': OP_GT,
'gte': OP_GTE,
'ne': OP_NE,
'in': OP_IN,
'is': OP_IS,
'like': OP_LIKE,
'ilike': OP_ILIKE,
'regexp': OP_REGEXP,
JOIN_INNER = 'inner'
JOIN_LEFT_OUTER = 'left outer'
JOIN_RIGHT_OUTER = 'right outer'
JOIN_FULL = 'full'
# Helper functions that are used in various parts of the codebase.
def merge_dict(source, overrides):
merged = source.copy()
return merged
def pythonify_name(name):
name = re.sub('([a-z_])([A-Z][_a-z])', '\\1 \\2', name)
return re.sub('[^\w+]', '_', name.lower())
def returns_clone(func):
Method decorator that will "clone" the object before applying the given
method. This ensures that state is mutated in a more predictable fashion,
and promotes the use of method-chaining.
def inner(self, *args, **kwargs):
clone = self.clone() # Assumes object implements `clone`.
func(clone, *args, **kwargs)
return clone
inner.call_local = func # Provide a way to call without cloning.
return inner
def not_allowed(func):
Method decorator to indicate a method is not allowed to be called. Will
raise a `NotImplementedError`.
def inner(self, *args, **kwargs):
raise NotImplementedError('%s is not allowed on %s instances' % (
func, type(self).__name__))
return inner
class Proxy(object):
Proxy class useful for situations when you wish to defer the initialization
of an object.
__slots__ = ['obj', '_callbacks']
def __init__(self):
self._callbacks = []
def initialize(self, obj):
self.obj = obj
for callback in self._callbacks:
def attach_callback(self, callback):
return callback
def __getattr__(self, attr):
if self.obj is None:
raise AttributeError('Cannot use uninitialized Proxy.')
return getattr(self.obj, attr)
def __setattr__(self, attr, value):
if attr not in self.__slots__:
raise AttributeError('Cannot set attribute on proxy.')
return super(Proxy, self).__setattr__(attr, value)
class _CDescriptor(object):
def __get__(self, instance, instance_type=None):
if instance is not None:
return Entity(instance._alias)
return self
# Classes representing the query tree.
class Node(object):
"""Base-class for any part of a query which shall be composable."""
c = _CDescriptor()
_node_type = 'node'
def __init__(self):
self._negated = False
self._alias = None
self._ordering = None # ASC or DESC.
def clone_base(self):
return type(self)()
def clone(self):
inst = self.clone_base()
inst._negated = self._negated
inst._alias = self._alias
inst._ordering = self._ordering
return inst
def __invert__(self):
self._negated = not self._negated
def alias(self, a=None):
self._alias = a
def asc(self):
self._ordering = 'ASC'
def desc(self):
self._ordering = 'DESC'
def _e(op, inv=False):
Lightweight factory which returns a method that builds an Expression
consisting of the left-hand and right-hand operands, using `op`.
def inner(self, rhs):
if inv:
return Expression(rhs, op, self)
return Expression(self, op, rhs)
return inner
__and__ = _e(OP_AND)
__or__ = _e(OP_OR)
__add__ = _e(OP_ADD)
__sub__ = _e(OP_SUB)
__mul__ = _e(OP_MUL)
__div__ = __truediv__ = _e(OP_DIV)
__xor__ = _e(OP_XOR)
__radd__ = _e(OP_ADD, inv=True)
__rsub__ = _e(OP_SUB, inv=True)
__rmul__ = _e(OP_MUL, inv=True)
__rdiv__ = __rtruediv__ = _e(OP_DIV, inv=True)
__rand__ = _e(OP_AND, inv=True)
__ror__ = _e(OP_OR, inv=True)
__rxor__ = _e(OP_XOR, inv=True)
def __eq__(self, rhs):
if rhs is None:
return Expression(self, OP_IS, None)
return Expression(self, OP_EQ, rhs)
def __ne__(self, rhs):
if rhs is None:
return Expression(self, OP_IS_NOT, None)
return Expression(self, OP_NE, rhs)
__lt__ = _e(OP_LT)
__le__ = _e(OP_LTE)
__gt__ = _e(OP_GT)
__ge__ = _e(OP_GTE)
__lshift__ = _e(OP_IN)
__rshift__ = _e(OP_IS)
__mod__ = _e(OP_LIKE)
__pow__ = _e(OP_ILIKE)
bin_and = _e(OP_BIN_AND)
bin_or = _e(OP_BIN_OR)
# Special expressions.
def in_(self, *rhs):
return Expression(self, OP_IN, rhs)
def not_in(self, *rhs):
return Expression(self, OP_NOT_IN, rhs)
def is_null(self, is_null=True):
if is_null:
return Expression(self, OP_IS, None)
return Expression(self, OP_IS_NOT, None)
def contains(self, rhs):
return Expression(self, OP_ILIKE, '%%%s%%' % rhs)
def startswith(self, rhs):
return Expression(self, OP_ILIKE, '%s%%' % rhs)
def endswith(self, rhs):
return Expression(self, OP_ILIKE, '%%%s' % rhs)
def between(self, low, high):
return Expression(self, OP_BETWEEN, Clause(low, R('AND'), high))
def regexp(self, expression):
return Expression(self, OP_REGEXP, expression)
def concat(self, rhs):
return Expression(self, OP_CONCAT, rhs)
class Expression(Node):
"""A binary expression, e.g `foo + 1` or `bar < 7`."""
_node_type = 'expression'
def __init__(self, lhs, op, rhs, flat=False):
super(Expression, self).__init__()
self.lhs = lhs
self.op = op
self.rhs = rhs
self.flat = flat
def clone_base(self):
return Expression(self.lhs, self.op, self.rhs, self.flat)
class DQ(Node):
"""A "django-style" filter expression, e.g. {'foo__eq': 'x'}."""
def __init__(self, **query):
super(DQ, self).__init__()
self.query = query
def clone_base(self):
return DQ(**self.query)
class Param(Node):
Arbitrary parameter passed into a query. Instructs the query compiler to
specifically treat this value as a parameter, useful for `list` which is
special-cased for `IN` lookups.
_node_type = 'param'
def __init__(self, value, conv=None):
self.value = value
self.conv = conv
super(Param, self).__init__()
def clone_base(self):
return Param(self.value, self.conv)
class Passthrough(Param):
_node_type = 'passthrough'
class SQL(Node):
"""An unescaped SQL string, with optional parameters."""
_node_type = 'sql'
def __init__(self, value, *params):
self.value = value
self.params = params
super(SQL, self).__init__()
def clone_base(self):
return SQL(self.value, *self.params)
R = SQL # backwards-compat.
class Func(Node):
"""An arbitrary SQL function call."""
_node_type = 'func'
def __init__(self, name, *arguments): = name
self.arguments = arguments
self._coerce = True
super(Func, self).__init__()
def coerce(self, coerce=True):
self._coerce = coerce
def clone_base(self):
res = Func(, *self.arguments)
res._coerce = self._coerce
return res
def over(self, partition_by=None, order_by=None, window=None):
if isinstance(partition_by, Window) and window is None:
window = partition_by
if window is None:
sql = Window(
partition_by=partition_by, order_by=order_by).__sql__()
sql = SQL(window._alias)
return Clause(self, SQL('OVER'), sql)
def __getattr__(self, attr):
def dec(*args, **kwargs):
return Func(attr, *args, **kwargs)
return dec
# fn is a factory for creating `Func` objects and supports a more friendly
# API. So instead of `Func("LOWER", param)`, `fn.LOWER(param)`.
fn = Func(None)
class Window(Node):
def __init__(self, partition_by=None, order_by=None):
super(Window, self).__init__()
self.partition_by = partition_by
self.order_by = order_by
self._alias = self._alias or 'w'
def __sql__(self):
over_clauses = []
if self.partition_by:
if self.order_by:
return EnclosedClause(Clause(*over_clauses))
def clone_base(self):
return Window(self.partition_by, self.order_by)
class Clause(Node):
"""A SQL clause, one or more Node objects joined by spaces."""
_node_type = 'clause'
glue = ' '
parens = False
def __init__(self, *nodes):
super(Clause, self).__init__()
self.nodes = list(nodes)
def clone_base(self):
clone = Clause(*self.nodes)
clone.glue = self.glue
clone.parens = self.parens
return clone
class CommaClause(Clause):
"""One or more Node objects joined by commas, no parens."""
glue = ', '
class EnclosedClause(CommaClause):
"""One or more Node objects joined by commas and enclosed in parens."""
parens = True
class Entity(Node):
"""A quoted-name or entity, e.g. "table"."column"."""
_node_type = 'entity'
def __init__(self, *path):
super(Entity, self).__init__()
self.path = path
def clone_base(self):
return Entity(*self.path)
def __getattr__(self, attr):
return Entity(*filter(None, self.path + (attr,)))
class Check(SQL):
"""Check constraint, usage: `Check('price > 10')`."""
def __init__(self, value):
super(Check, self).__init__('CHECK (%s)' % value)
class _StripParens(Node):
_node_type = 'strip_parens'
def __init__(self, node):
super(_StripParens, self).__init__()
self.node = node
JoinMetadata = namedtuple('JoinMetadata', (
'source', 'target_attr', 'dest', 'to_field', 'related_name'))
class Join(namedtuple('_Join', ('dest', 'join_type', 'on'))):
def get_foreign_key(self, source, dest):
fk_field = source._meta.rel_for_model(dest)
if fk_field is not None:
return fk_field, False
reverse_rel = source._meta.reverse_rel_for_model(dest)
if reverse_rel is not None:
return reverse_rel, True
return None, None
def get_join_type(self):
return self.join_type or JOIN_INNER
def join_metadata(self, source):
is_model_alias = isinstance(self.dest, ModelAlias)
if is_model_alias:
dest = self.dest.model_class
dest = self.dest
is_expr = isinstance(self.on, Expression)
join_alias = is_expr and self.on._alias or None
target_attr = to_field = related_name = None
fk_field, is_backref = self.get_foreign_key(source, dest)
if fk_field is not None:
if is_backref:
target_attr = dest._meta.db_table
related_name = fk_field.related_name
target_attr =
to_field =
elif is_expr and hasattr(self.on.lhs, 'name'):
target_attr =
target_attr = dest._meta.db_table
return JoinMetadata(
join_alias or target_attr,
class FieldDescriptor(object):
# Fields are exposed as descriptors in order to control access to the
# underlying "raw" data.
def __init__(self, field):
self.field = field
self.att_name =
def __get__(self, instance, instance_type=None):
if instance is not None:
return instance._data.get(self.att_name)
return self.field
def __set__(self, instance, value):
instance._data[self.att_name] = value
class Field(Node):
"""A column on a table."""
_field_counter = 0
_order = 0
_node_type = 'field'
db_field = 'unknown'
def __init__(self, null=False, index=False, unique=False,
verbose_name=None, help_text=None, db_column=None,
default=None, choices=None, primary_key=False, sequence=None,
constraints=None, schema=None):
self.null = null
self.index = index
self.unique = unique
self.verbose_name = verbose_name
self.help_text = help_text
self.db_column = db_column
self.default = default
self.choices = choices # Used for metadata purposes, not enforced.
self.primary_key = primary_key
self.sequence = sequence # Name of sequence, e.g. foo_id_seq.
self.constraints = constraints # List of column constraints.
self.schema = schema # Name of schema, e.g. 'public'.
# Used internally for recovering the order in which Fields were defined
# on the Model class.
Field._field_counter += 1
self._order = Field._field_counter
self._sort_key = (self.primary_key and 1 or 2), self._order
self._is_bound = False # Whether the Field is "bound" to a Model.
super(Field, self).__init__()
def clone_base(self, **kwargs):
inst = type(self)(
if self._is_bound: =
inst.model_class = self.model_class
inst._is_bound = self._is_bound
return inst
def add_to_class(self, model_class, name):
Hook that replaces the `Field` attribute on a class with a named
`FieldDescriptor`. Called by the metaclass during construction of the
""" = name
self.model_class = model_class
self.db_column = self.db_column or
if not self.verbose_name:
self.verbose_name = re.sub('_+', ' ', name).title()
model_class._meta.fields[] = self
model_class._meta.columns[self.db_column] = self
setattr(model_class, name, FieldDescriptor(self))
self._is_bound = True
def get_database(self):
return self.model_class._meta.database
def get_column_type(self):
field_type = self.get_db_field()
return self.get_database().compiler().get_column_type(field_type)
def get_db_field(self):
return self.db_field
def get_modifiers(self):
return None
def coerce(self, value):
return value
def db_value(self, value):
"""Convert the python value for storage in the database."""
return value if value is None else self.coerce(value)
def python_value(self, value):
"""Convert the database value to a pythonic value."""
return value if value is None else self.coerce(value)
def _as_entity(self, with_table=False):
if with_table:
return Entity(self.model_class._meta.db_table, self.db_column)
return Entity(self.db_column)
def __ddl_column__(self, column_type):
"""Return the column type, e.g. VARCHAR(255) or REAL."""
modifiers = self.get_modifiers()
if modifiers:
return SQL(
'%s(%s)' % (column_type, ', '.join(map(str, modifiers))))
return SQL(column_type)
def __ddl__(self, column_type):
"""Return a list of Node instances that defines the column."""
ddl = [self._as_entity(), self.__ddl_column__(column_type)]
if not self.null:
ddl.append(SQL('NOT NULL'))
if self.primary_key:
ddl.append(SQL('PRIMARY KEY'))
if self.sequence:
ddl.append(SQL("DEFAULT NEXTVAL('%s')" % self.sequence))
if self.constraints:
return ddl
def __hash__(self):
return hash( + '.' + self.model_class.__name__)
class BareField(Field):
db_field = 'bare'
class IntegerField(Field):
db_field = 'int'
coerce = int
class BigIntegerField(IntegerField):
db_field = 'bigint'
class PrimaryKeyField(IntegerField):
db_field = 'primary_key'
def __init__(self, *args, **kwargs):
kwargs['primary_key'] = True
super(PrimaryKeyField, self).__init__(*args, **kwargs)
class FloatField(Field):
db_field = 'float'
coerce = float
class DoubleField(FloatField):
db_field = 'double'
class DecimalField(Field):
db_field = 'decimal'
def __init__(self, max_digits=10, decimal_places=5, auto_round=False,
rounding=None, *args, **kwargs):
self.max_digits = max_digits
self.decimal_places = decimal_places
self.auto_round = auto_round
self.rounding = rounding or decimal.DefaultContext.rounding
super(DecimalField, self).__init__(*args, **kwargs)
def clone_base(self, **kwargs):
return super(DecimalField, self).clone_base(
def get_modifiers(self):
return [self.max_digits, self.decimal_places]
def db_value(self, value):
D = decimal.Decimal
if not value:
return value if value is None else D(0)
if self.auto_round:
exp = D(10) ** (-self.decimal_places)
rounding = self.rounding
return D(str(value)).quantize(exp, rounding=rounding)
return value
def python_value(self, value):
if value is not None:
if isinstance(value, decimal.Decimal):
return value
return decimal.Decimal(str(value))
def coerce_to_unicode(s, encoding='utf-8'):
if isinstance(s, unicode_type):
return s
elif isinstance(s, string_type):
return s.decode(encoding)
return unicode_type(s)
class CharField(Field):
db_field = 'string'
def __init__(self, max_length=255, *args, **kwargs):
self.max_length = max_length
super(CharField, self).__init__(*args, **kwargs)
def clone_base(self, **kwargs):
return super(CharField, self).clone_base(
def get_modifiers(self):
return self.max_length and [self.max_length] or None
def coerce(self, value):
return coerce_to_unicode(value or '')
class TextField(Field):
db_field = 'text'
def coerce(self, value):
return coerce_to_unicode(value or '')
class BlobField(Field):
db_field = 'blob'
def db_value(self, value):
if isinstance(value, basestring):
return binary_construct(value)
return value
class UUIDField(Field):
db_field = 'uuid'
def db_value(self, value):
return None if value is None else str(value)
def python_value(self, value):
return None if value is None else uuid.UUID(value)
def format_date_time(value, formats, post_process=None):
post_process = post_process or (lambda x: x)
for fmt in formats:
return post_process(datetime.datetime.strptime(value, fmt))
except ValueError:
return value
def _date_part(date_part):
def dec(self):
return self.model_class._meta.database.extract_date(date_part, self)
return dec
class _BaseFormattedField(Field):
formats = None
def __init__(self, formats=None, *args, **kwargs):
if formats is not None:
self.formats = formats
super(_BaseFormattedField, self).__init__(*args, **kwargs)
def clone_base(self, **kwargs):
return super(_BaseFormattedField, self).clone_base(
class DateTimeField(_BaseFormattedField):
db_field = 'datetime'
formats = [
'%Y-%m-%d %H:%M:%S.%f',
'%Y-%m-%d %H:%M:%S',
def python_value(self, value):
if value and isinstance(value, basestring):
return format_date_time(value, self.formats)
return value
year = property(_date_part('year'))
month = property(_date_part('month'))
day = property(_date_part('day'))
hour = property(_date_part('hour'))
minute = property(_date_part('minute'))
second = property(_date_part('second'))
class DateField(_BaseFormattedField):
db_field = 'date'
formats = [
'%Y-%m-%d %H:%M:%S',
'%Y-%m-%d %H:%M:%S.%f',
def python_value(self, value):
if value and isinstance(value, basestring):
pp = lambda x:
return format_date_time(value, self.formats, pp)
elif value and isinstance(value, datetime.datetime):
return value
year = property(_date_part('year'))
month = property(_date_part('month'))
day = property(_date_part('day'))
class TimeField(_BaseFormattedField):
db_field = 'time'
formats = [
'%Y-%m-%d %H:%M:%S.%f',
'%Y-%m-%d %H:%M:%S',
def python_value(self, value):
if value and isinstance(value, basestring):
pp = lambda x: x.time()
return format_date_time(value, self.formats, pp)
elif value and isinstance(value, datetime.datetime):
return value.time()
return value
hour = property(_date_part('hour'))
minute = property(_date_part('minute'))
second = property(_date_part('second'))
class BooleanField(Field):
db_field = 'bool'
coerce = bool
class RelationDescriptor(FieldDescriptor):
"""Foreign-key abstraction to replace a related PK with a related model."""
def __init__(self, field, rel_model):
self.rel_model = rel_model
super(RelationDescriptor, self).__init__(field)
def get_object_or_id(self, instance):
rel_id = instance._data.get(self.att_name)
if rel_id is not None or self.att_name in instance._obj_cache:
if self.att_name not in instance._obj_cache:
obj = self.rel_model.get(self.field.to_field == rel_id)
instance._obj_cache[self.att_name] = obj
return instance._obj_cache[self.att_name]
elif not self.field.null:
raise self.rel_model.DoesNotExist
return rel_id
def __get__(self, instance, instance_type=None):
if instance is not None:
return self.get_object_or_id(instance)
return self.field
def __set__(self, instance, value):
if isinstance(value, self.rel_model):
instance._data[self.att_name] = getattr(
instance._obj_cache[self.att_name] = value
orig_value = instance._data.get(self.att_name)
instance._data[self.att_name] = value
if orig_value != value and self.att_name in instance._obj_cache:
del instance._obj_cache[self.att_name]
class ReverseRelationDescriptor(object):
"""Back-reference to expose related objects as a `SelectQuery`."""
def __init__(self, field):
self.field = field
self.rel_model = field.model_class
def __get__(self, instance, instance_type=None):
if instance is not None:
self.field == getattr(instance,
return self
class ForeignKeyField(IntegerField):
def __init__(self, rel_model, related_name=None, on_delete=None,
on_update=None, extra=None, to_field=None, *args, **kwargs):
if rel_model != 'self' and not isinstance(rel_model, Proxy) and not \
issubclass(rel_model, Model):
raise TypeError('Unexpected value for `rel_model`. Expected '
'`Model`, `Proxy` or "self"')
self.rel_model = rel_model
self._related_name = related_name
self.deferred = isinstance(rel_model, Proxy)
self.on_delete = on_delete
self.on_update = on_update
self.extra = extra
self.to_field = to_field
super(ForeignKeyField, self).__init__(*args, **kwargs)
def clone_base(self, **kwargs):
return super(ForeignKeyField, self).clone_base(
def _get_descriptor(self):
return RelationDescriptor(self, self.rel_model)
def _get_backref_descriptor(self):
return ReverseRelationDescriptor(self)
def _get_related_name(self):
return self._related_name or ('%s_set' %
def add_to_class(self, model_class, name):
if isinstance(self.rel_model, Proxy):
def callback(rel_model):
self.rel_model = rel_model
self.add_to_class(model_class, name)
return = name
self.model_class = model_class
self.db_column = self.db_column or '%s_id' %
if not self.verbose_name:
self.verbose_name = re.sub('_+', ' ', name).title()
model_class._meta.fields[] = self
model_class._meta.columns[self.db_column] = self
self.related_name = self._get_related_name()
if self.rel_model == 'self':
self.rel_model = self.model_class
if self.to_field is not None:
if not isinstance(self.to_field, Field):
self.to_field = getattr(self.rel_model, self.to_field)
self.to_field = self.rel_model._meta.primary_key
if model_class._meta.validate_backrefs:
if self.related_name in self.rel_model._meta.fields:
error = ('Foreign key: %s.%s related name "%s" collision with '
'model field of the same name.')
raise AttributeError(error % (,, self.related_name))
if self.related_name in self.rel_model._meta.reverse_rel:
error = ('Foreign key: %s.%s related name "%s" collision with '
'foreign key using same related_name.')
raise AttributeError(error % (,, self.related_name))
setattr(model_class, name, self._get_descriptor())
self._is_bound = True
model_class._meta.rel[] = self
self.rel_model._meta.reverse_rel[self.related_name] = self
def get_db_field(self):
Overridden to ensure Foreign Keys use same column type as the primary
key they point to.
if not isinstance(self.to_field, PrimaryKeyField):
return self.to_field.get_db_field()
return super(ForeignKeyField, self).get_db_field()
def get_modifiers(self):
if not isinstance(self.to_field, PrimaryKeyField):
return self.to_field.get_modifiers()
return super(ForeignKeyField, self).get_modifiers()
def coerce(self, value):
return self.to_field.coerce(value)
def db_value(self, value):
if isinstance(value, self.rel_model):
value = value._get_pk_value()
return self.to_field.db_value(value)
class CompositeKey(object):
"""A primary key composed of multiple columns."""
sequence = None
def __init__(self, *field_names):
self.field_names = field_names
def add_to_class(self, model_class, name): = name
self.model_class = model_class
setattr(model_class, name, self)
def __get__(self, instance, instance_type=None):
if instance is not None:
return tuple([getattr(instance, field_name)
for field_name in self.field_names])
return self
def __set__(self, instance, value):
def __eq__(self, other):
expressions = [(self.model_class._meta.fields[field] == value)
for field, value in zip(self.field_names, other)]
return reduce(operator.and_, expressions)
class AliasMap(object):
prefix = 't'
def __init__(self):
self._alias_map = {}
self._counter = 0
def __repr__(self):
return '<AliasMap: %s>' % self._alias_map
def add(self, obj, alias=None):
if obj in self._alias_map:
self._counter += 1
self._alias_map[obj] = alias or '%s%s' % (self.prefix, self._counter)
def __getitem__(self, obj):
if obj not in self._alias_map:
return self._alias_map[obj]
def __contains__(self, obj):
return obj in self._alias_map
def update(self, alias_map):
if alias_map:
for obj, alias in alias_map._alias_map.items():
if obj not in self:
self._alias_map[obj] = alias
return self
class QueryCompiler(object):
# Mapping of `db_type` to actual column type used by database driver.
# Database classes may provide additional column types or overrides.
field_map = {
'bare': '',
'bigint': 'BIGINT',
'blob': 'BLOB',
'bool': 'SMALLINT',
'date': 'DATE',
'datetime': 'DATETIME',
'decimal': 'DECIMAL',
'double': 'REAL',
'float': 'REAL',
'int': 'INTEGER',
'primary_key': 'INTEGER',
'string': 'VARCHAR',
'text': 'TEXT',
'time': 'TIME',
# Mapping of OP_ to actual SQL operation. For most databases this will be
# the same, but some column types or databases may support additional ops.
# Like `field_map`, Database classes may extend or override these.
op_map = {
OP_EQ: '=',
OP_LT: '<',
OP_LTE: '<=',
OP_GT: '>',
OP_GTE: '>=',
OP_NE: '!=',
OP_IN: 'IN',
OP_IS: 'IS',
OP_BIN_AND: '&',
OP_BIN_OR: '|',
OP_ADD: '+',
OP_SUB: '-',
OP_MUL: '*',
OP_DIV: '/',
OP_XOR: '#',
OP_OR: 'OR',
OP_MOD: '%',
OP_CONCAT: '||',
join_map = {
alias_map_class = AliasMap
def __init__(self, quote_char='"', interpolation='?', field_overrides=None,
self.quote_char = quote_char
self.interpolation = interpolation
self._field_map = merge_dict(self.field_map, field_overrides or {})
self._op_map = merge_dict(self.op_map, op_overrides or {})
self._parse_map = self.get_parse_map()
self._unknown_types = set(['param'])
def get_parse_map(self):
# To avoid O(n) lookups when parsing nodes, use a lookup table for
# common node types O(1).
return {
'expression': self._parse_expression,
'param': self._parse_param,
'passthrough': self._parse_param,
'func': self._parse_func,
'clause': self._parse_clause,
'entity': self._parse_entity,
'field': self._parse_field,
'sql': self._parse_sql,
'select_query': self._parse_select_query,
'compound_select_query': self._parse_compound_select_query,
'strip_parens': self._parse_strip_parens,
def quote(self, s):
return '%s%s%s' % (self.quote_char, s, self.quote_char)
def get_column_type(self, f):
return self._field_map[f]
def get_op(self, q):
return self._op_map[q]
def _sorted_fields(self, field_dict):
return sorted(field_dict.items(), key=lambda i: i[0]._sort_key)
def _clean_extra_parens(self, s):
# Quick sanity check.
if not s or s[0] != '(':
return s
ct = i = 0
l = len(s)
while i < l:
if s[i] == '(' and s[l - 1] == ')':
ct += 1
i += 1
l -= 1
if ct:
# If we ever end up with negatively-balanced parentheses, then we
# know that one of the outer parentheses was required.
unbalanced_ct = 0
required = 0
for i in range(ct, l - ct):
if s[i] == '(':
unbalanced_ct += 1
elif s[i] == ')':
unbalanced_ct -= 1
if unbalanced_ct < 0:
required += 1
unbalanced_ct = 0
if required == ct:
ct -= required
if ct > 0:
return s[ct:-ct]
return s
def _parse_default(self, node, alias_map, conv):
return self.interpolation, [node]
def _parse_expression(self, node, alias_map, conv):
if isinstance(node.lhs, Field):
conv = node.lhs
lhs, lparams = self.parse_node(node.lhs, alias_map, conv)
rhs, rparams = self.parse_node(node.rhs, alias_map, conv)
template = '%s %s %s' if node.flat else '(%s %s %s)'
sql = template % (lhs, self.get_op(node.op), rhs)
return sql, lparams + rparams
def _parse_param(self, node, alias_map, conv):
if node.conv:
params = [node.conv(node.value)]
params = [node.value]
return self.interpolation, params
def _parse_func(self, node, alias_map, conv):
conv = node._coerce and conv or None
sql, params = self.parse_node_list(node.arguments, alias_map, conv)
return '%s(%s)' % (, self._clean_extra_parens(sql)), params
def _parse_clause(self, node, alias_map, conv):
sql, params = self.parse_node_list(
node.nodes, alias_map, conv, node.glue)
if node.parens:
sql = '(%s)' % self._clean_extra_parens(sql)
return sql, params
def _parse_entity(self, node, alias_map, conv):
return '.'.join(map(self.quote, node.path)), []
def _parse_sql(self, node, alias_map, conv):
return node.value, list(node.params)
def _parse_field(self, node, alias_map, conv):
if alias_map:
sql = '.'.join((
sql = self.quote(node.db_column)
return sql, []
def _parse_compound_select_query(self, node, alias_map, conv):
l, lp = self.generate_select(node.lhs, alias_map)
r, rp = self.generate_select(node.rhs, alias_map)
sql = '(%s %s %s)' % (l, node.operator, r)
return sql, lp + rp
def _parse_select_query(self, node, alias_map, conv):
clone = node.clone()
if not node._explicit_selection:
if conv and isinstance(conv, ForeignKeyField):
select_field = conv.to_field
select_field = clone.model_class._meta.primary_key
clone._select = (select_field,)
sub, params = self.generate_select(clone, alias_map)
return '(%s)' % self._clean_extra_parens(sub), params
def _parse_strip_parens(self, node, alias_map, conv):
sql, params = self.parse_node(node.node, alias_map, conv)
return self._clean_extra_parens(sql), params
def _parse(self, node, alias_map, conv):
# By default treat the incoming node as a raw value that should be
# parameterized.
node_type = getattr(node, '_node_type', None)
unknown = False
if node_type in self._parse_map:
sql, params = self._parse_map[node_type](node, alias_map, conv)
unknown = node_type in self._unknown_types
elif isinstance(node, (list, tuple)):
# If you're wondering how to pass a list into your query, simply
# wrap it in Param().
sql, params = self.parse_node_list(node, alias_map, conv)
sql = '(%s)' % sql
elif isinstance(node, Model):
sql = self.interpolation
if conv and isinstance(conv, ForeignKeyField):
params = [
params = [node._get_pk_value()]
elif (isclass(node) and issubclass(node, Model)) or \
isinstance(node, ModelAlias):
entity = node._as_entity().alias(alias_map[node])
sql, params = self.parse_node(entity, alias_map, conv)
sql, params = self._parse_default(node, alias_map, conv)
unknown = True
return sql, params, unknown
def parse_node(self, node, alias_map=None, conv=None):
sql, params, unknown = self._parse(node, alias_map, conv)
if unknown and conv and params:
params = [conv.db_value(i) for i in params]
if isinstance(node, Node):
if node._negated:
sql = 'NOT %s' % sql
if node._alias:
sql = ' '.join((sql, 'AS', node._alias))
if node._ordering:
sql = ' '.join((sql, node._ordering))
return sql, params
def parse_node_list(self, nodes, alias_map, conv=None, glue=', '):
sql = []
params = []
for node in nodes:
node_sql, node_params = self.parse_node(node, alias_map, conv)
return glue.join(sql), params
def calculate_alias_map(self, query, alias_map=None):
new_map = self.alias_map_class()
if alias_map is not None:
new_map._counter = alias_map._counter
new_map.add(query.model_class, query.model_class._meta.table_alias)
for src_model, joined_models in query._joins.items():
new_map.add(src_model, src_model._meta.table_alias)
for join_obj in joined_models:
if isinstance(join_obj.dest, Node):
new_map.add(join_obj.dest, join_obj.dest.alias)
new_map.add(join_obj.dest, join_obj.dest._meta.table_alias)
return new_map.update(alias_map)
def build_query(self, clauses, alias_map=None):
return self.parse_node(Clause(*clauses), alias_map)
def generate_joins(self, joins, model_class, alias_map):
# Joins are implemented as an adjancency-list graph. Perform a
# depth-first search of the graph to generate all the necessary JOINs.
clauses = []
seen = set()
q = [model_class]
while q:
curr = q.pop()
if curr not in joins or curr in seen:
for join in joins[curr]:
src = curr
dest = join.dest
if isinstance(join.on, Expression):
# Clear any alias on the join expression.
constraint = join.on.clone().alias()
field = src._meta.rel_for_model(dest, join.on)
if field:
left_field = field
right_field = field.to_field
field = dest._meta.rel_for_model(src, join.on)
left_field = field.to_field
right_field = field
constraint = (left_field == right_field)
if isinstance(dest, Node):
# TODO: ensure alias?
dest_n = dest
dest_n = dest._as_entity().alias(alias_map[dest])
join_type = join.get_join_type()
if join_type in self.join_map:
join_sql = SQL(self.join_map[join_type])
join_sql = SQL(join_type)
Clause(join_sql, dest_n, SQL('ON'), constraint))
return clauses
def generate_select(self, query, alias_map=None):
model = query.model_class
db = model._meta.database
alias_map = self.calculate_alias_map(query, alias_map)
if isinstance(query, CompoundSelect):
clauses = [_StripParens(query)]
if not query._distinct:
clauses = [SQL('SELECT')]
clauses = [SQL('SELECT DISTINCT')]
if query._distinct not in (True, False):
clauses += [SQL('ON'), EnclosedClause(*query._distinct)]
select_clause = Clause(*query._select)
select_clause.glue = ', '
clauses.extend((select_clause, SQL('FROM')))
if query._from is None:
if query._windows is not None:
for window in query._windows]))
join_clauses = self.generate_joins(query._joins, model, alias_map)
if join_clauses:
if query._where is not None:
clauses.extend([SQL('WHERE'), query._where])
if query._group_by:
clauses.extend([SQL('GROUP BY'), CommaClause(*query._group_by)])
if query._having:
clauses.extend([SQL('HAVING'), query._having])
if query._order_by:
clauses.extend([SQL('ORDER BY'), CommaClause(*query._order_by)])
if query._limit or (query._offset and db.limit_max):
limit = query._limit or db.limit_max
clauses.append(SQL('LIMIT %s' % limit))
if query._offset:
clauses.append(SQL('OFFSET %s' % query._offset))
for_update, no_wait = query._for_update
if for_update:
stmt = 'FOR UPDATE NOWAIT' if no_wait else 'FOR UPDATE'
return self.build_query(clauses, alias_map)
def generate_update(self, query):
model = query.model_class
alias_map = self.alias_map_class()
alias_map.add(model, model._meta.db_table)
clauses = [SQL('UPDATE'), model._as_entity(), SQL('SET')]
update = []
for field, value in self._sorted_fields(query._update):
if not isinstance(value, (Node, Model)):
value = Param(value, conv=field.db_value)
flat=True)) # No outer parens, no table alias.
if query._where:
clauses.extend([SQL('WHERE'), query._where])
return self.build_query(clauses, alias_map)
def _get_field_clause(self, fields):
return EnclosedClause(*[
field._as_entity(with_table=False) for field in fields])
def generate_insert(self, query):
model = query.model_class
alias_map = self.alias_map_class()
alias_map.add(model, model._meta.db_table)
statement = query._upsert and 'INSERT OR REPLACE INTO' or 'INSERT INTO'
clauses = [SQL(statement), model._as_entity()]
if query._query is not None:
# This INSERT query is of the form INSERT INTO ... SELECT FROM.
if query._fields:
elif query._rows is not None:
fields, value_clauses = [], []
have_fields = False
for row_dict in query._iter_rows():
if not have_fields:
fields = sorted(
row_dict.keys(), key=operator.attrgetter('_sort_key'))
have_fields = True
values = []
for field in fields:
value = row_dict[field]
if not isinstance(value, (Node, Model)):
value = Param(value, conv=field.db_value)
if fields:
return self.build_query(clauses, alias_map)
def generate_delete(self, query):
model = query.model_class
clauses = [SQL('DELETE FROM'), model._as_entity()]
if query._where:
clauses.extend([SQL('WHERE'), query._where])
return self.build_query(clauses)
def field_definition(self, field):
column_type = self.get_column_type(field.get_db_field())
ddl = field.__ddl__(column_type)
return Clause(*ddl)
def foreign_key_constraint(self, field):
ddl = [
if field.on_delete:
ddl.append(SQL('ON DELETE %s' % field.on_delete))
if field.on_update:
ddl.append(SQL('ON UPDATE %s' % field.on_update))
return Clause(*ddl)
def return_parsed_node(function_name):
# TODO: treat all `generate_` functions as returning clauses, instead
# of SQL/params.
def inner(self, *args, **kwargs):
fn = getattr(self, function_name)
return self.parse_node(fn(*args, **kwargs))
return inner
def _create_foreign_key(self, model_class, field, constraint=None):
constraint = constraint or 'fk_%s_%s_refs_%s' % (
fk_clause = self.foreign_key_constraint(field)
return Clause(
create_foreign_key = return_parsed_node('_create_foreign_key')
def _create_table(self, model_class, safe=False):
statement = 'CREATE TABLE IF NOT EXISTS' if safe else 'CREATE TABLE'
meta = model_class._meta
columns, constraints = [], []
if isinstance(meta.primary_key, CompositeKey):
pk_cols = [meta.fields[f]._as_entity()
for f in meta.primary_key.field_names]
SQL('PRIMARY KEY'), EnclosedClause(*pk_cols)))
for field in meta.get_fields():
if isinstance(field, ForeignKeyField) and not field.deferred:
return Clause(
EnclosedClause(*(columns + constraints)))
create_table = return_parsed_node('_create_table')
def _drop_table(self, model_class, fail_silently=False, cascade=False):
statement = 'DROP TABLE IF EXISTS' if fail_silently else 'DROP TABLE'
ddl = [SQL(statement), model_class._as_entity()]
if cascade:
return Clause(*ddl)
drop_table = return_parsed_node('_drop_table')
def index_name(self, table, columns):
index = '%s_%s' % (table, '_'.join(columns))
if len(index) > 64:
index_hash = hashlib.md5(index.encode('utf-8')).hexdigest()
index = '%s_%s' % (table, index_hash)
return index
def _create_index(self, model_class, fields, unique, *extra):
tbl_name = model_class._meta.db_table
statement = 'CREATE UNIQUE INDEX' if unique else 'CREATE INDEX'
index_name = self.index_name(tbl_name, [f.db_column for f in fields])
return Clause(
EnclosedClause(*[field._as_entity() for field in fields]),
create_index = return_parsed_node('_create_index')
def _create_sequence(self, sequence_name):
return Clause(SQL('CREATE SEQUENCE'), Entity(sequence_name))
create_sequence = return_parsed_node('_create_sequence')
def _drop_sequence(self, sequence_name):
return Clause(SQL('DROP SEQUENCE'), Entity(sequence_name))
drop_sequence = return_parsed_node('_drop_sequence')
class QueryResultWrapper(object):
Provides an iterator over the results of a raw Query, additionally doing
two things:
- converts rows from the database into python representations
- ensures that multiple iterations do not result in multiple queries
def __init__(self, model, cursor, meta=None):
self.model = model
self.cursor = cursor
self.__ct = 0
self.__idx = 0
self._result_cache = []
self._populated = False
self._initialized = False
if meta is not None:
self.column_meta, self.join_meta = meta
self.column_meta = self.join_meta = None
def __iter__(self):
self.__idx = 0
if not self._populated:
return self
return iter(self._result_cache)
def process_row(self, row):
return row
def iterate(self):
row = self.cursor.fetchone()
if not row:
self._populated = True
if not getattr(self.cursor, 'name', None):
raise StopIteration
elif not self._initialized:
self._initialized = True
return self.process_row(row)
def iterator(self):
while True:
yield self.iterate()
def next(self):
if self.__idx < self.__ct:
inst = self._result_cache[self.__idx]
self.__idx += 1
return inst
obj = self.iterate()
self.__ct += 1
self.__idx += 1
return obj
__next__ = next
def fill_cache(self, n=None):
n = n or float('Inf')
if n < 0:
raise ValueError('Negative values are not supported.')
self.__idx = self.__ct
while not self._populated and (n > self.__ct):
except StopIteration:
class ExtQueryResultWrapper(QueryResultWrapper):
def initialize(self, description):
model = self.model
conv = []
identity = lambda x: x
for i in range(len(description)):
func = identity
column = description[i][0]
found = False
if self.column_meta is not None:
select_column = self.column_meta[i]
except IndexError:
if isinstance(select_column, Field):
func = select_column.python_value
column = select_column._alias or
found = True
elif (isinstance(select_column, Func) and
len(select_column.arguments) and
isinstance(select_column.arguments[0], Field)):
if select_column._coerce:
# Special-case handling aggregations.
func = select_column.arguments[0].python_value
found = True
if not found and column in model._meta.columns:
field_obj = model._meta.columns[column]
column =
func = field_obj.python_value
conv.append((i, column, func))
self.conv = conv
class TuplesQueryResultWrapper(ExtQueryResultWrapper):
def process_row(self, row):
return tuple([self.conv[i][2](col) for i, col in enumerate(row)])
class NaiveQueryResultWrapper(ExtQueryResultWrapper):
def process_row(self, row):
instance = self.model()
for i, column, func in self.conv:
setattr(instance, column, func(row[i]))
return instance
class DictQueryResultWrapper(ExtQueryResultWrapper):
def process_row(self, row):
res = {}
for i, column, func in self.conv:
res[column] = func(row[i])
return res
class ModelQueryResultWrapper(QueryResultWrapper):
def initialize(self, description):
self.column_map, model_set = self.generate_column_map()
self.join_list = self.generate_join_list(model_set)
def generate_column_map(self):
column_map = []
models = set([self.model])
for i, node in enumerate(self.column_meta):
attr = conv = None
if isinstance(node, Field):
if isinstance(node, FieldProxy):
key = node._model_alias
constructor = node.model
key = constructor = node.model_class
attr =
conv = node.python_value
key = constructor = self.model
if isinstance(node, Expression) and node._alias:
attr = node._alias
column_map.append((key, constructor, attr, conv))
return column_map, models
def generate_join_list(self, models):
join_list = []
joins = self.join_meta
stack = [self.model]
while stack:
current = stack.pop()
if current not in joins:
for join in joins[current]:
if join.dest in models:
return join_list
def process_row(self, row):
collected = self.construct_instances(row)
instances = self.follow_joins(collected)
for i in instances:
return instances[0]
def construct_instances(self, row, keys=None):
collected_models = {}
for i, (key, constructor, attr, conv) in enumerate(self.column_map):
if keys is not None and key not in keys:
value = row[i]
if key not in collected_models:
collected_models[key] = constructor()
instance = collected_models[key]
if attr is None:
attr = self.cursor.description[i][0]
if conv is not None:
value = conv(value)
setattr(instance, attr, value)
return collected_models
def follow_joins(self, collected):
prepared = [collected[self.model]]
for (lhs, attr, rhs, to_field, related_name) in self.join_list:
inst = collected[lhs]
joined_inst = collected[rhs]
# Can we populate a value on the joined instance using the current?
if to_field is not None and attr in inst._data:
if getattr(joined_inst, to_field) is None:
setattr(joined_inst, to_field, inst._data[attr])
setattr(inst, attr, joined_inst)
return prepared
class AggregateQueryResultWrapper(ModelQueryResultWrapper):
def __init__(self, *args, **kwargs):
self._row = []
super(AggregateQueryResultWrapper, self).__init__(*args, **kwargs)
def initialize(self, description):
super(AggregateQueryResultWrapper, self).initialize(description)
# Collect the set of all models queried.
self.all_models = set()
for key, _, _, _ in self.column_map:
# Prepare data structure for analyzing unique rows.
self.models_with_aggregate = set()
self.back_references = {}
for (src_model, _, dest_model, _, related_name) in self.join_list:
if related_name:
self.back_references[dest_model] = (src_model, related_name)
self.columns_to_compare = {}
for idx, (_, model_class, col_name, _) in enumerate(self.column_map):
if model_class in self.models_with_aggregate:
self.columns_to_compare.setdefault(model_class, [])
self.columns_to_compare[model_class].append((idx, col_name))
def read_model_data(self, row):
models = {}
for model_class, column_data in self.columns_to_compare.items():
models[model_class] = []
for idx, col_name in column_data:
return models
def iterate(self):
if self._row:
row = self._row.pop()
row = self.cursor.fetchone()
if not row:
self._populated = True
if not getattr(self.cursor, 'name', None):
raise StopIteration
elif not self._initialized:
self._initialized = True
def _get_pk(instance):
if isinstance(instance._meta.primary_key, CompositeKey):
return tuple([
for field_name in instance._meta.primary_key.field_names])
return instance._get_pk_value()
identity_map = {}
_constructed = self.construct_instances(row)
primary_instance = _constructed[self.model]
for model_class, instance in _constructed.items():
identity_map[model_class] = OrderedDict()
identity_map[model_class][_get_pk(instance)] = instance
model_data = self.read_model_data(row)
while True:
cur_row = self.cursor.fetchone()
if cur_row is None:
duplicate_models = set()
cur_row_data = self.read_model_data(cur_row)
for model_class, data in cur_row_data.items():
if model_data[model_class] == data:
if not duplicate_models:
different_models = self.all_models - duplicate_models
new_instances = self.construct_instances(cur_row, different_models)
for model_class, instance in new_instances.items():
# Do not include any instances which are comprised solely of
# NULL values.
pk_value = _get_pk(instance)
if [val for val in instance._data.values() if val is not None]:
identity_map[model_class][pk_value] = instance
stack = [self.model]
instances = [primary_instance]
while stack:
current = stack.pop()
if current not in self.join_meta:
for join in self.join_meta[current]:
foreign_key = current._meta.rel_for_model(join.dest, join.on)
if foreign_key:
if join.dest not in identity_map:
for pk, instance in identity_map[current].items():
joined_inst = identity_map[join.dest][
setattr(instance,, joined_inst)
if not isinstance(join.dest, Node):
backref = current._meta.reverse_rel_for_model(
join.dest, join.on)
if not backref:
attr_name = backref.related_name
for instance in identity_map[current].values():
setattr(instance, attr_name, [])
if join.dest not in identity_map:
for pk, instance in identity_map[join.dest].items():
if pk is None:
joined_inst = identity_map[current][
except KeyError:
getattr(joined_inst, attr_name).append(instance)
for instance in instances:
return primary_instance
class Query(Node):
"""Base class representing a database query on one or more tables."""
require_commit = True
def __init__(self, model_class):
super(Query, self).__init__()
self.model_class = model_class
self.database = model_class._meta.database
self._dirty = True
self._query_ctx = model_class
self._joins = {self.model_class: []} # Join graph as adjacency list.
self._where = None
def __repr__(self):
sql, params = self.sql()
return '%s %s %s' % (self.model_class, sql, params)
def clone(self):
query = type(self)(self.model_class)
query.database = self.database
return self._clone_attributes(query)
def _clone_attributes(self, query):
if self._where is not None:
query._where = self._where.clone()
query._joins = self._clone_joins()
query._query_ctx = self._query_ctx
return query
def _clone_joins(self):
return dict(
(mc, list(j)) for mc, j in self._joins.items())
def _add_query_clauses(self, initial, expressions, conjunction=None):
reduced = reduce(operator.and_, expressions)
if initial is None:
return reduced
conjunction = conjunction or operator.and_
return conjunction(initial, reduced)
def where(self, *expressions):
self._where = self._add_query_clauses(self._where, expressions)
def orwhere(self, *expressions):
self._where = self._add_query_clauses(
self._where, expressions, operator.or_)
def join(self, dest, join_type=None, on=None):
if not on:
require_join_condition = [
isinstance(dest, SelectQuery),
(isclass(dest) and not self._query_ctx._meta.rel_exists(dest))]
if any(require_join_condition):
raise ValueError('A join condition must be specified.')
elif isinstance(on, basestring):
on = self._query_ctx._meta.fields[on]
self._joins.setdefault(self._query_ctx, [])
self._joins[self._query_ctx].append(Join(dest, join_type, on))
if not isinstance(dest, SelectQuery):
self._query_ctx = dest
def switch(self, model_class=None):
"""Change or reset the query context."""
self._query_ctx = model_class or self.model_class
def ensure_join(self, lm, rm, on=None):
ctx = self._query_ctx
for join in self._joins.get(lm, []):
if join.dest == rm:
return self
return self.switch(lm).join(rm, on=on).switch(ctx)
def convert_dict_to_node(self, qdict):
accum = []
joins = []
relationship = (ForeignKeyField, ReverseRelationDescriptor)
for key, value in sorted(qdict.items()):
curr = self.model_class
if '__' in key and key.rsplit('__', 1)[1] in DJANGO_MAP:
key, op = key.rsplit('__', 1)
op = DJANGO_MAP[op]
op = OP_EQ
for piece in key.split('__'):
model_attr = getattr(curr, piece)
if isinstance(model_attr, relationship):
curr = model_attr.rel_model
accum.append(Expression(model_attr, op, value))
return accum, joins
def filter(self, *args, **kwargs):
# normalize args and kwargs into a new expression
dq_node = Node()
if args:
dq_node &= reduce(operator.and_, [a.clone() for a in args])
if kwargs:
dq_node &= DQ(**kwargs)
# dq_node should now be an Expression, lhs = Node(), rhs = ...
q = deque([dq_node])
dq_joins = set()
while q:
curr = q.popleft()
if not isinstance(curr, Expression):
for side, piece in (('lhs', curr.lhs), ('rhs', curr.rhs)):
if isinstance(piece, DQ):
query, joins = self.convert_dict_to_node(piece.query)
expression = reduce(operator.and_, query)
# Apply values from the DQ object.
expression._negated = piece._negated
expression._alias = piece._alias
setattr(curr, side, expression)
dq_node = dq_node.rhs
query = self.clone()
for field in dq_joins:
if isinstance(field, ForeignKeyField):
lm, rm = field.model_class, field.rel_model
field_obj = field
elif isinstance(field, ReverseRelationDescriptor):
lm, rm = field.field.rel_model, field.rel_model
field_obj = field.field
query = query.ensure_join(lm, rm, field_obj)
return query.where(dq_node)
def compiler(self):
return self.database.compiler()
def sql(self):
raise NotImplementedError
def _execute(self):
sql, params = self.sql()
return self.database.execute_sql(sql, params, self.require_commit)
def execute(self):
raise NotImplementedError
def scalar(self, as_tuple=False, convert=False):
if convert:
row = self.tuples().first()
row = self._execute().fetchone()
if row and not as_tuple:
return row[0]
return row
class RawQuery(Query):
Execute a SQL query, returning a standard iterable interface that returns
model instances.
def __init__(self, model, query, *params):
self._sql = query
self._params = list(params)
self._qr = None
self._tuples = False
self._dicts = False
super(RawQuery, self).__init__(model)
def clone(self):
query = RawQuery(self.model_class, self._sql, *self._params)
query._tuples = self._tuples
query._dicts = self._dicts
return query
join = not_allowed('joining')
where = not_allowed('where')
switch = not_allowed('switch')
def tuples(self, tuples=True):
self._tuples = tuples
def dicts(self, dicts=True):
self._dicts = dicts
def sql(self):
return self._sql, self._params
def execute(self):
if self._qr is None:
if self._tuples:
ResultWrapper = TuplesQueryResultWrapper
elif self._dicts:
ResultWrapper = DictQueryResultWrapper
ResultWrapper = NaiveQueryResultWrapper
self._qr = ResultWrapper(self.model_class, self._execute(), None)
return self._qr
def __iter__(self):
return iter(self.execute())
class SelectQuery(Query):
_node_type = 'select_query'
def __init__(self, model_class, *selection):
super(SelectQuery, self).__init__(model_class)
self.require_commit = self.database.commit_select
self._from = None
self._group_by = None
self._having = None
self._order_by = None
self._windows = None
self._limit = None
self._offset = None
self._distinct = False
self._for_update = (False, False)
self._naive = False
self._tuples = False
self._dicts = False
self._aggregate_rows = False
self._alias = None
self._qr = None
def _clone_attributes(self, query):
query = super(SelectQuery, self)._clone_attributes(query)
query._explicit_selection = self._explicit_selection
query._select = list(self._select)
if self._from is not None:
query._from = []
for f in self._from:
if isinstance(f, Node):
if self._group_by is not None:
query._group_by = list(self._group_by)
if self._having:
query._having = self._having.clone()
if self._order_by is not None:
query._order_by = list(self._order_by)
if self._windows is not None:
query._windows = list(self._windows)
query._limit = self._limit
query._offset = self._offset
query._distinct = self._distinct
query._for_update = self._for_update
query._naive = self._naive
query._tuples = self._tuples
query._dicts = self._dicts
query._aggregate_rows = self._aggregate_rows
query._alias = self._alias
return query
def _model_shorthand(self, args):
accum = []
for arg in args:
if isinstance(arg, Node):
elif isinstance(arg, Query):
elif isinstance(arg, ModelAlias):
elif isclass(arg) and issubclass(arg, Model):
return accum
def compound_op(operator):
def inner(self, other):
supported_ops = self.model_class._meta.database.compound_operations
if operator not in supported_ops:
raise ValueError(
'Your database does not support %s' % operator)
return CompoundSelect(self.model_class, self, operator, other)
return inner
_compound_op_static = staticmethod(compound_op)
__or__ = compound_op('UNION')
__and__ = compound_op('INTERSECT')
__sub__ = compound_op('EXCEPT')
def __xor__(self, rhs):
# Symmetric difference, should just be (self | rhs) - (self & rhs)...
wrapped_rhs ='*')).from_(
EnclosedClause((self & rhs)).alias('_')).order_by()
return (self | rhs) - wrapped_rhs
def union_all(self, rhs):
return SelectQuery._compound_op_static('UNION ALL')(self, rhs)
def __select(self, *selection):
self._explicit_selection = len(selection) > 0
selection = selection or self.model_class._meta.get_fields()
self._select = self._model_shorthand(selection)
select = returns_clone(__select)
def from_(self, *args):
self._from = None
if args:
self._from = list(args)
def group_by(self, *args):
self._group_by = self._model_shorthand(args)
def having(self, *expressions):
self._having = self._add_query_clauses(self._having, expressions)
def order_by(self, *args):
self._order_by = list(args)
def window(self, *windows):
self._windows = list(windows)
def limit(self, lim):
self._limit = lim
def offset(self, off):
self._offset = off
def paginate(self, page, paginate_by=20):
if page > 0:
page -= 1
self._limit = paginate_by
self._offset = page * paginate_by
def distinct(self, is_distinct=True):
self._distinct = is_distinct
def for_update(self, for_update=True, nowait=False):
self._for_update = (for_update, nowait)
def naive(self, naive=True):
self._naive = naive
def tuples(self, tuples=True):
self._tuples = tuples
def dicts(self, dicts=True):
self._dicts = dicts
def aggregate_rows(self, aggregate_rows=True):
self._aggregate_rows = aggregate_rows
def alias(self, alias=None):
self._alias = alias
def annotate(self, rel_model, annotation=None):
if annotation is None:
annotation = fn.Count(rel_model._meta.primary_key).alias('count')
if self._query_ctx == rel_model:
query = self.switch(self.model_class)
query = self.clone()
query = query.ensure_join(query._query_ctx, rel_model)
if not query._group_by:
query._group_by = [x.alias() for x in query._select]
query._select = tuple(query._select) + (annotation,)
return query
def _aggregate(self, aggregation=None):
if aggregation is None:
aggregation = fn.Count(SQL('*'))
query = self.order_by()
query._select = [aggregation]
return query
def aggregate(self, aggregation=None, convert=True):
return self._aggregate(aggregation).scalar(convert=convert)
def count(self, clear_limit=False):
if self._distinct or self._group_by or self._limit or self._offset:
return self.wrapped_count(clear_limit=clear_limit)
# defaults to a count() of the primary key
return self.aggregate(convert=False) or 0
def wrapped_count(self, clear_limit=False):
clone = self.order_by()
if clear_limit:
clone._limit = clone._offset = None
sql, params = clone.sql()
wrapped = 'SELECT COUNT(1) FROM (%s) AS wrapped_select' % sql
rq = self.model_class.raw(wrapped, *params)
return rq.scalar() or 0
def exists(self):
clone = self.paginate(1, 1)
clone._select = [SQL('1')]
return bool(clone.scalar())
def get(self):
clone = self.paginate(1, 1)
return clone.execute().next()
except StopIteration:
raise self.model_class.DoesNotExist(
'Instance matching query does not exist:\nSQL: %s\nPARAMS: %s'
% self.sql())
def first(self):
res = self.execute()
return res._result_cache[0]
except IndexError:
def sql(self):
return self.compiler().generate_select(self)
def verify_naive(self):
model_class = self.model_class
for node in self._select:
if isinstance(node, Field) and node.model_class != model_class:
return False
return True
def get_query_meta(self):
return (self._select, self._joins)
def execute(self):
if self._dirty or not self._qr:
model_class = self.model_class
query_meta = self.get_query_meta()
if self._tuples:
ResultWrapper = TuplesQueryResultWrapper
elif self._dicts:
ResultWrapper = DictQueryResultWrapper
elif self._naive or not self._joins or self.verify_naive():
ResultWrapper = NaiveQueryResultWrapper
elif self._aggregate_rows:
ResultWrapper = AggregateQueryResultWrapper
ResultWrapper = ModelQueryResultWrapper
self._qr = ResultWrapper(model_class, self._execute(), query_meta)
self._dirty = False
return self._qr
return self._qr
def __iter__(self):
return iter(self.execute())
def iterator(self):
return iter(self.execute().iterator())
def __getitem__(self, value):
res = self.execute()
if isinstance(value, slice):
index = value.stop
index = value
if index is not None and index >= 0:
index += 1
return res._result_cache[value]
if PY3:
def __hash__(self):
return id(self)
class CompoundSelect(SelectQuery):
_node_type = 'compound_select_query'
def __init__(self, model_class, lhs=None, operator=None, rhs=None):
self.lhs = lhs
self.operator = operator
self.rhs = rhs
super(CompoundSelect, self).__init__(model_class, [])
def _clone_attributes(self, query):
query = super(CompoundSelect, self)._clone_attributes(query)
query.lhs = self.lhs
query.operator = self.operator
query.rhs = self.rhs
return query
def get_query_meta(self):
return self.lhs.get_query_meta()
class UpdateQuery(Query):
def __init__(self, model_class, update=None):
self._update = update
super(UpdateQuery, self).__init__(model_class)
def _clone_attributes(self, query):
query = super(UpdateQuery, self)._clone_attributes(query)
query._update = dict(self._update)
return query
join = not_allowed('joining')
def sql(self):
return self.compiler().generate_update(self)
def execute(self):
return self.database.rows_affected(self._execute())
class InsertQuery(Query):
def __init__(self, model_class, field_dict=None, rows=None,
fields=None, query=None):
super(InsertQuery, self).__init__(model_class)
self._upsert = False
self._is_multi_row_insert = rows is not None or query is not None
if rows is not None:
self._rows = rows
self._rows = [field_dict or {}]
self._fields = fields
self._query = query
def _iter_rows(self):
model_meta = self.model_class._meta
valid_fields = (set(model_meta.fields.keys()) |
def validate_field(field):
if field not in valid_fields:
raise KeyError('"%s" is not a recognized field.' % field)
defaults = model_meta._default_dict
callables = model_meta._default_callables
for row_dict in self._rows:
field_row = defaults.copy()
seen = set()
for key in row_dict:
if key in model_meta.fields:
field = model_meta.fields[key]
field = key
field_row[field] = row_dict[key]
if callables:
for field in callables:
if field not in seen:
field_row[field] = callables[field]()
yield field_row
def _clone_attributes(self, query):
query = super(InsertQuery, self)._clone_attributes(query)
query._rows = self._rows
query._upsert = self._upsert
query._is_multi_row_insert = self._is_multi_row_insert
query._fields = self._fields
query._query = self._query
return query
join = not_allowed('joining')
where = not_allowed('where clause')
def upsert(self, upsert=True):
self._upsert = upsert
def sql(self):
return self.compiler().generate_insert(self)
def execute(self):
if self._is_multi_row_insert and self._query is None:
if not self.database.insert_many:
last_id = None
for row in self._rows:
last_id = (InsertQuery(self.model_class, row)
return last_id
return self.database.last_insert_id(self._execute(), self.model_class)
class DeleteQuery(Query):
join = not_allowed('joining')
def sql(self):
return self.compiler().generate_delete(self)
def execute(self):
return self.database.rows_affected(self._execute())
IndexMetadata = namedtuple(
('name', 'sql', 'columns', 'unique', 'table'))
ColumnMetadata = namedtuple(
('name', 'data_type', 'null', 'primary_key', 'table'))
ForeignKeyMetadata = namedtuple(
('column', 'dest_table', 'dest_column', 'table'))
class PeeweeException(Exception): pass
class ImproperlyConfigured(PeeweeException): pass
class DatabaseError(PeeweeException): pass
class DataError(DatabaseError): pass
class IntegrityError(DatabaseError): pass
class InterfaceError(PeeweeException): pass
class InternalError(DatabaseError): pass
class NotSupportedError(DatabaseError): pass
class OperationalError(DatabaseError): pass
class ProgrammingError(DatabaseError): pass
class ExceptionWrapper(object):
__slots__ = ['exceptions']
def __init__(self, exceptions):
self.exceptions = exceptions
def __enter__(self): pass
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None:
if exc_type.__name__ in self.exceptions:
new_type = self.exceptions[exc_type.__name__]
reraise(new_type, new_type(*exc_value.args), traceback)
class _BaseConnectionLocal(object):
def __init__(self, **kwargs):
super(_BaseConnectionLocal, self).__init__(**kwargs)
self.autocommit = None
self.closed = True
self.conn = None
self.context_stack = []
self.transactions = []
class _ConnectionLocal(_BaseConnectionLocal, threading.local):
class Database(object):
commit_select = False
compiler_class = QueryCompiler
compound_operations = ['UNION', 'INTERSECT', 'EXCEPT', 'UNION ALL']
distinct_on = False
drop_cascade = False
field_overrides = {}
foreign_keys = True
for_update = False
for_update_nowait = False
insert_many = True
interpolation = '?'
limit_max = None
op_overrides = {}
quote_char = '"'
reserved_tables = []
savepoints = True
sequences = False
subquery_delete_same_table = True
window_functions = False
exceptions = {
'ConstraintError': IntegrityError,
'DatabaseError': DatabaseError,
'DataError': DataError,
'IntegrityError': IntegrityError,
'InterfaceError': InterfaceError,
'InternalError': InternalError,
'NotSupportedError': NotSupportedError,
'OperationalError': OperationalError,
'ProgrammingError': ProgrammingError}
def __init__(self, database, threadlocals=True, autocommit=True,
fields=None, ops=None, autorollback=False, **connect_kwargs):
self.init(database, **connect_kwargs)
if threadlocals:
self.__local = _ConnectionLocal()
self.__local = _BaseConnectionLocal()
self._conn_lock = threading.Lock()
self.autocommit = autocommit
self.autorollback = autorollback
self.field_overrides = merge_dict(self.field_overrides, fields or {})
self.op_overrides = merge_dict(self.op_overrides, ops or {})
def init(self, database, **connect_kwargs):
self.deferred = database is None
self.database = database
self.connect_kwargs = connect_kwargs
def exception_wrapper(self):
return ExceptionWrapper(self.exceptions)
def connect(self):
with self._conn_lock:
if self.deferred:
raise Exception('Error, database not properly initialized '
'before opening connection')
with self.exception_wrapper():
self.__local.conn = self._connect(
self.__local.closed = False
def close(self):
with self._conn_lock:
if self.deferred:
raise Exception('Error, database not properly initialized '
'before closing connection')
with self.exception_wrapper():
self.__local.closed = True
def get_conn(self):
if self.__local.context_stack:
return self.__local.context_stack[-1].connection
if self.__local.closed:
return self.__local.conn
def is_closed(self):
return self.__local.closed
def get_cursor(self):
return self.get_conn().cursor()
def _close(self, conn):
def _connect(self, database, **kwargs):
raise NotImplementedError
def register_fields(cls, fields):
cls.field_overrides = merge_dict(cls.field_overrides, fields)
def register_ops(cls, ops):
cls.op_overrides = merge_dict(cls.op_overrides, ops)
def last_insert_id(self, cursor, model):
if model._meta.auto_increment:
return cursor.lastrowid
def rows_affected(self, cursor):
return cursor.rowcount
def sql_error_handler(self, exception, sql, params, require_commit):
return True
def compiler(self):
return self.compiler_class(
self.quote_char, self.interpolation, self.field_overrides,
def execute_sql(self, sql, params=None, require_commit=True):
logger.debug((sql, params))
with self.exception_wrapper():
cursor = self.get_cursor()
cursor.execute(sql, params or ())
except Exception as exc:
if self.get_autocommit() and self.autorollback:
if self.sql_error_handler(exc, sql, params, require_commit):
if require_commit and self.get_autocommit():
return cursor
def begin(self):
def commit(self):
def rollback(self):
def set_autocommit(self, autocommit):
self.__local.autocommit = autocommit
def get_autocommit(self):
if self.__local.autocommit is None:
return self.__local.autocommit
def push_execution_context(self, transaction):
def pop_execution_context(self):
def execution_context_depth(self):
return len(self.__local.context_stack)
def execution_context(self, with_transaction=True):
return ExecutionContext(self, with_transaction=with_transaction)
def push_transaction(self, transaction):
def pop_transaction(self):
def transaction_depth(self):
return len(self.__local.transactions)
def transaction(self):
return transaction(self)
def commit_on_success(self, func):
def inner(*args, **kwargs):
with self.transaction():
return func(*args, **kwargs)
return inner
def savepoint(self, sid=None):
if not self.savepoints:
raise NotImplementedError
return savepoint(self, sid)
def atomic(self):
return _atomic(self)
def get_tables(self, schema=None):
raise NotImplementedError
def get_indexes(self, table, schema=None):
raise NotImplementedError
def get_columns(self, table, schema=None):
raise NotImplementedError
def get_primary_keys(self, table, schema=None):
raise NotImplementedError
def get_foreign_keys(self, table, schema=None):
raise NotImplementedError
def sequence_exists(self, seq):
raise NotImplementedError
def create_table(self, model_class, safe=False):
qc = self.compiler()
return self.execute_sql(*qc.create_table(model_class, safe))
def create_tables(self, models, safe=False):
create_model_tables(models, fail_silently=safe)
def create_index(self, model_class, fields, unique=False):
qc = self.compiler()
if not isinstance(fields, (list, tuple)):
raise ValueError('Fields passed to "create_index" must be a list '
'or tuple: "%s"' % fields)
fobjs = [
model_class._meta.fields[f] if isinstance(f, basestring) else f
for f in fields]
return self.execute_sql(*qc.create_index(model_class, fobjs, unique))
def create_foreign_key(self, model_class, field, constraint=None):
qc = self.compiler()
return self.execute_sql(*qc.create_foreign_key(
model_class, field, constraint))
def create_sequence(self, seq):
if self.sequences:
qc = self.compiler()
return self.execute_sql(*qc.create_sequence(seq))
def drop_table(self, model_class, fail_silently=False, cascade=False):
qc = self.compiler()
return self.execute_sql(*qc.drop_table(
model_class, fail_silently, cascade))
def drop_tables(self, models, safe=False, cascade=False):
drop_model_tables(models, fail_silently=safe, cascade=cascade)
def drop_sequence(self, seq):
if self.sequences:
qc = self.compiler()
return self.execute_sql(*qc.drop_sequence(seq))
def extract_date(self, date_part, date_field):
return fn.EXTRACT(Clause(date_part, R('FROM'), date_field))
def truncate_date(self, date_part, date_field):
return fn.DATE_TRUNC(SQL(date_part), date_field)
class SqliteDatabase(Database):
foreign_keys = False
insert_many = sqlite3 and sqlite3.sqlite_version_info >= (3, 7, 11, 0)
limit_max = -1
op_overrides = {
def __init__(self, *args, **kwargs):
self._journal_mode = kwargs.pop('journal_mode', None)
super(SqliteDatabase, self).__init__(*args, **kwargs)
if not self.database:
self.database = ':memory:'
def _connect(self, database, **kwargs):
conn = sqlite3.connect(database, **kwargs)
conn.isolation_level = None
return conn
def _add_conn_hooks(self, conn):
conn.create_function('date_part', 2, _sqlite_date_part)
conn.create_function('date_trunc', 2, _sqlite_date_trunc)
conn.create_function('regexp', 2, _sqlite_regexp)
if self._journal_mode:
self.execute_sql('PRAGMA journal_mode=%s;' % self._journal_mode)
def begin(self, lock_type='DEFERRED'):
self.execute_sql('BEGIN %s' % lock_type, require_commit=False)
def get_tables(self, schema=None):
cursor = self.execute_sql('SELECT name FROM sqlite_master WHERE '
'type = ? ORDER BY name;', ('table',))
return [row[0] for row in cursor.fetchall()]
def get_indexes(self, table, schema=None):
query = ('SELECT name, sql FROM sqlite_master '
'WHERE tbl_name = ? AND type = ? ORDER BY name')
cursor = self.execute_sql(query, (table, 'index'))
index_to_sql = dict(cursor.fetchall())
# Determine which indexes have a unique constraint.
unique_indexes = set()
cursor = self.execute_sql('PRAGMA index_list("%s")' % table)
for _, name, is_unique in cursor.fetchall():
if is_unique:
# Retrieve the indexed columns.
index_columns = {}
for index_name in sorted(index_to_sql):
cursor = self.execute_sql('PRAGMA index_info("%s")' % index_name)
index_columns[index_name] = [row[2] for row in cursor.fetchall()]
return [
name in unique_indexes,
for name in sorted(index_to_sql)]
def get_columns(self, table, schema=None):
cursor = self.execute_sql('PRAGMA table_info("%s")' % table)
return [ColumnMetadata(row[1], row[2], not row[3], bool(row[5]), table)
for row in cursor.fetchall()]
def get_primary_keys(self, table, schema=None):
cursor = self.execute_sql('PRAGMA table_info("%s")' % table)
return [row[1] for row in cursor.fetchall() if row[-1]]
def get_foreign_keys(self, table, schema=None):
cursor = self.execute_sql('PRAGMA foreign_key_list("%s")' % table)
return [ForeignKeyMetadata(row[3], row[2], row[4], table)
for row in cursor.fetchall()]
def savepoint(self, sid=None):
return savepoint_sqlite(self, sid)
def extract_date(self, date_part, date_field):
return fn.date_part(date_part, date_field)
def truncate_date(self, date_part, date_field):
return fn.strftime(SQLITE_DATE_TRUNC_MAPPING[date_part], date_field)
class PostgresqlDatabase(Database):
commit_select = True
distinct_on = True
drop_cascade = True
field_overrides = {
'blob': 'BYTEA',
'bool': 'BOOLEAN',
'datetime': 'TIMESTAMP',
'decimal': 'NUMERIC',
'primary_key': 'SERIAL',
'uuid': 'UUID',
for_update = True
for_update_nowait = True
interpolation = '%s'
op_overrides = {
reserved_tables = ['user']
sequences = True
window_functions = True
register_unicode = True
def _connect(self, database, **kwargs):
if not psycopg2:
raise ImproperlyConfigured('psycopg2 must be installed.')
conn = psycopg2.connect(database=database, **kwargs)
if self.register_unicode:
pg_extensions.register_type(pg_extensions.UNICODE, conn)
pg_extensions.register_type(pg_extensions.UNICODEARRAY, conn)
return conn
def last_insert_id(self, cursor, model):
meta = model._meta
schema = ''
if meta.schema:
schema = '%s.' % meta.schema
if meta.primary_key.sequence:
seq = meta.primary_key.sequence
elif meta.auto_increment:
seq = '%s_%s_seq' % (meta.db_table, meta.primary_key.db_column)
seq = None
if seq:
cursor.execute("SELECT CURRVAL('%s\"%s\"')" % (schema, seq))
result = cursor.fetchone()[0]
if self.get_autocommit():
return result
def get_tables(self, schema='public'):
query = ('SELECT tablename FROM pg_catalog.pg_tables '
'WHERE schemaname = %s ORDER BY tablename')
return [r for r, in self.execute_sql(query, (schema,)).fetchall()]
def get_indexes(self, table, schema='public'):
query = """
i.relname, idxs.indexdef, idx.indisunique,
array_to_string(array_agg(cols.attname), ',')
FROM pg_catalog.pg_class AS t
INNER JOIN pg_catalog.pg_index AS idx ON t.oid = idx.indrelid
INNER JOIN pg_catalog.pg_class AS i ON idx.indexrelid = i.oid
INNER JOIN pg_catalog.pg_indexes AS idxs ON
(idxs.tablename = t.relname AND idxs.indexname = i.relname)
LEFT OUTER JOIN pg_catalog.pg_attribute AS cols ON
(cols.attrelid = t.oid AND cols.attnum = ANY(idx.indkey))
WHERE t.relname = %s AND t.relkind = %s AND idxs.schemaname = %s
GROUP BY i.relname, idxs.indexdef, idx.indisunique
ORDER BY idx.indisunique DESC, i.relname;"""
cursor = self.execute_sql(query, (table, 'r', schema))
return [IndexMetadata(row[0], row[1], row[3].split(','), row[2], table)
for row in cursor.fetchall()]
def get_columns(self, table, schema='public'):
query = """
SELECT column_name, is_nullable, data_type
FROM information_schema.columns
WHERE table_name = %s AND table_schema = %s"""
cursor = self.execute_sql(query, (table, schema))
pks = set(self.get_primary_keys(table, schema))
return [ColumnMetadata(name, dt, null == 'YES', name in pks, table)
for name, null, dt in cursor.fetchall()]
def get_primary_keys(self, table, schema='public'):
query = """
SELECT kc.column_name
FROM information_schema.table_constraints AS tc
INNER JOIN information_schema.key_column_usage AS kc ON (
tc.table_name = kc.table_name AND
tc.table_schema = kc.table_schema AND
tc.constraint_name = kc.constraint_name)
tc.constraint_type = %s AND
tc.table_name = %s AND
tc.table_schema = %s"""
cursor = self.execute_sql(query, ('PRIMARY KEY', table, schema))
return [row for row, in cursor.fetchall()]
def get_foreign_keys(self, table, schema='public'):
sql = """
kcu.column_name, ccu.table_name, ccu.column_name
FROM information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON (tc.constraint_name = kcu.constraint_name AND
tc.constraint_schema = kcu.constraint_schema)
JOIN information_schema.constraint_column_usage AS ccu
ON (ccu.constraint_name = tc.constraint_name AND
ccu.constraint_schema = tc.constraint_schema)
tc.constraint_type = 'FOREIGN KEY' AND
tc.table_name = %s AND
tc.table_schema = %s"""
cursor = self.execute_sql(sql, (table, schema))
return [ForeignKeyMetadata(row[0], row[1], row[2], table)
for row in cursor.fetchall()]
def sequence_exists(self, sequence):
res = self.execute_sql("""
SELECT COUNT(*) FROM pg_class, pg_namespace
WHERE relkind='S'
AND pg_class.relnamespace = pg_namespace.oid
AND relname=%s""", (sequence,))
return bool(res.fetchone()[0])
def set_search_path(self, *search_path):
path_params = ','.join(['%s'] * len(search_path))
self.execute_sql('SET search_path TO %s' % path_params, search_path)
class MySQLDatabase(Database):
commit_select = True
compound_operations = ['UNION', 'UNION ALL']
field_overrides = {
'bool': 'BOOL',
'decimal': 'NUMERIC',
'float': 'FLOAT',
'primary_key': 'INTEGER AUTO_INCREMENT',
'text': 'LONGTEXT',
for_update = True
interpolation = '%s'
limit_max = 2 ** 64 - 1 # MySQL quirk
op_overrides = {
quote_char = '`'
subquery_delete_same_table = False
def _connect(self, database, **kwargs):
if not mysql:
raise ImproperlyConfigured('MySQLdb or PyMySQL must be installed.')
conn_kwargs = {
'charset': 'utf8',
'use_unicode': True,
if 'password' in conn_kwargs:
conn_kwargs['passwd'] = conn_kwargs.pop('password')
return mysql.connect(db=database, **conn_kwargs)
def get_tables(self, schema=None):
return [row for row, in self.execute_sql('SHOW TABLES')]
def get_indexes(self, table, schema=None):
cursor = self.execute_sql('SHOW INDEX FROM `%s`' % table)
unique = set()
indexes = {}
for row in cursor.fetchall():
if not row[1]:
indexes.setdefault(row[2], [])
return [IndexMetadata(name, None, indexes[name], name in unique, table)
for name in indexes]
def get_columns(self, table, schema=None):
sql = """
SELECT column_name, is_nullable, data_type
FROM information_schema.columns
WHERE table_name = %s AND table_schema = DATABASE()"""
cursor = self.execute_sql(sql, (table,))
pks = set(self.get_primary_keys(table))
return [ColumnMetadata(name, dt, null == 'YES', name in pks, table)
for name, null, dt in cursor.fetchall()]
def get_primary_keys(self, table, schema=None):
cursor = self.execute_sql('SHOW INDEX FROM `%s`' % table)
return [row[4] for row in cursor.fetchall() if row[2] == 'PRIMARY']
def get_foreign_keys(self, table, schema=None):
query = """
SELECT column_name, referenced_table_name, referenced_column_name
FROM information_schema.key_column_usage
WHERE table_name = %s
AND table_schema = DATABASE()
AND referenced_table_name IS NOT NULL
AND referenced_column_name IS NOT NULL"""
cursor = self.execute_sql(query, (table,))
return [
ForeignKeyMetadata(column, dest_table, dest_column, table)
for column, dest_table, dest_column in cursor.fetchall()]
def extract_date(self, date_part, date_field):
return fn.EXTRACT(Clause(R(date_part), R('FROM'), date_field))
def truncate_date(self, date_part, date_field):
return fn.DATE_FORMAT(date_field, MYSQL_DATE_TRUNC_MAPPING[date_part])
class _callable_context_manager(object):
def __call__(self, fn):
def inner(*args, **kwargs):
with self:
return fn(*args, **kwargs)
return inner
class ExecutionContext(_callable_context_manager):
def __init__(self, database, with_transaction=True):
self.database = database
self.with_transaction = with_transaction
def __enter__(self):
with self.database._conn_lock:
self.connection = self.database._connect(
if self.with_transaction:
self.txn = self.database.transaction()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
with self.database._conn_lock:
if self.with_transaction:
if not exc_type:
self.txn.__exit__(exc_type, exc_val, exc_tb)
class Using(ExecutionContext):
def __init__(self, database, models, with_transaction=True):
super(Using, self).__init__(database, with_transaction)
self.models = models
def __enter__(self):
self._orig = []
for model in self.models:
model._meta.database = self.database
return super(Using, self).__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
super(Using, self).__exit__(exc_type, exc_val, exc_tb)
for i, model in enumerate(self.models):
model._meta.database = self._orig[i]
class _atomic(_callable_context_manager):
def __init__(self, db):
self.db = db
def __enter__(self):
if self.db.transaction_depth() == 0:
self._helper = self.db.transaction()
self._helper = self.db.savepoint()
return self._helper.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
return self._helper.__exit__(exc_type, exc_val, exc_tb)
class transaction(_callable_context_manager):
def __init__(self, db):
self.db = db
def _begin(self):
def commit(self, begin=True):
if begin:
def rollback(self, begin=True):
if begin:
def __enter__(self):
self._orig = self.db.get_autocommit()
if self.db.transaction_depth() == 0:
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type:
elif self.db.transaction_depth() == 1:
class savepoint(_callable_context_manager):
def __init__(self, db, sid=None):
self.db = db
_compiler = db.compiler()
self.sid = sid or 's' + uuid.uuid4().hex
self.quoted_sid = _compiler.quote(self.sid)
def _execute(self, query):
self.db.execute_sql(query, require_commit=False)
def commit(self):
self._execute('RELEASE SAVEPOINT %s;' % self.quoted_sid)
def rollback(self):
self._execute('ROLLBACK TO SAVEPOINT %s;' % self.quoted_sid)
def __enter__(self):
self._orig_autocommit = self.db.get_autocommit()
self._execute('SAVEPOINT %s;' % self.quoted_sid)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type:
class savepoint_sqlite(savepoint):
def __enter__(self):
conn = self.db.get_conn()
# For sqlite, the connection's isolation_level *must* be set to None.
# The act of setting it, though, will break any existing savepoints,
# so only write to it if necessary.
if conn.isolation_level is not None:
self._orig_isolation_level = conn.isolation_level
conn.isolation_level = None
self._orig_isolation_level = None
return super(savepoint_sqlite, self).__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
return super(savepoint_sqlite, self).__exit__(
exc_type, exc_val, exc_tb)
if self._orig_isolation_level is not None:
self.db.get_conn().isolation_level = self._orig_isolation_level
class FieldProxy(Field):
def __init__(self, alias, field_instance):
self._model_alias = alias
self.model = self._model_alias.model_class
self.field_instance = field_instance
def clone_base(self):
return FieldProxy(self._model_alias, self.field_instance)
def coerce(self, value):
return self.field_instance.coerce(value)
def python_value(self, value):
return self.field_instance.python_value(value)
def db_value(self, value):
return self.field_instance.db_value(value)
def __getattr__(self, attr):
if attr == 'model_class':
return self._model_alias
return getattr(self.field_instance, attr)
class ModelAlias(object):
def __init__(self, model_class):
self.__dict__['model_class'] = model_class
def __getattr__(self, attr):
model_attr = getattr(self.model_class, attr)
if isinstance(model_attr, Field):
return FieldProxy(self, model_attr)
return model_attr
def __setattr__(self, attr, value):
raise AttributeError('Cannot set attributes on ModelAlias instances')
def get_proxy_fields(self):
return [
FieldProxy(self, f) for f in self.model_class._meta.get_fields()]
def select(self, *selection):
query = SelectQuery(self, *selection)
if self._meta.order_by:
query = query.order_by(*self._meta.order_by)
return query
class DoesNotExist(Exception): pass
if sqlite3:
default_database = SqliteDatabase('peewee.db')
default_database = None
class ModelOptions(object):
def __init__(self, cls, database=None, db_table=None, indexes=None,
order_by=None, primary_key=None, table_alias=None,
constraints=None, schema=None, validate_backrefs=True,
self.model_class = cls = cls.__name__.lower()
self.fields = {}
self.columns = {}
self.defaults = {}
self._default_by_name = {}
self._default_dict = {}
self._default_callables = {}
self.database = database or default_database
self.db_table = db_table
self.indexes = list(indexes or [])
self.order_by = order_by
self.primary_key = primary_key
self.table_alias = table_alias
self.constraints = constraints
self.schema = schema
self.validate_backrefs = validate_backrefs
self.auto_increment = None
self.rel = {}
self.reverse_rel = {}
for key, value in kwargs.items():
setattr(self, key, value)
self._additional_keys = set(kwargs.keys())
def prepared(self):
for field in self.fields.values():
if field.default is not None:
self.defaults[field] = field.default
if callable(field.default):
self._default_callables[field] = field.default
self._default_dict[field] = field.default
self._default_by_name[] = field.default
if self.order_by:
norm_order_by = []
for item in self.order_by:
if isinstance(item, Field):
prefix = '-' if item._ordering == 'DESC' else ''
item = prefix +
field = self.fields[item.lstrip('-')]
if item.startswith('-'):
self.order_by = norm_order_by
def get_default_dict(self):
dd = self._default_by_name.copy()
if self._default_callables:
for field, default in self._default_callables.items():
dd[] = default()
return dd
def get_sorted_fields(self):
key = lambda i: i[1]._sort_key
return sorted(self.fields.items(), key=key)
def get_field_names(self):
return [f[0] for f in self.get_sorted_fields()]
def get_fields(self):
return [f[1] for f in self.get_sorted_fields()]
def get_field_index(self, field):
for i, (field_name, field_obj) in enumerate(self.get_sorted_fields()):
if field_name ==
return i
return -1
def rel_for_model(self, model, field_obj=None):
is_field = isinstance(field_obj, Field)
is_node = not is_field and isinstance(field_obj, Node)
for field in self.get_fields():
if isinstance(field, ForeignKeyField) and field.rel_model == model:
is_match = any((
field_obj is None,
is_field and ==,
is_node and field_obj._alias ==
if is_match:
return field
def reverse_rel_for_model(self, model, field_obj=None):
return model._meta.rel_for_model(self.model_class, field_obj)
def rel_exists(self, model):
return self.rel_for_model(model) or self.reverse_rel_for_model(model)
def related_models(self, backrefs=False):
models = []
stack = [self.model_class]
while stack:
model = stack.pop()
if model in models:
for fk in model._meta.rel.values():
if backrefs:
for fk in model._meta.reverse_rel.values():
return models
class BaseModel(type):
inheritable = set(['constraints', 'database', 'indexes', 'order_by',
'primary_key', 'schema', 'validate_backrefs'])
def __new__(cls, name, bases, attrs):
if not bases:
return super(BaseModel, cls).__new__(cls, name, bases, attrs)
meta_options = {}
meta = attrs.pop('Meta', None)
if meta:
for k, v in meta.__dict__.items():
if not k.startswith('_'):
meta_options[k] = v
model_pk = getattr(meta, 'primary_key', None)
parent_pk = None
# inherit any field descriptors by deep copying the underlying field
# into the attrs of the new model, additionally see if the bases define
# inheritable model options and swipe them
for b in bases:
if not hasattr(b, '_meta'):
base_meta = getattr(b, '_meta')
if parent_pk is None:
parent_pk = deepcopy(base_meta.primary_key)
all_inheritable = cls.inheritable | base_meta._additional_keys
for (k, v) in base_meta.__dict__.items():
if k in all_inheritable and k not in meta_options:
meta_options[k] = v
for (k, v) in b.__dict__.items():