Permalink
Browse files

Fixed #7210 -- Added F() expressions to query language. See the docum…

…entation for details on usage.

Many thanks to:
    * Nicolas Lara, who worked on this feature during the 2008 Google Summer of Code.
    * Alex Gaynor for his help debugging and fixing a number of issues.
    * Malcolm Tredinnick for his invaluable review notes.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@9792 bcc190cf-cafb-0310-a4f2-bffc1f526a37
  • Loading branch information...
1 parent 08dd417 commit cf37e4624a967f936ecbb5a4eefc9d38ed9d7892 @freakboy3742 freakboy3742 committed Jan 29, 2009
@@ -3,6 +3,7 @@
from django.db import connection
from django.db.models.loading import get_apps, get_app, get_models, get_model, register_models
from django.db.models.query import Q
+from django.db.models.expressions import F
from django.db.models.manager import Manager
from django.db.models.base import Model
from django.db.models.aggregates import *
@@ -0,0 +1,110 @@
+from copy import deepcopy
+from datetime import datetime
+
+from django.utils import tree
+
+class ExpressionNode(tree.Node):
+ """
+ Base class for all query expressions.
+ """
+ # Arithmetic connectors
+ ADD = '+'
+ SUB = '-'
+ MUL = '*'
+ DIV = '/'
+ MOD = '%%' # This is a quoted % operator - it is quoted
+ # because it can be used in strings that also
+ # have parameter substitution.
+
+ # Bitwise operators
+ AND = '&'
+ OR = '|'
+
+ def __init__(self, children=None, connector=None, negated=False):
+ if children is not None and len(children) > 1 and connector is None:
+ raise TypeError('You have to specify a connector.')
+ super(ExpressionNode, self).__init__(children, connector, negated)
+
+ def _combine(self, other, connector, reversed, node=None):
+ if reversed:
+ obj = ExpressionNode([other], connector)
+ obj.add(node or self, connector)
+ else:
+ obj = node or ExpressionNode([self], connector)
+ obj.add(other, connector)
+ return obj
+
+ ###################
+ # VISITOR METHODS #
+ ###################
+
+ def prepare(self, evaluator, query, allow_joins):
+ return evaluator.prepare_node(self, query, allow_joins)
+
+ def evaluate(self, evaluator, qn):
+ return evaluator.evaluate_node(self, qn)
+
+ #############
+ # OPERATORS #
+ #############
+
+ def __add__(self, other):
+ return self._combine(other, self.ADD, False)
+
+ def __sub__(self, other):
+ return self._combine(other, self.SUB, False)
+
+ def __mul__(self, other):
+ return self._combine(other, self.MUL, False)
+
+ def __div__(self, other):
+ return self._combine(other, self.DIV, False)
+
+ def __mod__(self, other):
+ return self._combine(other, self.MOD, False)
+
+ def __and__(self, other):
+ return self._combine(other, self.AND, False)
+
+ def __or__(self, other):
+ return self._combine(other, self.OR, False)
+
+ def __radd__(self, other):
+ return self._combine(other, self.ADD, True)
+
+ def __rsub__(self, other):
+ return self._combine(other, self.SUB, True)
+
+ def __rmul__(self, other):
+ return self._combine(other, self.MUL, True)
+
+ def __rdiv__(self, other):
+ return self._combine(other, self.DIV, True)
+
+ def __rmod__(self, other):
+ return self._combine(other, self.MOD, True)
+
+ def __rand__(self, other):
+ return self._combine(other, self.AND, True)
+
+ def __ror__(self, other):
+ return self._combine(other, self.OR, True)
+
+class F(ExpressionNode):
+ """
+ An expression representing the value of the given field.
+ """
+ def __init__(self, name):
+ super(F, self).__init__(None, None, False)
+ self.name = name
+
+ def __deepcopy__(self, memodict):
+ obj = super(F, self).__deepcopy__(memodict)
+ obj.name = self.name
+ return obj
+
+ def prepare(self, evaluator, query, allow_joins):
+ return evaluator.prepare_leaf(self, query, allow_joins)
+
+ def evaluate(self, evaluator, qn):
+ return evaluator.evaluate_leaf(self, qn)
@@ -194,8 +194,13 @@ def get_db_prep_save(self, value):
def get_db_prep_lookup(self, lookup_type, value):
"Returns field's value prepared for database lookup."
if hasattr(value, 'as_sql'):
+ # If the value has a relabel_aliases method, it will need to
+ # be invoked before the final SQL is evaluated
+ if hasattr(value, 'relabel_aliases'):
+ return value
sql, params = value.as_sql()
return QueryWrapper(('(%s)' % sql), params)
+
if lookup_type in ('regex', 'iregex', 'month', 'day', 'search'):
return [value]
elif lookup_type in ('exact', 'gt', 'gte', 'lt', 'lte'):
@@ -309,7 +314,7 @@ def formfield(self, form_class=forms.CharField, **kwargs):
if callable(self.default):
defaults['show_hidden_initial'] = True
if self.choices:
- # Fields with choices get special treatment.
+ # Fields with choices get special treatment.
include_blank = self.blank or not (self.has_default() or 'initial' in kwargs)
defaults['choices'] = self.get_choices(include_blank=include_blank)
defaults['coerce'] = self.to_python
@@ -141,6 +141,10 @@ def pk_trace(value):
return v
if hasattr(value, 'as_sql'):
+ # If the value has a relabel_aliases method, it will need to
+ # be invoked before the final SQL is evaluated
+ if hasattr(value, 'relabel_aliases'):
+ return value
sql, params = value.as_sql()
return QueryWrapper(('(%s)' % sql), params)
@@ -17,6 +17,9 @@ class QueryWrapper(object):
def __init__(self, sql, params):
self.data = sql, params
+ def as_sql(self, qn=None):
+ return self.data
+
class Q(tree.Node):
"""
Encapsulates filters as objects that can then be combined logically (using
@@ -0,0 +1,92 @@
+from django.core.exceptions import FieldError
+from django.db import connection
+from django.db.models.fields import FieldDoesNotExist
+from django.db.models.sql.constants import LOOKUP_SEP
+
+class SQLEvaluator(object):
+ def __init__(self, expression, query, allow_joins=True):
+ self.expression = expression
+ self.opts = query.get_meta()
+ self.cols = {}
+
+ self.contains_aggregate = False
+ self.expression.prepare(self, query, allow_joins)
+
+ def as_sql(self, qn=None):
+ return self.expression.evaluate(self, qn)
+
+ def relabel_aliases(self, change_map):
+ for node, col in self.cols.items():
+ self.cols[node] = (change_map.get(col[0], col[0]), col[1])
+
+ #####################################################
+ # Vistor methods for initial expression preparation #
+ #####################################################
+
+ def prepare_node(self, node, query, allow_joins):
+ for child in node.children:
+ if hasattr(child, 'prepare'):
+ child.prepare(self, query, allow_joins)
+
+ def prepare_leaf(self, node, query, allow_joins):
+ if not allow_joins and LOOKUP_SEP in node.name:
+ raise FieldError("Joined field references are not permitted in this query")
+
+ field_list = node.name.split(LOOKUP_SEP)
+ if (len(field_list) == 1 and
+ node.name in query.aggregate_select.keys()):
+ self.contains_aggregate = True
+ self.cols[node] = query.aggregate_select[node.name]
+ else:
+ try:
+ field, source, opts, join_list, last, _ = query.setup_joins(
+ field_list, query.get_meta(),
+ query.get_initial_alias(), False)
+ _, _, col, _, join_list = query.trim_joins(source, join_list, last, False)
+
+ self.cols[node] = (join_list[-1], col)
+ except FieldDoesNotExist:
+ raise FieldError("Cannot resolve keyword %r into field. "
+ "Choices are: %s" % (self.name,
+ [f.name for f in self.opts.fields]))
+
+ ##################################################
+ # Vistor methods for final expression evaluation #
+ ##################################################
+
+ def evaluate_node(self, node, qn):
+ if not qn:
+ qn = connection.ops.quote_name
+
+ expressions = []
+ expression_params = []
+ for child in node.children:
+ if hasattr(child, 'evaluate'):
+ sql, params = child.evaluate(self, qn)
+ else:
+ try:
+ sql, params = qn(child), ()
+ except:
+ sql, params = str(child), ()
+
+ if hasattr(child, 'children') > 1:
+ format = '(%s)'
+ else:
+ format = '%s'
+
+ if sql:
+ expressions.append(format % sql)
+ expression_params.extend(params)
+ conn = ' %s ' % node.connector
+
+ return conn.join(expressions), expression_params
+
+ def evaluate_leaf(self, node, qn):
+ if not qn:
+ qn = connection.ops.quote_name
+
+ col = self.cols[node]
+ if hasattr(col, 'as_sql'):
+ return col.as_sql(qn), ()
+ else:
+ return '%s.%s' % (qn(col[0]), qn(col[1])), ()
@@ -18,6 +18,7 @@
from django.db.models.fields import FieldDoesNotExist
from django.db.models.query_utils import select_related_descend
from django.db.models.sql import aggregates as base_aggregates_module
+from django.db.models.sql.expressions import SQLEvaluator
from django.db.models.sql.where import WhereNode, Constraint, EverythingNode, AND, OR
from django.core.exceptions import FieldError
from datastructures import EmptyResultSet, Empty, MultiJoin
@@ -1271,6 +1272,10 @@ def add_filter(self, filter_expr, connector=AND, negate=False, trim=False,
else:
lookup_type = parts.pop()
+ # By default, this is a WHERE clause. If an aggregate is referenced
+ # in the value, the filter will be promoted to a HAVING
+ having_clause = False
+
# Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all
# uses of None as a query value.
if value is None:
@@ -1284,6 +1289,10 @@ def add_filter(self, filter_expr, connector=AND, negate=False, trim=False,
value = True
elif callable(value):
value = value()
+ elif hasattr(value, 'evaluate'):
+ # If value is a query expression, evaluate it
+ value = SQLEvaluator(value, self)
+ having_clause = value.contains_aggregate
for alias, aggregate in self.aggregate_select.items():
if alias == parts[0]:
@@ -1340,8 +1349,13 @@ def add_filter(self, filter_expr, connector=AND, negate=False, trim=False,
self.promote_alias_chain(join_it, join_promote)
self.promote_alias_chain(table_it, table_promote)
- self.where.add((Constraint(alias, col, field), lookup_type, value),
- connector)
+
+ if having_clause:
+ self.having.add((Constraint(alias, col, field), lookup_type, value),
+ connector)
+ else:
+ self.where.add((Constraint(alias, col, field), lookup_type, value),
+ connector)
if negate:
self.promote_alias_chain(join_list)
@@ -5,6 +5,7 @@
from django.core.exceptions import FieldError
from django.db.models.sql.constants import *
from django.db.models.sql.datastructures import Date
+from django.db.models.sql.expressions import SQLEvaluator
from django.db.models.sql.query import Query
from django.db.models.sql.where import AND, Constraint
@@ -136,7 +137,11 @@ def as_sql(self):
result.append('SET')
values, update_params = [], []
for name, val, placeholder in self.values:
- if val is not None:
+ if hasattr(val, 'as_sql'):
+ sql, params = val.as_sql(qn)
+ values.append('%s = %s' % (qn(name), sql))
+ update_params.extend(params)
+ elif val is not None:
values.append('%s = %s' % (qn(name), placeholder))
update_params.append(val)
else:
@@ -251,6 +256,8 @@ def add_update_fields(self, values_seq):
else:
placeholder = '%s'
+ if hasattr(val, 'evaluate'):
+ val = SQLEvaluator(val, self, allow_joins=False)
if model:
self.add_related_update(model, field.column, val, placeholder)
else:
@@ -97,6 +97,7 @@ def as_sql(self, qn=None):
else:
# A leaf node in the tree.
sql, params = self.make_atom(child, qn)
+
except EmptyResultSet:
if self.connector == AND and not self.negated:
# We can bail out early in this particular case (only).
@@ -114,6 +115,7 @@ def as_sql(self, qn=None):
if self.negated:
empty = True
continue
+
empty = False
if sql:
result.append(sql)
@@ -151,8 +153,9 @@ def make_atom(self, child, qn):
else:
cast_sql = '%s'
- if isinstance(params, QueryWrapper):
- extra, params = params.data
+ if hasattr(params, 'as_sql'):
+ extra, params = params.as_sql(qn)
+ cast_sql = ''
else:
extra = ''
@@ -214,6 +217,9 @@ def relabel_aliases(self, change_map, node=None):
if elt[0] in change_map:
elt[0] = change_map[elt[0]]
node.children[pos] = (tuple(elt),) + child[1:]
+ # Check if the query value also requires relabelling
+ if hasattr(child[3], 'relabel_aliases'):
+ child[3].relabel_aliases(change_map)
class EverythingNode(object):
"""
Oops, something went wrong. Retry.

0 comments on commit cf37e46

Please sign in to comment.