Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Whitespace, hee-YAH!

  • Loading branch information...
commit e469be1e3d1f75ab79f54fff6e08214539c13902 1 parent 6cf9de8
@medwards medwards authored
View
2  bench/run_bench.py
@@ -76,7 +76,7 @@ def test_list_entries_for_user(m):
b = m.create_blog(u, 'blog%d' % i)
for j in xrange(10):
e = m.create_entry(b, 'entry%d' % i, '')
-
+
for user in m.list_users():
for i in xrange(100):
entries = m.list_entries_by_user(user)
View
2  bench/sqlalc_bench/models.py
@@ -10,7 +10,7 @@
class User(Base):
__tablename__ = 'sqlalc_users'
-
+
id = Column(Integer, primary_key=True)
username = Column(String)
active = Column(Boolean)
View
4 example/app.py
@@ -186,7 +186,7 @@ def join():
email=request.form['email'],
join_date=datetime.datetime.now()
)
-
+
# mark the user as being 'authenticated' by setting the session vars
auth_user(user)
return redirect(url_for('homepage'))
@@ -237,7 +237,7 @@ def user_detail(username):
# using the "get_object_or_404" shortcut here to get a user with a valid
# username or short-circuit and display a 404 if no user exists in the db
user = get_object_or_404(User, username=username)
-
+
# get all the users messages ordered newest-first -- note how we're accessing
# the messages -- user.message_set. could also have written it as:
# Message.select().where(user=user).order_by(('pub_date', 'desc'))
View
686 peewee.py
@@ -1,10 +1,10 @@
# (\
# ( \ /(o)\ caw!
# ( \/ ()/ /)
-# ( `;.))'".)
+# ( `;.))'".)
# `(/////.-'
-# =====))=))===()
-# ///'
+# =====))=))===()
+# ///'
# //
# '
from __future__ import with_statement
@@ -37,7 +37,7 @@
'ImproperlyConfigured', 'SqliteDatabase', 'MySQLDatabase', 'PostgresqlDatabase',
'asc', 'desc', 'Count', 'Max', 'Min', 'Sum', 'Q', 'Field', 'CharField', 'TextField',
'DateTimeField', 'BooleanField', 'DecimalField', 'FloatField', 'IntegerField',
- 'PrimaryKeyField', 'ForeignKeyField', 'DoubleField', 'BigIntegerField', 'Model',
+ 'PrimaryKeyField', 'ForeignKeyField', 'DoubleField', 'BigIntegerField', 'Model',
'filter_query', 'annotate_query', 'F', 'R',
]
@@ -67,11 +67,11 @@ class BaseAdapter(object):
level `Database` abstraction and the underlying python libraries like
psycopg2. It also provides a way to unify the pythonic field types with
the underlying column types used by the database engine.
-
- The `BaseAdapter` provides two types of mappings:
+
+ The `BaseAdapter` provides two types of mappings:
- mapping between filter operations and their database equivalents
- mapping between basic field types and their database column types
-
+
The `BaseAdapter` also is the mechanism used by the `Database` class to:
- handle connections with the database
- extract information from the database cursor
@@ -83,7 +83,7 @@ class BaseAdapter(object):
subquery_delete_same_table = True
reserved_tables = []
quote_char = '"'
-
+
def get_field_types(self):
field_types = {
'integer': 'INTEGER',
@@ -102,16 +102,16 @@ def get_field_types(self):
}
field_types.update(self.get_field_overrides())
return field_types
-
+
def get_field_overrides(self):
return {}
-
+
def connect(self, database, **kwargs):
raise NotImplementedError
-
+
def close(self, conn):
conn.close()
-
+
def lookup_cast(self, lookup, value):
"""
When a lookup is being performed as a part of a WHERE clause, provides
@@ -123,10 +123,10 @@ def lookup_cast(self, lookup, value):
elif lookup in ('startswith', 'istartswith'):
return '%s%%' % value
return value
-
+
def last_insert_id(self, cursor, model):
return cursor.lastrowid
-
+
def rows_affected(self, cursor):
return cursor.rowcount
@@ -150,12 +150,12 @@ class SqliteAdapter(BaseAdapter):
'startswith': "GLOB %s",
}
interpolation = '?'
-
+
def connect(self, database, **kwargs):
if not sqlite3:
raise ImproperlyConfigured('sqlite3 must be installed on the system')
return sqlite3.connect(database, **kwargs)
-
+
def lookup_cast(self, lookup, value):
if lookup == 'contains':
return '*%s*' % value
@@ -193,7 +193,7 @@ def connect(self, database, **kwargs):
if not psycopg2:
raise ImproperlyConfigured('psycopg2 must be installed on the system')
return psycopg2.connect(database=database, **kwargs)
-
+
def get_field_overrides(self):
return {
'primary_key': 'SERIAL',
@@ -205,7 +205,7 @@ def get_field_overrides(self):
'boolean': 'BOOLEAN',
'blob': 'BYTEA',
}
-
+
def last_insert_id(self, cursor, model):
if model._meta.pk_sequence:
cursor.execute("SELECT CURRVAL('\"%s\"')" % (
@@ -214,7 +214,7 @@ def last_insert_id(self, cursor, model):
cursor.execute("SELECT CURRVAL('\"%s_%s_seq\"')" % (
model._meta.db_table, model._meta.pk_name))
return cursor.fetchone()[0]
-
+
class MySQLAdapter(BaseAdapter):
operations = {
@@ -273,38 +273,38 @@ def inner(self, *args, **kwargs):
raise ValueError('%s adapter does not support sequences' % (self.adapter))
return func(self, *args, **kwargs)
return inner
-
+
def __init__(self, adapter, database, threadlocals=False, autocommit=True, **connect_kwargs):
self.adapter = adapter
self.database = database
self.connect_kwargs = connect_kwargs
-
+
if threadlocals:
self.__local = threading.local()
else:
self.__local = type('DummyLocal', (object,), {})
-
+
self._conn_lock = threading.Lock()
self.autocommit = autocommit
-
+
def connect(self):
with self._conn_lock:
self.__local.conn = self.adapter.connect(self.database, **self.connect_kwargs)
self.__local.closed = False
-
+
def close(self):
with self._conn_lock:
self.adapter.close(self.__local.conn)
self.__local.closed = True
-
+
def get_conn(self):
if not hasattr(self.__local, 'closed') or self.__local.closed:
self.connect()
return self.__local.conn
-
+
def get_cursor(self):
return self.get_conn().cursor()
-
+
def execute(self, sql, params=None):
cursor = self.get_cursor()
res = cursor.execute(sql, params or ())
@@ -312,21 +312,21 @@ def execute(self, sql, params=None):
self.commit()
logger.debug((sql, params))
return cursor
-
+
def commit(self):
self.get_conn().commit()
-
+
def rollback(self):
self.get_conn().rollback()
-
+
def set_autocommit(self, autocommit):
self.__local.autocommit = autocommit
-
+
def get_autocommit(self):
if not hasattr(self.__local, 'autocommit'):
self.set_autocommit(self.autocommit)
return self.__local.autocommit
-
+
def commit_on_success(self, func):
def inner(*args, **kwargs):
orig = self.get_autocommit()
@@ -342,20 +342,20 @@ def inner(*args, **kwargs):
finally:
self.set_autocommit(orig)
return inner
-
+
def last_insert_id(self, cursor, model):
if model._meta.auto_increment:
return self.adapter.last_insert_id(cursor, model)
-
+
def rows_affected(self, cursor):
return self.adapter.rows_affected(cursor)
def quote_name(self, name):
return ''.join((self.adapter.quote_char, name, self.adapter.quote_char))
-
+
def column_for_field(self, field):
return self.column_for_field_type(field.get_db_field())
-
+
def column_for_field_type(self, db_field_type):
try:
return self.adapter.get_field_types()[db_field_type]
@@ -382,22 +382,22 @@ def create_table_query(self, model_class, safe):
def create_table(self, model_class, safe=False):
self.execute(self.create_table_query(model_class, safe))
-
+
def create_index_query(self, model_class, field_name, unique):
framing = 'CREATE %(unique)s INDEX %(index)s ON %(table)s(%(field)s);'
-
+
if field_name not in model_class._meta.fields:
raise AttributeError(
'Field %s not on model %s' % (field_name, model_class)
)
-
+
field_obj = model_class._meta.fields[field_name]
db_table = model_class._meta.db_table
index_name = self.quote_name('%s_%s' % (db_table, field_obj.db_column))
-
+
unique_expr = ternary(unique, 'UNIQUE', '')
-
+
return framing % {
'unique': unique_expr,
'index': index_name,
@@ -410,18 +410,18 @@ def create_index(self, model_class, field_name, unique=False):
def create_foreign_key(self, model_class, field):
return self.create_index(model_class, field.name, field.unique)
-
+
def drop_table(self, model_class, fail_silently=False):
framing = fail_silently and 'DROP TABLE IF EXISTS %s;' or 'DROP TABLE %s;'
self.execute(framing % self.quote_name(model_class._meta.db_table))
-
+
def add_column_sql(self, model_class, field_name):
field = model_class._meta.fields[field_name]
return 'ALTER TABLE %s ADD COLUMN %s' % (
self.quote_name(model_class._meta.db_table),
self.field_sql(field),
)
-
+
def rename_column_sql(self, model_class, field_name, new_name):
# this assumes that the field on the model points to the *old* fieldname
field = model_class._meta.fields[field_name]
@@ -430,28 +430,28 @@ def rename_column_sql(self, model_class, field_name, new_name):
self.quote_name(field.db_column),
self.quote_name(new_name),
)
-
+
def drop_column_sql(self, model_class, field_name):
field = model_class._meta.fields[field_name]
return 'ALTER TABLE %s DROP COLUMN %s' % (
self.quote_name(model_class._meta.db_table),
self.quote_name(field.db_column),
)
-
+
@require_sequence_support
def create_sequence(self, sequence_name):
return self.execute('CREATE SEQUENCE %s;' % self.quote_name(sequence_name))
-
+
@require_sequence_support
def drop_sequence(self, sequence_name):
return self.execute('DROP SEQUENCE %s;' % self.quote_name(sequence_name))
-
+
def get_indexes_for_table(self, table):
raise NotImplementedError
-
+
def get_tables(self):
raise NotImplementedError
-
+
def sequence_exists(self, sequence):
raise NotImplementedError
@@ -459,19 +459,19 @@ def sequence_exists(self, sequence):
class SqliteDatabase(Database):
def __init__(self, database, **connect_kwargs):
super(SqliteDatabase, self).__init__(SqliteAdapter(), database, **connect_kwargs)
-
+
def get_indexes_for_table(self, table):
res = self.execute('PRAGMA index_list(%s);' % self.quote_name(table))
rows = sorted([(r[1], r[2] == 1) for r in res.fetchall()])
return rows
-
+
def get_tables(self):
res = self.execute('select name from sqlite_master where type="table" order by name')
return [r[0] for r in res.fetchall()]
-
+
def drop_column_sql(self, model_class, field_name):
raise NotImplementedError('Sqlite3 does not have direct support for dropping columns')
-
+
def rename_column_sql(self, model_class, field_name, new_name):
raise NotImplementedError('Sqlite3 does not have direct support for renaming columns')
@@ -479,7 +479,7 @@ def rename_column_sql(self, model_class, field_name, new_name):
class PostgresqlDatabase(Database):
def __init__(self, database, **connect_kwargs):
super(PostgresqlDatabase, self).__init__(PostgresqlAdapter(), database, **connect_kwargs)
-
+
def get_indexes_for_table(self, table):
res = self.execute("""
SELECT c2.relname, i.indisprimary, i.indisunique
@@ -487,7 +487,7 @@ def get_indexes_for_table(self, table):
WHERE c.relname = %s AND c.oid = i.indrelid AND i.indexrelid = c2.oid
ORDER BY i.indisprimary DESC, i.indisunique DESC, c2.relname""", (table,))
return sorted([(r[0], r[1]) for r in res.fetchall()])
-
+
def get_tables(self):
res = self.execute("""
SELECT c.relname
@@ -498,7 +498,7 @@ def get_tables(self):
AND pg_catalog.pg_table_is_visible(c.oid)
ORDER BY c.relname""")
return [row[0] for row in res.fetchall()]
-
+
def sequence_exists(self, sequence):
res = self.execute("""
SELECT COUNT(*)
@@ -512,9 +512,9 @@ def sequence_exists(self, sequence):
class MySQLDatabase(Database):
def __init__(self, database, **connect_kwargs):
super(MySQLDatabase, self).__init__(MySQLAdapter(), database, **connect_kwargs)
-
+
def create_foreign_key(self, model_class, field):
- framing = """
+ framing = """
ALTER TABLE %(table)s ADD CONSTRAINT %(constraint)s
FOREIGN KEY (%(field)s) REFERENCES %(to)s(%(to_field)s)%(cascade)s;
"""
@@ -524,7 +524,7 @@ def create_foreign_key(self, model_class, field):
field.to._meta.db_table,
field.db_column,
)
-
+
query = framing % {
'table': self.quote_name(db_table),
'constraint': self.quote_name(constraint),
@@ -533,10 +533,10 @@ def create_foreign_key(self, model_class, field):
'to_field': self.quote_name(field.to._meta.pk_name),
'cascade': ' ON DELETE CASCADE' if field.cascade else '',
}
-
+
self.execute(query)
return super(MySQLDatabase, self).create_foreign_key(model_class, field)
-
+
def rename_column_sql(self, model_class, field_name, new_name):
field = model_class._meta.fields[field_name]
return 'ALTER TABLE %s CHANGE COLUMN %s %s %s' % (
@@ -545,12 +545,12 @@ def rename_column_sql(self, model_class, field_name, new_name):
self.quote_name(new_name),
field.render_field_template(),
)
-
+
def get_indexes_for_table(self, table):
res = self.execute('SHOW INDEXES IN %s;' % self.quote_name(table))
rows = sorted([(r[2], r[1] == 0) for r in res.fetchall()])
return rows
-
+
def get_tables(self):
res = self.execute('SHOW TABLES;')
return [r[0] for r in res.fetchall()]
@@ -569,13 +569,13 @@ def __init__(self, model, cursor, meta=None):
self.query_meta = meta or {}
self.column_meta = self.query_meta.get('columns')
self.join_meta = self.query_meta.get('graph')
-
+
self.__ct = 0
self.__idx = 0
-
+
self._result_cache = []
self._populated = False
-
+
def model_from_rowset(self, model_class, attr_dict):
instance = model_class()
for attr, value in attr_dict.iteritems():
@@ -585,11 +585,11 @@ def model_from_rowset(self, model_class, attr_dict):
else:
setattr(instance, attr, value)
return instance
-
+
def _row_to_dict(self, row):
return dict((self.cursor.description[i][0], value)
for i, value in enumerate(row))
-
+
def construct_instance(self, row):
if not self.column_meta:
# use attribute names pulled from the result cursor description,
@@ -601,7 +601,7 @@ def construct_instance(self, row):
collected_models = {}
for i, (model, col) in enumerate(self.column_meta):
value = row[i]
-
+
if isinstance(col, tuple):
if len(col) == 3:
model = self.model # special-case aggregates
@@ -610,42 +610,42 @@ def construct_instance(self, row):
col_name, attr = col
else:
col_name = attr = col
-
+
if model not in collected_models:
collected_models[model] = model()
-
+
instance = collected_models[model]
-
+
if col_name in instance._meta.columns:
field = instance._meta.columns[col_name]
setattr(instance, field.name, field.python_value(value))
else:
setattr(instance, attr, value)
-
+
return self.follow_joins(self.join_meta, collected_models, self.model)
-
+
def follow_joins(self, joins, collected_models, current):
inst = collected_models[current]
-
+
if current not in joins:
return inst
-
+
for joined_model, _, _ in joins[current]:
if joined_model in collected_models:
joined_inst = self.follow_joins(joins, collected_models, joined_model)
fk_field = current._meta.get_related_field_for_model(joined_model)
-
+
if not fk_field:
continue
-
+
if not joined_inst.get_pk():
joined_inst.set_pk(getattr(inst, fk_field.id_storage))
-
+
setattr(inst, fk_field.name, joined_inst)
setattr(inst, fk_field.id_storage, joined_inst.get_pk())
-
+
return inst
-
+
def __iter__(self):
self.__idx = 0
@@ -653,14 +653,14 @@ def __iter__(self):
return self
else:
return iter(self._result_cache)
-
+
def first(self):
try:
self.__idx = 0 # move to beginning of the list
inst = self.next()
except StopIteration:
inst = None
-
+
self.__idx = 0
return inst
@@ -671,7 +671,7 @@ def fill_cache(self):
for x in self:
pass
self.__idx = idx
-
+
def iterate(self):
row = self.cursor.fetchone()
if row:
@@ -679,17 +679,17 @@ def iterate(self):
else:
self._populated = True
raise StopIteration
-
+
def iterator(self):
while 1:
yield self.iterate()
-
+
def next(self):
if self.__idx < self.__ct:
inst = self._result_cache[self.__idx]
self.__idx += 1
return inst
-
+
instance = self.iterate()
self._result_cache.append(instance)
self.__ct += 1
@@ -740,7 +740,7 @@ def __init__(self, connector='AND', children=None):
self.connector = connector
self.children = children or []
self.negated = False
-
+
def connect(self, rhs, connector):
if isinstance(rhs, Leaf):
if connector == self.connector:
@@ -754,20 +754,20 @@ def connect(self, rhs, connector):
p = Node(connector)
p.children = [self, rhs]
return p
-
+
def __or__(self, rhs):
return self.connect(rhs, 'OR')
def __and__(self, rhs):
return self.connect(rhs, 'AND')
-
+
def __invert__(self):
self.negated = not self.negated
return self
def __nonzero__(self):
return bool(self.children)
-
+
def __unicode__(self):
query = []
nodes = []
@@ -787,20 +787,20 @@ def __unicode__(self):
class Leaf(object):
def __init__(self):
self.parent = None
-
+
def connect(self, connector):
if self.parent is None:
self.parent = Node(connector)
self.parent.children.append(self)
-
+
def __or__(self, rhs):
self.connect('OR')
return self.parent | rhs
-
+
def __and__(self, rhs):
self.connect('AND')
return self.parent & rhs
-
+
def __invert__(self):
self.negated = not self.negated
return self
@@ -812,7 +812,7 @@ def __init__(self, _model=None, **kwargs):
self.query = kwargs
self.negated = False
super(Q, self).__init__()
-
+
def __unicode__(self):
bits = ['%s = %s' % (k, v) for k, v in self.query.items()]
if len(self.query.items()) > 1:
@@ -847,13 +847,13 @@ class R(Leaf):
def __init__(self, *params):
self.params = params
super(R, self).__init__()
-
+
def sql_select(self):
if len(self.params) == 2:
return self.params
else:
raise ValueError('Incorrect number of argument provided for R() expression')
-
+
def sql_where(self):
return self.params[0], self.params[1:]
@@ -878,7 +878,7 @@ def parseq(model, *args, **kwargs):
of where clauses when querying.
"""
node = Node()
-
+
for piece in args:
apply_model(model, piece)
if isinstance(piece, (Q, R, Node)):
@@ -913,41 +913,41 @@ class EmptyResultException(Exception):
class BaseQuery(object):
query_separator = '__'
force_alias = False
-
+
def __init__(self, model):
self.model = model
self.query_context = model
self.database = self.model._meta.database
self.operations = self.database.adapter.operations
self.interpolation = self.database.adapter.interpolation
-
+
self._dirty = True
self._where = []
self._where_models = set()
self._joins = {}
self._joined_models = set()
-
+
def _clone_dict_graph(self, dg):
cloned = {}
for node, edges in dg.items():
cloned[node] = list(edges)
return cloned
-
+
def clone_where(self):
return list(self._where)
-
+
def clone_joins(self):
return self._clone_dict_graph(self._joins)
-
+
def clone(self):
raise NotImplementedError
def qn(self, name):
return self.database.quote_name(name)
-
+
def lookup_cast(self, lookup, value):
return self.database.adapter.lookup_cast(lookup, value)
-
+
def parse_query_args(self, model, **query):
"""
Parse out and normalize clauses in a query. The query is composed of
@@ -961,21 +961,21 @@ def parse_query_args(self, model, **query):
lhs, op = lhs.rsplit(self.query_separator, 1)
else:
op = 'eq'
-
+
if lhs in model._meta.columns:
lhs = model._meta.columns[lhs].name
-
+
try:
field = model._meta.get_field_by_name(lhs)
except AttributeError:
field = model._meta.get_related_field_by_name(lhs)
if field is None:
raise
-
+
if isinstance(rhs, R):
expr, params = rhs.sql_where()
lookup_value = [field.db_value(o) for o in params]
-
+
combined_expr = self.operations[op] % expr
operation = combined_expr % tuple(self.interpolation for p in params)
elif isinstance(rhs, F):
@@ -1009,13 +1009,13 @@ def parse_query_args(self, model, **query):
else:
lookup_value = field.db_value(rhs)
operation = self.operations[op] % self.interpolation
-
+
parsed.append(
(field.db_column, (operation, self.lookup_cast(op, lookup_value)))
)
-
+
return parsed
-
+
@returns_clone
def where(self, *args, **kwargs):
parsed = parseq(self.query_context, *args, **kwargs)
@@ -1044,7 +1044,7 @@ def switch(self, model):
self.query_context = model
return
raise AttributeError('You must JOIN on %s' % model.__name__)
-
+
def use_aliases(self):
return len(self._joined_models) > 0 or self.force_alias
@@ -1060,23 +1060,23 @@ def safe_combine(self, model, alias, col):
elif col in model._meta.fields:
return self.combine_field(alias, model._meta.fields[col].db_column)
return col
-
+
def follow_joins(self, current, alias_map, alias_required, alias_count, seen=None):
computed = []
seen = seen or set()
-
+
if current not in self._joins:
return computed
-
+
for i, (model, join_type, on) in enumerate(self._joins[current]):
seen.add(model)
-
+
if alias_required:
alias_count += 1
alias_map[model] = 't%d' % alias_count
else:
alias_map[model] = ''
-
+
from_model = current
field = from_model._meta.get_related_field_for_model(model, on)
if field:
@@ -1086,13 +1086,13 @@ def follow_joins(self, current, alias_map, alias_required, alias_count, seen=Non
field = from_model._meta.get_reverse_related_field_for_model(model, on)
left_field = from_model._meta.pk_name
right_field = field.db_column
-
+
if join_type is None:
if field.null and model not in self._where_models:
join_type = 'LEFT OUTER'
else:
join_type = 'INNER'
-
+
computed.append(
'%s JOIN %s AS %s ON %s = %s' % (
join_type,
@@ -1102,11 +1102,11 @@ def follow_joins(self, current, alias_map, alias_required, alias_count, seen=Non
self.combine_field(alias_map[model], right_field),
)
)
-
+
computed.extend(self.follow_joins(model, alias_map, alias_required, alias_count, seen))
-
+
return computed
-
+
def compile_where(self):
alias_count = 0
alias_map = {}
@@ -1117,13 +1117,13 @@ def compile_where(self):
alias_map[self.model] = 't%d' % alias_count
else:
alias_map[self.model] = ''
-
+
computed_joins = self.follow_joins(self.model, alias_map, alias_required, alias_count)
-
+
clauses = [self.parse_node(node, alias_map) for node in self._where]
-
+
return computed_joins, clauses, alias_map
-
+
def flatten_clauses(self, clauses):
where_with_alias = []
where_data = []
@@ -1131,7 +1131,7 @@ def flatten_clauses(self, clauses):
where_with_alias.append(query)
where_data.extend(data)
return where_with_alias, where_data
-
+
def convert_where_to_params(self, where_data):
flattened = []
for clause in where_data:
@@ -1140,7 +1140,7 @@ def convert_where_to_params(self, where_data):
else:
flattened.append(clause)
return flattened
-
+
def parse_node(self, node, alias_map):
query = []
query_data = []
@@ -1162,7 +1162,7 @@ def parse_node(self, node, alias_map):
if node.negated:
query = 'NOT (%s)' % query
return query, query_data
-
+
def parse_q(self, q, alias_map):
model = q.model or self.model
query = []
@@ -1179,27 +1179,27 @@ def parse_q(self, q, alias_map):
operation = operation % self.parse_f(value, f_model, alias_map)
else:
query_data.append(value)
-
+
combined = self.combine_field(alias_map[model], name)
query.append('%s %s' % (combined, operation))
-
+
if len(query) > 1:
query = '(%s)' % (' AND '.join(query))
else:
query = query[0]
-
+
if q.negated:
query = 'NOT %s' % query
-
+
return query, query_data
-
+
def parse_f(self, f_object, model, alias_map):
combined = self.combine_field(alias_map[model], f_object.field)
if f_object.op is not None:
combined = '(%s %s %s)' % (combined, f_object.op, f_object.value)
return combined
-
+
def parse_r(self, r_object, alias_map):
return r_object.sql_where()
@@ -1207,25 +1207,25 @@ def convert_subquery(self, subquery):
orig_query = subquery.query
if subquery.query == '*':
subquery.query = subquery.model._meta.pk_name
-
+
subquery.force_alias, orig_alias = True, subquery.force_alias
sql, data = subquery.sql()
subquery.query = orig_query
subquery.force_alias = orig_alias
return sql, data
-
+
def sorted_models(self, alias_map):
return [
(model, alias) \
for (model, alias) in sorted(alias_map.items(), key=lambda i: i[1])
]
-
+
def sql(self):
raise NotImplementedError
-
+
def execute(self):
raise NotImplementedError
-
+
def raw_execute(self, query, params):
return self.database.execute(query, params)
@@ -1238,22 +1238,22 @@ def __init__(self, model, query, *params):
def clone(self):
return RawQuery(self.model, self._sql, *self._params)
-
+
def sql(self):
return self._sql, self._params
-
+
def execute(self):
return QueryResultWrapper(self.model, self.raw_execute(*self.sql()))
-
+
def join(self):
raise AttributeError('Raw queries do not support joining programmatically')
-
+
def where(self):
raise AttributeError('Raw queries do not support querying programmatically')
-
+
def switch(self):
raise AttributeError('Raw queries do not support switching contexts')
-
+
def __iter__(self):
return iter(self.execute())
@@ -1270,7 +1270,7 @@ def __init__(self, model, query=None):
self._qr = None
self._for_update = False
super(SelectQuery, self).__init__(model)
-
+
def clone(self):
query = SelectQuery(self.model, self.query)
query.query_context = self.query_context
@@ -1287,53 +1287,53 @@ def clone(self):
query._joined_models = self._joined_models.copy()
query._joins = self.clone_joins()
return query
-
+
@returns_clone
def paginate(self, page, paginate_by=20):
if page > 0:
page -= 1
self._limit = paginate_by
self._offset = page * paginate_by
-
+
@returns_clone
def limit(self, num_rows):
self._limit = num_rows
-
+
@returns_clone
def offset(self, num_rows):
self._offset = num_rows
-
+
@returns_clone
def for_update(self, for_update=True):
self._for_update = for_update
-
+
def count(self):
if self._distinct or self._group_by:
return self.wrapped_count()
-
+
clone = self.order_by()
clone._limit = clone._offset = None
-
+
if clone.use_aliases():
clone.query = 'COUNT(t1.%s)' % (clone.model._meta.pk_name)
else:
clone.query = 'COUNT(%s)' % (clone.model._meta.pk_name)
-
+
res = clone.database.execute(*clone.sql())
-
+
return (res.fetchone() or [0])[0]
-
+
def wrapped_count(self):
clone = self.order_by()
clone._limit = clone._offset = None
-
+
sql, params = clone.sql()
query = 'SELECT COUNT(1) FROM (%s) AS wrapped_select' % sql
-
+
res = clone.database.execute(query, params)
-
+
return res.fetchone()[0]
-
+
@returns_clone
def group_by(self, *clauses):
model = self.query_context
@@ -1345,21 +1345,21 @@ def group_by(self, *clauses):
elif issubclass(clause, Model):
model = clause
fields = clause._meta.get_field_names()
-
+
self._group_by.append((model, fields))
-
+
@returns_clone
def having(self, *clauses):
self._having = clauses
-
+
@returns_clone
def distinct(self):
self._distinct = True
-
+
@returns_clone
def order_by(self, *clauses):
order_by = []
-
+
for clause in clauses:
if isinstance(clause, tuple):
if len(clause) == 3:
@@ -1377,19 +1377,19 @@ def order_by(self, *clauses):
model = self.query_context
field = clause
ordering = 'ASC'
-
+
order_by.append(
(model, field, ordering)
)
-
+
self._order_by = order_by
-
+
def exists(self):
clone = self.paginate(1, 1)
clone.query = '(1) AS a'
curs = self.database.execute(*clone.sql())
return bool(curs.fetchone())
-
+
def get(self, *args, **kwargs):
orig_ctx = self.query_context
self.query_context = self.model
@@ -1403,16 +1403,16 @@ def get(self, *args, **kwargs):
))
finally:
self.query_context = orig_ctx
-
+
def filter(self, *args, **kwargs):
return filter_query(self, *args, **kwargs)
-
+
def annotate(self, related_model, aggregation=None):
return annotate_query(self, related_model, aggregation)
def parse_select_query(self, alias_map):
q = self.query
-
+
if isinstance(q, (list, tuple)):
q = {self.model: self.query}
elif isinstance(q, basestring):
@@ -1423,30 +1423,30 @@ def parse_select_query(self, alias_map):
q = {self.model: [self.model._meta.pk_name]}
else:
return q, []
-
+
# by now we should have a dictionary if a valid type was passed in
if not isinstance(q, dict):
raise TypeError('Unknown type encountered parsing select query')
-
+
# gather aliases and models
sorted_models = self.sorted_models(alias_map)
-
+
# normalize if we are working with a dictionary
columns = []
model_cols = []
-
+
for model, alias in sorted_models:
if model not in q:
continue
-
+
if '*' in q[model]:
idx = q[model].index('*')
q[model] = q[model][:idx] + model._meta.get_field_names() + q[model][idx+1:]
-
+
for clause in q[model]:
if isinstance(clause, R):
clause = clause.sql_select()
-
+
if isinstance(clause, tuple):
if len(clause) == 3:
func, col_name, col_alias = clause
@@ -1468,20 +1468,20 @@ def parse_select_query(self, alias_map):
column = model._meta.get_column(clause)
columns.append(self.safe_combine(model, alias, column))
model_cols.append((model, column))
-
+
return ', '.join(columns), model_cols
-
+
def sql_meta(self):
joins, clauses, alias_map = self.compile_where()
where, where_data = self.flatten_clauses(clauses)
-
+
table = self.qn(self.model._meta.db_table)
params = []
group_by = []
use_aliases = self.use_aliases()
-
+
if use_aliases:
table = '%s AS %s' % (table, alias_map[self.model])
@@ -1499,18 +1499,18 @@ def sql_meta(self):
'columns': model_cols,
'graph': self._joins,
}
-
+
if self._distinct:
sel = 'SELECT DISTINCT'
else:
sel = 'SELECT'
-
+
select = '%s %s FROM %s' % (sel, parsed_query, table)
joins = '\n'.join(joins)
where = ' AND '.join(where)
group_by = ', '.join(group_by)
having = ' AND '.join(self._having)
-
+
order_by = []
for piece in self._order_by:
model, field, ordering = piece
@@ -1518,17 +1518,17 @@ def sql_meta(self):
alias = alias_map[model]
else:
alias = ''
-
+
order_by.append('%s %s' % (self.safe_combine(model, alias, field), ordering))
-
+
pieces = [select]
-
+
if joins:
pieces.append(joins)
if where:
pieces.append('WHERE %s' % where)
params.extend(self.convert_where_to_params(where_data))
-
+
if group_by:
pieces.append('GROUP BY %s' % group_by)
if having:
@@ -1539,16 +1539,16 @@ def sql_meta(self):
pieces.append('LIMIT %d' % self._limit)
if self._offset:
pieces.append('OFFSET %d' % self._offset)
-
+
if self._for_update and self.database.adapter.for_update_support:
pieces.append('FOR UPDATE')
-
+
return ' '.join(pieces), params, query_meta
-
+
def sql(self):
query, params, meta = self.sql_meta()
return query, params
-
+
def execute(self):
if self._dirty or not self._qr:
try:
@@ -1562,7 +1562,7 @@ def execute(self):
else:
# call the __iter__ method directly
return self._qr
-
+
def __iter__(self):
return iter(self.execute())
@@ -1571,7 +1571,7 @@ class UpdateQuery(BaseQuery):
def __init__(self, model, **kwargs):
self.update_query = kwargs
super(UpdateQuery, self).__init__(model)
-
+
def clone(self):
query = UpdateQuery(self.model, **self.update_query)
query._where = self.clone_where()
@@ -1579,27 +1579,27 @@ def clone(self):
query._joined_models = self._joined_models.copy()
query._joins = self.clone_joins()
return query
-
+
def parse_update(self):
sets = {}
for k, v in self.update_query.iteritems():
if k in self.model._meta.columns:
k = self.model._meta.columns[k].name
-
+
try:
field = self.model._meta.get_field_by_name(k)
except AttributeError:
field = self.model._meta.get_related_field_by_name(k)
if field is None:
raise
-
+
if not isinstance(v, F):
v = field.db_value(v)
-
+
sets[field.db_column] = v
-
+
return sets
-
+
def sql(self):
joins, clauses, alias_map = self.compile_where()
where, where_data = self.flatten_clauses(clauses)
@@ -1607,7 +1607,7 @@ def sql(self):
params = []
update_params = []
-
+
alias = alias_map.get(self.model)
for k, v in set_statement.iteritems():
@@ -1616,24 +1616,24 @@ def sql(self):
else:
params.append(v)
value = self.interpolation
-
+
update_params.append('%s=%s' % (self.combine_field(alias, k), value))
-
+
update = 'UPDATE %s SET %s' % (
self.qn(self.model._meta.db_table), ', '.join(update_params))
where = ' AND '.join(where)
-
+
pieces = [update]
-
+
if where:
pieces.append('WHERE %s' % where)
params.extend(self.convert_where_to_params(where_data))
-
+
return ' '.join(pieces), params
-
+
def join(self, *args, **kwargs):
raise AttributeError('Update queries do not support JOINs in sqlite')
-
+
def execute(self):
result = self.raw_execute(*self.sql())
return self.database.rows_affected(result)
@@ -1647,27 +1647,27 @@ def clone(self):
query._joined_models = self._joined_models.copy()
query._joins = self.clone_joins()
return query
-
+
def sql(self):
joins, clauses, alias_map = self.compile_where()
where, where_data = self.flatten_clauses(clauses)
params = []
-
+
delete = 'DELETE FROM %s' % (self.qn(self.model._meta.db_table))
where = ' AND '.join(where)
-
+
pieces = [delete]
-
+
if where:
pieces.append('WHERE %s' % where)
params.extend(self.convert_where_to_params(where_data))
-
+
return ' '.join(pieces), params
-
+
def join(self, *args, **kwargs):
raise AttributeError('Update queries do not support JOINs in sqlite')
-
+
def execute(self):
result = self.raw_execute(*self.sql())
return self.database.rows_affected(result)
@@ -1677,43 +1677,43 @@ class InsertQuery(BaseQuery):
def __init__(self, model, **kwargs):
self.insert_query = kwargs
super(InsertQuery, self).__init__(model)
-
+
def parse_insert(self):
cols = []
vals = []
for k, v in self.insert_query.iteritems():
if k in self.model._meta.columns:
k = self.model._meta.columns[k].name
-
+
try:
field = self.model._meta.get_field_by_name(k)
except AttributeError:
field = self.model._meta.get_related_field_by_name(k)
if field is None:
raise
-
+
cols.append(self.qn(field.db_column))
vals.append(field.db_value(v))
-
+
return cols, vals
-
+
def sql(self):
cols, vals = self.parse_insert()
-
+
insert = 'INSERT INTO %s (%s) VALUES (%s)' % (
self.qn(self.model._meta.db_table),
','.join(cols),
','.join(self.interpolation for v in vals)
)
-
+
return insert, vals
-
+
def where(self, *args, **kwargs):
raise AttributeError('Insert queries do not support WHERE clauses')
-
+
def join(self, *args, **kwargs):
raise AttributeError('Insert queries do not support JOINs')
-
+
def execute(self):
result = self.raw_execute(*self.sql())
return self.database.last_insert_id(result, self.model)
@@ -1733,22 +1733,22 @@ def convert_lookup(model, joins, lookup):
"""
Given a model, a graph of joins, and a lookup, return a tuple containing
a normalized lookup:
-
+
(model actually being queried, updated graph of joins, normalized lookup)
"""
operations = model._meta.database.adapter.operations
-
+
pieces = lookup.split('__')
operation = None
-
+
query_model = model
-
+
if len(pieces) > 1:
if pieces[-1] in operations:
operation = pieces.pop()
-
+
lookup = pieces.pop()
-
+
# we have some joins
if len(pieces):
for piece in pieces:
@@ -1757,10 +1757,10 @@ def convert_lookup(model, joins, lookup):
for field in query_model._meta.get_fields():
if not isinstance(field, ForeignKeyField):
continue
-
+
if piece in (field.name, field.db_column, field.related_name):
joined_model = field.to
-
+
if not joined_model:
try:
joined_model = query_model._meta.reverse_relations[piece]
@@ -1769,14 +1769,14 @@ def convert_lookup(model, joins, lookup):
piece,
query_model,
))
-
+
joins.setdefault(query_model, set())
joins[query_model].add(joined_model)
query_model = joined_model
-
+
if operation:
lookup = '%s__%s' % (lookup, operation)
-
+
return query_model, joins, lookup
@@ -1785,10 +1785,10 @@ def filter_query(model_or_query, *args, **kwargs):
Provide a django-like interface for executing queries
"""
model, select_query = model_or_select(model_or_query)
-
+
query = {} # mapping of models to queries
joins = {} # a graph of joins needed, passed into the convert_lookup function
-
+
# traverse Q() objects, find any joins that may be lurking -- clean up the
# lookups and assign the correct model
def fix_q(node_or_q, joins):
@@ -1803,16 +1803,16 @@ def fix_q(node_or_q, joins):
new_query[lookup] = value
node_or_q.model = query_model
node_or_q.query = new_query
-
+
for node_or_q in args:
fix_q(node_or_q, joins)
-
+
# iterate over keyword lookups and determine lookups and necessary joins
for raw_lookup, value in kwargs.items():
queried_model, joins, lookup = convert_lookup(model, joins, raw_lookup)
query.setdefault(queried_model, [])
query[queried_model].append((lookup, value))
-
+
def follow_joins(current, query):
if current in joins:
for joined_model in joins[current]:
@@ -1822,10 +1822,10 @@ def follow_joins(current, query):
query = follow_joins(joined_model, query)
return query
select_query = follow_joins(model, select_query)
-
+
for node in args:
select_query = select_query.where(node)
-
+
for model, lookups in query.items():
qargs, qkwargs = [], {}
for lookup in lookups:
@@ -1843,14 +1843,14 @@ def annotate_query(select_query, related_model, aggregation):
"""
aggregation = aggregation or Count(related_model._meta.pk_name)
model = select_query.model
-
+
select_query = select_query.switch(model)
cols = select_query.query
-
+
# ensure the join is there
if related_model not in select_query._joined_models:
select_query = select_query.join(related_model).switch(model)
-
+
# query for it
if isinstance(cols, dict):
selection = cols
@@ -1866,10 +1866,10 @@ def annotate_query(select_query, related_model, aggregation):
group_by = cols
else:
raise ValueError('Unknown type passed in to select query: "%s"' % type(cols))
-
+
# query for the related object
selection[related_model] = [aggregation]
-
+
select_query.query = selection
return select_query.group_by(group_by)
@@ -1900,10 +1900,10 @@ def render(self, db):
class VarCharColumn(Column):
db_field = 'string'
template = '%(column_type)s(%(max_length)d)'
-
+
def get_attributes(self):
return {'max_length': 255}
-
+
def db_value(self, value):
value = value or ''
return value[:self.attributes['max_length']]
@@ -1911,14 +1911,14 @@ def db_value(self, value):
class TextColumn(Column):
db_field = 'text'
-
+
def db_value(self, value):
return value or ''
class DateTimeColumn(Column):
db_field = 'datetime'
-
+
def python_value(self, value):
if isinstance(value, basestring):
value = value.rsplit('.', 1)[0]
@@ -1928,10 +1928,10 @@ def python_value(self, value):
class IntegerColumn(Column):
db_field = 'integer'
-
+
def db_value(self, value):
return value or 0
-
+
def python_value(self, value):
if value is not None:
return int(value)
@@ -1943,20 +1943,20 @@ class BigIntegerColumn(IntegerColumn):
class BooleanColumn(Column):
db_field = 'boolean'
-
+
def db_value(self, value):
return bool(value)
-
+
def python_value(self, value):
return bool(value)
class FloatColumn(Column):
db_field = 'float'
-
+
def db_value(self, value):
return value or 0.0
-
+
def python_value(self, value):
if value is not None:
return float(value)
@@ -1969,16 +1969,16 @@ class DoubleColumn(FloatColumn):
class DecimalColumn(Column):
db_field = 'decimal'
field_template = '%(column_type)s(%(max_digits)d, %(decimal_places)d)'
-
+
def get_attributes(self):
return {
'max_digits': 10,
'decimal_places': 5,
}
-
+
def db_value(self, value):
return value or decimal.Decimal(0)
-
+
def python_value(self, value):
if value is not None:
if isinstance(value, decimal.Decimal):
@@ -1997,12 +1997,12 @@ class FieldDescriptor(object):
def __init__(self, field):
self.field = field
self._cache_name = '__%s' % self.field.name
-
+
def __get__(self, instance, instance_type=None):
if instance:
return getattr(instance, self._cache_name, None)
return self.field
-
+
def __set__(self, instance, value):
setattr(instance, self._cache_name, value)
@@ -2025,10 +2025,10 @@ def __init__(self, null=False, db_index=False, unique=False, verbose_name=None,
self.default = default
self.attributes = kwargs
-
+
Field._field_counter += 1
self._order = Field._field_counter
-
+
def add_to_class(self, klass, name):
self.name = name
self.model = klass
@@ -2037,10 +2037,10 @@ def add_to_class(self, klass, name):
self.column = self.get_column()
setattr(klass, name, FieldDescriptor(self))
-
+
def get_column(self):
return self.column_class(**self.attributes)
-
+
def render_field_template(self):
params = {
'column': self.column.render(self.model._meta.database),
@@ -2048,15 +2048,15 @@ def render_field_template(self):
}
params.update(self.column.attributes)
return self.field_template % params
-
+
def db_value(self, value):
if (self.null and value is None):
return None
return self.column.db_value(value)
-
+
def python_value(self, value):
return self.column.python_value(value)
-
+
def lookup_value(self, lookup_type, value):
return self.db_value(value)
@@ -2066,7 +2066,7 @@ def class_prepared(self):
class CharField(Field):
column_class = VarCharColumn
-
+
class TextField(Field):
column_class = TextColumn
@@ -2121,23 +2121,23 @@ def get_column_class(self):
if self.model._meta.pk_sequence != None and self.model._meta.database.adapter.sequence_support:
self.column_class = PrimaryKeySequenceColumn
return self.column_class
-
+
def get_column(self):
return self.get_column_class()(**self.attributes)
-class ForeignRelatedObject(object):
+class ForeignRelatedObject(object):
def __init__(self, to, field):
self.to = to
self.field = field
self.field_name = self.field.name
self.field_column = self.field.id_storage
self.cache_name = '_cache_%s' % self.field_name
-
+
def __get__(self, instance, instance_type=None):
if not instance:
return self.field
-
+
if not getattr(instance, self.cache_name, None):
id = getattr(instance, self.field_column, 0)
qr = self.to.select().where(**{self.to._meta.pk_name: id})
@@ -2147,7 +2147,7 @@ def __get__(self, instance, instance_type=None):
if not self.field.null:
raise
return getattr(instance, self.cache_name, None)
-
+
def __set__(self, instance, obj):
if self.field.null and obj is None:
setattr(instance, self.field_column, None)
@@ -2165,7 +2165,7 @@ class ReverseForeignRelatedObject(object):
def __init__(self, related_model, name):
self.field_name = name
self.related_model = related_model
-
+
def __get__(self, instance, instance_type=None):
query = {self.field_name: instance.get_pk()}
qr = self.related_model.select().where(**query)
@@ -2174,7 +2174,7 @@ def __get__(self, instance, instance_type=None):
class ForeignKeyField(IntegerField):
field_template = '%(column)s%(nullable)s REFERENCES %(to_table)s (%(to_pk)s)%(cascade)s%(extra)s'
-
+
def __init__(self, to, null=False, related_name=None, cascade=False, extra=None, *args, **kwargs):
self.to = to
self._related_name = related_name
@@ -2186,12 +2186,12 @@ def __init__(self, to, null=False, related_name=None, cascade=False, extra=None,
'extra': self.extra or '',
})
super(ForeignKeyField, self).__init__(null=null, *args, **kwargs)
-
+
def add_to_class(self, klass, name):
self.name = name
self.model = klass
self.db_column = self.db_column or self.name + '_id'
-
+
if self.name == self.db_column:
self.id_storage = self.db_column + '_id'
else:
@@ -2201,32 +2201,32 @@ def add_to_class(self, klass, name):
self.to = self.model
self.verbose_name = self.verbose_name or re.sub('_', ' ', name).title()
-
+
if self._related_name is not None:
self.related_name = self._related_name
else:
self.related_name = klass._meta.db_table + '_set'
-
+
klass._meta.rel_fields[name] = self.name
setattr(klass, self.name, ForeignRelatedObject(self.to, self))
setattr(klass, self.id_storage, None)
-
+
reverse_rel = ReverseForeignRelatedObject(klass, self.name)
setattr(self.to, self.related_name, reverse_rel)
self.to._meta.reverse_relations[self.related_name] = klass
-
+
def lookup_value(self, lookup_type, value):
if isinstance(value, Model):
return value.get_pk()
return value or None
-
+
def db_value(self, value):
if isinstance(value, Model):
return value.get_pk()
if self.null and value is None:
return None
return self.column.db_value(value)
-
+
def get_column(self):
to_pk = self.to._meta.get_field_by_name(self.to._meta.pk_name)
to_col_class = to_pk.get_column_class()
@@ -2258,57 +2258,57 @@ def __init__(self, model_class, options=None):
options = options or {'database': database}
for k, v in options.items():
setattr(self, k, v)
-
+
self.rel_fields = {}
self.reverse_relations = {}
self.fields = {}
self.columns = {}
self.model_class = model_class
-
+
def get_sorted_fields(self):
return sorted(self.fields.items(), key=lambda (k,v): (k == self.pk_name and 1 or 2, v._order))
-
+
def get_field_names(self):
return [f[0] for f in self.get_sorted_fields()]
-
+
def get_fields(self):
return [f[1] for f in self.get_sorted_fields()]
-
+
def get_field_by_name(self, name):
if name in self.fields:
return self.fields[name]
raise AttributeError('Field named %s not found' % name)
-
+
def get_column_names(self):
return self.columns.keys()
-
+
def get_column(self, field_or_col):
if field_or_col in self.fields:
return self.fields[field_or_col].db_column
return field_or_col
-
+
def get_related_field_by_name(self, name):
if name in self.rel_fields:
return self.fields[self.rel_fields[name]]
-
+
def get_related_field_for_model(self, model, name=None):
for field in self.fields.values():
if isinstance(field, ForeignKeyField) and field.to == model:
if name is None or name == field.name or name == field.db_column:
return field
-
+
def get_reverse_related_field_for_model(self, model, name=None):
for field in model._meta.fields.values():
if isinstance(field, ForeignKeyField) and field.to == self.model_class:
if name is None or name == field.name or name == field.db_column:
return field
-
+
def get_field_for_related_name(self, model, related_name):
for field in model._meta.fields.values():
if isinstance(field, ForeignKeyField) and field.to == self.model_class:
if field.related_name == related_name:
return field
-
+
def rel_exists(self, model):
return self.get_related_field_for_model(model) or \
self.get_reverse_related_field_for_model(model)
@@ -2316,7 +2316,7 @@ def rel_exists(self, model):
class BaseModel(type):
inheritable_options = ['database', 'ordering', 'pk_sequence']
-
+
def __new__(cls, name, bases, attrs):
cls = super(BaseModel, cls).__new__(cls, name, bases, attrs)
@@ -2327,12 +2327,12 @@ def __new__(cls, name, bases, attrs):
meta = attrs.pop('Meta', None)
if meta:
attr_dict = meta.__dict__
-
+
for b in bases:
base_meta = getattr(b, '_meta', None)
if not base_meta:
continue
-
+
for (k, v) in base_meta.__dict__.items():
if k in cls.inheritable_options and k not in attr_dict:
attr_dict[k] = v
@@ -2346,7 +2346,7 @@ def __new__(cls, name, bases, attrs):
setattr(cls, field_name, field_copy)
_meta = BaseModelOptions(cls, attr_dict)
-
+
if not hasattr(_meta, 'db_table'):
_meta.db_table = re.sub('[^\w]+', '_', cls.__name__.lower())
@@ -2366,7 +2366,7 @@ def __new__(cls, name, bases, attrs):
_meta.columns[attr.db_column] = attr
if isinstance(attr, PrimaryKeyField):
_meta.pk_name = attr.name
-
+
if _meta.pk_name is None:
_meta.pk_name = 'id'
pk = PrimaryKeyField()
@@ -2374,7 +2374,7 @@ def __new__(cls, name, bases, attrs):
_meta.fields[_meta.pk_name] = pk
_meta.model_name = cls.__name__
-
+
pk_field = _meta.fields[_meta.pk_name]
pk_col = pk_field.column
if _meta.pk_sequence and _meta.database.adapter.sequence_support:
@@ -2384,26 +2384,26 @@ def __new__(cls, name, bases, attrs):
for field in _meta.fields.values():
field.class_prepared()
-
+
if hasattr(cls, '__unicode__'):
setattr(cls, '__repr__', lambda self: '<%s: %s>' % (
_meta.model_name, self.__unicode__()))
exception_class = type('%sDoesNotExist' % _meta.model_name, (DoesNotExist,), {})
cls.DoesNotExist = exception_class
-
+
return cls
class Model(object):
__metaclass__ = BaseModel
-
+
def __init__(self, *args, **kwargs):
self.initialize_defaults()
-
+
for k, v in kwargs.items():
setattr(self, k, v)
-
+
def initialize_defaults(self):
for field in self._meta.fields.values():
if field.default is not None:
@@ -2412,67 +2412,67 @@ def initialize_defaults(self):
else:
field_value = field.default
setattr(self, field.name, field_value)
-
+
def __eq__(self, other):
return other.__class__ == self.__class__ and \
self.get_pk() and \
other.get_pk() == self.get_pk()
-
+
def get_field_dict(self):
field_dict = {}
-
+
for field in self._meta.fields.values():
if isinstance(field, ForeignKeyField):
field_dict[field.name] = getattr(self, field.id_storage)
else:
field_dict[field.name] = getattr(self, field.name)
-
+
return field_dict
-
+
@classmethod
def table_exists(cls):
return cls._meta.db_table in cls._meta.database.get_tables()
-
+
@classmethod
def create_table(cls, fail_silently=False):
if fail_silently and cls.table_exists():
return
cls._meta.database.create_table(cls)
-
+
for field_name, field_obj in cls._meta.fields.items():
if isinstance(field_obj, ForeignKeyField):
cls._meta.database.create_foreign_key(cls, field_obj)
elif field_obj.db_index or field_obj.unique:
cls._meta.database.create_index(cls, field_obj.name, field_obj.unique)
-
+
@classmethod
def drop_table(cls, fail_silently=False):
cls._meta.database.drop_table(cls, fail_silently)
-
+
@classmethod
def filter(cls, *args, **kwargs):
return filter_query(cls, *args, **kwargs)
-
+
@classmethod
def select(cls, query=None):
select_query = SelectQuery(cls, query)
if cls._meta.ordering:
select_query = select_query.order_by(*cls._meta.ordering)
return select_query
-
+
@classmethod
def update(cls, **query):
return UpdateQuery(cls, **query)
-
+
@classmethod
def insert(cls, **query):
return InsertQuery(cls, **query)
-
+
@classmethod
def delete(cls, **query):
return DeleteQuery(cls, **query)
-
+
@classmethod
def raw(cls, sql, *params):