Permalink
Browse files

Refactored qs.add_q() and utils/tree.py

The sql/query.py add_q method did a lot of where/having tree hacking to
get complex queries to work correctly. The logic was refactored so that
it should be simpler to understand. The new logic should also produce
leaner WHERE conditions.

The changes cascade somewhat, as some other parts of Django (like
add_filter() and WhereNode) expect boolean trees in certain format or
they fail to work. So to fix the add_q() one must fix utils/tree.py,
some things in add_filter(), WhereNode and so on.

This commit also fixed add_filter to see negate clauses up the path.
A query like .exclude(Q(reversefk__in=a_list)) didn't work similarly to
.filter(~Q(reversefk__in=a_list)). The reason for this is that only
the immediate parent negate clauses were seen by add_filter, and thus a
tree like AND: (NOT AND: (AND: condition)) will not be handled
correctly, as there is one intermediary AND node in the tree. The
example tree is generated by .exclude(~Q(reversefk__in=a_list)).

Still, aggregation lost connectors in OR cases, and F() objects and
aggregates in same filter clause caused GROUP BY problems on some
databases.

Fixed #17600, fixed #13198, fixed #17025, fixed #17000, fixed #11293.
  • Loading branch information...
1 parent d744c55 commit d3f00bd5706b35961390d3814dd7e322ead3a9a3 @akaariai akaariai committed May 24, 2012
@@ -32,13 +32,14 @@ class GeoWhereNode(WhereNode):
Used to represent the SQL where-clause for spatial databases --
these are tied to the GeoQuery class that created it.
"""
- def add(self, data, connector):
+
+ def _prepare_data(self, data):
if isinstance(data, (list, tuple)):
obj, lookup_type, value = data
if ( isinstance(obj, Constraint) and
isinstance(obj.field, GeometryField) ):
data = (GeoConstraint(obj), lookup_type, value)
- super(GeoWhereNode, self).add(data, connector)
+ return super(GeoWhereNode, self)._prepare_data(data)
def make_atom(self, child, qn, connection):
lvalue, lookup_type, value_annot, params_or_value = child
@@ -1,6 +1,19 @@
"""
Classes to represent the definitions of aggregate functions.
"""
+from django.db.models.constants import LOOKUP_SEP
+
+def refs_aggregate(lookup_parts, aggregates):
+ """
+ A little helper method to check if the lookup_parts contains references
+ to the given aggregates set. Because the LOOKUP_SEP is contained in the
+ default annotation names we must check each prefix of the lookup_parts
+ for match.
+ """
+ for i in range(len(lookup_parts) + 1):
+ if LOOKUP_SEP.join(lookup_parts[0:i]) in aggregates:
+ return True
+ return False
class Aggregate(object):
"""
@@ -4,4 +4,3 @@
# Separator used to split filter strings apart.
LOOKUP_SEP = '__'
-
@@ -1,4 +1,7 @@
import datetime
+
+from django.db.models.aggregates import refs_aggregate
+from django.db.models.constants import LOOKUP_SEP
from django.utils import tree
class ExpressionNode(tree.Node):
@@ -37,6 +40,18 @@ def _combine(self, other, connector, reversed, node=None):
obj.add(other, connector)
return obj
+ def contains_aggregate(self, existing_aggregates):
+ if self.children:
+ return any(child.contains_aggregate(existing_aggregates)
+ for child in self.children
+ if hasattr(child, 'contains_aggregate'))
+ else:
+ return refs_aggregate(self.name.split(LOOKUP_SEP),
+ existing_aggregates)
+
+ def prepare_database_save(self, unused):
+ return self
+
###################
# VISITOR METHODS #
###################
@@ -113,9 +128,6 @@ def __ror__(self, other):
"Use .bitand() and .bitor() for bitwise logical operations."
)
- def prepare_database_save(self, unused):
- return self
-
class F(ExpressionNode):
"""
An expression representing the value of the given field.
@@ -47,6 +47,7 @@ def _combine(self, other, conn):
if not isinstance(other, Q):
raise TypeError(other)
obj = type(self)()
+ obj.connector = conn
obj.add(self, conn)
obj.add(other, conn)
return obj
@@ -63,6 +64,16 @@ def __invert__(self):
obj.negate()
return obj
+ def clone(self):
+ clone = self.__class__._new_instance(
+ children=[], connector=self.connector, negated=self.negated)
+ for child in self.children:
+ if hasattr(child, 'clone'):
+ clone.children.append(child.clone())
+ else:
+ clone.children.append(child)
+ return clone
+
class DeferredAttribute(object):
"""
A wrapper for a deferred-loading field. When the value is read from this
@@ -87,6 +87,7 @@ def as_sql(self, with_limits=True, with_col_aliases=False):
where, w_params = self.query.where.as_sql(qn=qn, connection=self.connection)
having, h_params = self.query.having.as_sql(qn=qn, connection=self.connection)
+ having_group_by = self.query.having.get_cols()
params = []
for val in six.itervalues(self.query.extra_select):
params.extend(val[1])
@@ -107,7 +108,7 @@ def as_sql(self, with_limits=True, with_col_aliases=False):
result.append('WHERE %s' % where)
params.extend(w_params)
- grouping, gb_params = self.get_grouping(ordering_group_by)
+ grouping, gb_params = self.get_grouping(having_group_by, ordering_group_by)
if grouping:
if distinct_fields:
raise NotImplementedError(
@@ -534,7 +535,7 @@ def get_from_clause(self):
first = False
return result, from_params
- def get_grouping(self, ordering_group_by):
+ def get_grouping(self, having_group_by, ordering_group_by):
"""
Returns a tuple representing the SQL elements in the "group by" clause.
"""
@@ -551,7 +552,7 @@ def get_grouping(self, ordering_group_by):
]
select_cols = []
seen = set()
- cols = self.query.group_by + select_cols
+ cols = self.query.group_by + having_group_by + select_cols
for col in cols:
col_params = ()
if isinstance(col, (list, tuple)):
@@ -7,23 +7,30 @@ class SQLEvaluator(object):
def __init__(self, expression, query, allow_joins=True, reuse=None):
self.expression = expression
self.opts = query.get_meta()
- self.cols = []
-
- self.contains_aggregate = False
self.reuse = reuse
+ self.cols = []
self.expression.prepare(self, query, allow_joins)
def relabeled_clone(self, change_map):
clone = copy.copy(self)
clone.cols = []
- for node, col in self.cols[:]:
+ for node, col in self.cols:
if hasattr(col, 'relabeled_clone'):
clone.cols.append((node, col.relabeled_clone(change_map)))
else:
clone.cols.append((node,
(change_map.get(col[0], col[0]), col[1])))
return clone
+ def get_cols(self):
+ cols = []
+ for node, col in self.cols:
+ if hasattr(node, 'get_cols'):
+ cols.extend(node.get_cols())
+ elif isinstance(col, tuple):
+ cols.append(col)
+ return cols
+
def prepare(self):
return self
@@ -44,9 +51,7 @@ def prepare_leaf(self, node, query, allow_joins):
raise FieldError("Joined field references are not permitted in this query")
field_list = node.name.split(LOOKUP_SEP)
- if (len(field_list) == 1 and
- node.name in query.aggregate_select.keys()):
- self.contains_aggregate = True
+ if node.name in query.aggregates:
self.cols.append((node, query.aggregate_select[node.name]))
else:
try:
Oops, something went wrong.

0 comments on commit d3f00bd

Please sign in to comment.